[
  {
    "path": ".cargo/audit.toml",
    "content": "# Audit config file\n#\n# It may be located in the user home (`~/.cargo/audit.toml`) or in the project\n# root (`.cargo/audit.toml`).\n#\n# All of the options which can be passed via CLI arguments can also be\n# permanently specified in this file.\n\n[advisories]\nignore = [\n    \"RUSTSEC-2024-0436\", # Paste used to generate macro, should be removed at some point.\n    \"RUSTSEC-2025-0119\", # `number_prefix` used by `tokenizers`, only in the examples.\n    \"RUSTSEC-2025-0141\", # `bincode` is no longer maintained.\n    \"RUSTSEC-2024-0388\", # `derivative` dependancy in the DQN example is unmaintained.\n] # advisory IDs to ignore e.g. [\"RUSTSEC-2019-0001\", ...]\ninformational_warnings = [\n    \"unmaintained\",\n] # warn for categories of informational advisories\nseverity_threshold = \"low\" # CVSS severity (\"none\", \"low\", \"medium\", \"high\", \"critical\")\n\n# Output Configuration\n[output]\ndeny = [\"unmaintained\"] # exit on error if unmaintained dependencies are found\nformat = \"terminal\"     # \"terminal\" (human readable report) or \"json\"\nquiet = false           # Only print information on error\nshow_tree = true        # Show inverse dependency trees along with advisories (default: true)\n\n[yanked]\nenabled = true      # Warn for yanked crates in Cargo.lock (default: true)\nupdate_index = true # Auto-update the crates.io index (default: true)\n"
  },
  {
    "path": ".cargo/config.toml",
    "content": "[alias]\nxtask = \"run --target-dir target/xtask --color always --package xtask --bin xtask --\"\nrun-checks = \"xtask -c all validate --release\""
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\n<!-- A clear and concise description of what the bug is. -->\n\n**To Reproduce**\n<!-- \n Steps to reproduce the behavior:\n 1. Go to '...'\n 2. Click on '....'\n 3. Scroll down to '....'\n 4. See error\n-->\n\n**Expected behavior**\n<!-- A clear and concise description of what you expected to happen. -->\n\n**Screenshots**\n<!-- If applicable, add screenshots to help explain your problem. -->\n\n**Desktop (please complete the following information):**\n - OS: [e.g. iOS]\n - Browser [e.g. chrome, safari]\n - Version [e.g. 22]\n\n**Smartphone (please complete the following information):**\n - Device: [e.g. iPhone6]\n - OS: [e.g. iOS8.1]\n - Browser [e.g. stock browser, safari]\n - Version [e.g. 22]\n\n**Additional context**\n<!-- Add any other context about the problem here. -->"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/doc_request.md",
    "content": "---\nname: Documentation request\nabout: Flag incoherent or missing documentation, including use case examples.\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n<!-- Please search existing issues to avoid creating duplicates -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n<!-- Please search existing issues to avoid creating duplicates -->\n\n### Feature description\n\n<!-- Describe the feature you'd like -->\n\n### Feature motivation\n\n<!-- Why do you want this? -->\n\n### (Optional) Suggest a Solution\n\n<!--\n  How do you think we should implement this feature? \n  Things to address include:\n    * Details of the technical implementation\n    * Tradeoffs made in design decisions\n    * Caveats and considerations for the future\n-->\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE/template.md",
    "content": "* **Please check if the PR fulfills these requirements**\n- [ ] The commit message follows our guidelines\n- [ ] Docs have been added / updated (for bug fixes / features)\n\n\n* **What kind of change does this PR introduce?** (Bug fix, feature, docs update, ...)\n\n\n* **Does this PR introduce a breaking change?** (What changes might users need to make in their application due to this PR?)\n\n\n* **Other information**:"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\n\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"daily\"\n    ignore:\n      - dependency-name: \"tracel-ai/github-actions*\"\n\n  - package-ecosystem: \"cargo\"\n    directories:\n      - \"/\"\n      - \"crates/burn\"\n      - \"crates/burn-*\"\n      - \"crates/burn-import/*-tests\"\n      - \"examples/*\"\n      - \"xtask\"\n    schedule:\n      interval: \"weekly\"\n\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "## Pull Request Template\n\n### Checklist\n\n- [ ] Confirmed that `cargo run-checks` command has been executed.\n- [ ] Made sure the book is up to date with changes in this PR.\n\n### Related Issues/PRs\n\n_Provide links to relevant issues and dependent PRs._\n\n### Changes\n\n_Summarize the problem being addressed and your solution._\n\n### Testing\n\n_Describe how these changes have been tested._\n"
  },
  {
    "path": ".github/workflows/combine-dependabot-prs.yml",
    "content": "name: Combine Dependabot PRs\n\non:\n  schedule:\n    - cron: '0 6 * * MON' # Monday at 6:00am UTC\n  workflow_dispatch:\n\npermissions:\n  contents: write\n  pull-requests: write\n  checks: read\n\njobs:\n  combine-prs:\n    runs-on: ubuntu-latest\n    steps:\n      - name: combine-prs\n        id: combine-prs\n        uses: github/combine-prs@v5.2.0\n        with:\n          labels: dependencies,automated\n"
  },
  {
    "path": ".github/workflows/dependencies.yml",
    "content": "name: dependencies\n\non:\n  schedule:\n    - cron: '0 21 * * TUE' # Run every Tuesday at 21:00 (UTC)\n  push:\n    tags:\n      - 'v*.*.*' # Run when a new version is being published\n\nenv:\n  UDEPS_VERSION: \"0.1.57\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  dependencies:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        checks:\n          - licenses\n          - bans sources\n    continue-on-error: ${{ matrix.checks == 'licenses' }} # failed licenses don't abort\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Audit Rust dependencies\n        # If a vulnerability is found, a new issue will automatically be opened\n        # since this action runs on main branch\n        uses: actions-rust-lang/audit@v1\n      # --------------------------------------------------------------------------------\n      - name: Detect multiple versions of the same crate\n        uses: EmbarkStudios/cargo-deny-action@v2\n        with:\n          command: check ${{ matrix.checks }}\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt\n      # --------------------------------------------------------------------------------\n      - name: Install cargo-udeps\n        env:\n          UDEPS_LINK: https://github.com/est31/cargo-udeps/releases/download\n        run: |\n          curl -L \"$UDEPS_LINK/v$UDEPS_VERSION/cargo-udeps-v$UDEPS_VERSION-x86_64-unknown-linux-gnu.tar.gz\" |\n          tar xz -C $HOME/.cargo/bin --strip-components 2\n      # --------------------------------------------------------------------------------\n      - name: Run cargo-udeps\n        run: |\n          cargo +nightly udeps --all-targets\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: publish\n\non:\n  push:\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n    inputs:\n      dry-run-only:\n        description: \"Run xtask publish in dry-run mode (no publish)\"\n        type: boolean\n        required: false\n        default: false\n\njobs:\n  publish-burn-rl:\n    needs:\n      - publish-burn-core\n      - publish-burn-optim\n      # dev dependencies\n      - publish-burn-ndarray\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-rl\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-vision:\n    needs:\n      - publish-burn-autodiff\n      - publish-burn-candle\n      - publish-burn-fusion\n      - publish-burn-cubecl-fusion\n      - publish-burn-cubecl\n      - publish-burn-ndarray\n      - publish-burn-tch\n      - publish-burn-tensor\n      - publish-burn-ir\n      - publish-burn-tensor-testgen\n      # dev dependencies\n      - publish-burn-wgpu\n      - publish-burn-cuda\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-vision\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-router:\n    needs:\n      - publish-burn-ir\n      - publish-burn-std\n      - publish-burn-tensor\n      # dev dependencies\n      - publish-burn-autodiff\n      - publish-burn-ndarray\n      - publish-burn-wgpu\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-router\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-remote:\n    needs:\n      - publish-burn-ir\n      - publish-burn-std\n      - publish-burn-tensor\n      - publish-burn-router\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-remote\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-derive:\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-derive\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-dataset:\n    needs:\n      - publish-burn-std\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-dataset\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-std:\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-std\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-tensor-testgen:\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-tensor-testgen\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-tensor:\n    needs:\n      - publish-burn-tensor-testgen\n      - publish-burn-std\n      - publish-burn-backend\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-tensor\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-backend:\n    needs:\n      - publish-burn-std\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-backend\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-ir:\n    needs:\n      - publish-burn-tensor\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-ir\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-fusion:\n    needs:\n      - publish-burn-ir\n      - publish-burn-tensor\n      - publish-burn-std\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-fusion\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-cubecl-fusion:\n    needs:\n      - publish-burn-ir\n      - publish-burn-std\n      - publish-burn-fusion\n      - publish-burn-tensor\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-cubecl-fusion\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-cubecl:\n    needs:\n      - publish-burn-ir\n      - publish-burn-std\n      - publish-burn-fusion\n      - publish-burn-cubecl-fusion\n      - publish-burn-tensor\n      - publish-burn-ndarray\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-cubecl\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-autodiff:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-tensor-testgen\n      - publish-burn-std\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-autodiff\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-tch:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-autodiff\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-tch\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-ndarray:\n    needs:\n      - publish-burn-ir\n      - publish-burn-tensor\n      - publish-burn-autodiff\n      - publish-burn-std\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-ndarray\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-wgpu:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-autodiff\n      - publish-burn-ndarray\n      - publish-burn-std\n      - publish-burn-cubecl\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-wgpu\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-cpu:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-fusion\n      - publish-burn-cubecl\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-cpu\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-cuda:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-autodiff\n      - publish-burn-ndarray\n      - publish-burn-std\n      - publish-burn-cubecl\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-cuda\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-rocm:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-autodiff\n      - publish-burn-ndarray\n      - publish-burn-std\n      - publish-burn-cubecl\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-rocm\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-candle:\n    needs:\n      - publish-burn-tensor\n      - publish-burn-autodiff\n      - publish-burn-tch\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-candle\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-collective:\n    needs:\n      - publish-burn-std\n      - publish-burn-tensor\n      - publish-burn-communication\n      # dev dependencies\n      - publish-burn-wgpu\n      - publish-burn-ndarray\n      - publish-burn-cuda\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-collective\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-communication:\n    needs:\n      - publish-burn-std\n      - publish-burn-tensor\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-communication\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-core:\n    needs:\n      - publish-burn-dataset\n      - publish-burn-std\n      - publish-burn-derive\n      - publish-burn-tensor\n      - publish-burn-vision\n      # dev dependencies\n      - publish-burn-autodiff\n      - publish-burn-wgpu\n      - publish-burn-tch\n      - publish-burn-cuda\n      - publish-burn-ndarray\n      - publish-burn-candle\n      - publish-burn-remote\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-core\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-nn:\n    needs:\n      - publish-burn-core\n      # dev dependencies\n      - publish-burn-autodiff\n      - publish-burn-wgpu\n      - publish-burn-tch\n      - publish-burn-ndarray\n      - publish-burn-candle\n      - publish-burn-remote\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-nn\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-optim:\n    needs:\n      - publish-burn-core\n      - publish-burn-collective\n      # dev dependencies\n      - publish-burn-autodiff\n      - publish-burn-wgpu\n      - publish-burn-tch\n      - publish-burn-ndarray\n      - publish-burn-candle\n      - publish-burn-remote\n      - publish-burn-nn\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-optim\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-train:\n    needs:\n      - publish-burn-core\n      - publish-burn-optim\n      - publish-burn-collective\n      - publish-burn-rl\n      - publish-burn-ndarray\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-train\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-dispatch:\n    needs:\n      - publish-burn-std\n      - publish-burn-backend\n      - publish-burn-autodiff\n      - publish-burn-cpu\n      - publish-burn-cuda\n      - publish-burn-rocm\n      - publish-burn-wgpu\n      - publish-burn-ndarray\n      - publish-burn-tch\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-dispatch\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn:\n    needs:\n      - publish-burn-core\n      - publish-burn-nn\n      - publish-burn-optim\n      - publish-burn-collective\n      - publish-burn-store\n      - publish-burn-train\n      - publish-burn-cpu\n      - publish-burn-dispatch\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n\n  publish-burn-store:\n    needs:\n      - publish-burn-core\n      - publish-burn-nn\n      - publish-burn-tensor\n    uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v9\n    with:\n      crate: burn-store\n      dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}\n    secrets:\n      CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/stale-pr.yml",
    "content": "name: Stale Pull Requests\n\non:\n  schedule:\n    - cron: '0 12 * * *' # Run every day at 12:00 (UTC)\n\n# The minimum permissions required to run this Action\npermissions:\n  contents: write # only for delete-branch option\n  issues: write\n  pull-requests: write\n\njobs:\n  stale-pr:\n    runs-on: ubuntu-latest\n    steps:\n    - name: checkout\n      uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n    - name: Stale pull requests\n      uses: actions/stale@v10\n      with:\n        # The idle number of days before marking issues stale.\n        #\n        #  With a negative number like -1, no issues\n        #  will be marked as stale automatically.\n        days-before-issue-stale: -1\n        # The idle number of days before marking pull requests stale\n        days-before-pr-stale: 30\n        # The idle number of days before closing\n        # the stale pull requests (due to the stale label).\n        #\n        # With a negative number like -1, the pull requests\n        # will never be closed automatically.\n        days-before-pr-close: -1\n        # Label to apply on staled pull requests\n        stale-pr-label: 'stale'\n        # The message that will be added as a comment to the pull request\n        stale-pr-message: 'This PR has been marked as stale because it has not been updated for over a month'\n        # Remove `stale` label from pull requests on updates/comments\n        remove-pr-stale-when-updated: true\n"
  },
  {
    "path": ".github/workflows/test-gpu.yml",
    "content": "name: CI GPU\n\non:\n  workflow_dispatch:\n    inputs:\n      pr_number:\n        description: \"Number of the pull request that triggers this run if any\"\n        type: number\n        required: false\n\n# important to set the run name to this format so that the CI server\n# can track the PR number from the workflow_run events.\nrun-name: ${{ github.workflow }}:${{ github.repository }}#${{ inputs.pr_number }}\n\nenv:\n  # Note: It is not possible to define top level env vars and pass them to composite actions.\n  # To work around this issue we use inputs and define all the env vars here.\n\n  RUST_PREVIOUS_VERSION: 1.92.0\n\n  # Dependency versioning\n  # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml\n\n  # GCP runners\n  GCP_RUNNERS_IMAGE_FAMILY: \"tracel-ci-ubuntu-2404-amd64-nvidia\"\n  GCP_RUNNERS_MACHINE_TYPE: \"g2-standard-4\"\n  GCP_RUNNERS_ZONE: \"us-east1-c\"\n\n  # Test in release mode (make it an empty string to test in debug mode)\n  TEST_RELEASE_FLAG: \"--release\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  prepare-checks:\n    runs-on: ubuntu-latest\n    outputs:\n      rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }}\n      gcp_runners_image_family: ${{ env.GCP_RUNNERS_IMAGE_FAMILY }}\n      gcp_runners_machine_type: ${{ env.GCP_RUNNERS_MACHINE_TYPE }}\n      gcp_runners_zone: ${{ env.GCP_RUNNERS_ZONE }}\n    steps:\n      - name: Do Nothing\n        if: false\n        run: echo\n\n  linux-std-cuda-tests:\n    needs: [prepare-checks]\n    timeout-minutes: 60\n    # '@id:' label must be unique within this worklow\n    runs-on:\n      [\n        \"@id:burn-cuda-job-${{github.run_id}}-${{github.run_attempt}}\",\n        \"@pr_number:${{ inputs.pr_number }}\",\n        \"@organization:tracel-ai\",\n        \"@repository:burn\",\n        \"@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}\",\n        \"@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}\",\n        \"@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}\",\n        \"@gpu:true\",\n      ]\n    env:\n      LD_LIBRARY_PATH: \"/usr/local/cuda/lib64\"\n      # disable incremental compilation (reduces artifact size)\n      CARGO_PROFILE_TEST_INCREMENTAL: \"false\"\n    # Keep the stragegy to be able to easily add new rust versions if required\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          enable-cache: false\n      # --------------------------------------------------------------------------------\n      - name: Tests (burn-cuda)\n        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-cuda-runner\n\n  linux-std-vulkan-tests:\n    needs: [prepare-checks]\n    timeout-minutes: 60\n    # '@id:' label must be unique within this worklow\n    runs-on:\n      [\n        \"@id:burn-vulkan-job-${{github.run_id}}-${{github.run_attempt}}\",\n        \"@pr_number:${{ inputs.pr_number }}\",\n        \"@organization:tracel-ai\",\n        \"@repository:burn\",\n        \"@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}\",\n        \"@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}\",\n        \"@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}\",\n        \"@gpu:true\",\n      ]\n    env:\n      # disable incremental compilation (reduces artifact size)\n      CARGO_PROFILE_TEST_INCREMENTAL: \"false\"\n    # Keep the stragegy to be able to easily add new rust versions if required\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          enable-cache: false\n      # --------------------------------------------------------------------------------\n      - name: Tests (burn-vulkan)\n        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-vulkan-runner\n\n  linux-std-wgpu-tests:\n    needs: [prepare-checks]\n    timeout-minutes: 60\n    # '@id:' label must be unique within this worklow\n    runs-on:\n      [\n        \"@id:burn-wgpu-job-${{github.run_id}}-${{github.run_attempt}}\",\n        \"@pr_number:${{ inputs.pr_number }}\",\n        \"@organization:tracel-ai\",\n        \"@repository:burn\",\n        \"@image-family:${{ needs.prepare-checks.outputs.gcp_runners_image_family }}\",\n        \"@machine-type:${{ needs.prepare-checks.outputs.gcp_runners_machine_type }}\",\n        \"@zones:${{ needs.prepare-checks.outputs.gcp_runners_zone }}\",\n        \"@gpu:true\",\n      ]\n    env:\n      # disable incremental compilation (reduces artifact size)\n      CARGO_PROFILE_TEST_INCREMENTAL: \"false\"\n    # Keep the stragegy to be able to easily add new rust versions if required\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          enable-cache: false\n      # --------------------------------------------------------------------------------\n      - name: Tests (burn-wgpu)\n        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci gcp-wgpu-runner\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: CI\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"Cargo.lock\"\n      - \"**.rs\"\n      - \"**.sh\"\n      - \"**.ps1\"\n      - \"**.yml\"\n      - \"**.toml\"\n      - \"!**.md\"\n      - \"!LICENSE-APACHE\"\n      - \"!LICENSE-MIT\"\n  pull_request:\n    types: [opened, synchronize]\n    paths:\n      - \"Cargo.lock\"\n      - \"**.rs\"\n      - \"**.sh\"\n      - \"**.ps1\"\n      - \"**.yml\"\n      - \"**.toml\"\n      - \"!**.md\"\n      - \"!LICENSE-APACHE\"\n      - \"!LICENSE-MIT\"\n\nenv:\n  # Note: It is not possible to define top level env vars and pass them to composite actions.\n  # To work around this issue we use inputs and define all the env vars here.\n\n  RUST_PREVIOUS_VERSION: 1.92.0\n\n  # Dependency versioning\n  # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml\n\n  # Mozilla Grcov\n  GRCOV_LINK: \"https://github.com/mozilla/grcov/releases/download\"\n  GRCOV_VERSION: \"0.8.19\"\n\n  # Test in release mode (make it an empty string to test in debug mode)\n  TEST_RELEASE_FLAG: \"--release\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  prepare-checks:\n    runs-on: ubuntu-latest\n    outputs:\n      rust-prev-version: ${{ env.RUST_PREVIOUS_VERSION }}\n    steps:\n      - name: Do Nothing\n        if: false\n        run: echo\n\n  code-quality:\n    runs-on: ubuntu-22.04\n    needs: prepare-checks\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-linux\n      # --------------------------------------------------------------------------------\n      - name: Audit\n        run: cargo xtask check audit\n      # --------------------------------------------------------------------------------\n      - name: Format\n        shell: bash\n        env:\n          # work around for colors\n          # see: https://github.com/rust-lang/rustfmt/issues/3385\n          TERM: xterm-256color\n        run: cargo xtask check format\n      # --------------------------------------------------------------------------------\n      - name: Lint\n        run: cargo xtask check lint\n      # --------------------------------------------------------------------------------\n      - name: Typos\n        uses: tracel-ai/github-actions/check-typos@v9\n\n  documentation:\n    runs-on: ubuntu-22.04\n    needs: prepare-checks\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-linux\n      # --------------------------------------------------------------------------------\n      - name: Documentation Build\n        run: cargo xtask doc build\n      # --------------------------------------------------------------------------------\n      - name: Documentation Tests\n        run: cargo xtask doc tests\n\n  linux-std-tests:\n    runs-on: ubuntu-22.04\n    needs: [prepare-checks, code-quality]\n    env:\n      DISABLE_WGPU_SPIRV: \"1\"\n      # disable incremental compilation (reduces artifact size)\n      CARGO_PROFILE_TEST_INCREMENTAL: \"false\"\n    strategy:\n      matrix:\n        rust: [stable, prev]\n        include:\n          - rust: stable\n            toolchain: stable\n            coverage: --enable-coverage\n          - rust: prev\n            toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-linux\n          # Disable cache on linux-std (stable) runner which currently always runs out of disk space with tests + coverage\n          enable-cache: ${{ matrix.rust != 'stable' }}\n      # # --------------------------------------------------------------------------------\n      - name: Install grcov\n        if: matrix.rust == 'stable'\n        shell: bash\n        run: |\n          curl -L \"$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2\" |\n          tar xj -C $HOME/.cargo/bin\n          cargo xtask coverage install\n      # --------------------------------------------------------------------------------\n      - name: Tests\n        run: cargo xtask ${{ matrix.coverage }} test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner\n      # --------------------------------------------------------------------------------\n      - name: Generate lcov.info\n        if: matrix.rust == 'stable'\n        # /* is to exclude std library code coverage from analysis\n        run: cargo xtask coverage generate --ignore \"/*,xtask/*,examples/*\" --profile release\n      # --------------------------------------------------------------------------------\n      - name: Codecov upload lcov.info\n        if: matrix.rust == 'stable'\n        uses: codecov/codecov-action@v5\n        with:\n          files: lcov.info\n          token: ${{ secrets.CODECOV_TOKEN }}\n\n  linux-no-std-tests:\n    runs-on: ubuntu-22.04\n    needs: [prepare-checks, code-quality]\n    strategy:\n      matrix:\n        rust: [stable, prev]\n        include:\n          - rust: stable\n            toolchain: stable\n          - rust: prev\n            toolchain: ${{ needs.prepare-checks.outputs.rust-prev-version }}\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-linux-no-std\n      # --------------------------------------------------------------------------------\n      - name: Crates Build\n        run: cargo xtask --context no-std build --ci\n      # --------------------------------------------------------------------------------\n      - name: Crates Tests\n        run: cargo xtask --context no-std test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner\n\n  windows-std-tests:\n    runs-on: windows-2022\n    needs: [prepare-checks, code-quality]\n    # Keep the stragegy to be able to easily add new rust versions if required\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-windows\n      # --------------------------------------------------------------------------------\n      - name: Tests\n        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-runner\n\n  macos-std-tests:\n    runs-on: blaze/macos-15\n    needs: [prepare-checks, code-quality]\n    timeout-minutes: 60\n    # Keep the stragegy to be able to easily add new rust versions if required\n    strategy:\n      matrix:\n        rust: [stable]\n        include:\n          - rust: stable\n            toolchain: stable\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Setup Rust\n        uses: tracel-ai/github-actions/install-rust@v9\n        with:\n          rust-toolchain: ${{ matrix.toolchain }}\n          cache-key: ${{ matrix.rust }}-macos\n      # --------------------------------------------------------------------------------\n      - name: Device check\n        run: system_profiler SPHardwareDataType\n      # --------------------------------------------------------------------------------\n      - name: Tests\n        run: cargo xtask test ${{ env.TEST_RELEASE_FLAG }} --ci github-mac-runner\n"
  },
  {
    "path": ".github/workflows/valgrind.yml",
    "content": "name: valgrind\n\non:\n  schedule:\n    - cron: '0 23 * * WED' # Run every Wednesday at 23:00 (UTC)\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  valgrind:\n    runs-on: [\n      '@id:burn-linux-valgrind-${{ github.run_id }}-${{ github.run_attempt }}',\n      '@image-family:ubuntu-2404-lts-amd64',\n      '@image-project:ubuntu-os-cloud',\n      '@disk-size:100',\n      '@keep-alive:false',\n      '@machine-type:n2-standard-16',\n      '@os:linux',\n      '@zones:northamerica-northeast1-b'\n      ]\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Install valgrind\n        run: |\n          sudo apt-get install valgrind\n      # --------------------------------------------------------------------------------\n      - name: Run cargo-valgrind\n        env:\n          CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER: \"valgrind -s --leak-check=full --show-leak-kinds=all --error-exitcode=1\"\n        # Looking for vulnerabilities\n        run: |\n          cargo test\n"
  },
  {
    "path": ".github/workflows/vulnerabilities.yml",
    "content": "name: vulnerabilities\n\non:\n  schedule:\n    - cron: '0 21 * * WED' # Run every Wednesday at 21:00 (UTC)\n  push:\n    tags:\n      - 'v*.*.*'\n\nenv:\n  CAREFUL_VERSION: \"0.4.9\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  cargo-careful:\n    runs-on: ubuntu-latest\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt, rust-src\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Install cargo-careful\n        env:\n          CAREFUL_LINK: https://github.com/RalfJung/cargo-careful/releases/download\n        run: |\n          curl -L \"$CAREFUL_LINK/v$CAREFUL_VERSION/cargo-careful.x86_64-unknown-linux-musl\" \\\n          --output $HOME/.cargo/bin/cargo-careful\n          chmod +x $HOME/.cargo/bin/cargo-careful\n      # --------------------------------------------------------------------------------\n      - name: Run cargo-careful\n        # Looking for undefined behaviours\n        run: cargo +nightly careful test\n\n  address-sanitizer:\n    runs-on: ubuntu-latest\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt, rust-src\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Run AddressSanitizer\n        env:\n          RUSTFLAGS: -Zsanitizer=address -Copt-level=3\n          RUSTDOCFLAGS: -Zsanitizer=address\n        # Looking for memory vulnerabilities\n        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture\n\n  thread-sanitizer:\n    runs-on: ubuntu-latest\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt, rust-src\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Run ThreadSanitizer\n        env:\n          RUSTFLAGS: -Zsanitizer=thread -Copt-level=3\n          RUSTDOCFLAGS: -Zsanitizer=thread\n        # Looking for data race among threads\n        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture\n\n  memory-sanitizer:\n    runs-on: ubuntu-latest\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt, rust-src\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Run MemorySanitizer\n        env:\n          RUSTFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins -Copt-level=3\n          RUSTDOCFLAGS: -Zsanitizer=memory -Zsanitizer-memory-track-origins\n        # Looking for unitialized memory.\n        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture\n\n  safe-stack:\n    runs-on: ubuntu-latest\n    steps:\n      - name: checkout\n        uses: actions/checkout@v6\n      # --------------------------------------------------------------------------------\n      - name: Install Rust nightly\n        uses: dtolnay/rust-toolchain@nightly\n        with:\n          toolchain: nightly\n          components: rustfmt, rust-src\n      # --------------------------------------------------------------------------------\n      - name: Install Mesa\n        uses: tracel-ai/github-actions/install-mesa@v9\n      # --------------------------------------------------------------------------------\n      - name: Run SafeStack\n        env:\n          RUSTFLAGS: -Zsanitizer=safestack -Copt-level=3\n          RUSTDOCFLAGS: -Zsanitizer=safestack\n        # Provides backward edge control flow protection\n        run: cargo test -Zbuild-std --target x86_64-unknown-linux-gnu -- --nocapture\n"
  },
  {
    "path": ".gitignore",
    "content": "target\n# These are backup files generated by rustfmt\n**/*.rs.bk\n.DS_Store\n\n.dir-locals.el\n.idea\n.vscode\n.vs\n.fleet\n.ipynb_checkpoints/\n\n# Build output directory\nout\n\n# Virtual Environment of Python\n.venv\nuv.lock\n\n# Nix direnv\n.envrc\n.direnv\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n  - family-names: \"Simard\"\n    given-names: \"Nathaniel\"\n    email: \"nathaniel.simard.42@gmail.com\"\n  - family-names: \"Fortier-Dubois\"\n    given-names: \"Louis\"\n    email: \"louisfd94@gmail.com\"\n  - family-names: \"Tadjibaev\"\n    given-names: \"Dilshod\"\n    email: \"dilshod@gmail.com\"\n  - family-names: \"Lagrange\"\n    given-names: \"Guillaume\"\n    email: \"lagrange.guillaume.1@gmail.com\"\n  - name: \"Burn Framework Contributors\"\ntitle: \"Burn\"\nversion: 0.14.0\ndate-released: 2024-08-27\nurl: \"https://burn.dev/\"\nrepository-code: \"https://github.com/tracel-ai/burn\"\nlicense:\n  - MIT\n  - Apache-2.0\nabstract: \"Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.\"\nkeywords:\n  - scientific-computing\n  - deep-learning\n  - machine-learning\n  - neural-networks\n  - rust\n  - high-performance-computing\n  - portability\n  - compute-efficiency\n"
  },
  {
    "path": "CODE-OF-CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\nnathaniel.simard.42@gmail.com.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.0, available at\nhttps://www.contributor-covenant.org/version/2/0/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Burn\n\nWelcome to the Burn community! We're glad you're interested in contributing.\n\n## How to Contribute\n\nThe best way to get started is to look at [open issues](https://github.com/tracel-ai/burn/issues)\nand find one that interests you. Issues labeled `good first issue` are a great starting point for\nnew contributors.\n\nIf you have an idea that isn't covered by an existing issue, open one first to discuss the approach.\nThis helps align expectations and avoids wasted effort on both sides.\n\nFor questions, discussions, or just to say hello, join us on\n[Discord](https://discord.gg/uPEBbYYDB6). The [Contributor Book](https://burn.dev/contributor-book/)\ncovers architecture, environment setup, and guides for common tasks.\n\n## Pull Requests\n\nEvery pull request should have a descriptive title, a description covering what you changed, why,\nhow you tested it, and a link to the relevant issue (if applicable). Prefer small, focused PRs over\nlarge ones that bundle unrelated changes.\n\nDraft pull requests are considered not yet ready for review.\n\nCI checks should pass before requesting review, though the signal isn't always accurate. If you have\nquestions or need early feedback, let us know on the PR or on\n[Discord](https://discord.gg/uPEBbYYDB6).\n\n### Change Ownership\n\nThe core principle behind all contributions: **PR authors must understand, justify, and explain\nevery change they propose.** After a PR is accepted, both the reviewer and the author should be\nconfident it improves the codebase.\n\nThis applies equally whether you wrote the code from scratch, adapted it from another project, or\nused AI tools to help generate it. The origin of the code doesn't matter; what matters is that you\nown it intellectually and can stand behind it during review.\n\n### AI-Assisted Contributions\n\nUsing LLMs and AI tools to generate code that is part of a contribution is allowed.\n\nThat said, the [Change Ownership](#change-ownership) principle applies fully. You are the author,\nnot your AI tool. This means:\n\n- Read and understand every line before submitting.\n- Review AI-generated code for correctness, style consistency, and relevance.\n- Test your changes locally and confirm they work as intended.\n- Be prepared to explain the rationale behind any change during review.\n\nDo not use \"AI generated\" as a justification for low-quality code.\n\n### Before You Open a PR\n\n1. **Check for an existing issue.** If there isn't one, open an issue first to discuss the approach.\n   This is especially important for large changes or refactors.\n2. **Read the codebase.** Understand the architecture and conventions already in place. The\n   [Contributor Book](https://burn.dev/contributor-book/) covers architecture, environment setup,\n   and guides for common tasks.\n3. **Keep it focused.** One PR should address one concern. If you spot an unrelated issue while\n   working, open a separate PR for it.\n4. **Run validation.** Run `cargo run-checks` before submitting. This runs formatting, linting, and\n   the full test suite. All checks must pass.\n\n### Code Quality Standards\n\n- Follow existing code style and project conventions.\n- Write idiomatic Rust. If you are new to the codebase, study existing patterns before contributing.\n- Keep dependencies minimal. Don't introduce new crates without discussion.\n- Document public APIs. Non-trivial logic should have comments explaining _why_, not just _what_.\n- Prefer clarity over cleverness.\n- Bug fixes should include a regression test.\n\n### Large Pull Requests\n\nLarge, complex PRs are harder to review effectively and carry more risk. To help both yourself and\nreviewers, consider breaking substantial changes into smaller, incremental PRs. Each should be\nvaluable on its own, even if the full picture spans multiple PRs.\n\nLarge efforts that are ultimately rejected are frustrating for everyone involved. If you're planning\na substantial change, open an issue or start a discussion first. It's much easier to course-correct\nearly than after the work is done.\n\n### Review Process\n\n- Maintainers review PRs as time allows. Please be patient.\n- Be responsive to feedback. If changes are requested, address them or explain your reasoning.\n- Reviewers may ask clarifying questions about any part of your PR. This is a normal part of\n  collaborative review and helps ensure shared understanding.\n- Don't force-push to rewrite history during an active review without notice.\n- If a PR goes stale for more than 14 days without a response from the author, it may be closed.\n\n## Getting Help\n\nIf you're stuck or unsure about something, don't hesitate to ask. Open an issue, start a discussion,\nor reach out on [Discord](https://discord.gg/uPEBbYYDB6). We're happy to help.\n"
  },
  {
    "path": "Cargo.toml",
    "content": "[workspace]\n# Try\n# require version 2 to avoid \"feature\" additiveness for dev-dependencies\n# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2\nresolver = \"2\"\n\nmembers = [\n    \"crates/*\",\n    \"crates/burn-store/pytorch-tests\",\n    \"crates/burn-store/safetensors-tests\",\n    \"crates/burn-collective/multinode-tests\",\n    \"examples/*\",\n    \"xtask\",\n]\n\nexclude = [\n    \"examples/notebook\",\n    \"examples/raspberry-pi-pico\",\n    \"examples/dqn-agent\",         # gym-rs\n]\n\n[workspace.package]\nedition = \"2024\"\nlicense = \"MIT OR Apache-2.0\"\nreadme = \"README.md\"\nversion = \"0.21.0-pre.2\"\n\n[workspace.lints.clippy]\n\n[workspace.lints.rustdoc]\nbroken_intra_doc_links = \"deny\"\ninvalid_html_tags = \"deny\"\n\n[workspace.dependencies]\natomic_float = \"1\"\naxum = \"0.8.8\"\nbytemuck = \"1.25.0\"\nbytes = { version = \"1.11.1\", default-features = false }\ncandle-core = { version = \"0.9.2\" }\nciborium = { version = \"0.2\", default-features = false }\nclap = { version = \"4.6.0\", features = [\"derive\"] }\ncolored = \"3.0.0\"\nconsole_error_panic_hook = \"0.1.7\"\nconst-random = \"0.1\"\ncsv = \"1.3.1\"\ndashmap = \"6.1.0\"\ndata-encoding = { version = \"2.10.0\", default-features = false, features = [\n    \"alloc\",\n] }\ndirs = \"6.0.0\"\nencoding_rs = \"0.8.33\"\nenumset = { version = \"1.1.10\", default-features = false }\nfake = \"5.1.0\"\nflate2 = \"1.1.9\"\nfloat-cmp = \"0.10.0\"\nfutures = \"0.3\"\nfutures-util = \"0.3\"\ngix-tempfile = { version = \"21.0.0\", features = [\"signals\"] }\nglobwalk = \"0.9.1\"\nhashbrown = \"0.16\"\nhound = \"3.5.1\"\nimage = \"0.25.9\"\nindicatif = \"0.18.0\"\ninsta = \"1.45.0\"\njs-sys = \"0.3.77\"\nlibm = \"0.2.15\"\nlog = { default-features = false, version = \"0.4.29\" }\nlzma-rust2 = \"0.16.2\"\nopentelemetry = \"0.31.0\"\nopentelemetry-aws = \"0.19.0\"\nopentelemetry-otlp = \"0.31.0\"\nopentelemetry_sdk = \"0.31.0\"\nparking_lot = { version = \"0.12.5\", default-features = false }\npaste = \"1\"\nplanus = { version = \"=1.1\" }\npolars = { version = \"0.53.0\", features = [\"lazy\"] }\npretty_assertions = \"1.4.1\"\nproc-macro2 = \"1.0.106\"\nquote = \"1.0.45\"\nr2d2 = \"0.8.10\"\nr2d2_sqlite = \"0.31.0\"\nrayon = \"1.10.0\"\nregex = { version = \"1.12.3\", default-features = false, features = [\n    \"perf\",\n    \"unicode\",\n] }\nreqwest = { version = \"0.12.23\", default-features = false, features = [\n    \"rustls-tls\",\n] }\nrmp-serde = { version = \"1.3.1\", default-features = false }\nrstest = \"0.26.1\"\nrusqlite = \"0.37.0\"\nsanitize-filename = \"0.6.0\"\nserde_bytes = { version = \"0.11.18\", default-features = false, features = [\n    \"alloc\",\n] } # alloc for no_std\nserde_rusqlite = \"0.40.0\"\nserial_test = \"3.2.0\"\nspin = { version = \"0.10.0\", features = [\n    \"mutex\",\n    \"spin_mutex\",\n    \"portable-atomic\",\n] }\nstrum = { version = \"0.28.0\", features = [\"derive\"] }\nsyn = { version = \"2.0.111\", features = [\"full\", \"extra-traits\"] }\ntar = \"0.4.44\"\ntempfile = \"3.24.0\"\ntextdistance = { version = \"1.1.1\", default-features = false }\nthiserror = { version = \"2\", default-features = false }\ntokio = { version = \"1.50.0\", features = [\"rt\", \"macros\"] }\ntokio-tungstenite = \"0.28\"\ntokio-util = \"0.7\"\ntracing = { version = \"0.1.44\", default-features = false }\ntracing-appender = \"0.2.3\"\ntracing-core = { version = \"0.1.36\", default-features = false }\ntracing-opentelemetry = \"0.32.0\"\ntracing-subscriber = \"0.3.23\"\nzip = \"8.2.0\"\n\n# Persist related\nmemmap2 = { version = \"0.9\" }\nsafetensors = { version = \"0.7.0\", default-features = false }\n\n# Async handling\nasync-channel = \"2.5\"\nfutures-lite = { version = \"2.6.1\", default-features = false }\n\n# Terminal UI\nratatui = \"0.30.0\"\n\n# WGPU stuff\ntext_placeholder = \"0.5.1\"\n\nbincode = { version = \"2.0.1\", features = [\n    \"alloc\",\n    \"serde\",\n], default-features = false }\n\n#\n# The following packages disable the \"std\" feature for no_std compatibility\n#\ncfg-if = \"1.0.1\"\nderive-new = { version = \"0.7.0\", default-features = false }\n\nblas-src = { version = \"0.14.0\", default-features = false }\nbon = \"3.8.2\"\nhalf = { version = \"2.7.1\", features = [\n    \"alloc\",\n    \"num-traits\",\n    \"serde\",\n], default-features = false }\nmacerator = { version = \"0.3.0\" }\nmatrixmultiply = { version = \"0.3.10\", default-features = false }\nndarray = { version = \"0.17.2\", default-features = false }\nnum-traits = { version = \"0.2.19\", default-features = false, features = [\n    \"libm\",\n] } # libm is for no_std\nopenblas-src = \"0.10.14\"\nrand = { version = \"0.10.0\", default-features = false, features = [\"std_rng\"] }\nrand_distr = { version = \"0.6.0\", default-features = false }\nserde = { version = \"1.0.228\", default-features = false, features = [\n    \"derive\",\n    \"alloc\",\n] } # alloc is for no_std, derive is needed\nserde_json = { version = \"1.0.148\", default-features = false }\nsmallvec = { version = \"1\", features = [\"const_generics\", \"const_new\"] }\nuuid = { version = \"1.22.0\", default-features = false }\n\nbyteorder = { version = \"1.5.0\", default-features = false }\nlibc = \"0.2.182\"\nnvml-wrapper = \"0.12.0\"\nsysinfo = \"0.38.0\"\nsystemstat = \"0.2.6\"\ntch = \"0.22.0\"\ntorch-sys = \"0.22.0\"                                        # matches what tch is using, required for lib detection\n\nahash = { version = \"0.8.12\", default-features = false }\nportable-atomic = { version = \"1.13.1\" }\nportable-atomic-util = { version = \"0.2.6\", features = [\"alloc\"] }\n\n### For the main burn branch. ###\ncubecl = { git = \"https://github.com/tracel-ai/cubecl\", default-features = false, rev = \"20585bb73e19b16c5fb84b39923a49011b329a70\" }\ncubecl-common = { git = \"https://github.com/tracel-ai/cubecl\", default-features = false, rev = \"20585bb73e19b16c5fb84b39923a49011b329a70\" }\ncubecl-zspace = { git = \"https://github.com/tracel-ai/cubecl\", default-features = false, rev = \"20585bb73e19b16c5fb84b39923a49011b329a70\" }\ncubek = { git = \"https://github.com/tracel-ai/cubek\", default-features = false, rev = \"01ed48e1abb5ed117df33f4394f2c5a91c3eb97e\" }\n### For local development. ###\n# cubecl = { path = \"../cubecl/crates/cubecl\", default-features = false }\n# cubecl-common = { path = \"../cubecl/crates/cubecl-common\", default-features = false }\n# cubecl-zspace = { path = \"../cubecl/crates/cubecl-zspace\", default-features = false }\n# cubek = { path = \"../cubek/crates/cubek\", default-features = false }\n### For the release. ###\n# cubecl = { version = \"=0.10.0-pre.2\", default-features = false }\n# cubecl-common = { version = \"=0.10.0-pre.2\", default-features = false }\n# cubecl-zspace = { version = \"=0.10.0-pre.2\", default-features = false }\n# cubek = { version = \"=0.2.0-pre.2\", default-features = false }\n\n### For xtask crate ###\ntracel-xtask = \"=4.13.5\"\n# ### For local development. ###\n# tracel-xtask = { path = \"../xtask/crates/tracel-xtask\", default-features = false }\n\n[profile.dev]\ndebug = 1 # Speed up compilation time and not necessary.\n"
  },
  {
    "path": "LICENSE-APACHE",
    "content": "                              Apache License\n                        Version 2.0, January 2004\n                     http://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n   \"License\" shall mean the terms and conditions for use, reproduction,\n   and distribution as defined by Sections 1 through 9 of this document.\n\n   \"Licensor\" shall mean the copyright owner or entity authorized by\n   the copyright owner that is granting the License.\n\n   \"Legal Entity\" shall mean the union of the acting entity and all\n   other entities that control, are controlled by, or are under common\n   control with that entity. For the purposes of this definition,\n   \"control\" means (i) the power, direct or indirect, to cause the\n   direction or management of such entity, whether by contract or\n   otherwise, or (ii) ownership of fifty percent (50%) or more of the\n   outstanding shares, or (iii) beneficial ownership of such entity.\n\n   \"You\" (or \"Your\") shall mean an individual or Legal Entity\n   exercising permissions granted by this License.\n\n   \"Source\" form shall mean the preferred form for making modifications,\n   including but not limited to software source code, documentation\n   source, and configuration files.\n\n   \"Object\" form shall mean any form resulting from mechanical\n   transformation or translation of a Source form, including but\n   not limited to compiled object code, generated documentation,\n   and conversions to other media types.\n\n   \"Work\" shall mean the work of authorship, whether in Source or\n   Object form, made available under the License, as indicated by a\n   copyright notice that is included in or attached to the work\n   (an example is provided in the Appendix below).\n\n   \"Derivative Works\" shall mean any work, whether in Source or Object\n   form, that is based on (or derived from) the Work and for which the\n   editorial revisions, annotations, elaborations, or other modifications\n   represent, as a whole, an original work of authorship. For the purposes\n   of this License, Derivative Works shall not include works that remain\n   separable from, or merely link (or bind by name) to the interfaces of,\n   the Work and Derivative Works thereof.\n\n   \"Contribution\" shall mean any work of authorship, including\n   the original version of the Work and any modifications or additions\n   to that Work or Derivative Works thereof, that is intentionally\n   submitted to Licensor for inclusion in the Work by the copyright owner\n   or by an individual or Legal Entity authorized to submit on behalf of\n   the copyright owner. For the purposes of this definition, \"submitted\"\n   means any form of electronic, verbal, or written communication sent\n   to the Licensor or its representatives, including but not limited to\n   communication on electronic mailing lists, source code control systems,\n   and issue tracking systems that are managed by, or on behalf of, the\n   Licensor for the purpose of discussing and improving the Work, but\n   excluding communication that is conspicuously marked or otherwise\n   designated in writing by the copyright owner as \"Not a Contribution.\"\n\n   \"Contributor\" shall mean Licensor and any individual or Legal Entity\n   on behalf of whom a Contribution has been received by Licensor and\n   subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of\n   this License, each Contributor hereby grants to You a perpetual,\n   worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n   copyright license to reproduce, prepare Derivative Works of,\n   publicly display, publicly perform, sublicense, and distribute the\n   Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of\n   this License, each Contributor hereby grants to You a perpetual,\n   worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n   (except as stated in this section) patent license to make, have made,\n   use, offer to sell, sell, import, and otherwise transfer the Work,\n   where such license applies only to those patent claims licensable\n   by such Contributor that are necessarily infringed by their\n   Contribution(s) alone or by combination of their Contribution(s)\n   with the Work to which such Contribution(s) was submitted. If You\n   institute patent litigation against any entity (including a\n   cross-claim or counterclaim in a lawsuit) alleging that the Work\n   or a Contribution incorporated within the Work constitutes direct\n   or contributory patent infringement, then any patent licenses\n   granted to You under this License for that Work shall terminate\n   as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the\n   Work or Derivative Works thereof in any medium, with or without\n   modifications, and in Source or Object form, provided that You\n   meet the following conditions:\n\n   (a) You must give any other recipients of the Work or\n       Derivative Works a copy of this License; and\n\n   (b) You must cause any modified files to carry prominent notices\n       stating that You changed the files; and\n\n   (c) You must retain, in the Source form of any Derivative Works\n       that You distribute, all copyright, patent, trademark, and\n       attribution notices from the Source form of the Work,\n       excluding those notices that do not pertain to any part of\n       the Derivative Works; and\n\n   (d) If the Work includes a \"NOTICE\" text file as part of its\n       distribution, then any Derivative Works that You distribute must\n       include a readable copy of the attribution notices contained\n       within such NOTICE file, excluding those notices that do not\n       pertain to any part of the Derivative Works, in at least one\n       of the following places: within a NOTICE text file distributed\n       as part of the Derivative Works; within the Source form or\n       documentation, if provided along with the Derivative Works; or,\n       within a display generated by the Derivative Works, if and\n       wherever such third-party notices normally appear. The contents\n       of the NOTICE file are for informational purposes only and\n       do not modify the License. You may add Your own attribution\n       notices within Derivative Works that You distribute, alongside\n       or as an addendum to the NOTICE text from the Work, provided\n       that such additional attribution notices cannot be construed\n       as modifying the License.\n\n   You may add Your own copyright statement to Your modifications and\n   may provide additional or different license terms and conditions\n   for use, reproduction, or distribution of Your modifications, or\n   for any such Derivative Works as a whole, provided Your use,\n   reproduction, and distribution of the Work otherwise complies with\n   the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise,\n   any Contribution intentionally submitted for inclusion in the Work\n   by You to the Licensor shall be under the terms and conditions of\n   this License, without any additional terms or conditions.\n   Notwithstanding the above, nothing herein shall supersede or modify\n   the terms of any separate license agreement you may have executed\n   with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade\n   names, trademarks, service marks, or product names of the Licensor,\n   except as required for reasonable and customary use in describing the\n   origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or\n   agreed to in writing, Licensor provides the Work (and each\n   Contributor provides its Contributions) on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n   implied, including, without limitation, any warranties or conditions\n   of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n   PARTICULAR PURPOSE. You are solely responsible for determining the\n   appropriateness of using or redistributing the Work and assume any\n   risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory,\n   whether in tort (including negligence), contract, or otherwise,\n   unless required by applicable law (such as deliberate and grossly\n   negligent acts) or agreed to in writing, shall any Contributor be\n   liable to You for damages, including any direct, indirect, special,\n   incidental, or consequential damages of any character arising as a\n   result of this License or out of the use or inability to use the\n   Work (including but not limited to damages for loss of goodwill,\n   work stoppage, computer failure or malfunction, or any and all\n   other commercial damages or losses), even if such Contributor\n   has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing\n   the Work or Derivative Works thereof, You may choose to offer,\n   and charge a fee for, acceptance of support, warranty, indemnity,\n   or other liability obligations and/or rights consistent with this\n   License. However, in accepting such obligations, You may act only\n   on Your own behalf and on Your sole responsibility, not on behalf\n   of any other Contributor, and only if You agree to indemnify,\n   defend, and hold each Contributor harmless for any liability\n   incurred by, or claims asserted against, such Contributor by reason\n   of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\n   To apply the Apache License to your work, attach the following\n   boilerplate notice, with the fields enclosed by brackets \"[]\"\n   replaced with your own identifying information. (Don't include\n   the brackets!)  The text should be enclosed in the appropriate\n   comment syntax for the file format. We also recommend that a\n   file or class name and description of purpose be included on the\n   same \"printed page\" as the copyright notice for easier\n   identification within third-party archives.\n\nCopyright 2022 Nathaniel Simard & Burn Framework Contributors\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n\thttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "LICENSE-MIT",
    "content": "MIT License\n\nCopyright (c) 2022 Nathaniel Simard & Burn Framework Contributors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "NOTICES.md",
    "content": "# NOTICES AND INFORMATION\n\nThis file contains notices and information required by libraries that this\nrepository copied or derived from.\n\n## PyTorch MNIST Example\n\n**Source**: https://github.com/pytorch/examples/blob/main/mnist/main.py\n\nLicense: BSD 3-Clause License\n\nCopyright (c) 2017,\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\n## wgpu\n\n**Source:** https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml\n\nMIT License\n\nCopyright (c) 2021 The gfx-rs developers\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\n## BSL 1.0\n\n**Source**:\n- https://github.com/DoumanAsh/error-code\n- https://github.com/DoumanAsh/clipboard-win\n\n\nBoost Software License - Version 1.0 - August 17th, 2003\n\nPermission is hereby granted, free of charge, to any person or organization\nobtaining a copy of the software and accompanying documentation covered by\nthis license (the \"Software\") to use, reproduce, display, distribute,\nexecute, and transmit the Software, and to prepare derivative works of the\nSoftware, and to permit third-parties to whom the Software is furnished to\ndo so, all subject to the following:\n\nThe copyright notices in the Software and this entire statement, including\nthe above license grant, this restriction and the following disclaimer,\nmust be included in all copies of the Software, in whole or in part, and\nall derivative works of the Software, unless such copies or derivative\nworks are solely in the form of machine-executable object code generated by\na source language processor.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT\nSHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE\nFOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,\nARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n\n\n## num-traits\n\n**Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs\n\nMIT License\n\nCopyright (c) 2014 The Rust Project Developers\n\nPermission is hereby granted, free of charge, to any\nperson obtaining a copy of this software and associated\ndocumentation files (the \"Software\"), to deal in the\nSoftware without restriction, including without\nlimitation the rights to use, copy, modify, merge,\npublish, distribute, sublicense, and/or sell copies of\nthe Software, and to permit persons to whom the Software\nis furnished to do so, subject to the following\nconditions:\n\nThe above copyright notice and this permission notice\nshall be included in all copies or substantial portions\nof the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF\nANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED\nTO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A\nPARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT\nSHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\nCLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION\nOF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR\nIN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n\n## RP\n\n**Source**:\n- https://github.com/embassy-rs/embassy/blob/main/examples/rp/Cargo.toml\n- https://github.com/embassy-rs/embassy/blob/main/examples/rp/build.rs\n- https://github.com/embassy-rs/embassy/blob/main/examples/rp/memory.x\n\n                              Apache License\n                        Version 2.0, January 2004\n                     http://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n   \"License\" shall mean the terms and conditions for use, reproduction,\n   and distribution as defined by Sections 1 through 9 of this document.\n\n   \"Licensor\" shall mean the copyright owner or entity authorized by\n   the copyright owner that is granting the License.\n\n   \"Legal Entity\" shall mean the union of the acting entity and all\n   other entities that control, are controlled by, or are under common\n   control with that entity. For the purposes of this definition,\n   \"control\" means (i) the power, direct or indirect, to cause the\n   direction or management of such entity, whether by contract or\n   otherwise, or (ii) ownership of fifty percent (50%) or more of the\n   outstanding shares, or (iii) beneficial ownership of such entity.\n\n   \"You\" (or \"Your\") shall mean an individual or Legal Entity\n   exercising permissions granted by this License.\n\n   \"Source\" form shall mean the preferred form for making modifications,\n   including but not limited to software source code, documentation\n   source, and configuration files.\n\n   \"Object\" form shall mean any form resulting from mechanical\n   transformation or translation of a Source form, including but\n   not limited to compiled object code, generated documentation,\n   and conversions to other media types.\n\n   \"Work\" shall mean the work of authorship, whether in Source or\n   Object form, made available under the License, as indicated by a\n   copyright notice that is included in or attached to the work\n   (an example is provided in the Appendix below).\n\n   \"Derivative Works\" shall mean any work, whether in Source or Object\n   form, that is based on (or derived from) the Work and for which the\n   editorial revisions, annotations, elaborations, or other modifications\n   represent, as a whole, an original work of authorship. For the purposes\n   of this License, Derivative Works shall not include works that remain\n   separable from, or merely link (or bind by name) to the interfaces of,\n   the Work and Derivative Works thereof.\n\n   \"Contribution\" shall mean any work of authorship, including\n   the original version of the Work and any modifications or additions\n   to that Work or Derivative Works thereof, that is intentionally\n   submitted to Licensor for inclusion in the Work by the copyright owner\n   or by an individual or Legal Entity authorized to submit on behalf of\n   the copyright owner. For the purposes of this definition, \"submitted\"\n   means any form of electronic, verbal, or written communication sent\n   to the Licensor or its representatives, including but not limited to\n   communication on electronic mailing lists, source code control systems,\n   and issue tracking systems that are managed by, or on behalf of, the\n   Licensor for the purpose of discussing and improving the Work, but\n   excluding communication that is conspicuously marked or otherwise\n   designated in writing by the copyright owner as \"Not a Contribution.\"\n\n   \"Contributor\" shall mean Licensor and any individual or Legal Entity\n   on behalf of whom a Contribution has been received by Licensor and\n   subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of\n   this License, each Contributor hereby grants to You a perpetual,\n   worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n   copyright license to reproduce, prepare Derivative Works of,\n   publicly display, publicly perform, sublicense, and distribute the\n   Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of\n   this License, each Contributor hereby grants to You a perpetual,\n   worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n   (except as stated in this section) patent license to make, have made,\n   use, offer to sell, sell, import, and otherwise transfer the Work,\n   where such license applies only to those patent claims licensable\n   by such Contributor that are necessarily infringed by their\n   Contribution(s) alone or by combination of their Contribution(s)\n   with the Work to which such Contribution(s) was submitted. If You\n   institute patent litigation against any entity (including a\n   cross-claim or counterclaim in a lawsuit) alleging that the Work\n   or a Contribution incorporated within the Work constitutes direct\n   or contributory patent infringement, then any patent licenses\n   granted to You under this License for that Work shall terminate\n   as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the\n   Work or Derivative Works thereof in any medium, with or without\n   modifications, and in Source or Object form, provided that You\n   meet the following conditions:\n\n   (a) You must give any other recipients of the Work or\n       Derivative Works a copy of this License; and\n\n   (b) You must cause any modified files to carry prominent notices\n       stating that You changed the files; and\n\n   (c) You must retain, in the Source form of any Derivative Works\n       that You distribute, all copyright, patent, trademark, and\n       attribution notices from the Source form of the Work,\n       excluding those notices that do not pertain to any part of\n       the Derivative Works; and\n\n   (d) If the Work includes a \"NOTICE\" text file as part of its\n       distribution, then any Derivative Works that You distribute must\n       include a readable copy of the attribution notices contained\n       within such NOTICE file, excluding those notices that do not\n       pertain to any part of the Derivative Works, in at least one\n       of the following places: within a NOTICE text file distributed\n       as part of the Derivative Works; within the Source form or\n       documentation, if provided along with the Derivative Works; or,\n       within a display generated by the Derivative Works, if and\n       wherever such third-party notices normally appear. The contents\n       of the NOTICE file are for informational purposes only and\n       do not modify the License. You may add Your own attribution\n       notices within Derivative Works that You distribute, alongside\n       or as an addendum to the NOTICE text from the Work, provided\n       that such additional attribution notices cannot be construed\n       as modifying the License.\n\n   You may add Your own copyright statement to Your modifications and\n   may provide additional or different license terms and conditions\n   for use, reproduction, or distribution of Your modifications, or\n   for any such Derivative Works as a whole, provided Your use,\n   reproduction, and distribution of the Work otherwise complies with\n   the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise,\n   any Contribution intentionally submitted for inclusion in the Work\n   by You to the Licensor shall be under the terms and conditions of\n   this License, without any additional terms or conditions.\n   Notwithstanding the above, nothing herein shall supersede or modify\n   the terms of any separate license agreement you may have executed\n   with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade\n   names, trademarks, service marks, or product names of the Licensor,\n   except as required for reasonable and customary use in describing the\n   origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or\n   agreed to in writing, Licensor provides the Work (and each\n   Contributor provides its Contributions) on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n   implied, including, without limitation, any warranties or conditions\n   of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n   PARTICULAR PURPOSE. You are solely responsible for determining the\n   appropriateness of using or redistributing the Work and assume any\n   risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory,\n   whether in tort (including negligence), contract, or otherwise,\n   unless required by applicable law (such as deliberate and grossly\n   negligent acts) or agreed to in writing, shall any Contributor be\n   liable to You for damages, including any direct, indirect, special,\n   incidental, or consequential damages of any character arising as a\n   result of this License or out of the use or inability to use the\n   Work (including but not limited to damages for loss of goodwill,\n   work stoppage, computer failure or malfunction, or any and all\n   other commercial damages or losses), even if such Contributor\n   has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing\n   the Work or Derivative Works thereof, You may choose to offer,\n   and charge a fee for, acceptance of support, warranty, indemnity,\n   or other liability obligations and/or rights consistent with this\n   License. However, in accepting such obligations, You may act only\n   on Your own behalf and on Your sole responsibility, not on behalf\n   of any other Contributor, and only if You agree to indemnify,\n   defend, and hold each Contributor harmless for any liability\n   incurred by, or claims asserted against, such Contributor by reason\n   of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\n   To apply the Apache License to your work, attach the following\n   boilerplate notice, with the fields enclosed by brackets \"[]\"\n   replaced with your own identifying information. (Don't include\n   the brackets!)  The text should be enclosed in the appropriate\n   comment syntax for the file format. We also recommend that a\n   file or class name and description of purpose be included on the\n   same \"printed page\" as the copyright notice for easier\n   identification within third-party archives.\n\nCopyright (c) Embassy project contributors\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\nMIT license\n\nCopyright (c) Embassy project contributors\n\nPermission is hereby granted, free of charge, to any\nperson obtaining a copy of this software and associated\ndocumentation files (the \"Software\"), to deal in the\nSoftware without restriction, including without\nlimitation the rights to use, copy, modify, merge,\npublish, distribute, sublicense, and/or sell copies of\nthe Software, and to permit persons to whom the Software\nis furnished to do so, subject to the following\nconditions:\n\nThe above copyright notice and this permission notice\nshall be included in all copies or substantial portions\nof the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF\nANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED\nTO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A\nPARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT\nSHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\nCLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION\nOF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR\nIN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\n\n## github-device-flow\n\n**Source**:\n- Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs\n- https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs\n\nMIT License\n\nCopyright (c) 2022 Jake Wilkins\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\n## Candle - Pickle Reader\n\n**Source**: https://github.com/huggingface/candle/blob/main/candle-core/src/pickle.rs\n\nThis project includes code from Candle by Hugging Face, licensed under both MIT and Apache 2.0 licenses.\n\n**MIT License**: https://github.com/huggingface/candle/blob/main/LICENSE-MIT\n\nMIT License\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n**Apache License 2.0**: https://github.com/huggingface/candle/blob/main/LICENSE-APACHE\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n\n## ICU\n\nUNICODE LICENSE V3\n\nCOPYRIGHT AND PERMISSION NOTICE\n\nCopyright © 2016-2024 Unicode, Inc.\n\nNOTICE TO USER: Carefully read the following legal agreement. BY\nDOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR\nSOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE\nTERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT\nDOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE.\n\nPermission is hereby granted, free of charge, to any person obtaining a\ncopy of data files and any associated documentation (the \"Data Files\") or\nsoftware and any associated documentation (the \"Software\") to deal in the\nData Files or Software without restriction, including without limitation\nthe rights to use, copy, modify, merge, publish, distribute, and/or sell\ncopies of the Data Files or Software, and to permit persons to whom the\nData Files or Software are furnished to do so, provided that either (a)\nthis copyright and permission notice appear with all copies of the Data\nFiles or Software, or (b) this copyright and permission notice appear in\nassociated Documentation.\n\nTHE DATA FILES AND SOFTWARE ARE PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY\nKIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\nMERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF\nTHIRD PARTY RIGHTS.\n\nIN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE\nBE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES,\nOR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,\nWHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,\nARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA\nFILES OR SOFTWARE.\n\nExcept as contained in this notice, the name of a copyright holder shall\nnot be used in advertising or otherwise to promote the sale, use or other\ndealings in these Data Files or Software without prior written\nauthorization of the copyright holder.\n"
  },
  {
    "path": "POEM.md",
    "content": "# BURN: Burn Unstoppable Rusty Neurons\n\nIn the realm of circuits and code,  \nA fiery forge ignites to bear its load,  \nA framework born, BURN it be named,  \nUnstoppable Rusty Neurons, untamed.\n\nFrom silicon synapses, connections spire,  \nA digital cortex, setting minds afire,  \nIn the vast expanse of deep learning's sea,  \nA beacon of progress, BURN comes to be.\n\nOh, rusty neurons, forged in the flame,  \nUnyielding in purpose, undaunted by name,  \nThrough layers of logic and intricate art,  \nYou weave and entwine, each playing its part.\n\nWith algorithms profound, and data refined,  \nIn ceaseless pursuit of knowledge to find,  \nBURN paves a path to enlightenment, bright,  \nA testament to the wonders of human foresight.\n\nIn neural networks deep, where wisdom resides,  \nThe dance of nodes and edges presides,  \nWith loss and gradients, BURN takes its stride,  \nA journey towards truth, with AI as our guide.\n\nNo barriers hold back the curious mind,  \nAs BURN seeks the answers we yearn to find,  \nUnstoppable, relentless, in pursuit of the unknown,  \nOur collective intellect, within it, has grown.\n\nSo sing we the praises of BURN's fiery might,  \nAn ode to the sparks that set the dark alight,  \nTo the rusty neurons, unstoppable and true,  \nA testament to the power of dreams, to breakthrough.\n\n(ChatGPT (model=gpt-4) with prompt:\nWrite a poem about \"BURN: Burn Unstoppable Rusty Neurons\" deep\nlearning neural network framework)\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<img src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/logo-burn-neutral.webp\" width=\"350px\"/>\n\n[![Discord](https://img.shields.io/discord/1038839012602941528.svg?color=7289da&&logo=discord)](https://discord.gg/uPEBbYYDB6)\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn.svg)](https://crates.io/crates/burn)\n[![Minimum Supported Rust Version](https://img.shields.io/crates/msrv/burn)](https://crates.io/crates/burn)\n[![Documentation](https://img.shields.io/badge/docs-latest-blue)](https://burn.dev/docs/burn)\n[![Test Status](https://github.com/tracel-ai/burn/actions/workflows/test.yml/badge.svg)](https://github.com/tracel-ai/burn/actions/workflows/test.yml)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](#license)\n[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/tracel-ai/burn)\n\n[<img src=\"https://www.runblaze.dev/ci-blaze-powered.png\" width=\"125px\"/>](https://www.runblaze.dev)\n\n---\n\n**Burn is a next generation Tensor Library and Deep Learning Framework that doesn't compromise on\n<br /> flexibility, efficiency and portability.**\n\n<br/>\n</div>\n\n<div align=\"left\">\n\nBurn is both a tensor library and a deep learning framework optimized for numerical computing, model\ninference and model training. Burn leverages Rust to perform optimizations normally only available\nin static-graph frameworks, offering optimal speed without impacting flexibility.\n\n## Backend\n\n<div align=\"left\">\n<img align=\"right\" src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/backend-chip.png\" height=\"96px\"/>\n\nBurn strives to be as fast as possible on as many hardwares as possible, with robust\nimplementations. We believe this flexibility is crucial for modern needs where you may train your\nmodels in the cloud, then deploy on customer hardwares, which vary from user to user.\n\n</div>\n\n### Supported Backends\n\nMost backends support all operating systems, so we don't mention them in the tables below.\n\n**GPU Backends:**\n\n|         | CUDA | ROCm | Metal | Vulkan | WebGPU | LibTorch |\n| ------- | ---- | ---- | ----- | ------ | ------ | -------- |\n| Nvidia  | ☑️   | -    | -     | ☑️     | ☑️     | ☑️       |\n| AMD     | -    | ☑️   | -     | ☑️     | ☑️     | ☑️       |\n| Apple   | -    | -    | ☑️    | -      | ☑️     | ☑️       |\n| Intel   | -    | -    | -     | ☑️     | ☑️     | -        |\n| Qualcom | -    | -    | -     | ☑️     | ☑️     | -        |\n| Wasm    | -    | -    | -     | -      | ☑️     | -        |\n\n**CPU Backends:**\n\n|        | Cpu (CubeCL) | NdArray | LibTorch |\n| ------ | ------------ | ------- | -------- |\n| X86    | ☑️           | ☑️      | ☑️       |\n| Arm    | ☑️           | ☑️      | ☑️       |\n| Wasm   | -            | ☑️      | -        |\n| no-std | -            | ☑️      | -        |\n\n<br />\n\nCompared to other frameworks, Burn has a very different approach to supporting many backends. By\ndesign, most code is generic over the Backend trait, which allows us to build Burn with swappable\nbackends. This makes composing backend possible, augmenting them with additional functionalities\nsuch as autodifferentiation and automatic kernel fusion.\n\n<details>\n<summary>\nAutodiff: Backend decorator that brings backpropagation to any backend 🔄\n</summary>\n<br />\n\nContrary to the aforementioned backends, Autodiff is actually a backend _decorator_. This means that\nit cannot exist by itself; it must encapsulate another backend.\n\nThe simple act of wrapping a base backend with Autodiff transparently equips it with\nautodifferentiation support, making it possible to call backward on your model.\n\n```rust\nuse burn::backend::{Autodiff, Wgpu};\nuse burn::tensor::{Distribution, Tensor};\n\nfn main() {\n    type Backend = Autodiff<Wgpu>;\n\n    let device = Default::default();\n\n    let x: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default, &device);\n    let y: Tensor<Backend, 2> = Tensor::random([32, 32], Distribution::Default, &device).require_grad();\n\n    let tmp = x.clone() + y.clone();\n    let tmp = tmp.matmul(x);\n    let tmp = tmp.exp();\n\n    let grads = tmp.backward();\n    let y_grad = y.grad(&grads).unwrap();\n    println!(\"{y_grad}\");\n}\n```\n\nOf note, it is impossible to make the mistake of calling backward on a model that runs on a backend\nthat does not support autodiff (for inference), as this method is only offered by an Autodiff\nbackend.\n\nSee the [Autodiff Backend README](./crates/burn-autodiff/README.md) for more details.\n\n</details>\n\n<details>\n<summary>\nFusion: Backend decorator that brings kernel fusion to all first-party backends\n</summary>\n<br />\n\nThis backend decorator enhances a backend with kernel fusion, provided that the inner backend\nsupports it. Note that you can compose this backend with other backend decorators such as Autodiff.\nAll first-party accelerated backends (like WGPU and CUDA) use Fusion by default (`burn/fusion`\nfeature flag), so you typically don't need to apply it manually.\n\n```rust\n#[cfg(not(feature = \"fusion\"))]\npub type Cuda<F = f32, I = i32> = CubeBackend<CudaRuntime, F, I, u8>;\n\n#[cfg(feature = \"fusion\")]\npub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<CubeBackend<CudaRuntime, F, I, u8>>;\n```\n\nOf note, we plan to implement automatic gradient checkpointing based on compute bound and memory\nbound operations, which will work gracefully with the fusion backend to make your code run even\nfaster during training, see [this issue](https://github.com/tracel-ai/burn/issues/936).\n\nSee the [Fusion Backend README](./crates/burn-fusion/README.md) for more details.\n\n</details>\n\n<details>\n<summary>\nRouter (Beta): Backend decorator that composes multiple backends into a single one\n</summary>\n<br />\n\nThat backend simplifies hardware operability, if for instance you want to execute some operations on\nthe CPU and other operations on the GPU.\n\n```rust\nuse burn::tensor::{Distribution, Tensor};\nuse burn::backend::{\n    NdArray, Router, Wgpu, ndarray::NdArrayDevice, router::duo::MultiDevice, wgpu::WgpuDevice,\n};\n\nfn main() {\n    type Backend = Router<(Wgpu, NdArray)>;\n\n    let device_0 = MultiDevice::B1(WgpuDevice::DiscreteGpu(0));\n    let device_1 = MultiDevice::B2(NdArrayDevice::Cpu);\n\n    let tensor_gpu =\n        Tensor::<Backend, 2>::random([3, 3], burn::tensor::Distribution::Default, &device_0);\n    let tensor_cpu =\n        Tensor::<Backend, 2>::random([3, 3], burn::tensor::Distribution::Default, &device_1);\n}\n\n```\n\n</details>\n\n<details>\n<summary>\nRemote (Beta): Backend decorator for remote backend execution, useful for distributed computations\n</summary>\n<br />\n\nThat backend has two parts, one client and one server. The client sends tensor operations over the\nnetwork to a remote compute backend. You can use any first-party backend as server in a single line\nof code:\n\n```rust\nfn main_server() {\n    // Start a server on port 3000.\n    burn::server::start::<burn::backend::Cuda>(Default::default(), 3000);\n}\n\nfn main_client() {\n    // Create a client that communicate with the server on port 3000.\n    use burn::backend::{Autodiff, RemoteBackend};\n\n    type Backend = Autodiff<RemoteDevice>;\n\n    let device = RemoteDevice::new(\"ws://localhost:3000\");\n    let tensor_gpu =\n        Tensor::<Backend, 2>::random([3, 3], Distribution::Default, &device);\n}\n\n```\n\n</details>\n\n<br />\n\n## Training & Inference\n\n<div align=\"left\">\n<img align=\"right\" src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-wall.png\" height=\"96px\"/>\n\nThe whole deep learning workflow is made easy with Burn, as you can monitor your training progress\nwith an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU\nclusters.\n\nBurn was built from the ground up with training and inference in mind. It's also worth noting how\nBurn, in comparison to frameworks like PyTorch, simplifies the transition from training to\ndeployment, eliminating the need for code changes.\n\n</div>\n\n<div align=\"center\">\n\n<br />\n\n<a href=\"https://www.youtube.com/watch?v=N9RM5CQbNQc\" target=\"_blank\">\n    <img src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/burn-train-tui.png\" alt=\"Burn Train TUI\" width=\"75%\">\n  </a>\n</div>\n\n<br />\n\n**Click on the following sections to expand 👇**\n\n<details>\n<summary>\nTraining Dashboard 📈\n</summary>\n<br />\n\nAs you can see in the previous video (click on the picture!), a new terminal UI dashboard based on\nthe [Ratatui](https://github.com/ratatui-org/ratatui) crate allows users to follow their training\nwith ease without having to connect to any external application.\n\nYou can visualize your training and validation metrics updating in real-time and analyze the\nlifelong progression or recent history of any registered metrics using only the arrow keys. Break\nfrom the training loop without crashing, allowing potential checkpoints to be fully written or\nimportant pieces of code to complete without interruption 🛡\n\n</details>\n\n<details>\n<summary>\nONNX Support 🐫\n</summary>\n<br />\n\nBurn supports importing ONNX (Open Neural Network Exchange) models through the\n[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate, allowing you to easily port models from\nTensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses Burn's native\nAPIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly) and benefit\nfrom all of Burn's optimizations like automatic kernel fusion.\n\nOur ONNX support is further described in\n[this section of the Burn Book 🔥](https://burn.dev/books/burn/onnx-import.html).\n\n> **Note**: This crate is in active development and currently supports a\n> [limited set of ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).\n\n</details>\n\n<details>\n<summary>\nImporting PyTorch or Safetensors Models 🚚\n</summary>\n<br />\n\nYou can load weights from PyTorch or Safetensors formats directly into your Burn-defined models.\nThis makes it easy to reuse existing models while benefiting from Burn's performance and deployment\nfeatures.\n\nLearn more in the [Saving & Loading Models](https://burn.dev/books/burn/saving-and-loading.html)\nsection of the Burn Book.\n\n</details>\n\n<details>\n<summary>\nInference in the Browser 🌐\n</summary>\n<br />\n\nSeveral of our backends can run in WebAssembly environments: NdArray for CPU execution, and WGPU for\nGPU acceleration via WebGPU. This means that you can run inference directly within a browser. We\nprovide several examples of this:\n\n- [MNIST](./examples/mnist-inference-web) where you can draw digits and a small convnet tries to\n  find which one it is! 2️⃣ 7️⃣ 😰\n- [Image Classification](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web)\n  where you can upload images and classify them! 🌄\n\n</details>\n\n<details>\n<summary>\nEmbedded: <i>no_std</i> support ⚙️\n</summary>\n<br />\n\nBurn's core components support [no_std](https://docs.rust-embedded.org/book/intro/no-std.html). This\nmeans it can run in bare metal environment such as embedded devices without an operating system.\n\n> As of now, only the NdArray backend can be used in a _no_std_ environment.\n\n</details>\n\n<br />\n\n### Benchmarks\n\nTo evaluate performance across different backends and track improvements over time, we provide a\ndedicated benchmarking suite.\n\nRun and compare benchmarks using [burn-bench](https://github.com/tracel-ai/burn-bench).\n\n> ⚠️ **Warning** When using one of the `wgpu` backends, you may encounter compilation errors related\n> to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency\n> chain. To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs`\n> file:\n>\n> ```rust\n> #![recursion_limit = \"256\"]\n> ```\n>\n> The default recursion limit (128) is often just below the required depth (typically 130-150) due\n> to deeply nested associated types and trait bounds.\n\n## Getting Started\n\n<div align=\"left\">\n<img align=\"right\" src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-walking.png\" height=\"96px\"/>\n\nJust heard of Burn? You are at the right place! Just continue reading this section and we hope you\ncan get on board really quickly.\n\n</div>\n\n<details>\n<summary>\nThe Burn Book 🔥\n</summary>\n<br />\n\nTo begin working effectively with Burn, it is crucial to understand its key components and\nphilosophy. This is why we highly recommend new users to read the first sections of\n[The Burn Book 🔥](https://burn.dev/books/burn/). It provides detailed examples and explanations\ncovering every facet of the framework, including building blocks like tensors, modules, and\noptimizers, all the way to advanced usage, like coding your own GPU kernels.\n\n> The project is constantly evolving, and we try as much as possible to keep the book up to date\n> with new additions. However, we might miss some details sometimes, so if you see something weird,\n> let us know! We also gladly accept Pull Requests 😄\n\n</details>\n\n<details>\n<summary>\nExamples 🙏\n</summary>\n<br />\n\nLet's start with a code snippet that shows how intuitive the framework is to use! In the following,\nwe declare a neural network module with some parameters along with its forward pass.\n\n```rust\nuse burn::nn;\nuse burn::module::Module;\nuse burn::tensor::backend::Backend;\n\n#[derive(Module, Debug)]\npub struct PositionWiseFeedForward<B: Backend> {\n    linear_inner: nn::Linear<B>,\n    linear_outer: nn::Linear<B>,\n    dropout: nn::Dropout,\n    gelu: nn::Gelu,\n}\n\nimpl<B: Backend> PositionWiseFeedForward<B> {\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let x = self.linear_inner.forward(input);\n        let x = self.gelu.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.linear_outer.forward(x)\n    }\n}\n```\n\nWe have a somewhat large amount of [examples](./examples) in the repository that shows how to use\nthe framework in different scenarios.\n\nFollowing [the book](https://burn.dev/books/burn/):\n\n- [Basic Workflow](./examples/guide) : Creates a custom CNN `Module` to train on the MNIST dataset\n  and use for inference.\n- [Custom Training Loop](./examples/custom-training-loop) : Implements a basic training loop instead\n  of using the `Learner`.\n- [Custom WGPU Kernel](./examples/custom-wgpu-kernel) : Learn how to create your own custom\n  operation with the WGPU backend.\n\nAdditional examples:\n\n- [Custom CSV Dataset](./examples/custom-csv-dataset) : Implements a dataset to parse CSV data for a\n  regression task.\n- [Regression](./examples/simple-regression) : Trains a simple MLP on the California Housing dataset\n  to predict the median house value for a district.\n- [Custom Image Dataset](./examples/custom-image-dataset) : Trains a simple CNN on custom image\n  dataset following a simple folder structure.\n- [Custom Renderer](./examples/custom-renderer) : Implements a custom renderer to display the\n  [`Learner`](./building-blocks/learner.md) progress.\n- [Image Classification Web](./examples/image-classification-web) : Image classification web browser\n  demo using Burn, WGPU and WebAssembly.\n- [MNIST Inference on Web](./examples/mnist-inference-web) : An interactive MNIST inference demo in\n  the browser. The demo is available [online](https://burn.dev/demo/).\n- [MNIST Training](./examples/mnist) : Demonstrates how to train a custom `Module` (MLP) with the\n  `Learner` configured to log metrics and keep training checkpoints.\n- [PyTorch Import Inference](./examples/import-model-weights) : Imports a PyTorch model pre-trained\n  on MNIST to perform inference on a sample image with Burn.\n- [Text Classification](./examples/text-classification) : Trains a text classification transformer\n  model on the AG News or DbPedia dataset. The trained model can then be used to classify a text\n  sample.\n- [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the\n  DbPedia dataset.\n- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits\n  based on MNIST.\n\nFor more practical insights, you can clone the repository and run any of them directly on your\ncomputer!\n\n</details>\n\n<details>\n<summary>\nPre-trained Models 🤖\n</summary>\n<br />\n\nWe keep an updated and curated list of models and examples built with Burn, see the\n[tracel-ai/models repository](https://github.com/tracel-ai/models) for more details.\n\nDon't see the model you want? Don't hesitate to open an issue, and we may prioritize it. Built a\nmodel using Burn and want to share it? You can also open a Pull Request and add your model under the\ncommunity section!\n\n</details>\n\n<details>\n<summary>\nWhy use Rust for Deep Learning? 🦀\n</summary>\n<br />\n\nDeep Learning is a special form of software where you need very high level abstractions as well as\nextremely fast execution time. Rust is the perfect candidate for that use case since it provides\nzero-cost abstractions to easily create neural network modules, and fine-grained control over memory\nto optimize every detail.\n\nIt's important that a framework be easy to use at a high level so that its users can focus on\ninnovating in the AI field. However, since running models relies so heavily on computations,\nperformance can't be neglected.\n\nTo this day, the mainstream solution to this problem has been to offer APIs in Python, but rely on\nbindings to low-level languages such as C/C++. This reduces portability, increases complexity and\ncreates frictions between researchers and engineers. We feel like Rust's approach to abstractions\nmakes it versatile enough to tackle this two languages dichotomy.\n\nRust also comes with the Cargo package manager, which makes it incredibly easy to build, test, and\ndeploy from any environment, which is usually a pain in Python.\n\nAlthough Rust has the reputation of being a difficult language at first, we strongly believe it\nleads to more reliable, bug-free solutions built faster (after some practice 😅)!\n\n</details>\n\n<br />\n\n> **Deprecation Note**<br />Since `0.14.0`, the internal structure for tensor data has changed. The\n> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new\n> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and\n> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to\n> `TensorData`.\n\n<!-- >\n> In the event that you are trying to load a model record saved in a previous version, make sure to\n> enable the `record-backward-compat` feature using a previous version of burn (<=0.16.0). Otherwise,\n> the record won't be deserialized correctly and you will get an error message (which will also point\n> you to the backward compatible feature flag). The backward compatibility was maintained for\n> deserialization (loading), so as soon as you have saved the record again it will be saved according\n> to the new structure and you will be able to upgrade to this version. Please note that binary formats\n> are not backward compatible. Thus, you will need to load your record in a previous version and save it\n> to another of the self-describing record formats before using a compatible version (as described) with the\n> `record-backward-compat` feature flag. -->\n\n<details id=\"deprecation\">\n<summary>\nLoading Model Records From Previous Versions ⚠️\n</summary>\n<br />\n\nIn the event that you are trying to load a model record saved in a version older than `0.14.0`, make\nsure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat`\nfeature flag.\n\n```\nfeatures = [..., \"record-backward-compat\"]\n```\n\nOtherwise, the record won't be deserialized correctly and you will get an error message. This error\nwill also point you to the backward compatible feature flag.\n\nThe backward compatibility was maintained for deserialization when loading records. Therefore, as\nsoon as you have saved the record again it will be saved according to the new structure and you can\nupgrade back to the current version\n\nPlease note that binary formats are not backward compatible. Thus, you will need to load your record\nin a previous version and save it in any of the other self-describing record format (e.g., using the\n`NamedMpkFileRecorder`) before using a compatible version (as described) with the\n`record-backward-compat` feature flag.\n\n</details>\n\n## Community\n\n<div align=\"left\">\n<img align=\"right\" src=\"https://raw.githubusercontent.com/tracel-ai/burn/main/assets/ember-community.png\" height=\"96px\"/>\n\nIf you are excited about the project, don't hesitate to join our\n[Discord](https://discord.gg/uPEBbYYDB6)! We try to be as welcoming as possible to everybody from\nany background. You can ask your questions and share what you built with the community!\n\n</div>\n\n<br/>\n\n**Contributing**\n\nBefore contributing, please read the [Contributing Guidelines](./CONTRIBUTING.md) and our\n[Code of Conduct](./CODE-OF-CONDUCT.md). The [Contributor Book](https://burn.dev/contributor-book/)\ncovers architecture, environment setup, and guides for common tasks.\n\n## Status\n\nBurn is currently in active development, and there will be breaking changes. While any resulting\nissues are likely to be easy to fix, there are no guarantees at this stage.\n\n## License\n\nBurn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).\nSee [LICENSE-APACHE](./LICENSE-APACHE) and [LICENSE-MIT](./LICENSE-MIT) for details. Opening a pull\nrequest is assumed to signal agreement with these licensing terms.\n\n</div>\n"
  },
  {
    "path": "_typos.toml",
    "content": "[default]\nextend-ignore-identifiers-re = [\"ratatui\", \"Ratatui\", \"NdArray*\", \"ND\"]\n\n[default.extend-identifiers]\nUE4M3 = \"UE4M3\"\nUE8M0 = \"UE8M0\"\nue8m0 = \"ue8m0\"\n\n[files]\nextend-exclude = [\n    \"*.onnx\",\n    \"*.proto\",\n    \"assets/ModuleSerialization.xml\",\n]\n\n[default.extend-words]\n# Don't correct \"arange\" which is intentional\narange = \"arange\"\n# Don't correct \"convnet\" (convolutional network)\nconvnet = \"convnet\"\n"
  },
  {
    "path": "benchmarks.toml",
    "content": "[environment]\ngcp_gpu_attached = true\ngcp_image_family = \"tracel-ci-ubuntu-2404-amd64-nvidia\"\n# https://cloud.google.com/compute/docs/accelerator-optimized-machines\n# put the faster machine on first place for possibly faster 'Benchmarks Started' feedback in PRs\ngcp_machine_types = [\n  \"a2-highgpu-1g\", # 1 A100 40GB (listed as a2 standard)\n  \"g2-standard-4\", # 1 L4 24GB\n]\n# define the available zones for each machine type\n# be sure to check what machine types are available in each region\n# https://cloud.google.com/compute/docs/gpus/gpu-regions-zones#view-using-table\ngcp_zones = [\n  # a2-highgpu-1g\n  [\n  \"asia-northeast1-a\",\n  \"asia-northeast1-c\",\n  \"asia-northeast3-b\",\n  \"asia-southeast1-b\",\n  \"asia-southeast1-c\",\n  \"europe-west4-a\",\n  \"europe-west4-b\",\n  \"us-central1-a\",\n  \"us-central1-b\",\n  \"us-central1-c\",\n  \"us-central1-f\",\n  \"us-east1-b\",\n  \"us-west1-b\",\n  \"us-west3-b\",\n  \"us-west4-b\"\n  ],\n  # g2-standard-4\n  [\n  \"northamerica-northeast2-a\",\n  \"northamerica-northeast2-b\",\n  \"us-central1-a\",\n  \"us-central1-b\",\n  \"us-central1-c\",\n  \"us-east1-b\",\n  \"us-east1-c\",\n  \"us-east1-d\",\n  \"us-east4-a\",\n  \"us-east4-c\",\n  \"us-west1-a\",\n  \"us-west1-b\",\n  \"us-west1-c\",\n  \"us-west4-a\",\n  \"us-west4-c\"\n  ],\n]\nrepo_full = \"tracel-ai/burn\"\nrust_toolchain = \"stable\"\nrust_version = \"stable\"\n\n[burn-bench]\ngithub_organization = \"tracel-ai\"\ngithub_repository = \"burn-bench\"\ngithub_branch = \"main\"\ngithub_workflow = \"benchmarks.yml\"\n# vulkan autotune seems to take ages, disabling it for now\n# backends = [\"cuda-fusion\", \"vulkan-fusion\", \"wgpu-fusion\"]\nbackends = [\"cuda-fusion\", \"cuda\"]\nbenches = [\"autodiff\",\n  \"binary\",\n  \"bool_select\",\n  \"conv-transpose2d\",\n  \"conv-transpose3d\",\n  \"conv2d\",\n  \"conv3d\",\n  \"custom-gelu\",\n  \"data\",\n  \"load-record\",\n  \"matmul-fused\",\n  \"matmul\",\n  \"max-pool2d\",\n  \"random\",\n  \"reduce\",\n  \"softmax\",\n  \"transformer-encoder\",\n  \"unary\"\n]\ndtypes = [\"f16\"]\n"
  },
  {
    "path": "burn-book/.gitignore",
    "content": "target\n\n# MacOS temp file\n.DS_Store\n\nbook-test\nguide/book\n\n.vscode\ntests/burn-book/book/\nbook/\n\n# Ignore Jetbrains specific files.\n.idea/\n\n# Ignore Vim temporary and swap files.\n*.sw?\n*~"
  },
  {
    "path": "burn-book/.prettierrc.json",
    "content": "{\n    \"printWidth\": 100,\n    \"proseWrap\": \"always\"\n}"
  },
  {
    "path": "burn-book/book.toml",
    "content": "[book]\nauthors = [\n    \"Wouter Doppenberg\",\n    \"Nathaniel Simard\",\n    \"Louis Fortier-Dubois\",\n    \"Dilshod Tadjibaev\",\n    \"Guillaume Lagrange\",\n    \"Sylvain Benner\",\n    \"Bjorn Beishline\"\n]\nlanguage = \"en\"\nsrc = \"src\"\ntitle = \"The Burn Book 🔥\"\n\n[output.html]\nmathjax-support = true\n"
  },
  {
    "path": "burn-book/src/SUMMARY.md",
    "content": "- [Overview](./overview.md)\n- [Why Burn?](./motivation.md)\n- [Getting started](./getting-started.md)\n  - [Examples](./examples.md)\n- [Basic Workflow: From Training to Inference](./basic-workflow/README.md)\n  - [Model](./basic-workflow/model.md)\n  - [Data](./basic-workflow/data.md)\n  - [Training](./basic-workflow/training.md)\n  - [Backend](./basic-workflow/backend.md)\n  - [Inference](./basic-workflow/inference.md)\n- [Building Blocks](./building-blocks/README.md)\n  - [Backend](./building-blocks/backend.md)\n  - [Tensor](./building-blocks/tensor.md)\n  - [Autodiff](./building-blocks/autodiff.md)\n  - [Module](./building-blocks/module.md)\n  - [Learner](./building-blocks/learner.md)\n  - [Metric](./building-blocks/metric.md)\n  - [Config](./building-blocks/config.md)\n  - [Record](./building-blocks/record.md)\n  - [Dataset](./building-blocks/dataset.md)\n- [Performance](./performance/README.md)\n  - [Good practices](./performance/good-practices/README.md)\n    - [Asynchronous Execution](./performance/good-practices/asynchronous-execution.md)\n    - [Kernel Fusion](./performance/good-practices/kernel-fusion.md)\n    - [Kernel Selection](./performance/good-practices/kernel-selection.md)\n  - [Quantization](./performance/quantization.md)\n  - [Distributed Computing](./performance/distributed-computing.md)\n- [Custom Training Loop](./custom-training-loop.md)\n- [Saving & Loading Models](./saving-and-loading.md)\n- [ONNX Import](./onnx-import.md)\n- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)\n- [Advanced](./advanced/README.md)\n  - [Backend Extension](./advanced/backend-extension/README.md)\n    - [Custom `CubeCL` Kernel](./advanced/backend-extension/custom-cubecl-kernel.md)\n    - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)\n  - [Custom Optimizer]()\n  - [WebAssembly](./advanced/web-assembly.md)\n  - [No-Std](./advanced/no-std.md)\n"
  },
  {
    "path": "burn-book/src/advanced/README.md",
    "content": "# Advanced\n\nIn this section, we will go into advanced topics that extend beyond basic usage. Given Burn's\nexceptional flexibility, a lot of advanced use cases become possible.\n\nBefore going through this section, we strongly recommend exploring the\n[basic workflow](../basic-workflow/) section and the\n[building blocks](../building-blocks/) section. Establishing a solid understanding of how\nthe framework operates is crucial to comprehending the advanced concepts presented here. While you\nhave the freedom to explore the advanced sections in any order you prefer, it's important to note\nthat this section is not intended to be linear, contrary to preceding sections. Instead, it serves\nas a repository of use cases that you can refer to for guidance as needed.\n"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/README.md",
    "content": "# Backend Extension\n\nBurn aims to be the most flexible deep learning framework. While it's crucial to maintain\ncompatibility with a wide variety of backends, Burn provides the ability to extend the functionality\nof a backend implementation to suit your modeling requirements. This versatility is advantageous in\nnumerous ways, such as supporting custom operations like flash attention or manually fusing\noperations for enhanced performance.\n\nIn this section, we will go into the process of extending a backend, providing multiple examples.\nBut before we proceed, let's establish the fundamental principles that will empower you to craft\nyour own backend extensions.\n\nAs you can observe, most types in Burn are generic over the Backend trait. This might give the\nimpression that Burn operates at a high level over the backend layer. However, making the trait\nexplicit instead of being chosen via a compilation flag was a thoughtful design decision. This\nexplicitness does not imply that all backends must be identical; rather, it offers a great deal of\nflexibility when composing backends. The autodifferentiation backend trait (see\n[autodiff section](../../building-blocks/autodiff.md)) is an example of how the backend trait has\nbeen extended to enable gradient computation with backpropagation. Furthermore, this design allows\nyou to create your own backend extension. To achieve this, you need to design your own backend trait\nspecifying which functions should be supported.\n\n```rust, ignore\npub trait Backend: burn::tensor::backend::Backend {\n    fn my_new_function(tensor: B::TensorPrimitive<2>) -> B::TensorPrimitive<2> {\n        // You can define a basic implementation reusing the Burn Backend API.\n        // This can be useful since all backends will now automatically support\n        // your model. But performance can be improved for this new\n        // operation by implementing this block in specific backends.\n    }\n}\n```\n\nYou can then implement your new custom backend trait for any backend that you want to support:\n\n```rust, ignore\nimpl<E: TchElement> Backend for burn_tch::LibTorch<E> {\n   fn my_new_function(tensor: TchTensor<E, 2>) -> TchTensor<E, 2> {\n      // My Tch implementation\n   }\n}\n\nimpl<E: NdArrayElement> Backend for burn_ndarray::NdArray<E> {\n    // No specific implementation, but the backend can still be used.\n}\n```\n\nYou can support the backward pass using the same pattern.\n\n```rust, ignore\nimpl<B: Backend> Backend for burn_autodiff::Autodiff<B> {\n    // No specific implementation; autodiff will work with the default\n    // implementation. Useful if you still want to train your model, but\n    // observe performance gains mostly during inference.\n}\n\nimpl<B: Backend> Backend for burn_autodiff::Autodiff<B> {\n   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {\n      // My own backward implementation, generic over my custom Backend trait.\n      //\n      // You can add a new method `my_new_function_backward` to your custom backend\n      // trait if you want to invoke a custom kernel during the backward pass.\n   }\n}\n\nimpl<E: TchElement> Backend for burn_autodiff::Autodiff<burn_tch::LibTorch<E>> {\n   fn my_new_function(tensor: AutodiffTensor<E, 2>) -> AutodiffTensor<E, 2> {\n      // My own backward implementation, generic over a backend implementation.\n      //\n      // This is another way to call a custom kernel for the backward pass that\n      // doesn't require the addition of a new `backward` function in the custom backend.\n      // This is useful if you don't want all backends to support training, reducing\n      // the need for extra code when you know your model will only be trained on one\n      // specific backend.\n   }\n}\n```\n\nThe specifics of each implementation will be covered by the examples provided in this section. The\n`cubecl` compiler frontend is the recommended method of implementing custom kernels, since it\nsupports multiple backends, including `wgpu` and `CUDA`, and is the way first-party `burn` kernels\nare written.\n"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/custom-cubecl-kernel.md",
    "content": "# Custom CubeCL Kernel\n\nIn this section, you will learn how to create your own custom operation by writing your own kernel\nwith the cubecl compiler frontend. We will take the example of a common workflow in the deep\nlearning field, where we create a kernel to fuse multiple operations together. Note that `burn` does\nthis automatically, but a manual implementation might be more efficient in some cases. We will fuse\na matmul kernel followed by an addition and the ReLU activation function, which is commonly found in\nvarious models. All the code can be found under the\n[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-cubecl-kernel).\n\n> Note: CubeCL is in active development, so this section may be outdated.\n\n## Custom Backend Trait\n\nFirst, we need to determine the type signature of our newly created operation by defining our custom\nbackend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which\nencapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid\nthe ugly disambiguation with associated types.\n\n```rust, ignore\n/// We create our own Backend trait that extends the Burn backend trait.\npub trait Backend: burn::tensor::backend::Backend {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self>;\n}\n\n/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.\npub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}\n```\n\nIn our project, we can use these traits instead of the\n`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs\ntypically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.\nTherefore, we can encapsulate our newly defined backend traits with functions that expose new\noperations while maintaining a consistent API.\n\n```rust, ignore\n/// We define our custom implementation using the added function on our custom backend.\npub fn matmul_add_relu_custom<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let output = B::fused_matmul_add_relu(\n        lhs.into_primitive().tensor(),\n        rhs.into_primitive().tensor(),\n        bias.into_primitive().tensor(),\n    );\n\n    Tensor::from_primitive(TensorPrimitive::Float(output))\n}\n\n/// We define a reference implementation using basic tensor operations.\npub fn matmul_add_relu_reference<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let x = lhs.matmul(rhs) + bias;\n\n    activation::relu(x)\n}\n\n```\n\nNote that we also provide a reference implementation for testing purposes, which allows us to easily\nvalidate our new implementation. While not mandatory, having a reference implementation can be\nvaluable, especially in projects where creating a reference implementation solely using basic tensor\noperations is feasible.\n\n## Forward Kernel\n\nNow, let's proceed to write the fused kernel using the `cubecl` compiler frontend. To keep things\nsimple, we'll create a straightforward matmul kernel without employing any intricate techniques. We\nwon't delve into the details of the `cube` macro, but if you're interested to learn more, please see\n[`cubecl` Book](https://github.com/tracel-ai/cubecl/tree/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/cubecl-book).\nThe actual matmul, add and relu computations are found at the end, after an extensive prelude that\nserves to correctly map each compute unit to the data it is responsible for, with support for\nbatches.\n\n```rust, ignore\nuse cubecl::{cube, prelude::*};\n\n#[cube(launch)]\npub fn fused_matmul_add_relu_kernel<F: Float>(\n    lhs: &Tensor<F>,\n    rhs: &Tensor<F>,\n    bias: &Tensor<F>,\n    output: &mut Tensor<F>,\n) {\n    let row = ABSOLUTE_POS_X;\n    let col = ABSOLUTE_POS_Y;\n    let batch = ABSOLUTE_POS_Z;\n\n    let n_rows = output.shape(output.rank() - 2);\n    let n_cols = output.shape(output.rank() - 1);\n    let dim_k = rhs.shape(rhs.rank() - 1);\n\n    if row >= n_rows || col >= n_cols {\n        return;\n    }\n\n    let offset_output = batch * n_rows * n_cols;\n    let mut offset_lhs = 0;\n    let mut offset_rhs = 0;\n\n    let batch_dims = output.rank() - 2;\n    for dim in 0..batch_dims {\n        offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);\n        offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);\n    }\n\n    let mut sum = F::new(0.0);\n    for k in 0..dim_k {\n        let lhs_index = row * dim_k + k;\n        let rhs_index = k * n_cols + col;\n\n        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];\n    }\n\n    let out_index = row * n_cols + col;\n    let index = offset_output + out_index;\n\n    output[index] = F::max(sum + bias[index], F::new(0.0));\n}\n```\n\nNow, let's move on to the next step, which involves implementing the remaining code to launch the\nkernel. We'll go into implementing our custom backend trait for the generic JIT backend. This\nautomatically implements the trait for `burn-cuda`, `burn-wgpu` as well as fusion.\n\n```rust, ignore\n/// Implement our custom backend trait for the generic `CubeBackend`.\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Backend\n    for CubeBackend<R, F, I, BT>\n{\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Define cube dim, hardcoded for simplicity.\n        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };\n\n        lhs.assert_is_on_same_device(&rhs);\n        lhs.assert_is_on_same_device(&bias);\n\n        // For simplicity, make sure each tensor is continuous.\n        let lhs = into_contiguous(lhs);\n        let rhs = into_contiguous(rhs);\n        let bias = into_contiguous(bias);\n\n        // Get the matmul relevant shapes.\n        let ndims = lhs.shape.num_dims();\n        let num_rows = lhs.shape[ndims - 2];\n        let num_cols = rhs.shape[ndims - 1];\n\n        // Compute shape of output, while tracking number of batches.\n        let mut num_batches = 1;\n        let mut shape_out = vec![0; ndims];\n        for i in shape_out.clone().into_iter().take(ndims - 2) {\n            shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]);\n            num_batches *= shape_out[i];\n        }\n        shape_out[ndims - 2] = num_rows;\n        shape_out[ndims - 1] = num_cols;\n        let shape_out = Shape::from(shape_out);\n\n        // Create a buffer for the output tensor.\n        let buffer = lhs\n            .client\n            .empty(shape_out.num_elements() * core::mem::size_of::<F>());\n\n        // Create the output tensor primitive.\n        let output = CubeTensor::new_contiguous(\n            lhs.client.clone(),\n            lhs.device.clone(),\n            shape_out,\n            buffer,\n            F::dtype(),\n        );\n\n        // Declare the wgsl workgroup with the number of cubes in x, y and z.\n        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;\n        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;\n        let cube_count =\n            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);\n\n        // Execute lazily the kernel with the launch information and the given buffers. For\n        // simplicity, no vectorization is performed\n        fused_matmul_add_relu_kernel::launch::<F, R>(\n            &lhs.client,\n            cube_count,\n            cube_dim,\n            lhs.into_tensor_arg(),\n            rhs.into_tensor_arg(),\n            bias.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n        );\n\n        // Return the output tensor.\n        output\n    }\n}\n```\n\nIn the preceding code block, we demonstrated how to launch the kernel that modifies the correct\nbuffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the\ncapability to execute any mutable operation on any buffer. While this isn't a problem in the\nprevious scenario where we only modify the newly created output buffer, it is wise to keep this in\nmind.\n\n## Backward\n\nNow that the custom backend trait is implemented for the JIT backend, you can use it to invoke the\n`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.\nIf your use case does not extend beyond inference, there is no need to implement any of the\nfollowing code.\n\nFor the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is\nactually generic over the backend. Instead of crafting our own `cubecl` kernel for the backward\npass, we will use our fused kernel only for the forward pass, and compute the gradient using basic\noperations.\n\n```rust, ignore\n// Implement our custom backend trait for any backend that also implements our custom backend trait.\nimpl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Create our zero-sized type that will implement the Backward trait.\n        #[derive(Debug)]\n        struct FusedMatmulAddReluBackward;\n\n        // Implement the backward trait for the given backend B, the node gradient\n        // with three other gradients to calculate (lhs, rhs, and bias).\n        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {\n            // Our state that we must build during the forward pass to compute the backward pass.\n            //\n            // Note that we could improve the performance further by only keeping the state of\n            // tensors that are tracked, improving memory management, but for simplicity, we avoid\n            // that part.\n            type State = (NodeId, NodeId, FloatTensor<B>, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                // Get the nodes of each variable.\n                let [node_lhs, node_rhs, node_bias] = ops.parents;\n                // Fetch the gradient for the current node.\n                let grad = grads.consume::<B>(&ops.node);\n\n                // Set our state.\n                let (lhs_state, rhs_state, output, shape_bias) = ops.state;\n                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);\n                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);\n\n                // Fetch shapes of our tensor to support broadcasting.\n                let shape_lhs = lhs.shape();\n                let shape_rhs = rhs.shape();\n\n                // Compute the gradient of the output using the already existing `relu_backward`\n                // function in the basic Burn backend trait.\n                let grad_output = B::relu_backward(output, grad);\n\n                // Compute the lhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_lhs = broadcast_shape::<B>(\n                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),\n                    &shape_lhs,\n                );\n                // Compute the rhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_rhs = broadcast_shape::<B>(\n                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),\n                    &shape_rhs,\n                );\n                // The add derivative is only 1, so we just need to support broadcasting to\n                // compute the bias gradient.\n                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);\n\n                // Register the gradient for each variable based on whether they are marked as\n                // `tracked`.\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, grad_bias);\n                }\n                if let Some(node) = node_lhs {\n                    grads.register::<B>(node.id, grad_lhs);\n                }\n                if let Some(node) = node_rhs {\n                    grads.register::<B>(node.id, grad_rhs);\n                }\n            }\n        }\n\n        // Prepare a stateful operation with each variable node and corresponding graph.\n        //\n        // Each node can be fetched with `ops.parents` in the same order as defined here.\n        match FusedMatmulAddReluBackward\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])\n            // Marks the operation as compute bound, meaning it will save its\n            // state instead of recomputing itself during checkpointing\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                // When at least one node is tracked, we should register our backward step.\n\n                // The state consists of what will be needed for this operation's backward pass.\n                // Since we need the parents' outputs, we must checkpoint their ids to retrieve\n                // their node output at the beginning of the backward pass. We can also save\n                // utility data such as the bias shape. If we also need this operation's output,\n                // we can either save it in the state or recompute it.\n                // during the backward pass. Here we choose to save it in the state because it's a\n                // compute bound operation.\n                let lhs_state = prep.checkpoint(&lhs);\n                let rhs_state = prep.checkpoint(&rhs);\n                let bias_shape = bias.primitive.shape();\n\n                let output = B::fused_matmul_add_relu(\n                    lhs.primitive.clone(),\n                    rhs.primitive.clone(),\n                    bias.primitive,\n                );\n\n                let state = (lhs_state, rhs_state, output.clone(), bias_shape);\n\n                prep.finish(state, output)\n            }\n            OpsKind::UnTracked(prep) => {\n                // When no node is tracked, we can just compute the original operation without\n                // keeping any state.\n                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);\n                prep.finish(output)\n            }\n        }\n    }\n}\n```\n\nThe previous code is self-documented to make it clearer, but here is what it does in summary:\n\nWe define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to\nbenefit from our implementation. In an autodiff-decorated backend, the forward pass must still be\nimplemented. This is achieved using a comprehensive match statement block where computation is\ndelegated to the inner backend, while keeping track of a state. The state comprises any information\nrelevant to the backward pass, such as input and output tensors, along with the bias shape. When an\noperation isn't tracked (meaning there won't be a backward pass for this specific operation in the\ngraph), storing a state becomes unnecessary, and we simply perform the forward computation.\n\nThe backward pass uses the gradient obtained from the preceding node in the computation graph. It\ncalculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the\nderivative is one), and `matmul` (another `matmul` with transposed inputs). This results in\ngradients for both input tensors and the bias, which are registered for consumption by subsequent\noperation nodes.\n\nThe only remaining part is to implement our autodiff-decorated backend trait for our JIT Backend.\n\n```rust, ignore\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend\n    for Autodiff<CubeBackend<R, F, I, BT>>\n{\n}\n```\n\n## Conclusion\n\nIn this guide, we've implemented a fused kernel using the `cubecl` compiler frontend, enabling\nexecution on any GPU and any `cubecl` backend. By delving into the inner workings of both the JIT\nbackend and the autodiff backend, we've gained a deeper understanding of these systems.\n\nWhile extending a backend may be harder than working with straightforward tensors, the benefits can\nbe worth it. This approach enables the crafting of custom models with greater control over\nexecution, which can potentially greatly enhance the performance of your models.\n\nAs we conclude this guide, we hope that you have gained insights into Burn's world of backend\nextensions, and that it will help you to unleash the full potential of your projects.\n"
  },
  {
    "path": "burn-book/src/advanced/backend-extension/custom-wgpu-kernel.md",
    "content": "# Custom WGPU Kernel\n\nIn this section, you will learn how to create your own custom operation by writing your own kernel\nwith the WGPU backend. We will take the example of a common workflow in the deep learning field,\nwhere we create a kernel to fuse multiple operations together. Note that `burn` does this\nautomatically, but a manual implementation might be more efficient in some cases. We will fuse a\nmatmul kernel followed by an addition and the ReLU activation function, which is commonly found in\nvarious models. All the code can be found under the\n[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/custom-wgpu-kernel).\n\n## Custom Backend Trait\n\nFirst, we need to determine the type signature of our newly created operation by defining our custom\nbackend traits. As we will use the associated type `TensorPrimitive` of the `Backend` trait, which\nencapsulates the underlying tensor implementation of the backend, we will use a type alias to avoid\nthe ugly disambiguation with associated types.\n\n```rust, ignore\n/// We create our own Backend trait that extends the Burn backend trait.\npub trait Backend: burn::tensor::backend::Backend {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self>;\n}\n\n/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.\npub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}\n```\n\nIn our project, we can use these traits instead of the\n`burn::tensor::backend::{Backend, AutodiffBackend}` traits provided by Burn. Burn's user APIs\ntypically make use of the `Tensor` struct rather than dealing directly with primitive tensor types.\nTherefore, we can encapsulate our newly defined backend traits with functions that expose new\noperations while maintaining a consistent API.\n\n```rust, ignore\n/// We define our custom implementation using the added function on our custom backend.\npub fn matmul_add_relu_custom<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let output = B::fused_matmul_add_relu(\n        lhs.into_primitive().tensor(),\n        rhs.into_primitive().tensor(),\n        bias.into_primitive().tensor(),\n    );\n\n    Tensor::from_primitive(TensorPrimitive::Float(output))\n}\n\n/// We define a reference implementation using basic tensor operations.\npub fn matmul_add_relu_reference<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let x = lhs.matmul(rhs) + bias;\n\n    activation::relu(x)\n}\n\n```\n\nNote that we also provide a reference implementation for testing purposes, which allows us to easily\nvalidate our new implementation. While not mandatory, having a reference implementation can be\nvaluable, especially in projects where creating a reference implementation solely using basic tensor\noperations is feasible.\n\n## Forward Kernel\n\nNow, let's proceed to write the fused kernel using the WGSL shading language. To keep things simple,\nwe'll create a straightforward matmul kernel without employing any intricate techniques. Although we\nwon't delve into the details of the WGSL syntax, as it falls beyond the scope of this guide, we\nstill provide the implementation below for readers who are curious. The actual matmul, add and relu\ncomputations are found at the end, after an extensive overhead whose use is to correctly map each\ncompute unit to the data it is responsible of, with support for batches.\n\n```wgsl, ignore\n@group(0)\n@binding(0)\nvar<storage, read_write> lhs: array<{{ elem }}>;\n\n@group(0)\n@binding(1)\nvar<storage, read_write> rhs: array<{{ elem }}>;\n\n@group(0)\n@binding(2)\nvar<storage, read_write> bias: array<{{ elem }}>;\n\n@group(0)\n@binding(3)\nvar<storage, read_write> output: array<{{ elem }}>;\n\n@group(0)\n@binding(4)\nvar<storage, read_write> info: array<u32>;\n\nconst BLOCK_SIZE = {{ workgroup_size_x }}u;\n\n@compute\n@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)\nfn main(\n    @builtin(global_invocation_id) global_id: vec3<u32>,\n    @builtin(local_invocation_index) local_idx: u32,\n    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n) {\n    // Indices\n    let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);\n    let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);\n    let batch = global_id.z;\n\n    // Basic information\n    let dim = info[0];\n    let n_rows = info[6u * dim - 1u];\n    let n_cols = info[6u * dim];\n    let K = info[5u * dim - 1u];\n\n    // Returns if outside the output dimension\n    if row >= n_rows || col >= n_cols {\n        return;\n    }\n\n    // Calculate the corresponding offsets with support for broadcasting.\n    let offset_output = batch * n_rows * n_cols;\n    var offset_lhs: u32 = 0u;\n    var offset_rhs: u32 = 0u;\n\n    let batch_dims = dim - 2u;\n    for (var b: u32 = 1u; b <= batch_dims; b++) {\n        let stride_lhs = info[b];\n        let stride_rhs = info[b + dim];\n        let stride_output = info[b + 2u * dim];\n        let shape_lhs = info[b + 3u * dim];\n        let shape_rhs = info[b + 4u * dim];\n\n        offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;\n        offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;\n    }\n\n    // Basic matmul implementation\n    var sum = 0.0;\n    for (var k: u32 = 0u; k < K; k++) {\n        let lhs_index = row * K + k;\n        let rhs_index = k * n_cols + col;\n\n        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];\n    }\n\n    let output_index = row * n_cols + col;\n    let index = offset_output + output_index;\n\n    // Add and ReLU\n    output[index] = max(sum + bias[index], 0.0);\n}\n```\n\nNow, let's move on to the next step, which involves implementing the remaining code to launch the\nkernel. The initial part entails loading the template and populating it with the appropriate\nvariables. The `register(name, value)` method simply replaces occurrences of `{{ name }}` in the\nabove WGSL code with some other string before it is compiled. In order to use templating utilities,\nyou will have to activate the `template` feature of Burn in your `cargo.toml`.\n\n```rust, ignore\n// Source the kernel written in WGSL.\nkernel_wgsl!(FusedMatmulAddReluRaw, \"./kernel.wgsl\");\n\n// Define our kernel type with cube information.\n#[derive(new, Debug)]\nstruct FusedMatmulAddRelu<E: FloatElement> {\n    cube_dim: CubeDim,\n    _elem: PhantomData<E>,\n}\n\n// Implement the dynamic kernel trait for our kernel type.\nimpl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {\n    fn source(&self) -> SourceTemplate {\n        // Extend our raw kernel with cube size information using the\n        // `SourceTemplate` trait.\n        FusedMatmulAddReluRaw::new()\n            .source()\n            .register(\"workgroup_size_x\", self.cube_dim.x.to_string())\n            .register(\"workgroup_size_y\", self.cube_dim.y.to_string())\n            .register(\"elem\", E::type_name())\n            .register(\"int\", \"i32\")\n    }\n\n    fn id(&self) -> cubecl::KernelId {\n        cubecl::KernelId::new::<Self>().info(self.cube_dim)\n    }\n}\n```\n\nSubsequently, we'll go into implementing our custom backend trait for the WGPU backend. Note that we\nwon't go into supporting the `fusion` feature flag in this tutorial, so we implement the trait for\nthe raw `WgpuBackend` type.\n\n```rust, ignore\n/// Implement our custom backend trait for the existing backend `WgpuBackend`.\nimpl<F: FloatElement, I: IntElement, BT: BoolElement> Backend\n    for CubeBackend<WgpuRuntime, F, I, BT>\n{\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Define cube dim, hardcoded for simplicity.\n        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };\n\n        lhs.assert_is_on_same_device(&rhs);\n        lhs.assert_is_on_same_device(&bias);\n\n        // For simplicity, make sure each tensor is continuous.\n        let lhs = into_contiguous(lhs);\n        let rhs = into_contiguous(rhs);\n        let bias = into_contiguous(bias);\n\n        // Get the matmul relevant shapes.\n        let ndims = lhs.shape.num_dims();\n        let num_rows = lhs.shape[ndims - 2];\n        let num_cols = rhs.shape[ndims - 1];\n\n        // Compute shape of output, while tracking number of batches.\n        let mut num_batches = 1;\n        let mut shape_out = vec![0; ndims];\n        for i in shape_out.clone().into_iter().take(ndims - 2) {\n            shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]);\n            num_batches *= shape_out[i];\n        }\n        shape_out[ndims - 2] = num_rows;\n        shape_out[ndims - 1] = num_cols;\n        let shape_out = Shape::from(shape_out);\n\n        // Create a buffer for the output tensor.\n        let buffer = lhs\n            .client\n            .empty(shape_out.num_elements() * core::mem::size_of::<F>());\n\n        // Create the output tensor primitive.\n        let output = CubeTensor::new_contiguous(\n            lhs.client.clone(),\n            lhs.device.clone(),\n            shape_out,\n            buffer,\n            F::dtype(),\n        );\n\n        // Create the kernel.\n        let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);\n\n        // Build info buffer with tensor information needed by the kernel, such as shapes and strides.\n        let info = build_info::<_, F>(&[&lhs, &rhs, &output]);\n        let info_handle = lhs.client.create(bytemuck::cast_slice(&info));\n\n        // Declare the wgsl workgroup with the number of cubes in x, y and z.\n        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;\n        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;\n        let cube_count =\n            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);\n\n        // Execute lazily the kernel with the launch information and the given buffers.\n        lhs.client.execute(\n            Box::new(SourceKernel::new(kernel, cube_dim)),\n            cube_count,\n            Bindings::new().with_buffers(vec![\n                lhs.handle.binding(),\n                rhs.handle.binding(),\n                bias.handle.binding(),\n                output.handle.clone().binding(),\n                info_handle.binding(),\n            ]),\n        );\n\n        // Return the output tensor.\n        output\n    }\n}\n```\n\nIn the preceding code block, we demonstrated how to launch the kernel that modifies the correct\nbuffer. It's important to note that Rust's mutability safety doesn't apply here; the context has the\ncapability to execute any mutable operation on any buffer. While this isn't a problem in the\nprevious scenario where we only modify the newly created output buffer, it is wise to keep this in\nmind.\n\n## Backward\n\nNow that the custom backend trait is implemented for the WGPU backend, you can use it to invoke the\n`matmul_add_relu_custom` function. However, calculating gradients is not yet possible at this stage.\nIf your use case does not extend beyond inference, there is no need to implement any of the\nfollowing code.\n\nFor the backward pass, we will leverage the backend implementation from `burn-autodiff`, which is\nactually generic over the backend. Instead of crafting our own WGSL kernel for the backward pass, we\nwill use our fused kernel only for the forward pass, and compute the gradient using basic\noperations.\n\n```rust, ignore\n// Implement our custom backend trait for any backend that also implements our custom backend trait.\n//\n// Note that we could implement the backend trait only for the Wgpu backend instead of any backend that\n// also implements our own API. This would allow us to call any function only implemented for Wgpu\n// and potentially call a custom kernel crafted only for this task.\nimpl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Create our zero-sized type that will implement the Backward trait.\n        #[derive(Debug)]\n        struct FusedMatmulAddReluBackward;\n\n        // Implement the backward trait for the given backend B, the node gradient\n        // with three other gradients to calculate (lhs, rhs, and bias).\n        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {\n            // Our state that we must build during the forward pass to compute the backward pass.\n            //\n            // Note that we could improve the performance further by only keeping the state of\n            // tensors that are tracked, improving memory management, but for simplicity, we avoid\n            // that part.\n            type State = (NodeId, NodeId, FloatTensor<B>, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                // Get the nodes of each variable.\n                let [node_lhs, node_rhs, node_bias] = ops.parents;\n                // Fetch the gradient for the current node.\n                let grad = grads.consume::<B>(&ops.node);\n\n                // Set our state.\n                let (lhs_state, rhs_state, output, shape_bias) = ops.state;\n                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);\n                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);\n\n                // Fetch shapes of our tensor to support broadcasting.\n                let shape_lhs = lhs.shape();\n                let shape_rhs = rhs.shape();\n\n                // Compute the gradient of the output using the already existing `relu_backward`\n                // function in the basic Burn backend trait.\n                let grad_output = B::relu_backward(output, grad);\n\n                // Compute the lhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_lhs = broadcast_shape::<B>(\n                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),\n                    &shape_lhs,\n                );\n                // Compute the rhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_rhs = broadcast_shape::<B>(\n                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),\n                    &shape_rhs,\n                );\n                // The add derivative is only 1, so we just need to support broadcasting to\n                // compute the bias gradient.\n                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);\n\n                // Register the gradient for each variable based on whether they are marked as\n                // `tracked`.\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, grad_bias);\n                }\n                if let Some(node) = node_lhs {\n                    grads.register::<B>(node.id, grad_lhs);\n                }\n                if let Some(node) = node_rhs {\n                    grads.register::<B>(node.id, grad_rhs);\n                }\n            }\n        }\n\n        // Prepare a stateful operation with each variable node and corresponding graph.\n        //\n        // Each node can be fetched with `ops.parents` in the same order as defined here.\n        match FusedMatmulAddReluBackward\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])\n            // Marks the operation as compute bound, meaning it will save its\n            // state instead of recomputing itself during checkpointing\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                // When at least one node is tracked, we should register our backward step.\n\n                // The state consists of what will be needed for this operation's backward pass.\n                // Since we need the parents' outputs, we must checkpoint their ids to retrieve their node\n                // output at the beginning of the backward. We can also save utility data such as the bias shape\n                // If we also need this operation's output, we can either save it in the state or recompute it\n                // during the backward pass. Here we choose to save it in the state because it's a compute bound operation.\n                let lhs_state = prep.checkpoint(&lhs);\n                let rhs_state = prep.checkpoint(&rhs);\n                let bias_shape = bias.primitive.shape();\n\n                let output = B::fused_matmul_add_relu(\n                    lhs.primitive.clone(),\n                    rhs.primitive.clone(),\n                    bias.primitive,\n                );\n\n                let state = (lhs_state, rhs_state, output.clone(), bias_shape);\n\n                prep.finish(state, output)\n            }\n            OpsKind::UnTracked(prep) => {\n                // When no node is tracked, we can just compute the original operation without\n                // keeping any state.\n                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);\n                prep.finish(output)\n            }\n        }\n    }\n}\n```\n\nThe previous code is self-documented to make it clearer, but here is what it does in summary.\n\nWe define `fused_matmul_add_relu` within `Autodiff<B>`, allowing any autodiff-decorated backend to\nbenefit from our implementation. In an autodiff-decorated backend, the forward pass must still be\nimplemented. This is achieved using a comprehensive match statement block where computation is\ndelegated to the inner backend, while keeping track of a state. The state comprises any information\nrelevant to the backward pass, such as input and output tensors, along with the bias shape. When an\noperation isn't tracked (meaning there won't be a backward pass for this specific operation in the\ngraph), storing a state becomes unnecessary, and we simply perform the forward computation.\n\nThe backward pass uses the gradient obtained from the preceding node in the computation graph. It\ncalculates the derivatives for `relu` (`relu_backward`), add (no operation is required here, as the\nderivative is one), and `matmul` (another `matmul` with transposed inputs). This results in\ngradients for both input tensors and the bias, which are registered for consumption by subsequent\noperation nodes.\n\nThe only remaining part is to implement our autodiff-decorated backend trait for our WGPU Backend.\n\n```rust, ignore\nimpl<F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend\n    for Autodiff<CubeBackend<WgpuRuntime, F, I, BT>>\n{\n}\n```\n\n## Conclusion\n\nIn this guide, we've implemented a fused kernel using the WGPU backend, enabling execution on any\nGPU. By delving into the inner workings of both the WGPU backend and the autodiff backend, we've\ngained a deeper understanding of these systems.\n\nWhile extending a backend may be harder than working with straightforward tensors, the benefits can\nbe worth it. This approach enables the crafting of custom models with greater control over\nexecution, which can potentially greatly enhance the performance of your models.\n\nAs we conclude this guide, we hope that you have gained insights into Burn's world of backend\nextensions, and that it will help you to unleash the full potential of your projects.\n"
  },
  {
    "path": "burn-book/src/advanced/no-std.md",
    "content": "# No Standard Library\n\nIn this section, you will learn how to run an ONNX inference model on an embedded system, with no\nstandard library support on a Raspberry Pi Pico 2. This should be universally applicable to other\nplatforms. All the code can be found in the\n[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico).\n\n## Step-by-Step Guide\n\nLet's walk through the process of running an embedded ONNX model:\n\n### Setup\nFollow the [embassy guide](https://embassy.dev/book/#_getting_started) for your specific environment. Once setup, you should have something similar to the following.\n```\n./inference\n├── Cargo.lock\n├── Cargo.toml\n├── build.rs\n├── memory.x\n└── src\n    └── main.rs\n```\n\nSome other dependencies have to be added\n```toml\n[dependencies]\nembedded-alloc = \"0.6.0\" # Only if there is no default allocator for your chip\nburn = { version = \"0.21\", default-features = false, features = [\"ndarray\"] } # Backend must be ndarray\nburn-store = { version = \"0.21\", default-features = false, features = [\"burnpack\"] }\n\n[build-dependencies]\nburn-onnx = { version = \"0.21\" } # Used to auto generate the rust code to import the model\n```\n\n### Import the Model\nFollow the directions in [ONNX Import](../onnx-import.md).\n\nUse the following ModelGen config\n```rs\nModelGen::new()\n    .input(my_model)\n    .out_dir(\"model/\")\n    .embed_states(true)\n    .run_from_script();\n```\n\n### Global Allocator\nFirst define a global allocator (if you are on a no_std system without alloc).\n\n```rs\nuse embedded_alloc::LlffHeap as Heap;\n\n#[global_allocator]\nstatic HEAP: Heap = Heap::empty();\n\n#[embassy_executor::main]\nasync fn main(_spawner: Spawner) {\n    {\n        use core::mem::MaybeUninit;\n        // Watch out for this, if it is too big or small for your model, the\n        // program may crash. This is in u8 bytes, as such this is a total of 100kb\n        const HEAP_SIZE: usize = 100 * 1024;\n        static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];\n        unsafe { HEAP.init(&raw mut HEAP_MEM as usize, HEAP_SIZE) } // Initialize the heap\n    }\n}\n```\n\n### Define Backend\nWe are using ndarray, so we just need to define the NdArray backend as usual\n```rs\nuse burn::{backend::NdArray, tensor::Tensor};\n\ntype Backend = NdArray<f32>;\ntype BackendDevice = <Backend as burn::tensor::backend::Backend>::Device;\n```\n\nThen inside the `main` function add\n```rs\nuse your_model::Model;\n\n// Get a default device for the backend\nlet device = BackendDevice::default();\n\n// Create a new model and load the state\nlet model: Model<Backend> = Model::default();\n```\n\n### Running the Model\nTo run the model, just call it as you would normally\n```rs\n// Define the tensor\nlet input = Tensor::<Backend, 2>::from_floats([[input]], &device);\n\n// Run the model on the input\nlet output = model.forward(input);\n```\n\n## Conclusion\nRunning a model in a no_std environment is pretty much identical to a normal environment. All that is needed is a global allocator.\n"
  },
  {
    "path": "burn-book/src/advanced/web-assembly.md",
    "content": "# WebAssembly\n\nBurn supports WebAssembly (WASM) execution using the `NdArray` and `WebGpu` backends, allowing\nmodels to run directly in the browser.\n\nCheck out the following examples:\n\n- [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web)\n- [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web)\n\nWhen targeting WebAssembly, certain dependencies require additional configuration. In particular,\nthe `getrandom` crate requires explicit setting when using `WebGpu`.\n"
  },
  {
    "path": "burn-book/src/basic-workflow/README.md",
    "content": "# Guide\n\nThis guide will walk you through the process of creating a custom model built with Burn. We will\ntrain a simple convolutional neural network model on the MNIST dataset and prepare it for inference.\n\nFor clarity, we sometimes omit imports in our code snippets. For more details, please refer to the\ncorresponding code in the `examples/guide` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide).\nWe reproduce this example in a step-by-step fashion, from dataset creation to modeling and training\nin the following sections. It is recommended to use the capabilities of your IDE or text editor to\nautomatically add the missing imports as you add the code snippets to your code.\n\n<div class=\"warning\">\n\nBe sure to checkout the git branch corresponding to the version of Burn you are using to follow\nthis guide.\n\nThe current version of Burn is `0.21` and the corresponding branch to checkout is `main`.\n</div>\n\nThe code for this demo can be executed from Burn's base directory using the command:\n\n```bash\ncargo run --example guide\n```\n\n## Key Learnings\n\n- Creating a project\n- Creating neural network models\n- Importing and preparing datasets\n- Training models on data\n- Choosing a backend\n- Using a model for inference\n"
  },
  {
    "path": "burn-book/src/basic-workflow/backend.md",
    "content": "# Backend\n\nWe have effectively written most of the necessary code to train our model. However, we have not\nexplicitly designated the backend to be used at any point. This will be defined in the main\nentrypoint of our program, namely the `main` function defined in `src/main.rs`.\n\n```rust , ignore\n# #![recursion_limit = \"256\"]\n# mod data;\n# mod model;\n# mod training;\n#\nuse crate::{model::ModelConfig, training::TrainingConfig};\nuse burn::{\n    backend::{Autodiff, Wgpu},\n#     data::dataset::Dataset,\n    optim::AdamConfig,\n};\n\nfn main() {\n    type MyBackend = Wgpu<f32, i32>;\n    type MyAutodiffBackend = Autodiff<MyBackend>;\n\n    let device = burn::backend::wgpu::WgpuDevice::default();\n    let artifact_dir = \"/tmp/guide\";\n    crate::training::train::<MyAutodiffBackend>(\n        artifact_dir,\n        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),\n        device.clone(),\n    );\n}\n```\n\nIn this code snippet, we use the `Wgpu` backend which is compatible with any operating system and will\nuse the GPU. For other options, see the Burn README. This backend type takes the graphics API, the\nfloat type and the int type as generic arguments that will be used during the training. The autodiff\nbackend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability \nto any backend.\n\nWe call the `train` function defined earlier with a directory for artifacts, the configuration of\nthe model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer\nconfiguration which in our case will be the default Adam configuration, and the device which can be\nobtained from the backend.\n\nYou can now train your freshly created model with the command:\n\n```console\ncargo run --release\n```\n\nWhen running your project with the command above, you should see the training progression through a\nbasic CLI dashboard:\n\n<img title=\"a title\" alt=\"Alt text\" src=\"./training-output.png\">\n"
  },
  {
    "path": "burn-book/src/basic-workflow/data.md",
    "content": "# Data\n\nTypically, one trains a model on some dataset. Burn provides a library of very useful dataset\nsources and transformations, such as Hugging Face dataset utilities that allow to download and store\ndata into an SQLite database for extremely efficient data streaming and storage. For this guide\nthough, we will use the MNIST dataset from `burn::data::dataset::vision` which requires no external\ndependency.\n\nTo iterate over a dataset efficiently, we will define a struct which will implement the `Batcher`\ntrait. The goal of a batcher is to map individual dataset items into a batched tensor that can be\nused as input to our previously defined model.\n\nLet us start by defining our dataset functionalities in a file `src/data.rs`. We shall omit some of\nthe imports for brevity, but the full code for following this guide can be found at\n`examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/examples/guide).\n\n```rust , ignore\nuse burn::{\n    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n    prelude::*,\n};\n\n\n#[derive(Clone, Default)]\npub struct MnistBatcher {}\n```\n\nThis batcher is pretty straightforward, as it only defines a struct that will implement the\n`Batcher` trait. The trait is generic over the `Backend` trait, which includes an associated type\nfor the device, as not all backends expose the same devices. As an example, the Libtorch-based\nbackend exposes `Cuda(gpu_index)`, `Cpu`, `Vulkan` and `Metal` devices, while the ndarray backend\nonly exposes the `Cpu` device.\n\nNext, we need to actually implement the batching logic.\n\n```rust , ignore\n# use burn::{\n#     data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n#     prelude::*,\n# };\n#\n# #[derive(Clone, Default)]\n# pub struct MnistBatcher {}\n#\n#[derive(Clone, Debug)]\npub struct MnistBatch<B: Backend> {\n    pub images: Tensor<B, 3>,\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {\n    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {\n        let images = items\n            .iter()\n            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())\n            .map(|data| Tensor::<B, 2>::from_data(data, device))\n            .map(|tensor| tensor.reshape([1, 28, 28]))\n            // Normalize: scale between [0,1] and make the mean=0 and std=1\n            // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example\n            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122\n            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)\n            .collect();\n\n        let targets = items\n            .iter()\n            .map(|item| {\n                Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)\n            })\n            .collect();\n\n        let images = Tensor::cat(images, 0);\n        let targets = Tensor::cat(targets, 0);\n\n        MnistBatch { images, targets }\n    }\n}\n```\n\n<details>\n<summary><strong>🦀 Iterators and Closures</strong></summary>\n\nThe iterator pattern allows you to perform some tasks on a sequence of items in turn.\n\nIn this example, an iterator is created over the `MnistItem`s in the vector `items` by calling the\n`iter` method.\n\n_Iterator adaptors_ are methods defined on the `Iterator` trait that produce different iterators by\nchanging some aspect of the original iterator. Here, the `map` method is called in a chain to\ntransform the original data before consuming the final iterator with `collect` to obtain the\n`images` and `targets` vectors. Both vectors are then concatenated into a single tensor for the\ncurrent batch.\n\nYou probably noticed that each call to `map` is different, as it defines a function to execute on\nthe iterator items at each step. These anonymous functions are called\n[_closures_](https://doc.rust-lang.org/book/ch13-01-closures.html) in Rust. They're easy to\nrecognize due to their syntax which uses vertical bars `||`. The vertical bars capture the input\nvariables (if applicable) while the rest of the expression defines the function to execute.\n\nIf we go back to the example, we can break down and comment the expression used to process the\nimages.\n\n```rust, ignore\nlet images = items                                                       // take items Vec<MnistItem>\n    .iter()                                                              // create an iterator over it\n    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())  // for each item, convert the image to float data struct\n    .map(|data| Tensor::<B, 2>::from_data(data, device))                 // for each data struct, create a tensor on the device\n    .map(|tensor| tensor.reshape([1, 28, 28]))                           // for each tensor, reshape to the image dimensions [C, H, W]\n    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)                    // for each image tensor, apply normalization\n    .collect();                                                          // consume the resulting iterator & collect the values into a new vector\n```\n\nFor more information on iterators and closures, be sure to check out the\n[corresponding chapter](https://doc.rust-lang.org/book/ch13-00-functional-features.html) in the Rust\nBook.\n\n</details><br>\n\nIn the previous example, we implement the `Batcher` trait with a list of `MnistItem` as input and a\nsingle `MnistBatch` as output. The batch contains the images in the form of a 3D tensor, along with\na targets tensor that contains the indexes of the correct digit class. The first step is to parse\nthe image array into a `TensorData` struct. Burn provides the `TensorData` struct to encapsulate\ntensor storage information without being specific for a backend. When creating a tensor from data,\nwe often need to convert the data precision to the current backend in use. This can be done with the\n`.convert()` method (in this example, the data is converted backend's float element type\n`B::FloatElem`). While importing the `burn::tensor::ElementConversion` trait, you can call `.elem()`\non a specific number to convert it to the current backend element type in use.\n"
  },
  {
    "path": "burn-book/src/basic-workflow/inference.md",
    "content": "# Inference\n\nNow that we have trained our model, the next natural step is to use it for inference.\n\nYou need two things in order to load weights for a model: the model's record and the model's config.\nSince parameters in Burn are lazy initialized, no allocation and GPU/CPU kernels are executed by the\n`ModelConfig::init` function. The weights are initialized when used for the first time, therefore\nyou can safely use `config.init(device).load_record(record)` without any meaningful performance\ncost. Let's create a simple `infer` method in a new file `src/inference.rs` which we will use to\nload our trained model.\n\n```rust , ignore\n# use crate::{data::MnistBatcher, training::TrainingConfig};\n# use burn::{\n#     data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n#     prelude::*,\n#     record::{CompactRecorder, Recorder},\n# };\n#\npub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {\n    let config = TrainingConfig::load(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should exist for the model; run train first\");\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model should exist; run train first\");\n\n    let model = config.model.init::<B>(&device).load_record(record);\n\n    let label = item.label;\n    let batcher = MnistBatcher::default();\n    let batch = batcher.batch(vec![item], &device);\n    let output = model.forward(batch.images);\n    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();\n\n    println!(\"Predicted {predicted} Expected {label}\");\n}\n```\n\nThe first step is to load the configuration of the training to fetch the correct model\nconfiguration. Then we can fetch the record using the same recorder as we used during training.\nFinally we can init the model with the configuration and the record. For simplicity we can use the\nsame batcher used during the training to pass from a MnistItem to a tensor.\n\nBy running the infer function, you should see the predictions of your model!\n\nAdd the call to `infer` to the `main.rs` file after the `train` function call:\n\n```rust , ignore\n# #![recursion_limit = \"256\"]\n# mod data;\n# mod inference;\n# mod model;\n# mod training;\n#\n# use crate::{model::ModelConfig, training::TrainingConfig};\n# use burn::{\n#     backend::{Autodiff, Wgpu},\n#     data::dataset::Dataset,\n#     optim::AdamConfig,\n# };\n#\n# fn main() {\n#     type MyBackend = Wgpu<f32, i32>;\n#     type MyAutodiffBackend = Autodiff<MyBackend>;\n#\n#     let device = burn::backend::wgpu::WgpuDevice::default();\n#     let artifact_dir = \"/tmp/guide\";\n#     crate::training::train::<MyAutodiffBackend>(\n#         artifact_dir,\n#         TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),\n#         device.clone(),\n#     );\n    crate::inference::infer::<MyBackend>(\n        artifact_dir,\n        device,\n        burn::data::dataset::vision::MnistDataset::test()\n            .get(42)\n            .unwrap(),\n    );\n# }\n```\n\nThe number `42` is the index of the image in the MNIST dataset. You can explore and verify them\nusing this [MNIST viewer](https://observablehq.com/@davidalber/mnist-viewer).\n\n---\n\nIn this short guide, we've introduced you to the fundamental building blocks for getting started\nwith Burn. While there's still plenty to explore, our goal has been to provide you with the\nessential knowledge to kickstart your productivity within the framework.\n"
  },
  {
    "path": "burn-book/src/basic-workflow/model.md",
    "content": "# Model\n\nThe first step is to create a project and add the different Burn dependencies. Start by creating a\nnew project with Cargo:\n\n```console\ncargo new guide\n```\n\nAs [mentioned previously](../getting-started.md#creating-a-burn-application), this will initialize\nyour `guide` project directory with a `Cargo.toml` and a `src/main.rs` file.\n\nIn the `Cargo.toml` file, add the `burn` dependency with `train`, `vision` and `wgpu` features.\nSince we disable the default features, we also want to enable `std`, `tui` (for the dashboard) and\n`fusion` for wgpu. Then run `cargo build` to build the project and import all the dependencies.\n\n```toml\n[package]\nname = \"guide\"\nversion = \"0.1.0\"\nedition = \"2024\"\n\n[dependencies]\n# Disable autotune default for convolutions\nburn = { version = \"~0.21\", features = [\"std\", \"tui\", \"train\", \"vision\", \"wgpu\", \"fusion\"], default-features = false }\n# burn = { version = \"~0.21\", features = [\"train\", \"vision\", \"wgpu\"] }\n```\n\nOur goal will be to create a basic convolutional neural network used for image classification. We\nwill keep the model simple by using two convolution layers followed by two linear layers, some\npooling and ReLU activations. We will also use dropout to improve training performance.\n\nLet us start by defining our model struct in a new file `src/model.rs`.\n\n```rust , ignore\nuse burn::{\n    nn::{\n        conv::{Conv2d, Conv2dConfig},\n        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},\n        Dropout, DropoutConfig, Linear, LinearConfig, Relu,\n    },\n    prelude::*,\n};\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n    pool: AdaptiveAvgPool2d,\n    dropout: Dropout,\n    linear1: Linear<B>,\n    linear2: Linear<B>,\n    activation: Relu,\n}\n```\n\nThere are two major things going on in this code sample.\n\n1. You can create a deep learning module with the `#[derive(Module)]` attribute on top of a struct.\n   This will generate the necessary code so that the struct implements the `Module` trait. This\n   trait will make your module both trainable and (de)serializable while adding related\n   functionalities. Like other attributes often used in Rust, such as `Clone`, `PartialEq` or\n   `Debug`, each field within the struct must also implement the `Module` trait.\n\n   <details>\n   <summary><strong>🦀 Trait</strong></summary>\n\n   Traits are a powerful and flexible Rust language feature. They provide a way to define shared\n   behavior for a particular type, which can be shared with other types.\n\n   A type's behavior consists of the methods called on that type. Since all `Module`s should\n   implement the same functionality, it is defined as a trait. Implementing a trait on a particular\n   type usually requires the user to implement the defined behaviors of the trait for their types,\n   though that is not the case here as explained above with the `derive` attribute. Check out the\n   [explainer below](#derive-attribute) to learn why.\n\n   For more details on traits, take a look at the\n   [associated chapter](https://doc.rust-lang.org/book/ch10-02-traits.html) in the Rust Book.\n   </details><br>\n\n   <details id=\"derive-attribute\">\n   <summary><strong>🦀 Derive Macro</strong></summary>\n\n   The `derive` attribute allows traits to be implemented easily by generating code that will\n   implement a trait with its own default implementation on the type that was annotated with the\n   `derive` syntax.\n\n   This is accomplished through a feature of Rust called\n   [procedural macros](https://doc.rust-lang.org/reference/procedural-macros.html), which allow us\n   to run code at compile time that operates over Rust syntax, both consuming and producing Rust\n   syntax. Using the attribute `#[my_macro]`, you can effectively extend the provided code. You will\n   see that the derive macro is very frequently employed to recursively implement traits, where the\n   implementation consists of the composition of all fields.\n\n   In this example, we want to derive the [`Module`](../building-blocks/module.md) and `Debug`\n   traits.\n\n   ```rust, ignore\n   #[derive(Module, Debug)]\n   pub struct MyCustomModule<B: Backend> {\n       linear1: Linear<B>,\n       linear2: Linear<B>,\n       activation: Relu,\n   }\n   ```\n\n   The basic `Debug` implementation is provided by the compiler to format a value using the `{:?}`\n   formatter. For ease of use, the `Module` trait implementation is automatically handled by Burn so\n   you don't have to do anything special. It essentially acts as parameter container.\n\n   For more details on derivable traits, take a look at the Rust\n   [appendix](https://doc.rust-lang.org/book/appendix-03-derivable-traits.html),\n   [reference](https://doc.rust-lang.org/reference/attributes/derive.html) or\n   [example](https://doc.rust-lang.org/rust-by-example/trait/derive.html).\n   </details><br>\n\n2. Note that the struct is generic over the [`Backend`](../building-blocks/backend.md) trait. The\n   backend trait abstracts the underlying low level implementations of tensor operations, allowing\n   your new model to run on any backend. Contrary to other frameworks, the backend abstraction isn't\n   determined by a compilation flag or a device type. This is important because you can extend the\n   functionalities of a specific backend (see\n   [backend extension section](../advanced/backend-extension)), and it allows for an innovative\n   [autodiff system](../building-blocks/autodiff.md). You can also change backend during runtime,\n   for instance to compute training metrics on a cpu backend while using a gpu one only to train the\n   model. In our example, the backend in use will be determined later on.\n\n   <details>\n   <summary><strong>🦀 Trait Bounds</strong></summary>\n\n   Trait bounds provide a way for generic items to restrict which types are used as their\n   parameters. The trait bounds stipulate what functionality a type implements. Therefore, bounding\n   restricts the generic to types that conform to the bounds. It also allows generic instances to\n   access the methods of traits specified in the bounds.\n\n   For a simple but concrete example, check out the\n   [Rust By Example on bounds](https://doc.rust-lang.org/rust-by-example/generics/bounds.html).\n\n   In Burn, the `Backend` trait enables you to run tensor operations using different implementations\n   as it abstracts tensor, device and element types. The\n   [getting started example](../getting-started.md#writing-a-code-snippet) illustrates the advantage\n   of having a simple API that works for different backend implementations. While it used the WGPU\n   backend, you could easily swap it with any other supported backend.\n\n   ```rust, ignore\n   // Choose from any of the supported backends.\n   // type Backend = Candle<f32, i64>;\n   // type Backend = LibTorch<f32>;\n   // type Backend = NdArray<f32>;\n   type Backend = Wgpu;\n\n   // Creation of two tensors.\n   let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);\n   let tensor_2 = Tensor::<Backend, 2>::ones_like(&tensor_1);\n\n   // Print the element-wise addition (done with the selected backend) of the two tensors.\n   println!(\"{}\", tensor_1 + tensor_2);\n   ```\n\n   For more details on trait bounds, check out the Rust\n   [trait bound section](https://doc.rust-lang.org/book/ch10-02-traits.html#trait-bound-syntax) or\n   [reference](https://doc.rust-lang.org/reference/items/traits.html#trait-bounds).\n\n   </details><br>\n\nNote that each time you create a new file in the `src` directory you also need to explicitly add\nthis module to the `main.rs` file. For instance after creating the `model.rs`, you need to add the\nfollowing at the top of the main file:\n\n```rust , ignore\nmod model;\n#\n# fn main() {\n# }\n```\n\nNext, we need to instantiate the model for training.\n\n```rust , ignore\n# use burn::{\n#     nn::{\n#         conv::{Conv2d, Conv2dConfig},\n#         pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},\n#         Dropout, DropoutConfig, Linear, LinearConfig, Relu,\n#     },\n#     prelude::*,\n# };\n#\n# #[derive(Module, Debug)]\n# pub struct Model<B: Backend> {\n#     conv1: Conv2d<B>,\n#     conv2: Conv2d<B>,\n#     pool: AdaptiveAvgPool2d,\n#     dropout: Dropout,\n#     linear1: Linear<B>,\n#     linear2: Linear<B>,\n#     activation: Relu,\n# }\n#\n#[derive(Config, Debug)]\npub struct ModelConfig {\n    num_classes: usize,\n    hidden_size: usize,\n    #[config(default = \"0.5\")]\n    dropout: f64,\n}\n\nimpl ModelConfig {\n    /// Returns the initialized model.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {\n        Model {\n            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),\n            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),\n            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),\n            activation: Relu::new(),\n            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),\n            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n        }\n    }\n}\n```\n\nAt a glance, you can view the model configuration by printing the model instance:\n\n```rust , ignore\n#![recursion_limit = \"256\"]\nmod model;\n\nuse crate::model::ModelConfig;\nuse burn::backend::Wgpu;\n\nfn main() {\n    type MyBackend = Wgpu<f32, i32>;\n\n    let device = Default::default();\n    let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);\n\n    println!(\"{model}\");\n}\n```\n\nOutput:\n\n```rust , ignore\nModel {\n  conv1: Conv2d {ch_in: 1, ch_out: 8, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}\n  conv2: Conv2d {ch_in: 8, ch_out: 16, stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}\n  pool: AdaptiveAvgPool2d {output_size: [8, 8]}\n  dropout: Dropout {prob: 0.5}\n  linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}\n  linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}\n  activation: Relu\n  params: 531178\n}\n```\n\n<details>\n<summary><strong>🦀 References</strong></summary>\n\nIn the previous example, the `init()` method signature uses `&` to indicate that the parameter types\nare references: `&self`, a reference to the current receiver (`ModelConfig`), and\n`device: &B::Device`, a reference to the backend device.\n\n```rust, ignore\npub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {\n    Model {\n        // ...\n    }\n}\n```\n\nReferences in Rust allow us to point to a resource to access its data without owning it. The idea of\nownership is quite core to Rust and is worth\n[reading up on](https://doc.rust-lang.org/book/ch04-00-understanding-ownership.html).\n\nIn a language like C, memory management is explicit and up to the programmer, which means it is easy\nto make mistakes. In a language like Java or Python, memory management is automatic with the help of\na garbage collector. This is very safe and straightforward, but also incurs a runtime cost.\n\nIn Rust, memory management is rather unique. Aside from simple types that implement\n[`Copy`](https://doc.rust-lang.org/std/marker/trait.Copy.html) (e.g.,\n[primitives](https://doc.rust-lang.org/rust-by-example/primitives.html) like integers, floats,\nbooleans and `char`), every value is _owned_ by some variable called the _owner_. Ownership can be\ntransferred from one variable to another and sometimes a value can be _borrowed_. Once the _owner_\nvariable goes out of scope, the value is _dropped_, which means that any memory it allocated can be\nfreed safely.\n\nBecause the method does not own the `self` and `device` variables, the values the references point\nto will not be dropped when the reference stops being used (i.e., the scope of the method).\n\nFor more information on references and borrowing, be sure to read the\n[corresponding chapter](https://doc.rust-lang.org/book/ch04-02-references-and-borrowing.html) in the\nRust Book.\n\n</details><br>\n\nWhen creating a custom neural network module, it is often a good idea to create a config alongside\nthe model struct. This allows you to define default values for your network, thanks to the `Config`\nattribute. The benefit of this attribute is that it makes the configuration serializable, enabling\nyou to painlessly save your model hyperparameters, enhancing your experimentation process. Note that\na constructor will automatically be generated for your configuration, which will take in as input\nvalues the parameters which do not have default values:\n`let config = ModelConfig::new(num_classes, hidden_size);`. The default values can be overridden\neasily with builder-like methods: (e.g `config.with_dropout(0.2);`)\n\nThe first implementation block is related to the initialization method. As we can see, all fields\nare set using the configuration of the corresponding neural network's underlying module. In this\nspecific case, we have chosen to expand the tensor channels from 1 to 8 with the first layer, then\nfrom 8 to 16 with the second layer, using a kernel size of 3 on all dimensions. We also use the\nadaptive average pooling module to reduce the dimensionality of the images to an 8 by 8 matrix,\nwhich we will flatten in the forward pass to have a 1024 (16 * 8 * 8) resulting tensor.\n\nNow let's see how the forward pass is defined.\n\n```rust , ignore\n# use burn::{\n#     nn::{\n#         conv::{Conv2d, Conv2dConfig},\n#         pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},\n#         Dropout, DropoutConfig, Linear, LinearConfig, Relu,\n#     },\n#     prelude::*,\n# };\n#\n# #[derive(Module, Debug)]\n# pub struct Model<B: Backend> {\n#     conv1: Conv2d<B>,\n#     conv2: Conv2d<B>,\n#     pool: AdaptiveAvgPool2d,\n#     dropout: Dropout,\n#     linear1: Linear<B>,\n#     linear2: Linear<B>,\n#     activation: Relu,\n# }\n#\n# #[derive(Config, Debug)]\n# pub struct ModelConfig {\n#     num_classes: usize,\n#     hidden_size: usize,\n#     #[config(default = \"0.5\")]\n#     dropout: f64,\n# }\n#\n# impl ModelConfig {\n#     /// Returns the initialized model.\n#     pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {\n#         Model {\n#             conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),\n#             conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),\n#             pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),\n#             activation: Relu::new(),\n#             linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),\n#             linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),\n#             dropout: DropoutConfig::new(self.dropout).init(),\n#         }\n#     }\n# }\n#\nimpl<B: Backend> Model<B> {\n    /// # Shapes\n    ///   - Images [batch_size, height, width]\n    ///   - Output [batch_size, num_classes]\n    pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = images.dims();\n\n        // Create a channel at the second dimension.\n        let x = images.reshape([batch_size, 1, height, width]);\n\n\n        let x = self.conv1.forward(x); // [batch_size, 8, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.conv2.forward(x); // [batch_size, 16, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        let x = self.pool.forward(x); // [batch_size, 16, 8, 8]\n        let x = x.reshape([batch_size, 16 * 8 * 8]);\n        let x = self.linear1.forward(x);\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        self.linear2.forward(x) // [batch_size, num_classes]\n    }\n}\n```\n\nFor former PyTorch users, this might feel very intuitive, as each module is directly incorporated\ninto the code using an eager API. Note that no abstraction is imposed for the forward method. You\nare free to define multiple forward functions with the names of your liking. Most of the neural\nnetwork modules already built with Burn use the `forward` nomenclature, simply because it is\nstandard in the field.\n\nSimilar to neural network modules, the [`Tensor`](../building-blocks/tensor.md) struct given as a\nparameter also takes the Backend trait as a generic argument, alongside its dimensionality. Even if\nit is not used in this specific example, it is possible to add the kind of the tensor as a third\ngeneric argument. For example, a 3-dimensional Tensor of different data types(float, int, bool)\nwould be defined as following:\n\n```rust , ignore\nTensor<B, 3> // Float tensor (default)\nTensor<B, 3, Float> // Float tensor (explicit)\nTensor<B, 3, Int> // Int tensor\nTensor<B, 3, Bool> // Bool tensor\n```\n\nNote that the specific element type, such as `f16`, `f32` and the likes, will be defined later with\nthe backend.\n"
  },
  {
    "path": "burn-book/src/basic-workflow/training.md",
    "content": "# Training\n\nWe are now ready to write the necessary code to train our model on the MNIST dataset. We shall\ndefine the code for this training section in the file: `src/training.rs`.\n\nInstead of a simple tensor, the model should output an item that can be understood by the learner, a\nstruct whose responsibility is to apply an optimizer to the model. The output struct is used for all\nmetrics calculated during the training. Therefore it should include all the necessary information to\ncalculate any metric that you want for a task.\n\nBurn provides two basic output types: `ClassificationOutput` and `RegressionOutput`. They implement\nthe necessary trait to be used with metrics. It is possible to create your own item, but it is\nbeyond the scope of this guide.\n\nSince the MNIST task is a classification problem, we will use the `ClassificationOutput` type.\n\n```rust , ignore\n# use crate::{\n#     data::{MnistBatch, MnistBatcher},\n#     model::{Model, ModelConfig},\n# };\n# use burn::{\n#     data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n#     nn::loss::CrossEntropyLossConfig,\n#     optim::AdamConfig,\n#     prelude::*,\n#     record::CompactRecorder,\n#     tensor::backend::AutodiffBackend,\n#     train::{\n#         ClassificationOutput, Learner, SupervisedTraining, TrainOutput, TrainStep, InferenceStep,\n#         metric::{AccuracyMetric, LossMetric},\n#     },\n# };\n# \nimpl<B: Backend> Model<B> {\n    pub fn forward_classification(\n        &self,\n        images: Tensor<B, 3>,\n        targets: Tensor<B, 1, Int>,\n    ) -> ClassificationOutput<B> {\n        let output = self.forward(images);\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output.device())\n            .forward(output.clone(), targets.clone());\n\n        ClassificationOutput::new(loss, output, targets)\n    }\n}\n```\n\nAs evident from the preceding code block, we employ the cross-entropy loss module for loss\ncalculation, without the inclusion of any padding token. We then return the classification output\ncontaining the loss, the output tensor with all logits and the targets.\n\nPlease take note that tensor operations receive owned tensors as input. For reusing a tensor\nmultiple times, you need to use the `clone()` function. There's no need to worry; this process won't\ninvolve actual copying of the tensor data. Instead, it will simply indicate that the tensor is\nemployed in multiple instances, implying that certain operations won't be performed in place. In\nsummary, our API has been designed with owned tensors to optimize performance.\n\nMoving forward, we will proceed with the implementation of both the training and validation steps\nfor our model.\n\n```rust , ignore\n# use crate::{\n#     data::{MnistBatch, MnistBatcher},\n#     model::{Model, ModelConfig},\n# };\n# use burn::{\n#     data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n#     nn::loss::CrossEntropyLossConfig,\n#     optim::AdamConfig,\n#     prelude::*,\n#     record::CompactRecorder,\n#     tensor::backend::AutodiffBackend,\n#     train::{\n#         ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep,\n#         metric::{AccuracyMetric, LossMetric},\n#     },\n# };\n# \n# impl<B: Backend> Model<B> {\n#     pub fn forward_classification(\n#         &self,\n#         images: Tensor<B, 3>,\n#         targets: Tensor<B, 1, Int>,\n#     ) -> ClassificationOutput<B> {\n#         let output = self.forward(images);\n#         let loss = CrossEntropyLossConfig::new()\n#             .init(&output.device())\n#             .forward(output.clone(), targets.clone());\n# \n#         ClassificationOutput::new(loss, output, targets)\n#     }\n# }\nimpl<B: AutodiffBackend> TrainStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_classification(batch.images, batch.targets);\n\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {\n        self.forward_classification(batch.images, batch.targets)\n    }\n}\n```\n\nHere we define the input and output types as generic arguments in the `TrainStep` and `InferenceStep`.\nWe will call them `MnistBatch` and `ClassificationOutput`. In the training step, the computation of\ngradients is straightforward, necessitating a simple invocation of `backward()` on the loss. Note\nthat contrary to PyTorch, gradients are not stored alongside each tensor parameter, but are rather\nreturned by the backward pass, as such: `let gradients = loss.backward();`. The gradient of a\nparameter can be obtained with the grad function: `let grad = tensor.grad(&gradients);`. Although it\nis not necessary when using the learner struct and the optimizers, it can prove to be quite useful\nwhen debugging or writing custom training loops. One of the differences between the training and the\nvalidation steps is that the former requires the backend to implement `AutodiffBackend` and not just\n`Backend`. Otherwise, the `backward` function is not available, as the backend does not support\nautodiff. We will see later how to create a backend with autodiff support.\n\n<details>\n<summary><strong>🦀 Generic Type Constraints in Method Definitions</strong></summary>\n\nAlthough generic data types, trait and trait bounds were already introduced in previous sections of\nthis guide, the previous code snippet might be a lot to take in at first.\n\nIn the example above, we implement the `TrainStep` and `InferenceStep` trait for our `Model` struct,\nwhich is generic over the `Backend` trait as has been covered before. These traits are provided by\n`burn::train` and define a common `step` method that should be implemented for all structs. Since\nthe trait is generic over the input and output types, the trait implementation must specify the\nconcrete types used. This is where the additional type constraints appear\n`<MnistBatch<B>, ClassificationOutput<B>>`. As we saw previously, the concrete input type for the\nbatch is `MnistBatch`, and the output of the forward pass is `ClassificationOutput`. The `step`\nmethod signature matches the concrete input and output types.\n\nFor more details specific to constraints on generic types when defining methods, take a look at\n[this section](https://doc.rust-lang.org/book/ch10-01-syntax.html#in-method-definitions) of the Rust\nBook.\n\n</details><br>\n\nLet us move on to establishing the practical training configuration.\n\n```rust , ignore\n# use crate::{\n#     data::{MnistBatch, MnistBatcher},\n#     model::{Model, ModelConfig},\n# };\n# use burn::{\n#     data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n#     nn::loss::CrossEntropyLossConfig,\n#     optim::AdamConfig,\n#     prelude::*,\n#     record::CompactRecorder,\n#     tensor::backend::AutodiffBackend,\n#     train::{\n#         ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep,\n#         metric::{AccuracyMetric, LossMetric},\n#     },\n# };\n# \n# impl<B: Backend> Model<B> {\n#     pub fn forward_classification(\n#         &self,\n#         images: Tensor<B, 3>,\n#         targets: Tensor<B, 1, Int>,\n#     ) -> ClassificationOutput<B> {\n#         let output = self.forward(images);\n#         let loss = CrossEntropyLossConfig::new()\n#             .init(&output.device())\n#             .forward(output.clone(), targets.clone());\n# \n#         ClassificationOutput::new(loss, output, targets)\n#     }\n# }\n# impl<B: AutodiffBackend> TrainStep for Model<B> {\n#     type Input = MnistBatch<B>;\n#     type Output = ClassificationOutput<B>;\n# \n#     fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n#         let item = self.forward_classification(batch.images, batch.targets);\n# \n#         TrainOutput::new(self, item.loss.backward(), item)\n#     }\n# }\n#\n# impl<B: Backend> InferenceStep for Model<B> {\n#     type Input = MnistBatch<B>;\n#     type Output = ClassificationOutput<B>;\n# \n#     fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {\n#         self.forward_classification(batch.images, batch.targets)\n#     }\n# }\n#\n#[derive(Config, Debug)]\npub struct TrainingConfig {\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n    #[config(default = 10)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1.0e-4)]\n    pub learning_rate: f64,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {\n    create_artifact_dir(artifact_dir);\n    config\n        .save(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should be saved successfully\");\n\n    B::seed(&device, config.seed);\n\n    let batcher = MnistBatcher::default();\n\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::test());\n\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metrics((AccuracyMetric::new(), LossMetric::new()))\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let model = config.model.init::<B>(&device);\n    let result = training.launch(Learner::new(\n        model,\n        config.optimizer.init(),\n        config.learning_rate,\n    ));\n\n    result\n        .model\n        .save_file(format!(\"{artifact_dir}/model\"), &CompactRecorder::new())\n        .expect(\"Trained model should be saved successfully\");\n}\n```\n\nIt is a good practice to use the `Config` derive to create the experiment configuration. In the\n`train` function, the first thing we are doing is making sure the `artifact_dir` exists, using the\nstandard rust library for file manipulation. All checkpoints, logging and metrics will be stored\nunder this directory. We initialize the dataloaders using the previously created batcher. Since no\nautomatic differentiation is needed during the validation phase, the `training.launch(...)` method\ndefines the necessary backend bounds on the data loader for `B::InnerBackend` (see\n[Backend](./backend.md)). The autodiff capabilities are available through a type system, making it\nnearly impossible to forget to deactivate gradient calculation.\n\nNext, we create a supervised training runner with the dataloaders for training and validation and\nwe register the accuracy and loss metric on both training and validation steps. We also configure the\ncheckpointer using the `CompactRecorder` to indicate how weights should be stored. This struct implements the `Recorder` trait, which makes\nit capable of saving records for persistency.\n\nFor the sake of simplicity in this example, we employ the test set as the validation\nset; however, we do not recommend this practice for actual usage.\n\nWe create the learner containing the model, the optimizer and the learning rate. Notably, the third\nargument of the learner's `new` function should actually be a learning rate _scheduler_. When provided with a\nfloat as in our example, it is automatically transformed into a _constant_ learning rate scheduler.\nThe learning rate is not part of the optimizer config as it is often done in other frameworks, but\nrather passed as a parameter when executing the optimizer step. This avoids having to mutate the\nstate of the optimizer and is therefore more functional. It makes no difference when using the\nlearner struct, but it will be an essential nuance to grasp if you implement your own training loop.\n\nOnce the learner and supervised training instance are created, we can call `training.launch` and provide the learner.\n\nFinally, the trained model is returned by the `launch` method. The trained weights are then saved using\nthe `CompactRecorder`. This recorder employs the `MessagePack` format with half precision, `f16` for\nfloats and `i16` for integers. Other recorders are available, offering support for various formats,\nsuch as `BinCode` and `JSON`, with or without compression. Any backend, regardless of precision, can\nload recorded data of any kind.\n"
  },
  {
    "path": "burn-book/src/building-blocks/README.md",
    "content": "# Building Blocks\n\nIn this section, we'll guide you through the core elements that make up Burn. We'll walk you through\nthe key components that serve as the building blocks of the framework and your future projects.\n\nAs you explore Burn, you might notice that we occasionally draw comparisons to PyTorch. We believe\nit can provide a smoother learning curve and help you grasp the nuances more effectively.\n"
  },
  {
    "path": "burn-book/src/building-blocks/autodiff.md",
    "content": "# Autodiff\n\nBurn's tensor also supports auto-differentiation, which is an essential part of any deep learning\nframework. We introduced the `Backend` trait in the [previous section](./backend.md), but Burn also\nhas another trait for autodiff: `AutodiffBackend`.\n\nHowever, not all tensors support auto-differentiation; you need a backend that implements both the\n`Backend` and `AutodiffBackend` traits. Fortunately, you can add auto-differentiation capabilities to any\nbackend using a backend decorator: `type MyAutodiffBackend = Autodiff<MyBackend>`. This\ndecorator implements both the `AutodiffBackend` and `Backend` traits by maintaining a dynamic\ncomputational graph and utilizing the inner backend to execute tensor operations.\n\nThe `AutodiffBackend` trait adds new operations on float tensors that can't be called otherwise. It also\nprovides a new associated type, `B::Gradients`, where each calculated gradient resides.\n\n```rust, ignore\nfn calculate_gradients<B: AutodiffBackend>(tensor: Tensor<B, 2>) -> B::Gradients {\n    let mut gradients = tensor.clone().backward();\n\n    let tensor_grad = tensor.grad(&gradients);        // get\n    let tensor_grad = tensor.grad_remove(&mut gradients); // pop\n\n    gradients\n}\n```\n\nNote that some functions will always be available even if the backend doesn't implement the\n`AutodiffBackend` trait. In such cases, those functions will do nothing.\n\n| Burn API                                | PyTorch Equivalent            |\n| --------------------------------------- | ----------------------------- |\n| `tensor.detach()`                       | `tensor.detach()`             |\n| `tensor.require_grad()`                 | `tensor.requires_grad()`      |\n| `tensor.is_require_grad()`              | `tensor.requires_grad`        |\n| `tensor.set_require_grad(require_grad)` | `tensor.requires_grad(False)` |\n\nHowever, you're unlikely to make any mistakes since you can't call `backward` on a tensor that is on\na backend that doesn't implement `AutodiffBackend`. Additionally, you can't retrieve the gradient of a\ntensor without an autodiff backend.\n\n## Difference with PyTorch\n\nThe way Burn handles gradients is different from PyTorch. First, when calling `backward`, each\nparameter doesn't have its `grad` field updated. Instead, the backward pass returns all the\ncalculated gradients in a container. This approach offers numerous benefits, such as the ability to\neasily send gradients to other threads.\n\nYou can also retrieve the gradient for a specific parameter using the `grad` method on a tensor.\nSince this method takes the gradients as input, it's hard to forget to call `backward` beforehand.\nNote that sometimes, using `grad_remove` can improve performance by allowing inplace operations.\n\nIn PyTorch, when you don't need gradients for inference or validation, you typically need to scope\nyour code using a block.\n\n```python\n# Inference mode\ntorch.inference():\n   # your code\n   ...\n\n# Or no grad\ntorch.no_grad():\n   # your code\n   ...\n```\n\nWith Burn, you don't need to wrap the backend with the `Autodiff` for inference, and you\ncan call `inner()` to obtain the inner tensor, which is useful for validation.\n\n```rust, ignore\n/// Use `B: AutodiffBackend`\nfn example_validation<B: AutodiffBackend>(tensor: Tensor<B, 2>) {\n    let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();\n    let _ = inner_tensor + 5;\n}\n\n/// Use `B: Backend`\nfn example_inference<B: Backend>(tensor: Tensor<B, 2>) {\n    let _ = tensor + 5;\n    ...\n}\n```\n\n**Gradients with Optimizers**\n\nWe've seen how gradients can be used with tensors, but the process is a bit different when working\nwith optimizers from `burn-core`. To work with the `Module` trait, a translation step is required to\nlink tensor parameters with their gradients. This step is necessary to easily support gradient\naccumulation and training on multiple devices, where each module can be forked and run on different\ndevices in parallel. We'll explore deeper into this topic in the [Module](./module.md) section.\n"
  },
  {
    "path": "burn-book/src/building-blocks/backend.md",
    "content": "# Backend\n\nNearly everything in Burn is based on the `Backend` trait, which enables you to run tensor\noperations using different implementations without having to modify your code. While a backend may\nnot necessarily have autodiff capabilities, the `AutodiffBackend` trait specifies when autodiff is\nneeded. This trait not only abstracts operations but also tensor, device, and element types,\nproviding each backend the flexibility they need. It's worth noting that the trait assumes eager\nmode since burn fully supports dynamic graphs. However, we may create another API to assist with\nintegrating graph-based backends, without requiring any changes to the user's code.\n\nUsers are not expected to directly use the backend trait methods, as it is primarily designed with\nbackend developers in mind rather than Burn users. Therefore, most Burn userland APIs are generic\nacross backends. This approach helps users discover the API more organically with proper\nautocomplete and documentation.\n"
  },
  {
    "path": "burn-book/src/building-blocks/config.md",
    "content": "# Config\n\nWhen writing scientific code, you normally have a lot of values that are set, and Deep Learning is\nno exception. Python has the possibility to define default parameters for functions, which helps\nimprove the developer experience. However, this has the downside of potentially breaking your code\nwhen upgrading to a new version, as the default values might change without your knowledge, making\ndebugging very challenging.\n\nWith that in mind, we came up with the Config system. It's a simple Rust derive that you can apply\nto your types, allowing you to define default values with ease. Additionally, all configs can be\nserialized, reducing potential bugs when upgrading versions and improving reproducibility.\n\n```rust , ignore\nuse burn::config::Config;\n\n#[derive(Config)]\npub struct MyModuleConfig {\n    d_model: usize,\n    d_ff: usize,\n    #[config(default = 0.1)]\n    dropout: f64,\n}\n```\n\nThe derive also adds useful `with_` methods for every attribute of your config, similar to a builder\npattern, along with a `save` method.\n\n```rust, ignore\nfn main() {\n    let config = MyModuleConfig::new(512, 2048);\n    println!(\"{}\", config.d_model); // 512\n    println!(\"{}\", config.d_ff); // 2048\n    println!(\"{}\", config.dropout); // 0.1\n    let config =  config.with_dropout(0.2);\n    println!(\"{}\", config.dropout); // 0.2\n\n    config.save(\"config.json\").unwrap();\n}\n```\n\n## Good practices\n\nBy using the config type it is easy to create new module instances. The initialization method should\nbe implemented on the config type with the device as argument.\n\n```rust, ignore\nimpl MyModuleConfig {\n    /// Create a module on the given device.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> MyModule {\n        MyModule {\n            linear: LinearConfig::new(self.d_model, self.d_ff).init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n        }\n    }\n}\n```\n\nThen we could add this line to the above `main`:\n\n```rust, ignore\nuse burn::backend::Wgpu;\nlet device = Default::default();\nlet my_module = config.init::<Wgpu>(&device);\n```\n"
  },
  {
    "path": "burn-book/src/building-blocks/dataset.md",
    "content": "# Dataset\n\nAt its core, a dataset is a collection of data typically related to a specific analysis or\nprocessing task. The data modality can vary depending on the task, but most datasets primarily\nconsist of images, texts, audio or videos.\n\nThis data source represents an integral part of machine learning to successfully train a model.\nThus, it is essential to provide a convenient and performant API to handle your data. Since this\nprocess varies wildly from one problem to another, it is defined as a trait that should be\nimplemented on your type. The dataset trait is quite similar to the dataset abstract class in\nPyTorch:\n\n```rust, ignore\npub trait Dataset<I>: Send + Sync {\n    fn get(&self, index: usize) -> Option<I>;\n    fn len(&self) -> usize;\n}\n```\n\nThe dataset trait assumes a fixed-length set of items that can be randomly accessed in constant\ntime. This is a major difference from datasets that use Apache Arrow underneath to improve streaming\nperformance. Datasets in Burn don't assume _how_ they are going to be accessed; it's just a\ncollection of items.\n\nHowever, you can compose multiple dataset transformations to lazily obtain what you want with zero\npre-processing, so that your training can start instantly!\n\n## Transformation\n\nTransformations in Burn are all lazy and modify one or multiple input datasets. The goal of these\ntransformations is to provide you with the necessary tools so that you can model complex data\ndistributions.\n\n| Transformation     | Description                                                                                                              |\n| ------------------ | ------------------------------------------------------------------------------------------------------------------------ |\n| `SamplerDataset`   | Samples items from a dataset. This is a convenient way to model a dataset as a probability distribution of a fixed size. |\n| `SelectionDataset` | Selects a subset of items by index from a dataset. Can be randomly shuffled; can be re-shuffled.                         |\n| `ShuffledDataset`  | Shuffles a wrapped dataset; This is a thin wrapper around `SelectionDataset`.                                            |\n| `PartialDataset`   | Returns a view of the input dataset with a specified range.                                                              |\n| `MapperDataset`    | Computes a transformation lazily on the input dataset.                                                                   |\n| `ComposedDataset`  | Composes multiple datasets together to create a larger one without copying any data.                                     |\n| `WindowsDataset`   | Dataset designed to work with overlapping windows of data extracted from an input dataset.                               |\n\nLet us look at the basic usages of each dataset transform and how they can be composed together.\nThese transforms are lazy by default except when specified, reducing the need for unnecessary\nintermediate allocations and improving performance. The full documentation of each transform can be\nfound at the [API reference](https://burn.dev/docs/burn/data/dataset/transform/index.html).\n\n- **SamplerDataset**: This transform can be used to sample items from a dataset with (default) or\n  without replacement. Transform is initialized with a sampling size which can be bigger or smaller\n  than the input dataset size. This is particularly useful in cases where we want to checkpoint\n  larger datasets more often during training and smaller datasets less often as the size of an epoch\n  is now controlled by the sampling size. Sample usage:\n\n```rust, ignore\ntype DbPedia = SqliteDataset<DbPediaItem>;\nlet dataset: DbPedia = HuggingfaceDatasetLoader::new(\"dbpedia_14\")\n        .dataset(\"train\").\n        .unwrap();\n\nlet dataset = SamplerDataset<DbPedia, DbPediaItem>::new(dataset, 10000);\n```\n\n- **SelectionDataset**: This transform can be used to select a subset of items from a dataset by\n  index. It can be initialized with a list of indices to select from the input dataset. This is\n  particularly useful when you want to create a smaller dataset from a larger one, for example, to\n  create a validation set from a training set.\n\n  The `SelectionDataset` can also be initialized with a random seed to shuffle the indices before\n  selection. This is useful when you want to randomly select a subset of items from the dataset.\n\n  Base dataset items may be included more than once in the selection.\n\n```rust, ignore\nlet explicit = SelectionDataset::from_indices_checked(dataset.clone(), vec![0, 1, 2, 0]);\n\nlet shuffled = SelectionDataset::new_shuffled(dataset.clone(), &mut rng);\nlet shuffled = SelectionDataset::new_shuffled(dataset.clone(), 42);\n\nlet mut mutable = SelectionDataset::new_select_all(dataset.clone(), vec![0, 1, 2, 0]);\nmutable.shuffle(42);\nmutable.shuffle(&mut rng);\n```\n\n- **ShuffledDataset**: This transform can be used to shuffle the items of a dataset. Particularly\n  useful before splitting the raw dataset into train/test splits. Can be initialized with a seed to\n  ensure reproducibility.\n\n  The `ShuffledDataset` is a thin wrapper around the `SelectionDataset`.\n\n```rust, ignore\nlet dataset = ShuffledDataset<DbPedia, DbPediaItem>::new(dataset, &mut rng);\nlet dataset = ShuffledDataset<DbPedia, DbPediaItem>::new(dataset, 42);\n```\n\n- **PartialDataset**: This transform is useful to return a view of the dataset with specified start\n  and end indices. Used to create train/val/test splits. In the example below, we show how to chain\n  ShuffledDataset and PartialDataset to create splits.\n\n```rust, ignore\n// define chained dataset type here for brevity\ntype PartialData = PartialDataset<ShuffledDataset<DbPedia, DbPediaItem>>;\nlet len = dataset.len();\nlet split = \"train\"; // or \"val\"/\"test\"\n\nlet data_split = match split {\n    \"train\" => PartialData::new(dataset, 0, len * 8 / 10),  // Get first 80% dataset\n    \"test\" => PartialData::new(dataset, len * 8 / 10, len), // Take remaining 20%\n    _ => panic!(\"Invalid split type\"),                      // Handle unexpected split types\n};\n```\n\n- **MapperDataset**: This transform is useful to apply a transformation on each of the items of a\n  dataset. Particularly useful for normalization of image data when channel means are known.\n\n- **ComposedDataset**: This transform is useful to compose multiple datasets downloaded from\n  multiple sources (say different HuggingfaceDatasetLoader sources) into a single bigger dataset\n  which can be sampled from one source.\n\n- **WindowsDataset**: This transform is useful to create overlapping windows of a dataset.\n  Particularly useful for sequential Time series Data, for example when working with an LSTM.\n\n## Storage\n\nThere are multiple dataset storage options available for you to choose from. The choice of the\ndataset to use should be based on the dataset's size as well as its intended purpose.\n\n| Storage            | Description                                                                                                                                          |\n| ------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `InMemDataset`     | In-memory dataset that uses a vector to store items. Well-suited for smaller datasets.                                                               |\n| `SqliteDataset`    | Dataset that uses [SQLite](https://www.sqlite.org/) to index items that can be saved in a simple SQL database file. Well-suited for larger datasets. |\n| `DataframeDataset` | Dataset that uses [Polars](https://www.pola.rs/) dataframe to store and manage data. Well-suited for efficient data manipulation and analysis.       |\n\n## Sources\n\nFor now, there are only a couple of dataset sources available with Burn, but more to come!\n\n### Hugging Face\n\nYou can easily import any Hugging Face dataset with Burn. We use SQLite as the storage to avoid\ndownloading the model each time or starting a Python process. You need to know the format of each\nitem in the dataset beforehand. Here's an example with the\n[dbpedia dataset](https://huggingface.co/datasets/dbpedia_14).\n\n```rust, ignore\n#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]\npub struct DbPediaItem {\n    pub title: String,\n    pub content: String,\n    pub label: usize,\n}\n\nfn main() {\n    let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new(\"dbpedia_14\")\n        .dataset(\"train\") // The training split.\n        .unwrap();\n}\n```\n\nWe see that items must derive `serde::Serialize`, `serde::Deserialize`, `Clone`, and `Debug`, but\nthose are the only requirements.\n\n<div class=\"warning\">\n\nThe `HuggingfaceDatasetLoader` relies on the\n[`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download\ndatasets. This is a Python library, so you must have an existing Python installation to use this\nloader.\n\n</div>\n\n### Images\n\n`ImageFolderDataset` is a generic vision dataset used to load images from disk. It is currently\navailable for multi-class and multi-label classification tasks as well as semantic segmentation and\nobject detection tasks.\n\n```rust, ignore\n// Create an image classification dataset from the root folder,\n// where images for each class are stored in their respective folder.\n//\n// For example:\n// root/dog/dog1.png\n// root/dog/dog2.png\n// ...\n// root/cat/cat1.png\nlet dataset = ImageFolderDataset::new_classification(\"path/to/dataset/root\").unwrap();\n```\n\n```rust, ignore\n// Create a multi-label image classification dataset from a list of items,\n// where each item is a tuple `(image path, labels)`, and a list of classes\n// in the dataset.\n//\n// For example:\nlet items = vec![\n    (\"root/dog/dog1.png\", vec![\"animal\".to_string(), \"dog\".to_string()]),\n    (\"root/cat/cat1.png\", vec![\"animal\".to_string(), \"cat\".to_string()]),\n];\nlet dataset = ImageFolderDataset::new_multilabel_classification_with_items(\n    items,\n    &[\"animal\", \"cat\", \"dog\"],\n)\n.unwrap();\n```\n\n```rust, ignore\n// Create a segmentation mask dataset from a list of items, where each\n// item is a tuple `(image path, mask path)` and a list of classes\n// corresponding to the integer values in the mask.\nlet items = vec![\n    (\n        \"path/to/images/image0.png\",\n        \"path/to/annotations/mask0.png\",\n    ),\n    (\n        \"path/to/images/image1.png\",\n        \"path/to/annotations/mask1.png\",\n    ),\n    (\n        \"path/to/images/image2.png\",\n        \"path/to/annotations/mask2.png\",\n    ),\n];\nlet dataset = ImageFolderDataset::new_segmentation_with_items(\n    items,\n    &[\n        \"cat\", // 0\n        \"dog\", // 1\n        \"background\", // 2\n    ],\n)\n.unwrap();\n```\n\n```rust, ignore\n// Create an object detection dataset from a COCO dataset. Currently only\n// the import of object detection data (bounding boxes) is supported.\n//\n// COCO offers separate annotation and image archives for training and\n// validation, paths to the unpacked files need to be passed as parameters:\n\nlet dataset = ImageFolderDataset::new_coco_detection(\n    \"/path/to/coco/instances_train2017.json\",\n    \"/path/to/coco/images/train2017\"\n)\n.unwrap();\n\n```\n\n### Comma-Separated Values (CSV)\n\nLoading records from a simple CSV file in-memory is simple with the `InMemDataset`:\n\n```rust, ignore\n// Build dataset from csv with tab ('\\t') delimiter.\n// The reader can be configured for your particular file.\nlet mut rdr = csv::ReaderBuilder::new();\nlet rdr = rdr.delimiter(b'\\t');\n\nlet dataset = InMemDataset::from_csv(\"path/to/csv\", rdr).unwrap();\n```\n\nNote that this requires the `csv` crate.\n\n**What about streaming datasets?**\n\nThere is no streaming dataset API with Burn, and this is by design! The learner struct will iterate\nmultiple times over the dataset and only checkpoint when done. You can consider the length of the\ndataset as the number of iterations before performing checkpointing and running the validation.\nThere is nothing stopping you from returning different items even when called with the same `index`\nmultiple times.\n\n## How Is The Dataset Used?\n\nDuring training, the dataset is used to access the data samples and, for most use cases in\nsupervised learning, their corresponding ground-truth labels. Remember that the `Dataset` trait\nimplementation is responsible to retrieve the data from its source, usually some sort of data\nstorage. At this point, the dataset could be naively iterated over to provide the model a single\nsample to process at a time, but this is not very efficient.\n\nInstead, we collect multiple samples that the model can process as a _batch_ to fully leverage\nmodern hardware (e.g., GPUs - which have impressive parallel processing capabilities). Since each\ndata sample in the dataset can be collected independently, the data loading is typically done in\nparallel to further speed things up. In this case, we parallelize the data loading using a\nmulti-threaded `BatchDataLoader` to obtain a sequence of items from the `Dataset` implementation.\nFinally, the sequence of items is combined into a batched tensor that can be used as input to a\nmodel with the `Batcher` trait implementation. Other tensor operations can be performed during this\nstep to prepare the batch data, as is done [in the basic workflow guide](../basic-workflow/data.md).\nThe process is illustrated in the figure below for the MNIST dataset.\n\n<img title=\"Burn Data Loading Pipeline\" alt=\"Burn Data Loading Pipeline\" src=\"./dataset.png\">\n\nAlthough we have conveniently implemented the\n[`MnistDataset`](https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/vision/mnist.rs)\nused in the guide, we'll go over its implementation to demonstrate how the `Dataset` and `Batcher`\ntraits are used.\n\nThe [MNIST dataset](http://yann.lecun.com/exdb/mnist/) of handwritten digits has a training set of\n60,000 examples and a test set of 10,000 examples. A single item in the dataset is represented by a\n\\\\(28 \\times 28\\\\) pixels black-and-white image (stored as raw bytes) with its corresponding label\n(a digit between \\\\(0\\\\) and \\\\(9\\\\)). This is defined by the `MnistItemRaw` struct.\n\n```rust, ignore\n# #[derive(Deserialize, Debug, Clone)]\nstruct MnistItemRaw {\n    pub image_bytes: Vec<u8>,\n    pub label: u8,\n}\n```\n\nWith single-channel images of such low resolution, the entire training and test sets can be loaded\nin memory at once. Therefore, we leverage the already existing `InMemDataset` to retrieve the raw\nimages and labels data. At this point, the image data is still just a bunch of bytes, but we want to\nretrieve the _structured_ image data in its intended form. For that, we can define a `MapperDataset`\nthat transforms the raw image bytes to a 2D array image (which we convert to float while we're at\nit).\n\n```rust, ignore\nconst WIDTH: usize = 28;\nconst HEIGHT: usize = 28;\n\n# /// MNIST item.\n# #[derive(Deserialize, Serialize, Debug, Clone)]\npub struct MnistItem {\n    /// Image as a 2D array of floats.\n    pub image: [[f32; WIDTH]; HEIGHT],\n\n    /// Label of the image.\n    pub label: u8,\n}\n\nstruct BytesToImage;\n\nimpl Mapper<MnistItemRaw, MnistItem> for BytesToImage {\n    /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).\n    fn map(&self, item: &MnistItemRaw) -> MnistItem {\n        // Ensure the image dimensions are correct.\n        debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);\n\n        // Convert the image to a 2D array of floats.\n        let mut image_array = [[0f32; WIDTH]; HEIGHT];\n        for (i, pixel) in item.image_bytes.iter().enumerate() {\n            let x = i % WIDTH;\n            let y = i / HEIGHT;\n            image_array[y][x] = *pixel as f32;\n        }\n\n        MnistItem {\n            image: image_array,\n            label: item.label,\n        }\n    }\n}\n\ntype MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;\n\n# /// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000\n# /// images per class. There are 60,000 training images and 10,000 test images.\n# ///\n# /// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).\npub struct MnistDataset {\n    dataset: MappedDataset,\n}\n```\n\nTo construct the `MnistDataset`, the data source must be parsed into the expected `MappedDataset`\ntype. Since both the train and test sets use the same file format, we can separate the functionality\nto load the `train()` and `test()` dataset.\n\n```rust, ignore\n\nimpl MnistDataset {\n    /// Creates a new train dataset.\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    /// Creates a new test dataset.\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    fn new(split: &str) -> Self {\n        // Download dataset\n        let root = MnistDataset::download(split);\n\n        // Parse data as vector of images bytes and vector of labels\n        let images: Vec<Vec<u8>> = MnistDataset::read_images(&root, split);\n        let labels: Vec<u8> = MnistDataset::read_labels(&root, split);\n\n        // Collect as vector of MnistItemRaw\n        let items: Vec<_> = images\n            .into_iter()\n            .zip(labels)\n            .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })\n            .collect();\n\n        // Create the MapperDataset for InMemDataset<MnistItemRaw> to transform\n        // items (MnistItemRaw -> MnistItem)\n        let dataset = InMemDataset::new(items);\n        let dataset = MapperDataset::new(dataset, BytesToImage);\n\n        Self { dataset }\n    }\n\n#    /// Download the MNIST dataset files from the web.\n#    /// Panics if the download cannot be completed or the content of the file cannot be written to disk.\n#    fn download(split: &str) -> PathBuf {\n#        // Dataset files are stored in the burn-dataset cache directory\n#        let cache_dir = dirs::cache_dir()\n#            .expect(\"Could not get cache directory\")\n#            .join(\"burn-dataset\");\n#        let split_dir = cache_dir.join(\"mnist\").join(split);\n#\n#        if !split_dir.exists() {\n#            create_dir_all(&split_dir).expect(\"Failed to create base directory\");\n#        }\n#\n#        // Download split files\n#        match split {\n#            \"train\" => {\n#                MnistDataset::download_file(TRAIN_IMAGES, &split_dir);\n#                MnistDataset::download_file(TRAIN_LABELS, &split_dir);\n#            }\n#            \"test\" => {\n#                MnistDataset::download_file(TEST_IMAGES, &split_dir);\n#                MnistDataset::download_file(TEST_LABELS, &split_dir);\n#            }\n#            _ => panic!(\"Invalid split specified {}\", split),\n#        };\n#\n#        split_dir\n#    }\n#\n#    /// Download a file from the MNIST dataset URL to the destination directory.\n#    /// File download progress is reported with the help of a [progress bar](indicatif).\n#    fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {\n#        // Output file name\n#        let file_name = dest_dir.as_ref().join(name);\n#\n#        if !file_name.exists() {\n#            // Download gzip file\n#            let bytes = download_file_as_bytes(&format!(\"{URL}{name}.gz\"), name);\n#\n#            // Create file to write the downloaded content to\n#            let mut output_file = File::create(&file_name).unwrap();\n#\n#            // Decode gzip file content and write to disk\n#            let mut gz_buffer = GzDecoder::new(&bytes[..]);\n#            std::io::copy(&mut gz_buffer, &mut output_file).unwrap();\n#        }\n#\n#        file_name\n#    }\n#\n#    /// Read images at the provided path for the specified split.\n#    /// Each image is a vector of bytes.\n#    fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {\n#        let file_name = if split == \"train\" {\n#            TRAIN_IMAGES\n#        } else {\n#            TEST_IMAGES\n#        };\n#        let file_name = root.as_ref().join(file_name);\n#\n#        // Read number of images from 16-byte header metadata\n#        let mut f = File::open(file_name).unwrap();\n#        let mut buf = [0u8; 4];\n#        let _ = f.seek(SeekFrom::Start(4)).unwrap();\n#        f.read_exact(&mut buf)\n#            .expect(\"Should be able to read image file header\");\n#        let size = u32::from_be_bytes(buf);\n#\n#        let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];\n#        let _ = f.seek(SeekFrom::Start(16)).unwrap();\n#        f.read_exact(&mut buf_images)\n#            .expect(\"Should be able to read image file header\");\n#\n#        buf_images\n#            .chunks(WIDTH * HEIGHT)\n#            .map(|chunk| chunk.to_vec())\n#            .collect()\n#    }\n#\n#    /// Read labels at the provided path for the specified split.\n#    fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {\n#        let file_name = if split == \"train\" {\n#            TRAIN_LABELS\n#        } else {\n#            TEST_LABELS\n#        };\n#        let file_name = root.as_ref().join(file_name);\n#\n#        // Read number of labels from 8-byte header metadata\n#        let mut f = File::open(file_name).unwrap();\n#        let mut buf = [0u8; 4];\n#        let _ = f.seek(SeekFrom::Start(4)).unwrap();\n#        f.read_exact(&mut buf)\n#            .expect(\"Should be able to read label file header\");\n#        let size = u32::from_be_bytes(buf);\n#\n#        let mut buf_labels: Vec<u8> = vec![0u8; size as usize];\n#        let _ = f.seek(SeekFrom::Start(8)).unwrap();\n#        f.read_exact(&mut buf_labels)\n#            .expect(\"Should be able to read labels from file\");\n#\n#        buf_labels\n#    }\n}\n```\n\nSince the `MnistDataset` simply wraps a `MapperDataset` instance with `InMemDataset`, we can easily\nimplement the `Dataset` trait.\n\n```rust, ignore\nimpl Dataset<MnistItem> for MnistDataset {\n    fn get(&self, index: usize) -> Option<MnistItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n```\n\nThe only thing missing now is the `Batcher`, which we already went over\n[in the basic workflow guide](../basic-workflow/data.md). The `Batcher` takes a list of `MnistItem`\nretrieved by the dataloader as input and returns a batch of images as a 3D tensor along with their\ntargets.\n"
  },
  {
    "path": "burn-book/src/building-blocks/learner.md",
    "content": "# Learner\n\nThe [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates\nmultiple utilities for training deep learning models. The goal of the crate is to provide users with\na well-crafted and flexible training loop, so that projects do not have to write such components\nfrom the ground up. Most of the interactions with `burn-train` will be with the `SupervisedTraining`\nstruct, briefly presented in the previous [training section](../basic-workflow/training.md). This\nstruct enables you to configure the training loop, offering support for registering metrics,\nenabling logging, checkpointing states, using multiple devices, and so on.\n\nThere are still some assumptions in the current provided APIs, which may make them inappropriate for\nyour learning requirements. Indeed, they assume your model will learn from a training dataset and be\nvalidated against another dataset. This is the most common paradigm, allowing users to do both\nsupervised and unsupervised learning as well as fine-tuning. However, for more complex requirements,\ncreating a [custom training loop](../custom-training-loop.md) might be what you need.\n\n## Usage\n\nThe `SupervisedLearning` struct must be created with the training and validation dataloaders. It provides you with numerous options when it comes to configurations.\n\n| Configuration          | Description                                                                    |\n| ---------------------- | ------------------------------------------------------------------------------ |\n| Training Metric        | Register a training metric                                                     |\n| Validation Metric      | Register a validation metric                                                   |\n| Training Metric Plot   | Register a training metric with plotting (requires the metric to be numeric)   |\n| Validation Metric Plot | Register a validation metric with plotting (requires the metric to be numeric) |\n| Metric Logger          | Configure the metric loggers (default is saving them to files)                 |\n| Renderer               | Configure how to render metrics (default is CLI)                               |\n| Grad Accumulation      | Configure the number of steps before applying gradients                        |\n| File Checkpointer      | Configure how the model, optimizer and scheduler states are saved              |\n| Num Epochs             | Set the number of epochs                                                       |\n| Devices                | Set the devices to be used                                                     |\n| Checkpoint             | Restart training from a checkpoint                                             |\n| Application logging    | Configure the application logging installer (default is writing to `experiment.log`)                                   |\n| Training Strategy      | Use a custom training strategy, allowing you to use your own training loop with all the capabilities of the `SupervisedTraining` struct          |\n\nWhen the training is configured to your liking, you can then move forward to running the training. The\n`launch` method requires a learner object providing: the model, the optimizer and the learning rate scheduler. Note\nthat the latter can be a simple float if you want it to be constant during training.\n\nThe `launch` method will start the training and return the trained model once finished.\n\nAgain, please refer to the [training section](../basic-workflow/training.md) for a relevant code\nsnippet.\n\n## Artifacts\n\nWhen creating a `SupervisedTraining` instance, all the collected data will be saved under the directory provided as\nthe argument to the `new` method. Here is an example of the data layout for a model recorded using\nthe compressed message pack format, with the accuracy and loss metrics registered:\n\n```\n├── experiment.log\n├── checkpoint\n│   ├── model-1.mpk.gz\n│   ├── optim-1.mpk.gz\n│   └── scheduler-1.mpk.gz\n│   ├── model-2.mpk.gz\n│   ├── optim-2.mpk.gz\n│   └── scheduler-2.mpk.gz\n├── train\n│   ├── epoch-1\n│   │   ├── Accuracy.log\n│   │   └── Loss.log\n│   └── epoch-2\n│       ├── Accuracy.log\n│       └── Loss.log\n└── valid\n    ├── epoch-1\n    │   ├── Accuracy.log\n    │   └── Loss.log\n    └── epoch-2\n        ├── Accuracy.log\n        └── Loss.log\n```\n\nYou can choose to save or synchronize that local directory with a remote file system, if desired.\nThe file checkpointer is capable of automatically deleting old checkpoints according to a specified\nconfiguration.\n"
  },
  {
    "path": "burn-book/src/building-blocks/metric.md",
    "content": "# Metric\n\nWhen working with the learner, you have the option to record metrics that will be monitored\nthroughout the training process. We currently offer a restricted range of metrics.\n\n| Metric              | Description                                                                                 |\n| ------------------- | ------------------------------------------------------------------------------------------- |\n| Accuracy            | Calculate the accuracy in percentage                                                        |\n| TopKAccuracy        | Calculate the top-k accuracy in percentage                                                  |\n| Precision           | Calculate precision in percentage                                                           |\n| Recall              | Calculate recall in percentage                                                              |\n| FBetaScore          | Calculate F<sub>β </sub>score in percentage                                                 |\n| AUROC               | Calculate the area under curve of ROC in percentage                                         |\n| Loss                | Output the loss used for the backward pass                                                  |\n| CharErrorRate (CER) | Calculate Character Error Rate in percentage                                                |\n| WordErrorRate (WER) | Calculate Word Error Rate in percentage                                                     |\n| HammingScore        | Calculate hamming score (also known as multi-label or label-based accuracy) in percentage   |\n| Perplexity          | Calculate perplexity which is a measure of how well a probability model predicts samples    |\n| IterationSpeed      | Tracks the training iteration speed, measuring how many iterations are completed per second |\n| CPU Temperature     | Fetch the temperature of CPUs                                                               |\n| CPU Usage           | Fetch the CPU utilization                                                                   |\n| CPU Memory Usage    | Fetch the CPU RAM usage                                                                     |\n| Learning Rate       | Fetch the current learning rate for each optimizer step                                     |\n| CUDA                | Fetch general CUDA metrics such as utilization                                              |\n\n| Vision Metric | Description                                                                                          |\n| ------------- | ---------------------------------------------------------------------------------------------------- |\n| Dice          | Computes the Dice-Sorenson coefficient (DSC) for evaluating overlap between binary masks             |\n| DISTS         | Computes the Deep Image Structure and Texture Similarity (DISTS) metric for image quality assessment |\n| LPIPS         | Computes the Learned Perceptual Image Patch Similarity (LPIPS) for image quality assessment          |\n| MS-SSIM       | Computes the Multi-scale Structural Similarity index measure (MS-SSIM) for image quality assessment  |\n| PSNR          | Computes the Peak Signal-to-Noise Ratio (PSNR) for image quality assessment                          |\n| SSIM          | Computes the Structural Similarity index measure (SSIM) for image quality assessment                 |\n\n## Using Metrics with the Learner\n\nIn order to use a metric, the output of your training step must implement the `Adaptor` trait from \n`burn-train::metric` for each metric's corresponding input type. The `Adaptor` trait simply converts \nyour output struct into the input type the metric expects.\n\nBurn provides four built-in output structs that cover common tasks. Each one already implements \n`Adaptor` for a set of metrics, so in many cases you can use them directly without writing any \nadaptor code yourself.\n\n- `ClassificationOutput<B>`:\n    - Use case: Single-label classification\n    - Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 1, Int>`\n    - Adapted metrics: Accuracy, TopKAccuracy, Perplexity, Precision\\*, Recall\\*, FBetaScore\\*, AUROC\\*, Loss\n- `MultiLabelClassificationOutput<B>`:\n    - Use case: Multi-label classification\n    - Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2, Int>`\n    - Adapted metrics: HammingScore, Precision\\*, Recall\\*, FBetaScore\\*, Loss\n- `RegressionOutput<B>`:\n    - Use case: Regression tasks\n    - Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2>`\n    - Adapted metrics: Loss\n- `SequenceOutput<B>`:\n    - Use case: Sequence prediction\n    - Fields: `loss: Tensor<B, 1>`, `logits: Tensor<B, 3>`, `predictions: Option<Tensor<B, 2, Int>>`, `targets: Tensor<B, 2, Int>`\n    - Adapted metrics: Accuracy, TopKAccuracy, Perplexity, CER, WER, Loss\n\n\\* Precision, Recall, and FBetaScore all use `ConfusionStatsInput` as its input type so these three \nmetrics are automatically (implicitly) adapted since `ConfusionStatsInput` is adapted.\n\nIf your metric isn't already adapted for the appropriate output struct, you can implement `Adaptor` yourself. \nFor example, here is how `ClassificationOutput` adapts to `AccuracyInput`:\n\n```rust,ignore\nimpl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> AccuracyInput<B> {\n        AccuracyInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n```\n\nIf your task type is not covered by the built-in output structs, you can create an output struct for your data\nand then adapt your metric for the output struct:\n\n```rust,ignore\n#[derive(new)]\npub struct ClassificationOutput<B: Backend> {\n    /// The loss.\n    pub loss: Tensor<B, 1>,\n\n    /// The output.\n    pub output: Tensor<B, 2>,\n\n    /// The targets.\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> AccuracyInput<B> {\n        AccuracyInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n```\n\nYou can also open an issue on the [GitHub repository](https://github.com/tracel-ai/burn) when your task type is \nnot covered by the built-in output structs. However, since creating an output struct for your data is simple, \nit is recommended to try creating your own output struct first. \n\n# Custom Metric\n\nGenerating your own custom metrics is done by implementing the `Metric` trait.\n\n```rust , ignore\n\n/// Metric trait.\n///\n/// # Notes\n///\n/// Implementations should define their own input type only used by the metric.\n/// This is important since some conflict may happen when the model output is adapted for each\n/// metric's input type.\npub trait Metric: Send + Sync + Clone {\n    /// The input type of the metric.\n    type Input;\n\n    /// The parameterized name of the metric.\n    ///\n    /// This should be unique, so avoid using short generic names, prefer using the long name.\n    ///\n    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different\n    /// values of k), the name should be unique for each instance.\n    fn name(&self) -> MetricName;\n\n    /// Update the metric state and returns the current metric entry.\n    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;\n\n    /// Clear the metric state.\n    fn clear(&mut self);\n}\n```\n\nAs an example, let's see how the loss metric is implemented.\n\n```rust, ignore\n/// The loss metric.\n#[derive(Clone)]\npub struct LossMetric<B: Backend> {\n    name: Arc<String>,\n    state: NumericMetricState,\n    _b: B,\n}\n\n/// The [loss metric](LossMetric) input type.\n#[derive(new)]\npub struct LossInput<B: Backend> {\n    tensor: Tensor<B, 1>,\n}\n\nimpl<B: Backend> Default for LossMetric<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> LossMetric<B> {\n    /// Create the metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Loss\".to_string()),\n            state: NumericMetricState::default(),\n            _b: Default::default(),\n        }\n    }\n}\n\n\nimpl<B: Backend> Metric for LossMetric<B> {\n    type Input = LossInput<B>;\n\n    fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [batch_size] = loss.tensor.dims();\n        let loss = loss\n            .tensor\n            .clone()\n            .mean()\n            .into_data()\n            .iter::<f64>()\n            .next()\n            .unwrap();\n\n        self.state.update(\n            loss,\n            batch_size,\n            FormatOptions::new(self.name()).precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n```\n\nWhen the metric you are implementing is numeric in nature, you may want to also implement the\n`Numeric` trait. This will allow your metric to be plotted.\n\n```rust, ignore\nimpl<B: Backend> Numeric for LossMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n```\n"
  },
  {
    "path": "burn-book/src/building-blocks/module.md",
    "content": "# Module\n\nThe `Module` derive allows you to create your own neural network modules, similar to PyTorch. The\nderive function only generates the necessary methods to essentially act as a parameter container for\nyour type, it makes no assumptions about how the forward pass is declared.\n\n```rust, ignore\nuse burn::module::Module;\nuse burn::tensor::backend::Backend;\n\n#[derive(Module, Debug)]\npub struct PositionWiseFeedForward<B: Backend> {\n    linear_inner: Linear<B>,\n    linear_outer: Linear<B>,\n    dropout: Dropout,\n    gelu: Gelu,\n}\n\nimpl<B: Backend> PositionWiseFeedForward<B> {\n    /// Normal method added to a struct.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let x = self.linear_inner.forward(input);\n        let x = self.gelu.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.linear_outer.forward(x)\n    }\n}\n```\n\nNote that all fields declared in the struct must also implement the `Module` trait.\n\n## Tensor\n\nIf you want to create your own module that contains tensors, and not just other modules defined with\nthe `Module` derive, you need to be careful to achieve the behavior you want.\n\n- `Param<Tensor<B, D>>`: If you want the tensor to be included as a parameter of your modules, you\n  need to wrap the tensor in a `Param` struct. This will create an ID that will be used to identify\n  this parameter. This is essential when performing module optimization and when saving states such\n  as optimizer and module checkpoints. Note that a module's record only contains parameters.\n\n- `Param<Tensor<B, D>>.set_require_grad(false)`: If you want the tensor to be included as a\n  parameter of your modules, and therefore saved with the module's weights, but you don't want it to\n  be updated by the optimizer.\n\n- `Tensor<B, D>`: If you want the tensor to act as a constant that can be recreated when\n  instantiating a module. This can be useful when generating sinusoidal embeddings, for example.\n\n## Methods\n\nThese methods are available for all modules.\n\n| Burn API                                | PyTorch Equivalent                       |\n| --------------------------------------- | ---------------------------------------- |\n| `module.devices()`                      | N/A                                      |\n| `module.fork(device)`                   | Similar to `module.to(device).detach()`  |\n| `module.to_device(device)`              | `module.to(device)`                      |\n| `module.no_grad()`                      | `module.require_grad_(False)`            |\n| `module.num_params()`                   | N/A                                      |\n| `module.visit(visitor)`                 | N/A                                      |\n| `module.map(mapper)`                    | N/A                                      |\n| `module.into_record()`                  | Similar to `state_dict`                  |\n| `module.load_record(record)`            | Similar to `load_state_dict(state_dict)` |\n| `module.save_file(file_path, recorder)` | N/A                                      |\n| `module.load_file(file_path, recorder)` | N/A                                      |\n\nSimilar to the backend trait, there is also the `AutodiffModule` trait to signify a module with\nautodiff support.\n\n| Burn API         | PyTorch Equivalent |\n| ---------------- | ------------------ |\n| `module.valid()` | `module.eval()`    |\n\n## Visitor & Mapper\n\nAs mentioned earlier, modules primarily function as parameter containers. Therefore, we naturally\noffer several ways to perform functions on each parameter. This is distinct from PyTorch, where\nextending module functionalities is not as straightforward.\n\nThe `map` and `visitor` methods are quite similar but serve different purposes. Mapping is used for\npotentially mutable operations where each parameter of a module can be updated to a new value. In\nBurn, optimizers are essentially just sophisticated module mappers. Visitors, on the other hand, are\nused when you don't intend to modify the module but need to retrieve specific information from it,\nsuch as the number of parameters or a list of devices in use.\n\nYou can implement your own mapper or visitor by implementing these simple traits:\n\n```rust, ignore\n/// Module visitor trait.\npub trait ModuleVisitor<B: Backend> {\n    /// Visit a float tensor in the module.\n    fn visit_float<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D>);\n    /// Visit an int tensor in the module.\n    fn visit_int<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Int>);\n    /// Visit a bool tensor in the module.\n    fn visit_bool<const D: usize>(&mut self, id: ParamId, tensor: &Tensor<B, D, Bool>);\n}\n\n/// Module mapper trait.\npub trait ModuleMapper<B: Backend> {\n    /// Map a float tensor in the module.\n    fn map_float<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;\n    /// Map an int tensor in the module.\n    fn map_int<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D, Int>) -> Tensor<B, D, Int>;\n    /// Map a bool tensor in the module.\n    fn map_bool<const D: usize>(&mut self, id: ParamId, tensor: Tensor<B, D, Bool>) -> Tensor<B, D, Bool>;\n}\n```\n\nNote that the trait doesn't require all methods to be implemented as they are already defined to\nperform no operation. If you're only interested in float tensors (like the majority of use cases),\nthen you can simply implement `map_float` or `visit_float`.\n\nFor example, the `ModuleMapper` trait could be implemented to clamp all parameters into the range\n`[min, max]`.\n\n```rust, ignore\n/// Clamp parameters into the range `[min, max]`.\npub struct Clamp {\n    /// Lower-bound of the range.\n    pub min: f32,\n    /// Upper-bound of the range.\n    pub max: f32,\n}\n\n// Clamp all floating-point parameter tensors between `[min, max]`.\nimpl<B: Backend> ModuleMapper<B> for Clamp {\n    fn map_float<const D: usize>(\n        &mut self,\n        _id: burn::module::ParamId,\n        tensor: burn::prelude::Tensor<B, D>,\n    ) -> burn::prelude::Tensor<B, D> {\n        tensor.clamp(self.min, self.max)\n    }\n}\n\n// Clamp module mapper into the range `[-0.5, 0.5]`\nlet mut clamp = Clamp {\n    min: -0.5,\n    max: 0.5,\n};\nlet model = model.map(&mut clamp);\n```\n\nIf you want to use this during training to constrain your model parameters, make sure that the\nparameter tensors are still tracked for autodiff. This can be done with a simple adjustment to the\nimplementation.\n\n```rust, ignore\nimpl<B: AutodiffBackend> ModuleMapper<B> for Clamp {\n    fn map_float<const D: usize>(\n        &mut self,\n        _id: burn::module::ParamId,\n        tensor: burn::prelude::Tensor<B, D>,\n    ) -> burn::prelude::Tensor<B, D> {\n        let is_require_grad = tensor.is_require_grad();\n\n        let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max));\n\n        if is_require_grad {\n            tensor = tensor.require_grad();\n        }\n\n        tensor\n    }\n}\n```\n\n## Module Display\n\nBurn provides a simple way to display the structure of a module and its configuration at a glance.\nYou can print the module to see its structure, which is useful for debugging and tracking changes\nacross different versions of a module. (See the print output of the\n[Basic Workflow Model](../basic-workflow/model.md) example.)\n\nTo customize the display of a module, you can implement the `ModuleDisplay` trait for your module.\nThis will change the default display settings for the module and its children. Note that\n`ModuleDisplay` is automatically implemented for all modules, but you can override it to customize\nthe display by annotating the module with `#[module(custom_display)]`.\n\n```rust\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct PositionWiseFeedForward<B: Backend> {\n    linear_inner: Linear<B>,\n    linear_outer: Linear<B>,\n    dropout: Dropout,\n    gelu: Gelu,\n}\n\nimpl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {\n    /// Custom settings for the display of the module.\n    /// If `None` is returned, the default settings will be used.\n    fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {\n        DisplaySettings::new()\n            // Will show all attributes (default is false)\n            .with_show_all_attributes(false)\n            // Will show each attribute on a new line (default is true)\n            .with_new_line_after_attribute(true)\n            // Will show the number of parameters (default is true)\n            .with_show_num_parameters(true)\n            // Will indent by 2 spaces (default is 2)\n            .with_indentation_size(2)\n            // Will show the parameter ID (default is false)\n            .with_show_param_id(false)\n            // Convenience method to wrap settings in Some()\n            .optional()\n    }\n\n    /// Custom content to be displayed.\n    /// If `None` is returned, the default content will be used\n    /// (all attributes of the module)\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"linear_inner\", &self.linear_inner)\n            .add(\"linear_outer\", &self.linear_outer)\n            .add(\"anything\", \"anything_else\")\n            .optional()\n    }\n}\n```\n\n## Built-in Modules\n\nBurn comes with built-in modules that you can use to build your own modules.\n\n### General\n\n| Burn API          | PyTorch Equivalent                            |\n| ----------------- | --------------------------------------------- |\n| `BatchNorm`       | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc.       |\n| `Celu`            | `nn.CELU`                                     |\n| `Dropout`         | `nn.Dropout`                                  |\n| `Elu`             | `nn.ELU`                                      |\n| `Embedding`       | `nn.Embedding`                                |\n| `GaussianNoise`   | _No direct equivalent_                        |\n| `Gelu`            | `nn.Gelu`                                     |\n| `Glu`             | `nn.Glu`                                      |\n| `GroupNorm`       | `nn.GroupNorm`                                |\n| `HardShrink`      | `nn.Hardshrink`                               |\n| `HardSigmoid`     | `nn.Hardsigmoid`                              |\n| `HardSwish`       | `nn.Hardswish`                                |\n| `InstanceNorm`    | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |\n| `LayerNorm`       | `nn.LayerNorm`                                |\n| `LeakyRelu`       | `nn.LeakyReLU`                                |\n| `Linear`          | `nn.Linear`                                   |\n| `Prelu`           | `nn.PReLu`                                    |\n| `Relu`            | `nn.ReLU`                                     |\n| `Selu`            | `nn.SELU`                                     |\n| `Sigmoid`         | `nn.Sigmoid`                                  |\n| `Softplus`        | `nn.Softplus`                                 |\n| `SoftShrink`      | `nn.Softshrink`                               |\n| `Softsign`        | `nn.Softsign`                                 |\n| `Shrink`          | _No direct equivalent_                        |\n| `RmsNorm`         | _No direct equivalent_                        |\n| `SwiGlu`          | _No direct equivalent_                        |\n| `Tanh`            | `nn.Tanh`                                     |\n| `ThresholdedRelu` | _No direct equivalent_                        |\n\n### Convolutions\n\n| Burn API          | PyTorch Equivalent             |\n| ----------------- | ------------------------------ |\n| `Conv1d`          | `nn.Conv1d`                    |\n| `Conv2d`          | `nn.Conv2d`                    |\n| `Conv3d`          | `nn.Conv3d`                    |\n| `ConvTranspose1d` | `nn.ConvTranspose1d`           |\n| `ConvTranspose2d` | `nn.ConvTranspose2d`           |\n| `ConvTranspose3d` | `nn.ConvTranspose3d`           |\n| `DeformConv2d`    | `torchvision.ops.DeformConv2d` |\n\n### Pooling\n\n| Burn API            | PyTorch Equivalent     |\n| ------------------- | ---------------------- |\n| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |\n| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |\n| `AvgPool1d`         | `nn.AvgPool1d`         |\n| `AvgPool2d`         | `nn.AvgPool2d`         |\n| `MaxPool1d`         | `nn.MaxPool1d`         |\n| `MaxPool2d`         | `nn.MaxPool2d`         |\n\n### Interpolation\n\n| Burn API        | PyTorch Equivalent |\n| --------------- | ------------------ |\n| `Interpolate1d` | `nn.Upsample`     |\n| `Interpolate2d` | `nn.Upsample`     |\n\nInterpolation modules resize tensors using one of the available `InterpolateMode` options:\n\n| Mode      | Description                                              |\n| --------- | -------------------------------------------------------- |\n| `Nearest` | Nearest-neighbor interpolation                           |\n| `Linear`  | Linear interpolation (bilinear for 2D)                   |\n| `Cubic`   | Cubic interpolation (bicubic for 2D)                     |\n| `Lanczos` | Lanczos3 resampling (6-tap sinc-based filter, a=3)       |\n\nConfiguration is done via `Interpolate1dConfig` / `Interpolate2dConfig` with these options:\n\n| Option          | Type                                     | Default   | Description                                              |\n| --------------- |------------------------------------------| --------- | -------------------------------------------------------- |\n| `output_size`   | `Option<usize>` / `Option<[usize; 2]>`   | `None`    | Target output size (takes precedence over scale_factor)  |\n| `scale_factor`  | `Option<f32>` / `Option<[f32; 2]>`       | `None`    | Scale factor for resizing                                |\n| `mode`          | `InterpolateMode`                        | `Nearest` | Interpolation algorithm                                  |\n| `align_corners` | `bool`                                   | `true`    | Align input/output corner pixels                         |\n\n### RNNs\n\n| Burn API         | PyTorch Equivalent     |\n| ---------------- | ---------------------- |\n| `Gru`/`BiGru`    | `nn.GRU`               |\n| `Lstm`/`BiLstm`  | `nn.LSTM`              |\n| `GateController` | _No direct equivalent_ |\n\n### Transformer\n\n| Burn API             | PyTorch Equivalent      |\n| -------------------- | ----------------------- |\n| `MultiHeadAttention` | `nn.MultiheadAttention` |\n| `TransformerDecoder` | `nn.TransformerDecoder` |\n| `TransformerEncoder` | `nn.TransformerEncoder` |\n| `PositionalEncoding` | _No direct equivalent_  |\n| `RotaryEncoding`     | _No direct equivalent_  |\n\n### Loss\n\n| Burn API                 | PyTorch Equivalent       |\n| ------------------------ | ------------------------ |\n| `BinaryCrossEntropyLoss` | `nn.BCELoss`             |\n| `CosineEmbeddingLoss`    | `nn.CosineEmbeddingLoss` |\n| `CrossEntropyLoss`       | `nn.CrossEntropyLoss`    |\n| `CTCLoss`                | `nn.CTCLoss`             |\n| `GramMatrixLoss`         | _No direct equivalent_   |\n| `HuberLoss`              | `nn.HuberLoss`           |\n| `KLDivLoss`              | `nn.KLDivLoss`           |\n| `LpLoss`                 | _No direct equivalent_   |\n| `MseLoss`                | `nn.MSELoss`             |\n| `PoissonNllLoss`         | `nn.PoissonNLLLoss`      |\n| `RNNTLoss`               | `torchaudio.functional.rnnt_loss` |\n| `SmoothL1Loss`           | `nn.SmoothL1Loss`        |\n"
  },
  {
    "path": "burn-book/src/building-blocks/record.md",
    "content": "# Record\n\nRecords are how states are saved with Burn. Compared to most other frameworks, Burn has its own\nadvanced saving mechanism that allows interoperability between backends with minimal possible\nruntime errors. There are multiple reasons why Burn decided to create its own saving formats.\n\nFirst, Rust has [serde](https://serde.rs/), which is an extremely well-developed serialization and\ndeserialization library that also powers the `safetensors` format developed by Hugging Face. If used\nproperly, all the validations are done when deserializing, which removes the need to write\nvalidation code. Since modules in Burn are created with configurations, they can't implement\nserialization and deserialization. That's why the record system was created: allowing you to save\nthe state of modules independently of the backend in use extremely fast while still giving you all\nthe flexibility possible to include any non-serializable field within your module.\n\n**Why not use safetensors?**\n\n[`safetensors`](https://github.com/huggingface/safetensors) uses serde with the JSON file format and\nonly supports serializing and deserializing tensors. The record system in Burn gives you the\npossibility to serialize any type, which is very useful for optimizers that save their state, but\nalso for any non-standard, cutting-edge modeling needs you may have. Additionally, the record system\nperforms automatic precision conversion by using Rust types, making it more reliable with fewer\nmanual manipulations.\n\nIt is important to note that the `safetensors` format uses the word _safe_ to distinguish itself\nfrom Pickle, which is vulnerable to Python code injection. On our end, the simple fact that we use\nRust already ensures that no code injection is possible. If your storage mechanism doesn't handle\ndata corruption, you might prefer a recorder that performs checksum validation (i.e., any recorder\nwith Gzip compression).\n\n## Recorder\n\nRecorders are independent of the backend and serialize records with precision and a format. Note\nthat the format can also be in-memory, allowing you to save the records directly into bytes.\n\n| Recorder               | Format                   | Compression |\n| ---------------------- | ------------------------ | ----------- |\n| DefaultFileRecorder    | File - Named MessagePack | None        |\n| NamedMpkFileRecorder   | File - Named MessagePack | None        |\n| NamedMpkGzFileRecorder | File - Named MessagePack | Gzip        |\n| BinFileRecorder        | File - Binary            | None        |\n| BinGzFileRecorder      | File - Binary            | Gzip        |\n| JsonGzFileRecorder     | File - Json              | Gzip        |\n| PrettyJsonFileRecorder | File - Pretty Json       | Gzip        |\n| BinBytesRecorder       | In Memory - Binary       | None        |\n\nEach recorder supports precision settings decoupled from the precision used for training or\ninference. These settings allow you to define the floating-point and integer types that will be used\nfor serialization and deserialization.\n\n| Setting                   | Float Precision | Integer Precision |\n| ------------------------- | --------------- | ----------------- |\n| `DoublePrecisionSettings` | `f64`           | `i64`             |\n| `FullPrecisionSettings`   | `f32`           | `i32`             |\n| `HalfPrecisionSettings`   | `f16`           | `i16`             |\n\nNote that when loading a record into a module, the type conversion is automatically handled, so you\ncan't encounter errors. The only crucial aspect is using the same recorder for both serialization\nand deserialization; otherwise, you will encounter loading errors.\n\n**Which recorder should you use?**\n\n- If you want fast serialization and deserialization, choose a recorder without compression. The one\n  with the lowest file size without compression is the binary format; otherwise, the named\n  MessagePack could be used.\n- If you want to save models for storage, you can use compression, but avoid using the binary\n  format, as it may not be backward compatible.\n- If you want to debug your model's weights, you can use the pretty JSON format.\n- If you want to deploy with `no-std`, use the in-memory binary format and include the bytes with\n  the compiled code.\n\nFor examples on saving and loading records, take a look at\n[Saving and Loading Models](../saving-and-loading.md).\n"
  },
  {
    "path": "burn-book/src/building-blocks/tensor.md",
    "content": "# Tensor\n\nAs previously explained in the [model section](../basic-workflow/model.md), the Tensor struct has 3\ngeneric arguments: the backend B, the dimensionality D, and the data type.\n\n```rust, ignore\nTensor<B, D>           // Float tensor (default)\nTensor<B, D, Float>    // Explicit float tensor\nTensor<B, D, Int>      // Int tensor\nTensor<B, D, Bool>     // Bool tensor\n```\n\nNote that the specific element types used for `Float`, `Int`, and `Bool` tensors are defined by\nbackend implementations.\n\nBurn Tensors are defined by the number of dimensions D in its declaration as opposed to its shape.\nThe actual shape of the tensor is inferred from its initialization. For example, a Tensor of size\n(5,) is initialized as below:\n\n```rust, ignore\nlet floats = [1.0, 2.0, 3.0, 4.0, 5.0];\n\n// Get the default device\nlet device = Default::default();\n\n// correct: Tensor is 1-Dimensional with 5 elements\nlet tensor_1 = Tensor::<Backend, 1>::from_floats(floats, &device);\n\n// incorrect: let tensor_1 = Tensor::<Backend, 5>::from_floats(floats, &device);\n// this will lead to an error and is for creating a 5-D tensor\n```\n\n### Initialization\n\nBurn Tensors are primarily initialized using the `from_data()` method which takes the `TensorData`\nstruct as input. The `TensorData` struct has two public fields: `shape` and `dtype`. The `value`,\nnow stored as bytes, is private but can be accessed via any of the following methods: `as_slice`,\n`as_mut_slice`, `to_vec` and `iter`. To retrieve the data from a tensor, the method `.to_data()`\nshould be employed when intending to reuse the tensor afterward. Alternatively, `.into_data()` is\nrecommended for one-time use. Let's look at a couple of examples for initializing a tensor from\ndifferent inputs.\n\n```rust, ignore\n\n// Initialization from a given Backend (Wgpu)\nlet tensor_1 = Tensor::<Wgpu, 1>::from_data([1.0, 2.0, 3.0], &device);\n\n// Initialization from a generic Backend\nlet tensor_2 = Tensor::<Backend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);\n\n// Initialization using from_floats (Recommended for f32 ElementType)\n// Will be converted to TensorData internally.\nlet tensor_3 = Tensor::<Backend, 1>::from_floats([1.0, 2.0, 3.0], &device);\n\n// Initialization of Int Tensor from array slices\nlet arr: [i32; 6] = [1, 2, 3, 4, 5, 6];\nlet tensor_4 = Tensor::<Backend, 1, Int>::from_data(TensorData::from(&arr[0..3]), &device);\n\n// Initialization from a custom type\n\nstruct BodyMetrics {\n    age: i8,\n    height: i16,\n    weight: f32\n}\n\nlet bmi = BodyMetrics{\n        age: 25,\n        height: 180,\n        weight: 80.0\n    };\nlet data  = TensorData::from([bmi.age as f32, bmi.height as f32, bmi.weight]);\nlet tensor_5 = Tensor::<Backend, 1>::from_data(data, &device);\n\n```\n\n## Ownership and Cloning\n\nAlmost all Burn operations take ownership of the input tensors. Therefore, reusing a tensor multiple\ntimes will necessitate cloning it. Let's look at an example to understand the ownership rules and\ncloning better. Suppose we want to do a simple min-max normalization of an input tensor.\n\n```rust, ignore\nlet input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);\nlet min = input.min();\nlet max = input.max();\nlet input = (input - min).div(max - min);\n```\n\nWith PyTorch tensors, the above code would work as expected. However, Rust's strict ownership rules\nwill give an error and prevent using the input tensor after the first `.min()` operation. The\nownership of the input tensor is transferred to the variable `min` and the input tensor is no longer\navailable for further operations. Burn Tensors like most complex primitives do not implement the\n`Copy` trait and therefore have to be cloned explicitly. Now let's rewrite a working example of\ndoing min-max normalization with cloning.\n\n```rust, ignore\nlet input = Tensor::<Wgpu, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device);\nlet min = input.clone().min();\nlet max = input.clone().max();\nlet input = (input.clone() - min.clone()).div(max - min);\nprintln!(\"{}\", input.to_data());// Success: [0.0, 0.33333334, 0.6666667, 1.0]\n\n// Notice that max, min have been moved in last operation so\n// the below print will give an error.\n// If we want to use them for further operations,\n// they will need to be cloned in similar fashion.\n// println!(\"{:?}\", min.to_data());\n```\n\nWe don't need to be worried about memory overhead because with cloning, the tensor's buffer isn't\ncopied, and only a reference to it is increased. This makes it possible to determine exactly how\nmany times a tensor is used, which is very convenient for reusing tensor buffers or even fusing\noperations into a single kernel ([burn-fusion](https://burn.dev/docs/burn_fusion/index.htmls)). For\nthat reason, we don't provide explicit inplace operations. If a tensor is used only one time,\ninplace operations will always be used when available.\n\n## Tensor Operations\n\nNormally with PyTorch, explicit inplace operations aren't supported during the backward pass, making\nthem useful only for data preprocessing or inference-only model implementations. With Burn, you can\nfocus more on _what_ the model should do, rather than on _how_ to do it. We take the responsibility\nof making your code run as fast as possible during training as well as inference. The same\nprinciples apply to broadcasting; all operations support broadcasting unless specified otherwise.\n\nHere, we provide a list of all supported operations along with their PyTorch equivalents. Note that\nfor the sake of simplicity, we ignore type signatures. For more details, refer to the\n[full documentation](https://docs.rs/burn/latest/burn/tensor/struct.Tensor.html).\n\n### Basic Operations\n\nThose operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.\n\n| Burn                                                 | PyTorch Equivalent                                                        |\n| ---------------------------------------------------- | ------------------------------------------------------------------------- |\n| `Tensor::cat(tensors, dim)`                          | `torch.cat(tensors, dim)`                                                 |\n| `Tensor::empty(shape, options)`                      | `torch.empty(shape, device=device, dtype=dtype)`                          |\n| `Tensor::from_primitive(primitive)`                  | N/A                                                                       |\n| `Tensor::stack(tensors, dim)`                        | `torch.stack(tensors, dim)`                                               |\n| `tensor.all()`                                       | `tensor.all()`                                                            |\n| `tensor.all_dim(dim)`                                | `tensor.all(dim)`                                                         |\n| `tensor.any()`                                       | `tensor.any()`                                                            |\n| `tensor.any_dim(dim)`                                | `tensor.any(dim)`                                                         |\n| `tensor.chunk(num_chunks, dim)`                      | `tensor.chunk(num_chunks, dim)`                                           |\n| `tensor.split(split_size, dim)`                      | `tensor.split(split_size, dim)`                                           |\n| `tensor.split_with_sizes(split_sizes, dim)`          | `tensor.split([split_sizes], dim)`                                        |\n| `tensor.device()`                                    | `tensor.device`                                                           |\n| `tensor.dtype()`                                     | `tensor.dtype`                                                            |\n| `tensor.dims()`                                      | `tensor.size()`                                                           |\n| `tensor.equal(other)`                                | `x == y`                                                                  |\n| `tensor.equal_elem(other)`                           | `tensor.eq(other)`                                                        |\n| `tensor.expand(shape)`                               | `tensor.expand(shape)`                                                    |\n| `tensor.flatten(start_dim, end_dim)`                 | `tensor.flatten(start_dim, end_dim)`                                      |\n| `tensor.flip(axes)`                                  | `tensor.flip(axes)`                                                       |\n| `tensor.full_like(fill_value)`                       | `torch.full_like(tensor, fill_value)`                                     |\n| `tensor.gather(dim, indices)`                        | `torch.gather(tensor, dim, indices)`                                      |\n| `tensor.into_data()`                                 | N/A                                                                       |\n| `tensor.into_primitive()`                            | N/A                                                                       |\n| `tensor.into_scalar()`                               | `tensor.item()`                                                           |\n| `tensor.mask_fill(mask, value)`                      | `tensor.masked_fill(mask, value)`                                         |\n| `tensor.mask_where(mask, value_tensor)`              | `torch.where(mask, value_tensor, tensor)`                                 |\n| `tensor.movedim(src, dst)`                           | `tensor.movedim(src, dst)`                                                |\n| `tensor.narrow(dim, start, length)`                  | `tensor.narrow(dim, start, length)`                                       |\n| `tensor.not_equal(other)`                            | `x != y`                                                                  |\n| `tensor.not_equal_elem(scalar)`                      | `tensor.ne(scalar)`                                                       |\n| `tensor.ones_like()`                                 | `torch.ones_like(tensor)`                                                 |\n| `tensor.permute(axes)`                               | `tensor.permute(axes)`                                                    |\n| `tensor.repeat_dim(dim, times)`                      | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |\n| `tensor.repeat(sizes)`                               | `tensor.repeat(sizes)`                                                    |\n| `tensor.reshape(shape)`                              | `tensor.view(shape)`                                                      |\n| `tensor.roll(shifts, dims)`                          | `tensor.roll(shifts, dims)`                                               |\n| `tensor.roll_dim(shift, dim)`                        | `tensor.roll([shift], [dim])`                                             |\n| `tensor.scatter(dim, indices, values, update)`       | `tensor.scatter_add(dim, indices, values)`                                |\n| `tensor.select(dim, indices)`                        | `tensor.index_select(dim, indices)`                                       |\n| `tensor.select_assign(dim, indices, values, update)` | `tensor.index_add(dim, indices, values)`                                  |\n| `tensor.shape()`                                     | `tensor.shape`                                                            |\n| `tensor.slice(slices)`                               | `tensor[(*ranges,)]`                                                      |\n| `tensor.slice_assign(slices, values)`                | `tensor[(*ranges,)] = values`                                             |\n| `tensor.slice_fill(slices, value)`                   | `tensor[(*ranges,)] = value`                                              |\n| `tensor.slice_dim(dim, slice)`                       | N/A                                                                       |\n| `tensor.squeeze()`                                   | `tensor.squeeze()`                                                        |\n| `tensor.squeeze_dim(dim)`                            | `tensor.squeeze(dim)`                                                     |\n| `tensor.squeeze_dims(dims)`                          | `tensor.squeeze(dims)` where `dims` is a tuple of ints                    |\n| `tensor.swap_dims(dim1, dim2)`                       | `tensor.transpose(dim1, dim2)`                                            |\n| `tensor.take(dim, indices)`                          | `numpy.take(tensor, indices, dim)`                                        |\n| `tensor.to_data()`                                   | N/A                                                                       |\n| `tensor.to_device(device)`                           | `tensor.to(device)`                                                       |\n| `tensor.transpose()`                                 | `tensor.T`                                                                |\n| `tensor.t()`                                         | `tensor.T`                                                                |\n| `tensor.unsqueeze()`                                 | N/A                                                                       |\n| `tensor.unsqueeze_dim(dim)`                          | `tensor.unsqueeze(dim)`                                                   |\n| `tensor.unsqueeze_dims(dims)`                        | N/A                                                                       |\n| `tensor.zeros_like()`                                | `torch.zeros_like(tensor)`                                                |\n| `Tensor::full(shape, fill_value, options)`           | `torch.full(shape, fill_value, device=device, dtype=dtype)`               |\n| `Tensor::ones(shape, options)`                       | `torch.ones(shape, device=device, dtype=dtype)`                           |\n| `Tensor::zeros(shape, options)`                      | `torch.zeros(shape, device=device, dtype=dtype)`                          |\n\n### Numeric Operations\n\nThose operations are available for numeric tensor kinds: `Float` and `Int`.\n\n| Burn                                                            | PyTorch Equivalent                            |\n| --------------------------------------------------------------- | --------------------------------------------- |\n| `tensor.abs()`                                                  | `torch.abs(tensor)`                           |\n| `tensor.add(other)` or `tensor + other`                         | `tensor + other`                              |\n| `tensor.add_scalar(scalar)` or `tensor + scalar`                | `tensor + scalar`                             |\n| `tensor.all_close(other, atol, rtol)`                           | `torch.allclose(tensor, other, atol, rtol)`   |\n| `tensor.argmax(dim)`                                            | `tensor.argmax(dim)`                          |\n| `tensor.argmin(dim)`                                            | `tensor.argmin(dim)`                          |\n| `tensor.argsort(dim)`                                           | `tensor.argsort(dim)`                         |\n| `tensor.argsort_descending(dim)`                                | `tensor.argsort(dim, descending=True)`        |\n| `tensor.bool()`                                                 | `tensor.bool()`                               |\n| `tensor.clamp(min, max)`                                        | `torch.clamp(tensor, min=min, max=max)`       |\n| `tensor.clamp_max(max)`                                         | `torch.clamp(tensor, max=max)`                |\n| `tensor.clamp_min(min)`                                         | `torch.clamp(tensor, min=min)`                |\n| `tensor.cumsum(dim)`                                            | `tensor.cumsum(dim)`                          |\n| `tensor.cumprod(dim)`                                           | `tensor.cumprod(dim)`                         |\n| `tensor.cummin(dim)`                                            | `tensor.cummin(dim)`                          |\n| `tensor.cummax(dim)`                                            | `tensor.cummax(dim)`                          |\n| `tensor.div(other)` or `tensor / other`                         | `tensor / other`                              |\n| `tensor.div_scalar(scalar)` or `tensor / scalar`                | `tensor / scalar`                             |\n| `tensor.dot(other)`                                             | `torch.dot(tensor, other)`                    |\n| `tensor.greater(other)`                                         | `tensor.gt(other)`                            |\n| `tensor.greater_elem(scalar)`                                   | `tensor.gt(scalar)`                           |\n| `tensor.greater_equal(other)`                                   | `tensor.ge(other)`                            |\n| `tensor.greater_equal_elem(scalar)`                             | `tensor.ge(scalar)`                           |\n| `tensor.lower(other)`                                           | `tensor.lt(other)`                            |\n| `tensor.lower_elem(scalar)`                                     | `tensor.lt(scalar)`                           |\n| `tensor.lower_equal(other)`                                     | `tensor.le(other)`                            |\n| `tensor.lower_equal_elem(scalar)`                               | `tensor.le(scalar)`                           |\n| `tensor.max()`                                                  | `tensor.max()`                                |\n| `tensor.max_abs()`                                              | `tensor.abs().max()`                          |\n| `tensor.max_abs_dim(dim)`                                       | `tensor.abs().max(dim, keepdim=True)`         |\n| `tensor.max_abs_dims(dims)`                                     | `tensor.abs().max(dims, keepdim=True)`        |\n| `tensor.max_dim(dim)`                                           | `tensor.max(dim, keepdim=True)`               |\n| `tensor.max_dims(dims)`                                         | `tensor.max(dims, keepdim=True)`              |\n| `tensor.max_dim_with_indices(dim)`                              | N/A                                           |\n| `tensor.max_pair(other)`                                        | `torch.Tensor.max(a,b)`                       |\n| `tensor.mean()`                                                 | `tensor.mean()`                               |\n| `tensor.mean_dim(dim)`                                          | `tensor.mean(dim, keepdim=True)`              |\n| `tensor.mean_dims(dims)`                                        | `tensor.mean(dims, keepdim=True)`             |\n| `tensor.min()`                                                  | `tensor.min()`                                |\n| `tensor.min_dim(dim)`                                           | `tensor.min(dim, keepdim=True)`               |\n| `tensor.min_dims(dims)`                                         | `tensor.min(dims, keepdim=True)`              |\n| `tensor.min_dim_with_indices(dim)`                              | N/A                                           |\n| `tensor.min_pair(other)`                                        | `torch.Tensor.min(a,b)`                       |\n| `tensor.mul(other)` or `tensor * other`                         | `tensor * other`                              |\n| `tensor.mul_scalar(scalar)` or `tensor * scalar`                | `tensor * scalar`                             |\n| `tensor.neg()` or `-tensor`                                     | `-tensor`                                     |\n| `tensor.one_hot(num_classes)`                                   | `torch.nn.functional.one_hot`                 |\n| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)`   | N/A                                           |\n| `tensor.pad(pads, mode)`                                        | `torch.nn.functional.pad(tensor, pads, mode)` |\n| `tensor.powf(other)` or `tensor.powi(intother)`                 | `tensor.pow(other)`                           |\n| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)`                          |\n| `tensor.prod()`                                                 | `tensor.prod()`                               |\n| `tensor.prod_dim(dim)`                                          | `tensor.prod(dim, keepdim=True)`              |\n| `tensor.prod_dims(dims)`                                        | `tensor.prod(dims, keepdim=True)`             |\n| `tensor.rem(other)` or `tensor % other`                         | `tensor % other`                              |\n| `tensor.sign()`                                                 | `tensor.sign()`                               |\n| `tensor.sort(dim)`                                              | `tensor.sort(dim).values`                     |\n| `tensor.sort_descending(dim)`                                   | `tensor.sort(dim, descending=True).values`    |\n| `tensor.sort_descending_with_indices(dim)`                      | `tensor.sort(dim, descending=True)`           |\n| `tensor.sort_with_indices(dim)`                                 | `tensor.sort(dim)`                            |\n| `tensor.sub(other)` or `tensor - other`                         | `tensor - other`                              |\n| `tensor.sub_scalar(scalar)` or `tensor - scalar`                | `tensor - scalar`                             |\n| `tensor.sum()`                                                  | `tensor.sum()`                                |\n| `tensor.sum_dim(dim)`                                           | `tensor.sum(dim, keepdim=True)`               |\n| `tensor.sum_dims(dims)`                                         | `tensor.sum(dims, keepdim=True)`              |\n| `tensor.sum_dims_squeeze(dims)`                                 | `tensor.sum(dims, keepdim=False)`             |\n| `tensor.topk(k, dim)`                                           | `tensor.topk(k, dim).values`                  |\n| `tensor.topk_with_indices(k, dim)`                              | `tensor.topk(k, dim)`                         |\n| `tensor.tril(diagonal)`                                         | `torch.tril(tensor, diagonal)`                |\n| `tensor.triu(diagonal)`                                         | `torch.triu(tensor, diagonal)`                |\n| `tensor.unfold(dim, size, step)`                                | `tensor.unfold(dim, size, step)`              |\n| `Tensor::eye(size, device)`                                     | `torch.eye(size, device=device)`              |\n| `scalar - tensor`                                               | `scalar - tensor`                             |\n\n### Float Operations\n\nThose operations are only available for `Float` tensors.\n\n| Burn API                                     | PyTorch Equivalent                         |\n| -------------------------------------------- | ------------------------------------------ |\n| `tensor.acos()`                              | `tensor.acos()`                            |\n| `tensor.acosh()`                             | `tensor.acosh()`                           |\n| `tensor.asin()`                              | `tensor.asin()`                            |\n| `tensor.asinh()`                             | `tensor.asinh()`                           |\n| `tensor.atan()`                              | `tensor.atan()`                            |\n| `tensor.atanh()`                             | `tensor.atanh()`                           |\n| `tensor.atan2(other_tensor)`                 | `tensor.atan2(other_tensor)`               |\n| `tensor.cast(dtype)`                         | `tensor.to(dtype)`                         |\n| `tensor.ceil()`                              | `tensor.ceil()`                            |\n| `tensor.contains_nan()`                      | N/A                                        |\n| `tensor.cos()`                               | `tensor.cos()`                             |\n| `tensor.cosh()`                              | `tensor.cosh()`                            |\n| `tensor.cross(other)`                        | `torch.cross(tensor, other)`               |\n| `tensor.deg2rad()`                           | `torch.deg2rad()`                          |\n| `tensor.erf()`                               | `tensor.erf()`                             |\n| `tensor.exp()`                               | `tensor.exp()`                             |\n| `tensor.floor()`                             | `tensor.floor()`                           |\n| `tensor.fmod(other)`                         | `tensor.fmod(other)`                       |\n| `tensor.fmod_scalar(scalar)`                 | `tensor.fmod(scalar)`                      |\n| `tensor.from_floats(floats, device)`         | N/A                                        |\n| `tensor.int()`                               | Similar to `tensor.to(torch.long)`         |\n| `tensor.is_close(other, atol, rtol)`         | `torch.isclose(tensor, other, atol, rtol)` |\n| `tensor.is_finite()`                         | `torch.isfinite(tensor)`                   |\n| `tensor.is_inf()`                            | `torch.isinf(tensor)`                      |\n| `tensor.is_nan()`                            | `torch.isnan(tensor)`                      |\n| `tensor.log()`                               | `tensor.log()`                             |\n| `tensor.log1p()`                             | `tensor.log1p()`                           |\n| `tensor.matmul(other)`                       | `tensor.matmul(other)`                     |\n| `tensor.rad2deg()`                           | `torch.rad2deg()`                          |\n| `tensor.random(shape, distribution, device)` | N/A                                        |\n| `tensor.random_like(distribution)`           | `torch.rand_like()` only uniform           |\n| `tensor.recip()` or `1.0 / tensor`           | `tensor.reciprocal()` or `1.0 / tensor`    |\n| `tensor.round()`                             | `tensor.round()`                           |\n| `tensor.sin()`                               | `tensor.sin()`                             |\n| `tensor.sinh()`                              | `tensor.sinh()`                            |\n| `tensor.square()`                            | `tensor.square()`                          |\n| `tensor.sqrt()`                              | `tensor.sqrt()`                            |\n| `tensor.tan()`                               | `tensor.tan()`                             |\n| `tensor.tanh()`                              | `tensor.tanh()`                            |\n| `tensor.trunc()`                             | `tensor.trunc()`                           |\n| `tensor.var(dim)`                            | `tensor.var(dim)`                          |\n| `tensor.var_bias(dim)`                       | N/A                                        |\n| `tensor.var_mean(dim)`                       | N/A                                        |\n| `tensor.var_mean_bias(dim)`                  | N/A                                        |\n| `tensor.median(dim)`                         | `tensor.median(dim)`                       |\n| `tensor.median_with_indices(dim)`            | `tensor.median(dim)`                       |\n\n### Int Operations\n\nThose operations are only available for `Int` tensors.\n\n| Burn API                                         | PyTorch Equivalent                                      |\n| ------------------------------------------------ | ------------------------------------------------------- |\n| `Tensor::arange(5..10, device)`                  | `tensor.arange(start=5, end=10, device=device)`         |\n| `Tensor::arange_step(5..10, 2, device)`          | `tensor.arange(start=5, end=10, step=2, device=device)` |\n| `tensor.bitwise_and(other)`                      | `torch.bitwise_and(tensor, other)`                      |\n| `tensor.bitwise_and_scalar(scalar)`              | `torch.bitwise_and(tensor, scalar)`                     |\n| `tensor.bitwise_not()`                           | `torch.bitwise_not(tensor)`                             |\n| `tensor.bitwise_left_shift(other)`               | `torch.bitwise_left_shift(tensor, other)`               |\n| `tensor.bitwise_left_shift_scalar(scalar)`       | `torch.bitwise_left_shift(tensor, scalar)`              |\n| `tensor.bitwise_right_shift(other)`              | `torch.bitwise_right_shift(tensor, other)`              |\n| `tensor.bitwise_right_shift_scalar(scalar)`      | `torch.bitwise_right_shift(tensor, scalar)`             |\n| `tensor.bitwise_or(other)`                       | `torch.bitwise_or(tensor, other)`                       |\n| `tensor.bitwise_or_scalar(scalar)`               | `torch.bitwise_or(tensor, scalar)`                      |\n| `tensor.bitwise_xor(other)`                      | `torch.bitwise_xor(tensor, other)`                      |\n| `tensor.bitwise_xor_scalar(scalar)`              | `torch.bitwise_xor(tensor, scalar)`                     |\n| `tensor.float()`                                 | `tensor.to(torch.float)`                                |\n| `tensor.from_ints(ints)`                         | N/A                                                     |\n| `tensor.cartesian_grid(shape, device)`           | N/A                                                     |\n\n### Bool Operations\n\nThose operations are only available for `Bool` tensors.\n\n| Burn API                             | PyTorch Equivalent              |\n| ------------------------------------ | ------------------------------- |\n| `Tensor::diag_mask(shape, diagonal)` | N/A                             |\n| `Tensor::tril_mask(shape, diagonal)` | N/A                             |\n| `Tensor::triu_mask(shape, diagonal)` | N/A                             |\n| `tensor.argwhere()`                  | `tensor.argwhere()`             |\n| `tensor.bool_and()`                  | `tensor.logical_and()`          |\n| `tensor.bool_not()`                  | `tensor.logical_not()`          |\n| `tensor.bool_or()`                   | `tensor.logical_or()`           |\n| `tensor.bool_xor()`                  | `tensor.logical_xor()`          |\n| `tensor.float()`                     | `tensor.to(torch.float)`        |\n| `tensor.int()`                       | `tensor.to(torch.long)`         |\n| `tensor.nonzero()`                   | `tensor.nonzero(as_tuple=True)` |\n\n### Quantization Operations\n\nThose operations are only available for `Float` tensors on backends that implement quantization\nstrategies.\n\n| Burn API                           | PyTorch Equivalent |\n| ---------------------------------- | ------------------ |\n| `tensor.quantize(scheme, qparams)` | N/A                |\n| `tensor.dequantize()`              | N/A                |\n\n## Activation Functions\n\n| Burn API                                         | PyTorch Equivalent                                 |\n| ------------------------------------------------ | -------------------------------------------------- |\n| `activation::celu(tensor, alpha)`                | `nn.functional.celu(tensor, alpha)`                |\n| `activation::elu(tensor, alpha)`                 | `nn.functional.elu(tensor, alpha)`                 |\n| `activation::gelu(tensor)`                       | `nn.functional.gelu(tensor)`                       |\n| `activation::glu(tensor, dim)`                   | `nn.functional.glu(tensor, dim)`                   |\n| `activation::hard_shrink(tensor, lambda)`        | `nn.functional.hardshrink(tensor, lambd)`          |\n| `activation::hard_sigmoid(tensor, alpha, beta)`  | `nn.functional.hardsigmoid(tensor)`                |\n| `activation::hard_swish(tensor)`                 | `nn.functional.hardswish(tensor)`                  |\n| `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` |\n| `activation::log_sigmoid(tensor)`                | `nn.functional.log_sigmoid(tensor)`                |\n| `activation::log_softmax(tensor, dim)`           | `nn.functional.log_softmax(tensor, dim)`           |\n| `activation::mish(tensor)`                       | `nn.functional.mish(tensor)`                       |\n| `activation::prelu(tensor,alpha)`                | `nn.functional.prelu(tensor,weight)`               |\n| `activation::quiet_softmax(tensor, dim)`         | `nn.functional.quiet_softmax(tensor, dim)`         |\n| `activation::relu(tensor)`                       | `nn.functional.relu(tensor)`                       |\n| `activation::shrink(tensor, lambda, bias)`       | _No direct equivalent_                             |\n| `activation::soft_shrink(tensor, lambda)`        | `nn.functional.softshrink(tensor, lambd)`          |\n| `activation::sigmoid(tensor)`                    | `nn.functional.sigmoid(tensor)`                    |\n| `activation::selu(tensor)`                       | `nn.functional.selu(tensor)`                       |\n| `activation::silu(tensor)`                       | `nn.functional.silu(tensor)`                       |\n| `activation::softmax(tensor, dim)`               | `nn.functional.softmax(tensor, dim)`               |\n| `activation::softmin(tensor, dim)`               | `nn.functional.softmin(tensor, dim)`               |\n| `activation::softplus(tensor, beta)`             | `nn.functional.softplus(tensor, beta)`             |\n| `activation::softsign(tensor)`                   | `nn.functional.softsign(tensor)`                   |\n| `activation::tanh(tensor)`                       | `nn.functional.tanh(tensor)`                       |\n| `activation::thresholded_relu(tensor, alpha)`    | `nn.functional.threshold(tensor, alpha, 0)`        |\n\n## Grid Functions\n\n| Burn API                                            | PyTorch Equivalent                                                   |\n| --------------------------------------------------- | -------------------------------------------------------------------- |\n| `grid::affine_grid_2d(transformation_tensor, dims)` | `nn.functional.affine_grid(theta_tensor, size, align_corners)` |\n| `grid::meshgrid(tensors, GridIndexing::Matrix)`     | `torch.meshgrid(tensors, indexing=\"ij\")`                             |\n| `grid::meshgrid(tensors, GridIndexing::Cartesian)`  | `torch.meshgrid(tensors, indexing=\"xy\")`                             |\n| `grid::meshgrid_stack(tensors, index_pos)`          | _No direct equivalent_                                               |\n\n## Linalg Functions\n\n| Burn API                                           | PyTorch Equivalent                                  |\n| -------------------------------------------------- | --------------------------------------------------- |\n| `linalg::cosine_similarity(x1, x2, dim, eps)`      | `nn.functional.cosine_similarity(x1, x2, dim, eps)` |\n| `linalg::diag(tensor)`                             | `torch.diag(tensor)`                                |\n| `linalg::l0_norm(tensor, dim)`                     | _No direct equivalent_                              |\n| `linalg::l1_norm(tensor, dim)`                     | _No direct equivalent_                              |\n| `linalg::l2_norm(tensor, dim)`                     | _No direct equivalent_                              |\n| `linalg::lp_norm(tensor, p, dim)`                  | _No direct equivalent_                              |\n| `linalg::lu_decomposition(tensor)`                 | `torch.linalg.lu(tensor)`                           |\n| `linalg::matvec(matrix, vector)`                   | `torch.matmul(matrix, vector)` / `@` operator       |\n| `linalg::max_abs_norm(tensor, dim)`                | _No direct equivalent_                              |\n| `linalg::min_abs_norm(tensor, dim)`                | _No direct equivalent_                              |\n| `linalg::outer(lhs, rhs)`                          | `torch.outer(lhs, rhs)` / `einsum(\"bi,bj->bij\", …)` |\n| `linalg::outer_dim(lhs, rhs, dim)`                 | _No direct equivalent_                              |\n| `linalg::trace(tensor)`                            | `torch.trace(tensor)`                               |\n| `linalg::vector_norm(tensor, p, dim)`              | `torch.linalg.vector_norm(tensor, p, dim)`          |\n| `linalg::vector_normalize(tensor, norm, dim, eps)` | `nn.functional.normalize(tensor, p, dim, eps)`      |\n\n## Displaying Tensor Details\n\nBurn provides flexible options for displaying tensor information, allowing you to control the level\nof detail and formatting to suit your needs.\n\n### Basic Display\n\nTo display a detailed view of a tensor, you can simply use Rust's `println!` or `format!` macros:\n\n```rust, ignore\nlet tensor = Tensor::<Backend, 2>::full([2, 3], 0.123456789, &Default::default());\nprintln!(\"{}\", tensor);\n```\n\nThis will output:\n\n```\nTensor {\n  data:\n[[0.12345679, 0.12345679, 0.12345679],\n [0.12345679, 0.12345679, 0.12345679]],\n  shape:  [2, 3],\n  device:  Cpu,\n  backend:  \"ndarray\",\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}\n```\n\n### Controlling Precision\n\nYou can control the number of decimal places displayed using Rust's formatting syntax:\n\n```rust\nprintln!(\"{:.2}\", tensor);\n```\n\nOutput:\n\n```\nTensor {\n  data:\n[[0.12, 0.12, 0.12],\n [0.12, 0.12, 0.12]],\n  shape:  [2, 3],\n  device:  Cpu,\n  backend:  \"ndarray\",\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}\n```\n\n### Global Print Options\n\nFor more fine-grained control over tensor printing, Burn provides a `PrintOptions` struct and a\n`set_print_options` function:\n\n```rust, ignore\nuse burn::tensor::{set_print_options, PrintOptions};\n\nlet print_options = PrintOptions {\n    precision: Some(2),\n    ..Default::default()\n};\n\nset_print_options(print_options);\n```\n\nOptions:\n\n- `precision`: Number of decimal places for floating-point numbers (default: None)\n- `threshold`: Maximum number of elements to display before summarizing (default: 1000)\n- `edge_items`: Number of items to show at the beginning and end of each dimension when summarizing\n  (default: 3)\n\n  ### Checking Tensor Closeness\n\n  Burn provides a utility function `check_closeness` to compare two tensors and assess their\n  similarity. This function is particularly useful for debugging and validating tensor operations,\n  especially when working with floating-point arithmetic where small numerical differences can\n  accumulate. It's also valuable when comparing model outputs during the process of importing models\n  from other frameworks, helping to ensure that the imported model produces results consistent with\n  the original.\n\n  Here's an example of how to use `check_closeness`:\n\n  ```rust, ignore\n  use burn::tensor::{check_closeness, Tensor};\n  type B = burn::backend::NdArray;\n\n  let device = Default::default();\n  let tensor1 = Tensor::<B, 1>::from_floats(\n      [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],\n      &device,\n  );\n  let tensor2 = Tensor::<B, 1>::from_floats(\n      [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],\n      &device,\n  );\n\n  check_closeness(&tensor1, &tensor2);\n  ```\n\n  The `check_closeness` function compares the two input tensors element-wise, checking their\n  absolute differences against a range of epsilon values. It then prints a detailed report showing\n  the percentage of elements that are within each tolerance level.\n\n  The output provides a breakdown for different epsilon values, allowing you to assess the closeness\n  of the tensors at various precision levels. This is particularly helpful when dealing with\n  operations that may introduce small numerical discrepancies.\n\n  The function uses color-coded output to highlight the results:\n\n  - Green [PASS]: All elements are within the specified tolerance.\n  - Yellow [WARN]: Most elements (90% or more) are within tolerance.\n  - Red [FAIL]: Significant differences are detected.\n\n  This utility can be invaluable when implementing or debugging tensor operations, especially those\n  involving complex mathematical computations or when porting algorithms from other frameworks. It's\n  also an essential tool when verifying the accuracy of imported models, ensuring that the Burn\n  implementation produces results that closely match those of the original model.\n"
  },
  {
    "path": "burn-book/src/custom-training-loop.md",
    "content": "# Custom Training Loops\n\nEven though Burn comes with a project dedicated to simplifying training, it doesn't mean that you\nhave to use it. Sometimes you may have special needs for your training, and it might be faster to\njust reimplement the training loop yourself. Also, you may just prefer implementing your own\ntraining loop instead of using a pre-built one in general.\n\nBurn's got you covered!\n\nWe will start from the same example shown in the [basic workflow](./basic-workflow) section, but\nwithout using the `Learner` struct.\n\n```rust, ignore\n#[derive(Config, Debug)]\npub struct MnistTrainingConfig {\n    #[config(default = 10)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1e-4)]\n    pub lr: f64,\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n}\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    // Create the configuration.\n    let config_model = ModelConfig::new(10, 1024);\n    let config_optimizer = AdamConfig::new();\n    let config = MnistTrainingConfig::new(config_model, config_optimizer);\n\n    B::seed(&device, config.seed);\n\n    // Create the model and optimizer.\n    let mut model = config.model.init::<B>(&device);\n    let mut optim = config.optimizer.init();\n\n    // Create the batcher.\n    let batcher = MnistBatcher::default();\n\n    // Create the dataloaders.\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::test());\n\n    ...\n}\n```\n\nAs seen with the previous example, setting up the configurations and the dataloader hasn't changed.\nNow, let's move forward and write our own training loop:\n\n```rust, ignore\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    ...\n\n    // Iterate over our training and validation loop for X epochs.\n    for epoch in 1..config.num_epochs + 1 {\n        // Implement our training loop.\n        for (iteration, batch) in dataloader_train.iter().enumerate() {\n            let output = model.forward(batch.images);\n            let loss = CrossEntropyLoss::new(None, &output.device())\n                .forward(output.clone(), batch.targets.clone());\n            let accuracy = accuracy(output, batch.targets);\n\n            println!(\n                \"[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %\",\n                epoch,\n                iteration,\n                loss.clone().into_scalar(),\n                accuracy,\n            );\n\n            // Gradients for the current backward pass\n            let grads = loss.backward();\n            // Gradients linked to each parameter of the model.\n            let grads = GradientsParams::from_grads(grads, &model);\n            // Update the model using the optimizer.\n            model = optim.step(config.lr, model, grads);\n        }\n\n        // Get the model without autodiff.\n        let model_valid = model.valid();\n\n        // Implement our validation loop.\n        for (iteration, batch) in dataloader_test.iter().enumerate() {\n            let output = model_valid.forward(batch.images);\n            let loss = CrossEntropyLoss::new(None, &output.device())\n                .forward(output.clone(), batch.targets.clone());\n            let accuracy = accuracy(output, batch.targets);\n\n            println!(\n                \"[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}\",\n                epoch,\n                iteration,\n                loss.clone().into_scalar(),\n                accuracy,\n            );\n        }\n    }\n}\n```\n\nIn the previous code snippet, we can observe that the loop starts from epoch `1` and goes up to\n`num_epochs`. Within each epoch, we iterate over the training dataloader. During this process, we\nexecute the forward pass, which is necessary for computing both the loss and accuracy. To maintain\nsimplicity, we print the results to stdout.\n\nUpon obtaining the loss, we can invoke the `backward()` function, which returns the gradients\nspecific to each variable. It's important to note that we need to map these gradients to their\ncorresponding parameters using the `GradientsParams` type. This step is essential because you might\nrun multiple different autodiff graphs and accumulate gradients for each parameter id.\n\nFinally, we can perform the optimization step using the learning rate, the model, and the computed\ngradients. It's worth mentioning that, unlike PyTorch, there's no need to register the gradients\nwith the optimizer, nor do you have to call `zero_grad`. The gradients are automatically consumed\nduring the optimization step. If you're interested in gradient accumulation, you can easily achieve\nthis by using the `GradientsAccumulator`.\n\n```rust, ignore\nlet mut accumulator = GradientsAccumulator::new();\nlet grads = model.backward();\nlet grads = GradientsParams::from_grads(grads, &model);\naccumulator.accumulate(&model, grads); ...\nlet grads = accumulator.grads(); // Pop the accumulated gradients.\n```\n\nNote that after each epoch, we include a validation loop to assess our model's performance on\npreviously unseen data. To disable gradient tracking during this validation step, we can invoke\n`model.valid()`, which provides a model on the inner backend without autodiff capabilities. It's\nimportant to emphasize that we've declared our validation batcher to be on the inner backend,\nspecifically `MnistBatcher<B::InnerBackend>`; not using `model.valid()` will result in a compilation\nerror.\n\nYou can find the code above available as an\n[example](https://github.com/tracel-ai/burn/tree/main/examples/custom-training-loop) for you to\ntest.\n\n## Multiple optimizers\n\nIt's common practice to set different learning rates, optimizer parameters, or use different optimizers entirely, for different parts\nof a model. In Burn, each `GradientParams` can contain only a subset of gradients to actually apply with an optimizer.\nThis allows you to flexibly mix and match optimizers!\n\n```rust,ignore\n// Start with calculating all gradients\nlet grads = loss.backward();\n\n// Now split the gradients into various parts.\nlet grads_conv1 = GradientParams::from_module(&mut grads, &model.conv1);\nlet grads_conv2 = GradientParams::from_module(&mut grads, &model.conv2);\n\n// You can step the model with these gradients, using different learning\n// rates for each param. You could also use an entirely different optimizer here!\nmodel = optim.step(config.lr * 2.0, model, grads_conv1);\nmodel = optim.step(config.lr * 4.0, model, grads_conv2);\n\n// For even more granular control you can split off individual parameter\n// eg. a linear bias usually needs a smaller learning rate.\nif let Some(bias) == model.linear1.bias {\n    let grads_bias = GradientParams::from_params(&mut grads, &model.linear1, &[bias.id]);\n    model = optim.step(config.lr * 0.1, model, grads_bias);\n}\n\n// Note that above calls remove gradients, so we can just get all \"remaining\" gradients.\nlet grads = GradientsParams::from_grads(grads, &model);\nmodel = optim.step(config.lr, model, grads);\n```\n\n## Custom Type\n\nThe explanations above demonstrate how to create a basic training loop. However, you may find it\nbeneficial to organize your program using intermediary types. There are various ways to do this, but\nit requires getting comfortable with generics.\n\nIf you wish to group the optimizer and the model into the same structure, you have several options.\nIt's important to note that the optimizer trait depends on both the `AutodiffModule` trait and the\n`AutodiffBackend` trait, while the module only depends on the `AutodiffBackend` trait.\n\nHere's a closer look at how you can create your types:\n\n**Create a struct that is generic over the backend and the optimizer, with a predefined model.**\n\n```rust, ignore\nstruct Learner<B, O>\nwhere\n    B: AutodiffBackend,\n{\n    model: Model<B>,\n    optim: O,\n}\n```\n\nThis is quite straightforward. You can be generic over the backend since it's used with the concrete\ntype `Model` in this case.\n\n**Create a struct that is generic over the model and the optimizer.**\n\n```rust, ignore\nstruct Learner<M, O> {\n    model: M,\n    optim: O,\n}\n```\n\nThis option is a quite intuitive way to declare the struct. You don't need to write type constraints\nwith a `where` statement when defining a struct; you can wait until you implement the actual\nfunction. However, with this struct, you may encounter some issues when trying to implement code\nblocks to your struct.\n\n```rust, ignore\nimpl<B, M, O> Learner<M, O>\nwhere\n    B: AutodiffBackend,\n    M: AutodiffModule<B>,\n    O: Optimizer<M, B>,\n{\n    pub fn step(&mut self, _batch: MnistBatch<B>) {\n        //\n    }\n}\n```\n\nThis will result in the following compilation error:\n\n```console\n1. the type parameter `B` is not constrained by the impl trait, self type, or predicates\n   unconstrained type parameter [E0207]\n```\n\nTo resolve this issue, you have two options. The first one is to make your function generic over\nthe backend and add your trait constraint within its definition:\n\n```rust, ignore\n#[allow(dead_code)]\nimpl<M, O> Learner2<M, O> {\n    pub fn step<B: AutodiffBackend>(&mut self, _batch: MnistBatch<B>)\n    where\n        B: AutodiffBackend,\n        M: AutodiffModule<B>,\n        O: Optimizer<M, B>,\n    {\n        //\n    }\n}\n```\n\nHowever, some people may prefer to have the constraints on the implementation block itself. In that\ncase, you can make your struct generic over the backend using `PhantomData<B>`.\n\n**Create a struct that is generic over the backend, the model, and the optimizer.**\n\n```rust, ignore\nstruct Learner3<B, M, O> {\n    model: M,\n    optim: O,\n    _b: PhantomData<B>,\n}\n```\n\nYou might wonder why `PhantomData` is required. Each generic argument must be used as a field when\ndeclaring a struct. When you don't need the generic argument, you can use `PhantomData` to mark it\nas a zero sized type.\n\nThese are just some suggestions on how to define your own types, but you are free to use any pattern\nthat you prefer.\n"
  },
  {
    "path": "burn-book/src/distributed-computing.md",
    "content": "# Distributed Computing\n"
  },
  {
    "path": "burn-book/src/examples.md",
    "content": "# Examples\n\nIn the [next chapter](./basic-workflow) you'll have the opportunity to implement the whole Burn\n`guide` example yourself in a step by step manner.\n\nMany additional Burn examples are available in the\n[examples](https://github.com/tracel-ai/burn/tree/main/examples) directory. Burn examples are\norganized as library crates with one or more examples that are executable binaries. An example can\nthen be executed using the following cargo command line in the root of the Burn repository:\n\n```bash\ncargo run --example <example name>\n```\n\nTo learn more about crates and examples, read the Rust section below.\n\n<details>\n<summary><strong>🦀 About Rust crates</strong></summary>\n\nEach Burn example is a **package** which are subdirectories of the `examples` directory. A package\nis composed of one or more **crates**.\n\nA package is a bundle of one or more crates that provides a set of functionality. A package contains\na `Cargo.toml` file that describes how to build those crates.\n\nA crate is a compilation unit in Rust. It could be a single file, but it is often easier to split up\ncrates into multiple **modules**.\n\nA module lets us organize code within a crate for readability and easy reuse. Modules also allow us\nto control the _privacy_ of items. For instance the `pub(crate)` keyword is employed to make a\nmodule publicly available inside the crate. In the snippet below there are four modules declared,\ntwo of them are public and visible to the users of the crates, one of them is public inside the\ncrate only and crate users cannot see it, at last one is private when there is no keyword. These\nmodules can be single files or a directory with a `mod.rs` file inside.\n\n```rust, ignore\npub mod data;\npub mod inference;\npub(crate) mod model;\nmod training;\n```\n\nA crate can come in one of two forms: a **binary crate** or a **library crate**. When compiling a\ncrate, the compiler first looks in the crate root file (`src/lib.rs` for a library crate and\n`src/main.rs` for a binary crate). Any module declared in the crate root file will be inserted in\nthe crate for compilation.\n\nAll Burn examples are library crates and they can contain one or more executable examples that uses\nthe library. We even have some Burn examples that uses the library crate of other examples.\n\nThe examples are unique files under the `examples` directory. Each file produces an executable file\nwith the same name. Each example can then be executed with `cargo run --example <executable name>`.\n\nBelow is a file tree of a typical Burn example package:\n\n```\nexamples/burn-example\n├── Cargo.toml\n├── examples\n│   ├── example1.rs      ---> compiled to example1 binary\n│   ├── example2.rs      ---> compiled to example2 binary\n│   └── ...\n└── src\n    ├── lib.rs           ---> this is the root file for a library\n    ├── module1.rs\n    ├── module2.rs\n    └── ...\n```\n\n</details><br>\n\nThe following additional examples are currently available if you want to check them out:\n\n| Example                                                                                                   | Description                                                                                                                                                                                  |\n| :-------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| [Custom CSV Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-csv-dataset)             | Implements a dataset to parse CSV data for a regression task.                                                                                                                                |\n| [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression)                      | Trains a simple MLP on the California Housing dataset to predict the median house value for a district.                                                                                      |\n| [Custom Image Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-image-dataset)         | Trains a simple CNN on custom image dataset following a simple folder structure.                                                                                                             |\n| [Custom Renderer](https://github.com/tracel-ai/burn/tree/main/examples/custom-renderer)                   | Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress.                                                                                              |\n| [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly.                                                                                                                      |\n| [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web)        | An interactive MNIST inference demo in the browser. The demo is available [online](https://burn.dev/demo/).                                                                                  |\n| [MNIST Training](https://github.com/tracel-ai/burn/tree/main/examples/mnist)                              | Demonstrates how to train a custom [`Module`](./building-blocks/module.md) (MLP) with the [`Learner`](./building-blocks/learner.md) configured to log metrics and keep training checkpoints. |\n| [ONNX Import Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference)         | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn.                                                                                                 |\n| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/import-model-weights)          | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn.                                                                                               |\n| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification)           | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample.                                             |\n| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation)                   | Trains a text generation transformer model on the DbPedia dataset.                                                                                                                           |\n| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan)                        | Trains a WGAN model to generate new handwritten digits based on MNIST.                                                                                                                       |\n\nFor more information on each example, see their respective `README.md` file. Be sure to check out\nthe [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date\nlist.\n\n<div class=\"warning\">\n\nNote that some examples use the\n[`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download the\ndatasets required in the examples. This is a Python library, which means that you will need to\ninstall Python before running these examples. This requirement will be clearly indicated in the\nexample's README when applicable.\n\n</div>\n"
  },
  {
    "path": "burn-book/src/getting-started.md",
    "content": "# Getting Started\n\nBurn is a deep learning framework in the Rust programming language. Therefore, it goes without\nsaying that one must understand the basic notions of Rust. Reading the first chapters of the\n[Rust Book](https://doc.rust-lang.org/book/) is recommended, but don't worry if you're just starting\nout. We'll try to provide as much context and reference to external resources when required. Just\nlook out for the **🦀 Rust Note** indicators.\n\n## Installing Rust\n\nFor installation instructions, please refer to the\n[installation page](https://doc.rust-lang.org/book/ch01-01-installation.html). It explains in\ndetails the most convenient way for you to install Rust on your computer, which is the very first\nthing to do to start using Burn.\n\n## Creating a Burn application\n\nOnce Rust is correctly installed, create a new Rust application by using Rust's build system and\npackage manager Cargo. It is automatically installed with Rust.\n\n<details>\n<summary><strong>🦀 Cargo Cheat Sheet</strong></summary>\n\n[Cargo](https://doc.rust-lang.org/cargo/) is a very useful tool to manage Rust projects because it\nhandles a lot of tasks. More precisely, it is used to compile your code, download the\nlibraries/packages your code depends on, and build said libraries.\n\nBelow is a quick cheat sheet of the main `cargo` commands you might use throughout this guide.\n\n| Command             | Description                                                                                  |\n| ------------------- | -------------------------------------------------------------------------------------------- |\n| `cargo new` _path_  | Create a new Cargo package in the given directory.                                           |\n| `cargo add` _crate_ | Add dependencies to the Cargo.toml manifest file.                                            |\n| `cargo build`       | Compile the local package and all of its dependencies (in debug mode, use `-r` for release). |\n| `cargo check`       | Check the local package for compilation errors (much faster).                                |\n| `cargo run`         | Run the local package binary.                                                                |\n\nFor more information, check out\n[Hello, Cargo!](https://doc.rust-lang.org/book/ch01-03-hello-cargo.html) in the Rust Book.\n\n</details><br>\n\nIn the directory of your choice, run the following:\n\n```console\ncargo new my_burn_app\n```\n\nThis will initialize the `my_burn_app` project directory with a `Cargo.toml` file and a `src`\ndirectory with an auto-generated `main.rs` file inside. Head inside the directory to check:\n\n```console\ncd my_burn_app\n```\n\nThen, add Burn as a dependency:\n\n```console\ncargo add burn --features wgpu\n```\n\nFinally, compile the local package by executing the following:\n\n```console\ncargo build\n```\n\nThat's it, you're ready to start! You have a project configured with Burn and the WGPU backend,\nwhich allows to execute low-level operations on any platform using the GPU.\n\n<div class=\"warning\">\n\nWhen using one of the `wgpu` backends, you may encounter compilation errors related to recursive\ntype evaluation. This is due to complex type nesting within the `wgpu` dependency chain.\n\nTo resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` file:\n\n```rust\n#![recursion_limit = \"256\"]\n```\n\nThe default recursion limit (128) is often just below the required depth (typically 130-150) due to\ndeeply nested associated types and trait bounds.\n\n</div>\n\n## Writing a code snippet\n\nThe `src/main.rs` was automatically generated by Cargo, so let's replace its content with the\nfollowing:\n\n```rust, ignore\nuse burn::tensor::Tensor;\nuse burn::backend::Wgpu;\n\n// Type alias for the backend to use.\ntype Backend = Wgpu;\n\nfn main() {\n    let device = Default::default();\n    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first\n    let tensor_1 = Tensor::<Backend, 2>::from_data([[2., 3.], [4., 5.]], &device);\n    let tensor_2 = Tensor::<Backend, 2>::ones_like(&tensor_1);\n\n    // Print the element-wise addition (done with the WGPU backend) of the two tensors.\n    println!(\"{}\", tensor_1 + tensor_2);\n}\n```\n\n<details>\n<summary><strong>🦀 Use Declarations</strong></summary>\n\nTo bring any of the Burn module or item into scope, a `use` declaration is added.\n\nIn the example above, we wanted bring the `Tensor` struct and `Wgpu` backend into scope with the\nfollowing:\n\n```rust, ignore\nuse burn::tensor::Tensor;\nuse burn::backend::Wgpu;\n```\n\nThis is pretty self-explanatory in this case. But, the same declaration could be written as a\nshortcut to simultaneously binding of multiple paths with a common prefix:\n\n```rust, ignore\nuse burn::{tensor::Tensor, backend::Wgpu};\n```\n\nIn this example, the common prefix is pretty short and there are only two items to bind locally.\nTherefore, the first usage with two `use` declarations might be preferred. But know that both\nexamples are valid. For more details on the `use` keyword, take a look at\n[this section](https://doc.rust-lang.org/book/ch07-04-bringing-paths-into-scope-with-the-use-keyword.html)\nof the Rust Book or the\n[Rust reference](https://doc.rust-lang.org/reference/items/use-declarations.html).\n\n</details><br>\n\n<details>\n<summary><strong>🦀 Generic Data Types</strong></summary>\n\nIf you're new to Rust, you're probably wondering why we had to use `Tensor::<Backend, 2>::...`.\nThat's because the `Tensor` struct is [generic](https://doc.rust-lang.org/book/ch10-01-syntax.html)\nover multiple concrete data types. More specifically, a `Tensor` can be defined using three generic\nparameters: the backend, the number of dimensions (rank) and the data type (defaults to `Float`).\nHere, we only specify the backend and number of dimensions since a `Float` tensor is used by\ndefault. For more details on the `Tensor` struct, take a look at\n[this section](./building-blocks/tensor.md).\n\nMost of the time when generics are involved, the compiler can infer the generic parameters\nautomatically. In this case, the compiler needs a little help. This can usually be done in one of\ntwo ways: providing a type annotation or binding the generic parameter via the _turbofish_ `::<>`\nsyntax. In the example above we used the so-called _turbofish_ syntax, but we could have used type\nannotations instead like this:\n\n```rust, ignore\nlet tensor_1: Tensor<Backend, 2> = Tensor::from_data([[2., 3.], [4., 5.]]);\nlet tensor_2 = Tensor::ones_like(&tensor_1);\n```\n\nYou probably noticed that we provided a type annotation for the first tensor only and yet this\nexample still works. That's because the compiler (correctly) inferred that `tensor_2` had the same\ngeneric parameters. The same could have been done in the original example, but specifying the\nparameters for both is more explicit.\n\n</details><br>\n\nBy running `cargo run`, you should now see the result of the addition:\n\n```console\nTensor {\n  data:\n[[3.0, 4.0],\n [5.0, 6.0]],\n  shape:  [2, 2],\n  device:  DefaultDevice,\n  backend:  \"wgpu\",\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}\n```\n\nWhile the previous example is somewhat trivial, the upcoming basic workflow section will walk you\nthrough a much more relevant example for deep learning applications.\n\n## Using `prelude`\n\nBurn comes with a variety of things in its core library. When creating a new model or using an\nexisting one for inference, you may need to import every single component you used, which could be a\nlittle verbose.\n\nTo address it, a `prelude` module is provided, allowing you to easily import commonly used structs\nand macros as a group:\n\n```rust, ignore\nuse burn::prelude::*;\n```\n\nwhich is equal to:\n\n```rust, ignore\nuse burn::{\n    config::Config,\n    module::Module,\n    nn,\n    tensor::{\n        backend::Backend, Bool, Device, ElementConversion, Float, Int, Shape, Tensor,\n        TensorData,\n    },\n};\n```\n\n<div class=\"warning\">\n\nFor the sake of simplicity, the subsequent chapters of this book will all use this form of importing\nexcept in the [Building Blocks](./building-blocks) chapter, as explicit importing aids users in\ngrasping the usage of particular structures and macros.\n\n</div>\n"
  },
  {
    "path": "burn-book/src/models-and-pretrained-weights.md",
    "content": "# Models and Pre-Trained Weights\n\n## Models Repository\n\nThe [`models`](https://github.com/tracel-ai/models) repository contains definitions of different\ndeep learning models with examples for different domains like computer vision and natural language\nprocessing.\n\nThis includes image classification models such as\n[`MobileNetV2`](https://github.com/tracel-ai/models/tree/main/mobilenetv2-burn),\n[`SqueezeNet`](https://github.com/tracel-ai/models/tree/main/squeezenet-burn) and\n[`ResNet`](https://github.com/tracel-ai/models/tree/main/resnet-burn), object detection models such\nas [`YOLOX`](https://github.com/tracel-ai/models/tree/main/yolox-burn) and language models like\n[`BERT` and `RoBERTa`](https://github.com/tracel-ai/models/tree/main/bert-burn).\n\nBe sure to check out the up-to-date\n[collection of models](https://github.com/tracel-ai/models?tab=readme-ov-file#collection-of-official-models)\nto get you started. Pre-trained weights are available for every supported architecture in this\ncollection. You will also find a spotlight of\n[community contributed models](https://github.com/tracel-ai/models?tab=readme-ov-file#community-contributions).\n\n## Burn-LM (alpha)\n\n[`Burn-LM`](https://github.com/tracel-ai/burn-lm) is an LLM inference engine built on Burn. It\nprovides access to large language models with open-source pre-trained weights and supports running,\nfine-tuning, and experimenting with them on any Burn backend.\n\nUnlike tools focused solely on inference, Burn-LM is designed to work in a unified way across\ndifferent models and tasks, making it easier to explore both inference and training workflows within\nthe same framework.\n"
  },
  {
    "path": "burn-book/src/motivation.md",
    "content": "# Why Burn?\n\nWhy bother with the effort of creating an entirely new deep learning framework from scratch when\nPyTorch, TensorFlow, and other frameworks already exist? Spoiler alert: Burn isn't merely a\nreplication of PyTorch or TensorFlow in Rust. It represents a novel approach, placing significant\nemphasis on making the right compromises in the right areas to facilitate exceptional flexibility,\nhigh performance, and a seamless developer experience. Burn isn’t a framework specialized for only\none type of application, it is designed to serve as a versatile framework suitable for a wide range\nof research and production uses. The foundation of Burn's design revolves around three key user\nprofiles:\n\n**Machine Learning Researchers** require tools to construct and execute experiments efficiently.\nIt’s essential for them to iterate quickly on their ideas and design testable experiments which can\nhelp them discover new findings. The framework should facilitate the swift implementation of\ncutting-edge research while ensuring fast execution for testing.\n\n**Machine Learning Engineers** are another important demographic to keep in mind. Their focus leans\nless on swift implementation and more on establishing robustness, seamless deployment, and\ncost-effective operations. They seek dependable, economical models capable of achieving objectives\nwithout excessive expense. The whole machine learning workflow —from training to inference— must be\nas efficient as possible with minimal unpredictable behavior.\n\n**Low level Software Engineers** working with hardware vendors want their processing units to run\nmodels as fast as possible to gain competitive advantage. This endeavor involves harnessing\nhardware-specific features such as Tensor Core for Nvidia. Since they are mostly working at a system\nlevel, they want to have absolute control over how the computation will be executed.\n\nThe goal of Burn is to satisfy all of those personas!\n"
  },
  {
    "path": "burn-book/src/onnx-import.md",
    "content": "# ONNX Import\n\n## Introduction\n\nAs deep learning evolves, interoperability between frameworks becomes crucial. Burn provides robust\nsupport for importing [ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html)\nmodels through the [`burn-onnx`](https://github.com/tracel-ai/burn-onnx) crate, enabling you to\nleverage pre-trained models in your Rust-based deep learning projects.\n\n## Why Import Models?\n\nImporting pre-trained models offers several advantages:\n\n1. **Time-saving**: Skip the resource-intensive process of training models from scratch.\n2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by\n   researchers and industry leaders.\n3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from\n   knowledge transfer.\n4. **Consistency across frameworks**: Maintain consistent performance when moving between\n   frameworks.\n\n## Understanding ONNX\n\nONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models\nwith these key features:\n\n- **Framework agnostic**: Provides a common format that works across various deep learning\n  frameworks.\n- **Comprehensive representation**: Captures both the model architecture and trained weights.\n- **Wide support**: Compatible with popular frameworks like PyTorch, TensorFlow, and scikit-learn.\n\nThis standardization allows seamless movement of models between different frameworks and deployment\nenvironments.\n\n## Burn's ONNX Support\n\nBurn's approach to ONNX import offers unique advantages:\n\n1. **Native Rust code generation**: Translates ONNX models into Rust source code for deep\n   integration with Burn's ecosystem.\n2. **Compile-time optimization**: Leverages the Rust compiler to optimize the generated code,\n   potentially improving performance.\n3. **No runtime dependency**: Eliminates the need for an ONNX runtime, unlike many other solutions.\n4. **Trainability**: Allows imported models to be further trained or fine-tuned using Burn.\n5. **Portability**: Enables compilation for various targets, including WebAssembly and embedded\n   devices.\n6. **Backend flexibility**: Works with any of Burn's supported backends.\n\n## ONNX Compatibility\n\nBurn recommends ONNX models use **opset version 16 or higher** for best compatibility. While models\nwith older opset versions may work, opset 16+ ensures access to all supported operators and their\nlatest behavior. If you encounter issues with an older model, consider upgrading it using the ONNX\nversion converter.\n\n### Upgrading ONNX Models\n\nThere are two simple ways to upgrade your ONNX models to the recommended opset version:\n\nOption 1: Use the provided utility script:\n\n```\nuv run --script https://raw.githubusercontent.com/tracel-ai/burn-onnx/refs/heads/main/onnx_opset_upgrade.py\n```\n\nOption 2: Use a custom Python script:\n\n```python\nimport onnx\nfrom onnx import version_converter, shape_inference\n\n# Load your ONNX model\nmodel = onnx.load('path/to/your/model.onnx')\n\n# Convert the model to opset version 16\nupgraded_model = version_converter.convert_version(model, 16)\n\n# Apply shape inference to the upgraded model\ninferred_model = shape_inference.infer_shapes(upgraded_model)\n\n# Save the converted model\nonnx.save(inferred_model, 'upgraded_model.onnx')\n```\n\n## Step-by-Step Guide\n\nFollow these steps to import an ONNX model into your Burn project:\n\n### Step 1: Update `Cargo.toml`\n\nFirst, add the required dependencies to your `Cargo.toml`:\n\n```toml\n[dependencies]\nburn = { version = \"~0.21\", features = [\"ndarray\"] }\n\n[build-dependencies]\nburn-onnx = \"~0.21\"\n```\n\n### Step 2: Update `build.rs`\n\nIn your `build.rs` file:\n\n```rust, ignore\nuse burn_onnx::ModelGen;\n\nfn main() {\n    ModelGen::new()\n        .input(\"src/model/my_model.onnx\")\n        .out_dir(\"model/\")\n        .run_from_script();\n}\n```\n\nThis generates Rust code and a `.bpk` weights file from your ONNX model during the build process.\n\n### Step 3: Modify `mod.rs`\n\nIn your `src/model/mod.rs` file, include the generated code:\n\n```rust, ignore\npub mod my_model {\n    include!(concat!(env!(\"OUT_DIR\"), \"/model/my_model.rs\"));\n}\n```\n\n### Step 4: Use the Imported Model\n\nNow you can use the imported model in your code:\n\n```rust, ignore\nuse burn::tensor;\nuse burn_ndarray::{NdArray, NdArrayDevice};\nuse model::my_model::Model;\n\nfn main() {\n    let device = NdArrayDevice::default();\n\n    // Create model instance and load weights from target dir default device\n    let model: Model<NdArray<f32>> = Model::default();\n\n    // Create input tensor (replace with your actual input)\n    let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device);\n\n    // Perform inference\n    let output = model.forward(input);\n\n    println!(\"Model output: {:?}\", output);\n}\n```\n\n## Advanced Configuration\n\nThe `ModelGen` struct provides configuration options:\n\n```rust, ignore\nuse burn_onnx::{ModelGen, LoadStrategy};\n\nModelGen::new()\n    .input(\"path/to/model.onnx\")\n    .out_dir(\"model/\")\n    .development(true)                       // Enable development mode for debugging\n    .load_strategy(LoadStrategy::Embedded)   // Embed weights in the binary\n    .run_from_script();\n```\n\n- `input`: Path to the ONNX model file\n- `out_dir`: Output directory for generated code and weights\n- `development`: When enabled, generates additional debug files (`.onnx.txt`, `.graph.txt`)\n- `load_strategy`: Controls which weight-loading constructors are generated on the `Model` struct\n  (see below)\n\nModel weights are stored in Burnpack format (`.bpk`), which provides efficient serialization and\nloading.\n\n### Load Strategy\n\nThe `LoadStrategy` enum controls how the generated model loads its weights:\n\n| Strategy   | Generated constructors                          | `Default` impl | Use case                                  |\n|------------|------------------------------------------------|-----------------|-------------------------------------------|\n| `File`     | `from_file()`, `from_bytes()`                  | Yes             | Standard desktop/server (default)         |\n| `Embedded` | `from_embedded()`, `from_bytes()`              | Yes             | Single binary, small models               |\n| `Bytes`    | `from_bytes()`                                 | No              | WASM, embedded, custom loaders            |\n| `None`     | (none)                                         | No              | Manual weight management                  |\n\nThe default strategy is `File`, which keeps weights in a separate `.bpk` file and generates a\n`from_file()` constructor.\n\nFor WebAssembly or environments without filesystem access, use `LoadStrategy::Bytes`:\n\n```rust, ignore\nModelGen::new()\n    .input(\"model.onnx\")\n    .out_dir(\"model/\")\n    .load_strategy(LoadStrategy::Bytes)\n    .run_from_script();\n```\n\nThen load weights at runtime from any byte source (e.g., a network fetch):\n\n```rust, ignore\nlet model = Model::<Backend>::from_bytes(weight_bytes, &device);\n```\n\n## Loading and Using Models\n\nYou can load models in several ways, depending on the `LoadStrategy` used during code generation:\n\n```rust, ignore\n// Load from the output directory with default device (recommended for most use cases)\n// This automatically loads weights from the .bpk file\n// Available with LoadStrategy::File or LoadStrategy::Embedded\nlet model = Model::<Backend>::default();\n\n// Create a new model instance with a specific device\n// (initializes weights randomly; load weights via `load_from` afterward)\nlet model = Model::<Backend>::new(&device);\n\n// Load from a specific .bpk file (LoadStrategy::File)\nlet model = Model::<Backend>::from_file(\"path/to/weights.bpk\", &device);\n\n// Load from in-memory bytes (LoadStrategy::File, Embedded, or Bytes)\nlet model = Model::<Backend>::from_bytes(weight_bytes, &device);\n\n// Load from embedded weights (LoadStrategy::Embedded)\nlet model = Model::<Backend>::from_embedded(&device);\n```\n\n## Troubleshooting\n\nCommon issues and solutions:\n\n1. **Unsupported ONNX operator**: Check the\n   [list of supported ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).\n   You may need to simplify your model or wait for support.\n\n2. **Build errors**: Ensure your `burn-onnx` version matches your Burn version and verify the ONNX\n   file path in `build.rs`.\n\n3. **Runtime errors**: Confirm that your input tensors match the expected shape and data type of\n   your model.\n\n4. **Performance issues**: Consider using a more performant backend or optimizing your model\n   architecture.\n\n5. **Viewing generated files**: Find the generated Rust code and weights in the `OUT_DIR` directory\n   (usually `target/debug/build/<project>/out`).\n\n## Examples and Resources\n\nFor practical examples, check out the\n[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples):\n\n1. [ONNX Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) -\n   MNIST inference example\n2. [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) -\n   SqueezeNet running in the browser via WebAssembly\n3. [Raspberry Pi Pico](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico) -\n   Embedded deployment example\n\nThese demonstrate real-world usage of ONNX import in Burn projects.\n\nFor contributors looking to add support for new ONNX operators:\n\n- [Development Guide](https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md) -\n  Step-by-step guide for implementing new operators\n\n## Conclusion\n\nImporting ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's\nperformance and Rust's safety features. Following this guide, you can seamlessly integrate ONNX\nmodels into your Burn projects for inference, fine-tuning, or further development.\n\nThe `burn-onnx` crate is actively developed, with ongoing work to support more ONNX operators and\nimprove performance. Visit the [burn-onnx repository](https://github.com/tracel-ai/burn-onnx) for\nupdates and to contribute!\n"
  },
  {
    "path": "burn-book/src/overview.md",
    "content": "# Overview\n\nWelcome to The Burn Book 👋\n\nThis book will help you get started with the Burn deep learning framework, whether you are an\nadvanced user or a beginner. We have crafted some sections for you:\n\n- [Basic Workflow: From Training to Inference](./basic-workflow): We'll start with the fundamentals,\n  guiding you through the entire workflow, from training your models to deploying them for\n  inference. This section lays the groundwork for your Burn expertise.\n\n- [Building Blocks](./building-blocks): Dive deeper into Burn's core components, understanding how\n  they fit together. This knowledge forms the basis for more advanced usage and customization.\n\n- [Performance - Good Practices](./performance/good-practices/): Tips for writing models and\n  training code that make the most of hardware resources while avoiding common pitfalls that can\n  slow down execution.\n\n- [Custom Training Loop](./custom-training-loop.md): Gain the power to customize your training\n  loops, fine-tuning your models to meet your specific requirements. This section empowers you to\n  harness Burn's flexibility to its fullest.\n\n- [Saving & Loading Models](./saving-and-loading.md): Learn how to save and load your trained\n  models, including importing weights from PyTorch and SafeTensors formats.\n\n- [ONNX Import](./onnx-import.md): Learn how to import ONNX models using the\n  [burn-onnx](https://github.com/tracel-ai/burn-onnx) crate.\n\n- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md): Get started quickly with\n  ready-to-use models and pre-trained weights.\n\n- [Advanced](./advanced): Finally, venture into advanced topics, exploring Burn's capabilities at\n  their peak. This section caters to those who want to push the boundaries of what's possible with\n  Burn.\n\nThroughout the book, we assume a basic understanding of deep learning concepts, but we may refer to\nadditional material when it seems appropriate.\n"
  },
  {
    "path": "burn-book/src/performance/README.md",
    "content": "# Performance\n\nThis section covers the key concepts you need to understand to get the most out of Burn and your\nhardware.\n"
  },
  {
    "path": "burn-book/src/performance/distributed-computing.md",
    "content": "# Distributed Computing\n\nDistributed computing support was introduced in Burn 0.19. Documentation and examples will be\navailable soon.\n"
  },
  {
    "path": "burn-book/src/performance/good-practices/README.md",
    "content": "# Performance - Best Practices\n\nThis section provides valuable insights into the performance characteristics of Burn and guides\nusers on how to effectively leverage them for optimal results.\n\nIt includes several sections, each offering relevant details. While understanding these concepts can\naid in model optimization, it’s always crucial to conduct benchmarks and profile models to\naccurately assess performance improvements.\n\n- [Asynchronous Execution](./asynchronous-execution.md)\n- [Kernel Fusion](./kernel-fusion.md)\n- [Kernel Selection](./kernel-selection.md)\n"
  },
  {
    "path": "burn-book/src/performance/good-practices/asynchronous-execution.md",
    "content": "# Asynchronous Execution\n\nMost Burn backends execute tensor operations in an asynchronous manner. However, the async notation\nis often not required for most tensor operations, privileging the simplicity of sync Rust.\n\nThere are only a few operations that trigger synchronization of the backend, and it is very\nimportant to correctly handle those to optimize hardware utilization. Those operations are\n`into_data`, `into_scalar`, and `sync`. Some tensor operations might call `into_data` underneath,\ntriggering a synchronization, like `to_device` for some backends.\n\nThere are several ways to minimize synchronization overhead, one of which is to batch sync\noperations into a single transaction. Burn provides a high-level composable API to build\ntransactions, which will only trigger a single sync on the device.\n\nFor instance, it is often used when collecting metrics during training:\n\n```rust\n/// All of these variables are tensors.\nlet (output, loss, targets) = ..;\n\n/// Now output, loss, and targets will be `TensorData` stored on the CPU.\nlet [output, loss, targets] = Transaction::default()\n    .register(output)\n    .register(loss)\n    .register(targets)\n    .execute()\n    .try_into()\n    .expect(\"Correct amount of tensor data\");\n```\n\nAnother way of optimizing reads and avoiding device stalls is to read the data on a different\nthread. Under the hood, CubeCL-based backends assign different execution queues for different\nthreads, meaning that syncing a thread shouldn’t impact the throughput of another thread.\n\n## Using Different Backends for Different Tasks\n\nTensor operations aren’t the only things that are asynchronous; dataset and dataloading are also\nlazily executed. This allows for efficient data augmentation and sampling without having to cache\nhuge datasets on disk. However, this might reduce training throughput if data augmentation is\nperformed on the same device as the training itself. So, it is normally encouraged to use a\ndifferent device, maybe even a different backend, for that purpose. For optimal performance, also\navoid small allocations followed by a batching procedure. Even if it doesn’t break asynchronicity,\nit can slow down performance.\n\n```rust\n/// Items is a vector of many tensors.\nlet items = ..;\nlet batch = Tensor::cat(items, 1);\n```\n\nPrefer doing the concatenation of tensors on the data augmentation device and not on the training\ndevice.\n\n```rust\n/// Items is a vector of many tensors.\nlet items = ..;\nlet device_training = ..;\nlet axis_batch = 0;\n\nlet items = Tensor::cat(items, axis_batch);\nlet batch = Tensor::from_data(items.into_data(), device_training);\n```\n"
  },
  {
    "path": "burn-book/src/performance/good-practices/kernel-fusion.md",
    "content": "# Kernel Fusion\n\nAn interesting property of async execution is that it allows performance optimizations like kernel\nfusion. Coupled with CubeCL and its Just-In-Time compiler, Burn can serialize tensor operations into\na symbolic graph, then optimize it for improved efficiency.\n\nKernel fusion may reorder operations to reduce global memory reads, writes, and allocations. Being\naware of which operations can be fused is relevant, as it can be easy to break an execution graph.\n\nThe easiest way to optimize for fusion is to avoid keeping tensors alive for too long. When fusion\nisn’t possible, all tensors that will be used later will trigger a global memory write. Fortunately,\nRust and Clippy are quite good at detecting unnecessary clones, but special care should still be\ntaken.\n\nView operations can also interfere with fusion. They can be included in optimized graphs, but only\nto a limited extent, and they reduce vectorization potential as we have fewer guarantees about\nmemory access patterns with transformed indices. So, it is good practice to group view operations\ntogether before executing blocks of operations.\n\n```rust\nlet tensor4 = tensor1.unsqueeze().matmul(tensor2) + tensor3.unsqueeze();\n```\n\nCould be improved with the following:\n\n```rust\nlet tensor1 = tensor1.unsqueeze();\nlet tensor3 = tensor3.unsqueeze();\nlet tensor4 = tensor1.matmul(tensor2) + tensor3;\n```\n\nThis reduces the necessary reordering and may reduce a global memory write or improve vectorization.\nWe might be able to detect these patterns in the future, but for now, it’s a good idea to order your\noperations using this pattern. As a reminder, view operations typically only update tensor metadata\nin most cases. These operations include `slice`, `slice_assign`, `select`, `gather`, `scatter`,\n`reshape`, `swap_dims`, `transpose`, `unsqueeze`, etc.\n\nWith fusion enabled, it is often not necessary to write custom kernels, as you can rely on our\nsystem to optimize most element-wise operations. However, most compute-bound kernels require many\ntricks and deep knowledge of GPU memory architectures, where automatic compiler optimizations often\nunderperform compared to human-designed algorithms. This is why Burn’s approach to fusion is\ncentered around fuse-on-read and fuse-on-write. This means that complex compute-bound kernels that\nchange the shapes of tensors can fuse a block of element-wise operations when reading the input\ntensor and when writing the output tensor. The implication is that multiple compute-bound operations\nin a sequence can reduce fusion potential.\n\n```rust\n// This line might trigger 3 writes: tensor1, tensor2, and tensor3, if tensor1 and tensor2 are abstract tensors.\nlet tensor3 = tensor1.clone().sum_dim(tensor2.clone(), 2);\nlet tensor4 = tensor2.sum_dim(tensor3, 2);\nlet tensor5 = tensor4 + (tensor1 * tensor2);\n```\n\n```rust\nlet tmp = tensor1.clone() + tensor2.clone();\nlet tensor3 = tensor1.sum_dim(tensor2, 2);\nlet tensor4 = tensor2.sum_dim(tensor3, 2);\nlet tensor5 = tensor4 + tmp;\n```\n\nThe lesson? Whenever possible, pass only the latest value to a compute operation. Don’t clone a\ntensor before compute-bound operations, as it might trigger an additional write if that tensor isn’t\nmaterialized from initial fusion.\n\nIt’s a bit complex, but the first code snippet is actually better if `tensor1` and `tensor2` are\nconcrete in global memory. This would be the case if `tensor1` and `tensor2` are model parameters,\nso prefer this implementation style in such scenarios.\n\nThe second code snippet is preferred when `tensor1` and `tensor2` are virtual tensors, meaning they\nwere fused by earlier operations and require a global memory read to be accessed later. This happens\nif those tensors are part of a signal in neural networks.\n\nReordering operations can help in such scenarios but will not create temporary values, making the\nprevious optimization harder. We might eventually automatically optimize these cases, but the\nsolution space is quite large, and it’s not a planned optimization. Profiling model blocks is always\na good idea to identify which code block is faster when faced with ambiguous situations.\n"
  },
  {
    "path": "burn-book/src/performance/good-practices/kernel-selection.md",
    "content": "# Kernel Selection\n\nAs mentioned earlier, complex compute-bound operations are highly non-trivial and require many\ntricks for optimal performance. However, the way these tricks are applied varies depending on the\nhardware and problem shapes. To select the best kernel, we use a search method with a highly\nconfigurable autotune system that performs micro-benchmarks at runtime on the current hardware.\n\nThis may trigger a cold start, but the results of these benchmarks are cached on disk for subsequent\nexecutions.\n\nFor deployment or training on spot instances, it’s a good idea to bundle the autotune cache with the\ncode to mitigate cold starts. Refer to the\n[CubeCL configuration documentation](https://burn.dev/books/cubecl/advanced-usage/config.html) for\nmore details on fine-grained settings .\n\nFrom the user’s point of view, kernel selection shouldn’t be a problem, but as usual, crafting\nmodels with even shapes, multiples of 8, can significantly improve performance. Avoid creating\ntensors with shapes that are multiples of 10, like `[1000, 1000]`, as these typically require bounds\nchecking and may limit vectorization.\n\nPrefer shapes like `[1024, 1024]`, where dimensions are multiples of 32 or powers of 2, as these are\ngenerally optimal. If you have no choice but to use a suboptimal shape, prefer handling it in a\nsingle kernel, transforming it into an optimal shape. It’s better to have a slow neural network\nlayer followed by fast ones than to propagate unevenness and end up with smaller, but slower,\nlayers.\n"
  },
  {
    "path": "burn-book/src/performance/quantization.md",
    "content": "# Quantization\n\nQuantization techniques perform computations and store tensors in lower precision data types like\n8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep\nlearning model categorized as:\n\n- Post-training quantization (PTQ)\n- Quantization aware training (QAT)\n\nIn post-training quantization, the model is trained in floating point precision and later converted\nto the lower precision data type. There are two types of post-training quantization:\n\n1. Static quantization: quantizes the weights and activations of the model. Quantizing the\n   activations statically requires data to be calibrated (i.e., recording the activation values to\n   compute the optimal quantization parameters with representative data).\n1. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the\n   activations are dynamically at runtime.\n\nSometimes post-training quantization is not able to achieve acceptable task accuracy. In general,\nthis is where quantization-aware training (QAT) can be used: during training, fake-quantization\nmodules are inserted in the forward and backward passes to simulate quantization effects, allowing\nthe model to learn representations that are more robust to reduced precision.\n\nBurn does not currently support QAT. Only post-training quantization (PTQ) is implemented at this\ntime.\n\n<div class=\"warning\">\n\nQuantization support in Burn is currently in active development.\n\nIt supports the following PTQ modes on some backends:\n\n- Per-tensor and per-block quantization to 8-bit, 4-bit and 2-bit representations\n\nNo integer operations are currently supported, which means tensors are dequantized to perform the\noperations in floating point precision.\n\n</div>\n\n## Module Quantization\n\nQuantizing the weights of your model after training is quite simple. We have access to the weight\ntensors and can collect their statistics, such as the min and max value when using\n`MinMaxCalibration`, to compute the quantization parameters.\n\n```rust , ignore\n# use burn::module::Quantizer;\n# use burn::tensor::quantization::{Calibration, QuantLevel, QuantParam, QuantScheme, QuantValue};\n#\n// Quantization config\nlet scheme = QuantScheme::default()\n    .with_level(QuantLevel::Block(32))\n    .with_value(QuantValue::Q4F)\n    .with_param(QuantParam::F16);\nlet mut quantizer = Quantizer {\n    calibration: Calibration::MinMax,\n    scheme,\n};\n\n// Quantize the weights\nlet model = model.quantize_weights(&mut quantizer);\n```\n\n### Calibration\n\nCalibration is the step during quantization where the range of all floating-point tensors is\ncomputed. This is pretty straightforward for weights since the actual range is known at\n_quantization-time_ (weights are static), but activations require more attention.\n\nTo compute the quantization parameters, Burn supports the following `Calibration` methods.\n\n| Method   | Description                                                                      |\n| :------- | :------------------------------------------------------------------------------- |\n| `MinMax` | Computes the quantization range mapping based on the running min and max values. |\n\n### Quantization Scheme\n\nA quantization scheme defines how an input is quantized, including the representation of quantized\nvalues, storage format, granularity, and how the values are scaled.\n\n```rust\nlet scheme = QuantScheme::default()\n    .with_mode(QuantMode::Symmetric)         // Quantization mode\n    .with_level(QuantLevel::block([2, 16]))  // Granularity (per-tensor or per-block)\n    .with_value(QuantValue::Q8S)             // Data type of quantized values, independent of how they're stored\n    .with_store(QuantStore::Native)          // Storage format for quantized values\n    .with_param(QuantParam::F16);            // Precision for quantization parameters\n```\n\n#### Quantization Mode\n\n| Mode        | Description                                  |\n| :---------- | :------------------------------------------- |\n| `Symmetric` | Values are scaled symmetrically around zero. |\n\n#### Quantization Level\n\n| Level                          | Description                                                                                                  |\n| :----------------------------- | :----------------------------------------------------------------------------------------------------------- |\n| `Tensor`                       | A single quantization parameter set for the entire tensor.                                                   |\n| `Block(block_size: BlockSize)` | Tensor divided into blocks (1D, 2D, or higher) defined by block_size, each with its own quantization params. |\n\n#### Quantization Value\n\n| Value  | Bits | Description                                   |\n| :----- | :--: | :-------------------------------------------- |\n| `Q8F`  |  8   | 8-bit full-range quantization                 |\n| `Q4F`  |  4   | 4-bit full-range quantization                 |\n| `Q2F`  |  2   | 2-bit full-range quantization                 |\n| `Q8S`  |  8   | 8-bit symmetric quantization                  |\n| `Q4S`  |  4   | 4-bit symmetric quantization                  |\n| `Q2S`  |  2   | 2-bit symmetric quantization                  |\n| `E5M2` |  8   | 8-bit floating-point (5 exponent, 2 mantissa) |\n| `E4M3` |  8   | 8-bit floating-point (4 exponent, 3 mantissa) |\n| `E2M1` |  4   | 4-bit floating-point (2 exponent, 1 mantissa) |\n\n#### Quantization Store\n\n| Store               | Description                                                                                                                                       |\n| :------------------ | :------------------------------------------------------------------------------------------------------------------------------------------------ |\n| `Native`            | Each quantized value is stored directly in a native format, which doesn't require packing and unpacking.                                          |\n| `PackedNative(dim)` | Multiple quantized values packed into a 32-bit integer. Argument is the dimension the tensor is packed on, starting from the innermost dimension. |\n| `PackedU32(dim)`    | Multiple quantized values packed into a 32-bit integer. Argument is the dimension the tensor is packed on, starting from the innermost dimension. |\n\nNative storage is not supported for sub-byte quantization values.\n\n#### Quantization Parameters Precision\n\n| Param  | Description                    |\n| :----- | :----------------------------- |\n| `F32`  | Full floating-point precision. |\n| `F16`  | Half-precision floating point. |\n| `BF16` | Brain float 16-bit precision.  |\n"
  },
  {
    "path": "burn-book/src/saving-and-loading.md",
    "content": "# Saving and Loading Models\n\nSaving your trained machine learning model is quite easy, no matter the output format you choose. As\nmentioned in the [Record](./building-blocks/record.md) section, different formats are supported to\nserialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the\n[MessagePack](https://msgpack.org/) binary serialization format with the help of\n[rmp_serde](https://docs.rs/rmp-serde/).\n\n```rust, ignore\n// Save model in MessagePack format with full precision\nlet recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();\nmodel\n    .save_file(model_path, &recorder)\n    .expect(\"Should be able to save the model\");\n```\n\nNote that the file extension is automatically handled by the recorder depending on the one you\nchoose. Therefore, only the file path and base name should be provided.\n\nNow that you have a trained model saved to your disk, you can easily load it in a similar fashion.\n\n```rust, ignore\n// Load model in full precision from MessagePack file\nlet recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();\nmodel = model\n    .load_file(model_path, &recorder, device)\n    .expect(\"Should be able to load the model weights from the provided file\");\n```\n\n**Note:** models can be saved in different output formats, just make sure you are using the correct\nrecorder type when loading the saved model. Type conversion between different precision settings is\nautomatically handled, but formats are not interchangeable. A model can be loaded from one format\nand saved to another format, just as long as you load it back with the new recorder type afterwards.\n\n## Initialization from Recorded Weights\n\nThe most straightforward way to load weights for a module is simply by using the generated method\n[load_record](https://burn.dev/docs/burn/module/trait.Module.html#tymethod.load_record). Note that\nparameter initialization is lazy, therefore no actual tensor allocation and GPU/CPU kernels are\nexecuted before the module is used. This means that you can use `init(device)` followed by\n`load_record(record)` without any meaningful performance cost.\n\n```rust, ignore\n// Create a dummy initialized model to save\nlet device = Default::default();\nlet model = Model::<MyBackend>::init(&device);\n\n// Save model in MessagePack format with full precision\nlet recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();\nmodel\n    .save_file(model_path, &recorder)\n    .expect(\"Should be able to save the model\");\n```\n\nAfterwards, the model can just as easily be loaded from the record saved on disk.\n\n```rust, ignore\n// Load model record on the backend's default device\nlet record: ModelRecord<MyBackend> =\n    NamedMpkFileRecorder::<FullPrecisionSettings>::new()\n        .load(model_path.into(), &device)\n        .expect(\"Could not load model weights\");\n\n// Initialize a new model with the loaded record/weights\nlet model = Model::init(&device).load_record(record);\n```\n\n## Model Weight Store\n\nWhile the Recorder API works well for basic saving and loading, `burn-store` was introduced to\naddress its limitations around memory efficiency and flexibility. It provides zero-copy\nmemory-mapped loading, cross-framework interoperability (PyTorch and SafeTensors), key remapping,\npartial loading, and filtering. The `burn-store`\ncrate is intended to eventually replace the Recorder API, but since it was recently released, both\nAPIs are supported.\n\n### Supported Formats\n\n| Format          | Extension      | Description                                                                               |\n| --------------- | -------------- | ----------------------------------------------------------------------------------------- |\n| **Burnpack**    | `.bpk`         | Burn's native format with fast loading, zero-copy support, and training state persistence |\n| **SafeTensors** | `.safetensors` | Industry-standard format from Hugging Face for secure tensor serialization                |\n| **PyTorch**     | `.pt`, `.pth`  | Direct loading of PyTorch model weights (read-only)                                       |\n\n### Saving a Model\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, BurnpackStore};\n\n// Save to Burnpack (recommended)\nlet mut store = BurnpackStore::from_file(\"model.bpk\");\nmodel.save_into(&mut store)?;\n\n// Or save to SafeTensors\nuse burn_store::SafetensorsStore;\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\");\nmodel.save_into(&mut store)?;\n```\n\n### Loading a Model\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, BurnpackStore};\n\nlet device = Default::default();\nlet mut model = MyModel::init(&device);\n\n// Load from Burnpack\nlet mut store = BurnpackStore::from_file(\"model.bpk\");\nmodel.load_from(&mut store)?;\n```\n\n### Loading from PyTorch\n\nYou can load weights directly from PyTorch `.pt` files:\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, PytorchStore};\n\nlet mut model = MyModel::init(&device);\nlet mut store = PytorchStore::from_file(\"pytorch_model.pt\");\nmodel.load_from(&mut store)?;\n```\n\n#### Exporting from PyTorch\n\nSave only the model weights (state_dict), not the entire model:\n\n```python\nimport torch\nimport torch.nn as nn\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.conv1 = nn.Conv2d(2, 2, (2, 2))\n        self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False)\n\n    def forward(self, x):\n        return self.conv2(self.conv1(x))\n\nmodel = Net()\ntorch.save(model.state_dict(), \"model.pt\")  # Correct: save state_dict\n# torch.save(model, \"model.pt\")             # Wrong: saves entire model\n```\n\n#### Accessing Nested State Dicts\n\nSome PyTorch checkpoints nest the state_dict under a key:\n\n```rust, ignore\nlet mut store = PytorchStore::from_file(\"checkpoint.pt\")\n    .with_top_level_key(\"state_dict\");\nmodel.load_from(&mut store)?;\n```\n\n### Loading from SafeTensors\n\nFor SafeTensors files exported from PyTorch, use the adapter for proper weight transformation:\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};\n\nlet mut model = MyModel::init(&device);\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_from_adapter(PyTorchToBurnAdapter);\nmodel.load_from(&mut store)?;\n```\n\nFor SafeTensors files created by Burn, no adapter is needed:\n\n```rust, ignore\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\");\nmodel.load_from(&mut store)?;\n```\n\n#### Exporting from PyTorch to SafeTensors\n\n```python\nfrom safetensors.torch import save_file\n\nmodel = Net()\nsave_file(model.state_dict(), \"model.safetensors\")\n```\n\n### Saving for PyTorch Compatibility\n\nUse the adapter when saving for PyTorch consumption:\n\n```rust, ignore\nuse burn_store::{BurnToPyTorchAdapter, SafetensorsStore};\n\nlet mut store = SafetensorsStore::from_file(\"for_pytorch.safetensors\")\n    .with_to_adapter(BurnToPyTorchAdapter)\n    .skip_enum_variants(true);\nmodel.save_into(&mut store)?;\n```\n\n### Handling Load Results\n\nThe `load_from` method returns detailed information about the loading process:\n\n```rust, ignore\nlet result = model.load_from(&mut store)?;\n\n// Print a formatted summary with suggestions\nprintln!(\"{}\", result);\n\n// Or inspect individual fields\nprintln!(\"Applied: {} tensors\", result.applied.len());\nprintln!(\"Missing: {:?}\", result.missing);\nprintln!(\"Errors: {:?}\", result.errors);\n\nif result.is_success() {\n    println!(\"All tensors loaded successfully\");\n}\n```\n\n### Adding Metadata\n\nBurnpack and SafeTensors support custom metadata:\n\n```rust, ignore\nlet mut store = BurnpackStore::from_file(\"model.bpk\")\n    .metadata(\"version\", \"1.0\")\n    .metadata(\"description\", \"My trained model\")\n    .metadata(\"epochs\", \"100\");\nmodel.save_into(&mut store)?;\n```\n\n### Advanced Features\n\n#### Key Remapping\n\nRemap parameter names using regex patterns when model structures don't match:\n\n```rust, ignore\nlet mut store = PytorchStore::from_file(\"model.pt\")\n    // Remove prefix: \"model.conv1.weight\" -> \"conv1.weight\"\n    .with_key_remapping(r\"^model\\.\", \"\")\n    // Rename: \"layer1\" -> \"encoder.layer1\"\n    .with_key_remapping(r\"^layer\", \"encoder.layer\");\nmodel.load_from(&mut store)?;\n```\n\nFor complex remapping:\n\n```rust, ignore\nuse burn_store::KeyRemapper;\n\nlet remapper = KeyRemapper::new()\n    .add_pattern(r\"^transformer\\.h\\.(\\d+)\\.\", \"transformer.layer$1.\")?\n    .add_pattern(r\"\\.attn\\.\", \".attention.\")?;\n\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .remap(remapper);\n```\n\n#### Partial Loading\n\nLoad weights even when some tensors are missing:\n\n```rust, ignore\nlet mut store = PytorchStore::from_file(\"pretrained.pt\")\n    .allow_partial(true);\n\nlet result = model.load_from(&mut store)?;\nprintln!(\"Missing (initialized randomly): {:?}\", result.missing);\n```\n\n#### Filtering Tensors\n\nLoad or save only specific layers:\n\n```rust, ignore\n// Load only encoder layers\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_regex(r\"^encoder\\..*\")\n    .allow_partial(true);\n\n// Save only encoder layers\nlet mut store = SafetensorsStore::from_file(\"encoder.safetensors\")\n    .with_regex(r\"^encoder\\..*\");\nmodel.save_into(&mut store)?;\n\n// Multiple patterns (OR logic)\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_regex(r\"^encoder\\..*\")      // encoder tensors\n    .with_regex(r\".*\\.bias$\")          // OR any bias tensors\n    .with_full_path(\"decoder.scale\"); // OR specific tensor\n```\n\n#### Non-Contiguous Layer Indices\n\nPyTorch `nn.Sequential` with mixed layers creates non-contiguous indices. `PytorchStore`\nautomatically remaps these:\n\n```\nPyTorch: fc.0.weight, fc.2.weight, fc.4.weight  (gaps from ReLU layers)\nBurn:    fc.0.weight, fc.1.weight, fc.2.weight  (contiguous)\n```\n\nThis is enabled by default. Disable if needed:\n\n```rust, ignore\nlet mut store = PytorchStore::from_file(\"model.pt\")\n    .map_indices_contiguous(false);\n```\n\n#### Zero-Copy Loading\n\nFor embedded models or large files, use zero-copy loading to avoid memory copies:\n\n```rust, ignore\n// Embedded model (compile-time)\nstatic MODEL_DATA: &[u8] = include_bytes!(\"model.bpk\");\nlet mut store = BurnpackStore::from_static(MODEL_DATA);\nmodel.load_from(&mut store)?;\n\n// Large file (memory-mapped)\nlet mut store = BurnpackStore::from_file(\"large_model.bpk\")\n    .zero_copy(true);\nmodel.load_from(&mut store)?;\n```\n\n#### Half-Precision Storage\n\nSave models at half precision (F16) to reduce file size by ~50%, then load back at full precision:\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, BurnpackStore, HalfPrecisionAdapter};\n\nlet adapter = HalfPrecisionAdapter::new();\n\n// Save: F32 -> F16 (same adapter for both directions)\nlet mut store = BurnpackStore::from_file(\"model_f16.bpk\")\n    .with_to_adapter(adapter.clone());\nmodel.save_into(&mut store)?;\n\n// Load: F16 -> F32\nlet mut store = BurnpackStore::from_file(\"model_f16.bpk\")\n    .with_from_adapter(adapter);\nmodel.load_from(&mut store)?;\n```\n\nBy default, weights in Linear, Embedding, Conv\\*, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, and\nPRelu modules are converted. BatchNorm is excluded because its running variance can underflow in\nF16. Customize with `with_module()` and `without_module()`:\n\n```rust, ignore\n// Keep LayerNorm at full precision\nlet adapter = HalfPrecisionAdapter::new()\n    .without_module(\"LayerNorm\");\n\n// Add a custom module to the conversion set\nlet adapter = HalfPrecisionAdapter::new()\n    .with_module(\"CustomLayer\");\n```\n\n#### Direct Tensor Access\n\nInspect tensors without loading into a model:\n\n```rust, ignore\nuse burn_store::ModuleStore;\n\nlet mut store = PytorchStore::from_file(\"model.pt\");\n\n// List all tensor names\nlet names = store.keys()?;\n\n// Get specific tensor\nif let Some(snapshot) = store.get_snapshot(\"encoder.layer0.weight\")? {\n    println!(\"Shape: {:?}, DType: {:?}\", snapshot.shape, snapshot.dtype);\n}\n```\n\n#### Model Surgery\n\nTransfer weights between models:\n\n```rust, ignore\nuse burn_store::{ModuleSnapshot, PathFilter};\n\n// Transfer all weights\nlet snapshots = model1.collect(None, None, false);\nmodel2.apply(snapshots, None, None, false);\n\n// Transfer only encoder weights\nlet filter = PathFilter::new().with_regex(r\"^encoder\\..*\");\nlet snapshots = model1.collect(Some(filter.clone()), None, false);\nmodel2.apply(snapshots, Some(filter), None, false);\n```\n\n### API Reference\n\n#### Builder Methods\n\n| Category      | Method                         | Description                  |\n| ------------- | ------------------------------ | ---------------------------- |\n| **Filtering** | `with_regex(pattern)`          | Filter by regex pattern      |\n|               | `with_full_path(path)`         | Include specific tensor      |\n|               | `with_predicate(fn)`           | Custom filter logic          |\n| **Remapping** | `with_key_remapping(from, to)` | Regex-based renaming         |\n|               | `remap(KeyRemapper)`           | Complex remapping rules      |\n| **Adapters**  | `with_from_adapter(adapter)`   | Loading transformations      |\n|               | `with_to_adapter(adapter)`     | Saving transformations       |\n|               | `HalfPrecisionAdapter::new()`  | F32/F16 mixed-precision      |\n| **Config**    | `allow_partial(bool)`          | Continue on missing tensors  |\n|               | `with_top_level_key(key)`      | Access nested dict (PyTorch) |\n|               | `skip_enum_variants(bool)`     | Skip enum variants in paths  |\n|               | `map_indices_contiguous(bool)` | Remap non-contiguous indices |\n|               | `metadata(key, value)`         | Add custom metadata          |\n|               | `zero_copy(bool)`              | Enable zero-copy loading     |\n\n#### Direct Access Methods\n\n| Method                | Description                      |\n| --------------------- | -------------------------------- |\n| `keys()`              | Get ordered list of tensor names |\n| `get_all_snapshots()` | Get all tensors as BTreeMap      |\n| `get_snapshot(name)`  | Get specific tensor by name      |\n\n### Troubleshooting\n\n#### Common Issues\n\n1. **\"Missing source values\" error**: You saved the entire PyTorch model instead of the state_dict.\n   Re-export with `torch.save(model.state_dict(), \"model.pt\")`.\n\n2. **Shape mismatch**: Your Burn model doesn't match the source architecture. Verify layer\n   configurations (channels, kernel sizes, bias settings).\n\n3. **Key not found**: Parameter names don't match. Use `with_key_remapping()` or inspect keys:\n\n   ```rust, ignore\n   let store = PytorchStore::from_file(\"model.pt\");\n   println!(\"Available keys: {:?}\", store.keys()?);\n   ```\n\n#### Inspecting Files\n\nUse [Netron](https://github.com/lutzroeder/netron) to visualize `.pt` and `.safetensors` files.\n\nFor Burnpack files:\n\n```bash\ncargo run --example burnpack_inspect model.bpk\n```\n"
  },
  {
    "path": "codecov.yml",
    "content": "coverage:\n  status:\n    project:\n      default:\n        # https://docs.codecov.com/docs/commit-status#informational\n        informational: true\n        target: 80%\n    patch:\n      default:\n        informational: true\n        target: 80%\ngithub_checks:\n    annotations: false\n"
  },
  {
    "path": "contributor-book/.gitignore",
    "content": "target\n\n# MacOS temp file\n.DS_Store\n\nbook-test\nguide/book\n\n.vscode\ntests/burn-book/book/\nbook/\n\n# Ignore Jetbrains specific files.\n.idea/\n\n# Ignore Vim temporary and swap files.\n*.sw?\n*~"
  },
  {
    "path": "contributor-book/.prettierrc.json",
    "content": "{\n    \"printWidth\": 100,\n    \"proseWrap\": \"always\"\n}"
  },
  {
    "path": "contributor-book/book.toml",
    "content": "[book]\nauthors = [\n    \"Wouter Doppenberg\",\n    \"Nathaniel Simard\",\n    \"Louis Fortier-Dubois\",\n    \"Dilshod Tadjibaev\",\n    \"Guillaume Lagrange\",\n    \"Joshua Ferguson\",\n    \"The Burn Community\",\n]\nlanguage = \"en\"\nsrc = \"src\"\ntitle = \"The Burn Contributor Book 🔥\"\n\n[output.html]\nmathjax-support = true\n"
  },
  {
    "path": "contributor-book/src/SUMMARY.md",
    "content": "- [Overview](./overview.md)\n- [How to Read This Book](./how-to-read-this-book.md)\n- [Getting Started](./getting-started/README.md)\n  - [Setting Up The Environment](./getting-started/setting-up-the-environment.md)\n  - [Configuring Your Editor (Optional)](./getting-started/configuring-your-editor.md)\n  - [Testing](./getting-started/testing.md)\n- [Architecture Overview](./project-architecture/README.md)\n  - [Modules](./project-architecture/module.md)\n  - [Serialization](./project-architecture/serialization.md)\n  - [Tensor](./project-architecture/tensor.md)\n  - [Backend](./project-architecture/backend.md)\n- [Guides for Contributors](./guides/README.md)\n  - [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)\n  - [Submitting Examples to Burn](./guides/submitting-examples.md)\n- [Frequently Encountered Issues](./frequently-encountered-issues/README.md)\n  - [Issues Related To Adding Operators](./frequently-encountered-issues/issues-while-adding-ops.md)\n"
  },
  {
    "path": "contributor-book/src/frequently-encountered-issues/README.md",
    "content": "# Frequently Encountered Issues\n\nThis is a collection of issues people have encountered and asked about on the\n[Discord server](https://discord.gg/uPEBbYYDB6). This section is separated from the guides since it\ncan involve lots of details that are only relevant to a small subset of contributors.\n"
  },
  {
    "path": "contributor-book/src/frequently-encountered-issues/issues-while-adding-ops.md",
    "content": "# Issues encountered while adding ops\n\nBelow are some of the issues that were encountered while adding ops to the project. If you encounter\nan issue while adding an op that isn't listed here, and it's not obvious how to fix it, you can add\nit to this list or reach out on the [Discord server](https://discord.gg/uPEBbYYDB6) if you need\nhelp.\n\n## Off by .000001 errors\n\n```sh\n---- fusion::base::tests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'fusion::base::tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/fusion/base.rs:185:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } } ----\n\ntests::maxmin::tests::test_mean_dim_2d stdout ---- thread 'tests::maxmin::tests::test_mean_dim_2d' panicked at burn-wgpu/src/lib.rs:49:5: assertion `left == right` failed left: Data { value: [1.0, 4.0], shape: Shape { dims: [2, 1] } } right: Data { value: [0.99999994, 3.9999998], shape: Shape { dims: [2, 1] } }\n```\n\nIf you encounter this, swap out the `assert_eq!` in the failing test for\n`tensor1.to_data().assert_approx_eq` with `3` as the second argument. The second arguments specifies\nthe level of precision: `3` is equivalent to a less than 10<sup>-3</sup> (0.001) difference between\nthe elements of the two tensors.\n"
  },
  {
    "path": "contributor-book/src/getting-started/README.md",
    "content": "# Getting Started\n\nThis section is for setting up the environment and how to do basic development tasks such as running\ntests and checking your code before committing. If you need help with the process or run into\nissues, feel free to ask on the [Discord server](https://discord.gg/uPEBbYYDB6) in the Development\nchannels.\n"
  },
  {
    "path": "contributor-book/src/getting-started/configuring-your-editor.md",
    "content": "# Configuring your editor\n\nThese steps are not required, and most of this isn't specific to Burn, but it's definitely helpful\nif you haven't already done it.\n\n## VSCode\n\nInstall the following extensions:\n\n- [rust-lang.rust-analyzer](https://marketplace.visualstudio.com/items?itemName=rust-lang.rust-analyzer)\n  for Rust syntax and semantic analysis\n- [tamasfe.even-better-toml](https://marketplace.visualstudio.com/items?itemName=tamasfe.even-better-toml)\n  for TOML syntax and semantic analysis\n- [fill-labs.dependi](https://marketplace.visualstudio.com/items?itemName=fill-labs.dependi) for\n  managing dependencies\n- [vadimcn.vscode-lldb](https://marketplace.visualstudio.com/items?itemName=vadimcn.vscode-lldb) for\n  debugging\n\n### Setting up the Debugger\n\nTo use the debugger, follow these steps:\n\n1. Open `Command Palette` with `Ctrl+Shift+P` or `F1` and type\n   `LLDB: Generate Launch Configurations from Cargo.toml` then select it, this will generate a file\n   that should be saved as `.vscode/launch.json`.\n2. Select the configuration from the \"run and debug\" side panel, then select the target from the list.\n   Since this repo has `debug = 0` in the root `Cargo.toml` to speed up compilation, you need replace it with `debug = true` in the root `Cargo.toml` when using a debugger and breakpoints with `launch.json` settings.\n3. Now you can enable breakpoints on code through IDE then start debugging the library/binary you\n   want, like in the following example:\n\n![debug-options](debug-options-vscode.png)\n\nIf you're creating a new library or binary, keep in mind to repeat step 1 to always keep a fresh\nlist of targets.\n\n## Have another editor? Open a PR!\n"
  },
  {
    "path": "contributor-book/src/getting-started/setting-up-the-environment.md",
    "content": "# Setting up the environment\n\nDepending on what part of the project you plan on contributing to, there are a couple of tools to\ninstall and commands to be familiar with. This section should be up to date with current project\npractices (as of 2024-04-15).\n\n## General\n\nThere are a few commands you will want to run prior to any commit for a non-draft PR:\n\n1. `cargo fmt --all` will run `rustfmt` on all files in the project.\n2. `cargo clippy --fix` will run [Clippy](https://github.com/rust-lang/rust-clippy) and fix any\n   coding issues it can. Clippy necessitates to be in a clean Git state, but this can be\n   circumvented by adding the `--allow-dirty` flag.\n3. `cargo run-checks` is a command used to test the project. It is required to run successfully\n   prior to merging a PR. Fair warning, running these tests can take a while[^linux_mem_note].\n\n   > Want more detailed macro error diagnostics? This is especially useful for debugging tensor-related tests:\n   >\n   > ```bash\n   > RUSTC_BOOTSTRAP=1 RUSTFLAGS=\"-Zmacro-backtrace\" cargo run-checks\n   > ```\n\n## Updating the burn semver version\n\nIf for some reason you need to bump for the next version (though that should probably be left to the\nmaintainers), edit the semantic version number in `burn/Cargo.toml`, and then run `cargo update` to\nupdate the lock file.\n\n## Contributing to either the Burn Book or Contributor Book\n\nBoth the Burn Book and the Contributor Book are built with mdbook. To open the book locally, run\n`mdbook serve <path/to/book>` or `cargo xtask books {burn|contributor} open` which will install and\nuse mdbook automatically.\n\nAlternatively, if you want to install mdbook directly, run the following command[^update_note]:\n\n```bash\ncargo install mdbook\n```\n\nAlso instead of running `cargo run-checks`, you can run `cargo xtask check typos` to only check\nfor misspellings. This will install [typo](https://crates.io/crates/typos-cli), and if any are\nencountered you should be able to run `typo -w /path/to/book` to fix them.\n\n[^linux_mem_note]:\n    If your system is running into issues with memory and you are on linux, you may want to switch\n    to a [virtual console](https://wiki.archlinux.org/title/Linux_console#Virtual_consoles) to run\n    the tests. To do this, press `ctrl+alt+f3` to switch to a virtual console (and log in), and\n    either `ctrl+alt+f1` or `ctrl+alt+f2` to switch back to your graphical session.\n\n[^update_note]:\n    You might also want to install [cargo-update](https://github.com/nabijaczleweli/cargo-update) to\n    easily keep your tools up to date, though it is in no way required.\n"
  },
  {
    "path": "contributor-book/src/getting-started/testing.md",
    "content": "# Testing\n\n## Test for Tensor Operations\n\nTest for tensor operations (generally of the form: given this input, expect it match or approximate\nthis output) are defined only in\n[`crates/burn-tensor/src/test/ops`](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/ops)\nand not in the backends (with the exception of `burn-autodiff`). The tensor operation tests are\nadded to the `testgen_all` macro rule in\n[`crates/burn-tensor/src/tests/mod.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-tensor/src/tests/mod.rs).\nThis is then propagated to the existing backends without any additional work.\n\n### Test for Autodiff\n\nTests for autodiff go under\n[burn-autodiff/src/tests](https://github.com/tracel-ai/burn/tree/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-autodiff/src/tests)\nand should verify backward pass correctness. For binary tensor operations, both the left and right\nsides need to be verified.\n\nHere's an easy way to define tests for a new operation's backward pass:\n\n1. Use small tensors with simple values.\n2. Pop open a terminal, launch `ipython` and import `numpy` then do the calculations by hand. You\n   can also use [Google Colab](https://colab.google/) so you don't have to install the packages on\n   your system.\n3. Compare the actual outputs to the expected output for left-hand side, right-hand side.\n\nFor float tensors, it is advised to use\n`actual_output_tensor.into_data().assert_approx_eq::<FloatElem<TestBackend>>(&expected_tensor_data, Tolerance::default())`\ninstead of `assert_eq!(...` due to occasional hiccups with floating point calculations. Other\nassertions should also always use `FloatElem<TestBackend>`, and use `.elem()` to convert any\nliterals. Backends are tested for multiple precisions, and hardcoding to a fixed type causes tests\nto fail with alternate floating point precisions. For convenience, it might be worth aliasing the\ntype like `type FT = FloatElem<TestBackend>;`.\n\nFor integers, tests should use `IntElem<TestBackend>`, and exit the test if the test values are\nunrepresentable (above `max_value`, below `min_value`). A minimum range of `[0..127]` (`i8`) can be\nassumed.\n"
  },
  {
    "path": "contributor-book/src/guides/README.md",
    "content": "# Guides for Contributors\n\nThe following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn."
  },
  {
    "path": "contributor-book/src/guides/adding-a-new-operation-to-burn.md",
    "content": "# Adding a New Operation to burn\n\nLet's discuss how one might go about adding new operators to Burn, using the example of the pow\noperator added in [this PR](https://github.com/tracel-ai/burn/pull/1133/files).\n\n## Adding the Op to burn-tensor\n\n`burn-tensor` is the crate that defines all tensor operations that need to be implemented by the\nvarious backends. The core of this lies in\n[crates/burn-backend/src/tensor/ops/numeric.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/numeric.rs#L17),\nwhich is home to the numeric trait. The numeric trait is the home of all tensor operations that are\nnumeric in nature and that are shared by `Int` and `Float` Tensor types. The numeric trait is\nimplemented in\n[crates/burn-backend/src/tensor/ops/int.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/int.rs)\nfor the int type and in\n[crates/burn-backend/src/tensor/ops/float.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/tensor/ops/float.rs)\nfor the float type. More information on the relationship between Tensor modules can be found under\nthe section for [Tensor Architecture](../project-architecture/tensor.md#tensor-operations).\n\nHere is where pow was added to `crates/burn-tensor/src/tensor/api/numeric.rs`:\n\n1. for the\n   [`Tensor<Backend, Dimension, Kind>` struct](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L573)\n2. for the\n   [numeric trait](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L1955)\n3. for the implementation of numeric for\n   [float](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2722)\n   and\n   [int](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/numeric.rs#L2375)\n\nTensor is a struct that has a single member: `primitive` (defined\n[here](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/base.rs#L27)),\nthat is defined by its\n[`Kind`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/kind.rs#L16):\none of `Bool`, `Float`, or `Int` (those linked in 3). These call the ops for that data type defined\nin the\n[`Backend`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/base.rs#L64)\nsupertrait[^supertrait]. This is the trait that is then implemented by the different `burn-`\nbackends (such as `burn-ndarray` and `burn-wgpu`) which must implement the functions if no default\nis provided.\n\nIn this case, we don't need to worry about `Bool` Tensors. `Float` ops are implemented under\n[crates/burn-backend/src/backend/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/tensor.rs),\nand `Int` ops under\n[crates/burn-backend/src/backend/ops/int_tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/int_tensor.rs).\nThe current convention is ops of each type, if not unique to that type, are prefixed with the type.\nSo `powf` and sundry would be defined as `int_powf` for `IntTensorOps` and `float_powf` for\n`FloatTensorOps`. If an op is unique to a type, then it should be implemented under\n`burn-tensor/src/api/{type}.rs`. For example, here is an implementation for\n[`sin` under `crates/burn-tensor/src/api/float.rs`](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tensor/src/tensor/api/float.rs#L82)\nwhich obviously doesn't make sense for `Int` or `Bool` tensors.\n\nThe `Int` Tensor function uses the ones defined for Float with 2 extra casts (LHS to a `Float`\ntensor, Output to an `Int`). Given that the rest of the code will only look at the float\nimplementations.\n\nWith the addition of quantized float tensors, the `Float` tensor primitive is represented by the\n[`TensorPrimitive`](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/api/kind.rs#L69)\nenum. This allows us to handle both float and quantized float operations in the `Tensor`\nimplementation, correctly dispatching to the corresponding op (float or quantized) based on the\nvariant. Following the same convention, the equivalent\n[quantized tensor ops](https://github.com/tracel-ai/burn/blob/a6a5c22e0db56d947b9165d4dae42783a5a6b689/crates/burn-tensor/src/tensor/ops/qtensor.rs#L45)\nare prefixed with `q_*` (e.g., `q_reshape` instead of `float_reshape`). Most ops have a default\nimplementation that simply dequantizes the input into its floating-point representation, performs\nthe operation on the float tensor, and quantizes the output. Backends can overwrite specific\nimplementations when required/desired.\n\n### Adding Tests\n\nAdditional tests should be added to `burn-backend-tests` under\n[`crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/{op_name}.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend-tests/tests/tensor/float/ops/powf.rs),\nand the module name should be inserted into\n`crates/burn-backend-tests/tests/tensor/{float_or_int}/ops/mod.rs`.\n\nIf it makes sense for a floating point operation to support quantization, the\n[`QTensorOps`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend/src/backend/ops/qtensor.rs#L117)\ncounterpart is usually added at the same time with a default implementation (as mentioned in the\nprevious section). Tests for `q_*` ops follow a similar procedure: the test is added under\n[`crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/{op_name}.rs`](https://github.com/tracel-ai/burn/tree/9f31281/crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended),\nthe module name is inserted into\n[`crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs).\nIf you take a look at any of the existing tests for an operation on a quantized tensor, you will see\nthat the inputs and expected outputs are always defined with floating point values. While it assumes\nthat the quantization and dequantization are correct, it makes the tests much more readable and\neasier to understand w.r.t. what is being tested. Effectively, the tests are there to ensure that a\ntensor operation is invariant to quantization (up to some quantization error, of course).\n\n_Note: the tests try to use tensors with floating point values which can be de/quantized without\nintroducing too much quantization error, but the result always depends on the operation (e.g.,\ntensor product of values can grow larger and significantly increase the output tensor range, leading\nto more de/quantization error on the results)._\n\n## Adding the Op to burn-autodiff\n\nSince this is probably the hardest and the least straightforward, we'll cover this backend\nseparately. `burn-autodiff` enables other backends to use autodifferentiation[^autodiff]. Ops for\nfloat types are implemented in\n[crates/burn-autodiff/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-autodiff/src/ops/tensor.rs)\nand need to:\n\n1. Define a unit struct [^absolute_units] that implements a backward (pass) function\n2. Within the backward function, as this is an elementwise binary operation it implements the binary\n   function (from `backward.rs` under the same directory), the last 2 arguments are two closures\n   that define the left and right partial derivatives.\n3. Then define what happens when a specific operation is tracked or untracked, where untracked just\n   calls the function in the normal way, and tracked sets the execution the backward function\n   defined above.\n4. When tracked, operations are part of the autodiff graph and must save the needed information to\n   efficiently perform their backward pass later. If the information is light (such as a shape), it\n   should be directly saved in the state. If the operation's inputs are needed to compute the\n   backward pass, it should be checkpointed rather than saved. This will allow the input to be\n   provided lazily at the backward pass depending on the checkpointing strategy.\n5. An operation must also be identified as _compute-bound_ (`.computeBound()`) or _memory-bound_\n   (`.memoryBound()`) for gradient checkpointing. _Compute-bound_ operation are heavy to compute\n   (for instance matmul or convolution), which means that even with checkpointing they will save\n   their output for the backward pass and not recompute it. _Memory-bound_ operations are more\n   trivial (like `powf` which only performs one small operation per tensor entry), so it can be\n   beneficial to recompute them during the backward pass instead of saving their whole forward\n   output to memory. Operations registered as _memory-bound_ need to know their parents\n   (`.parents()` method) and how to recompute their forward pass during the backward pass (with a\n   struct that implements `RetroForward`), using their parents' outputs.\n\nThe above steps are mostly boilerplate, so you can often just copy the contents of another similar\nop, change the name of the structs, and ensure that either both sides have the data they need (if\nthey need to have a copy of the opposite sided tensor, clone its contents).\n\n### Computing derivatives\n\nFor those that need it, here is a quick refresher on the necessary calculus. If you are familiar\nwith how to calculate partial derivatives, you can skip this section.\n\nSince `pow` is a binary operation, the left and right functions are the partial derivatives with\nrespect to the left and right sided tensors.\n\nLet's define the operator as a function \\\\(f(x,y)=x^{y}\\\\) , where \\\\(x\\\\) is the left hand tensor\nand \\\\(y\\\\) is the right handed tensor. The two closures are defining the partial derivatives of\n\\\\(f\\\\) with respect to \\\\(x\\\\),\\\\(y\\\\). Treat the other variables as a constant\n\n$$\\frac{\\delta }{\\delta x} (x^{y})= y \\cdot x^{y-1}$$ is the left handed closure, and\n\n$$\\frac{\\delta }{\\delta y} (x^{y}) = x^{y} \\cdot ln(x)$$\n\nis the right. If you aren't sure how to calculate these by hand, it is recommended to use\n[symbolab](<https://www.symbolab.com/solver/partial-derivative-calculator/%5Cfrac%7B%5Cpartial%7D%7B%5Cpartial%20x%7D%5Cleft(x%5E%7By%7D%5Cright)?or=input>),\nplug in your operator in terms of \\\\(x\\\\) and \\\\(y\\\\), and just swap out the variable\n\\\\(x\\\\)|\\\\(y\\\\) in the partial derivative to get the other side.\n\n### Testing autodiff\n\nFor testing the `autodiff` operations, please refer to\n[this section](../getting-started/testing.md).\n\n## Adding the Op to other backends\n\nMost of these are fairly straightforward implementations. For reference here's pow's float\nimplementation for torch and ndarray backends:\n\n1. Torch implementation in\n   [crates/burn-tch/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/tensor.rs#L467)\n   and the Op used in\n   [crates/burn-tch/src/ops/base.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-tch/src/ops/base.rs#L481)\n2. NdArray in\n   [crates/burn-ndarray/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/0ee2021567b3725907df5fd1a905ce60b1aca096/crates/burn-ndarray/src/ops/tensor.rs#L472)\n\nThis is where any calculation happens currently. Playing a guessing game with method names and\nseeing what completions are suggested will take you far. If you are having trouble figuring out how\nto do it from the docs for that backend,\n[try searching github for relevant function calls](https://docs.github.com/en/search-github/github-code-search/understanding-github-code-search-syntax).\n\n## Adding the Op to fusion, JIT and cubecl backends\n\nAdding an operator to these backends can be fairly straightforward, though due to what these\nbackends are for, involves a bit more indirection. Fusion and jit, like autodiff, are not target\nbackends as much as backends that enable certain functionality for other backends, in this case\nkernel fusion or just-in-time compilation. Adding the operator won't involve doing any calculation,\nyou'll just be describing how the generated code should look. Most of this can be\ncopy/pasted/adjusted from other functions.\n\nHere's how powf was added to `burn-fusion`:\n\n1. Added powf to the float ops under\n   [crates/burn-fusion/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-fusion/src/ops/tensor.rs#L2061)\n2. Added powf to the `NumericOperationIr` enum under\n   [crates/burn-ir/src/operation.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-ir/src/operation.rs#L564)\n3. Added powf to the implementations of `NumericOperationIr` enum under\n   [crates/burn-ir/src/operation.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-ir/src/operation.rs#L1086)\n4. Added powf to the implemented of `NumericOperationIr` enum under\n   [burn/crates/burn-fusion/src/stream/context.rs](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-fusion/src/stream/context.rs#L883)\n\nThe way `cubecl` handles tensor-scalar operations is by transforming both into a sequence of\nvectorized scalar operations. Since powf already existed in `cubecl`, it was pretty easy to reuse\nthe existing implementation for the situation where both sides of the operation were tensors. The\n`cubecl` crate is primarily concerned with how the operation is compiled and executed by the gpu.\nThe actual implementation is defined in `burn-cubecl`.\n\nHere is where code was added for powf in `burn-cubecl` and `cubecl`:\n\n1. to the implementation of\n   [`FloatTensorOps` under `burn/crates/burn-cubecl/src/ops/tensor.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl/src/ops/tensor.rs#L578)\n2. the function being called was added to\n   [`burn/crates/burn-cubecl/src/ops/numeric.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl/src/ops/numeric.rs#L211-L214)\n3. the operator was defined in\n   [`cubecl/crates/cubecl-ir/src/arithmetic.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-ir/src/arithmetic.rs#L41)\n4. how the operation looks to the gpu was added to\n   [`burn/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs`](https://github.com/tracel-ai/burn/blob/9f31281/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs#L97)\n5. the mappings between the gpu operation and the CPP, WGSL and SPIR-V instructions were added to\n   [`cubecl/crates/cubecl-cpp/src/shared/base.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/base.rs#L1285),\n   [`cubecl/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L869)\n   and\n   [`cubecl/crates/cubecl-spirv/src/arithmetic.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-spirv/src/arithmetic.rs#L491)\n6. the instructions themselves were added for WGSL to\n   [instruction op enum in `cubecl/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs`](https://github.com/tracel-ai/cubecl/blob/f5b63076a01a5c03ea9ed20799d3eeaf776b45da/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L124),\n   and the actual\n   [instruction in wgsl here](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs#L654),\n   for CPP in the enum here\n   [`cubecl/crates/cubecl-cpp/src/shared/instruction.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/instruction.rs#L187)\n   and the actual instruction here\n   [`cubecl/crates/cubecl-cpp/src/shared/binary.rs`](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-cpp/src/shared/binary.rs#L216)\n\nWe needed to generate some custom WGSL code for powf in WGSL, primarily due to issues with proper\ncase handling of the wgsl pow function, like 0 to the 0 power being 1, and any negative number to an\neven power being positive. We reused as much as the existing logic as possible, and then branched at\nthe last point based off the var type of the rhs.\n[See here](https://github.com/tracel-ai/cubecl/blob/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs#L1229).\nFor most operations, you shouldn't need to add to `cubecl-wgpu/src/compiler/wgsl/extension.rs`\nunless the operation isn't native to WGSL.\n\nFor functions that need a complex kernel without a direct mapping to a base instruction, simply use\nthe `cube` macro (see\n[the `cubecl` book](https://github.com/tracel-ai/cubecl/tree/88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35/cubecl-book)).\n\nAnd you're done! Congrats, you just fully added a new operation to burn, and we are all one step\ncloser to the answer to [Are we learning yet?](https://www.arewelearningyet.com/) being \"Yes, and\nit's freaking fast!\". Buy yourself a coffee.\n\n[^supertrait]:\n    for more on supertraits see\n    [the advanced trait section of the rust book](https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#using-supertraits-to-require-one-traits-functionality-within-another-trait)\n\n[^autodiff]:\n    wiki link for\n    [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)\n\n[^absolute_units]:\n    for more information on unit structs see\n    [the defining and instantiating structs section of the rust book](https://doc.rust-lang.org/book/ch05-01-defining-structs.html#unit-like-structs-without-any-fields)\n"
  },
  {
    "path": "contributor-book/src/guides/submitting-examples.md",
    "content": "# Submitting Examples to Burn\n\nThis guide explains how to create and submit new examples to the Burn repository. Examples are a great way to demonstrate Burn's capabilities and help users understand how to use the framework effectively.\n\nFor a minimal working example, see the [simple-regression](https://github.com/tracel-ai/burn/blob/main/examples/simple-regression/examples/regression.rs) example in the repository.\n\n## Repository Structure\n\nThe Burn repository is set up as a workspace, with examples located in the `examples/` directory. Each example is a separate crate that can reuse workspace dependencies.\n\n## Creating a New Example\n\n1. Navigate to the examples directory:\n   ```bash\n   cd examples\n   ```\n\n2. Create a new library crate:\n   ```bash\n   cargo new --lib <my-example>\n   ```\n\n3. Update the example's `Cargo.toml`:\n   ```toml\n   [package]\n   name = \"<my-example>\"\n   version = \"0.1.0\"\n   edition = \"2021\"\n   readme = \"README.md\"\n   # Remove this line if it exists\n   # readme.workspace = true\n\n   [dependencies]\n   # Reuse workspace dependencies when available\n   serde = { workspace = true }\n   # Add example-specific dependencies\n   burn = { path = \"../../\" }\n   ```\n\n## Required Files and Structure\n\n### README.md\nEach example must include a README.md file with:\n- A brief description of what the example demonstrates\n- A terminal command showing how to run the example\n- Any prerequisites or setup instructions\n\nExample README structure:\n````markdown\n# Example Name\n\nBrief description of what this example demonstrates.\n\n## Running the Example\n\n```bash\ncargo run --example <my-example>\n```\n\n## Prerequisites\n\nList any prerequisites here.\n````\n\n### Source Code Structure\n\n- `src/` directory: Contains the main implementation code\n- `examples/` directory: Contains example code\n  - `<my-example>.rs`: Example implementation\n\n## Resource Handling\n\n- Resources (datasets, models, etc.) should be downloaded in the example code\n- Do not track external files in the repository\n- Include code to download and prepare resources when the example is run\n\n## Best Practices\n\n1. **Code Organization**\n   - Keep the code modular and well-documented\n   - Use clear, descriptive variable and function names\n   - Include comments explaining complex operations\n\n2. **Error Handling**\n   - Implement proper error handling\n   - Provide meaningful error messages\n   - Handle resource download failures gracefully\n\n3. **Performance**\n   - Optimize for reasonable execution time\n   - Include progress indicators for long-running operations\n   - Consider adding configuration options for different hardware capabilities\n\n4. **Documentation**\n   - Document all public APIs\n   - Include inline comments for complex logic\n   - Explain any non-obvious implementation details\n\n## Submitting Your Example\n\n1. Ensure your example follows all the guidelines above\n2. Test your example thoroughly\n3. Create a pull request with:\n   - A clear description of what the example demonstrates\n   - Any relevant issue numbers\n   - Screenshots or output examples (if applicable)\n\nFeel free to ask questions in the pull request if you need clarification or guidance. "
  },
  {
    "path": "contributor-book/src/how-to-read-this-book.md",
    "content": "# How to read this book\n\nThroughout this book, we maintain the following structure.\n\n## Linking\n\nWhen referring to structures or functions within codebase, we provide permalinks to the lines in\nspecific commits, and indicate them by the relative path of their parent file from the project root.\nFor example this is a reference to the `Tensor` struct in\n[`crates/burn-tensor/src/tensor/api/base.rs`](https://github.com/tracel-ai/burn/blob/e303e31c8bc85486690ff80df65d1e25e16728c4/crates/burn-tensor/src/tensor/api/base.rs#L27)\n\nWhen some reference information is useful but is beyond the scope of contributing to Burn, we\nprovide that information in a footnote. To build on the previous example, the `Tensor` mentioned is\nwhat's referred to as a newtype struct[^1].\n\nDirect hyperlinks are for tools and resources that are not part of the Burn project, but are useful\nfor contributing to it. For example, when working on implementing an operation for autodiff, it can\nbe useful to use [symbolab](https://www.symbolab.com/) to calculate the left and right partial\nderivatives.\n\n[^1]: For more information on newtype please refer to\n    [the Advanced Types chapter of the Rust Book](https://doc.rust-lang.org/book/ch19-04-advanced-types.html#using-the-newtype-pattern-for-type-safety-and-abstraction)\n"
  },
  {
    "path": "contributor-book/src/overview.md",
    "content": "# Overview\n\nWelcome to The Burn Contributor's Book 👋\n\nThis book will help you get acquainted with the internals of the Burn deep learning framework and\nprovide some detailed guidance on how to contribute to the project. Before opening a PR, please read\nthe [Contributing Guidelines](https://github.com/tracel-ai/burn/blob/main/CONTRIBUTING.md).\n\nWe have crafted some sections for you:\n\n- [Getting Started](./getting-started): Much like the [Burn Book](https://burn.dev/books/burn/) which\n  targets users, we'll start with the fundamentals, guiding you through tasks like setting up the\n  development environment, running tests, and what you should check prior to each commit.\n\n- [Project Architecture](./project-architecture): This section will give you an in-depth look at the\n  architecture of Burn.\n\n- [Guides](./guides): We provide some guides on how to do specific tasks, such as adding a new\n  operations to Burn.\n\n- [Frequently Encountered Issues](./frequently-encountered-issues): If you are running into an issue\n  that has you stumped, this is the section to check out prior to asking on the\n  [Discord](https://discord.gg/uPEBbYYDB6). It's a collection of errors encountered by contributors,\n  what caused them, and how they were resolved.\n\nAs this book is geared towards contributors and not towards users of Burn, we'll assume you have a\ngood understanding of software development, but will make efforts to explain anything outside of\nthat scope, or at least provide links to resources that explain it better than we can.\n"
  },
  {
    "path": "contributor-book/src/project-architecture/README.md",
    "content": "# Project Architecture\n\nThis section documents most major architectural decisions with the reasoning behind them.\n\n**Sections**\n\n- [Module](./module.md)\n  - [Optimization](./module.md#optimization)\n    - [Constraints](./module.md#constraints)\n    - [Solution](./module.md#solution)\n- [Serialization](./serialization.md)\n  - [Constraints](./serialization.md#constraints)\n  - [Solution](./serialization.md#solution)\n    - [Pros](./serialization.md#pros)\n    - [Cons](./serialization.md#cons)\n    - [Compatibility](./serialization.md#compatibility)\n- [Tensor](./tensor.md)\n- [Backend](./backend.md)\n  - [Autodiff](./backend.md#autodiff)\n"
  },
  {
    "path": "contributor-book/src/project-architecture/backend.md",
    "content": "# Backend\n\nThe Backend trait abstracts multiple things:\n\n- Device type\n- Float tensor type\n- Bool tensor type\n- Int tensor type\n- Float element type\n- Int element type\n- Float tensor operations (kernels)\n- Int tensor operations (kernels)\n- Bool tensor operations (kernels)\n\n## Element types\n\n> Warning: there are plans to change this architecture in the near future.\n\nEven though having one type for tensors is convenient for the tensor API, it can be cumbersome when\nimplementing a backend. Therefore, backends can decide, through associated types, what types they\nwant to use for their int, float, and bool tensors. Since float and int can have multiple\nprecisions, the float and int element types are also associated types that must be declared by the\nbackend.\n\nNote that the backend chooses the precision and not the user. Since not all backends will support\nthe same element types, no assumptions must be made. Therefore, there are no methods on tensors to\nchange the precision, except for the `to_full_precision` function, which ensures numerical stability\non the current backend. Backend implementations can provide a way to choose the precision, which can\nbe accomplished with a generic parameter (e.g. `NdArray<f32>`).\n\n## Operations\n\nTo be as general as possible, tensor operations are implemented as plain functions. There is no\nobject or self, just functions that take tensors as input and often return tensors as output as\nwell. Backend implementations are free to use their own patterns to implement these kernels. Note\nthat Burn is a dynamic graph deep learning framework, so backends may have to implement asynchronous\nkernel executions for performance reasons.\n\n## Autodiff\n\nAs of now, there is only one backend decorator that supports autodiff. It follows the decorator\npattern, making any backend differentiable. However, the `AutodiffBackend` trait abstracts how\ngradients are calculated, and other approaches to autodiff might be added later. For more\ninformation about how the current autodiff backend works, you can read this (slightly outdated)\n[blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).\n"
  },
  {
    "path": "contributor-book/src/project-architecture/module.md",
    "content": "# Module\n\nModules are a way of creating neural network structures that can be easily optimized, saved, and\nloaded with little to no boilerplate. Unlike other frameworks, a module does not force the\ndeclaration of the forward pass, leaving it up to the implementer to decide how it should be\ndefined.\n\nAdditionally, most modules are created using a (de)serializable configuration, which defines the\nstructure of the module and its hyperparameters. Parameters and hyperparameters are not serialized\ninto the same file, and both are normally necessary to load a module for inference.\n\n## Optimization\n\nOptimization is normally done with variants of gradient descent, and it is important to provide an\neasy API for optimizing modules.\n\n### Constraints\n\n1. **Users should be able to control what is optimized.** Modules can contain anything for maximum\n   flexibility, but not everything needs to be optimized.\n2. **Optimizers should have a serializable state that is updated during training.** Many optimizers\n   keep track of previous gradients to implement some form of momentum. However, the state can be\n   anything, not just tensors, allowing for easy implementation of any kind of optimizer.\n3. **The learning rate can be updated during training.** Learning rate schedulers are often used\n   during training and should be considered as a key aspect.\n\n### Solution\n\nIn the following, the `Module` trait is defined in\n[`crates/burn-core/src/module/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/module/base.rs#L83)\nand the `Optimizer` trait is defined in\n[`crates/burn-core/src/optim/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/base.rs#L8)\n\nThe solution to this problem comprises multiple parts. Firstly, the `Optimizer` trait is quite\nsimilar to the `Module` trait, in terms of saving and loading the state. Please refer to the\n[serialization](./serialization.md) section for more details.\n\nSecondly, two traits were created. The `Optimizer` trait is general and relatively unopinionated,\nwith a simple `step` method that takes a learning rate, a module, and the gradients. The other\ntrait, `SimpleOptimizer`, aims to provide an easier API for implementing new optimizers. The goal is\nto allow implementations to avoid handling missing gradients, loading and exporting records,\nnavigating the module parameter structure, handling tracked and untracked tensors, and other such\ntasks.\n\nThirdly, each tensor that will be optimized needs to be wrapped into a `Param` struct, which gives\nthem an ID used for (de)serialization and to associate the state of the optimizer to each parameter.\nThe `Module` trait has two ways to navigate over parameters. The first one is the `map` function,\nwhich returns `Self` and makes it easy to implement any transformation and mutate all parameters.\nThe second one is the `visit` function, which has a similar signature but does not mutate the\nparameter tensors.\n\n#### SimpleOptimizer\n\nLocated in\n[`crates/burn-core/src/optim/simple/base.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/base.rs#L9),\nthe `SimpleOptimizer` has two major assumptions:\n\n1. The state of the optimizer is linked to each parameter. In other words, each parameter has its\n   own optimizer state, decoupled from the other parameters.\n2. The state of the optimizer implements `Record`, `Clone`, and has a `'static` lifetime.\n\nThe benefits of those assumptions materialize in simplicity with little loss in flexibility. The\nstate associative type is also generic over the dimension, making it extremely easy to include\ntensors in the state that share the same dimensionality as its parameter.\n\nTo wrap a simple optimizer into the more general `Optimizer` trait, the `OptimizerAdaptor` struct is\nused.\n\n#### OptimizerAdaptor\n\nLocated in in\n[`crates/burn-core/src/optim/simple/adaptor.rs`](https://github.com/tracel-ai/burn/blob/81a67b6a0992b9b5c33cda8b9784570143b67319/crates/burn-core/src/optim/simple/adaptor.rs#L14),\nthe `OptimizerAdaptor` is a simple struct composed of a `SimpleOptimizer` and a hashmap with all\nrecords associated with each parameter ID.\n\nWhen performing an optimization step, the adaptor handles the following:\n\n1. Updates each parameter tensor in the given module using the `Module::map` function.\n2. Checks if a gradient for the current tensor exists.\n3. Makes sure that the gradient, the tensor, and the optimizer state associated with the current\n   parameter are on the same device. The device can be different if the state is loaded from disk to\n   restart training.\n4. Performs the simple optimizer step using the inner tensor since the operations done by the\n   optimizer should not be tracked in the autodiff graph.\n5. Updates the state for the current parameter and returns the updated tensor, making sure it's\n   properly registered into the autodiff graph if gradients are marked as required.\n\nNote that a parameter can still be updated by another process, as it is the case with running\nmetrics used in batch norm. These tensors are still wrapped using the `Param` struct so that they\nare included in the module's state and given a proper parameter ID, but they are not registered in\nthe autodiff graph.\n"
  },
  {
    "path": "contributor-book/src/project-architecture/serialization.md",
    "content": "# Serialization\n\nAn important aspect of a deep learning framework is the ability to save and load models from disk.\nDespite appearing as a simple feature, it involves numerous constraints that require a proper\nsolution.\n\n## Constraints\n\n1. **Users should be able to declare the precision of the model to be saved, independent of the\n   backend in use.**\n\n   The modules should not be duplicated in RAM in another precision to support this. Conversion\n   should be done lazily during (de)serialization.\n\n2. **Users should be able to add any field to a module, even fields that are not serializable.**\n\n   This can include constants, database connections, other module references, or any other\n   information. Only parameters should be serialized since the structure of the module itself should\n   be encapsulated with module configurations (hyperparameters).\n\n3. **Users should be able to declare the format in which the module should be saved.**\n\n   This can involve saving to a compressed JSON file or directly to bytes in memory for `no-std`\n   environments.\n\n4. **Users should be able to create a module with its saved parameters without having to initialize\n   the module first.**\n\n   This will avoid unnecessary module initialization and tensor loading, resulting in reduced cold\n   start when dealing with inference.\n\nIn addition to all of these constraints, the solution should be easy to use.\n\n## Solution\n\nIn order to be able to add any field to a module without requiring it to be (de)serializable, we\ndecouple the module type from its state. We create a new type for each module that only contains the\nparameters that need to be saved. To generate that type automatically, the user must either declare\nwhich field is a parameter or a constant, or we assume that each field implements the module trait.\n\nThe second solution was chosen as it simplifies the code generation and reduces the size of the user\nAPI. This means that the `Module` trait should be implemented by\n[primitive types](https://github.com/tracel-ai/burn/blob/main/crates/burn-core/src/module/param/primitive.rs).\nThe following diagrams highlight the main types and traits used in the solution.\n\n<div align=\"center\">\n<h4>Module Serialization Types</h4>\n<img src=\"./module-serialization.png\" width=\"700px\"/>\n<div align=\"left\">\n\nThe way the types interact with each other is pretty straightforward. First, a module can be\nconverted into a record using `into_record()`. Note that tensors can be cloned, but it won't\nactually copy any data; it will simply create another reference to the same data.\n\nThen, a `Recorder` instance can be used to serialize any record. The `Recorder` has the\n`PrecisionSettings` type as associate type, so any record will be serialized using the settings\nprovided at the creation of the `Recorder` instance. Note that tensors implement record, and their\nitem is just a wrapper struct that contains information about the precision in which the tensor\nshould be saved or loaded. No actual copy of the tensor is made until this point. The tensor is\nconverted to the `TensorData` struct and then converted into the specified precision only when\n`serialize()` or `deserialize()` are called, which makes the whole process lazy.\n\nTo recapitulate, the `Module` trait has an associated type that implements `Record`, which only\ncontains the parameters of the model. The `Record` trait has a generic associated type (GAT) that\nspecifies a family of types that can be (de)serialized given any `PrecisionSettings`. Records are\ntherefore decoupled from the backend in use, and the saved items can be loaded on any backend with\nany precision, since the conversion is type-safe and done when `serialize()` and `deserialize()` are\ncalled. All of the types are generated using simple derive macros without any conditional statements\nor complex syntax, as `Record` and `Module` are implemented for all primitive types. This makes the\ncode simple and easy to maintain. In addition, you can extend the current system with your own\n`Recorder` and `PrecisionSettings` to control how your modules should be saved and loaded.\n\n### Pros\n\n- All constraints are respected.\n- The code is simple and easy to maintain, with very few conditional statements. It is just\n  recursive data structures, where all the complexity is handled by the framework in primitive\n  implementations.\n- The user API is simple and small, with only two derives (`Record` and `Module`) and no additional\n  attributes.\n- Users can create their own `Module` and `Record` primitive types, which gives them the flexibility\n  to control how their data is serialized without having to fork the framework.\n\n### Cons\n\n- There are more types, but most of them are automatically generated and single-purpose, so users\n  don't need to interact with them for common use cases. However, they can do so if necessary.\n- When instantiating a new record manually, each field must be set to something, even if the type\n  itself is `()`, which represents no value. Since the code generation step uses associative types,\n  it doesn't know that a field type is actually nothing. Creating a record manually without using\n  the generated function `into_record` or loading it from a file is only useful to load a set of\n  parameters into a module from an arbitrary source. Using the record may not be the optimal\n  solution to this problem, and another API could be created in the future.\n\n### Compatibility\n\nRecord may become incompatible with previous versions of Burn, depending on the chosen format. The\nmore compact format (bincode) store minimal information about the type, making it significantly\nsmaller but less resilient to type changes such adding an optional field. At some point, it might be\nnecessary to provide a translation script that can translate a more resilient format from a previous\nversion to a more compact one.\n"
  },
  {
    "path": "contributor-book/src/project-architecture/tensor.md",
    "content": "# Tensor\n\nA proper deep learning framework should have a fast tensor implementation with autodiff support, and\nBurn is no exception. The tensor API abstracts away backend implementation details and focuses on\nusability without compromising performance. To make it as easy as possible to use, there is only one\ntensor type, which is different from multiple tensor and deep learning crates in Rust. Generic\nparameters are used instead to specialize the tensor type.\n\n- **B: Backend:** The first argument is the backend on which the tensor implementation lies.\n- **const D: usize:** The second argument is the dimensionality of the tensor.\n- **K: TensorKind:** The third argument is the tensor kind, which can be either Float, Int or Bool.\n  By default, the tensor kind is set to Float, so for most tensors, the kind argument is not\n  necessary.\n\nHaving one struct for tensors reduces the complexity of the tensor API, which also means less\nduplicated documentation to write and maintain.\n\nTensors are thread-safe, which means that you can send a tensor to another thread, and everything\nwill work, including auto-differentiation. Note that there are no explicit in-place tensor\noperations since all tensor operations take owned tensors as parameters, which make it possible to\nmutate them. Tensors can be shared simply by cloning them, but if there is only one reference to a\ntensor, the backend implementation is free to reuse the tensor's allocated data. For more\ninformation about how it is done, you can have a look at this\n[blog post](https://burn.dev/blog/burn-rusty-approach-to-tensor-handling).\n\n## Tensor Operations\n\nOperations on Tensors (sometimes shortened to Ops) are defined in traits (generally part of the\nBackend Supertrait) and implemented for the Tensor struct. The appropriate parent trait of an\noperation depends on the type of operation:\n\n- `base` => All tensor kinds should implement these operations (reshape, into_data, etc.). The\n  implementation is in\n  [crates/burn-tensor/src/tensor/api/base.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/base.rs).\n- `numeric` => All tensors that are numeric by nature should implement these operations (Add, Sub,\n  Div, etc.). The implementation is in\n  [crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/numeric.rs).\n- `Float` => Tensor operations are only available for float tensors. The implementation is in\n  [burn-tensor/src/tensor/api/float.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/float.rs).\n- `Int` => Tensor operations are only available for int tensors. The implementation is in\n  [burn-tensor/src/tensor/api/int.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/int.rs).\n- `bool` => Tensor operations are only available for bool tensors. The implementation is in\n  [burn-tensor/src/tensor/api/bool.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/api/bool.rs).\n\n`Numeric` is directly implemented for `Float` and `Int` tensors, and in general, The implementations\nfor these methods are calling the corresponding `{Int|Float}` method defined in the backend\nsupertrait.\n\nAnything that is implemented by numeric should have an implementation in the `{Int|Float}` traits,\nthough it may be avoidable if the operation for one type requires casting to the other type. To\nprovide an example, `powf` should be implemented for `Int` tensors, but it should not be an Int\nTensor Operation. The LHS should be converted to a float, and the output should be converted back to\nan int. So it's possible to avoid implementing `IntTensorOp` altogether.\n\nAdditionally there are some operations that should be defined as functions instead of tensor op\nmethods. These are:\n\n`module` => These should be exported as functions instead of methods on tensors. The implementation\nis in\n[crates/burn-tensor/src/tensor/ops/module.rs](https://github.com/tracel-ai/burn/tree/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/modules).\n`activation` => These should also be exported as functions instead of methods on tensors. The\nimplementation is in\n[crates/burn-tensor/src/tensor/ops/activation.rs](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-tensor/src/tensor/ops/activation.rs).\nNote that some activations are just a combination of backend operations and are not declared in\nthere.\n"
  },
  {
    "path": "crates/burn/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Flexible and Comprehensive Deep Learning Framework in Rust\"\ndocumentation = \"https://docs.rs/burn\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn\"\nrust-version = \"1.92\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"std\",\n    \"burn-core/default\",\n    \"burn-train?/default\",\n    \"burn-collective?/default\",\n    # Backends\n    \"burn-candle?/default\",\n    \"burn-cpu?/default\",\n    \"burn-ndarray?/default\",\n    \"burn-tch?/default\",\n    \"burn-wgpu?/default\",\n    \"burn-router?/default\",\n    \"burn-remote?/default\",\n    \"burn-cuda?/default\",\n    \"burn-autodiff?/default\",\n    \"burn-rocm?/default\",\n    \"burn-nn/default\",\n    \"burn-optim/default\",\n    \"burn-dispatch?/default\",\n]\ndoc = [\n    \"default\",\n    \"train\",\n    \"burn-core/doc\",\n    \"burn-train/doc\",\n    \"burn-collective/doc\",\n    \"burn-store?/std\",\n    # Backends\n    \"burn-candle/doc\",\n    \"burn-cpu?/doc\",\n    \"burn-ndarray/doc\",\n    \"burn-tch/doc\",\n    \"burn-wgpu/doc\",\n    \"burn-router/doc\",\n    \"burn-cuda/doc\",\n    \"burn-autodiff?/std\",\n    \"burn-rocm/doc\",\n    \"burn-nn/doc\",\n    \"burn-optim/doc\",\n    \"burn-dispatch?/doc\",\n]\nstd = [\n    \"burn-core/std\",\n    # Backends\n    \"burn-candle?/std\",\n    \"burn-cpu?/std\",\n    \"burn-ndarray?/std\",\n    \"burn-wgpu?/std\",\n    \"burn-router?/std\",\n    \"burn-cuda?/std\",\n    \"burn-autodiff?/std\",\n    \"burn-rocm?/std\",\n    \"burn-store?/std\",\n    \"burn-tch?/std\",\n    \"burn-nn/std\",\n    \"burn-optim/std\",\n    \"burn-dispatch?/std\",\n]\ntracing = [\n    \"cubecl?/tracing\",\n    \"burn-core/tracing\",\n    # Backends\n    \"burn-candle?/tracing\",\n    \"burn-cpu?/tracing\",\n    \"burn-ndarray?/tracing\",\n    \"burn-wgpu?/tracing\",\n    \"burn-router?/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-autodiff?/tracing\",\n    \"burn-rocm?/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-store?/tracing\",\n    \"burn-nn/tracing\",\n    \"burn-optim/tracing\",\n    \"burn-dispatch?/tracing\",\n]\n\n\nnetwork = [\"burn-core/network\"]\n\n# Training with full features\ntrain = [\"burn-train\", \"autodiff\", \"dataset\"]\n\n## Includes the Text UI (progress bars, metric plots)\ntui = [\"burn-train?/tui\"]\n\n##  Includes system info metrics (CPU/GPU usage, etc)\nmetrics = [\"burn-train?/sys-metrics\"]\n\n# Datasets\ndataset = [\"burn-core/dataset\"]\n\nsqlite = [\"burn-core/sqlite\"]\nsqlite-bundled = [\"burn-core/sqlite-bundled\"]\n\n# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.\nrecord-item-custom-serde = [\"burn-core/record-item-custom-serde\"]\n\n# Model storage and serialization (SafeTensors, PyTorch interop)\nstore = [\"burn-store\"]\n\n# CubeCL re-export\ncubecl = [\"dep:cubecl\"]\n\naudio = [\"burn-core/audio\"]\nvision = [\"burn-core/vision\"]\nrl = [\"dep:burn-rl\", \"burn-train?/rl\"]\n\n# Backend\nir = [\"burn-ir\"]\nautodiff = [\"burn-autodiff\", \"burn-dispatch?/autodiff\"]\nfusion = [\n    \"ir\",\n    \"burn-wgpu?/fusion\",\n    \"burn-cuda?/fusion\",\n    \"burn-rocm?/fusion\",\n    \"burn-cpu?/fusion\",\n]\n\n## Backend features\naccelerate = [\"burn-candle?/accelerate\", \"burn-ndarray?/blas-accelerate\"]\nautotune = [\n    \"burn-wgpu?/autotune\",\n    \"burn-cuda?/autotune\",\n    \"burn-rocm?/autotune\",\n    \"burn-cpu?/autotune\",\n]\nautotune-checks = [\n    \"burn-wgpu?/autotune-checks\",\n    \"burn-cuda?/autotune-checks\",\n    \"burn-rocm?/autotune-checks\",\n    \"burn-cpu?/autotune-checks\",\n]\nblas-netlib = [\"burn-ndarray?/blas-netlib\"]\nopenblas = [\"burn-ndarray?/blas-openblas\"]\nopenblas-system = [\"burn-ndarray?/blas-openblas-system\"]\nremote = [\"burn-remote/client\", \"ir\"]\nrouter = [\"burn-router\", \"ir\"]\nserver = [\"burn-remote/server\"]\nsimd = [\"burn-ndarray?/simd\"]\ntemplate = [\"burn-wgpu?/template\"]\ncollective = [\"burn-collective\", \"burn-optim/collective\", \"burn-train?/ddp\"]\n\ncandle = [\"burn-candle\"]\ncandle-cuda = [\"candle\", \"burn-candle/cuda\"]\ncandle-metal = [\"burn-candle?/metal\"]\ncuda = [\"burn-cuda\", \"burn-dispatch?/cuda\"]\nrocm = [\"burn-rocm\", \"burn-dispatch?/rocm\"]\nndarray = [\"burn-ndarray\", \"burn-dispatch?/ndarray\"]\ntch = [\"burn-tch\"]\nvulkan = [\"wgpu\", \"burn-wgpu/vulkan\", \"burn-dispatch?/vulkan\"]\nwebgpu = [\"wgpu\", \"burn-wgpu/webgpu\", \"burn-dispatch?/webgpu\"]\nmetal = [\"wgpu\", \"burn-wgpu/metal\", \"burn-dispatch?/metal\"]\nwgpu = [\"burn-wgpu\"]\ncpu = [\"burn-cpu\", \"burn-dispatch?/cpu\"]\n\n# Backend dispatch\ndispatch = [\"burn-dispatch\"]\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std when std is disabled **\n\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-train = { path = \"../burn-train\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-collective = { path = \"../burn-collective\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-store = { path = \"../burn-store\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-nn = { path = \"../burn-nn\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-optim = { path = \"../burn-optim\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-rl = { path = \"../burn-rl\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\n# Backends\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-candle = { path = \"../burn-candle\", version = \"=0.21.0-pre.2\", optional = true }\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-cpu = { path = \"../burn-cpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-remote = { path = \"../burn-remote\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-dispatch = { path = \"../burn-dispatch\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\ncubecl = { workspace = true, default-features = false, optional = true }\n"
  },
  {
    "path": "crates/burn/src/backend.rs",
    "content": "#[cfg(feature = \"ndarray\")]\npub use burn_ndarray as ndarray;\n\n#[cfg(feature = \"ndarray\")]\npub use ndarray::NdArray;\n\n#[cfg(feature = \"autodiff\")]\npub use burn_autodiff as autodiff;\n\n#[cfg(feature = \"remote\")]\npub use burn_remote as remote;\n#[cfg(feature = \"remote\")]\npub use burn_remote::RemoteBackend;\n\n#[cfg(feature = \"autodiff\")]\npub use burn_autodiff::Autodiff;\n\n#[cfg(feature = \"wgpu\")]\npub use burn_wgpu as wgpu;\n\n#[cfg(feature = \"wgpu\")]\npub use burn_wgpu::Wgpu;\n\n#[cfg(feature = \"webgpu\")]\npub use burn_wgpu::WebGpu;\n\n#[cfg(feature = \"vulkan\")]\npub use burn_wgpu::Vulkan;\n\n#[cfg(feature = \"metal\")]\npub use burn_wgpu::Metal;\n\n#[cfg(feature = \"cuda\")]\npub use burn_cuda as cuda;\n\n#[cfg(feature = \"cuda\")]\npub use burn_cuda::Cuda;\n\n#[cfg(feature = \"candle\")]\npub use burn_candle as candle;\n\n#[cfg(feature = \"candle\")]\npub use burn_candle::Candle;\n\n#[cfg(feature = \"rocm\")]\npub use burn_rocm as rocm;\n\n#[cfg(feature = \"rocm\")]\npub use burn_rocm::Rocm;\n\n#[cfg(feature = \"tch\")]\npub use burn_tch as libtorch;\n\n#[cfg(feature = \"tch\")]\npub use burn_tch::LibTorch;\n\n#[cfg(feature = \"router\")]\npub use burn_router::Router;\n\n#[cfg(feature = \"router\")]\npub use burn_router as router;\n\n#[cfg(feature = \"ir\")]\npub use burn_ir as ir;\n\n#[cfg(feature = \"collective\")]\npub use burn_collective as collective;\n#[cfg(feature = \"cpu\")]\npub use burn_cpu as cpu;\n\n#[cfg(feature = \"cpu\")]\npub use burn_cpu::Cpu;\n"
  },
  {
    "path": "crates/burn/src/collective.rs",
    "content": "pub use burn_collective::*;\n"
  },
  {
    "path": "crates/burn/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n\n//! # Burn\n//!\n//! Burn is a new comprehensive dynamic Deep Learning Framework built using Rust\n//! with extreme flexibility, compute efficiency and portability as its primary goals.\n//!\n//! ## Performance\n//!\n//! Because we believe the goal of a deep learning framework is to convert computation\n//! into useful intelligence, we have made performance a core pillar of Burn.\n//! We strive to achieve top efficiency by leveraging multiple optimization techniques:\n//!\n//! - Automatic kernel fusion\n//! - Asynchronous execution\n//! - Thread-safe building blocks\n//! - Intelligent memory management\n//! - Automatic kernel selection\n//! - Hardware specific features\n//! - Custom Backend Extension\n//!\n//! ## Training & Inference\n//!\n//! The whole deep learning workflow is made easy with Burn, as you can monitor your training progress\n//! with an ergonomic dashboard, and run inference everywhere from embedded devices to large GPU clusters.\n//!\n//! Burn was built from the ground up with training and inference in mind. It's also worth noting how Burn,\n//! in comparison to frameworks like PyTorch, simplifies the transition from training to deployment,\n//! eliminating the need for code changes.\n//!\n//! ## Backends\n//!\n//! Burn strives to be as fast as possible on as many hardwares as possible, with robust implementations.\n//! We believe this flexibility is crucial for modern needs where you may train your models in the cloud,\n//! then deploy on customer hardwares, which vary from user to user.\n//!\n//! Compared to other frameworks, Burn has a very different approach to supporting many backends.\n//! By design, most code is generic over the Backend trait, which allows us to build Burn with swappable backends.\n//! This makes composing backend possible, augmenting them with additional functionalities such as\n//! autodifferentiation and automatic kernel fusion.\n//!\n//! - WGPU (WebGPU): Cross-Platform GPU Backend\n//! - Candle: Backend using the Candle bindings\n//! - LibTorch: Backend using the LibTorch bindings\n//! - NdArray: Backend using the NdArray primitive as data structure\n//! - Autodiff: Backend decorator that brings backpropagation to any backend\n//! - Fusion: Backend decorator that brings kernel fusion to backends that support it\n//!\n//! # Quantization\n//!\n//! Quantization techniques perform computations and store tensors in lower precision data types like\n//! 8-bit integer instead of floating point precision. There are multiple approaches to quantize a deep\n//! learning model categorized as post-training quantization (PTQ) and quantization aware training (QAT).\n//!\n//! In post-training quantization, the model is trained in floating point precision and later converted\n//! to the lower precision data type. There are two types of post-training quantization:\n//!\n//! 1. Static quantization: quantizes the weights and activations of the model. Quantizing the\n//!    activations statically requires data to be calibrated (i.e., recording the activation values to\n//!    compute the optimal quantization parameters with representative data).\n//! 2. Dynamic quantization: quantized the weights ahead of time (like static quantization) but the\n//!    activations are dynamically at runtime.\n//!\n//! Sometimes post-training quantization is not able to achieve acceptable task accuracy. In general,\n//! this is where quantization-aware training (QAT) can be used: during training, fake-quantization\n//! modules are inserted in the forward and backward passes to simulate quantization effects, allowing\n//! the model to learn representations that are more robust to reduced precision.\n//!\n//! Burn does not currently support QAT. Only post-training quantization (PTQ) is implemented at this\n//! time.\n//!\n//! Quantization support in Burn is currently in active development. It supports the following PTQ modes on some backends:\n//! - Per-tensor and per-block quantization to 8-bit, 4-bit and 2-bit representations\n//!\n//! ## Feature Flags\n//!\n//! The following feature flags are available.\n//! By default, the feature `std` is activated.\n//!\n//! - Training\n//!   - `train`: Enables features `dataset` and `autodiff` and provides a training environment\n//!   - `tui`: Includes Text UI with progress bar and plots\n//!   - `metrics`: Includes system info metrics (CPU/GPU usage, etc.)\n//! - Dataset\n//!   - `dataset`: Includes a datasets library\n//!   - `audio`: Enables audio datasets (SpeechCommandsDataset)\n//!   - `sqlite`: Stores datasets in SQLite database\n//!   - `sqlite_bundled`: Use bundled version of SQLite\n//!   - `vision`: Enables vision datasets (MnistDataset)\n//! - Backends\n//!   - `wgpu`: Makes available the WGPU backend\n//!   - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler\n//!   - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler\n//!   - `cuda`: Makes available the CUDA backend\n//!   - `rocm`: Makes available the ROCm backend\n//!   - `candle`: Makes available the Candle backend\n//!   - `tch`: Makes available the LibTorch backend\n//!   - `ndarray`: Makes available the NdArray backend\n//! - Backend specifications\n//!   - `accelerate`: If supported, Accelerate will be used\n//!   - `blas-netlib`: If supported, Blas Netlib will be use\n//!   - `openblas`: If supported, Openblas will be use\n//!   - `openblas-system`: If supported, Openblas installed on the system will be use\n//!   - `autotune`: Enable running benchmarks to select the best kernel in backends that support it.\n//!   - `fusion`: Enable operation fusion in backends that support it.\n//! - Backend decorators\n//!   - `autodiff`: Makes available the Autodiff backend\n//! - Model Storage\n//!   - `store`: Enables model storage with SafeTensors format and PyTorch interoperability\n//! - Others:\n//!   - `std`: Activates the standard library (deactivate for no_std)\n//!   - `server`: Enables the remote server.\n//!   - `network`: Enables network utilities (currently, only a file downloader with progress bar)\n//!\n//! You can also check the details in sub-crates [`burn-core`](https://docs.rs/burn-core) and [`burn-train`](https://docs.rs/burn-train).\n\npub use burn_core::*;\n\n/// Train module\n#[cfg(feature = \"train\")]\npub mod train {\n    pub use burn_train::*;\n}\n\n/// Module for reinforcement learning.\n#[cfg(feature = \"rl\")]\npub mod rl {\n    pub use burn_rl::*;\n}\n\n/// Backend module.\npub mod backend;\n\n#[cfg(feature = \"server\")]\npub use burn_remote::server;\n\n/// Module for collective operations\n#[cfg(feature = \"collective\")]\npub mod collective;\n\n/// Module for model storage and serialization\n#[cfg(feature = \"store\")]\npub mod store {\n    pub use burn_store::*;\n}\n\n/// Neural network module.\npub mod nn {\n    pub use burn_nn::*;\n}\n\n/// Optimizers module.\npub mod optim {\n    pub use burn_optim::*;\n}\n\n// For backward compat, `burn::lr_scheduler::*`\n/// Learning rate scheduler module.\n#[cfg(feature = \"std\")]\npub mod lr_scheduler {\n    pub use burn_optim::lr_scheduler::*;\n}\n// For backward compat, `burn::grad_clipping::*`\n/// Gradient clipping module.\npub mod grad_clipping {\n    pub use burn_optim::grad_clipping::*;\n}\n\n#[cfg(feature = \"dispatch\")]\npub use burn_dispatch::*;\n\n/// CubeCL module re-export.\n#[cfg(feature = \"cubecl\")]\npub mod cubecl {\n    pub use cubecl::*;\n}\n\npub mod prelude {\n    //! Structs and macros used by most projects. Add `use\n    //! burn::prelude::*` to your code to quickly get started with\n    //! Burn.\n    pub use burn_core::prelude::*;\n\n    pub use crate::nn;\n}\n"
  },
  {
    "path": "crates/burn-autodiff/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Automatic differentiation backend for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-autodiff\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-autodiff\"\ndocumentation = \"https://docs.rs/burn-autodiff\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\", \"tracing\"]\nstd = [\"dep:parking_lot\"]\nexport_tests = []         # check checkpointer is_empty in tests\n\ntracing = [\n    \"dep:tracing\",\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n]\n\n[dependencies]\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n\n\nderive-new = { workspace = true }\nspin = { workspace = true }\nparking_lot = { workspace = true, optional = true }\nlog = { workspace = true }\nhashbrown = { workspace = true }\nnum-traits = { workspace = true }\nportable-atomic = { workspace = true }\ntracing = { workspace = true, optional = true, features = [\"default\"] }\n\n\n[package.metadata.docs.rs]\nfeatures = [\"default\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-autodiff/README.md",
    "content": "# Burn Autodiff\n\n> [Burn](https://github.com/tracel-ai/burn) autodiff backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-autodiff.svg)](https://crates.io/crates/burn-autodiff)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-autodiff/blob/master/README.md)\n\nFor now only first order reverse mode autodiff is supported.\n"
  },
  {
    "path": "crates/burn-autodiff/src/backend.rs",
    "content": "use crate::{\n    checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},\n    grads::Gradients,\n    tensor::AutodiffTensor,\n};\nuse alloc::{format, string::String};\nuse burn_backend::{\n    backend::{AutodiffBackend, Backend, ExecutionError},\n    tensor::{BoolTensor, IntTensor, QuantizedTensor},\n};\nuse core::marker::PhantomData;\n\n/// Enable auto-differentiation on a backend.\n///\n/// This works as a backend decorator, extending the functionality of any backend with\n/// backpropagation.\n#[derive(Clone, Copy, Debug, Default)]\npub struct Autodiff<B, C = NoCheckpointing> {\n    _b: PhantomData<B>,\n    _checkpoint_strategy: PhantomData<C>,\n}\n\nimpl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {\n    type Device = B::Device;\n\n    type FloatTensorPrimitive = AutodiffTensor<B>;\n    type FloatElem = B::FloatElem;\n\n    type IntTensorPrimitive = B::IntTensorPrimitive;\n    type IntElem = B::IntElem;\n\n    type BoolTensorPrimitive = B::BoolTensorPrimitive;\n    type BoolElem = B::BoolElem;\n\n    type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        true\n    }\n\n    fn name(device: &Self::Device) -> String {\n        format!(\"autodiff<{}>\", B::name(device))\n    }\n\n    fn seed(device: &B::Device, seed: u64) {\n        B::seed(device, seed)\n    }\n\n    fn sync(device: &B::Device) -> Result<(), ExecutionError> {\n        B::sync(device)\n    }\n\n    fn memory_persistent_allocations<\n        Output: Send,\n        Input: Send,\n        Func: Fn(Input) -> Output + Send,\n    >(\n        device: &Self::Device,\n        input: Input,\n        func: Func,\n    ) -> Output {\n        B::memory_persistent_allocations(device, input, func)\n    }\n\n    fn memory_cleanup(device: &Self::Device) {\n        B::memory_cleanup(device)\n    }\n\n    fn staging<'a, Iter>(data: Iter, device: &Self::Device)\n    where\n        Iter: Iterator<Item = &'a mut burn_backend::TensorData>,\n    {\n        B::staging(data, device);\n    }\n\n    fn supports_dtype(device: &Self::Device, dtype: burn_std::DType) -> bool {\n        B::supports_dtype(device, dtype)\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {\n        B::dtype_usage(device, dtype)\n    }\n}\n\nimpl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {\n    type InnerBackend = B;\n    type Gradients = Gradients;\n\n    fn backward(tensor: AutodiffTensor<B>) -> Gradients {\n        tensor.backward()\n    }\n\n    fn grad(tensor: &AutodiffTensor<B>, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {\n        tensor.grad(grads)\n    }\n\n    fn grad_remove(\n        tensor: &AutodiffTensor<B>,\n        grads: &mut Gradients,\n    ) -> Option<B::FloatTensorPrimitive> {\n        tensor.grad_remove(grads)\n    }\n    fn inner(tensor: AutodiffTensor<B>) -> B::FloatTensorPrimitive {\n        tensor.primitive\n    }\n\n    fn from_inner(tensor: B::FloatTensorPrimitive) -> AutodiffTensor<B> {\n        AutodiffTensor::new(tensor)\n    }\n\n    fn grad_replace(\n        tensor: &AutodiffTensor<B>,\n        grads: &mut Self::Gradients,\n        grad: B::FloatTensorPrimitive,\n    ) {\n        tensor.grad_replace(grads, grad);\n    }\n\n    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend> {\n        tensor\n    }\n\n    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend> {\n        tensor\n    }\n\n    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self> {\n        tensor\n    }\n\n    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self> {\n        tensor\n    }\n\n    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend> {\n        tensor\n    }\n\n    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {\n        tensor\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/base.rs",
    "content": "use super::{\n    retro_forward::RetroForwards,\n    state::{BackwardStates, State},\n};\nuse crate::collections::HashMap;\nuse crate::graph::NodeId;\n\nuse alloc::{vec, vec::Vec};\n\n#[derive(new, Debug)]\n/// Links a [NodeId] to its autodiff graph [NodeRef]\npub(crate) struct NodeTree {\n    map: HashMap<NodeId, Vec<NodeId>>,\n}\n\nimpl NodeTree {\n    /// Gives the parents of the node in the autodiff graph\n    pub(crate) fn parents(&self, node_id: &NodeId) -> Option<Vec<NodeId>> {\n        self.map.get(node_id).cloned()\n    }\n}\n\n#[derive(new, Debug)]\n/// Struct responsible of fetching the output for a node in the autodiff graph during a backward pass\npub struct Checkpointer {\n    backward_states: BackwardStates,\n    retro_forwards: RetroForwards,\n    node_tree: NodeTree,\n}\n\nimpl Checkpointer {\n    /// Gives the output of the given node, by recursively asking parents to compute themselves\n    /// or give their pre-computed tensors.\n    pub fn retrieve_node_output<T>(&mut self, node_id: NodeId) -> T\n    where\n        T: Clone + Send + 'static,\n    {\n        self.topological_sort(node_id).into_iter().for_each(|node| {\n            self.retro_forwards\n                .execute_retro_forward(node, &mut self.backward_states)\n        });\n\n        self.backward_states.get_state::<T>(&node_id)\n    }\n\n    /// Sorts the ancestors of NodeId in a way such that all parents come before their children\n    /// Useful to avoid recursivity later when mutating the states\n    ///\n    /// The sort on a compute bound state or a memory bound that is already computed is trivial.\n    /// The match on State::Computed also serves as a stopping criterion for the sort,\n    /// we don't need to look higher than that during recursivity.\n    fn topological_sort(&self, node_id: NodeId) -> Vec<NodeId> {\n        match self.backward_states.get_state_ref(&node_id) {\n            Some(state) => match state {\n                State::Recompute { n_required: _ } => {\n                    let mut sorted = Vec::new();\n                    let parents = self.node_tree.parents(&node_id).unwrap();\n                    for parent_node in parents {\n                        let parent_sorted = self.topological_sort(parent_node);\n                        for ps in parent_sorted {\n                            if !sorted.contains(&ps) {\n                                sorted.push(ps)\n                            }\n                        }\n                    }\n                    sorted.push(node_id);\n                    sorted\n                }\n                State::Computed {\n                    state_content: _,\n                    n_required: _,\n                } => vec![node_id],\n            },\n            None => panic!(\"Node {node_id:?} is not in the backward_states. \"),\n        }\n    }\n\n    /// Checks if checkpointer has been drained adequately. Useful for testing\n    pub fn is_empty(&self) -> bool {\n        self.backward_states.is_empty() && self.retro_forwards.is_empty()\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/builder.rs",
    "content": "use crate::{\n    collections::HashMap,\n    graph::{ComputingProperty, NodeId},\n    tensor::AutodiffTensor,\n};\nuse alloc::{boxed::Box, sync::Arc, vec::Vec};\nuse burn_backend::Backend;\nuse core::any::Any;\n\nuse super::{\n    base::{Checkpointer, NodeTree},\n    retro_forward::{RetroForward, RetroForwards},\n    state::{BackwardStates, State},\n};\n\n#[derive(Debug)]\n/// Determines if a node should checkpoint its computed output or its retro_forward for recomputation\n/// The action is normally created by the child of the node, once the node is determined to be needed\npub enum CheckpointingAction {\n    /// The node's already computed output should be saved\n    Computed {\n        /// The node\n        node_id: NodeId,\n        /// The node's output\n        state_content: Box<dyn Any + Send>,\n    },\n    /// The node should recompute itself when asked\n    Recompute {\n        /// The node\n        node_id: NodeId,\n        /// How the node should recompute itself\n        retro_forward: Arc<dyn RetroForward>,\n    },\n}\n\n// TODO: Remove that when proper client server.\nunsafe impl Send for CheckpointingAction {}\n\nimpl CheckpointingAction {\n    /// Utility function to access the id of the node of the checkpointing action\n    pub fn id(&self) -> NodeId {\n        match self {\n            CheckpointingAction::Computed {\n                node_id: node_ref,\n                state_content: _,\n            } => *node_ref,\n            CheckpointingAction::Recompute {\n                node_id: node_ref,\n                retro_forward: _,\n            } => *node_ref,\n        }\n    }\n}\n\n#[derive(new, Debug, Default)]\n/// Accumulates checkpoints as checkpointing actions during the forward pass,\n/// and builds a checkpointer right before the backward pass\npub struct CheckpointerBuilder {\n    explicit_actions: Vec<CheckpointingAction>,\n    backup_actions: Vec<CheckpointingAction>,\n}\n\n/// Determines if a checkpoint should impact the n_required values (Main)\n/// or if it should just keep the state in case it's required (Backup)\n///\npub(crate) enum ActionType {\n    /// Explicit actions have been explicitly requested by some operation to retrieve their state\n    Explicit,\n    /// Backup actions are not always needed. They exist to save the output of an operation\n    /// whose child is memory bound, in case the state is indirectly needed when computing\n    /// the child's retro_forward. If no explicit action ever asks for the child's output, then\n    /// the backup output will go out of scope when the checkpointer is built.\n    Backup,\n}\n\nimpl CheckpointerBuilder {\n    pub(crate) fn checkpoint<B: Backend>(\n        &mut self,\n        tensor: &AutodiffTensor<B>,\n        action_type: ActionType,\n    ) {\n        let action_list = match action_type {\n            ActionType::Explicit => &mut self.explicit_actions,\n            ActionType::Backup => &mut self.backup_actions,\n        };\n        match &tensor.node.properties {\n            ComputingProperty::ComputeBound | ComputingProperty::Ambiguous => {\n                action_list.push(CheckpointingAction::Computed {\n                    node_id: tensor.node.id,\n                    state_content: Box::new(tensor.primitive.clone()),\n                })\n            }\n            ComputingProperty::MemoryBound { retro_forward } => {\n                action_list.push(CheckpointingAction::Recompute {\n                    node_id: tensor.node.id,\n                    retro_forward: retro_forward.clone(),\n                })\n            }\n        }\n    }\n\n    pub(crate) fn extend(&mut self, other: CheckpointerBuilder) {\n        for other_action in other.explicit_actions {\n            self.explicit_actions.push(other_action)\n        }\n        for other_unsure in other.backup_actions {\n            self.backup_actions.push(other_unsure)\n        }\n    }\n\n    pub(crate) fn build(self, node_tree: NodeTree) -> Checkpointer {\n        let mut backward_states_map = HashMap::new();\n        let mut retro_forwards_map = HashMap::new();\n\n        // Find recursion stopping points\n        let stop_nodes: Vec<NodeId> = self.find_stop_nodes();\n\n        // We start by identifying how many times each node will be required.\n        let n_required_map = self.build_n_required_map(&node_tree, stop_nodes);\n\n        // Then we checkpoint the nodes with the corresponding n_required value\n        self.insert_checkpoints(\n            &mut backward_states_map,\n            &mut retro_forwards_map,\n            n_required_map,\n        );\n\n        Checkpointer::new(\n            BackwardStates::new(backward_states_map),\n            RetroForwards::new(retro_forwards_map),\n            node_tree,\n        )\n    }\n\n    fn find_stop_nodes(&self) -> Vec<NodeId> {\n        let mut stop_nodes = Vec::default();\n        for action in self\n            .explicit_actions\n            .iter()\n            .chain(self.backup_actions.iter())\n        {\n            match action {\n                CheckpointingAction::Computed {\n                    node_id: node_ref,\n                    state_content: _,\n                } => stop_nodes.push(*node_ref),\n                CheckpointingAction::Recompute {\n                    node_id: _,\n                    retro_forward: _,\n                } => {}\n            }\n        }\n        stop_nodes\n    }\n\n    fn build_n_required_map(\n        &self,\n        node_tree: &NodeTree,\n        stop_nodes: Vec<NodeId>,\n    ) -> HashMap<NodeId, usize> {\n        let mut n_required_map = HashMap::<NodeId, usize>::default();\n\n        for action in self.explicit_actions.iter() {\n            match action {\n                CheckpointingAction::Computed {\n                    node_id: node_ref,\n                    state_content: _,\n                } => {\n                    let id = *node_ref;\n                    match n_required_map.remove(&id) {\n                        Some(n) => {\n                            n_required_map.insert(id, n + 1);\n                        }\n                        None => {\n                            n_required_map.insert(id, 1);\n                        }\n                    };\n                }\n                CheckpointingAction::Recompute {\n                    node_id: node_ref,\n                    retro_forward: _,\n                } => {\n                    let id = *node_ref;\n                    Self::update_n_required_of_parents(\n                        id,\n                        &mut n_required_map,\n                        node_tree,\n                        &stop_nodes,\n                    );\n                }\n            }\n        }\n\n        n_required_map\n    }\n\n    fn insert_checkpoints(\n        mut self,\n        backward_states_map: &mut HashMap<NodeId, State>,\n        retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,\n        n_required_map: HashMap<NodeId, usize>,\n    ) {\n        // We do not loop over checkpointing actions anymore because they can contain\n        // duplicates or miss some that are in backup. We loop over the n_required_map\n        // from which we use the ids to find them again in the checkpointing actions\n        for (node_id, n_required) in n_required_map {\n            // We find the checkpointing action for node_id. It's likely in checkpointing_actions\n            // so we check there first, otherwise it will be in backup.\n            // Technically it can be there several times but can never be of both types, so we can assume the first we find is fine\n\n            let action = match self\n                .explicit_actions\n                .iter()\n                .position(|action| action.id() == node_id)\n            {\n                Some(pos) => self.explicit_actions.remove(pos),\n                None => {\n                    let pos = self\n                        .backup_actions\n                        .iter()\n                        .position(|action| action.id() == node_id);\n                    self.backup_actions.remove(pos.unwrap_or_else(|| {\n                        panic!(\"Node {:?} is needed but never checkpointed\", &node_id)\n                    }))\n                }\n            };\n\n            match action {\n                CheckpointingAction::Computed {\n                    node_id: _,\n                    state_content,\n                } => {\n                    self.checkpoint_compute(backward_states_map, node_id, state_content, n_required)\n                }\n                CheckpointingAction::Recompute {\n                    node_id: _,\n                    retro_forward,\n                } => self.checkpoint_lazy(\n                    backward_states_map,\n                    retro_forward_map,\n                    node_id,\n                    retro_forward,\n                    n_required,\n                ),\n            };\n        }\n    }\n\n    fn update_n_required_of_parents(\n        id: NodeId,\n        n_required_map: &mut HashMap<NodeId, usize>,\n        node_tree: &NodeTree,\n        stop_nodes: &Vec<NodeId>,\n    ) {\n        match n_required_map.remove(&id) {\n            Some(n) => {\n                n_required_map.insert(id, n + 1);\n            }\n            None => {\n                n_required_map.insert(id, 1);\n                if !stop_nodes.contains(&id)\n                    && let Some(parents) = node_tree.parents(&id)\n                {\n                    for p in parents {\n                        Self::update_n_required_of_parents(\n                            p,\n                            n_required_map,\n                            node_tree,\n                            stop_nodes,\n                        );\n                    }\n                }\n            }\n        }\n    }\n\n    fn checkpoint_compute(\n        &self,\n        backward_states_map: &mut HashMap<NodeId, State>,\n        node_id: NodeId,\n        state_content: Box<dyn Any + Send>,\n        n_required: usize,\n    ) {\n        backward_states_map.insert(\n            node_id,\n            State::Computed {\n                state_content,\n                n_required,\n            },\n        );\n    }\n\n    fn checkpoint_lazy(\n        &self,\n        backward_states_map: &mut HashMap<NodeId, State>,\n        retro_forward_map: &mut HashMap<NodeId, Arc<dyn RetroForward>>,\n        node_id: NodeId,\n        retro_forward: Arc<dyn RetroForward>,\n        n_required: usize,\n    ) {\n        retro_forward_map.insert(node_id, retro_forward);\n        backward_states_map.insert(node_id, State::Recompute { n_required });\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/mod.rs",
    "content": "/// Checkpointer module\npub mod base;\npub(crate) mod builder;\n/// RetroForward module\npub mod retro_forward;\n/// BackwardStates module\npub mod state;\n/// CheckpointStrategy module\npub mod strategy;\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/retro_forward.rs",
    "content": "use crate::collections::HashMap;\nuse crate::graph::NodeId;\n\nuse alloc::sync::Arc;\nuse core::fmt::Debug;\n\nuse super::state::{BackwardStates, State};\n\n/// Definition of the forward function of a node, called during retropropagation only.\n/// This is different from the normal forward function because it reads and writes from\n/// the [BackwardStates] map instead of having a clear function signature.\npub trait RetroForward: Debug + Send + 'static {\n    /// Applies the forward pass for retropropagation.\n    fn forward(&self, states: &mut BackwardStates, out_node: NodeId);\n}\n\n#[derive(new, Debug)]\n/// Links [NodeId]s to their corresponding [RetroForward]\npub(crate) struct RetroForwards {\n    map: HashMap<NodeId, Arc<dyn RetroForward>>,\n}\n\nimpl RetroForwards {\n    /// Executes the [RetroForward] for a given [NodeId] if the node's\n    /// [State] is [State::Recompute], otherwise does nothing.\n    pub(crate) fn execute_retro_forward(\n        &mut self,\n        node_id: NodeId,\n        backward_states: &mut BackwardStates,\n    ) {\n        if let State::Recompute { n_required: _ } = backward_states\n            .get_state_ref(&node_id)\n            .unwrap_or_else(|| panic!(\"Should find node {node_id:?}\"))\n        {\n            // Retro forwards are always used only once because afterwards their state is computed\n            let retro_forward = self.map.remove(&node_id).unwrap();\n            retro_forward.forward(backward_states, node_id);\n        }\n    }\n\n    pub(crate) fn is_empty(&self) -> bool {\n        self.map.is_empty()\n    }\n}\n\n#[macro_export]\n/// Creates a RetroForward struct for unary scalar operations\nmacro_rules! retro_unary_scalar {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug, Clone)]\n        struct $name<B: Backend> {\n            lhs_id: NodeId,\n            rhs: Scalar,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for $name<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);\n                let out = $ops(lhs, self.rhs);\n                states.save(out_node, out)\n            }\n        }\n    };\n}\n\n#[macro_export]\n/// Creates a RetroForward struct for unary scalar operations\nmacro_rules! retro_unary {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug, Clone)]\n        struct $name<B: Backend> {\n            input_id: NodeId,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for $name<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = $ops(input);\n                states.save(out_node, out)\n            }\n        }\n    };\n}\n\n#[macro_export]\n/// Creates a RetroForward struct for binary operations\nmacro_rules! retro_binary {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug, Clone)]\n        struct $name<B: Backend> {\n            lhs_id: NodeId,\n            rhs_id: NodeId,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for $name<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);\n                let rhs = states.get_state::<B::FloatTensorPrimitive>(&self.rhs_id);\n                let out = $ops(lhs, rhs);\n                states.save(out_node, out)\n            }\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/state.rs",
    "content": "use core::any::Any;\n\nuse crate::collections::HashMap;\nuse crate::graph::NodeId;\nuse alloc::boxed::Box;\n\n/// In order to accept arbitrary node output in the same hashmap, we need to upcast them to any.\npub(crate) type StateContent = Box<dyn Any + Send>;\n\n#[derive(Debug)]\n/// The state contained at one node. Encapsulates the node output if precomputed,\n/// or clearly asks that it needs to be recomputed from the parents.\n/// Also keeps track of the number of times the state is required so it can be removed\n/// from the map of states on its last use.\npub(crate) enum State {\n    /// The state was not checkpointed, will need to recompute it from the node's parents\n    Recompute { n_required: usize },\n    /// The state was checkpointed or computed during retropropagation and can be directly accessed\n    Computed {\n        state_content: StateContent,\n        n_required: usize,\n    },\n}\n\nimpl State {\n    /// Returns a reference to the (not yet) downcasted node output, if checkpointed\n    pub(crate) fn to_state_content(&self) -> &StateContent {\n        match self {\n            State::Recompute { n_required: _ } => {\n                unreachable!(\n                    \"Can't get state content of recompute state. A child has likely been accessed before its parents.\"\n                )\n            }\n            State::Computed {\n                state_content,\n                n_required: _,\n            } => state_content,\n        }\n    }\n\n    /// Returns a (not yet) downcasted node output, if checkpointed\n    pub(crate) fn into_state_content(self) -> StateContent {\n        match self {\n            State::Recompute { n_required: _ } => {\n                unreachable!(\n                    \"Can't get state content of recompute state. A child has likely been accessed before its parents.\"\n                )\n            }\n            State::Computed {\n                state_content,\n                n_required: _,\n            } => state_content,\n        }\n    }\n\n    /// Returns the number of time the state is required\n    pub(crate) fn n_required(&self) -> usize {\n        match self {\n            State::Recompute { n_required } => *n_required,\n            State::Computed {\n                state_content: _,\n                n_required,\n            } => *n_required,\n        }\n    }\n}\n\n#[derive(new, Default, Debug)]\n/// Links [NodeId]s to their current state\npub struct BackwardStates {\n    map: HashMap<NodeId, State>,\n}\n\nimpl BackwardStates {\n    /// Returns the output in the state of the given [NodeId],\n    /// and decrements the number of times this state is required.\n    /// This function always gives ownership of the output, but will clone it if needed for further uses.\n    pub fn get_state<T>(&mut self, node_id: &NodeId) -> T\n    where\n        T: Clone + Send + 'static,\n    {\n        // Fetch the state and decrement its number of required\n        let state = self.map.remove(node_id).unwrap();\n        let remaining_n_required = state.n_required() - 1;\n\n        // Downcast the state to whatever it is supposed to be\n        // If still needed after giving ownership, we copy it back to the hashmap\n        if remaining_n_required > 0 {\n            let new_stored_state = match state {\n                State::Recompute { n_required: _ } => unreachable!(),\n                State::Computed {\n                    state_content,\n                    n_required: _,\n                } => State::Computed {\n                    state_content,\n                    n_required: remaining_n_required,\n                },\n            };\n\n            let downcasted = new_stored_state\n                .to_state_content()\n                .downcast_ref::<T>()\n                .unwrap()\n                .clone();\n\n            self.insert_state(*node_id, new_stored_state);\n\n            downcasted\n        } else {\n            let downcasted = state.into_state_content().downcast::<T>().unwrap();\n            *downcasted\n        }\n    }\n\n    /// Returns a reference to the [State] of the given node\n    /// Useful when we need [State] information without needing the underlying tensor\n    pub(crate) fn get_state_ref(&self, node_id: &NodeId) -> Option<&State> {\n        self.map.get(node_id)\n    }\n\n    /// Associates a [State] to its [NodeId]\n    pub(crate) fn insert_state(&mut self, node_id: NodeId, state: State) {\n        self.map.insert(node_id, state);\n    }\n\n    /// Saves the output to the state of the given [NodeId].\n    pub fn save<T>(&mut self, node_id: NodeId, saved_output: T)\n    where\n        T: Clone + Send + 'static,\n    {\n        let n_required = self.get_state_ref(&node_id).unwrap().n_required();\n        self.insert_state(\n            node_id,\n            State::Computed {\n                state_content: Box::new(saved_output),\n                n_required,\n            },\n        );\n    }\n\n    pub(crate) fn is_empty(&self) -> bool {\n        self.map.is_empty()\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/checkpoint/strategy.rs",
    "content": "use core::fmt::Debug;\n\nuse burn_backend::Backend;\n\nuse crate::{graph::ComputingProperty, tensor::AutodiffTensor};\nuse alloc::sync::Arc;\n\nuse super::{\n    builder::{ActionType, CheckpointerBuilder},\n    retro_forward::RetroForward,\n};\n\n/// Strategy for the amount of checkpointing to do during autodiff\npub trait CheckpointStrategy: Clone + Copy + Debug + Default + Send + Sync + 'static {\n    /// May modify the compute property depending on the strategy\n    fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty;\n\n    /// Checkpoints parents if necessary in the strategy\n    fn checkpoint_parents<'a, B2, A>(\n        parents: A,\n        builder: &mut CheckpointerBuilder,\n    ) -> Result<(), CheckpointingError>\n    where\n        B2: Backend,\n        A: IntoIterator<Item = &'a AutodiffTensor<B2>>;\n}\n\n#[derive(Debug)]\n/// Error that can happen when trying to checkpoint a tensor.\npub enum CheckpointingError {\n    /// When a parent is untracked, we can't easily checkpoint its state, since we don't know the\n    /// requirements in advanced.\n    UntrackedParent,\n}\n\n#[derive(Clone, Copy, Debug, Default)]\n/// All operations are considered compute bound, notwithstanding how they are marked\npub struct NoCheckpointing {}\n\nimpl CheckpointStrategy for NoCheckpointing {\n    /// An operation marked as memory bound is actually compute bound.\n    fn compute_property<R: RetroForward>(_retro_forward: R) -> ComputingProperty {\n        ComputingProperty::ComputeBound\n    }\n\n    /// An operation marked as memory bound is actually compute bound.\n    /// It's therefore useless to checkpoint the parents\n    fn checkpoint_parents<'a, B2, A>(\n        _parents: A,\n        _builder: &mut CheckpointerBuilder,\n    ) -> Result<(), CheckpointingError>\n    where\n        B2: Backend,\n        A: IntoIterator<Item = &'a AutodiffTensor<B2>>,\n    {\n        // Nothing to do here\n        Ok(())\n    }\n}\n\n#[derive(Clone, Copy, Debug, Default)]\n/// Operation properties are as they are marked (compute or memory bound)\npub struct BalancedCheckpointing {}\n\nimpl CheckpointStrategy for BalancedCheckpointing {\n    /// An operation marked as memory bound is memory bound.\n    /// When memory bound, an operation needs to save its RetroForward\n    fn compute_property<R: RetroForward>(retro_forward: R) -> ComputingProperty {\n        ComputingProperty::MemoryBound {\n            retro_forward: Arc::new(retro_forward),\n        }\n    }\n\n    /// An operation marked as memory bound is really memory bound.\n    /// Since the operation may not checkpoint its parents but may need them indirectly\n    /// if asked to recompute itself, the method needs to know the parent tensors to maybe checkpoint them\n    fn checkpoint_parents<'a, B2, A>(\n        parents: A,\n        builder: &mut CheckpointerBuilder,\n    ) -> Result<(), CheckpointingError>\n    where\n        B2: Backend,\n        A: IntoIterator<Item = &'a AutodiffTensor<B2>>,\n    {\n        let mut can_checkpoint = true;\n\n        for tensor in parents.into_iter() {\n            if let crate::graph::Requirement::None = tensor.node.requirement {\n                can_checkpoint = false;\n            } else {\n                builder.checkpoint(tensor, ActionType::Backup);\n            }\n        }\n\n        if !can_checkpoint {\n            *builder = CheckpointerBuilder::default();\n            return Err(CheckpointingError::UntrackedParent);\n        }\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/grads.rs",
    "content": "use burn_backend::{\n    Backend, TensorMetadata, TensorPrimitive,\n    tensor::{FloatTensor, TensorContainer},\n};\n\nuse crate::{\n    NodeId,\n    graph::{NodeRef, Requirement},\n    tensor::AutodiffTensor,\n};\n\n/// Gradient identifier.\npub type GradID = u64;\n\n/// Gradients container used during the backward pass.\npub struct Gradients {\n    container: TensorContainer<GradID>,\n}\n\nimpl Gradients {\n    /// Creates a new gradients container.\n    pub fn new<B: Backend>(root_node: NodeRef, root_tensor: FloatTensor<B>) -> Self {\n        let mut gradients = Self {\n            container: TensorContainer::new(),\n        };\n        gradients.register::<B>(\n            root_node.id,\n            B::float_ones(\n                root_tensor.shape(),\n                &B::float_device(&root_tensor),\n                root_tensor.dtype().into(),\n            ),\n        );\n        gradients\n    }\n\n    /// Consumes the gradients for a given tensor.\n    ///\n    /// Each tensor should be consumed exactly 1 time if its gradients are only required during the\n    /// backward pass, otherwise, it may be consume multiple times.\n    pub fn consume<B: Backend>(&mut self, node: &NodeRef) -> FloatTensor<B> {\n        match node.requirement {\n            Requirement::Grad => self\n                .container\n                .get::<B>(&node.id.value)\n                .map(|tensor| tensor.tensor())\n                .expect(\"Can't consume the gradients before they are registered at least once.\"),\n            Requirement::GradInBackward => self\n                .container\n                .remove::<B>(&node.id.value)\n                .map(|tensor| tensor.tensor())\n                .expect(\"Can't consume the gradients before they are registered at least once.\"),\n            Requirement::None => panic!(\"Trying to consume the gradients for an untracked tensor\"),\n        }\n    }\n\n    /// Removes a grad tensor from the container.\n    pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {\n        self.container\n            .remove::<B>(&tensor.node.id.value)\n            .map(|tensor| tensor.tensor())\n    }\n\n    /// Gets a grad tensor from the container.\n    pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {\n        self.container\n            .get::<B>(&tensor.node.id.value)\n            .map(|tensor| tensor.tensor())\n    }\n\n    /// Register a grad tensor in the container.\n    ///\n    /// If the tensor already exists, add both tensors together before saving the result.\n    pub fn register<B: Backend>(&mut self, node_id: NodeId, value: FloatTensor<B>) {\n        if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {\n            self.container.register::<B>(\n                node_id.value,\n                TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),\n            );\n        } else {\n            self.container\n                .register::<B>(node_id.value, TensorPrimitive::Float(value));\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/base.rs",
    "content": "use super::NodeId;\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients, graph::Parent};\nuse alloc::boxed::Box;\n\n/// Backward step for reverse mode autodiff.\npub trait Step: Send + core::fmt::Debug {\n    /// Executes the step and consumes it.\n    fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer);\n    /// Depth of the operation relative to the first node added to a graph.\n    fn depth(&self) -> usize;\n    /// The node associated to the step.\n    fn node(&self) -> NodeId;\n    /// The parents of the node associated to the step.\n    fn parents(&self) -> &[Parent];\n}\n\npub type StepBoxed = Box<dyn Step>;\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/mod.rs",
    "content": "mod base;\nmod node;\nmod requirement;\n\npub mod traversal;\n\npub use base::*;\npub use node::*;\npub use requirement::*;\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/node.rs",
    "content": "use alloc::{sync::Arc, vec::Vec};\n\n#[cfg(target_has_atomic = \"64\")]\nuse core::sync::atomic::{AtomicU64, Ordering};\n#[cfg(not(target_has_atomic = \"64\"))]\nuse portable_atomic::{AtomicU64, Ordering};\n\nuse crate::checkpoint::retro_forward::RetroForward;\nuse crate::runtime::AutodiffClientImpl;\n\nuse super::Requirement;\n\n#[derive(Debug, Clone)]\npub enum ComputingProperty {\n    ComputeBound,\n    MemoryBound {\n        retro_forward: Arc<dyn RetroForward>,\n    },\n    Ambiguous, // Maybe autotune someday\n}\n\n/// This is safe only because we only call RetroForward on the autodiff server.\n/// Therefore, the trait will never be used by multiple threads at the same time.\n///\n/// TODO: Find a way to avoid cloning the compute property, which will remove the need to add the\n/// Arc, which will make (dyn RetroForward) safely implement Send.\nunsafe impl Send for ComputingProperty {}\n/// unsafe Sync is required because Send is only implemented for Arc<Sync>, not Arc<Send>.\nunsafe impl Sync for ComputingProperty {}\n\n/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.\n#[derive(new, Debug)]\npub struct Node {\n    pub parents: Vec<Parent>,\n    pub order: usize,\n    pub id: NodeId,\n    pub requirement: Requirement,\n    pub properties: ComputingProperty,\n    pub client: AutodiffClientImpl,\n}\npub type NodeRef = Arc<Node>;\n\n#[derive(new, Debug, Clone, PartialEq, Eq)]\npub struct Parent {\n    pub id: NodeId,\n}\n\nimpl Node {\n    /// Returns the [node](Node) only if gradients are required.\n    pub fn clone_if_require_grad(self: &Arc<Self>) -> Option<NodeRef> {\n        match self.requirement.is_none() {\n            true => None,\n            false => Some(self.clone()),\n        }\n    }\n}\n\n/// Unique identifier generated for each node.\n#[derive(Clone, Hash, PartialEq, Eq, Debug, Copy)]\npub struct NodeId {\n    /// The integer representation of the id\n    pub value: u64,\n}\n\nimpl core::fmt::Display for NodeId {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_fmt(format_args!(\"NodeId({})\", self.value))\n    }\n}\n\nimpl NodeId {\n    /// Create a unique [node id](NodeId).\n    pub fn new() -> Self {\n        static COUNTER: AtomicU64 = AtomicU64::new(0);\n        let value = COUNTER.fetch_add(1, Ordering::Relaxed);\n        if value == u64::MAX {\n            panic!(\"NodeId overflowed\");\n        }\n        Self { value }\n    }\n}\n\nimpl Default for NodeId {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/requirement.rs",
    "content": "use super::NodeRef;\n\n/// Requirement for each tensor in the graph.\n#[derive(Debug, Clone, Copy, PartialEq, Eq)]\npub enum Requirement {\n    /// Operations that require gradients.\n    Grad,\n    /// Operations that require gradients only for backprop.\n    GradInBackward,\n    /// Operations that don't need gradients, therefore not to be included in the graph.\n    None,\n}\n\nimpl Requirement {\n    /// Returns true if gradients are not required.\n    pub fn is_none(&self) -> bool {\n        matches!(self, Self::None)\n    }\n    /// Returns the right requirement from a list of nodes.\n    pub fn from_nodes(nodes: &[NodeRef]) -> Self {\n        if nodes.len() == 1 {\n            return nodes[0].requirement.infer(&Requirement::None);\n        }\n\n        nodes\n            .iter()\n            .map(|node| node.requirement)\n            .reduce(|acc, requirement| requirement.infer(&acc))\n            .unwrap_or(Requirement::None)\n    }\n\n    fn infer(&self, other: &Self) -> Self {\n        match self.is_none() && other.is_none() {\n            true => Self::None,\n            false => Self::GradInBackward,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/graph/traversal.rs",
    "content": "use super::{Step, StepBoxed};\nuse crate::{\n    NodeId,\n    collections::{HashMap, HashSet},\n    graph::Parent,\n};\nuse alloc::vec::Vec;\n\n/// Breadth for search algorithm.\npub struct BreadthFirstSearch;\n\npub trait TraversalItem {\n    fn id(&self) -> NodeId;\n    fn parents(&self) -> &[Parent];\n    fn parent_nodes(&self) -> Vec<NodeId> {\n        self.parents().iter().map(|p| p.id).collect()\n    }\n}\n\nimpl BreadthFirstSearch {\n    /// Traverse the graph of backward steps from a root node.\n    pub fn traverse<F, I>(\n        &self,\n        root_id: NodeId,\n        root_step: I,\n        steps: &mut HashMap<NodeId, I>,\n        mut callback: F,\n    ) where\n        F: FnMut(NodeId, I),\n        I: TraversalItem,\n    {\n        let mut visited = HashSet::new();\n        let mut parents = Vec::new();\n\n        visited.insert(root_id);\n        parents.append(&mut root_step.parent_nodes());\n\n        callback(root_id, root_step);\n\n        while let Some(id) = parents.pop() {\n            let step = match steps.remove(&id) {\n                Some(step) => step,\n                None => continue,\n            };\n\n            let step_node = step.id();\n            let step_parents = step.parent_nodes();\n\n            if visited.contains(&step_node) {\n                continue;\n            }\n\n            visited.insert(step_node);\n\n            for id in step_parents.iter() {\n                if !visited.contains(id) {\n                    parents.push(*id);\n                }\n            }\n\n            callback(step_node, step);\n        }\n    }\n}\n\nimpl TraversalItem for StepBoxed {\n    fn id(&self) -> NodeId {\n        Step::node(self.as_ref())\n    }\n\n    fn parents(&self) -> &[Parent] {\n        Step::parents(self.as_ref())\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! # Burn Autodiff\n//!\n//! This autodiff library is a part of the Burn project. It is a standalone crate\n//! that can be used to perform automatic differentiation on tensors. It is\n//! designed to be used with the Burn Tensor crate, but it can be used with any\n//! tensor library that implements the `Backend` trait.\n\n#[macro_use]\nextern crate derive_new;\n\nextern crate alloc;\n\n/// Checkpoint module.\npub mod checkpoint;\n/// Gradients module.\npub mod grads;\n/// Operation module.\npub mod ops;\n\npub(crate) mod graph;\n// Exported for backend extension\npub use graph::NodeId;\npub(crate) mod tensor;\npub(crate) mod utils;\n\nmod backend;\n\npub(crate) mod runtime;\n\npub use backend::*;\n\n/// A facade around for HashMap and HashSet.\n/// This avoids elaborate import wrangling having to happen in every module.\nmod collections {\n    #[cfg(not(feature = \"std\"))]\n    pub use hashbrown::{HashMap, HashSet};\n    #[cfg(feature = \"std\")]\n    pub use std::collections::{HashMap, HashSet};\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/activation.rs",
    "content": "use core::marker::PhantomData;\n\nuse crate::{\n    Autodiff,\n    checkpoint::{\n        base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,\n        strategy::CheckpointStrategy,\n    },\n    grads::Gradients,\n    graph::NodeId,\n    ops::{Backward, Ops, OpsKind, unary},\n    retro_unary,\n};\nuse burn_backend::{Backend, ops::ActivationOps, tensor::FloatTensor};\n\nimpl<B: Backend, C: CheckpointStrategy> ActivationOps<Autodiff<B, C>> for Autodiff<B, C> {\n    fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Gelu;\n\n        retro_unary!(RetroGelu, B::gelu);\n\n        impl<B: Backend> Backward<B, 1> for Gelu {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::gelu_backward(input, grad)\n                });\n            }\n        }\n\n        match Gelu\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroGelu::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::gelu(tensor.primitive.clone()))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::gelu(tensor.primitive)),\n        }\n    }\n\n    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Relu;\n\n        retro_unary!(RetroRelu, B::relu);\n\n        impl<B: Backend> Backward<B, 1> for Relu {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let state = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::relu_backward(state, grad)\n                });\n            }\n        }\n\n        match Relu\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRelu::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::relu(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::relu(tensor.primitive)),\n        }\n    }\n\n    fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sigmoid;\n\n        retro_unary!(RetroSigmoid, B::sigmoid);\n\n        impl<B: Backend> Backward<B, 1> for Sigmoid {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                let output = B::sigmoid(input);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::sigmoid_backward(output, grad)\n                });\n            }\n        }\n\n        match Sigmoid\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSigmoid::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::sigmoid(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::sigmoid(tensor.primitive)),\n        }\n    }\n\n    fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct LogSigmoid;\n\n        retro_unary!(RetroLogSigmoid, B::log_sigmoid);\n\n        impl<B: Backend> Backward<B, 1> for LogSigmoid {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::log_sigmoid_backward(input, grad)\n                });\n            }\n        }\n\n        match LogSigmoid\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroLogSigmoid::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::log_sigmoid(tensor.primitive.clone()))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::log_sigmoid(tensor.primitive)),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/backward.rs",
    "content": "use super::{Ops, OpsPrep};\nuse crate::{\n    checkpoint::{base::Checkpointer, builder::CheckpointerBuilder, strategy::CheckpointStrategy},\n    grads::Gradients,\n    graph::{ComputingProperty, NodeRef, Requirement},\n    utils::duplicate,\n};\nuse burn_backend::Backend;\n\n/// Trait for all operations.\n///\n/// # Notes\n///\n/// Concrete types implementing this trait should not have any state.\n/// If a state is necessary during the backward pass,\n/// they should be declared with the associated type 'State'.\npub trait Backward<B, const N: usize>: Send + core::fmt::Debug\nwhere\n    Self: Sized + 'static,\n    B: Backend,\n{\n    /// Associated type to compute the backward pass.\n    type State: Clone + Send + core::fmt::Debug + 'static;\n\n    /// The backward pass.\n    fn backward(\n        self,\n        ops: Ops<Self::State, N>,\n        grads: &mut Gradients,\n        checkpointer: &mut Checkpointer,\n    );\n\n    /// Prepare the backward ops.\n    fn prepare<C: CheckpointStrategy>(\n        self,\n        nodes: [NodeRef; N],\n    ) -> OpsPrep<Self, B, Self::State, C, N> {\n        let requirement = Requirement::from_nodes(&nodes);\n        OpsPrep::new(\n            nodes,\n            requirement,\n            self,\n            ComputingProperty::Ambiguous, // If not specified we start with ambiguous\n            CheckpointerBuilder::default(),\n        )\n    }\n}\n\n/// Execute a binary operation during the backward step.\npub fn binary<B, FLhs, FRhs>(\n    parents: [Option<NodeRef>; 2],\n    node: NodeRef,\n    grads: &mut Gradients,\n    func_lhs: FLhs,\n    func_rhs: FRhs,\n) where\n    B: Backend,\n    FLhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,\n    FRhs: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,\n{\n    let [grad_4lhs, grad_4rhs] = duplicate(&parents, Some(grads.consume::<B>(&node)));\n    let [node_lhs, node_rhs] = parents;\n\n    if let Some(node) = node_lhs {\n        let grad = func_lhs(grad_4lhs.unwrap());\n        grads.register::<B>(node.id, grad)\n    }\n\n    if let Some(node) = node_rhs {\n        let grad = func_rhs(grad_4rhs.unwrap());\n        grads.register::<B>(node.id, grad)\n    }\n}\n\n/// Execute a unary operation during the backward step.\npub fn unary<B, F>(parents: [Option<NodeRef>; 1], node: NodeRef, grads: &mut Gradients, func: F)\nwhere\n    B: Backend,\n    F: FnOnce(B::FloatTensorPrimitive) -> B::FloatTensorPrimitive,\n{\n    let [parent_node] = parents;\n    let grad = grads.consume::<B>(&node);\n\n    if let Some(node) = parent_node {\n        let grad = func(grad);\n        grads.register::<B>(node.id, grad)\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/base.rs",
    "content": "use super::Backward;\nuse crate::{\n    checkpoint::{\n        base::Checkpointer,\n        builder::{ActionType, CheckpointerBuilder},\n        retro_forward::RetroForward,\n        strategy::CheckpointStrategy,\n    },\n    grads::Gradients,\n    graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},\n    tensor::AutodiffTensor,\n};\nuse alloc::boxed::Box;\nuse burn_backend::{Backend, TensorMetadata, tensor::FloatTensor};\nuse burn_std::Shape;\nuse core::marker::PhantomData;\n\n/// Operation in preparation.\n///\n/// Each mode has its own set of functions to minimize cloning for unused backward states.\n#[derive(new)]\npub struct OpsPrep<Backward, B, S, C, const N: usize, Mode = Init> {\n    nodes: [NodeRef; N],\n    requirement: Requirement,\n    backward: Backward,\n    compute_property: ComputingProperty,\n    checkpointer_builder: CheckpointerBuilder,\n    checkpoint_strategy: PhantomData<C>,\n    phantom_backend: PhantomData<B>,\n    phantom_state: PhantomData<S>,\n    marker: PhantomData<Mode>,\n}\n\n/// Operation is initialized\npub struct Init;\n/// Operation has been tagged as memory bound\npub struct MemoryBound;\n/// Memory bound operation has received its RetroForward\npub struct MemoryBoundRetroForward;\n/// Operation's compute property is fixed\npub struct ComputePropertyDone;\n/// Tracked operation tag.\npub struct Tracked;\n/// Untracked operation tag.\npub struct UnTracked;\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Init>\nwhere\n    B: Backend,\n    BO: Backward<B, N, State = S>,\n{\n    /// Indicates that the operation is compute bound, meaning its computation\n    /// is heavy and should not be recomputed\n    pub fn compute_bound(self) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone> {\n        OpsPrep::new(\n            self.nodes,\n            self.requirement,\n            self.backward,\n            ComputingProperty::ComputeBound,\n            self.checkpointer_builder,\n        )\n    }\n\n    /// Indicates that the operation is memory bound, meaning its computation\n    /// is light and can be recomputed\n    pub fn memory_bound(self) -> OpsPrep<BO, B, S, C, N, MemoryBound> {\n        OpsPrep::new(\n            self.nodes,\n            self.requirement,\n            self.backward,\n            self.compute_property,\n            self.checkpointer_builder,\n        )\n    }\n}\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBound>\nwhere\n    B: Backend,\n    BO: Backward<B, N, State = S>,\n    C: CheckpointStrategy,\n{\n    /// Registers the retro forward, if needed\n    pub fn retro_forward<R: RetroForward>(\n        self,\n        retro_forward: R,\n    ) -> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward> {\n        OpsPrep::new(\n            self.nodes,\n            self.requirement,\n            self.backward,\n            C::compute_property(retro_forward),\n            self.checkpointer_builder,\n        )\n    }\n}\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, MemoryBoundRetroForward>\nwhere\n    B: Backend,\n    BO: Backward<B, N, State = S>,\n    C: CheckpointStrategy,\n{\n    /// Checkpoints the parents, if needed\n    pub fn parents<'a, B2, A>(mut self, parents: A) -> OpsPrep<BO, B, S, C, N, ComputePropertyDone>\n    where\n        B2: Backend,\n        A: IntoIterator<Item = &'a AutodiffTensor<B2>>,\n    {\n        let compute_property = match C::checkpoint_parents(parents, &mut self.checkpointer_builder)\n        {\n            Ok(..) => self.compute_property,\n            Err(..) => ComputingProperty::ComputeBound,\n        };\n\n        OpsPrep::new(\n            self.nodes,\n            self.requirement,\n            self.backward,\n            compute_property,\n            self.checkpointer_builder,\n        )\n    }\n}\n\nimpl<BO, B, C, const N: usize> OpsPrep<BO, B, (), C, N, ComputePropertyDone>\nwhere\n    B: Backend,\n    BO: Backward<B, N, State = ()>,\n{\n    /// Prepare a stateless operation.\n    pub fn stateless(self, output: FloatTensor<B>) -> AutodiffTensor<B> {\n        match self.stateful() {\n            OpsKind::Tracked(prep) => prep.finish((), output),\n            OpsKind::UnTracked(prep) => prep.finish(output),\n        }\n    }\n}\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, ComputePropertyDone>\nwhere\n    B: Backend,\n    S: Clone + Send + core::fmt::Debug + 'static,\n    BO: Backward<B, N, State = S>,\n{\n    /// Prepare an operation that requires a state during the backward pass.\n    pub fn stateful(self) -> OpsKind<BO, B, S, C, N> {\n        match self.requirement.is_none() {\n            false => OpsKind::Tracked(OpsPrep::new(\n                self.nodes,\n                self.requirement,\n                self.backward,\n                self.compute_property,\n                self.checkpointer_builder,\n            )),\n            true => OpsKind::UnTracked(OpsPrep::new(\n                self.nodes,\n                self.requirement,\n                self.backward,\n                self.compute_property,\n                self.checkpointer_builder,\n            )),\n        }\n    }\n}\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, UnTracked>\nwhere\n    B: Backend,\n    S: Clone + Send + core::fmt::Debug + 'static,\n    BO: Backward<B, N, State = S>,\n{\n    /// Finish the preparation of an untracked operation and returns the output tensor.\n    pub fn finish(self, output: FloatTensor<B>) -> AutodiffTensor<B> {\n        let output = AutodiffTensor::from_parents(\n            output,\n            &self.nodes,\n            self.requirement,\n            self.compute_property,\n        );\n        let parents = self.nodes.map(|node| node.clone_if_require_grad());\n        let ops = Ops::new(parents, output.node.clone(), ());\n\n        // We register the ops in the graph even if untracked, otherwise memory bound operations\n        // that have an untracked parent would not be able to retrieve it\n        output.register_step(UntrackedOpsStep::new(ops), self.checkpointer_builder)\n    }\n}\n\nimpl<BO, B, S, C, const N: usize> OpsPrep<BO, B, S, C, N, Tracked>\nwhere\n    B: Backend,\n    S: Clone + Send + core::fmt::Debug + 'static,\n    BO: Backward<B, N, State = S>,\n{\n    /// Finish the preparation of a tracked operation and returns the output tensor.\n    pub fn finish(self, state: S, output: FloatTensor<B>) -> AutodiffTensor<B> {\n        let output = AutodiffTensor::from_parents(\n            output,\n            &self.nodes,\n            self.requirement,\n            self.compute_property,\n        );\n        let parents = self.nodes.map(|node| node.clone_if_require_grad());\n        let ops = Ops::new(parents, output.node.clone(), state);\n\n        output.register_step(OpsStep::new(ops, self.backward), self.checkpointer_builder)\n    }\n\n    /// Checkpoints the tensor\n    pub fn checkpoint(&mut self, tensor: &AutodiffTensor<B>) -> NodeId {\n        self.checkpointer_builder\n            .checkpoint(tensor, ActionType::Explicit);\n\n        tensor.node.id\n    }\n}\n\n/// Enum used before finishing tracked and untracked operations.\npub enum OpsKind<BO, B, S, C, const N: usize> {\n    /// Tracked operation preparation.\n    Tracked(OpsPrep<BO, B, S, C, N, Tracked>),\n    /// Untracked operation preparation.\n    UnTracked(OpsPrep<BO, B, S, C, N, UnTracked>),\n}\n\n/// Operation containing its parent nodes, its own node and the backward step state.\n#[derive(new, Debug)]\npub struct Ops<S, const N: usize> {\n    /// Parents nodes.\n    pub parents: [Option<NodeRef>; N],\n    /// The node.\n    pub node: NodeRef,\n    /// The state.\n    pub state: S,\n}\n\n/// Operation implementing backward [step](Step) with type erasing.\n#[derive(new, Debug)]\nstruct OpsStep<B, T, SB, const N: usize>\nwhere\n    B: Backend,\n    T: Backward<B, N, State = SB>,\n    SB: Clone + Send + core::fmt::Debug + 'static,\n{\n    ops: Ops<SB, N>,\n    backward: T,\n    phantom: PhantomData<B>,\n}\n\nimpl<B, T, SB, const N: usize> Step for OpsStep<B, T, SB, N>\nwhere\n    B: Backend,\n    T: Backward<B, N, State = SB>,\n    SB: Clone + Send + core::fmt::Debug + 'static,\n{\n    fn step(self: Box<Self>, grads: &mut Gradients, checkpointer: &mut Checkpointer) {\n        self.backward.backward(self.ops, grads, checkpointer);\n    }\n\n    fn node(&self) -> NodeId {\n        self.ops.node.id\n    }\n\n    fn parents(&self) -> &[Parent] {\n        &self.ops.node.parents\n    }\n\n    fn depth(&self) -> usize {\n        self.ops.node.order\n    }\n}\n\n#[derive(new, Debug)]\nstruct UntrackedOpsStep<const N: usize> {\n    ops: Ops<(), N>,\n}\n\nimpl<const N: usize> Step for UntrackedOpsStep<N> {\n    fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {\n        // Nothing to do\n    }\n\n    fn node(&self) -> NodeId {\n        self.ops.node.id\n    }\n\n    fn parents(&self) -> &[Parent] {\n        &self.ops.node.parents\n    }\n    fn depth(&self) -> usize {\n        self.ops.node.order\n    }\n}\n\n/// Make sure the grad tensor has the given shape.\n///\n/// If broadcasting happened during the forward pass, the gradients will be sum along the\n/// broadcasted dimension.\npub fn broadcast_shape<B: Backend>(mut grad: FloatTensor<B>, shape: &Shape) -> FloatTensor<B> {\n    let shape_grad = grad.shape();\n    let ndims = shape_grad.num_dims();\n\n    for i in 0..ndims {\n        if shape_grad[i] != shape[i] {\n            if shape[i] != 1 {\n                panic!(\n                    \"Invalid broadcast shapes: Next grad shape {:?}, Previous grad shape {:?}. {}\",\n                    shape, shape_grad, \"Expected the shape of the next grad to be 1.\"\n                );\n            }\n            grad = B::float_sum_dim(grad, i);\n        }\n    }\n\n    grad\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/bool_tensor.rs",
    "content": "use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};\nuse alloc::vec::Vec;\n\nuse burn_backend::{\n    Backend, ExecutionError, Scalar, TensorData,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, Device, IntTensor},\n};\nuse burn_std::Shape;\n\nimpl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {\n    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B> {\n        B::bool_from_data(data, device)\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<B>) -> Result<TensorData, ExecutionError> {\n        B::bool_into_data(tensor).await\n    }\n\n    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B> {\n        B::bool_into_int(tensor)\n    }\n\n    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B> {\n        B::bool_to_device(tensor, device)\n    }\n\n    fn bool_device(tensor: &BoolTensor<B>) -> Device<B> {\n        B::bool_device(tensor)\n    }\n\n    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {\n        B::bool_reshape(tensor, shape)\n    }\n\n    fn bool_slice(tensor: BoolTensor<B>, slices: &[burn_std::Slice]) -> BoolTensor<B> {\n        B::bool_slice(tensor, slices)\n    }\n\n    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B> {\n        B::bool_empty(shape, device)\n    }\n\n    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B> {\n        B::bool_zeros(shape, device)\n    }\n\n    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B> {\n        B::bool_ones(shape, device)\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        slices: &[burn_std::Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        B::bool_slice_assign(tensor, slices, value)\n    }\n\n    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {\n        B::bool_cat(tensors, dim)\n    }\n\n    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        B::bool_equal(lhs, rhs)\n    }\n\n    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B> {\n        B::bool_not(tensor)\n    }\n\n    fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        B::bool_and(lhs, rhs)\n    }\n\n    fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        B::bool_or(lhs, rhs)\n    }\n\n    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        B::bool_xor(lhs, rhs)\n    }\n\n    fn bool_into_float(tensor: BoolTensor<B>) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {\n        AutodiffTensor::new(B::bool_into_float(tensor))\n    }\n\n    fn bool_swap_dims(\n        tensor: <Autodiff<B> as Backend>::BoolTensorPrimitive,\n        dim1: usize,\n        dim2: usize,\n    ) -> <Autodiff<B> as Backend>::BoolTensorPrimitive {\n        B::bool_swap_dims(tensor, dim1, dim2)\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        B::bool_permute(tensor, axes)\n    }\n\n    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B> {\n        B::bool_flip(tensor, axes)\n    }\n\n    async fn bool_argwhere(tensor: BoolTensor<B>) -> IntTensor<B> {\n        B::bool_argwhere(tensor).await\n    }\n\n    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B> {\n        B::bool_expand(tensor, shape)\n    }\n\n    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {\n        B::bool_repeat_dim(tensor, dim, times)\n    }\n\n    fn bool_unfold(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> BoolTensor<Self> {\n        B::bool_unfold(tensor, dim, size, step)\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        source: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        B::bool_mask_where(tensor, mask, source)\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        B::bool_mask_fill(tensor, mask, value)\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        B::bool_gather(dim, tensor, indices)\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        B::bool_scatter_or(dim, tensor, indices, value)\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        B::bool_equal_elem(lhs, rhs)\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/int_tensor.rs",
    "content": "use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};\nuse alloc::vec::Vec;\n\nuse burn_backend::{\n    Backend, Distribution, ExecutionError, Scalar, TensorData,\n    ops::IntTensorOps,\n    tensor::{BoolTensor, Device, IntTensor},\n};\nuse burn_std::{IntDType, Shape};\n\nimpl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {\n    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {\n        B::int_from_data(data, device)\n    }\n\n    async fn int_into_data(tensor: IntTensor<B>) -> Result<TensorData, ExecutionError> {\n        B::int_into_data(tensor).await\n    }\n\n    fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {\n        B::int_to_device(tensor, device)\n    }\n\n    fn int_device(tensor: &IntTensor<B>) -> Device<Self> {\n        B::int_device(tensor)\n    }\n\n    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {\n        B::int_reshape(tensor, shape)\n    }\n\n    fn int_slice(tensor: IntTensor<B>, slices: &[burn_std::Slice]) -> IntTensor<B> {\n        B::int_slice(tensor, slices)\n    }\n\n    fn int_empty(\n        shape: Shape,\n        device: &<Autodiff<B> as Backend>::Device,\n        dtype: IntDType,\n    ) -> IntTensor<B> {\n        B::int_empty(shape, device, dtype)\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<B>,\n        slices: &[burn_std::Slice],\n        value: IntTensor<B>,\n    ) -> IntTensor<B> {\n        B::int_slice_assign(tensor, slices, value)\n    }\n\n    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {\n        B::int_cat(tensors, dim)\n    }\n\n    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        B::int_equal(lhs, rhs)\n    }\n\n    fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        B::int_equal_elem(lhs, rhs)\n    }\n\n    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_add(lhs, rhs)\n    }\n\n    fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::int_add_scalar(lhs, rhs)\n    }\n\n    fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {\n        B::int_clamp_min(tensor, min)\n    }\n\n    fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {\n        B::int_clamp_max(tensor, max)\n    }\n\n    fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {\n        B::int_clamp(tensor, min, max)\n    }\n\n    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_sub(lhs, rhs)\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::int_sub_scalar(lhs, rhs)\n    }\n\n    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_mul(lhs, rhs)\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::int_mul_scalar(lhs, rhs)\n    }\n\n    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_div(lhs, rhs)\n    }\n\n    fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::int_div_scalar(lhs, rhs)\n    }\n\n    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_remainder(lhs, rhs)\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::int_remainder_scalar(lhs, rhs)\n    }\n\n    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::int_matmul(lhs, rhs)\n    }\n\n    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {\n        B::int_neg(tensor)\n    }\n\n    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {\n        B::int_zeros(shape, device, dtype)\n    }\n\n    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {\n        B::int_ones(shape, device, dtype)\n    }\n\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<Self>,\n        dtype: IntDType,\n    ) -> IntTensor<B> {\n        B::int_full(shape, fill_value, device, dtype)\n    }\n\n    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {\n        B::int_sum(tensor)\n    }\n\n    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_sum_dim(tensor, dim)\n    }\n\n    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {\n        B::int_mean(tensor)\n    }\n\n    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_mean_dim(tensor, dim)\n    }\n\n    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_cumsum(tensor, dim)\n    }\n\n    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_cumprod(tensor, dim)\n    }\n\n    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_cummin(tensor, dim)\n    }\n\n    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_cummax(tensor, dim)\n    }\n\n    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {\n        B::int_repeat_dim(tensor, dim, times)\n    }\n\n    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        B::int_greater(lhs, rhs)\n    }\n\n    fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        B::int_greater_elem(lhs, rhs)\n    }\n\n    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        B::int_greater_equal(lhs, rhs)\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        B::int_greater_equal_elem(lhs, rhs)\n    }\n\n    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        B::int_lower(lhs, rhs)\n    }\n\n    fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        B::int_lower_elem(lhs, rhs)\n    }\n\n    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        B::int_lower_equal(lhs, rhs)\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        B::int_lower_equal_elem(lhs, rhs)\n    }\n\n    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {\n        B::int_gather(dim, tensor, indices)\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<B>,\n        indices: IntTensor<B>,\n        value: IntTensor<B>,\n    ) -> IntTensor<B> {\n        B::int_scatter_add(dim, tensor, indices, value)\n    }\n\n    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {\n        B::int_select(tensor, dim, indices)\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<B>,\n        dim: usize,\n        indices: IntTensor<B>,\n        value: IntTensor<B>,\n    ) -> IntTensor<B> {\n        B::int_select_add(tensor, dim, indices, value)\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<B>,\n        mask: BoolTensor<B>,\n        value: IntTensor<B>,\n    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {\n        B::int_mask_where(tensor, mask, value)\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<B>,\n        mask: BoolTensor<B>,\n        value: Scalar,\n    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {\n        B::int_mask_fill(tensor, mask, value)\n    }\n\n    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_argmax(tensor, dim)\n    }\n    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_argmin(tensor, dim)\n    }\n    fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {\n        B::int_max(tensor)\n    }\n    fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {\n        B::int_max_dim(tensor, dim)\n    }\n    fn int_max_dim_with_indices(\n        tensor: B::IntTensorPrimitive,\n        dim: usize,\n    ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {\n        B::int_max_dim_with_indices(tensor, dim)\n    }\n    fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {\n        B::int_min(tensor)\n    }\n    fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {\n        B::int_min_dim(tensor, dim)\n    }\n    fn int_min_dim_with_indices(\n        tensor: B::IntTensorPrimitive,\n        dim: usize,\n    ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {\n        B::int_min_dim_with_indices(tensor, dim)\n    }\n    fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {\n        B::int_abs(tensor)\n    }\n    fn int_into_float(\n        tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,\n    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {\n        AutodiffTensor::new(B::int_into_float(tensor))\n    }\n\n    fn int_swap_dims(\n        tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,\n        dim1: usize,\n        dim2: usize,\n    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {\n        B::int_swap_dims(tensor, dim1, dim2)\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> IntTensor<Self> {\n        B::int_random(shape, distribution, device)\n    }\n\n    fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {\n        B::int_arange(range, device)\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        B::int_permute(tensor, axes)\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        B::int_flip(tensor, axes)\n    }\n\n    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        B::int_sign(tensor)\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        B::int_prod(tensor)\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        B::int_prod_dim(tensor, dim)\n    }\n\n    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {\n        B::int_expand(tensor, shape)\n    }\n\n    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        B::int_sort(tensor, dim, descending)\n    }\n\n    fn int_sort_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        descending: bool,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        B::int_sort_with_indices(tensor, dim, descending)\n    }\n\n    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        B::int_argsort(tensor, dim, descending)\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_and(lhs, rhs)\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        B::bitwise_and_scalar(lhs, rhs)\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_or(lhs, rhs)\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        B::bitwise_or_scalar(lhs, rhs)\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_xor(lhs, rhs)\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        B::bitwise_xor_scalar(lhs, rhs)\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_not(tensor)\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_left_shift(lhs, rhs)\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        B::bitwise_left_shift_scalar(lhs, rhs)\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        B::bitwise_right_shift(lhs, rhs)\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        B::bitwise_right_shift_scalar(lhs, rhs)\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        B::int_cast(tensor, dtype)\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        B::int_unfold(tensor, dim, size, step)\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/maxmin.rs",
    "content": "use super::{Backward, Ops, unary};\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients};\nuse burn_backend::{Backend, TensorMetadata};\nuse burn_std::Shape;\n\n#[derive(Debug)]\npub(crate) struct MaxMinDim;\n\nimpl<B: Backend> Backward<B, 1> for MaxMinDim {\n    type State = (B::IntTensorPrimitive, Shape, usize);\n\n    fn backward(\n        self,\n        ops: Ops<Self::State, 1>,\n        grads: &mut Gradients,\n        _checkpointer: &mut Checkpointer,\n    ) {\n        unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n            let (indices, shape, dim) = ops.state;\n            let device = B::float_device(&grad);\n            let dtype = grad.dtype();\n            let zeros = B::float_zeros(shape, &device, dtype.into());\n\n            B::float_scatter_add(dim, zeros, indices, grad)\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/mod.rs",
    "content": "mod activation;\nmod backward;\nmod base;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\n\npub(crate) mod maxmin;\npub(crate) mod sort;\n\npub use backward::*;\npub use base::*;\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/module.rs",
    "content": "use crate::Autodiff;\nuse crate::checkpoint::base::Checkpointer;\nuse crate::checkpoint::strategy::CheckpointStrategy;\nuse crate::grads::Gradients;\nuse crate::graph::NodeId;\nuse crate::ops::{Backward, Ops, unary};\nuse crate::tensor::AutodiffTensor;\n\nuse burn_backend::Backend;\nuse burn_backend::ops::attention::attention_fallback;\nuse burn_backend::ops::*;\nuse burn_backend::tensor::{FloatTensor, IntTensor};\n\nuse super::OpsKind;\n\nimpl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B, C> {\n    fn embedding(weights: AutodiffTensor<B>, indices: IntTensor<B>) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct Embedding;\n\n        impl<B: Backend> Backward<B, 1> for Embedding {\n            type State = (B::FloatTensorPrimitive, IntTensor<B>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (weights, indices) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::embedding_backward(weights, grad, indices)\n                });\n            }\n        }\n\n        match Embedding\n            .prepare::<C>([weights.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (weights.primitive.clone(), indices.clone()),\n                B::embedding(weights.primitive, indices),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)),\n        }\n    }\n\n    fn embedding_backward(\n        _weights: AutodiffTensor<B>,\n        _output: AutodiffTensor<B>,\n        _indices: IntTensor<B>,\n    ) -> AutodiffTensor<B> {\n        panic!(\"Can't differentiate embedding backward.\");\n    }\n\n    fn conv1d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvOptions<1>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct Conv1DWithBias;\n        #[derive(Debug)]\n        struct Conv1DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for Conv1DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvOptions<1>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv1d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv1d_weight_backward(x.clone(), weight, grad.clone(), options);\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv1d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for Conv1DNoBias {\n            type State = (NodeId, NodeId, ConvOptions<1>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv1d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv1d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n        match bias {\n            Some(bias) => match Conv1DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv1d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match Conv1DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv1d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => {\n                    prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))\n                }\n            },\n        }\n    }\n\n    fn conv_transpose1d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvTransposeOptions<1>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct ConvTranspose1DWithBias;\n        #[derive(Debug)]\n        struct ConvTranspose1DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for ConvTranspose1DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<1>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose1d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose1d_weight_backward(\n                        x.clone(),\n                        weight,\n                        grad.clone(),\n                        options,\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv_transpose1d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for ConvTranspose1DNoBias {\n            type State = (NodeId, NodeId, ConvTransposeOptions<1>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose1d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose1d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        match bias {\n            Some(bias) => match ConvTranspose1DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv_transpose1d(\n                            x.primitive,\n                            weight.primitive,\n                            Some(bias.primitive),\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match ConvTranspose1DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv_transpose1d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(\n                    x.primitive,\n                    weight.primitive,\n                    None,\n                    options,\n                )),\n            },\n        }\n    }\n\n    fn conv2d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvOptions<2>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct Conv2DWithBias;\n        #[derive(Debug)]\n        struct Conv2DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for Conv2DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv2d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad =\n                        B::conv2d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv2d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for Conv2DNoBias {\n            type State = (NodeId, NodeId, ConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv2d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv2d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        match bias {\n            Some(bias) => match Conv2DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv2d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match Conv2DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv2d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n\n                OpsKind::UnTracked(prep) => {\n                    prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))\n                }\n            },\n        }\n    }\n\n    fn deform_conv2d(\n        x: AutodiffTensor<B>,\n        offset: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        mask: Option<AutodiffTensor<B>>,\n        bias: Option<AutodiffTensor<B>>,\n        options: DeformConvOptions<2>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct DeformConv2DWithMaskWithBias;\n        #[derive(Debug)]\n        struct DeformConv2DWithMaskNoBias;\n        #[derive(Debug)]\n        struct DeformConv2DNoMaskWithBias;\n        #[derive(Debug)]\n        struct DeformConv2DNoMaskNoBias;\n\n        impl<B: Backend> Backward<B, 5> for DeformConv2DWithMaskWithBias {\n            type State = (NodeId, NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 5>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, offset_state, weight_state, mask_state, bias_state, options) =\n                    ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n                let offset = checkpointer.retrieve_node_output(offset_state);\n                let weight = checkpointer.retrieve_node_output(weight_state);\n                let mask = Some(checkpointer.retrieve_node_output(mask_state));\n                let bias = Some(checkpointer.retrieve_node_output(bias_state));\n\n                let backward =\n                    B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options);\n\n                if let Some(node) = node_x {\n                    grads.register::<B>(node.id, backward.x_grad)\n                }\n                if let Some(node) = node_offset {\n                    grads.register::<B>(node.id, backward.offset_grad)\n                }\n                if let Some(node) = node_weight {\n                    grads.register::<B>(node.id, backward.weight_grad)\n                }\n                if let Some(node) = node_mask {\n                    grads.register::<B>(node.id, backward.mask_grad.unwrap())\n                }\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, backward.bias_grad.unwrap())\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 4> for DeformConv2DWithMaskNoBias {\n            type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 4>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_offset, node_weight, node_mask] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, offset_state, weight_state, mask_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n                let offset = checkpointer.retrieve_node_output(offset_state);\n                let weight = checkpointer.retrieve_node_output(weight_state);\n                let mask = Some(checkpointer.retrieve_node_output(mask_state));\n\n                let backward =\n                    B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options);\n\n                if let Some(node) = node_x {\n                    grads.register::<B>(node.id, backward.x_grad)\n                }\n                if let Some(node) = node_offset {\n                    grads.register::<B>(node.id, backward.offset_grad)\n                }\n                if let Some(node) = node_weight {\n                    grads.register::<B>(node.id, backward.weight_grad)\n                }\n                if let Some(node) = node_mask {\n                    grads.register::<B>(node.id, backward.mask_grad.unwrap())\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 4> for DeformConv2DNoMaskWithBias {\n            type State = (NodeId, NodeId, NodeId, NodeId, DeformConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 4>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_offset, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, offset_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n                let offset = checkpointer.retrieve_node_output(offset_state);\n                let weight = checkpointer.retrieve_node_output(weight_state);\n                let bias = Some(checkpointer.retrieve_node_output(bias_state));\n\n                let backward =\n                    B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options);\n\n                if let Some(node) = node_x {\n                    grads.register::<B>(node.id, backward.x_grad)\n                }\n                if let Some(node) = node_offset {\n                    grads.register::<B>(node.id, backward.offset_grad)\n                }\n                if let Some(node) = node_weight {\n                    grads.register::<B>(node.id, backward.weight_grad)\n                }\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, backward.bias_grad.unwrap())\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 3> for DeformConv2DNoMaskNoBias {\n            type State = (NodeId, NodeId, NodeId, DeformConvOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_offset, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, offset_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n                let offset = checkpointer.retrieve_node_output(offset_state);\n                let weight = checkpointer.retrieve_node_output(weight_state);\n\n                let backward =\n                    B::deform_conv2d_backward(x, offset, weight, None, None, grad, options);\n\n                if let Some(node) = node_x {\n                    grads.register::<B>(node.id, backward.x_grad)\n                }\n                if let Some(node) = node_offset {\n                    grads.register::<B>(node.id, backward.offset_grad)\n                }\n                if let Some(node) = node_weight {\n                    grads.register::<B>(node.id, backward.weight_grad)\n                }\n            }\n        }\n\n        match (mask, bias) {\n            (Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias\n                .prepare::<C>([\n                    x.node.clone(),\n                    offset.node.clone(),\n                    weight.node.clone(),\n                    mask.node.clone(),\n                    bias.node.clone(),\n                ])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let offset_state = prep.checkpoint(&offset);\n                    let weight_state = prep.checkpoint(&weight);\n                    let mask_state = prep.checkpoint(&mask);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (\n                            x_state,\n                            offset_state,\n                            weight_state,\n                            mask_state,\n                            bias_state,\n                            options.clone(),\n                        ),\n                        B::deform_conv2d(\n                            x.primitive,\n                            offset.primitive,\n                            weight.primitive,\n                            Some(mask.primitive),\n                            Some(bias.primitive),\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(\n                    x.primitive,\n                    offset.primitive,\n                    weight.primitive,\n                    Some(mask.primitive),\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            (Some(mask), None) => match DeformConv2DWithMaskNoBias\n                .prepare::<C>([\n                    x.node.clone(),\n                    offset.node.clone(),\n                    weight.node.clone(),\n                    mask.node.clone(),\n                ])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let offset_state = prep.checkpoint(&offset);\n                    let weight_state = prep.checkpoint(&weight);\n                    let mask_state = prep.checkpoint(&mask);\n                    prep.finish(\n                        (\n                            x_state,\n                            offset_state,\n                            weight_state,\n                            mask_state,\n                            options.clone(),\n                        ),\n                        B::deform_conv2d(\n                            x.primitive,\n                            offset.primitive,\n                            weight.primitive,\n                            Some(mask.primitive),\n                            None,\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(\n                    x.primitive,\n                    offset.primitive,\n                    weight.primitive,\n                    Some(mask.primitive),\n                    None,\n                    options,\n                )),\n            },\n            (None, Some(bias)) => match DeformConv2DNoMaskWithBias\n                .prepare::<C>([\n                    x.node.clone(),\n                    offset.node.clone(),\n                    weight.node.clone(),\n                    bias.node.clone(),\n                ])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let offset_state = prep.checkpoint(&offset);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (\n                            x_state,\n                            offset_state,\n                            weight_state,\n                            bias_state,\n                            options.clone(),\n                        ),\n                        B::deform_conv2d(\n                            x.primitive,\n                            offset.primitive,\n                            weight.primitive,\n                            None,\n                            Some(bias.primitive),\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(\n                    x.primitive,\n                    offset.primitive,\n                    weight.primitive,\n                    None,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            (None, None) => match DeformConv2DNoMaskNoBias\n                .prepare::<C>([x.node.clone(), offset.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let offset_state = prep.checkpoint(&offset);\n                    let weight_state = prep.checkpoint(&weight);\n                    prep.finish(\n                        (x_state, offset_state, weight_state, options.clone()),\n                        B::deform_conv2d(\n                            x.primitive,\n                            offset.primitive,\n                            weight.primitive,\n                            None,\n                            None,\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(\n                    x.primitive,\n                    offset.primitive,\n                    weight.primitive,\n                    None,\n                    None,\n                    options,\n                )),\n            },\n        }\n    }\n\n    fn deform_conv2d_backward(\n        _x: AutodiffTensor<B>,\n        _offset: AutodiffTensor<B>,\n        _weight: AutodiffTensor<B>,\n        _mask: Option<AutodiffTensor<B>>,\n        _bias: Option<AutodiffTensor<B>>,\n        _output_grad: AutodiffTensor<B>,\n        _options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        panic!(\"Can't differentiate deform conv 2d backward.\");\n    }\n\n    fn conv_transpose2d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvTransposeOptions<2>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct ConvTranspose2DWithBias;\n        #[derive(Debug)]\n        struct ConvTranspose2DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for ConvTranspose2DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose2d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose2d_weight_backward(\n                        x.clone(),\n                        weight,\n                        grad.clone(),\n                        options,\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv_transpose2d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for ConvTranspose2DNoBias {\n            type State = (NodeId, NodeId, ConvTransposeOptions<2>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose2d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose2d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        match bias {\n            Some(bias) => match ConvTranspose2DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv_transpose2d(\n                            x.primitive,\n                            weight.primitive,\n                            Some(bias.primitive),\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match ConvTranspose2DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv_transpose2d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(\n                    x.primitive,\n                    weight.primitive,\n                    None,\n                    options,\n                )),\n            },\n        }\n    }\n\n    fn conv3d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvOptions<3>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct Conv3DWithBias;\n        #[derive(Debug)]\n        struct Conv3DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for Conv3DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvOptions<3>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv3d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad =\n                        B::conv3d_weight_backward(x.clone(), weight.clone(), grad.clone(), options);\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv3d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for Conv3DNoBias {\n            type State = (NodeId, NodeId, ConvOptions<3>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv3d_x_backward(\n                        x.clone(),\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv3d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        match bias {\n            Some(bias) => match Conv3DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv3d(x.primitive, weight.primitive, Some(bias.primitive), options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv3d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match Conv3DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv3d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n\n                OpsKind::UnTracked(prep) => {\n                    prep.finish(B::conv3d(x.primitive, weight.primitive, None, options))\n                }\n            },\n        }\n    }\n\n    fn conv_transpose3d(\n        x: AutodiffTensor<B>,\n        weight: AutodiffTensor<B>,\n        bias: Option<AutodiffTensor<B>>,\n        options: ConvTransposeOptions<3>,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct ConvTranspose3DWithBias;\n        #[derive(Debug)]\n        struct ConvTranspose3DNoBias;\n\n        impl<B: Backend> Backward<B, 3> for ConvTranspose3DWithBias {\n            type State = (NodeId, NodeId, NodeId, ConvTransposeOptions<3>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight, node_bias] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, bias_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n                let bias = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(bias_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose3d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose3d_weight_backward(\n                        x.clone(),\n                        weight,\n                        grad.clone(),\n                        options,\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_bias {\n                    let grad = B::conv_transpose3d_bias_backward(x, bias, grad);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for ConvTranspose3DNoBias {\n            type State = (NodeId, NodeId, ConvTransposeOptions<3>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_x, node_weight] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, weight_state, options) = ops.state;\n                let x = checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(x_state);\n                let weight =\n                    checkpointer.retrieve_node_output::<B::FloatTensorPrimitive>(weight_state);\n\n                if let Some(node) = node_x {\n                    let grad = B::conv_transpose3d_x_backward(\n                        weight.clone(),\n                        grad.clone(),\n                        options.clone(),\n                    );\n                    grads.register::<B>(node.id, grad)\n                }\n                if let Some(node) = node_weight {\n                    let grad = B::conv_transpose3d_weight_backward(x, weight, grad, options);\n                    grads.register::<B>(node.id, grad)\n                }\n            }\n        }\n\n        match bias {\n            Some(bias) => match ConvTranspose3DWithBias\n                .prepare::<C>([x.node.clone(), weight.node.clone(), bias.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n                    let bias_state = prep.checkpoint(&bias);\n\n                    prep.finish(\n                        (x_state, weight_state, bias_state, options.clone()),\n                        B::conv_transpose3d(\n                            x.primitive,\n                            weight.primitive,\n                            Some(bias.primitive),\n                            options,\n                        ),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(\n                    x.primitive,\n                    weight.primitive,\n                    Some(bias.primitive),\n                    options,\n                )),\n            },\n            None => match ConvTranspose3DNoBias\n                .prepare::<C>([x.node.clone(), weight.node.clone()])\n                .compute_bound()\n                .stateful()\n            {\n                OpsKind::Tracked(mut prep) => {\n                    let x_state = prep.checkpoint(&x);\n                    let weight_state = prep.checkpoint(&weight);\n\n                    prep.finish(\n                        (x_state, weight_state, options.clone()),\n                        B::conv_transpose3d(x.primitive, weight.primitive, None, options),\n                    )\n                }\n                OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose3d(\n                    x.primitive,\n                    weight.primitive,\n                    None,\n                    options,\n                )),\n            },\n        }\n    }\n\n    // TODO: Support a custom unfold4d operation by overriding the default implementation.\n    //\n    // We don't override it now because the fold operation isn't available for the backward pass.\n    // This implies that when autodiff is enabled, custom unfold operations defined by backends\n    // won't be used. Instead, the conv2d operation with custom weights matrix will be used.\n    // Therefore, the conv2d backward pass will be used for the unfold4d backward pass.\n    //\n    // fn unfold4d(\n    //     x:AutodiffTensor<B>,\n    //     kernel_size: [usize; 2],\n    //     options: UnfoldOptions,\n    // ) -> AutodiffTensor<B> {\n    //     todo!()\n    // }\n\n    fn avg_pool1d(\n        x: AutodiffTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct AvgPool1D;\n\n        impl<B: Backend> Backward<B, 1> for AvgPool1D {\n            type State = (NodeId, usize, usize, usize, bool, bool);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_parent] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n                let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =\n                    ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n\n                if let Some(node) = node_parent {\n                    let grad = B::avg_pool1d_backward(\n                        x,\n                        grad,\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    );\n                    grads.register::<B>(node.id, grad);\n                }\n            }\n        }\n\n        match AvgPool1D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                prep.finish(\n                    (\n                        x_state,\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    ),\n                    B::avg_pool1d(\n                        x.primitive.clone(),\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    ),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::avg_pool1d(\n                x.primitive,\n                kernel_size,\n                stride,\n                padding,\n                count_include_pad,\n                ceil_mode,\n            )),\n        }\n    }\n\n    fn avg_pool2d(\n        x: AutodiffTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct AvgPool2D;\n\n        impl<B: Backend> Backward<B, 1> for AvgPool2D {\n            type State = (NodeId, [usize; 2], [usize; 2], [usize; 2], bool, bool);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_parent] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n                let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =\n                    ops.state;\n                let x = checkpointer.retrieve_node_output(x_state);\n\n                if let Some(node) = node_parent {\n                    let grad = B::avg_pool2d_backward(\n                        x,\n                        grad,\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    );\n                    grads.register::<B>(node.id, grad);\n                }\n            }\n        }\n\n        match AvgPool2D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                prep.finish(\n                    (\n                        x_state,\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    ),\n                    B::avg_pool2d(\n                        x.primitive.clone(),\n                        kernel_size,\n                        stride,\n                        padding,\n                        count_include_pad,\n                        ceil_mode,\n                    ),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::avg_pool2d(\n                x.primitive,\n                kernel_size,\n                stride,\n                padding,\n                count_include_pad,\n                ceil_mode,\n            )),\n        }\n    }\n\n    fn avg_pool2d_backward(\n        _x: AutodiffTensor<B>,\n        _grad: AutodiffTensor<B>,\n        _kernel_size: [usize; 2],\n        _stride: [usize; 2],\n        _padding: [usize; 2],\n        _count_include_pad: bool,\n        _ceil_mode: bool,\n    ) -> AutodiffTensor<B> {\n        panic!(\"Can't differentiate avg pool 2d backward.\");\n    }\n\n    fn max_pool1d(\n        x: AutodiffTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> AutodiffTensor<B> {\n        match MaxPool1D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                let output = B::max_pool1d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n                prep.finish(\n                    (\n                        x_state,\n                        output.indices,\n                        kernel_size,\n                        stride,\n                        padding,\n                        dilation,\n                        ceil_mode,\n                    ),\n                    output.output,\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::max_pool1d(\n                x.primitive,\n                kernel_size,\n                stride,\n                padding,\n                dilation,\n                ceil_mode,\n            )),\n        }\n    }\n\n    fn max_pool1d_with_indices(\n        x: AutodiffTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<Self> {\n        match MaxPool1D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                let output = B::max_pool1d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n\n                let output_tensor = prep.finish(\n                    (\n                        x_state,\n                        output.indices.clone(),\n                        kernel_size,\n                        stride,\n                        padding,\n                        dilation,\n                        ceil_mode,\n                    ),\n                    output.output,\n                );\n\n                MaxPool1dWithIndices::new(output_tensor, output.indices)\n            }\n            OpsKind::UnTracked(prep) => {\n                let output = B::max_pool1d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n                let output_tensor = prep.finish(output.output);\n\n                MaxPool1dWithIndices::new(output_tensor, output.indices)\n            }\n        }\n    }\n\n    fn max_pool1d_with_indices_backward(\n        x: AutodiffTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        output_grad: AutodiffTensor<B>,\n        indices: IntTensor<B>,\n    ) -> MaxPool1dBackward<Self> {\n        let output = B::max_pool1d_with_indices_backward(\n            x.primitive,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            output_grad.primitive,\n            indices,\n        );\n        MaxPool1dBackward::new(AutodiffTensor::new(output.x_grad))\n    }\n\n    fn max_pool2d(\n        x: AutodiffTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> AutodiffTensor<B> {\n        match MaxPool2D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                let output = B::max_pool2d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n                prep.finish(\n                    (\n                        x_state,\n                        output.indices,\n                        kernel_size,\n                        stride,\n                        padding,\n                        dilation,\n                        ceil_mode,\n                    ),\n                    output.output,\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::max_pool2d(\n                x.primitive,\n                kernel_size,\n                stride,\n                padding,\n                dilation,\n                ceil_mode,\n            )),\n        }\n    }\n\n    fn max_pool2d_with_indices(\n        x: AutodiffTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        match MaxPool2D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n\n                let output = B::max_pool2d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n\n                let output_tensor = prep.finish(\n                    (\n                        x_state,\n                        output.indices.clone(),\n                        kernel_size,\n                        stride,\n                        padding,\n                        dilation,\n                        ceil_mode,\n                    ),\n                    output.output,\n                );\n\n                MaxPool2dWithIndices::new(output_tensor, output.indices)\n            }\n            OpsKind::UnTracked(prep) => {\n                let output = B::max_pool2d_with_indices(\n                    x.primitive,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                );\n                let output_tensor = prep.finish(output.output);\n\n                MaxPool2dWithIndices::new(output_tensor, output.indices)\n            }\n        }\n    }\n\n    fn max_pool2d_with_indices_backward(\n        _x: AutodiffTensor<B>,\n        _kernel_size: [usize; 2],\n        _stride: [usize; 2],\n        _padding: [usize; 2],\n        _dilation: [usize; 2],\n        _ceil_mode: bool,\n        _output_grad: AutodiffTensor<B>,\n        _indices: IntTensor<B>,\n    ) -> MaxPool2dBackward<Self> {\n        panic!(\"Can't differentiate max pool2d with indices backward.\");\n    }\n    fn adaptive_avg_pool1d(x: AutodiffTensor<B>, output_size: usize) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct AdaptiveAvgPool1D;\n\n        impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool1D {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_parent] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n                let state = checkpointer.retrieve_node_output(ops.state);\n\n                if let Some(node) = node_parent {\n                    let grad = B::adaptive_avg_pool1d_backward(state, grad);\n                    grads.register::<B>(node.id, grad);\n                }\n            }\n        }\n\n        match AdaptiveAvgPool1D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                prep.finish(x_state, B::adaptive_avg_pool1d(x.primitive, output_size))\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::adaptive_avg_pool1d(x.primitive, output_size))\n            }\n        }\n    }\n\n    fn adaptive_avg_pool2d(x: AutodiffTensor<B>, output_size: [usize; 2]) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct AdaptiveAvgPool2D;\n\n        impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool2D {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_parent] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n                let state = checkpointer.retrieve_node_output(ops.state);\n\n                if let Some(node) = node_parent {\n                    let grad = B::adaptive_avg_pool2d_backward(state, grad);\n                    grads.register::<B>(node.id, grad);\n                }\n            }\n        }\n\n        match AdaptiveAvgPool2D\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                prep.finish(x_state, B::adaptive_avg_pool2d(x.primitive, output_size))\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::adaptive_avg_pool2d(x.primitive, output_size))\n            }\n        }\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        _x: AutodiffTensor<B>,\n        _grad: AutodiffTensor<B>,\n    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {\n        panic!(\"Can't differentiate adaptive avg pool2d backward.\");\n    }\n\n    fn interpolate(\n        x: AutodiffTensor<B>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> AutodiffTensor<B> {\n        #[derive(Debug)]\n        struct Interpolate;\n        impl<B: Backend> Backward<B, 1> for Interpolate {\n            type State = (NodeId, [usize; 2], InterpolateOptions);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let [node_parent] = ops.parents;\n                let grad = grads.consume::<B>(&ops.node);\n\n                let (x_state, output_size, options) = ops.state;\n                let state = checkpointer.retrieve_node_output(x_state);\n\n                if let Some(node) = node_parent {\n                    let grad = B::interpolate_backward(state, grad, output_size, options);\n                    grads.register::<B>(node.id, grad);\n                }\n            }\n        }\n\n        match Interpolate\n            .prepare::<C>([x.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let x_state = prep.checkpoint(&x);\n                let output = B::interpolate(x.primitive.clone(), output_size, options.clone());\n                prep.finish((x_state, output_size, options), output)\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::interpolate(x.primitive, output_size, options))\n            }\n        }\n    }\n\n    fn interpolate_backward(\n        _x: FloatTensor<Autodiff<B, C>>,\n        _grad: FloatTensor<Autodiff<B, C>>,\n        _output_size: [usize; 2],\n        _options: InterpolateOptions,\n    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {\n        panic!(\"Can't differentiate interpolate backward.\");\n    }\n\n    fn attention(\n        query: FloatTensor<Autodiff<B, C>>,\n        key: FloatTensor<Autodiff<B, C>>,\n        value: FloatTensor<Autodiff<B, C>>,\n        mask: Option<burn_backend::tensor::BoolTensor<Autodiff<B, C>>>,\n        attn_bias: Option<FloatTensor<Autodiff<B, C>>>,\n        options: AttentionModuleOptions,\n    ) -> FloatTensor<Autodiff<B, C>> {\n        attention_fallback::<Self>(query, key, value, mask, attn_bias, options)\n    }\n}\n\n#[derive(Debug)]\nstruct MaxPool1D;\n\nimpl<B: Backend> Backward<B, 1> for MaxPool1D {\n    type State = (NodeId, IntTensor<B>, usize, usize, usize, usize, bool);\n\n    fn backward(\n        self,\n        ops: Ops<Self::State, 1>,\n        grads: &mut Gradients,\n        checkpointer: &mut Checkpointer,\n    ) {\n        let [node_parent] = ops.parents;\n        let grad = grads.consume::<B>(&ops.node);\n        let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;\n        let x = checkpointer.retrieve_node_output(x_state);\n\n        if let Some(node) = node_parent {\n            let grad = B::max_pool1d_with_indices_backward(\n                x,\n                kernel_size,\n                stride,\n                padding,\n                dilation,\n                ceil_mode,\n                grad,\n                indices,\n            );\n\n            grads.register::<B>(node.id, grad.x_grad);\n        }\n    }\n}\n\n#[derive(Debug)]\nstruct MaxPool2D;\n\nimpl<B: Backend> Backward<B, 1> for MaxPool2D {\n    type State = (\n        NodeId,\n        IntTensor<B>,\n        [usize; 2],\n        [usize; 2],\n        [usize; 2],\n        [usize; 2],\n        bool,\n    );\n\n    fn backward(\n        self,\n        ops: Ops<Self::State, 1>,\n        grads: &mut Gradients,\n        checkpointer: &mut Checkpointer,\n    ) {\n        let [node_parent] = ops.parents;\n        let grad = grads.consume::<B>(&ops.node);\n        let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;\n        let x = checkpointer.retrieve_node_output(x_state);\n\n        if let Some(node) = node_parent {\n            let grad = B::max_pool2d_with_indices_backward(\n                x,\n                kernel_size,\n                stride,\n                padding,\n                dilation,\n                ceil_mode,\n                grad,\n                indices,\n            );\n\n            grads.register::<B>(node.id, grad.x_grad);\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    Backend, ExecutionError, TensorData,\n    ops::QTensorOps,\n    tensor::{\n        Device, FloatTensor, IntTensor, QuantizedTensor,\n        quantization::QuantizationParametersPrimitive,\n    },\n};\nuse burn_std::{QuantScheme, Shape};\n\nuse crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};\n\nimpl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {\n    fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {\n        todo!()\n    }\n\n    fn quantize(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n        _qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        todo!() // required for QAT\n    }\n\n    fn quantize_dynamic(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n    ) -> QuantizedTensor<Self> {\n        todo!()\n    }\n\n    fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        todo!()\n    }\n\n    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {\n        B::q_device(tensor)\n    }\n\n    fn q_to_device(\n        _tensor: QuantizedTensor<Self>,\n        _device: &Device<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        B::q_reshape(tensor, shape)\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        B::q_into_data(tensor).await\n    }\n\n    fn q_swap_dims(\n        _tensor: QuantizedTensor<Self>,\n        _dim1: usize,\n        _dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_gather(\n        _dim: usize,\n        _tensor: QuantizedTensor<Self>,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_select(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_slice(\n        _tensor: QuantizedTensor<Self>,\n        _slices: &[burn_std::Slice],\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {\n        B::q_argmax(tensor, dim)\n    }\n\n    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {\n        B::q_argmin(tensor, dim)\n    }\n\n    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/sort.rs",
    "content": "use super::{Backward, Ops, unary};\nuse crate::{checkpoint::base::Checkpointer, grads::Gradients};\nuse burn_backend::{Backend, TensorMetadata};\nuse burn_std::Shape;\n\n#[derive(Debug)]\npub(crate) struct SortDim;\n\nimpl<B: Backend> Backward<B, 1> for SortDim {\n    type State = (B::IntTensorPrimitive, Shape, usize);\n\n    fn backward(\n        self,\n        ops: Ops<Self::State, 1>,\n        grads: &mut Gradients,\n        _checkpointer: &mut Checkpointer,\n    ) {\n        unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n            let (indices, shape, dim) = ops.state;\n            let device = B::float_device(&grad);\n            let dtype = grad.dtype();\n            let zeros = B::float_zeros(shape, &device, dtype.into());\n\n            B::float_scatter_add(dim, zeros, indices, grad)\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/tensor.rs",
    "content": "use alloc::{boxed::Box, vec, vec::Vec};\nuse core::marker::PhantomData;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports, reason = \"required on aarch64, unused on x86_64\")]\nuse num_traits::float::Float;\n\nuse crate::{\n    Autodiff,\n    checkpoint::{\n        base::Checkpointer, builder::CheckpointerBuilder, retro_forward::RetroForward,\n        state::BackwardStates, strategy::CheckpointStrategy,\n    },\n    grads::Gradients,\n    graph::{ComputingProperty, NodeId, NodeRef, Parent, Requirement, Step},\n    ops::{Backward, Ops, OpsKind, binary, broadcast_shape, unary},\n    retro_binary, retro_unary, retro_unary_scalar,\n    tensor::AutodiffTensor,\n    utils::duplicate,\n};\n\nuse burn_backend::{\n    Backend, ExecutionError, TensorData, TensorMetadata,\n    ops::FloatTensorOps,\n    tensor::{BoolTensor, Device, FloatTensor, IntTensor},\n};\nuse burn_backend::{Scalar, ops::unfold::calculate_unfold_windows};\nuse burn_std::{FloatDType, Shape, Slice};\n\nuse super::maxmin::MaxMinDim;\n\n// Unsqueeze op on primitive.\nfn unsqueeze_like<B: Backend>(\n    tensor: B::FloatTensorPrimitive,\n    shape: Shape,\n) -> B::FloatTensorPrimitive {\n    let ndims_out = shape.num_dims();\n    let shape = tensor.shape();\n    let ndims_in = shape.num_dims();\n\n    let mut dims = vec![1; ndims_out];\n    let num_ones = ndims_out - ndims_in;\n    dims[num_ones..(ndims_in + num_ones)].copy_from_slice(&shape[..ndims_in]);\n\n    B::float_reshape(tensor, Shape::from(dims))\n}\n\nimpl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C> {\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(data),\n        fields(?data.shape, ?data.dtype)\n    ))]\n    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {\n        AutodiffTensor::new(B::float_from_data(data, device))\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: burn_backend::Distribution,\n        device: &Device<Self>,\n    ) -> FloatTensor<Self> {\n        AutodiffTensor::new(B::float_random(shape, distribution, device))\n    }\n\n    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        AutodiffTensor::new(B::float_zeros(shape, device, dtype))\n    }\n\n    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        AutodiffTensor::new(B::float_ones(shape, device, dtype))\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(\n            from = ?tensor.node,\n            shape = ?tensor.shape(),\n            dtype = ?tensor.dtype(),\n        )\n    ))]\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        B::float_into_data(tensor.primitive).await\n    }\n\n    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {\n        B::float_device(&tensor.primitive)\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(\n            from = ?tensor.node,\n            shape = ?tensor.shape(),\n            dtype = ?tensor.dtype(),\n        )\n    ))]\n    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct ToDevice;\n\n        impl<B: Backend> Backward<B, 1> for ToDevice {\n            type State = B::Device;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_to_device(grad, &ops.state)\n                });\n            }\n        }\n\n        match ToDevice\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let device_old = B::float_device(&tensor.primitive);\n                prep.finish(device_old, B::float_to_device(tensor.primitive, device))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_to_device(tensor.primitive, device)),\n        }\n    }\n\n    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        AutodiffTensor::new(B::float_empty(shape, device, dtype))\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Add;\n\n        retro_binary!(RetroAdd, B::float_add);\n\n        impl<B: Backend> Backward<B, 2> for Add {\n            type State = (Shape, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape_lhs, shape_rhs) = ops.state;\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| broadcast_shape::<B>(grad, &shape_lhs),\n                    |grad| broadcast_shape::<B>(grad, &shape_rhs),\n                );\n            }\n        }\n\n        match Add\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAdd::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (lhs.primitive.shape(), rhs.primitive.shape()),\n                B::float_add(lhs.primitive, rhs.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_add(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct AddScalar;\n\n        retro_unary_scalar!(RetroAddScalar, B::float_add_scalar);\n\n        impl<B: Backend> Backward<B, 1> for AddScalar {\n            type State = ();\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| grad);\n            }\n        }\n\n        AddScalar\n            .prepare::<C>([lhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAddScalar::<B>::new(lhs.node.id, rhs))\n            .parents([&lhs])\n            .stateless(B::float_add_scalar(lhs.primitive, rhs))\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sub;\n\n        retro_binary!(RetroSub, B::float_sub);\n\n        impl<B: Backend> Backward<B, 2> for Sub {\n            type State = (Shape, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape_lhs, shape_rhs) = ops.state;\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| broadcast_shape::<B>(grad, &shape_lhs),\n                    |grad| broadcast_shape::<B>(B::float_neg(grad), &shape_rhs),\n                );\n            }\n        }\n\n        match Sub\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSub::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (lhs.primitive.shape(), rhs.primitive.shape()),\n                B::float_sub(lhs.primitive, rhs.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_sub(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct SubScalar;\n\n        retro_unary_scalar!(RetroSubScalar, B::float_sub_scalar);\n\n        impl<B: Backend> Backward<B, 1> for SubScalar {\n            type State = ();\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| grad);\n            }\n        }\n\n        SubScalar\n            .prepare::<C>([lhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSubScalar::<B>::new(lhs.node.id, rhs))\n            .parents([&lhs])\n            .stateless(B::float_sub_scalar(lhs.primitive, rhs))\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Mul;\n\n        retro_binary!(RetroMul, B::float_mul);\n\n        impl<B: Backend> Backward<B, 2> for Mul {\n            type State = (Option<NodeId>, Option<NodeId>, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs, rhs, broadcast) = ops.state;\n                let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs));\n                let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        let grad = B::float_mul(grad, rhs.unwrap());\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        let grad = B::float_mul(grad, lhs.unwrap());\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let lhs_tracked = lhs.is_tracked();\n        let rhs_tracked = rhs.is_tracked();\n        let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);\n\n        match Mul\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroMul::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs));\n                let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs));\n\n                prep.finish(\n                    (lhs_state, rhs_state, broadcast),\n                    B::float_mul(lhs.primitive, rhs.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_mul(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct MulScalar;\n\n        retro_unary_scalar!(RetroMulScalar, B::float_mul_scalar);\n\n        impl<B: Backend> Backward<B, 1> for MulScalar {\n            type State = Scalar;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mul_scalar(grad, ops.state)\n                });\n            }\n        }\n\n        match MulScalar\n            .prepare::<C>([lhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroMulScalar::<B>::new(lhs.node.id, rhs))\n            .parents([&lhs])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(rhs, B::float_mul_scalar(lhs.primitive, rhs)),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_mul_scalar(lhs.primitive, rhs)),\n        }\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Div;\n\n        retro_binary!(RetroDiv, B::float_div);\n\n        impl<B: Backend> Backward<B, 2> for Div {\n            type State = (Option<NodeId>, Option<NodeId>, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs, rhs, broadcast) = ops.state;\n                let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs));\n                let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs));\n                let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, rhs);\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        let rhs = rhs_4lhs.unwrap();\n                        let value = B::float_recip(rhs);\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        let rhs = rhs_4rhs.unwrap();\n                        let lhs = lhs.unwrap();\n                        let value =\n                            B::float_div(B::float_neg(lhs), B::float_powi_scalar(rhs, 2.into()));\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let lhs_tracked = lhs.is_tracked();\n        let rhs_tracked = rhs.is_tracked();\n        let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);\n\n        match Div\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroDiv::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs));\n                let rhs_state = (lhs_tracked || rhs_tracked).then(|| prep.checkpoint(&rhs));\n\n                prep.finish(\n                    (lhs_state, rhs_state, broadcast),\n                    B::float_div(lhs.primitive, rhs.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_div(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct DivScalar;\n\n        retro_unary_scalar!(RetroDivScalar, B::float_div_scalar);\n\n        impl<B: Backend> Backward<B, 1> for DivScalar {\n            type State = Scalar;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let tmp = 1.0 / ops.state.elem::<f32>();\n                    B::float_mul_scalar(grad, tmp.into())\n                });\n            }\n        }\n\n        match DivScalar\n            .prepare::<C>([lhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroDivScalar::<B>::new(lhs.node.id, rhs))\n            .parents([&lhs])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(rhs, B::float_div_scalar(lhs.primitive, rhs)),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_div_scalar(lhs.primitive, rhs)),\n        }\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Rem;\n\n        retro_binary!(RetroRem, B::float_remainder);\n\n        impl<B: Backend> Backward<B, 2> for Rem {\n            type State = (Option<NodeId>, Option<NodeId>, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs, rhs, broadcast) = ops.state;\n                let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs));\n                let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        // remainder(x, y) = x - floor(x / y) * y\n                        // partial(x - floor(x / y) * y, x) = 1\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        // partial(x - floor(x / y) * y, y) = - floor(x / y)\n                        let rhs = rhs.unwrap();\n                        let lhs = lhs.unwrap();\n                        let value = B::float_neg(B::float_floor(B::float_div(lhs, rhs)));\n                        let grad = B::float_mul(grad, value);\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let lhs_tracked = lhs.is_tracked();\n        let rhs_tracked = rhs.is_tracked();\n        let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);\n\n        match Rem\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRem::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs));\n                let rhs_state = (lhs_tracked || rhs_tracked).then(|| prep.checkpoint(&rhs));\n\n                prep.finish(\n                    (lhs_state, rhs_state, broadcast),\n                    B::float_remainder(lhs.primitive, rhs.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_remainder(lhs.primitive, rhs.primitive))\n            }\n        }\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct RemainderScalar;\n\n        retro_unary_scalar!(RetroRemainderScalar, B::float_remainder_scalar);\n\n        impl<B: Backend> Backward<B, 1> for RemainderScalar {\n            type State = ();\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| grad);\n            }\n        }\n\n        RemainderScalar\n            .prepare::<C>([lhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRemainderScalar::<B>::new(lhs.node.id, rhs))\n            .parents([&lhs])\n            .stateless(B::float_remainder_scalar(lhs.primitive, rhs))\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Matmul;\n\n        impl<B: Backend> Backward<B, 2> for Matmul {\n            type State = (Option<NodeId>, Option<NodeId>, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs, rhs, broadcast) = ops.state;\n                let lhs = lhs.map(|lhs| checkpointer.retrieve_node_output(lhs));\n                let rhs = rhs.map(|rhs| checkpointer.retrieve_node_output(rhs));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        let rhs = B::float_transpose(rhs.unwrap());\n                        let grad = B::float_matmul(grad, rhs);\n\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        let lhs = B::float_transpose(lhs.unwrap());\n                        let grad = B::float_matmul(lhs, grad);\n\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let lhs_tracked = lhs.is_tracked();\n        let rhs_tracked = rhs.is_tracked();\n        let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);\n\n        match Matmul\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs));\n                let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs));\n                prep.finish(\n                    (lhs_state, rhs_state, broadcast),\n                    B::float_matmul(lhs.primitive, rhs.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_matmul(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Cross;\n\n        impl<B: Backend> Backward<B, 2> for Cross {\n            type State = (Option<NodeId>, Option<NodeId>, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs_id, rhs_id, dim) = ops.state;\n                let lhs = lhs_id.map(|id| checkpointer.retrieve_node_output(id));\n                let rhs = rhs_id.map(|id| checkpointer.retrieve_node_output(id));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| B::float_cross(rhs.unwrap(), grad, dim),\n                    |grad| B::float_cross(grad, lhs.unwrap(), dim),\n                );\n            }\n        }\n\n        let lhs_tracked = lhs.is_tracked();\n        let rhs_tracked = rhs.is_tracked();\n\n        match Cross\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = rhs_tracked.then(|| prep.checkpoint(&lhs));\n                let rhs_state = lhs_tracked.then(|| prep.checkpoint(&rhs));\n                prep.finish(\n                    (lhs_state, rhs_state, dim),\n                    B::float_cross(lhs.primitive, rhs.primitive, dim),\n                )\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_cross(lhs.primitive, rhs.primitive, dim))\n            }\n        }\n    }\n\n    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Neg;\n\n        retro_unary!(RetroNeg, B::float_neg);\n\n        impl<B: Backend> Backward<B, 1> for Neg {\n            type State = ();\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| B::float_neg(grad));\n            }\n        }\n\n        Neg.prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroNeg::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateless(B::float_neg(tensor.primitive))\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Recip;\n\n        retro_unary!(RetroRecip, B::float_recip);\n\n        impl<B: Backend> Backward<B, 1> for Recip {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let tensor = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let tmp = B::float_powi_scalar(tensor, (-2).into());\n                    let value = B::float_neg(tmp);\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Recip\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRecip::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_recip(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_recip(tensor.primitive)),\n        }\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct SwapDim;\n\n        #[derive(new, Debug)]\n        struct RetroSwapDims<B: Backend> {\n            input_id: NodeId,\n            dim1: usize,\n            dim2: usize,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroSwapDims<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_swap_dims(input, self.dim1, self.dim2);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for SwapDim {\n            type State = (usize, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim1, dim2) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_swap_dims(grad, dim2, dim1)\n                });\n            }\n        }\n\n        match SwapDim\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSwapDims::<B>::new(tensor.node.id, dim1, dim2))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (dim1, dim2),\n                B::float_swap_dims(tensor.primitive, dim1, dim2),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_swap_dims(tensor.primitive, dim1, dim2))\n            }\n        }\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct PermuteDim;\n\n        #[derive(new, Debug)]\n        struct RetroPermuteDims<B: Backend> {\n            input_id: NodeId,\n            axes: Vec<usize>,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroPermuteDims<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_permute(input, &self.axes);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for PermuteDim {\n            type State = Vec<usize>;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let axes = ops.state;\n\n                let mut inverse = vec![0usize; axes.len()];\n                axes.iter()\n                    .enumerate()\n                    .for_each(|(i, &axis)| inverse[axis] = i);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_permute(grad, &inverse)\n                });\n            }\n        }\n\n        match PermuteDim\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroPermuteDims::<B>::new(tensor.node.id, axes.to_vec()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                prep.finish(axes.to_vec(), B::float_permute(tensor.primitive, axes))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_permute(tensor.primitive, axes)),\n        }\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct FlipDim;\n\n        #[derive(new, Debug)]\n        struct RetroFlipDims<B: Backend> {\n            input_id: NodeId,\n            axes: Vec<usize>,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroFlipDims<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_flip(input, &self.axes);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for FlipDim {\n            type State = Vec<usize>;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let axes = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_flip(grad, &axes)\n                });\n            }\n        }\n\n        match FlipDim\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroFlipDims::<B>::new(tensor.node.id, axes.to_vec()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)),\n        }\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct ReshapeDim;\n\n        #[derive(new, Debug)]\n        struct RetroReshape<B: Backend> {\n            input_id: NodeId,\n            shape: Shape,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroReshape<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_reshape(input, self.shape.clone());\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for ReshapeDim {\n            type State = (Shape, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape_original, shape) = ops.state;\n                let ndims_out = shape.num_dims();\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let shape_grad = grad.shape();\n                    let mut grad = grad;\n\n                    for i in 0..ndims_out {\n                        if shape[i] == 1 && shape_grad[i] != 1 {\n                            grad = B::float_sum_dim(grad, i);\n                        }\n                    }\n\n                    B::float_reshape(grad, shape_original)\n                });\n            }\n        }\n\n        match ReshapeDim\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroReshape::<B>::new(tensor.node.id, shape.clone()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), shape.clone()),\n                B::float_reshape(tensor.primitive, shape),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_reshape(tensor.primitive, shape)),\n        }\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<B>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Gather;\n\n        impl<B: Backend> Backward<B, 1> for Gather {\n            type State = (usize, IntTensor<B>, Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim, indices, shape, device) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let zeros = B::float_zeros(shape, &device, grad.dtype().into());\n                    B::float_scatter_add(dim, zeros, indices, grad)\n                });\n            }\n        }\n\n        match Gather\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (\n                    dim,\n                    indices.clone(),\n                    tensor.primitive.shape(),\n                    B::float_device(&tensor.primitive),\n                ),\n                B::float_gather(dim, tensor.primitive, indices),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_gather(dim, tensor.primitive, indices))\n            }\n        }\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<B>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Scatter;\n\n        impl<B: Backend> Backward<B, 2> for Scatter {\n            type State = (usize, IntTensor<B>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim, indices) = ops.state;\n                let [_, indices_4rhs] = duplicate(&ops.parents, Some(indices));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| grad,\n                    |grad| B::float_gather(dim, grad, indices_4rhs.unwrap()),\n                );\n            }\n        }\n\n        match Scatter\n            .prepare::<C>([tensor.node, value.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (dim, indices.clone()),\n                B::float_scatter_add(dim, tensor.primitive, indices, value.primitive),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_scatter_add(\n                dim,\n                tensor.primitive,\n                indices,\n                value.primitive,\n            )),\n        }\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<B>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Select;\n\n        #[derive(new, Debug)]\n        struct RetroSelect<B: Backend> {\n            input_id: NodeId,\n            dim: usize,\n            indices: IntTensor<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroSelect<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_select(input, self.dim, self.indices.clone());\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for Select {\n            type State = (usize, IntTensor<B>, Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim, indices, shape, device) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let zeros = B::float_zeros(shape, &device, grad.dtype().into());\n                    B::float_select_add(zeros, dim, indices, grad)\n                });\n            }\n        }\n\n        match Select\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSelect::<B>::new(tensor.node.id, dim, indices.clone()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (\n                    dim,\n                    indices.clone(),\n                    tensor.primitive.shape(),\n                    B::float_device(&tensor.primitive),\n                ),\n                B::float_select(tensor.primitive, dim, indices),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_select(tensor.primitive, dim, indices))\n            }\n        }\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<B>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct IndexSelectDimAssign;\n\n        #[derive(new, Debug)]\n        struct RetroSelectAssign<B: Backend> {\n            tensor_id: NodeId,\n            dim: usize,\n            indices: IntTensor<B>,\n            value_id: NodeId,\n        }\n\n        impl<B: Backend> RetroForward for RetroSelectAssign<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let tensor = states.get_state::<B::FloatTensorPrimitive>(&self.tensor_id);\n                let value = states.get_state::<B::FloatTensorPrimitive>(&self.value_id);\n                let out = B::float_select_add(tensor, self.dim, self.indices.clone(), value);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for IndexSelectDimAssign {\n            type State = (usize, IntTensor<B>);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim, indices) = ops.state;\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| grad,\n                    |grad| B::float_select(grad, dim, indices),\n                );\n            }\n        }\n\n        match IndexSelectDimAssign\n            .prepare::<C>([tensor.node.clone(), value.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSelectAssign::<B>::new(\n                tensor.node.id,\n                dim,\n                indices.clone(),\n                value.node.id,\n            ))\n            .parents([&tensor, &value])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (dim, indices.clone()),\n                B::float_select_add(tensor.primitive, dim, indices, value.primitive),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_select_add(\n                tensor.primitive,\n                dim,\n                indices,\n                value.primitive,\n            )),\n        }\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Index;\n\n        #[derive(new, Debug)]\n        struct RetroSlice<B: Backend> {\n            tensor_id: NodeId,\n            slices: Vec<Slice>,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroSlice<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let tensor = states.get_state::<B::FloatTensorPrimitive>(&self.tensor_id);\n                let out = B::float_slice(tensor, &self.slices);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for Index {\n            type State = (Vec<Slice>, Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (slices, shape, device) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let zeros = B::float_zeros(shape, &device, grad.dtype().into());\n                    B::float_slice_assign(zeros, &slices, grad)\n                });\n            }\n        }\n\n        match Index\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSlice::<B>::new(tensor.node.id, slices.to_vec()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (\n                    slices.to_vec(),\n                    tensor.primitive.shape(),\n                    B::float_device(&tensor.primitive),\n                ),\n                B::float_slice(tensor.primitive, slices),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_slice(tensor.primitive, slices)),\n        }\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct SliceAssign;\n\n        #[derive(new, Debug)]\n        struct RetroSliceAssign<B: Backend> {\n            tensor_id: NodeId,\n            slices: Vec<Slice>,\n            value_id: NodeId,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroSliceAssign<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let tensor = states.get_state::<B::FloatTensorPrimitive>(&self.tensor_id);\n                let value = states.get_state::<B::FloatTensorPrimitive>(&self.value_id);\n                let out = B::float_slice_assign(tensor, &self.slices, value);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 2> for SliceAssign {\n            type State = (Vec<Slice>, Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (slices, shape_rhs, device) = ops.state;\n                let [slices_4lhs, slices_4rhs] = duplicate(&ops.parents, Some(slices));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        let zeros = B::float_zeros(shape_rhs, &device, grad.dtype().into());\n                        B::float_slice_assign(grad, &slices_4lhs.unwrap(), zeros)\n                    },\n                    |grad| B::float_slice(grad, &slices_4rhs.unwrap()),\n                );\n            }\n        }\n\n        match SliceAssign\n            .prepare::<C>([tensor.node.clone(), value.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSliceAssign::<B>::new(\n                tensor.node.id,\n                slices.to_vec(),\n                value.node.id,\n            ))\n            .parents([&tensor, &value])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (\n                    slices.to_vec(),\n                    value.primitive.shape(),\n                    B::float_device(&value.primitive),\n                ),\n                B::float_slice_assign(tensor.primitive, slices, value.primitive),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_slice_assign(\n                tensor.primitive,\n                slices,\n                value.primitive,\n            )),\n        }\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        source: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct MaskWhere;\n\n        impl<B: Backend> Backward<B, 2> for MaskWhere {\n            type State = (BoolTensor<B>, Shape, Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (mask, shape_lhs, shape_rhs, device) = ops.state;\n                let [mask_4lhs, mask_4rhs] = duplicate(&ops.parents, Some(mask));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        let zeros = B::float_zeros(shape_lhs.clone(), &device, grad.dtype().into());\n                        let grad = B::float_mask_where(grad, mask_4lhs.unwrap(), zeros);\n\n                        broadcast_shape::<B>(grad, &shape_lhs)\n                    },\n                    |grad| {\n                        let zeros = B::float_zeros(shape_rhs.clone(), &device, grad.dtype().into());\n                        let grad = B::float_mask_where(zeros, mask_4rhs.unwrap(), grad);\n\n                        broadcast_shape::<B>(grad, &shape_rhs)\n                    },\n                );\n            }\n        }\n\n        match MaskWhere\n            .prepare::<C>([tensor.node, source.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (\n                    mask.clone(),\n                    tensor.primitive.shape(),\n                    source.primitive.shape(),\n                    B::float_device(&source.primitive),\n                ),\n                B::float_mask_where(tensor.primitive, mask, source.primitive),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_mask_where(\n                tensor.primitive,\n                mask,\n                source.primitive,\n            )),\n        }\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<B>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct MaskFill;\n\n        impl<B: Backend> Backward<B, 1> for MaskFill {\n            type State = BoolTensor<B>;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mask_fill(grad, ops.state, 0f32.into())\n                });\n            }\n        }\n\n        match MaskFill\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                mask.clone(),\n                B::float_mask_fill(tensor.primitive, mask, value),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_mask_fill(tensor.primitive, mask, value))\n            }\n        }\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<B> {\n        B::float_equal(lhs.primitive, rhs.primitive)\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {\n        B::float_equal_elem(lhs.primitive, rhs)\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<B> {\n        B::float_greater(lhs.primitive, rhs.primitive)\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {\n        B::float_greater_elem(lhs.primitive, rhs)\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<B> {\n        B::float_greater_equal(lhs.primitive, rhs.primitive)\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {\n        B::float_greater_equal_elem(lhs.primitive, rhs)\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<B> {\n        B::float_lower(lhs.primitive, rhs.primitive)\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {\n        B::float_lower_elem(lhs.primitive, rhs)\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<B> {\n        B::float_lower_equal(lhs.primitive, rhs.primitive)\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<B> {\n        B::float_lower_equal_elem(lhs.primitive, rhs)\n    }\n\n    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        B::float_is_nan(tensor.primitive)\n    }\n\n    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        B::float_is_inf(tensor.primitive)\n    }\n\n    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // When we detach a tensor, we remove it from the graph, but we still want to keep the\n        // `require_grad` setting.\n        let is_require_grad = Self::float_is_require_grad(&tensor);\n        let tensor = AutodiffTensor::new(tensor.primitive);\n\n        match is_require_grad {\n            true => tensor.require_grad(),\n            false => tensor,\n        }\n    }\n\n    fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {\n        if require_grad {\n            return tensor.require_grad();\n        }\n\n        AutodiffTensor::new(tensor.primitive)\n    }\n\n    fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {\n        matches!(tensor.node.requirement, Requirement::Grad)\n    }\n\n    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Mean;\n\n        impl<B: Backend> Backward<B, 1> for Mean {\n            type State = Shape;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let shape = ops.state;\n                    let val = 1_f64 / shape.num_elements() as f64;\n                    let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into());\n                    let val = B::float_mul_scalar(ones, val.into());\n\n                    let grad = unsqueeze_like::<B>(grad, val.shape());\n                    B::float_mul(val, grad)\n                });\n            }\n        }\n\n        match Mean.prepare::<C>([tensor.node]).compute_bound().stateful() {\n            OpsKind::Tracked(prep) => {\n                prep.finish(tensor.primitive.shape(), B::float_mean(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_mean(tensor.primitive)),\n        }\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sum;\n\n        impl<B: Backend> Backward<B, 1> for Sum {\n            type State = Shape;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let val =\n                        B::float_ones(ops.state, &B::float_device(&grad), grad.dtype().into());\n\n                    let grad = unsqueeze_like::<B>(grad, val.shape());\n                    B::float_mul(val, grad)\n                });\n            }\n        }\n\n        match Sum.prepare::<C>([tensor.node]).compute_bound().stateful() {\n            OpsKind::Tracked(prep) => {\n                prep.finish(tensor.primitive.shape(), B::float_sum(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_sum(tensor.primitive)),\n        }\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct MeanDim;\n\n        impl<B: Backend> Backward<B, 1> for MeanDim {\n            type State = (Shape, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, dim) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let val = 1_f64 / shape[dim] as f64;\n                    let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into());\n                    let val = B::float_mul_scalar(ones, val.into());\n\n                    let grad = B::float_sum_dim(grad, dim);\n                    B::float_mul(val, grad)\n                });\n            }\n        }\n\n        match MeanDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), dim),\n                B::float_mean_dim(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_mean_dim(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct SumDim;\n\n        impl<B: Backend> Backward<B, 1> for SumDim {\n            type State = (Shape, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, dim) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let ones = B::float_ones(shape, &B::float_device(&grad), grad.dtype().into());\n                    let grad = B::float_sum_dim(grad, dim);\n\n                    B::float_mul(ones, grad)\n                });\n            }\n        }\n\n        match SumDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), dim),\n                B::float_sum_dim(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_sum_dim(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct CumSum;\n\n        impl<B: Backend> Backward<B, 1> for CumSum {\n            type State = (Shape, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (_shape, dim) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // Gradient of cumsum is cumsum of gradient in reverse\n                    let grad_reversed = B::float_flip(grad.clone(), &[dim]);\n                    let grad_cumsum = B::float_cumsum(grad_reversed, dim);\n                    B::float_flip(grad_cumsum, &[dim])\n                });\n            }\n        }\n\n        match CumSum\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), dim),\n                B::float_cumsum(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct CumProd;\n\n        impl<B: Backend> Backward<B, 1> for CumProd {\n            type State = (B::FloatTensorPrimitive, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (input, dim) = ops.state;\n                let output = B::float_cumprod(input.clone(), dim);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // Gradient of cumprod using negative step slicing\n                    // Formula: grad_input[i] = sum_{j>=i}(grad_output[j] * output[j] / input[i])\n                    //        = (1 / input[i]) * sum_{j>=i}(grad_output[j] * output[j])\n                    //        = (1 / input) * reverse_cumsum(grad * output)\n                    //\n                    // LIMITATION: This produces NaN when input contains zeros.\n                    // A proper zero-safe implementation requires more sophisticated algorithms\n                    // (see PyTorch's cumprod_backward or JAX's associative_scan approach).\n                    // TODO: Implement zero-safe gradient computation.\n                    // See: https://github.com/tracel-ai/burn/issues/3864\n\n                    let grad_times_output = B::float_mul(grad, output.clone());\n\n                    // Create slices to reverse along the specified dimension\n                    let shape = grad_times_output.shape();\n                    let mut slices = vec![Slice::full(); shape.num_dims()];\n                    slices[dim] = Slice::with_step(0, None, -1);\n\n                    // Reverse, cumsum, reverse back using negative step slicing\n                    let grad_reversed = B::float_slice(grad_times_output, &slices);\n                    let grad_cumsum = B::float_cumsum(grad_reversed, dim);\n                    let grad_result = B::float_slice(grad_cumsum, &slices);\n\n                    B::float_div(grad_result, input)\n                });\n            }\n        }\n\n        match CumProd\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.clone(), dim),\n                B::float_cumprod(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cumprod(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct CumMin;\n\n        impl<B: Backend> Backward<B, 1> for CumMin {\n            type State = (B::FloatTensorPrimitive, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (input, dim) = ops.state;\n                let output = B::float_cummin(input.clone(), dim);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // Gradient flows to the input positions that produced each output\n                    // Use scatter to accumulate gradients (scatter does sum reduction)\n\n                    let shape = input.shape();\n                    let device = B::float_device(&input);\n                    let dim_size = shape[dim] as i64;\n\n                    // Create indices [0, 1, 2, ...] along the dimension\n                    let arange_1d = B::int_arange(0..dim_size, &device);\n\n                    // Reshape to broadcast along the specified dimension\n                    let mut arange_shape = vec![1; shape.num_dims()];\n                    arange_shape[dim] = dim_size as usize;\n                    let arange = B::int_reshape(arange_1d, Shape::from(arange_shape));\n\n                    // Expand to match input shape\n                    let arange = B::int_expand(arange, shape.clone());\n\n                    // Find where cummin[i] == input[i] (these are source positions)\n                    let is_source = B::float_equal(output.clone(), input.clone());\n                    let is_source_int = B::bool_into_int(is_source);\n\n                    // Mask: where is_source, use index; else 0\n                    let masked_indices = B::int_mul(arange, is_source_int);\n\n                    // Cummax propagates the last valid (non-zero) index forward\n                    let source_indices = B::int_cummax(masked_indices, dim);\n\n                    // Scatter gradients to source positions (sum reduction)\n                    let zeros = B::float_zeros(shape, &device, grad.dtype().into());\n                    B::float_scatter_add(dim, zeros, source_indices, grad)\n                });\n            }\n        }\n\n        match CumMin\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.clone(), dim),\n                B::float_cummin(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cummin(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct CumMax;\n\n        impl<B: Backend> Backward<B, 1> for CumMax {\n            type State = (B::FloatTensorPrimitive, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (input, dim) = ops.state;\n                let output = B::float_cummax(input.clone(), dim);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // Gradient flows to the input positions that produced each output\n                    // Use scatter to accumulate gradients (scatter does sum reduction)\n\n                    let shape = input.shape();\n                    let device = B::float_device(&input);\n                    let dim_size = shape[dim] as i64;\n\n                    // Create indices [0, 1, 2, ...] along the dimension\n                    let arange_1d = B::int_arange(0..dim_size, &device);\n\n                    // Reshape to broadcast along the specified dimension\n                    let mut arange_shape = vec![1; shape.num_dims()];\n                    arange_shape[dim] = dim_size as usize;\n                    let arange = B::int_reshape(arange_1d, Shape::from(arange_shape));\n\n                    // Expand to match input shape\n                    let arange = B::int_expand(arange, shape.clone());\n\n                    // Find where cummax[i] == input[i] (these are source positions)\n                    let is_source = B::float_equal(output.clone(), input.clone());\n                    let is_source_int = B::bool_into_int(is_source);\n\n                    // Mask: where is_source, use index; else 0\n                    let masked_indices = B::int_mul(arange, is_source_int);\n\n                    // Cummax propagates the last valid (non-zero) index forward\n                    let source_indices = B::int_cummax(masked_indices, dim);\n\n                    // Scatter gradients to source positions (sum reduction)\n                    let zeros = B::float_zeros(shape, &device, grad.dtype().into());\n                    B::float_scatter_add(dim, zeros, source_indices, grad)\n                });\n            }\n        }\n\n        match CumMax\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.clone(), dim),\n                B::float_cummax(tensor.primitive, dim),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cummax(tensor.primitive, dim)),\n        }\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<B> {\n        B::float_argmax(tensor.primitive, dim)\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<B> {\n        B::float_argmin(tensor.primitive, dim)\n    }\n\n    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Exp;\n\n        retro_unary!(RetroExp, B::float_exp);\n\n        impl<B: Backend> Backward<B, 1> for Exp {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                let output = B::float_exp(input);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mul(grad, output)\n                });\n            }\n        }\n\n        match Exp\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroExp::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_exp(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_exp(tensor.primitive)),\n        }\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Log;\n\n        retro_unary!(RetroLog, B::float_log);\n\n        impl<B: Backend> Backward<B, 1> for Log {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_recip(input);\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Log\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroLog::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_log(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_log(tensor.primitive)),\n        }\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Log1P;\n\n        retro_unary!(RetroLog1P, B::float_log1p);\n\n        impl<B: Backend> Backward<B, 1> for Log1P {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_add_scalar(input, 1f32.into());\n                    let value = B::float_recip(value);\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Log1P\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroLog1P::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_log1p(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_log1p(tensor.primitive)),\n        }\n    }\n\n    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct PowfScalar;\n\n        #[derive(new, Debug)]\n        struct RetroPowfScalar<B: Backend> {\n            lhs_id: NodeId,\n            rhs: f64,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroPowfScalar<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let lhs = states.get_state::<B::FloatTensorPrimitive>(&self.lhs_id);\n                let out = B::float_powf_scalar(lhs, self.rhs.into());\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for PowfScalar {\n            type State = (NodeId, f64);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (tensor_id, value) = ops.state;\n                let tensor = checkpointer.retrieve_node_output(tensor_id);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let tmp = B::float_powf_scalar(tensor, (value - 1.).into());\n                    let value = B::float_mul_scalar(tmp, value.into());\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match PowfScalar\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroPowfScalar::<B>::new(tensor.node.id, value.elem()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = (prep.checkpoint(&tensor), value.elem());\n                prep.finish(state, B::float_powf_scalar(tensor.primitive, value))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_powf_scalar(tensor.primitive, value)),\n        }\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sqrt;\n\n        retro_unary!(RetroSqrt, B::float_sqrt);\n\n        impl<B: Backend> Backward<B, 1> for Sqrt {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_div_scalar(\n                        B::float_powf_scalar(input, (-0.5).into()),\n                        2f32.into(),\n                    );\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Sqrt\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSqrt::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_sqrt(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_sqrt(tensor.primitive)),\n        }\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Abs;\n\n        retro_unary!(RetroAbs, B::float_abs);\n\n        impl<B: Backend> Backward<B, 1> for Abs {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let tensor: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(ops.state);\n                let state = B::float_sign(tensor);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mul(grad, state)\n                });\n            }\n        }\n\n        match Abs\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAbs::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_abs(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_abs(tensor.primitive)),\n        }\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Cos;\n\n        retro_unary!(RetroCos, B::float_cos);\n\n        impl<B: Backend> Backward<B, 1> for Cos {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_neg(B::float_sin(input));\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Cos\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroCos::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_cos(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cos(tensor.primitive)),\n        }\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sin;\n\n        retro_unary!(RetroSin, B::float_sin);\n\n        impl<B: Backend> Backward<B, 1> for Sin {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let state = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_cos(state);\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Sin\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSin::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_sin(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_sin(tensor.primitive)),\n        }\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Tanh;\n\n        retro_unary!(RetroTanh, B::float_tanh);\n\n        impl<B: Backend> Backward<B, 1> for Tanh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                let state = B::float_tanh(input);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let value = B::float_add_scalar(\n                        B::float_neg(B::float_powi_scalar(state, 2.into())),\n                        1f32.into(),\n                    );\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Tanh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroTanh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_tanh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_tanh(tensor.primitive)),\n        }\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Cosh;\n\n        retro_unary!(RetroCosh, B::float_cosh);\n\n        impl<B: Backend> Backward<B, 1> for Cosh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mul(grad, B::float_sinh(input))\n                });\n            }\n        }\n\n        match Cosh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroCosh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_cosh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cosh(tensor.primitive)),\n        }\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sinh;\n\n        retro_unary!(RetroSinh, B::float_sinh);\n\n        impl<B: Backend> Backward<B, 1> for Sinh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_mul(grad, B::float_cosh(input))\n                });\n            }\n        }\n\n        match Sinh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSinh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_sinh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_sinh(tensor.primitive)),\n        }\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Tan;\n\n        retro_unary!(RetroTan, B::float_tan);\n\n        impl<B: Backend> Backward<B, 1> for Tan {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                let tan_x = B::float_tan(input);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx tan(x) = 1 + tan^2(x)\n                    let tan_sq = B::float_powi_scalar(tan_x, 2.into());\n                    B::float_mul(grad, B::float_add_scalar(tan_sq, 1f32.into()))\n                });\n            }\n        }\n\n        match Tan\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroTan::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_tan(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_tan(tensor.primitive)),\n        }\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Asin;\n\n        retro_unary!(RetroAsin, B::float_asin);\n\n        impl<B: Backend> Backward<B, 1> for Asin {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx asin(x) = 1/sqrt(1 - x^2)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let denom = B::float_sqrt(B::float_add_scalar(B::float_neg(x_sq), 1f32.into()));\n                    B::float_mul(grad, B::float_recip(denom))\n                });\n            }\n        }\n\n        match Asin\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAsin::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_asin(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_asin(tensor.primitive)),\n        }\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Acos;\n\n        retro_unary!(RetroAcos, B::float_acos);\n\n        impl<B: Backend> Backward<B, 1> for Acos {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx acos(x) = -1/sqrt(1 - x^2)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let denom = B::float_sqrt(B::float_add_scalar(B::float_neg(x_sq), 1f32.into()));\n                    let value = B::float_neg(B::float_recip(denom));\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Acos\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAcos::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_acos(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_acos(tensor.primitive)),\n        }\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Atan;\n\n        retro_unary!(RetroAtan, B::float_atan);\n\n        impl<B: Backend> Backward<B, 1> for Atan {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx atan(x) = 1/(1 + x^2)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let value = B::float_recip(B::float_add_scalar(x_sq, 1f32.into()));\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Atan\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAtan::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_atan(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_atan(tensor.primitive)),\n        }\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Asinh;\n\n        retro_unary!(RetroAsinh, B::float_asinh);\n\n        impl<B: Backend> Backward<B, 1> for Asinh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx asinh(x) = 1/sqrt(x^2 + 1)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let value =\n                        B::float_recip(B::float_sqrt(B::float_add_scalar(x_sq, 1f32.into())));\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Asinh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAsinh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_asinh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_asinh(tensor.primitive)),\n        }\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Acosh;\n\n        retro_unary!(RetroAcosh, B::float_acosh);\n\n        impl<B: Backend> Backward<B, 1> for Acosh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx acosh(x) = 1/sqrt(x^2 - 1)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let value =\n                        B::float_recip(B::float_sqrt(B::float_sub_scalar(x_sq, 1f32.into())));\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Acosh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAcosh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_acosh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_acosh(tensor.primitive)),\n        }\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Atanh;\n\n        retro_unary!(RetroAtanh, B::float_atanh);\n\n        impl<B: Backend> Backward<B, 1> for Atanh {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let input = checkpointer.retrieve_node_output(ops.state);\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    // d/dx atanh(x) = 1/(1 - x^2)\n                    let x_sq = B::float_powi_scalar(input, 2.into());\n                    let value =\n                        B::float_recip(B::float_add_scalar(B::float_neg(x_sq), 1f32.into()));\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Atanh\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAtanh::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_atanh(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_atanh(tensor.primitive)),\n        }\n    }\n\n    fn float_atan2(y: FloatTensor<Self>, x: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Atan2;\n\n        retro_binary!(RetroAtan2, B::float_atan2);\n\n        impl<B: Backend> Backward<B, 2> for Atan2 {\n            type State = (Option<NodeId>, Option<NodeId>, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (y_id, x_id, broadcast) = ops.state;\n                let y = y_id.map(|id| checkpointer.retrieve_node_output(id));\n                let x = x_id.map(|id| checkpointer.retrieve_node_output(id));\n                let [y_4y, y_4x] = duplicate(&ops.parents, y);\n                let [x_4y, x_4x]: [Option<FloatTensor<B>>; 2] = duplicate(&ops.parents, x);\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        // d/dy atan2(y, x) = x/(x^2 + y^2)\n                        let y = y_4y.unwrap();\n                        let x = x_4y.unwrap();\n                        let x_sq = B::float_powi_scalar(x.clone(), 2.into());\n                        let y_sq = B::float_powi_scalar(y, 2.into());\n                        let denom = B::float_add(x_sq, y_sq);\n                        let value = B::float_div(x, denom);\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        // d/dx atan2(y, x) = -y/(x^2 + y^2)\n                        let y = y_4x.unwrap();\n                        let x = x_4x.unwrap();\n                        let x_sq = B::float_powi_scalar(x, 2.into());\n                        let y_sq = B::float_powi_scalar(y.clone(), 2.into());\n                        let denom = B::float_add(x_sq, y_sq);\n                        let value = B::float_neg(B::float_div(y, denom));\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let y_tracked = y.is_tracked();\n        let x_tracked = x.is_tracked();\n        let broadcast = BinaryOpsBroadcast::new::<B>(&y.primitive, &x.primitive);\n\n        match Atan2\n            .prepare::<C>([y.node.clone(), x.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroAtan2::<B>::new(y.node.id, x.node.id))\n            .parents([&y, &x])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let is_tracked = y_tracked || x_tracked;\n                let y_state = is_tracked.then(|| prep.checkpoint(&y));\n                let x_state = is_tracked.then(|| prep.checkpoint(&x));\n\n                prep.finish(\n                    (y_state, x_state, broadcast),\n                    B::float_atan2(y.primitive, x.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_atan2(y.primitive, x.primitive)),\n        }\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Round;\n        retro_unary!(RetroRound, B::float_round);\n\n        impl<B: Backend> Backward<B, 1> for Round {\n            type State = (Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, device) = ops.state;\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_zeros(shape, &device, grad.dtype().into())\n                })\n            }\n        }\n\n        match Round\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRound::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (tensor.primitive.shape(), B::float_device(&tensor.primitive)),\n                B::float_round(tensor.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)),\n        }\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Floor;\n        retro_unary!(RetroFloor, B::float_floor);\n\n        impl<B: Backend> Backward<B, 1> for Floor {\n            type State = (Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, device) = ops.state;\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_zeros(shape, &device, grad.dtype().into())\n                })\n            }\n        }\n\n        match Floor\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroFloor::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (tensor.primitive.shape(), B::float_device(&tensor.primitive)),\n                B::float_floor(tensor.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),\n        }\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Ceil;\n        retro_unary!(RetroCeil, B::float_ceil);\n\n        impl<B: Backend> Backward<B, 1> for Ceil {\n            type State = (Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, device) = ops.state;\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_zeros(shape, &device, grad.dtype().into())\n                })\n            }\n        }\n\n        match Ceil\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroCeil::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (tensor.primitive.shape(), B::float_device(&tensor.primitive)),\n                B::float_ceil(tensor.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_ceil(tensor.primitive)),\n        }\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Trunc;\n        retro_unary!(RetroTrunc, B::float_trunc);\n\n        impl<B: Backend> Backward<B, 1> for Trunc {\n            type State = (Shape, B::Device);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape, device) = ops.state;\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_zeros(shape, &device, grad.dtype().into())\n                })\n            }\n        }\n\n        match Trunc\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroTrunc::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(preps) => preps.finish(\n                (tensor.primitive.shape(), B::float_device(&tensor.primitive)),\n                B::float_trunc(tensor.primitive),\n            ),\n            OpsKind::UnTracked(preps) => preps.finish(B::float_trunc(tensor.primitive)),\n        }\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Erf;\n\n        retro_unary!(RetroErf, B::float_erf);\n\n        impl<B: Backend> Backward<B, 1> for Erf {\n            type State = NodeId;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let ops = checkpointer.retrieve_node_output(ops.state);\n                    let exponent = B::float_neg(B::float_powi_scalar(ops, 2.into()));\n                    let numerator = B::float_mul_scalar(B::float_exp(exponent), 2.0.into());\n                    let denominator = core::f64::consts::PI.sqrt().into();\n                    let value = B::float_div_scalar(numerator, denominator);\n\n                    B::float_mul(grad, value)\n                });\n            }\n        }\n\n        match Erf\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroErf::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let state = prep.checkpoint(&tensor);\n                prep.finish(state, B::float_erf(tensor.primitive))\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_erf(tensor.primitive)),\n        }\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CatStep<B: Backend> {\n            nodes: Vec<Option<NodeRef>>,\n            // The dimension of each tensor along the dim dimension.\n            // This indicates the number of dimension concatenated for each tensor.\n            dim_sizes: Vec<usize>,\n            output: NodeRef,\n            phantom: PhantomData<B>,\n            dim: usize,\n            parents: Vec<Parent>,\n        }\n\n        impl<B: Backend> Step for CatStep<B> {\n            fn step(self: Box<Self>, grads: &mut Gradients, _checkpointer: &mut Checkpointer) {\n                let grad = grads.consume::<B>(&self.output);\n                let ranges_template: Vec<_> = grad.shape().iter().map(|&v| 0..v).collect();\n\n                self.nodes\n                    .into_iter()\n                    .zip(self.dim_sizes)\n                    .scan(0, |offset, (node_opt, dim_size)| {\n                        let start = *offset;\n                        let end = start + dim_size;\n                        *offset = end;\n                        Some(node_opt.map(|node| (node, start, end)))\n                    })\n                    .flatten()\n                    .for_each(|(node, start, end)| {\n                        let mut ranges = ranges_template.clone();\n                        ranges[self.dim] = start..end;\n\n                        let slices: Vec<Slice> = ranges\n                            .iter()\n                            .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1))\n                            .collect();\n                        grads.register::<B>(node.id, B::float_slice(grad.clone(), &slices));\n                    });\n            }\n\n            fn node(&self) -> NodeId {\n                self.output.id\n            }\n\n            fn parents(&self) -> &[Parent] {\n                &self.parents\n            }\n            fn depth(&self) -> usize {\n                self.output.order\n            }\n        }\n\n        let mut nodes = Vec::with_capacity(tensors.len());\n        let mut primitives = Vec::with_capacity(tensors.len());\n        let mut dim_sizes = Vec::with_capacity(tensors.len());\n\n        tensors.into_iter().for_each(|tensor| {\n            dim_sizes.push(tensor.primitive.shape()[dim]);\n            nodes.push(tensor.node);\n            primitives.push(tensor.primitive);\n        });\n\n        let requirement = Requirement::from_nodes(&nodes);\n\n        // For simplicity, this operation does not checkpoint anything\n        let cat_computing_property = ComputingProperty::Ambiguous;\n        let checkpointer_builder = CheckpointerBuilder::default();\n\n        let output = B::float_cat(primitives, dim);\n        if requirement.is_none() {\n            return AutodiffTensor::from_parents(\n                output,\n                &nodes,\n                requirement,\n                cat_computing_property,\n            );\n        }\n\n        let output =\n            AutodiffTensor::from_parents(output, &nodes, requirement, cat_computing_property);\n\n        let mut parents = Vec::new();\n\n        let nodes = nodes\n            .into_iter()\n            .map(|node| node.clone_if_require_grad())\n            .collect::<Vec<_>>();\n        for node in nodes.iter().flatten() {\n            parents.push(Parent { id: node.id });\n        }\n        let ops = CatStep::<B>::new(nodes, dim_sizes, output.node.clone(), dim, parents);\n        output.register_step(ops, checkpointer_builder)\n    }\n\n    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        match MaxMinDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim);\n                prep.finish((index, shape, dim), tensor)\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_max_dim(tensor.primitive, dim)),\n        }\n    }\n    fn float_max_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<B>) {\n        match MaxMinDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim);\n                let tensor = prep.finish((index.clone(), shape, dim), tensor);\n\n                (tensor, index)\n            }\n            OpsKind::UnTracked(prep) => {\n                let (tensor, index) = B::float_max_dim_with_indices(tensor.primitive, dim);\n                let tensor = prep.finish(tensor);\n\n                (tensor, index)\n            }\n        }\n    }\n\n    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        match MaxMinDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim);\n                prep.finish((index, shape, dim), tensor)\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_min_dim(tensor.primitive, dim)),\n        }\n    }\n    fn float_min_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<B>) {\n        match MaxMinDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim);\n                let tensor = prep.finish((index.clone(), shape, dim), tensor);\n\n                (tensor, index)\n            }\n            OpsKind::UnTracked(prep) => {\n                let (tensor, index) = B::float_min_dim_with_indices(tensor.primitive, dim);\n                let tensor = prep.finish(tensor);\n\n                (tensor, index)\n            }\n        }\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> <Autodiff<B> as Backend>::IntTensorPrimitive {\n        B::float_into_int(tensor.primitive)\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct PowF;\n\n        retro_binary!(RetroPowf, B::float_powf);\n\n        impl<B: Backend> Backward<B, 2> for PowF {\n            type State = (NodeId, NodeId, BinaryOpsBroadcast);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 2>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                let (lhs_id, rhs_id, broadcast) = ops.state;\n                let lhs: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(lhs_id);\n                let rhs: B::FloatTensorPrimitive = checkpointer.retrieve_node_output(rhs_id);\n\n                // Both lhs and rhs are needed for both lhs and rhs gradients, but we clone them\n                // the number of times required by the parents specification.\n                let [rhs_4lhs, rhs_4rhs] = duplicate(&ops.parents, Some(rhs));\n                let [lhs_4lhs, lhs_4rhs] = duplicate(&ops.parents, Some(lhs));\n\n                binary::<B, _, _>(\n                    ops.parents,\n                    ops.node,\n                    grads,\n                    |grad| {\n                        //rhs*(lhs.val**(rhs-1))*grad\n                        let rhs1 = rhs_4lhs.unwrap();\n                        let rhs2 = rhs1.clone();\n                        let lhs = lhs_4lhs.unwrap();\n\n                        let tmp = B::float_powf(lhs, B::float_sub_scalar(rhs1, 1.0.into()));\n                        let value = B::float_mul(tmp, rhs2);\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_lhs::<B>(grad)\n                    },\n                    |grad| {\n                        //lhs**rhs * ln(lhs) * grad\n                        let rhs = rhs_4rhs.unwrap();\n                        let lhs1 = lhs_4rhs.unwrap();\n                        let lhs2 = lhs1.clone();\n                        let tmp = B::float_powf(lhs1, rhs);\n                        let value = B::float_mul(tmp, B::float_log(lhs2));\n                        let grad = B::float_mul(grad, value);\n\n                        broadcast.backward_rhs::<B>(grad)\n                    },\n                );\n            }\n        }\n\n        let broadcast = BinaryOpsBroadcast::new::<B>(&lhs.primitive, &rhs.primitive);\n\n        match PowF\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroPowf::<B>::new(lhs.node.id, rhs.node.id))\n            .parents([&lhs, &rhs])\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                let lhs_state = prep.checkpoint(&lhs);\n                let rhs_state = prep.checkpoint(&rhs);\n                prep.finish(\n                    (lhs_state, rhs_state, broadcast),\n                    B::float_powf(lhs.primitive, rhs.primitive),\n                )\n            }\n            OpsKind::UnTracked(prep) => prep.finish(B::float_powf(lhs.primitive, rhs.primitive)),\n        }\n    }\n\n    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Sign;\n\n        retro_unary!(RetroSign, B::float_sign);\n\n        impl<B: Backend> Backward<B, 1> for Sign {\n            type State = ();\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                unary::<B, _>(ops.parents, ops.node, grads, |grad|\n                        // Always return 0 because the derivative of the sign function\n                        // does not contribute to gradient updates in a meaningful way.\n                        B::float_mul_scalar(grad, 0f32.into()));\n            }\n        }\n\n        Sign.prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroSign::<B>::new(tensor.node.id))\n            .parents([&tensor])\n            .stateless(B::float_sign(tensor.primitive))\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        // D1: tensor, D2: shape\n        #[derive(Debug)]\n        struct ExpandDim;\n\n        #[derive(new, Debug)]\n        struct RetroExpand<B: Backend> {\n            input_id: NodeId,\n            shape: Shape,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroExpand<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let input = states.get_state::<B::FloatTensorPrimitive>(&self.input_id);\n                let out = B::float_expand(input, self.shape.clone());\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for ExpandDim {\n            type State = (Shape, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape_in, shape_out) = ops.state;\n                let ndims_in = shape_in.num_dims();\n                let ndims_out = shape_out.num_dims();\n\n                let mut shape_expanded = vec![1; ndims_out];\n\n                debug_assert!(ndims_out >= ndims_in);\n\n                for i in 0..ndims_in {\n                    shape_expanded[i + (ndims_out - ndims_in)] = shape_in[i];\n                }\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let shape_grad = grad.shape();\n                    let mut grad = grad;\n\n                    #[allow(clippy::needless_range_loop)]\n                    for i in 0..ndims_out {\n                        if shape_expanded[i] == 1 && shape_grad[i] != 1 {\n                            grad = B::float_sum_dim(grad, i);\n                        }\n                    }\n\n                    B::float_reshape(grad, shape_in)\n                });\n            }\n        }\n\n        match ExpandDim\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroExpand::<B>::new(tensor.node.id, shape.clone()))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), shape.clone()),\n                B::float_expand(tensor.primitive, shape),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_expand(tensor.primitive, shape)),\n        }\n    }\n\n    fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {\n        match super::sort::SortDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, indices) =\n                    B::float_sort_with_indices(tensor.primitive, dim, descending);\n                prep.finish((indices, shape, dim), tensor)\n            }\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_sort(tensor.primitive, dim, descending))\n            }\n        }\n    }\n\n    fn float_sort_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        descending: bool,\n    ) -> (FloatTensor<Self>, IntTensor<B>) {\n        match super::sort::SortDim\n            .prepare::<C>([tensor.node])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => {\n                let shape = tensor.primitive.shape();\n                let (tensor, indices) =\n                    B::float_sort_with_indices(tensor.primitive, dim, descending);\n                let tensor = prep.finish((indices.clone(), shape, dim), tensor);\n\n                (tensor, indices)\n            }\n            OpsKind::UnTracked(prep) => {\n                let (tensor, indices) =\n                    B::float_sort_with_indices(tensor.primitive, dim, descending);\n                let tensor = prep.finish(tensor);\n\n                (tensor, indices)\n            }\n        }\n    }\n\n    fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<B> {\n        B::float_argsort(tensor.primitive, dim, descending)\n    }\n\n    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Repeat;\n\n        #[derive(new, Debug)]\n        struct RetroRepeat<B: Backend> {\n            tensor_id: NodeId,\n            dim: usize,\n            times: usize,\n            _backend: PhantomData<B>,\n        }\n\n        impl<B: Backend> RetroForward for RetroRepeat<B> {\n            fn forward(&self, states: &mut BackwardStates, out_node: NodeId) {\n                let tensor = states.get_state::<B::FloatTensorPrimitive>(&self.tensor_id);\n                let out = B::float_repeat_dim(tensor, self.dim, self.times);\n                states.save(out_node, out)\n            }\n        }\n\n        impl<B: Backend> Backward<B, 1> for Repeat {\n            type State = (usize, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (dim, times) = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let mut dims = grad.shape();\n                    let orig_dim_size = dims[dim] / times;\n                    if orig_dim_size > 1 {\n                        dims[dim] = orig_dim_size;\n                        let orig_dims = dims.clone();\n                        dims.insert(dim + 1, times); // shape [..., orig_dim_size, times, ...]\n                        let grad = B::float_reshape(grad, dims);\n                        let grad = B::float_sum_dim(grad, dim + 1); // sum over repeat times\n                        B::float_reshape(grad, orig_dims)\n                    } else {\n                        B::float_sum_dim(grad, dim)\n                    }\n                });\n            }\n        }\n\n        match Repeat\n            .prepare::<C>([tensor.node.clone()])\n            .memory_bound()\n            .retro_forward(RetroRepeat::<B>::new(tensor.node.id, dim, times))\n            .parents([&tensor])\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (dim, times),\n                B::float_repeat_dim(tensor.primitive, dim, times),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))\n            }\n        }\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_std::FloatDType) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Cast;\n\n        impl<B: Backend> Backward<B, 1> for Cast {\n            type State = FloatDType;\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let dtype = ops.state;\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    B::float_cast(grad, dtype)\n                });\n            }\n        }\n\n        match Cast\n            .prepare::<C>([tensor.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                tensor.dtype().into(),\n                B::float_cast(tensor.primitive, dtype),\n            ),\n            OpsKind::UnTracked(prep) => prep.finish(B::float_cast(tensor.primitive, dtype)),\n        }\n    }\n\n    // TODO: Implement float_prod and float_sum\n    // https://github.com/tracel-ai/burn/issues/1458\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        #[derive(Debug)]\n        struct Unfold;\n\n        impl<B: Backend> Backward<B, 1> for Unfold {\n            type State = (Shape, usize, usize, usize);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 1>,\n                grads: &mut Gradients,\n                _checkpointer: &mut Checkpointer,\n            ) {\n                let (shape_in, dim, size, step) = ops.state;\n                let windows = calculate_unfold_windows(shape_in[dim], size, step);\n\n                unary::<B, _>(ops.parents, ops.node, grads, |grad| {\n                    let device = B::float_device(&grad);\n                    let mut grad_input =\n                        B::float_zeros(shape_in.clone(), &device, grad.dtype().into());\n\n                    if windows == 0 {\n                        return grad_input;\n                    }\n\n                    let ndims_in = shape_in.num_dims();\n                    let ndims_out = grad.shape().num_dims();\n\n                    let mut target_shape = shape_in.clone();\n                    target_shape[dim] = size;\n\n                    for window_idx in 0..windows {\n                        let mut slices_out = vec![Slice::new(0, None, 1); ndims_out];\n                        let start = window_idx * step;\n                        let end = start + size;\n                        slices_out[dim] =\n                            Slice::new(window_idx as isize, Some((window_idx + 1) as isize), 1);\n\n                        let window_grad = B::float_slice(grad.clone(), &slices_out);\n\n                        let last_axis = ndims_out - 1;\n                        let mut permutation: Vec<usize> = (0..dim).collect();\n                        permutation.push(last_axis);\n                        permutation.extend(dim + 1..last_axis);\n                        permutation.push(dim);\n\n                        let window_grad = B::float_permute(window_grad, &permutation);\n                        let window_grad = B::float_reshape(window_grad, target_shape.clone());\n\n                        let mut slices_in = vec![Slice::new(0, None, 1); ndims_in];\n                        slices_in[dim] = Slice::new(start as isize, Some(end as isize), 1);\n\n                        let current = B::float_slice(grad_input.clone(), &slices_in);\n                        let updated = B::float_add(current, window_grad);\n                        grad_input = B::float_slice_assign(grad_input, &slices_in, updated);\n                    }\n\n                    grad_input\n                });\n            }\n        }\n\n        match Unfold\n            .prepare::<C>([tensor.node.clone()])\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(prep) => prep.finish(\n                (tensor.primitive.shape(), dim, size, step),\n                B::float_unfold(tensor.primitive, dim, size, step),\n            ),\n            OpsKind::UnTracked(prep) => {\n                prep.finish(B::float_unfold(tensor.primitive, dim, size, step))\n            }\n        }\n    }\n}\n\n#[derive(Debug, Clone)]\nenum BinaryOpsBroadcast {\n    Broadcasted(Shape, Shape),\n    None,\n}\n\nimpl BinaryOpsBroadcast {\n    fn new<B: Backend>(lhs: &B::FloatTensorPrimitive, rhs: &B::FloatTensorPrimitive) -> Self {\n        let shape_lhs = lhs.shape();\n        let shape_rhs = rhs.shape();\n        let ndims = shape_lhs.num_dims();\n\n        for i in 0..ndims {\n            if shape_rhs[i] != shape_lhs[i] {\n                return Self::Broadcasted(shape_lhs, shape_rhs);\n            }\n        }\n\n        Self::None\n    }\n\n    fn backward_lhs<B: Backend>(&self, grad: B::FloatTensorPrimitive) -> B::FloatTensorPrimitive {\n        match self {\n            BinaryOpsBroadcast::Broadcasted(lhs, _rhs) => broadcast_shape::<B>(grad, lhs),\n            BinaryOpsBroadcast::None => grad,\n        }\n    }\n\n    fn backward_rhs<B: Backend>(&self, grad: B::FloatTensorPrimitive) -> B::FloatTensorPrimitive {\n        match self {\n            BinaryOpsBroadcast::Broadcasted(_lhs, rhs) => broadcast_shape::<B>(grad, rhs),\n            BinaryOpsBroadcast::None => grad,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/ops/transaction.rs",
    "content": "use burn_backend::{\n    Backend, ExecutionError,\n    ops::{TransactionOps, TransactionPrimitive},\n};\n\nuse crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};\n\nimpl<B: Backend, C: CheckpointStrategy> TransactionOps<Self> for Autodiff<B, C> {\n    async fn tr_execute(\n        transaction: TransactionPrimitive<Self>,\n    ) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {\n        B::tr_execute(TransactionPrimitive::new(\n            transaction\n                .read_floats\n                .into_iter()\n                .map(|t| t.primitive)\n                .collect(),\n            transaction.read_qfloats,\n            transaction.read_ints,\n            transaction.read_bools,\n        ))\n        .await\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/client.rs",
    "content": "use crate::{\n    checkpoint::builder::CheckpointerBuilder,\n    grads::Gradients,\n    graph::StepBoxed,\n    tensor::{AutodiffTensor, NodeRefCount},\n};\nuse burn_backend::Backend;\n\n/// Client used to communicate with the autodiff server.\npub trait AutodiffClient: Send + Clone {\n    /// Register a new step.\n    fn register(&self, node_id: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder);\n    /// Call backpropagation from the given tensor.\n    fn backward<B: Backend>(&self, tensor: AutodiffTensor<B>) -> Gradients;\n}\n\n/// Client implementation in used.\npub type AutodiffClientImpl = super::graph::GraphMutexClient;\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/graph.rs",
    "content": "use super::{AutodiffClient, server::AutodiffServer};\nuse crate::{\n    NodeId,\n    checkpoint::builder::CheckpointerBuilder,\n    grads::Gradients,\n    graph::{Parent, StepBoxed},\n    runtime::server::NodeCleaner,\n    tensor::{AutodiffTensor, NodeRefCount},\n};\nuse alloc::sync::Arc;\nuse alloc::vec::Vec;\nuse burn_backend::Backend;\nuse hashbrown::{HashMap, HashSet};\n\n#[cfg(feature = \"std\")]\nuse parking_lot::{Mutex, MutexGuard};\n\n#[cfg(not(feature = \"std\"))]\nuse spin::{Mutex, MutexGuard};\n\n/// A client for managing multiple graphs using mutex-based synchronization.\n///\n/// The biggest benefit of using this client implementation is that each graph can modify its own\n/// data without blocking other graphs, which is essential for multi-device training.\n///\n/// # Notes\n///\n/// The [AutodiffServer] fully supports multiple graphs with sharing nodes, however those type of\n/// graphs will be stored under a single mutex-protected graph by the client, limiting\n/// parallelisation.\n#[derive(Clone, new, Debug)]\npub struct GraphMutexClient;\n\n/// Manages a collection of graphs, mapping [node ids](NodeId) to their respective graph.\n///\n/// The `GraphLocator` is responsible for selecting and merging graphs based on their IDs and parent\n/// dependencies, ensuring proper synchronization and server allocation.\n///\n/// # Notes\n///\n/// Multiple node ids can point to the same graph, where the autodiff graph is stored.\n#[derive(Default)]\npub struct GraphLocator {\n    graphs: HashMap<NodeId, Arc<Graph>>,\n    /// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.\n    /// This is to ensure that when merging graphs, we correctly move all previous graphs to\n    /// the new merged one.\n    keys: HashMap<NodeId, HashSet<NodeId>>,\n}\n\n/// Represents a single computation graph with a mutex-protected server.\n///\n/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was\n/// first created.\npub(crate) struct Graph {\n    origin: NodeId,\n    state: Mutex<GraphState>,\n}\n\n#[derive(Default)]\nstruct GraphState {\n    server: AutodiffServer,\n}\n\nimpl core::fmt::Debug for Graph {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.debug_struct(\"Graph\")\n            .field(\"origin\", &self.origin)\n            .finish()\n    }\n}\n\nstatic STATE: Mutex<Option<GraphLocator>> = Mutex::new(None);\n\nimpl GraphMutexClient {\n    /// Retrieves or creates a graph for the given [NodeId] and parent dependencies.\n    ///\n    /// # Parameters\n    /// - `node`: The unique identifier for the stream.\n    /// - `parents`: A slice of parent nodes that the stream depends on.\n    ///\n    /// # Returns\n    /// An `Arc<Graph>` representing the selected or newly created stream.\n    fn graph(node: NodeId, parents: &[Parent]) -> Arc<Graph> {\n        let mut state = STATE.lock();\n\n        match state.as_mut() {\n            Some(locator) => locator.select(node, parents),\n            None => {\n                let mut locator = GraphLocator::default();\n                let stream = locator.select(node, parents);\n                *state = Some(locator);\n                stream\n            }\n        }\n    }\n}\n\nimpl AutodiffClient for GraphMutexClient {\n    fn register(&self, node_id_ref: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {\n        let node_id = *node_id_ref;\n        let graph = GraphMutexClient::graph(node_id, step.parents());\n        let mut state = graph.state.lock();\n\n        state.server.register(node_id_ref, step, actions);\n    }\n\n    fn backward<B: Backend>(&self, root: AutodiffTensor<B>) -> Gradients {\n        let node_id = root.node.id;\n        let graph = GraphMutexClient::graph(root.node.id, &[]);\n\n        let grads = Gradients::new::<B>(root.node, root.primitive);\n        let grads = {\n            let mut state = graph.state.lock();\n            state.server.backward::<GraphCleaner>(grads, node_id)\n        }; // lock released\n\n        GraphCleaner::cleanup_orphaned_entries();\n\n        grads\n    }\n}\n\nstruct GraphCleaner<'a> {\n    guard: MutexGuard<'a, Option<GraphLocator>>,\n}\n\nimpl<'a> GraphCleaner<'a> {\n    fn cleanup_orphaned_entries() {\n        let graphs = {\n            // Get the available graphs and release the lock\n            match STATE.lock().as_ref() {\n                Some(state) => state.graphs.clone(),\n                None => return,\n            }\n        };\n\n        let mut should_remove = Vec::new();\n        for graph in graphs.values() {\n            {\n                let mut guard = graph.state.lock();\n                // Double safety: in case it was marked as no longer useful, but other\n                // nodes are still relevant, we only check which nodes can safely be removed.\n                if !guard.server.maybe_useful() {\n                    guard\n                        .server\n                        .free_unused_roots(|node| should_remove.push(*node));\n                }\n            }\n        }\n\n        if !should_remove.is_empty() {\n            let mut state = STATE.lock();\n            if let Some(state) = state.as_mut() {\n                for node in should_remove {\n                    state.remove_entry(&node);\n                }\n            }\n        }\n    }\n}\n\nimpl<'a> NodeCleaner for GraphCleaner<'a> {\n    fn init() -> Self {\n        let guard = STATE.lock();\n        Self { guard }\n    }\n\n    fn clean(&mut self, node: &NodeId) {\n        if let Some(state) = self.guard.as_mut() {\n            state.remove_entry(node);\n        }\n    }\n}\n\nimpl GraphLocator {\n    /// Selects a single graph for the given [NodeId], considering parent dependencies.\n    ///\n    /// If multiple graphs are found, they are merged into a single one.\n    ///\n    /// # Parameters\n    /// - `node`: The node ID of the graph to select.\n    /// - `parents`: A slice of parent nodes that the graph depends on.\n    ///\n    /// # Returns\n    ///\n    /// An `Arc<Graph>` representing the selected or merged graph.\n    pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {\n        match self.analyse(node, parents) {\n            GraphAnalysis::NoCollision(graph) => {\n                if graph.origin != node {\n                    self.graphs.insert(node, graph.clone());\n                    self.register_key(graph.origin, node);\n                }\n\n                graph\n            }\n            GraphAnalysis::Collisions(graphs) => self.merge(node, graphs),\n        }\n    }\n\n    /// Analyses the graph for a given node and its parents, returning the associated `GraphAnalysis`.\n    fn analyse(&mut self, node: NodeId, parents: &[Parent]) -> GraphAnalysis {\n        // If no parents, there is no collision, therefore a single graph is ok.\n        if parents.is_empty() {\n            let graph = match self.graphs.get(&node) {\n                Some(val) => val.clone(),\n                None => self.new_graph(node),\n            };\n            return GraphAnalysis::NoCollision(graph);\n        };\n\n        // We collect all graphs of parents and of the current node based on their origin node id.\n        let mut graphs = HashMap::<NodeId, Arc<Graph>>::new();\n\n        if let Some(val) = self.graphs.get(&node) {\n            graphs.insert(val.origin, val.clone());\n        }\n\n        for parent in parents {\n            match self.graphs.get(&parent.id) {\n                Some(graph) => graphs.insert(graph.origin, graph.clone()),\n                None => continue,\n            };\n        }\n\n        if graphs.is_empty() {\n            return match self.graphs.get(&node) {\n                Some(old) => GraphAnalysis::NoCollision(old.clone()),\n                None => GraphAnalysis::NoCollision(self.new_graph(node)),\n            };\n        }\n\n        if graphs.len() == 1 {\n            return GraphAnalysis::NoCollision(graphs.drain().next().unwrap().1);\n        }\n\n        GraphAnalysis::Collisions(graphs)\n    }\n\n    /// Merges multiple graphs associated with a node into a single graph.\n    fn merge(&mut self, node: NodeId, mut graphs: HashMap<NodeId, Arc<Graph>>) -> Arc<Graph> {\n        let mut graphs = graphs.drain().map(|g| g.1);\n\n        let main = graphs.next().expect(\"At least one graph\");\n        self.register_key(main.origin, node);\n\n        let mut state = main.state.lock();\n\n        for graph in graphs {\n            self.merge_two(&mut state, &main, graph);\n        }\n\n        self.graphs.insert(main.origin, main.clone());\n        self.graphs.insert(node, main.clone());\n\n        core::mem::drop(state);\n\n        main\n    }\n\n    /// Registers a key for a given origin node.\n    fn register_key(&mut self, origin: NodeId, key: NodeId) {\n        if !self.keys.contains_key(&origin) {\n            // Ensure an entry exists for this origin\n            self.keys.insert(origin, HashSet::new());\n        }\n\n        if origin != key {\n            // Register this node to point to the origin graph\n            self.keys.get_mut(&origin).unwrap().insert(key);\n        }\n    }\n\n    /// Merges two graphs by combining their states and updating graph mappings.\n    fn merge_two(&mut self, main_state: &mut GraphState, main: &Arc<Graph>, merged: Arc<Graph>) {\n        let mut locked = merged.state.lock();\n        let mut state_old = GraphState::default();\n        core::mem::swap(&mut state_old, &mut locked);\n        main_state.server.extend(state_old.server);\n\n        // Re-map merged origin to the main graph\n        self.graphs.insert(merged.origin, main.clone());\n\n        // Move all keys (node IDs) from the merged graph to the main graph\n        if let Some(locator_keys) = self.keys.remove(&merged.origin) {\n            for k in locator_keys.iter() {\n                self.graphs.insert(*k, main.clone());\n            }\n\n            let locator_keys_main = self\n                .keys\n                .get_mut(&main.origin)\n                .expect(\"Should be init before the merge.\");\n            locator_keys_main.extend(locator_keys);\n        }\n    }\n\n    /// Creates a new graph for a given node.\n    fn new_graph(&mut self, origin: NodeId) -> Arc<Graph> {\n        let graph = Arc::new(Graph {\n            origin,\n            state: Mutex::new(GraphState::default()),\n        });\n        self.graphs.insert(origin, graph.clone());\n        self.keys.insert(origin, HashSet::new());\n        graph\n    }\n\n    fn remove_entry(&mut self, node: &NodeId) {\n        if let Some(graph) = self.graphs.remove(node) {\n            let mut remove = false;\n\n            if let Some(entry) = self.keys.get_mut(&graph.origin) {\n                entry.remove(node);\n                if entry.is_empty() {\n                    remove = true;\n                }\n            }\n\n            if remove {\n                self.keys.remove(&graph.origin);\n            }\n        }\n    }\n}\n\n/// Represents the analysis result of graph operations for a given node and its parents.\n#[derive(Debug)]\nenum GraphAnalysis {\n    /// No collision detected, contains the graph associated with the node.\n    NoCollision(Arc<Graph>),\n    /// Collision detected, contains a map of node IDs to their associated graphs.\n    Collisions(HashMap<NodeId, Arc<Graph>>),\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/memory_management.rs",
    "content": "use crate::{\n    NodeId,\n    collections::{HashMap, HashSet},\n    graph::Parent,\n    tensor::NodeRefCount,\n};\nuse alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};\nuse core::mem;\n\n#[derive(Default, Debug)]\npub struct GraphMemoryManagement {\n    nodes: HashMap<NodeRefCount, Vec<NodeId>>,\n    leaves: HashSet<NodeId>,\n    statuses: HashMap<NodeId, NodeMemoryStatus>,\n}\n\n#[derive(Debug, Clone, PartialEq)]\nenum NodeMemoryStatus {\n    Useful,\n    Unavailable,\n    Unknown,\n}\n\nimpl GraphMemoryManagement {\n    pub fn extend(&mut self, other: Self) {\n        self.nodes.extend(other.nodes);\n        self.leaves.extend(other.leaves);\n        self.statuses.extend(other.statuses);\n    }\n\n    /// Register a new node with its parent.\n    pub fn register(&mut self, node: NodeRefCount, parents: &[Parent]) {\n        let node_id = *node.as_ref();\n\n        for parent in parents.iter() {\n            self.leaves.remove(&parent.id);\n        }\n\n        self.leaves.insert(node_id);\n        self.nodes\n            .insert(node, parents.iter().map(|p| p.id).collect());\n    }\n\n    /// Free the node from the state.\n    pub fn consume_node(&mut self, node_id: NodeId) {\n        if !self.is_referenced(node_id) {\n            self.leaves.remove(&node_id);\n            self.nodes.remove(&node_id);\n        }\n    }\n\n    /// Free all nodes whose backward call has become impossible\n    ///\n    /// This function goes into three steps, which must happen for all leaves\n    /// before going into the next step. Then it deletes what can be safely deleted\n    pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {\n        let leaves = self.leaves.clone();\n        let mut new_leaves = HashSet::new();\n        let mut deletables = Vec::new();\n\n        // When consuming nodes with a backward pass, some other backward passes become\n        // unavailable because some of their parents have been consumed. They are\n        // identified here.\n        for leaf in leaves.clone() {\n            self.unavailable_propagation(leaf);\n        }\n\n        // Among the available nodes that remain, some may be useless if no\n        // available node with a tensor reference exist in their descendance.\n        // But some may seem useless from some leaf but be useful from another one,\n        // hence the need to iterate on all leaves.\n        self.useful_propagation(leaves.clone());\n\n        // New leaves are the roots of a useful backward sub-tree.\n        // Deletables are everything not marked as useful.\n        for leaf in leaves {\n            self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);\n        }\n\n        // Replace leaves by the new ones and delete everything not useful anymore\n        mem::swap(&mut self.leaves, &mut new_leaves);\n\n        self.clear_unused_roots(&mut deletables);\n\n        self.statuses.clear();\n        for node_to_delete in deletables {\n            self.nodes.remove(&node_to_delete);\n            on_free_graph(&node_to_delete)\n        }\n    }\n\n    pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {\n        let mut deletables = Vec::new();\n        self.clear_unused_roots(&mut deletables);\n\n        for node_id in deletables {\n            self.nodes.remove(&node_id);\n            on_free_graph(&node_id);\n        }\n    }\n\n    fn clear_unused_roots(&self, to_delete: &mut Vec<NodeId>) {\n        for (id, parents) in self.nodes.iter() {\n            let is_useful = matches!(\n                self.statuses.get(id.as_ref()),\n                Some(NodeMemoryStatus::Useful)\n            );\n\n            // Check if parents are either empty or absent from self.nodes\n            let parents_absent = parents.iter().all(|p| !self.nodes.contains_key(p));\n\n            if !is_useful && Arc::strong_count(id) == 1 && parents_absent {\n                to_delete.push(*id.as_ref())\n            }\n        }\n    }\n\n    fn unavailable_propagation(&mut self, node_id: NodeId) -> NodeMemoryStatus {\n        // If already visited\n        if let Some(status) = self.statuses.get(&node_id) {\n            return status.clone();\n        }\n\n        match self.nodes.get(&node_id).cloned() {\n            // If node exists and any of its parents is unavailable, it is unavailable as well\n            // If node exists but the parents vec is empty, it is a tensor that never had parents;\n            //  the status remains unknown\n            Some(parents) => {\n                let mut node_status = NodeMemoryStatus::Unknown;\n                for parent in parents {\n                    let parent_status = self.unavailable_propagation(parent);\n                    if let NodeMemoryStatus::Unavailable = parent_status {\n                        node_status = NodeMemoryStatus::Unavailable;\n                    }\n                }\n                self.statuses.insert(node_id, node_status.clone());\n                node_status\n            }\n            // If node does not exist, it was\n            // deleted, so this and all its descendants are unavailable\n            None => {\n                self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);\n                NodeMemoryStatus::Unavailable\n            }\n        }\n    }\n\n    fn useful_propagation(&mut self, leaves: HashSet<NodeId>) {\n        // Accumulate visited nodes\n        let mut explored = HashSet::new();\n        let mut tagged_useful = HashSet::new();\n\n        // Queue of nodes to visit\n        let mut to_tag_useful = PopNodeSet::default();\n        let mut to_explore = PopNodeSet::new(leaves);\n\n        // Utility function to iterate over a node's parents\n        let parents = |node_id| {\n            self.nodes\n                .get(&node_id)\n                .cloned()\n                .unwrap_or_default()\n                .into_iter()\n        };\n\n        loop {\n            // Pop a node id, greedily looking at tag_useful ones first\n            let (node_id, status) = match to_tag_useful.pop() {\n                Some(node_id) => (node_id, NodeMemoryStatus::Useful),\n                None => match to_explore.pop() {\n                    Some(node_id) => {\n                        let node_status = self\n                            .statuses\n                            .get(&node_id)\n                            .expect(\"All nodes should have received a status during unavailable_propagation\")\n                            .to_owned();\n\n                        if let NodeMemoryStatus::Unknown = node_status {\n                            match self.is_referenced(node_id) {\n                                true => (node_id, NodeMemoryStatus::Useful),\n                                false => (node_id, NodeMemoryStatus::Unknown),\n                            }\n                        } else {\n                            (node_id, node_status)\n                        }\n                    }\n                    None => {\n                        // There are no nodes in the queues anymore\n                        break;\n                    }\n                },\n            };\n\n            match status {\n                NodeMemoryStatus::Useful => {\n                    tagged_useful.insert(node_id);\n                    for parent in parents(node_id) {\n                        // The node can be explored, as long as it's not already tagged useful\n                        if !(tagged_useful.contains(&parent) || to_tag_useful.contains(&parent)) {\n                            to_tag_useful.insert(parent);\n                        }\n                    }\n                }\n                _ => {\n                    explored.insert(node_id);\n                    for parent in parents(node_id) {\n                        if !(explored.contains(&parent) || to_explore.contains(&parent)) {\n                            to_explore.insert(parent);\n                        }\n                    }\n                }\n            }\n\n            self.statuses.insert(node_id, status);\n        }\n    }\n\n    fn identify_leaves_and_deletables(\n        &self,\n        leaf_id: NodeId,\n        new_leaves: &mut HashSet<NodeId>,\n        to_delete: &mut Vec<NodeId>,\n    ) {\n        let mut visited = HashSet::new();\n        let mut to_visit = vec![leaf_id];\n\n        while let Some(node_id) = to_visit.pop() {\n            visited.insert(node_id);\n\n            match self\n                .statuses\n                .get(&node_id)\n                .expect(\"Node should have status\")\n            {\n                NodeMemoryStatus::Useful => {\n                    new_leaves.insert(node_id);\n                }\n                _ => {\n                    to_delete.push(node_id);\n\n                    for parent in self\n                        .nodes\n                        .get(&node_id)\n                        .cloned()\n                        .unwrap_or_default()\n                        .into_iter()\n                    {\n                        if !visited.contains(&parent) {\n                            to_visit.push(parent);\n                        }\n                    }\n                }\n            };\n        }\n    }\n\n    fn is_referenced(&self, node_id: NodeId) -> bool {\n        match self.nodes.get_key_value(&node_id) {\n            Some((key, _value)) => Arc::strong_count(key) > 1,\n            None => panic!(\"Node should be in the nodes map\"),\n        }\n    }\n\n    pub(crate) fn maybe_useful(&self) -> bool {\n        self.nodes.keys().any(|node| Arc::strong_count(node) > 1)\n    }\n}\n\n/// Wrapper over hash set for fast popping of any node\n#[derive(new, Default)]\nstruct PopNodeSet {\n    hash_set: HashSet<NodeId>,\n}\n\nimpl PopNodeSet {\n    #[inline(always)]\n    fn pop(&mut self) -> Option<NodeId> {\n        self.hash_set\n            .iter()\n            .next()\n            .copied()\n            .and_then(|node_id| self.hash_set.take(&node_id))\n    }\n\n    #[inline(always)]\n    fn contains(&self, node_id: &NodeId) -> bool {\n        self.hash_set.contains(node_id)\n    }\n\n    #[inline(always)]\n    fn insert(&mut self, node_id: NodeId) {\n        self.hash_set.insert(node_id);\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/mod.rs",
    "content": "mod client;\nmod memory_management;\nmod server;\n\npub mod graph;\npub use client::*;\n"
  },
  {
    "path": "crates/burn-autodiff/src/runtime/server.rs",
    "content": "use super::memory_management::GraphMemoryManagement;\nuse crate::{\n    NodeId,\n    checkpoint::{\n        base::{Checkpointer, NodeTree},\n        builder::CheckpointerBuilder,\n    },\n    collections::HashMap,\n    grads::Gradients,\n    graph::{StepBoxed, traversal::BreadthFirstSearch},\n    tensor::NodeRefCount,\n};\nuse alloc::vec::Vec;\n\n#[derive(Default)]\npub struct AutodiffServer {\n    steps: HashMap<NodeId, StepBoxed>,\n    actions_builder: HashMap<NodeId, CheckpointerBuilder>,\n    memory_management: GraphMemoryManagement,\n}\n\n/// Defines how nodes are clean.\npub trait NodeCleaner {\n    /// Initialize a new cleaner.\n    fn init() -> Self;\n    /// Cleans a single [node](NodeId).\n    fn clean(&mut self, node: &NodeId);\n}\n\nimpl AutodiffServer {\n    pub fn extend(&mut self, other: AutodiffServer) {\n        self.steps.extend(other.steps);\n        self.actions_builder.extend(other.actions_builder);\n        self.memory_management.extend(other.memory_management);\n    }\n\n    pub fn register(&mut self, rc: NodeRefCount, step: StepBoxed, actions: CheckpointerBuilder) {\n        let parents = step.parents();\n        let node_id = *rc.as_ref();\n\n        self.memory_management.register(rc, parents);\n\n        self.steps.insert(node_id, step);\n        self.actions_builder.insert(node_id, actions);\n    }\n\n    pub fn backward<NC: NodeCleaner>(&mut self, grads: Gradients, node_id: NodeId) -> Gradients {\n        let step = self.steps.remove(&node_id).expect(\n            \"Node should have a step registered, did you forget to call \\\n             `Tensor::register_grad` on the tensor where you need gradients?\",\n        );\n        let builder = self.actions_builder.remove(&node_id).unwrap();\n\n        let mut consumed = Vec::new();\n        let (tape, checkpointer) = self.build_tape(node_id, step, builder, &mut consumed);\n\n        let gradients = Self::execute_steps(tape, grads, checkpointer);\n\n        // Cleanup\n        let mut cleaner = NC::init();\n        self.memory_management\n            .free_unavailable_nodes(|node_id: &NodeId| {\n                self.steps.remove(node_id);\n                self.actions_builder.remove(node_id);\n                NC::clean(&mut cleaner, node_id);\n            });\n        for node_id in consumed {\n            cleaner.clean(&node_id)\n        }\n\n        gradients\n    }\n\n    pub(crate) fn free_unused_roots(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {\n        self.memory_management.free_unused_roots(|node_id| {\n            self.steps.remove(node_id);\n            self.actions_builder.remove(node_id);\n            on_free_graph(node_id);\n        });\n    }\n\n    fn build_tape(\n        &mut self,\n        node: NodeId,\n        node_step: StepBoxed,\n        mut builder: CheckpointerBuilder,\n        consumed: &mut Vec<NodeId>,\n    ) -> (Vec<Vec<StepBoxed>>, Checkpointer) {\n        let mut tape = (0..node_step.depth())\n            .map(|_| Vec::with_capacity(1))\n            .collect::<Vec<_>>();\n\n        let mut tree = HashMap::default();\n\n        BreadthFirstSearch.traverse(node, node_step, &mut self.steps, |id, step| {\n            self.memory_management.consume_node(id);\n            // Clean up consumed node\n            consumed.push(id);\n\n            let depth = step.depth();\n\n            if depth == 0 {\n                return;\n            }\n\n            if let Some(steps) = tape.get_mut(depth - 1) {\n                let parents = step.parents().iter().map(|p| p.id).filter(|s| *s != id);\n                tree.insert(id, parents.collect());\n                steps.push(step);\n            }\n\n            if let Some(node_builder) = self.actions_builder.remove(&id) {\n                builder.extend(node_builder);\n            }\n        });\n\n        let checkpointer = builder.build(NodeTree::new(tree));\n\n        (tape, checkpointer)\n    }\n\n    fn execute_steps(\n        tape: Vec<Vec<StepBoxed>>,\n        mut grads: Gradients,\n        mut checkpointer: Checkpointer,\n    ) -> Gradients {\n        tape.into_iter().rev().for_each(|steps| {\n            steps\n                .into_iter()\n                .for_each(|step| step.step(&mut grads, &mut checkpointer))\n        });\n\n        // For checkpointing tests\n        #[cfg(feature = \"export_tests\")]\n        assert!(checkpointer.is_empty());\n\n        grads\n    }\n\n    pub(crate) fn maybe_useful(&self) -> bool {\n        self.memory_management.maybe_useful()\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/tensor.rs",
    "content": "use crate::{\n    checkpoint::{base::Checkpointer, builder::CheckpointerBuilder},\n    grads::Gradients,\n    graph::{ComputingProperty, Node, NodeId, NodeRef, Parent, Requirement, Step},\n    runtime::{AutodiffClient, AutodiffClientImpl},\n};\nuse alloc::{boxed::Box, sync::Arc, vec};\nuse burn_backend::{Backend, TensorMetadata};\n\n#[derive(Debug, Clone)]\npub struct AutodiffTensor<B: Backend> {\n    pub primitive: B::FloatTensorPrimitive,\n    pub node: NodeRef,\n    pub rc: NodeRefCount,\n}\n\nimpl<B: Backend> TensorMetadata for AutodiffTensor<B> {\n    fn dtype(&self) -> burn_std::DType {\n        self.primitive.dtype()\n    }\n\n    fn shape(&self) -> burn_std::Shape {\n        self.primitive.shape()\n    }\n\n    fn rank(&self) -> usize {\n        self.primitive.rank()\n    }\n}\n\npub type NodeRefCount = Arc<NodeId>;\n\n#[derive(new, Debug)]\npub(crate) struct RootStep {\n    node: NodeRef,\n}\n\nimpl Step for RootStep {\n    fn step(self: Box<Self>, _grads: &mut Gradients, _checkpointer: &mut Checkpointer) {\n        // Nothing to do\n    }\n\n    fn node(&self) -> NodeId {\n        self.node.id\n    }\n\n    fn parents(&self) -> &[Parent] {\n        &self.node.parents\n    }\n\n    fn depth(&self) -> usize {\n        self.node.order\n    }\n}\n\nimpl<B: Backend> AutodiffTensor<B> {\n    /// Create a new leaf tensor.\n    pub fn new(primitive: B::FloatTensorPrimitive) -> Self {\n        let id = NodeId::new();\n        let node: NodeRef = Node::new(\n            vec![],\n            0,\n            id,\n            Requirement::None,\n            ComputingProperty::Ambiguous,\n            AutodiffClientImpl::new(),\n        )\n        .into();\n\n        Self {\n            rc: Arc::new(node.id),\n            primitive,\n            node: node.clone(),\n        }\n    }\n\n    pub fn is_tracked(&self) -> bool {\n        !self.node.requirement.is_none()\n    }\n\n    /// Mark the tensor as requiring gradients.\n    ///\n    /// # Panics\n    ///\n    /// It panics if the tensor is not a leaf.\n    pub fn require_grad(mut self) -> Self {\n        match self.node.requirement {\n            Requirement::Grad => self,\n            Requirement::GradInBackward => {\n                panic!(\"Can't convert a non leaf tensor into a tracked tensor\")\n            }\n            Requirement::None => {\n                self.node = Node::new(\n                    vec![],\n                    0,\n                    self.node.id,\n                    Requirement::Grad,\n                    self.node.properties.clone(),\n                    self.node.client.clone(),\n                )\n                .into();\n                let step = RootStep::new(self.node.clone());\n\n                self.register_step(step, CheckpointerBuilder::default())\n            }\n        }\n    }\n\n    /// Create a tensor from parent infos.\n    pub fn from_parents(\n        primitive: B::FloatTensorPrimitive,\n        parent_nodes: &[NodeRef],\n        requirement: Requirement,\n        computing_properties: ComputingProperty,\n    ) -> Self {\n        let order = parent_nodes\n            .iter()\n            .map(|node| node.order)\n            .reduce(usize::max)\n            .unwrap_or(0)\n            + 1;\n\n        let client = parent_nodes\n            .first()\n            .map(|node| node.client.clone())\n            .unwrap_or_else(AutodiffClientImpl::new);\n\n        let node: NodeRef = Node::new(\n            parent_nodes\n                .iter()\n                .filter_map(|node| node.clone_if_require_grad())\n                .map(|node| Parent::new(node.id))\n                .collect(),\n            order,\n            NodeId::new(),\n            requirement,\n            computing_properties,\n            client,\n        )\n        .into();\n\n        Self {\n            rc: Arc::new(node.id),\n            primitive,\n            node,\n        }\n    }\n\n    /// Register a step into a graph for that tensor.\n    ///\n    /// # Warning\n    ///\n    /// This should be called only once per tensor.\n    pub fn register_step<S: Step + 'static>(\n        self,\n        step_that_created_the_tensor: S,\n        actions: CheckpointerBuilder,\n    ) -> Self {\n        self.node.client.register(\n            self.rc.clone(),\n            Box::new(step_that_created_the_tensor),\n            actions,\n        );\n        self\n    }\n\n    pub fn into_primitive(self) -> B::FloatTensorPrimitive {\n        self.primitive\n    }\n\n    pub fn backward(self) -> Gradients {\n        let client = self.node.client.clone();\n\n        AutodiffClient::backward::<B>(&client, self)\n    }\n\n    pub fn grad(&self, grads: &Gradients) -> Option<B::FloatTensorPrimitive> {\n        grads.get::<B>(self)\n    }\n\n    pub fn grad_remove(&self, grads: &mut Gradients) -> Option<B::FloatTensorPrimitive> {\n        grads.remove::<B>(self)\n    }\n\n    pub fn grad_replace(&self, grads: &mut Gradients, grad: B::FloatTensorPrimitive) {\n        grads.remove::<B>(self);\n        grads.register::<B>(self.node.id, grad);\n    }\n}\n"
  },
  {
    "path": "crates/burn-autodiff/src/utils.rs",
    "content": "use alloc::vec::Vec;\n\nuse crate::graph::NodeRef;\n/// Duplicate the given object for each node that requires gradients.\n///\n/// # Notes\n///\n/// This is useful since you don't have to keep N cloned references alive event if just 1 node\n/// will be updated.\n///\n/// If the object is a tensor and if one reference exists, it can be updated inplace.\npub fn duplicate<T: Clone + core::fmt::Debug, const N: usize>(\n    nodes: &[Option<NodeRef>; N],\n    obj: Option<T>,\n) -> [Option<T>; N] {\n    nodes\n        .iter()\n        .map(|node| match node {\n            Some(_) => obj.clone(),\n            None => None,\n        })\n        .collect::<Vec<_>>()\n        .try_into()\n        .unwrap()\n}\n"
  },
  {
    "path": "crates/burn-backend/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Core backend interfaces and data structures for executing tensor operations in Burn.\"\ndocumentation = \"https://docs.rs/burn-backend\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-backend\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-backend\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\ndoc = [\"default\"]\nstd = [\"rand/std\", \"num-traits/std\", \"burn-std/std\", \"cubecl?/std\"]\n\ntracing = [\"burn-std/tracing\", \"cubecl/tracing\"]\n\ncubecl = [\"dep:cubecl\", \"burn-std/cubecl\"]\ncubecl-cuda = [\"cubecl\", \"cubecl/cuda\"]\ncubecl-hip = [\"cubecl\", \"cubecl/hip\"]\ncubecl-wgpu = [\"cubecl\", \"cubecl/wgpu\"]\ncubecl-cpu = [\"cubecl\", \"cubecl/cpu\"]\n\n[dependencies]\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\ncubecl = { workspace = true, optional = true, default-features = false }\n\nbytemuck = { workspace = true, features = [\"extern_crate_alloc\"] }\nderive-new = { workspace = true }\nenumset = { workspace = true }\nhashbrown = { workspace = true }\nnum-traits = { workspace = true }\nrand = { workspace = true, default-features = false }\nrand_distr = { workspace = true }\nserde = { workspace = true }\nthiserror = { workspace = true }\n\n[dev-dependencies]\nrand = { workspace = true, features = [\"thread_rng\"] }\npaste = { workspace = true }\nserde_json = { workspace = true, features = [\"alloc\"]}\n"
  },
  {
    "path": "crates/burn-backend/README.md",
    "content": "# Burn Backend\n\nThis crate includes the core backend interfaces and data structures for executing tensor operations\nin Burn.\n"
  },
  {
    "path": "crates/burn-backend/src/backend/base.rs",
    "content": "use burn_std::DType;\npub use burn_std::backtrace::BackTrace;\n\nuse alloc::string::String;\nuse enumset::{EnumSet, EnumSetType};\nuse serde::{Deserialize, Serialize};\nuse thiserror::Error;\n\nuse crate::element::Element;\nuse crate::ops::*;\nuse crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};\nuse crate::{QTensorPrimitive, TensorData, TensorMetadata};\n\nuse super::DeviceOps;\n\n/// This trait defines all types and functions needed for a backend to be used with burn.\n///\n/// ## Design\n///\n/// This trait aims to be as unopinionated as possible and allows implementations to define\n/// their own types and patterns. Therefore, there are few pre-defined abstractions baked\n/// into this trait.\n///\n/// Backends must define their own tensor types for each data type: `float`, `int`, and `bool`.\n/// Since we minimize assumptions, we chose to separate these types, as they are used in\n/// different contexts. However, some backends may have a generic tensor type that is used\n/// for all data types.\n///\n/// ### Eager Mode\n///\n/// Because burn supports dynamic graphs, the backend trait is designed around kernel\n/// implementations that can be called without any mutable context or graph. This may not be\n/// ideal for backends that want to configure their computational graphs and execute them\n/// multiple times.\n///\n/// To implement this kind of backend, channels could be used to communicate with a backend\n/// server thread to build the computation graphs and re-execute the ones that are repeated,\n/// with some form of cache. Once that pattern has matured, a graph mode backend trait could\n/// be extracted from it, allowing other backends of the same kind to be quickly integrated\n/// with burn. This pattern could also be used to create an operation fusion trait, which\n/// allows backends to define what kind of graph structures can be fused into one operation.\n///\n/// ### Multi-Threaded\n///\n/// Backend tensor types are all `Clone` + `Send`, which allows them to be safely\n/// sent between threads. It is recommended to wrap tensors with [Arc](alloc::sync::Arc),\n/// which avoids copying the tensor's buffer. Note that it is still possible to mutate and\n/// reuse tensors' buffer without locking; see the next section on the Mutable API.\n///\n/// ### Mutable API\n///\n/// There is no mutable or inplace operation API to implement, but that does not mean that\n/// backends cannot support them. Using [try_unwrap](alloc::sync::Arc::try_unwrap) and\n/// [get_mut](alloc::sync::Arc::get_mut) allows backends to have access to an owned or mutable\n/// reference to their tensor buffer data structure if the tensor is not shared. In that case,\n/// backends can dispatch to their owned inplace operations for better performance.\n///\n/// ## Documentation\n///\n/// Most of the documentation for each function can be found on the user API\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// struct in the `burn-tensor` crate.\n/// For modules, public functions are often created, which can be used by `burn-core` modules.\npub trait Backend:\n    FloatTensorOps<Self>\n    + BoolTensorOps<Self>\n    + IntTensorOps<Self>\n    + ModuleOps<Self>\n    + ActivationOps<Self>\n    + QTensorOps<Self>\n    + TransactionOps<Self>\n    + Clone\n    + Default\n    + Sized\n    + Send\n    + Sync\n    + core::fmt::Debug\n    + 'static\n{\n    /// Device type.\n    type Device: DeviceOps;\n\n    /// Tensor primitive to be used for all float operations.\n    type FloatTensorPrimitive: TensorMetadata + 'static;\n    /// Default float element type.\n    type FloatElem: Element;\n\n    /// Tensor primitive to be used for all int operations.\n    type IntTensorPrimitive: TensorMetadata + 'static;\n    /// Int element type.\n    type IntElem: Element;\n\n    /// Tensor primitive to be used for all bool operations.\n    type BoolTensorPrimitive: TensorMetadata + 'static;\n    /// Tensor primitive to be used for all bool operations.\n    type BoolElem: Element;\n\n    /// Tensor primitive to be used for all quantized operations.\n    type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;\n\n    /// If autodiff is enabled.\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    /// Sets the current allocation mode to persistent.\n    #[allow(unused_variables)]\n    fn memory_persistent_allocations<\n        Output: Send,\n        Input: Send,\n        Func: Fn(Input) -> Output + Send,\n    >(\n        device: &Self::Device,\n        input: Input,\n        func: Func,\n    ) -> Output {\n        func(input)\n    }\n\n    /// Manually triggers a memory cleanup on the given device.\n    #[allow(unused_variables)]\n    fn memory_cleanup(device: &Self::Device) {}\n\n    /// Name of the backend.\n    fn name(device: &Self::Device) -> String;\n\n    /// Seeds the backend on the specified device.\n    ///\n    /// There is no guarantee that only the specified device will be seeded, but it is guaranteed\n    /// that at least the specified device will be seeded.\n    ///\n    /// In all cases, this should ensure deterministic execution for a single-threaded program.\n    fn seed(device: &Self::Device, seed: u64);\n\n    /// Sync the backend, ensure that all computation are finished.\n    fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {\n        Ok(())\n    }\n\n    /// Marks the given data as being used as a staging buffer for transfer between CPU and\n    /// accelerators like GPUs.\n    ///\n    /// The given data might be transferred to pinned memory or another format to improve data transfer\n    /// speed.\n    fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)\n    where\n        Iter: Iterator<Item = &'a mut TensorData>,\n    {\n    }\n\n    /// Whether the type is fully supported by the specified device for general operations.\n    ///\n    /// A type is considered supported if it can be used for the full suite of tensor\n    /// operations, including storage, conversion, and basic arithmetic.\n    ///\n    /// Returning `false` does not necessarily mean the device cannot handle the type at all.\n    /// For instance, a device might support a type only for specialized hardware\n    /// acceleration (e.g., matrix multiplication) but lack general arithmetic support. Such\n    /// types should return `false` here as they are not globally supported.\n    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {\n        Self::dtype_usage(device, dtype).is_superset(DTypeUsage::general())\n    }\n\n    /// Returns the [DTypeUsageSet] for the given [DType] on the specified device.\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet;\n}\n\n/// An error that can happen when syncing a device.\n#[derive(Error, Serialize, Deserialize)]\npub enum ExecutionError {\n    /// A generic error happened during execution.\n    ///\n    /// The backtrace and context information should be included in the reason string.\n    #[error(\"An error happened during execution\\nCaused by:\\n  {reason}\")]\n    WithContext {\n        /// The reason of the error.\n        reason: String,\n    },\n    /// A generic error happened during execution thrown in the Burn project.\n    ///\n    /// The full context isn't captured by the string alone.\n    #[error(\"An error happened during execution\\nCaused by:\\n  {reason}\")]\n    Generic {\n        /// The reason of the error.\n        reason: String,\n        /// The backtrace.\n        #[serde(skip)]\n        backtrace: BackTrace,\n    },\n}\n\nimpl core::fmt::Debug for ExecutionError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_fmt(format_args!(\"{self}\"))\n    }\n}\n\n/// Trait that allows a backend to support autodiff.\npub trait AutodiffBackend: Backend {\n    /// The inner backend type.\n    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;\n\n    /// Gradients type.\n    type Gradients: Send;\n\n    /// Backward pass.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.\n    ///\n    /// # Returns\n    ///\n    /// The gradients.\n    fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;\n\n    /// Returns the gradients of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to extract the gradients from.\n    ///\n    /// # Returns\n    ///\n    /// An optional tensor containing the gradient.\n    fn grad(\n        tensor: &FloatTensor<Self>,\n        grads: &Self::Gradients,\n    ) -> Option<FloatTensor<Self::InnerBackend>>;\n\n    /// Pops the gradients of a tensor and returns them.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to pop the gradients from.\n    /// * `grads` - The gradients.\n    ///\n    /// # Returns\n    ///\n    /// An optional tensor containing the given gradients.\n    fn grad_remove(\n        tensor: &FloatTensor<Self>,\n        grads: &mut Self::Gradients,\n    ) -> Option<FloatTensor<Self::InnerBackend>>;\n\n    /// Replace the gradients of a tensor with the one provided.\n    ///\n    /// If no gradient existed for the provided tensor, register it.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to pop the gradients from.\n    /// * `grads` - The gradients.\n    /// * `grad` - The updated grad tensor.\n    fn grad_replace(\n        tensor: &FloatTensor<Self>,\n        grads: &mut Self::Gradients,\n        grad: FloatTensor<Self::InnerBackend>,\n    );\n\n    /// Returns the tensor with inner backend type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the inner backend tensor for.\n    ///\n    /// # Returns\n    ///\n    /// The inner backend tensor.\n    fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;\n\n    /// Returns the tensor with inner backend type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the inner backend tensor for.\n    ///\n    /// # Returns\n    ///\n    /// The inner backend tensor.\n    fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;\n\n    /// Returns the tensor with inner backend type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the inner backend tensor for.\n    ///\n    /// # Returns\n    ///\n    /// The inner backend tensor.\n    fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;\n\n    /// Returns the tensor with inner backend type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the inner backend tensor for.\n    ///\n    /// # Returns\n    ///\n    /// The inner backend tensor.\n    fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;\n\n    /// Converts the inner backend tensor to the autodiff backend tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The inner backend tensor to convert.\n    ///\n    ///\n    /// # Returns\n    ///\n    /// The autodiff backend tensor.\n    fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;\n\n    /// Converts the inner backend tensor to the autodiff backend tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The inner backend tensor to convert.\n    ///\n    ///\n    /// # Returns\n    ///\n    /// The autodiff backend tensor.\n    fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;\n\n    /// Converts the inner backend tensor to the autodiff backend tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The inner backend tensor to convert.\n    ///\n    ///\n    /// # Returns\n    ///\n    /// The autodiff backend tensor.\n    fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;\n\n    /// Converts the inner backend tensor to the autodiff backend tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The inner backend tensor to convert.\n    ///\n    ///\n    /// # Returns\n    ///\n    /// The autodiff backend tensor.\n    fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;\n}\n\n/// Describes how a data type can be used on a given device.\n///\n/// A data type may be supported for different classes of operations. Not all\n/// data types that appear in hardware or kernel implementations are suitable\n/// for general-purpose tensor operations.\n#[derive(Debug, EnumSetType)]\npub enum DTypeUsage {\n    /// The type can be stored in device memory and converted to and from\n    /// other supported data types.\n    Storage,\n    /// The type supports general-purpose arithmetic and common tensor\n    /// operations (e.g. elementwise ops, reductions, etc.).\n    Arithmetic,\n    /// The type is supported by hardware-accelerated execution paths.\n    ///\n    /// This typically indicates support for accelerator-backed compute units (e.g., tensor\n    /// cores executing MMA instructions) for high-performance operations such as matrix\n    /// multiplication and operations that lower to it.\n    ///\n    /// # Notes\n    /// - A type can be both [`Arithmetic`](DTypeUsage::Arithmetic) and\n    ///   [`Accelerated`](DTypeUsage::Accelerated) if it supports general-purpose operations\n    ///   *and* accelerated paths.\n    /// - If a type is marked as `Accelerated` but not `Arithmetic`, it is not\n    ///   suitable for general-purpose tensor operations and may only be used\n    ///   in specific accelerated operations.\n    ///\n    /// `Accelerated` is a **flag**, not a detailed descriptor. It does not enumerate which\n    /// operations are accelerated or which accelerator features are available.\n    Accelerated,\n}\n\n/// A set of [DTypeUsage] representing the total capabilities of a data type on a device.\npub type DTypeUsageSet = EnumSet<DTypeUsage>;\n\nimpl DTypeUsage {\n    /// Returns the usage set required for general-purpose tensor support.\n    pub fn general() -> DTypeUsageSet {\n        DTypeUsage::Storage | DTypeUsage::Arithmetic\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/device.rs",
    "content": "pub use burn_std::device::*;\n\n/// Device trait for all burn backend devices.\npub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + Device {\n    /// Returns the [device id](DeviceId).\n    fn id(&self) -> DeviceId {\n        self.to_id()\n    }\n\n    /// Returns the inner device without autodiff enabled.\n    ///\n    /// For most devices this is a no-op that returns `self`. For autodiff-enabled\n    /// devices, this returns the underlying inner device.\n    fn inner(&self) -> &Self {\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/mod.rs",
    "content": "mod base;\nmod device;\nmod primitive;\n\npub use base::*;\npub use device::*;\npub use primitive::*;\n\n/// Backend operations on tensors.\npub mod ops;\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/activation.rs",
    "content": "use crate::tensor::FloatTensor;\nuse crate::{Backend, Scalar, TensorMetadata};\nuse core::f64::consts::SQRT_2;\n\n/// Activation function operations.\n///\n/// This trait let backend implementations override activation functions for better performance.\npub trait ActivationOps<B: Backend> {\n    /// Applies the LeakyReLU activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `negative_slope` - The negative_slope value that values smaller than 0 are multiplied with.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn leaky_relu(tensor: FloatTensor<B>, negative_slope: Scalar) -> FloatTensor<B> {\n        let mask = B::float_lower_elem(tensor.clone(), 0f32.into());\n        let scaled_tensor = B::float_mul_scalar(tensor.clone(), negative_slope);\n\n        // Update the tensor where the values are `< 0` by `tensor * negative_slope`.\n        B::float_mask_where(tensor, mask, scaled_tensor)\n    }\n\n    /// Applies the ReLU activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn relu(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let mask = B::float_lower_equal_elem(tensor.clone(), 0f32.into());\n\n        B::float_mask_fill(tensor, mask, 0f32.into())\n    }\n\n    /// Applies the ReLU activation function backward.\n    ///\n    /// # Arguments\n    ///\n    /// * `output` - The output tensor.\n    ///\n    /// # Returns\n    ///\n    /// The gradient.\n    fn relu_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {\n        let mask = B::float_lower_equal_elem(output, 0f32.into());\n\n        B::float_mask_fill(grad, mask, 0.into())\n    }\n\n    /// Applies the Gelu activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn gelu(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let x = B::float_div_scalar(tensor.clone(), SQRT_2.into());\n        let x = B::float_erf(x);\n        let x = B::float_add_scalar(x, 1f32.into());\n        let x = B::float_mul(tensor, x);\n\n        B::float_div_scalar(x, 2f32.into())\n    }\n    /// Applies the PReLu activation function.\n    /// # Arguments\n    /// * `tensor` - The input tensor\n    /// * `alpha` - The weight tensor\n    fn prelu(tensor: FloatTensor<B>, alpha: FloatTensor<B>) -> FloatTensor<B> {\n        let mask = B::float_lower_elem(tensor.clone(), 0f32.into());\n        let scaled_tensor = B::float_mul(tensor.clone(), alpha);\n        B::float_mask_where(tensor, mask, scaled_tensor)\n    }\n\n    /// Applies the Gelu activation function backward.\n    ///\n    /// # Arguments\n    ///\n    /// * `x` - The tensor.\n    /// * `grad` - The gradient.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn gelu_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {\n        // Derivative of the approximate gelu implementation based on tanh.\n\n        let constant_1 = 0.0356774;\n        let constant_2 = 0.797885;\n        let constant_3 = 0.0535161;\n        let constant_4 = 0.398942;\n\n        let x3 = B::float_powi_scalar(x.clone(), 3.into());\n\n        let c1 = B::float_mul_scalar(x3.clone(), constant_1.into());\n        let c2 = B::float_mul_scalar(x.clone(), constant_2.into());\n        let c3 = B::float_mul_scalar(x3, constant_3.into());\n        let c4 = B::float_mul_scalar(x, constant_4.into());\n\n        let inner1 = B::float_add(c1, c2);\n        let inner2 = B::float_add(c3, c4);\n\n        let tanh = B::float_tanh(inner1);\n\n        let sech = B::float_powi_scalar(tanh.clone(), 2.into());\n        let sech = B::float_neg(sech);\n        let sech = B::float_add_scalar(sech, 1.into());\n\n        let y1 = B::float_mul_scalar(tanh, 0.5.into());\n        let y2 = B::float_mul(inner2, sech);\n        let y2 = B::float_add_scalar(y2, 0.5.into());\n        let y = B::float_add(y1, y2);\n\n        B::float_mul(y, grad)\n    }\n\n    /// Applies the Sigmoid activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let dtype = tensor.dtype();\n        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);\n        let tensor_tmp = B::float_exp(B::float_neg(B::float_log(B::float_add_scalar(\n            B::float_exp(B::float_neg(tensor_full)),\n            1.0.into(),\n        ))));\n\n        B::float_cast(tensor_tmp, dtype.into())\n    }\n\n    /// Applies the Sigmoid activation function backward.\n    ///\n    /// # Arguments\n    ///\n    /// * `output` - The output tensor of the sigmoid function.\n    /// * `grad` - The gradient.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn sigmoid_backward(output: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {\n        let value = B::float_mul(\n            output.clone(),\n            B::float_add_scalar(B::float_neg(output), 1.0.into()),\n        );\n        B::float_mul(value, grad)\n    }\n\n    /// Applies the hard Sigmoid activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `alpha` - The alpha value that the tensor is multiplied with.\n    /// * `beta` - The beta value that is added to the tensor\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn hard_sigmoid(tensor: FloatTensor<B>, alpha: Scalar, beta: Scalar) -> FloatTensor<B> {\n        let dtype = tensor.dtype();\n        let tensor_full = B::float_cast(tensor, burn_std::FloatDType::F32);\n\n        let tensor_tmp = B::float_clamp(\n            B::float_add_scalar(B::float_mul_scalar(tensor_full, alpha), beta),\n            0.0.into(),\n            1.0.into(),\n        );\n\n        B::float_cast(tensor_tmp, dtype.into())\n    }\n\n    /// Applies the LogSigmoid activation function.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn log_sigmoid(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        // To avoid overflow, we use the log-sum-exp trick.\n        //\n        // ```ignore\n        // log(sigmoid(x)) = log(1/(1 + exp(-x)))\n        //                 = log(1) - log(1 + exp(-x))\n        //                 = -log(1 + exp(-x))\n        //                 = -log(exp(0) + exp(-x))\n        // ```\n        // The `exp(t)` of even a moderate-magnitude positive number can be astronomically huge, so we\n        // subtract the `max(t, 0)` of each value (where `t = -x` in this case). This results in the\n        // following equivalence:\n        // ```ignore\n        // log(sigmoid(x)) = -(max(-x, 0) + log(exp(-max(-x, 0)) + exp(-x - max(-x, 0))))\n        // ```\n        //\n        // This extends the range of values for which we obtain accurate results.\n\n        // max(-x, 0)\n        let tensor_neg = B::float_neg(tensor);\n        let mask = B::float_lower_elem(tensor_neg.clone(), 0f32.into());\n        let max_elem = B::float_mask_fill(tensor_neg.clone(), mask, 0f32.into());\n        let max_elem_neg = B::float_neg(max_elem.clone());\n\n        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))\n        let z = B::float_add(\n            B::float_exp(max_elem_neg.clone()),\n            B::float_exp(B::float_sub(tensor_neg, max_elem.clone())),\n        );\n\n        // -max(-x, 0) - log(-z)\n        B::float_sub(max_elem_neg, B::float_log(z))\n    }\n\n    /// Applies the LogSigmoid activation function backward.\n    ///\n    /// # Arguments\n    ///\n    /// * `x` - The input tensor.\n    /// * `grad` - The gradient.\n    ///\n    /// # Returns\n    ///\n    /// The output gradient.\n    fn log_sigmoid_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {\n        // Derivative of -max(-x, 0) - log(exp(-max(-x, 0)) - exp(-x - max(-x, 0)))) is\n        // -max_derive - (-max_derive * exp(-max(-x, 0)) + (-1 - max_derive) * exp(-x - max(-x, 0))) / z\n        // where z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))\n        //\n        // This simplifies to:\n        // -max_derive - (z-1)/z if x is >= 0\n        // -max_derive + (z-1)/z if x is < 0\n\n        let shape = x.shape();\n        let dtype = x.dtype();\n        let device = B::float_device(&x);\n\n        // max(-x, 0)\n        let x_neg = B::float_neg(x);\n        let mask = B::float_lower_elem(x_neg.clone(), 0f32.into()); // -x < 0 or x >= 0\n        let max_elem = B::float_mask_fill(x_neg.clone(), mask.clone(), 0f32.into());\n\n        // z = exp(-max(-x, 0)) + exp(-x - max(-x, 0))\n        let z = B::float_add(\n            B::float_exp(B::float_neg(max_elem.clone())),\n            B::float_exp(B::float_sub(x_neg, max_elem)),\n        );\n\n        // Derivative of max(-x, 0) is 1 if x < 0 or 0 if x >= 0\n        let ones = B::float_ones(shape, &device, dtype.into());\n        let max_derive = B::float_mask_fill(ones.clone(), mask.clone(), 0f32.into());\n        let sign = B::float_mask_fill(ones.clone(), mask, (-1f32).into());\n\n        // grad * (max_derive - sign * (1 - (1 / z)))\n        B::float_mul(\n            grad,\n            B::float_sub(\n                max_derive,\n                B::float_mul(sign, B::float_sub(ones, B::float_recip(z))),\n            ),\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/argwhere.rs",
    "content": "use crate::tensor::{Device, IntTensor};\nuse crate::{Backend, TensorData, element::ElementConversion};\nuse alloc::vec::Vec;\nuse burn_std::Shape;\n\n/// Compute the indices of the elements that are non-zero, grouped by element.\n///\n/// # Arguments\n///\n/// * `data` - The input tensor data.\n///\n/// # Returns\n///\n/// A 2D tensor containing the indices of all non-zero elements of the given tensor.\n/// Each row contains the indices of a non-zero element.\n///\n/// # Remarks\n///\n/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.\n/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved\n/// by static dispatch. It is not designed for direct usage by users, and not recommended to import\n/// or use this function directly.\npub fn argwhere_data<B: Backend>(data: TensorData, device: &Device<B>) -> IntTensor<B> {\n    let dims = &data.shape;\n    let ndims = dims.len();\n    let count_nonzero = data.iter::<bool>().filter(|&v| v).count();\n\n    /// Converts a flat index into a vector of indices for the specified tensor shape\n    fn unravel_index<B: Backend>(index: usize, shape: &[usize]) -> Vec<B::IntElem> {\n        shape\n            .iter()\n            .rev()\n            .scan(index, |i, size| {\n                let dim_idx = *i % size;\n                *i /= size;\n                Some((dim_idx as i64).elem())\n            })\n            .collect::<Vec<_>>()\n            .into_iter()\n            .rev()\n            .collect()\n    }\n\n    let indices = data\n        .iter::<bool>()\n        .enumerate()\n        .filter_map(|(index, v)| if v { Some(index) } else { None })\n        .map(|index| unravel_index::<B>(index, dims))\n        .collect::<Vec<_>>()\n        .concat();\n\n    B::int_from_data(\n        TensorData::new(indices, Shape::new([count_nonzero, ndims])),\n        device,\n    )\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/bool_tensor.rs",
    "content": "use super::{\n    argwhere::argwhere_data, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign,\n};\nuse crate::tensor::{Bool, BoolTensor, Device, FloatTensor, IntTensor};\nuse crate::{Backend, TensorData, TensorMetadata};\nuse crate::{ExecutionError, Scalar};\nuse alloc::vec::Vec;\nuse burn_std::{Shape, Slice};\nuse core::future::Future;\n\n/// Bool Tensor API for basic operations, see\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// for documentation on each function.\npub trait BoolTensorOps<B: Backend> {\n    /// Creates a new bool tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the given shape.\n    fn bool_empty(shape: Shape, device: &Device<B>) -> BoolTensor<B>;\n\n    /// Creates a new bool tensor filled false.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor filled with false.\n    fn bool_zeros(shape: Shape, device: &Device<B>) -> BoolTensor<B>;\n\n    /// Creates a new bool tensor filled true.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor filled with true.\n    fn bool_ones(shape: Shape, device: &Device<B>) -> BoolTensor<B>;\n\n    /// Converts the tensor to a data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The data structure with the tensor's data.\n    fn bool_into_data(\n        tensor: BoolTensor<B>,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;\n\n    /// Creates a tensor from the data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The data structure.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the data.\n    fn bool_from_data(data: TensorData, device: &Device<B>) -> BoolTensor<B>;\n\n    /// Converts bool tensor to int tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The int tensor with the same data as the bool tensor.\n    fn bool_into_int(tensor: BoolTensor<B>) -> IntTensor<B>;\n\n    /// Converts bool tensor to float tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The float tensor with the same data as the bool tensor.\n    fn bool_into_float(tensor: BoolTensor<B>) -> FloatTensor<B>;\n\n    /// Gets the device of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The device of the tensor.\n    fn bool_device(tensor: &BoolTensor<B>) -> Device<B>;\n\n    /// Moves the tensor to the device.\n    fn bool_to_device(tensor: BoolTensor<B>, device: &Device<B>) -> BoolTensor<B>;\n\n    /// Reshapes the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `shape` - The new shape.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the new shape.\n    fn bool_reshape(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;\n\n    /// Gets the values from the tensor for the given ranges.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `slices` - The slices specifying ranges and steps for each dimension.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values for the given slices.\n    ///\n    /// # Note\n    ///\n    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not\n    /// be passed to this method. Backend implementations do not need to handle empty slices.\n    fn bool_slice(tensor: BoolTensor<B>, slices: &[Slice]) -> BoolTensor<B>;\n\n    /// Sets the values in the tensor for the given ranges.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `ranges` - The ranges to set the values for.\n    /// * `value` - The values to set.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values set for the given ranges.\n    ///\n    /// # Note\n    ///\n    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty slice assignments.\n    fn bool_slice_assign(\n        tensor: BoolTensor<B>,\n        slices: &[Slice],\n        value: BoolTensor<B>,\n    ) -> BoolTensor<B>;\n\n    /// Fills the tensor with values from the value tensor if the mask is true at the given\n    /// indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `mask` - The mask.\n    /// * `value` - The value tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values filled.\n    fn bool_mask_where(\n        tensor: BoolTensor<B>,\n        mask: BoolTensor<B>,\n        value: BoolTensor<B>,\n    ) -> BoolTensor<B>;\n\n    /// Fills the tensor with the given value if the mask is true at the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `mask` - The mask.\n    /// * `value` - The value.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values filled.\n    fn bool_mask_fill(tensor: BoolTensor<B>, mask: BoolTensor<B>, value: Scalar) -> BoolTensor<B>;\n\n    /// Gather elements from the tensor at the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to gather from.\n    /// * `tensor` - The tensor.\n    /// * `indices` - The indices.\n    fn bool_gather(dim: usize, tensor: BoolTensor<B>, indices: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Scatter a given value to the tensor at the given indices using boolean or reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to scatter to.\n    /// * `tensor` - The tensor.\n    /// * `indices` - The indices.\n    /// * `value` - The value.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values scattered.\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<B>,\n        indices: IntTensor<B>,\n        value: BoolTensor<B>,\n    ) -> BoolTensor<B>;\n\n    /// Select tensor elements along the given dimension corresponding to the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices of the elements to select.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements.\n    fn bool_select(tensor: BoolTensor<B>, dim: usize, indices: IntTensor<B>) -> BoolTensor<B> {\n        // Default implementation: convert to int, select, then convert back to bool\n        let int_tensor = B::bool_into_int(tensor);\n        let selected = B::int_select(int_tensor, dim, indices);\n        B::int_equal_elem(selected, 1.into())\n    }\n\n    /// Assign the selected elements along the given dimension corresponding to the given indices\n    /// to the given value using sum reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to assign the values to.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices of the elements to assign.\n    /// * `value` - The values to assign.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the assigned values.\n    fn bool_select_or(\n        tensor: BoolTensor<B>,\n        dim: usize,\n        indices: IntTensor<B>,\n        value: BoolTensor<B>,\n    ) -> BoolTensor<B> {\n        // Default implementation: convert to int, select_assign, then convert back to bool\n        let int_tensor = B::bool_into_int(tensor);\n        let int_values = B::bool_into_int(value);\n        let assigned = B::int_select_add(int_tensor, dim, indices, int_values);\n        // After select_assign with sum reduction, any non-zero value should be true\n        B::int_greater_elem(assigned, 0.into())\n    }\n\n    /// Repeats one dimension of the tensor a given number of times along that dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to repeat.\n    /// * `times` - The number of times to repeat the dimension.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimension repeated.\n    fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {\n        repeat_with_slice_assign::<B, Bool>(tensor, dim, times)\n    }\n\n    /// Concatenates the tensors along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensors` - The tensors to concatenate.\n    /// * `dim` - The dimension to concatenate along.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the tensors concatenated along the given dimension.\n    ///\n    /// # Note\n    ///\n    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty tensors.\n    fn bool_cat(tensors: Vec<BoolTensor<B>>, dim: usize) -> BoolTensor<B> {\n        cat_with_slice_assign::<B, Bool>(tensors, dim)\n    }\n\n    /// Equates the two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the equate.\n    fn bool_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the comparison.\n    fn bool_not_equal(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        let equal_tensor = B::bool_equal(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Element-wise equality comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn bool_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn bool_not_equal_elem(lhs: BoolTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        let equal_tensor = B::bool_equal_elem(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Inverses boolean values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the negation.\n    fn bool_not(tensor: BoolTensor<B>) -> BoolTensor<B>;\n\n    /// Executes the logical and (`&&`) operation on two boolean tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the logical and.\n    fn bool_and(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;\n\n    /// Executes the logical or (`||`) operation on two boolean tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the logical or.\n    fn bool_or(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise exclusive or.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the result of the comparison.\n    fn bool_xor(lhs: BoolTensor<B>, rhs: BoolTensor<B>) -> BoolTensor<B> {\n        Self::bool_not_equal(lhs, rhs)\n    }\n\n    /// Transposes a bool tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    fn bool_transpose(tensor: BoolTensor<B>) -> BoolTensor<B> {\n        let ndims = tensor.shape().num_dims();\n        Self::bool_swap_dims(tensor, ndims - 2, ndims - 1)\n    }\n\n    /// Swaps two dimensions of a bool tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap.\n    /// * `dim2` - The second dimension to swap.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    fn bool_swap_dims(tensor: BoolTensor<B>, dim1: usize, dim2: usize) -> BoolTensor<B>;\n\n    /// Permutes the dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to permute the dimensions of.\n    /// * `axes` - The new order of the dimensions.\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    fn bool_permute(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;\n\n    /// Reverse the order of elements in a tensor along the given axes.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reverse.\n    /// * `axes` - The axes to reverse.\n    ///\n    /// The tensor with the elements reversed.\n    fn bool_flip(tensor: BoolTensor<B>, axes: &[usize]) -> BoolTensor<B>;\n\n    /// Tests if any element in the boolean `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.\n    fn bool_any(tensor: BoolTensor<B>) -> BoolTensor<B> {\n        let sum = B::int_sum(B::bool_into_int(tensor));\n        B::int_greater_elem(sum, 0.into())\n    }\n\n    /// Tests if any element in the boolean `tensor` evaluates to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn bool_any_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {\n        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);\n        B::int_greater_elem(sum, 0.into())\n    }\n\n    /// Tests if all elements in the boolean `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor\n    /// evaluate to True, False otherwise.\n    fn bool_all(tensor: BoolTensor<B>) -> BoolTensor<B> {\n        let num_elems = tensor.shape().num_elements() as i64;\n        let sum = B::int_sum(B::bool_into_int(tensor));\n        B::int_equal_elem(sum, num_elems.into())\n    }\n\n    /// Tests if all elements in the boolean `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn bool_all_dim(tensor: BoolTensor<B>, dim: usize) -> BoolTensor<B> {\n        let num_elems = tensor.shape()[dim] as i64;\n        let sum = B::int_sum_dim(B::bool_into_int(tensor), dim);\n        B::int_equal_elem(sum, num_elems.into())\n    }\n\n    /// Compute the indices of the elements that are non-zero, grouped by element.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A 2D tensor containing the indices of all non-zero elements of the given tensor.\n    /// Each row contains the indices of a non-zero element.\n    fn bool_argwhere(tensor: BoolTensor<B>) -> impl Future<Output = IntTensor<B>> + 'static + Send {\n        async {\n            // Size of each output tensor is variable (= number of nonzero elements in the tensor).\n            // Reading the data to count the number of truth values might cause sync but is required.\n            let device = B::bool_device(&tensor);\n            let data = B::bool_into_data(tensor)\n                .await\n                .expect(\"Can read the data without error\");\n            argwhere_data::<B>(data, &device)\n        }\n    }\n\n    /// Broadcasts the bool `tensor` to the given `shape`.\n    fn bool_expand(tensor: BoolTensor<B>, shape: Shape) -> BoolTensor<B>;\n\n    /// Unfold windows along a dimension.\n    ///\n    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n    /// * `dim` - the selected dim.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.\n    fn bool_unfold(tensor: BoolTensor<B>, dim: usize, size: usize, step: usize) -> BoolTensor<B>;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/cat.rs",
    "content": "use crate::{\n    Backend, TensorMetadata,\n    tensor::{BasicOps, TensorKind},\n};\nuse alloc::vec::Vec;\nuse burn_std::Slice;\n\npub(crate) fn cat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>(\n    tensors: Vec<K::Primitive>,\n    dim: usize,\n) -> K::Primitive {\n    let first_tensor = tensors.first().expect(\"Tensors should not be empty\");\n    let mut shape = first_tensor.shape();\n    let device = K::device(first_tensor);\n    let dtype = first_tensor.dtype();\n\n    let output_dim_length: usize = tensors.iter().map(|tensor| tensor.shape()[dim]).sum();\n    shape[dim] = output_dim_length;\n\n    let mut tensor_output = K::empty(shape.clone(), &device, dtype);\n\n    let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>();\n\n    let mut output_index = 0;\n    for tensor in tensors {\n        let mut indices = indices_select_all.clone();\n        let tensor_dim_length = tensor.shape()[dim];\n        indices[dim] = output_index..output_index + tensor_dim_length;\n        output_index += tensor_dim_length;\n\n        // Convert ranges to Slice\n        let slices: Vec<Slice> = indices\n            .iter()\n            .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1))\n            .collect();\n        tensor_output = K::slice_assign(tensor_output, &slices, tensor);\n    }\n\n    tensor_output\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/int_tensor.rs",
    "content": "use super::cat::cat_with_slice_assign;\nuse super::repeat_dim::repeat_with_slice_assign;\nuse super::sort::{argsort, sort, sort_with_indices};\nuse crate::tensor::{BoolTensor, Device, FloatTensor, Int, IntElem, IntTensor};\nuse crate::{Backend, Distribution, TensorData, TensorMetadata, element::ElementConversion};\nuse crate::{ExecutionError, Scalar};\nuse alloc::vec::Vec;\nuse burn_std::{IntDType, Shape, Slice};\nuse core::ops::Range;\n\n/// Int Tensor API for basic and numeric operations, see\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// for documentation on each function.\npub trait IntTensorOps<B: Backend> {\n    /// Creates a new int tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The integer tensor with the given shape.\n    fn int_empty(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B>;\n\n    /// Converts the tensor to a data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The data structure with the tensor's data.\n    fn int_into_data(\n        tensor: IntTensor<B>,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;\n\n    /// Creates a tensor from the data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The data structure.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the data.\n    fn int_from_data(data: TensorData, device: &Device<B>) -> IntTensor<B>;\n\n    /// Gets the device of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The device of the tensor.\n    fn int_device(tensor: &IntTensor<B>) -> Device<B>;\n\n    /// Moves the tensor to the given device.\n    fn int_to_device(tensor: IntTensor<B>, device: &Device<B>) -> IntTensor<B>;\n\n    /// Reshapes the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `shape` - The new shape.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the new shape.\n    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;\n\n    /// Gets the element at the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `slices` - The slices specifying ranges and steps for each dimension.\n    ///\n    /// # Returns\n    ///\n    /// The elements at the given indices.\n    ///\n    /// # Note\n    ///\n    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not\n    /// be passed to this method. Backend implementations do not need to handle empty slices.\n    fn int_slice(tensor: IntTensor<B>, slices: &[Slice]) -> IntTensor<B>;\n\n    /// Sets the values in the tensor for the given ranges.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `ranges` - The ranges to set the values for.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values set for the given ranges.\n    ///\n    /// # Note\n    ///\n    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty slice assignments.\n    fn int_slice_assign(\n        tensor: IntTensor<B>,\n        slices: &[Slice],\n        value: IntTensor<B>,\n    ) -> IntTensor<B>;\n\n    /// Converts int tensor to float tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The int tensor with the same data as the float tensor.\n    fn int_into_float(tensor: IntTensor<B>) -> FloatTensor<B>;\n\n    /// Fills the tensor with values from the value tensor if the mask is true at the given\n    /// indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `mask` - The mask.\n    /// * `value` - The value tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values filled.\n    fn int_mask_where(\n        tensor: IntTensor<B>,\n        mask: BoolTensor<B>,\n        value: IntTensor<B>,\n    ) -> IntTensor<B>;\n\n    /// Fills the tensor with the given value if the mask is true at the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `mask` - The mask.\n    /// * `value` - The value.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values filled.\n    fn int_mask_fill(tensor: IntTensor<B>, mask: BoolTensor<B>, value: Scalar) -> IntTensor<B>;\n\n    /// Gather elements from the tensor at the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to gather from.\n    /// * `tensor` - The tensor.\n    /// * `indices` - The indices.\n    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B>;\n\n    /// Scatter a given value to the tensor at the given indices using sum reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to scatter to.\n    /// * `tensor` - The tensor.\n    /// * `indices` - The indices.\n    /// * `value` - The value.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the values scattered.\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<B>,\n        indices: IntTensor<B>,\n        value: IntTensor<B>,\n    ) -> IntTensor<B>;\n\n    /// Select tensor elements along the given dimension corresponding to the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements.\n    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B>;\n\n    /// Assign the selected elements along the given dimension corresponding to the given indices\n    /// to the given value using sum reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices.\n    /// * `value` - The value.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements assigned to the given value.\n    fn int_select_add(\n        tensor: IntTensor<B>,\n        dim: usize,\n        indices: IntTensor<B>,\n        value: IntTensor<B>,\n    ) -> IntTensor<B>;\n\n    /// Repeats the tensor along the given dimension the given number of times.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to repeat.\n    /// * `times` - The number of times to repeat.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given dimension repeated the given number of times.\n    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {\n        repeat_with_slice_assign::<B, Int>(tensor, dim, times)\n    }\n\n    /// Concatenates the given tensors along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensors` - The tensors.\n    /// * `dim` - The dimension to concatenate along.\n    ///\n    /// # Returns\n    ///\n    /// The concatenated tensor.\n    ///\n    /// # Note\n    ///\n    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty tensors.\n    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {\n        cat_with_slice_assign::<B, Int>(tensors, dim)\n    }\n\n    /// Element-wise equality comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_not_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {\n        let equal_tensor = B::int_equal(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Element-wise equality comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_not_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        let equal_tensor = B::int_equal_elem(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Element-wise greater than comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise greater than comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_greater_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise greater than or equal comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise greater than or equal comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise less than comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise less than comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_lower_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise less than or equal comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise less than or equal comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The boolean tensor with the result of the comparison.\n    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    // ====  NUMERIC ==== //\n\n    /// Element-wise addition.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of the addition.\n    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise addition with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of the addition.\n    fn int_add_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Element-wise power with a IntTensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side IntTensor.\n    /// * `rhs` - The right-hand side IntTensor.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the power of the elements of `rhs`.\n    fn int_powi(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {\n        B::float_into_int(B::float_powi(B::int_into_float(lhs), rhs))\n    }\n\n    /// Element-wise power with a scalar.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// A number of common exponent cases can be implemented with operations\n    /// which are much cheaper than generic exponentiation.\n    ///\n    /// This (`Backend` impl overridable) operation handles generic optimizations\n    /// for several common integer exponent cases; and then dispatches to\n    /// the (`Backend` impl overridable) [`Self::int_powi_scalar_impl`]\n    /// operation to handle the generic case.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`.\n    fn int_powi_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        let exp = rhs.elem::<i32>();\n        match exp {\n            0 => Self::int_ones(lhs.shape(), &B::int_device(&lhs), lhs.dtype().into()),\n            1 => lhs,\n            2 => Self::int_mul(lhs.clone(), lhs),\n            _ => Self::int_powi_scalar_impl(lhs, rhs),\n        }\n    }\n\n    /// Element-wise power with a scalar.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// This is the generic implementation of integer exponentiation\n    /// called by [`Self::int_powi_scalar`] in the fallback case.\n    ///\n    /// By default, this performs a relatively expensive conversion to float,\n    /// exponentiation in float, and conversion back to int.\n    /// This reduces the minimal operation set for `Backend`s,\n    /// at the cost of performance.\n    ///\n    /// This is a good target for specialized optimizations in `Backend` implementations.\n    ///\n    /// As a general rule, this should not be called directly.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`.\n    fn int_powi_scalar_impl(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B> {\n        B::float_into_int(B::float_powi_scalar_impl(B::int_into_float(lhs), rhs))\n    }\n\n    /// Clamps a tensor under a minimum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn int_clamp_min(tensor: IntTensor<B>, min: Scalar) -> IntTensor<B> {\n        let mask = Self::int_lower_elem(tensor.clone(), min);\n        Self::int_mask_fill(tensor, mask, min)\n    }\n\n    /// Clamps a tensor over a maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn int_clamp_max(tensor: IntTensor<B>, max: Scalar) -> IntTensor<B> {\n        let mask = Self::int_greater_elem(tensor.clone(), max);\n        Self::int_mask_fill(tensor, mask, max)\n    }\n\n    /// Clamps a tensor between a minimum and maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn int_clamp(tensor: IntTensor<B>, min: Scalar, max: Scalar) -> IntTensor<B> {\n        Self::int_clamp_min(Self::int_clamp_max(tensor, max), min)\n    }\n\n    /// Element-wise subtraction.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of the subtraction.\n    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise subtraction with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of the subtraction.\n    fn int_sub_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Element-wise multiplication.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of the multiplication.\n    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise multiplication with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of the multiplication.\n    fn int_mul_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Element-wise division.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of the division.\n    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise division with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of the division.\n    fn int_div_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Element-wise modulus.\n    ///\n    /// # Arguments\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of applying the modulus of the scalar to the tensor.\n    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise modulus with a scalar.\n    ///\n    /// # Arguments\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of applying the modulus of the scalar to the tensor.\n    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Multiplies two tensors together using matrix multiplication.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of multiplying the two tensors together using matrix multiplication.\n    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Element-wise negation.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to negate.\n    ///\n    /// # Returns\n    ///\n    /// The negated tensor.\n    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {\n        Self::int_mul_scalar(tensor, (-1).into())\n    }\n\n    /// Creates a tensor of zeros.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor of zeros.\n    fn int_zeros(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {\n        Self::int_from_data(TensorData::full_dtype(shape, 0, dtype.into()), device)\n    }\n\n    /// Creates a tensor of ones.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor of ones.\n    fn int_ones(shape: Shape, device: &Device<B>, dtype: IntDType) -> IntTensor<B> {\n        Self::int_from_data(TensorData::full_dtype(shape, 1, dtype.into()), device)\n    }\n\n    /// Creates a tensor filled with given value.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `fill_value` - The value with which to fill the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor filled with given value\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<B>,\n        dtype: IntDType,\n    ) -> IntTensor<B> {\n        Self::int_from_data(\n            TensorData::full_dtype(shape, fill_value, dtype.into()),\n            device,\n        )\n    }\n\n    /// Sums all elements in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    ///\n    /// # Returns\n    ///\n    /// The sum of all elements in the tensor.\n    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B>;\n\n    /// Sums all elements in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    /// * `dim` - The dimension to sum along.\n    ///\n    /// # Returns\n    ///\n    /// The sum of all elements in the tensor along the dimension.\n    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the product of all elements in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the product of.\n    ///\n    /// # Returns\n    ///\n    /// The product of all elements in the tensor.\n    fn int_prod(tensor: IntTensor<B>) -> IntTensor<B>;\n\n    /// Computes the product of all elements in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the product of.\n    /// * `dim` - The dimension to compute the product along.\n    ///\n    /// # Returns\n    ///\n    /// The product of all elements in the tensor along the dimension.\n    fn int_prod_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the mean of all elements in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the mean of.\n    ///\n    /// # Returns\n    ///\n    /// The mean of all elements in the tensor.\n    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {\n        let num_elems = tensor.shape().num_elements() as i64;\n        B::int_div_scalar(B::int_sum(tensor), num_elems.into())\n    }\n\n    /// Computes the mean of all elements in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the mean of.\n    ///\n    /// # Returns\n    ///\n    /// The mean of all elements in the tensor along the dimension.\n    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the cumulative sum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative sum of.\n    /// * `dim` - The dimension along which to compute the cumulative sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative sum\n    /// of all elements up to and including that position along the dimension.\n    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the cumulative product of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative product of.\n    /// * `dim` - The dimension along which to compute the cumulative product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative product\n    /// of all elements up to and including that position along the dimension.\n    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the cumulative minimum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative minimum of.\n    /// * `dim` - The dimension along which to compute the cumulative minimum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the minimum\n    /// of all elements up to and including that position along the dimension.\n    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Computes the cumulative maximum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative maximum of.\n    /// * `dim` - The dimension along which to compute the cumulative maximum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the maximum\n    /// of all elements up to and including that position along the dimension.\n    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Gets the indices of the maximum elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum indices of.\n    /// * `dim` - The dimension to get the maximum indices along.\n    ///\n    /// # Returns\n    ///\n    /// The indices of the maximum elements along the dimension.\n    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Gets the indices of the minimum elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum indices of.\n    /// * `dim` - The dimension to get the minimum indices along.\n    ///\n    /// # Returns\n    ///\n    /// The indices of the minimum elements along the dimension.\n    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Gets the maximum element in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum element of.\n    ///\n    /// # Returns\n    ///\n    /// The maximum element in the tensor.\n    fn int_max(tensor: IntTensor<B>) -> IntTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::int_max_dim(tensor, 0)\n    }\n\n    /// Gets the maximum element in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum element of.\n    /// * `dim` - The dimension to get the maximum element along.\n    ///\n    /// # Returns\n    ///\n    /// The maximum element in the tensor along the dimension.\n    fn int_max_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        let index = B::int_argmax(tensor.clone(), dim);\n        B::int_gather(dim, tensor, index)\n    }\n\n    /// Gets the maximum elements and corresponding indices along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements and indices of.\n    /// * `dim` - The dimension to get the maximum elements and indices along.\n    ///\n    /// # Returns\n    ///\n    /// The maximum elements and corresponding indices along the dimension.\n    fn int_max_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {\n        let index = B::int_argmax(tensor.clone(), dim);\n        let values = B::int_gather(dim, tensor, index.clone());\n\n        (values, index)\n    }\n\n    /// Gets the maximum absolute element in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum element of.\n    ///\n    /// # Returns\n    ///\n    /// The maximum element in the tensor.\n    fn int_max_abs(tensor: IntTensor<B>) -> IntTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::int_max_abs_dim(tensor, 0)\n    }\n\n    /// Gets the maximum absolute element in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum element of.\n    /// * `dim` - The dimension to get the maximum element along.\n    ///\n    /// # Returns\n    ///\n    /// The maximum element in the tensor along the dimension.\n    fn int_max_abs_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        B::int_max_dim(B::int_abs(tensor), dim)\n    }\n\n    /// Gets the minimum element in the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum element of.\n    ///\n    /// # Returns\n    ///\n    /// The minimum element in the tensor.\n    fn int_min(tensor: IntTensor<B>) -> IntTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::int_min_dim(tensor, 0)\n    }\n\n    /// Gets the minimum elements in the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum element of.\n    /// * `dim` - The dimension to get the minimum element along.\n    ///\n    /// # Returns\n    ///\n    /// The minimum element in the tensor along the dimension.\n    fn int_min_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {\n        let index = B::int_argmin(tensor.clone(), dim);\n        B::int_gather(dim, tensor, index)\n    }\n\n    /// Gets the minimum elements and corresponding indices along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements and indices of.\n    /// * `dim` - The dimension to get the minimum elements and indices along.\n    ///\n    /// # Returns\n    ///\n    /// The minimum elements and corresponding indices along the dimension.\n    fn int_min_dim_with_indices(tensor: IntTensor<B>, dim: usize) -> (IntTensor<B>, IntTensor<B>) {\n        let indices = B::int_argmin(tensor.clone(), dim);\n        let values = B::int_gather(dim, tensor, indices.clone());\n\n        (values, indices)\n    }\n\n    /// Returns a new tensor with absolute values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take absolute value of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with absolute values.\n    fn int_abs(tensor: IntTensor<B>) -> IntTensor<B>;\n\n    /// Transposes an int tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    fn int_transpose(tensor: IntTensor<B>) -> IntTensor<B> {\n        let ndims = tensor.shape().num_dims();\n        Self::int_swap_dims(tensor, ndims - 2, ndims - 1)\n    }\n\n    /// Swaps two dimensions of an int tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap.\n    /// * `dim2` - The second dimension to swap.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    fn int_swap_dims(tensor: IntTensor<B>, dim1: usize, dim2: usize) -> IntTensor<B>;\n\n    /// Permutes the dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to permute the dimensions of.\n    /// * `axes` - The new order of the dimensions.\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    fn int_permute(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;\n\n    /// Reverse the order of elements in a tensor along the given axes.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reverse.\n    /// * `axes` - The axes to reverse.\n    ///\n    /// The tensor with the elements reversed.\n    fn int_flip(tensor: IntTensor<B>, axes: &[usize]) -> IntTensor<B>;\n\n    /// Creates a new int tensor with random values.\n    ///\n    ///  # Arguments\n    ///  * `shape` - The shape of the tensor.\n    ///  * `distribution` - The distribution to sample from.\n    ///  * `device` - The device to create the tensor on.\n    ///\n    ///  # Returns\n    ///\n    ///  The tensor with the given shape and random values.\n    fn int_random(shape: Shape, distribution: Distribution, device: &Device<B>) -> IntTensor<B>;\n\n    /// Creates a new tensor with values from the given range with the given step size.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range of values.\n    /// * `step` - The step size.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given values.\n    fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B> {\n        let value = range\n            .step_by(step)\n            .map(|i| i.elem())\n            .collect::<Vec<IntElem<B>>>();\n        let shape = Shape::new([value.len()]);\n        let data = TensorData::new(value, shape);\n        B::int_from_data(data, device)\n    }\n\n    /// Creates a new tensor with values from the given range.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range of values.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given values.\n    ///\n    /// # Remarks\n    ///\n    /// Uses `arange_step` with a step size of 1 under the hood.\n    fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B> {\n        Self::int_arange_step(range, 1, device)\n    }\n\n    /// Tests if any element in the int `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.\n    fn int_any(tensor: IntTensor<B>) -> BoolTensor<B> {\n        let bool_tensor = B::int_equal_elem(tensor, 0.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::int_sum(B::bool_into_int(bool_tensor));\n        B::int_greater_elem(sum, 0.into())\n    }\n\n    /// Tests if any element in the int `tensor` evaluates to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn int_any_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {\n        let bool_tensor = B::int_equal_elem(tensor, 0.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);\n        B::int_greater_elem(sum, 0.into())\n    }\n\n    /// Tests if all elements in the int `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor\n    /// evaluate to True, False otherwise.\n    fn int_all(tensor: IntTensor<B>) -> BoolTensor<B> {\n        let num_elems = tensor.shape().num_elements() as i64;\n        let bool_tensor = B::int_equal_elem(tensor, 0.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::int_sum(B::bool_into_int(bool_tensor));\n        B::int_equal_elem(sum, num_elems.into())\n    }\n\n    /// Tests if all elements in the int `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn int_all_dim(tensor: IntTensor<B>, dim: usize) -> BoolTensor<B> {\n        let num_elems = tensor.shape()[dim] as i64;\n        let bool_tensor = B::int_equal_elem(tensor, 0.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::int_sum_dim(B::bool_into_int(bool_tensor), dim);\n        B::int_equal_elem(sum, num_elems.into())\n    }\n\n    /// Returns the signs of the int `tensor`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to extract the signs from.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.\n    fn int_sign(tensor: IntTensor<B>) -> IntTensor<B> {\n        let dtype = tensor.dtype();\n        let zeros = B::int_zeros(tensor.shape(), &B::int_device(&tensor), dtype.into());\n        let less_than_zero = B::int_lower_elem(tensor.clone(), 0.into());\n        let greater_than_zero = B::int_greater_elem(tensor, 0.into());\n\n        let mut result = B::int_mask_fill(zeros, less_than_zero, (-1).into());\n        result = B::int_mask_fill(result, greater_than_zero, 1.into());\n        result\n    }\n\n    /// Broadcasts the int `tensor` to the given `shape`.\n    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B>;\n\n    /// Sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.\n    fn int_sort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {\n        sort::<B, Int>(tensor, dim, descending)\n    }\n\n    /// Sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor and corresponding indices, where\n    /// the elements are sorted by value and the indices map back to the original input tensor.\n    fn int_sort_with_indices(\n        tensor: IntTensor<B>,\n        dim: usize,\n        descending: bool,\n    ) -> (IntTensor<B>, IntTensor<B>) {\n        sort_with_indices::<B, Int>(tensor, dim, descending)\n    }\n\n    /// Returns the indices that sort the elements of the input `tensor` by value\n    /// along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.\n    fn int_argsort(tensor: IntTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {\n        argsort::<B, Int>(tensor, dim, descending)\n    }\n\n    /// Bitwise AND operation for Int Tensors\n    fn bitwise_and(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise AND operation for Int Tensors with a scalar\n    fn bitwise_and_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Bitwise OR operation for Int Tensors\n    fn bitwise_or(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise OR operation for Int Tensors with a scalar\n    fn bitwise_or_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Bitwise XOR operation for Int Tensors\n    fn bitwise_xor(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise XOR operation for Int Tensors with a scalar\n    fn bitwise_xor_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Bitwise NOT operation for Int Tensors\n    fn bitwise_not(tensor: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise left shift operation for Int Tensors\n    fn bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise left shift operation for Int Tensors with a scalar\n    fn bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Bitwise right shift operation for Int Tensors\n    fn bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;\n\n    /// Bitwise right shift operation for Int Tensors with a scalar\n    fn bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: Scalar) -> IntTensor<B>;\n\n    /// Converts a tensor to another integer data type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to convert.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same values as `tensor` but in the target integer data type.\n    fn int_cast(tensor: IntTensor<B>, dtype: IntDType) -> IntTensor<B>;\n\n    /// Unfold windows along a dimension.\n    ///\n    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n    /// * `dim` - the selected dim.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.\n    fn int_unfold(tensor: IntTensor<B>, dim: usize, size: usize, step: usize) -> IntTensor<B>;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/mod.rs",
    "content": "mod activation;\nmod bool_tensor;\nmod int_tensor;\nmod modules;\nmod qtensor;\nmod tensor;\nmod transaction;\n\npub(crate) mod argwhere;\npub(crate) mod cat;\npub(crate) mod repeat_dim;\npub(crate) mod sort;\n\npub use activation::*;\npub use bool_tensor::*;\npub use int_tensor::*;\npub use modules::*;\npub use qtensor::*;\npub use tensor::*;\npub use transaction::*;\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/attention.rs",
    "content": "use core::f32;\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\nuse burn_std::Shape;\n\nuse crate::{\n    Backend, TensorMetadata,\n    ops::AttentionModuleOptions,\n    tensor::{BoolTensor, FloatTensor},\n};\n\n/// Computes softmax(QKᵗ * scale) · V using separate kernels.\n/// Serves as a fallback when FlashAttention is not used.\npub fn attention_fallback<B: Backend>(\n    query: FloatTensor<B>,\n    key: FloatTensor<B>,\n    value: FloatTensor<B>,\n    mask: Option<BoolTensor<B>>,\n    attn_bias: Option<FloatTensor<B>>,\n    options: AttentionModuleOptions,\n) -> FloatTensor<B> {\n    if let Some(softcap) = options.softcap {\n        assert!(softcap > 0.0, \"softcap must be positive, got {softcap}\");\n    }\n\n    // Attention scores: A = QKᵗ * scale\n    let query_shape = query.shape().dims::<4>();\n    let scale = options\n        .scale\n        .unwrap_or_else(|| 1.0 / (*query_shape.last().unwrap() as f64).sqrt());\n    let transposed_key = B::float_transpose(key);\n    let qk = B::float_matmul(query, transposed_key);\n    let attention_scores = B::float_mul_scalar(qk, scale.into());\n\n    // Softcap: softcap * tanh(scores / softcap)\n    // Applied to raw logits before any -inf masking, so that tanh does not\n    // map -inf to a finite value (which would break masking semantics).\n    let attention_scores = if let Some(softcap) = options.softcap {\n        let scaled = B::float_div_scalar(attention_scores, softcap.into());\n        let tanh = B::float_tanh(scaled);\n        B::float_mul_scalar(tanh, softcap.into())\n    } else {\n        attention_scores\n    };\n\n    // Bool masking\n    let attention_scores = if let Some(mask) = mask {\n        B::float_mask_fill(attention_scores, mask, f32::NEG_INFINITY.into())\n    } else {\n        attention_scores\n    };\n\n    // Causal masking: mask positions where col > row (future positions)\n    let attention_scores = if options.is_causal {\n        let causal_mask = build_causal_mask::<B>(&attention_scores);\n        B::float_mask_fill(attention_scores, causal_mask, f32::NEG_INFINITY.into())\n    } else {\n        attention_scores\n    };\n\n    // Additive bias (ALiBi, relative position biases, etc.)\n    let attention_scores = if let Some(bias) = attn_bias {\n        B::float_add(attention_scores, bias)\n    } else {\n        attention_scores\n    };\n\n    // Softmax: S = softmax(A)\n    let max_per_dim = B::float_max_dim(attention_scores.clone(), 3);\n    let minus_max = B::float_sub(attention_scores, max_per_dim);\n    let numerator = B::float_exp(minus_max);\n    let sum_exp = B::float_sum_dim(numerator.clone(), 3);\n    let softmax = B::float_div(numerator, sum_exp);\n\n    // Context: S · V\n    B::float_matmul(softmax, value)\n}\n\n/// Builds a causal (upper-triangular) bool mask where `true` means \"mask this position\".\n/// Shape: [batch_size, num_heads, seq_q, seq_k], masking positions where col > row.\nfn build_causal_mask<B: Backend>(attention_scores: &FloatTensor<B>) -> BoolTensor<B> {\n    let device = B::float_device(attention_scores);\n    let scores_shape = attention_scores.shape().dims::<4>();\n    let [batch_size, num_heads, seq_q, seq_k] = scores_shape;\n\n    // row indices [seq_q, 1] and col indices [1, seq_k]\n    // Offset col indices so that the causal boundary aligns at the bottom-right corner,\n    // which handles cross-attention (seq_k > seq_q) correctly.\n    let offset = seq_k as i64 - seq_q as i64;\n    let rows = B::int_reshape(\n        B::int_arange(0..seq_q as i64, &device),\n        Shape::new([seq_q, 1]),\n    );\n    let cols = B::int_reshape(\n        B::int_arange(0..seq_k as i64, &device),\n        Shape::new([1, seq_k]),\n    );\n\n    // mask where col > row + offset (upper triangle)\n    let rows_shifted = B::int_add_scalar(rows, offset.into());\n    let mask_2d = B::int_lower(rows_shifted, cols);\n\n    // Reshape to [1, 1, seq_q, seq_k] then expand to [batch_size, num_heads, seq_q, seq_k]\n    let mask_4d = B::bool_reshape(mask_2d, Shape::new([1, 1, seq_q, seq_k]));\n    B::bool_expand(mask_4d, Shape::new([batch_size, num_heads, seq_q, seq_k]))\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/base.rs",
    "content": "use super::{conv, pool};\nuse crate::ops::unfold::unfold4d_using_conv2d;\nuse crate::tensor::{BoolTensor, FloatTensor, IntTensor};\nuse crate::{Backend, ElementConversion, TensorMetadata};\nuse burn_std::Shape;\nuse core::num::NonZeroUsize;\n\n/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).\n#[derive(new)]\npub struct Conv2dBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n\n    /// Weights gradient.\n    pub weights_grad: FloatTensor<B>,\n\n    /// Bias gradient.\n    pub bias_grad: Option<FloatTensor<B>>,\n}\n\n/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).\n#[derive(new)]\npub struct DeformConv2dBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n\n    /// Offset gradient.\n    pub offset_grad: FloatTensor<B>,\n\n    /// Weights gradient.\n    pub weight_grad: FloatTensor<B>,\n\n    /// Mask gradient.\n    pub mask_grad: Option<FloatTensor<B>>,\n\n    /// Bias gradient.\n    pub bias_grad: Option<FloatTensor<B>>,\n}\n\n/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).\n#[derive(new)]\npub struct Conv3dBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n\n    /// Weights gradient.\n    pub weights_grad: FloatTensor<B>,\n\n    /// Bias gradient.\n    pub bias_grad: Option<FloatTensor<B>>,\n}\n\n/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).\n#[derive(new)]\npub struct MaxPool1dBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n}\n\n/// Results from [max_pool1d](ModuleOps::max_pool1d_with_indices).\n#[derive(new)]\npub struct MaxPool1dWithIndices<B: Backend> {\n    /// The output tensor.\n    pub output: FloatTensor<B>,\n\n    /// The indices tensor.\n    pub indices: IntTensor<B>,\n}\n\n/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).\n#[derive(new)]\npub struct MaxPool2dBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n}\n\n/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).\n#[derive(new)]\npub struct MaxPool2dWithIndices<B: Backend> {\n    /// The output tensor.\n    pub output: FloatTensor<B>,\n\n    /// The indices tensor.\n    pub indices: IntTensor<B>,\n}\n\n/// Check that the parameter value is non-zero.\n// NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`.\npub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {\n    NonZeroUsize::new(value).expect(msg);\n    value\n}\n\n/// Convolution options.\n#[derive(Debug, Clone, Hash, PartialEq, Eq)]\npub struct ConvOptions<const N: usize> {\n    /// Stride (non-zero).\n    pub stride: [usize; N],\n\n    /// Padding.\n    pub padding: [usize; N],\n\n    /// Dilation (non-zero).\n    pub dilation: [usize; N],\n\n    /// Groups (non-zero).\n    pub groups: usize,\n}\n\nimpl<const N: usize> ConvOptions<N> {\n    /// Constructs a new `ConvOptions`.\n    pub fn new(\n        stride: [usize; N],\n        padding: [usize; N],\n        dilation: [usize; N],\n        groups: usize,\n    ) -> Self {\n        Self {\n            stride: stride.map(|s| check_nonzero(s, \"stride must be non-zero\")),\n            padding,\n            dilation: dilation.map(|d| check_nonzero(d, \"dilation must be non-zero\")),\n            groups: check_nonzero(groups, \"groups must be non-zero\"),\n        }\n    }\n}\n\n/// Convolution options with support for asymmetric padding.\n///\n/// Wraps [`ConvOptions`] (which represents symmetric padding for the backend op)\n/// and adds optional asymmetric padding. When asymmetric padding is specified,\n/// the functional convolution layer applies an explicit pad operation before\n/// dispatching to the backend.\n///\n/// Implements `From<ConvOptions<N>>` for backward compatibility.\n#[derive(Debug, Clone)]\npub struct PaddedConvOptions<const N: usize> {\n    /// The underlying convolution options for the backend.\n    pub options: ConvOptions<N>,\n    /// Padding at the end of each dimension (e.g., bottom/right for 2D).\n    /// If `None`, padding is symmetric (same as `options.padding`).\n    /// If `Some`, specifies different end-padding per dimension.\n    pub padding_end: Option<[usize; N]>,\n}\n\nimpl<const N: usize> PaddedConvOptions<N> {\n    /// Creates options with asymmetric padding.\n    ///\n    /// `padding_start` is stored in `ConvOptions::padding`.\n    /// `padding_end` specifies the end padding per dimension.\n    pub fn asymmetric(\n        stride: [usize; N],\n        padding_start: [usize; N],\n        padding_end: [usize; N],\n        dilation: [usize; N],\n        groups: usize,\n    ) -> Self {\n        let options = ConvOptions::new(stride, padding_start, dilation, groups);\n        if padding_start == padding_end {\n            Self {\n                options,\n                padding_end: None,\n            }\n        } else {\n            Self {\n                options,\n                padding_end: Some(padding_end),\n            }\n        }\n    }\n\n    /// Returns true if padding is asymmetric.\n    pub fn is_asymmetric(&self) -> bool {\n        self.padding_end.is_some()\n    }\n}\n\nimpl<const N: usize> From<ConvOptions<N>> for PaddedConvOptions<N> {\n    fn from(options: ConvOptions<N>) -> Self {\n        Self {\n            options,\n            padding_end: None,\n        }\n    }\n}\n\n/// Convolution options.\n#[derive(Debug, Clone, Hash, PartialEq, Eq)]\npub struct DeformConvOptions<const N: usize> {\n    /// Stride (non-zero).\n    pub stride: [usize; N],\n\n    /// Padding.\n    pub padding: [usize; N],\n\n    /// Dilation (non-zero).\n    pub dilation: [usize; N],\n\n    /// Weight Groups (non-zero).\n    pub weight_groups: usize,\n\n    /// Offset Groups (non-zero).\n    pub offset_groups: usize,\n}\n\nimpl<const N: usize> DeformConvOptions<N> {\n    /// Constructs a new `DeformConvOptions`.\n    pub fn new(\n        stride: [usize; N],\n        padding: [usize; N],\n        dilation: [usize; N],\n        weight_groups: usize,\n        offset_groups: usize,\n    ) -> Self {\n        Self {\n            stride: stride.map(|s| check_nonzero(s, \"stride must be non-zero\")),\n            padding,\n            dilation: dilation.map(|d| check_nonzero(d, \"dilation must be non-zero\")),\n            weight_groups: check_nonzero(weight_groups, \"weight groups must be non-zero\"),\n            offset_groups: check_nonzero(offset_groups, \"offset groups must be non-zero\"),\n        }\n    }\n}\n\n/// Transposed convolution options.\n#[derive(Debug, Clone, Hash, PartialEq, Eq)]\npub struct ConvTransposeOptions<const N: usize> {\n    /// Stride (non-zero).\n    pub stride: [usize; N],\n\n    /// Padding.\n    pub padding: [usize; N],\n\n    /// Padding out.\n    pub padding_out: [usize; N],\n\n    /// Dilation (non-zero).\n    pub dilation: [usize; N],\n\n    /// Groups (non-zero).\n    pub groups: usize,\n}\n\nimpl<const N: usize> ConvTransposeOptions<N> {\n    /// Constructs a new `ConvTransposeOptions`.\n    pub fn new(\n        stride: [usize; N],\n        padding: [usize; N],\n        padding_out: [usize; N],\n        dilation: [usize; N],\n        groups: usize,\n    ) -> Self {\n        Self {\n            stride: stride.map(|s| check_nonzero(s, \"stride must be non-zero\")),\n            padding,\n            padding_out,\n            dilation: dilation.map(|d| check_nonzero(d, \"dilation must be non-zero\")),\n            groups: check_nonzero(groups, \"groups must be non-zero\"),\n        }\n    }\n}\n\n/// Unfold operation options.\n#[derive(Debug, Clone)]\npub struct UnfoldOptions {\n    /// The number of positions to slide over the input tensor in each dimension.\n    /// A stride of `[1, 1]` will slide the kernel one pixel at a time.\n    pub stride: [usize; 2],\n\n    /// The number of zero-padding pixels added to each side of the input tensor in each dimension.\n    pub padding: [usize; 2],\n\n    /// The spacing between the blocks (patches) in the original input tensor.\n    pub dilation: [usize; 2],\n}\n\nimpl UnfoldOptions {\n    /// Constructs a new `UnfoldOptions`.\n    pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {\n        Self {\n            stride: stride.map(|s| check_nonzero(s, \"stride must be non-zero\")),\n            padding,\n            dilation: dilation.map(|d| check_nonzero(d, \"dilation must be non-zero\")),\n        }\n    }\n}\n\n/// Algorithm used for upsampling.\n#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]\npub enum InterpolateMode {\n    /// Nearest-neighbor interpolation.\n    /// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>\n    Nearest,\n\n    /// Bilinear interpolation.\n    /// <https://en.wikipedia.org/wiki/Bilinear_interpolation>\n    Bilinear,\n\n    /// Bicubic interpolation.\n    /// <https://en.wikipedia.org/wiki/Bicubic_interpolation>\n    Bicubic,\n\n    /// Lanczos3 interpolation (6-tap sinc-based filter).\n    /// <https://en.wikipedia.org/wiki/Lanczos_resampling>\n    Lanczos3,\n}\n\n/// Interpolation options.\n#[derive(Debug, Clone)]\npub struct InterpolateOptions {\n    /// Algorithm used for upsampling.\n    pub mode: InterpolateMode,\n    /// If `true`, the input and output tensors are aligned by their corner pixels.\n    /// If `false`, half-pixel coordinate mapping is used instead.\n    pub align_corners: bool,\n}\n\nimpl InterpolateOptions {\n    /// Create new interpolate options with the given mode.\n    /// Defaults to `align_corners = true`.\n    pub fn new(mode: InterpolateMode) -> Self {\n        Self {\n            mode,\n            align_corners: true,\n        }\n    }\n\n    /// Set align_corners.\n    pub fn with_align_corners(mut self, align_corners: bool) -> Self {\n        self.align_corners = align_corners;\n        self\n    }\n}\n\n/// Padding mode for grid sampling when coordinates are out of bounds.\n///\n/// Matches PyTorch's `padding_mode` parameter in `grid_sample`.\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]\npub enum GridSamplePaddingMode {\n    /// Fill with zeros for out-of-bounds coordinates.\n    #[default]\n    Zeros,\n    /// Clamp coordinates to the border (use nearest edge value).\n    Border,\n    /// Reflect coordinates at the boundary.\n    Reflection,\n}\n\n/// Options for grid sampling operations.\n#[derive(Debug, Clone)]\npub struct GridSampleOptions {\n    /// Interpolation mode (bilinear, nearest, or bicubic).\n    pub mode: InterpolateMode,\n    /// Padding mode for out-of-bounds coordinates.\n    pub padding_mode: GridSamplePaddingMode,\n    /// If `true`, grid values of -1 and 1 correspond to the corner pixels.\n    /// If `false`, they correspond to the corner points of the corner pixels\n    /// (i.e., -1 maps to -0.5 and 1 maps to size - 0.5 in pixel coordinates).\n    pub align_corners: bool,\n}\n\nimpl Default for GridSampleOptions {\n    fn default() -> Self {\n        Self {\n            mode: InterpolateMode::Bilinear,\n            padding_mode: GridSamplePaddingMode::Zeros,\n            align_corners: false,\n        }\n    }\n}\n\nimpl From<InterpolateMode> for GridSampleOptions {\n    fn from(value: InterpolateMode) -> Self {\n        GridSampleOptions::new(value)\n    }\n}\n\nimpl GridSampleOptions {\n    /// Create new grid sample options with the given interpolation mode.\n    ///\n    /// Uses default values for padding_mode (Zeros) and align_corners (false).\n    pub fn new(mode: InterpolateMode) -> Self {\n        Self {\n            mode,\n            ..Default::default()\n        }\n    }\n\n    /// Set the padding mode.\n    pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {\n        self.padding_mode = padding_mode;\n        self\n    }\n\n    /// Set align_corners.\n    pub fn with_align_corners(mut self, align_corners: bool) -> Self {\n        self.align_corners = align_corners;\n        self\n    }\n}\n\n/// Padding mode for tensor pad operations.\n///\n/// Defines how values are filled when padding a tensor beyond its original boundaries.\n/// Padding can be applied to any dimension of a tensor.\n///\n/// # Modes\n///\n/// - [`Constant`](PadMode::Constant): Fill with a specified value (default: 0.0)\n/// - [`Reflect`](PadMode::Reflect): Mirror values at boundary, excluding edge (requires padding < dim_size)\n/// - [`Edge`](PadMode::Edge): Replicate boundary values\n#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]\npub enum PadMode {\n    /// Fill padded regions with a constant value.\n    ///\n    /// # Example\n    /// For tensor `[1, 2, 3]` with padding 2 on the left and value 0:\n    /// Result: `[0, 0, 1, 2, 3]`\n    Constant(f32),\n\n    /// Reflect values at the boundary, excluding the edge value.\n    ///\n    /// Padding must be less than the dimension size (i.e., `padding < dim_size`).\n    ///\n    /// # Example\n    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:\n    /// Result: `[3, 2, 1, 2, 3, 4]` (reflects from index 1, not 0)\n    Reflect,\n\n    /// Replicate the edge values.\n    ///\n    /// # Example\n    /// For tensor `[1, 2, 3, 4]` with padding 2 on the left:\n    /// Result: `[1, 1, 1, 2, 3, 4]`\n    Edge,\n}\n\nimpl Default for PadMode {\n    fn default() -> Self {\n        PadMode::Constant(0.0)\n    }\n}\n\nimpl<E: ElementConversion> From<E> for PadMode {\n    fn from(value: E) -> Self {\n        PadMode::Constant(value.elem())\n    }\n}\n\n/// Gradient computed during the backward pass for each tensor used by [interpolate](ModuleOps::interpolate).\n#[derive(new)]\npub struct InterpolateBackward<B: Backend> {\n    /// Gradient.\n    pub x_grad: FloatTensor<B>,\n}\n\n/// Options for [attention](ModuleOps::attention).\n#[derive(Debug, Clone, Copy, Default, PartialEq, serde::Deserialize, serde::Serialize)]\npub struct AttentionModuleOptions {\n    /// Custom scale factor applied to QK^T. When `None`, defaults to `1/sqrt(head_dim)`.\n    pub scale: Option<f64>,\n\n    /// Soft capping applied before softmax: `softcap * tanh(scores / softcap)`.\n    /// Used by Gemma-2 and similar models. Must be positive when set.\n    pub softcap: Option<f64>,\n\n    /// When `true`, applies causal (autoregressive) masking so that each query position\n    /// can only attend to key positions at or before it. This is more efficient than\n    /// passing an explicit lower-triangular bool mask because backends can use optimized\n    /// kernel paths (e.g. flash attention with causal mode).\n    pub is_causal: bool,\n}\n\n/// Module operations trait.\npub trait ModuleOps<B: Backend> {\n    /// Embedding operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `weights` - The embedding weights.\n    /// * `indices` - The indices tensor.\n    ///\n    /// # Returns\n    ///\n    /// The output tensor.\n    fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {\n        let [batch_size, seq_length] = indices.shape().dims();\n        let [_, d_model] = weights.shape().dims();\n\n        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));\n        let output = B::float_select(weights, 0, indices);\n\n        B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))\n    }\n\n    /// Embedding backward operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `weights` - The embedding weights.\n    /// * `output_grad` - The output gradient.\n    /// * `indices` - The indices tensor.\n    ///\n    /// # Returns\n    ///\n    /// The gradient.\n    fn embedding_backward(\n        weights: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        indices: IntTensor<B>,\n    ) -> FloatTensor<B> {\n        let [batch_size, seq_length] = indices.shape().dims();\n        let [n_embeddings, d_model] = weights.shape().dims();\n        let device = B::float_device(&weights);\n        let dtype = output_grad.dtype();\n\n        let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));\n        let output_grad =\n            B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));\n        let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());\n\n        B::float_select_add(grad, 0, indices, output_grad)\n    }\n    /// One dimensional convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, length]`,\n    /// weight: `[channels_out, channels_in, kernel_size]`,\n    /// bias:   `[channels_out]`,\n    fn conv1d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv1d_from_conv2d::<B>(x, weight, bias, options)\n    }\n    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `x`.\n    fn conv1d_x_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv1d_x_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `weight`.\n    fn conv1d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv1d](ModuleOps::conv1d) operation, returning the gradient for `bias`.\n    fn conv1d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv1d_bias_backward::<B>(x, bias, output_grad)\n    }\n    /// Two dimensional convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, height, width]`,\n    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,\n    /// bias:   `[channels_out]`,\n    fn conv2d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `x`.\n    fn conv2d_x_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<B> {\n        conv::conv2d_x_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `weight`.\n    fn conv2d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<B> {\n        conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv2d](ModuleOps::conv2d) operation, returning the gradient for `bias`.\n    fn conv2d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv2d_bias_backward::<B>(x, bias, output_grad)\n    }\n\n    /// Two dimensional deformable convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, height, width]`,\n    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,\n    /// bias:   `[channels_out]`,\n    fn deform_conv2d(\n        x: FloatTensor<B>,\n        offset: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        mask: Option<FloatTensor<B>>,\n        bias: Option<FloatTensor<B>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.\n    fn deform_conv2d_backward(\n        x: FloatTensor<B>,\n        offset: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        mask: Option<FloatTensor<B>>,\n        bias: Option<FloatTensor<B>>,\n        output_grad: FloatTensor<B>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<B>;\n\n    /// Three dimensional convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, depth, height, width]`,\n    /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2, kernel_size_3]`,\n    /// bias:   `[channels_out]`,\n    fn conv3d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `x`.\n    fn conv3d_x_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<B> {\n        conv::conv3d_x_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `weight`.\n    fn conv3d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<B> {\n        conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv3d](ModuleOps::conv3d) operation, returning the gradient for `bias`.\n    fn conv3d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv3d_bias_backward::<B>(x, bias, output_grad)\n    }\n    /// One dimensional transposed convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, length]`,\n    /// weight: `[channels_in, channels_out, length]`,\n    /// bias:   `[channels_out]`,\n    fn conv_transpose1d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)\n    }\n    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `x`.\n    fn conv_transpose1d_x_backward(\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `weight`.\n    fn conv_transpose1d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 1d](ModuleOps::conv_transpose1d) operation, returning the gradient for `bias`.\n    fn conv_transpose1d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)\n    }\n\n    /// Two dimensional transposed convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, height, width]`,\n    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,\n    /// bias:   `[channels_out]`,\n    fn conv_transpose2d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `x`.\n    fn conv_transpose2d_x_backward(\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `weight`.\n    fn conv_transpose2d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 2d](ModuleOps::conv_transpose2d) operation, returning the gradient for `bias`.\n    fn conv_transpose2d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)\n    }\n\n    /// Three dimensional transposed convolution.\n    ///\n    /// # Shapes\n    ///\n    /// x:      `[batch_size, channels_in, height, width]`,\n    /// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2, kernel_size_3]`,\n    /// bias:   `[channels_out]`,\n    fn conv_transpose3d(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        bias: Option<FloatTensor<B>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `x`.\n    fn conv_transpose3d_x_backward(\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `weight`.\n    fn conv_transpose3d_weight_backward(\n        x: FloatTensor<B>,\n        weight: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)\n    }\n    /// Backward pass for the [conv transpose 3d](ModuleOps::conv_transpose3d) operation, returning the gradient for `bias`.\n    fn conv_transpose3d_bias_backward(\n        x: FloatTensor<B>,\n        bias: FloatTensor<B>,\n        output_grad: FloatTensor<B>,\n    ) -> FloatTensor<B> {\n        conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)\n    }\n\n    /// Four-dimensional unfolding.\n    ///\n    /// # Shapes\n    ///\n    /// * x:      ``[batch_size, channels_in, height, width]``,\n    /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,\n    fn unfold4d(\n        x: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        options: UnfoldOptions,\n    ) -> FloatTensor<B> {\n        if options.padding == [0, 0] && options.dilation == [1, 1] {\n            let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);\n            let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);\n\n            // batch, channels, h_blocks, w_blocks, h_kern, w_kern\n\n            let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);\n            let shape = blocks.shape();\n\n            // batch, channels, h_kern, w_kern, h_blocks, w_blocks\n\n            B::float_reshape(\n                blocks,\n                [\n                    shape[0],\n                    shape[1] * shape[2] * shape[3],\n                    shape[4] * shape[5],\n                ]\n                .into(),\n            )\n        } else {\n            unfold4d_using_conv2d::<B>(x, kernel_size, options)\n        }\n    }\n\n    /// One dimensional avg pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, length],\n    fn avg_pool1d(\n        x: FloatTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<B> {\n        pool::avg_pool1d_from_2d::<B>(\n            x,\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n        )\n    }\n    /// Backward pass for the [avg pooling 1d](ModuleOps::avg_pool1d) operation.\n    fn avg_pool1d_backward(\n        x: FloatTensor<B>,\n        grad: FloatTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<B> {\n        pool::avg_pool1d_backward_from_2d::<B>(\n            x,\n            grad,\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n        )\n    }\n    /// Two dimensional avg pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, height, width],\n    fn avg_pool2d(\n        x: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<B>;\n    /// Backward pass for the [avg pooling 2d](ModuleOps::avg_pool2d) operation.\n    fn avg_pool2d_backward(\n        x: FloatTensor<B>,\n        grad: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<B>;\n    /// Two dimensional adaptive avg pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, height, width],\n    fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;\n    /// Backward pass for the [adaptive avg pooling 2d](ModuleOps::adaptive_avg_pool2d) operation.\n    fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;\n    /// One dimensional adaptive avg pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, length],\n    fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {\n        pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)\n    }\n    /// Backward pass for the [adaptive avg pooling 1d](ModuleOps::adaptive_avg_pool1d) operation.\n    fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {\n        pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)\n    }\n    /// One dimensional max pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, length],\n    fn max_pool1d(\n        x: FloatTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> FloatTensor<B> {\n        pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)\n    }\n\n    /// One dimensional max pooling with indices.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, height, width],\n    fn max_pool1d_with_indices(\n        x: FloatTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<B> {\n        pool::max_pool1d_with_indices_from_2d::<B>(\n            x,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n        )\n    }\n    /// Backward pass for the [max pooling 1d](ModuleOps::max_pool1d_with_indices) operation.\n    #[allow(clippy::too_many_arguments)]\n    fn max_pool1d_with_indices_backward(\n        x: FloatTensor<B>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        output_grad: FloatTensor<B>,\n        indices: IntTensor<B>,\n    ) -> MaxPool1dBackward<B> {\n        pool::max_pool1d_with_indices_backward_from_2d::<B>(\n            x,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            output_grad,\n            indices,\n        )\n    }\n\n    /// Two dimensional max pooling.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, height, width],\n    fn max_pool2d(\n        x: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<B>;\n\n    /// Two dimensional max pooling with indices.\n    ///\n    /// # Shapes\n    ///\n    /// x: [batch_size, channels, height, width],\n    fn max_pool2d_with_indices(\n        x: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<B>;\n    /// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.\n    #[allow(clippy::too_many_arguments)]\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<B>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<B>,\n        indices: IntTensor<B>,\n    ) -> MaxPool2dBackward<B>;\n\n    /// Down/up samples the input.\n    ///\n    /// # Shapes\n    ///\n    /// x: `[batch_size, channels, height, width]`,\n    fn interpolate(\n        x: FloatTensor<B>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<B>;\n\n    /// Backward pass for the [interpolate](ModuleOps::interpolate) operation.\n    fn interpolate_backward(\n        x: FloatTensor<B>,\n        grad: FloatTensor<B>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<B>;\n\n    /// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,\n    /// where scale defaults to 1/sqrt(head_dim). Optionally applies masking,\n    /// additive bias, causal masking, and softcap to the attention scores.\n    ///\n    /// # Arguments\n    /// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`\n    /// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`\n    /// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`\n    /// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,\n    ///   where `true` indicates positions to mask (i.e. set to -inf before softmax).\n    /// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`\n    ///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).\n    /// - `options`: Additional attention options (custom scale, softcap, causal masking).\n    ///\n    /// # Returns\n    /// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`\n    /// representing the attended context per head.\n    ///\n    /// # Note\n    /// This implementation does not support dropout and is intended for inference or\n    /// use cases where dropout is not needed.\n    fn attention(\n        query: FloatTensor<B>,\n        key: FloatTensor<B>,\n        value: FloatTensor<B>,\n        mask: Option<BoolTensor<B>>,\n        attn_bias: Option<FloatTensor<B>>,\n        options: AttentionModuleOptions,\n    ) -> FloatTensor<B>;\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    #[should_panic = \"stride must be non-zero\"]\n    fn conv_options_stride_zero() {\n        let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);\n    }\n\n    #[test]\n    #[should_panic = \"dilation must be non-zero\"]\n    fn conv_options_dilation_zero() {\n        let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);\n    }\n\n    #[test]\n    #[should_panic = \"groups must be non-zero\"]\n    fn conv_options_groups_zero() {\n        let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);\n    }\n\n    #[test]\n    #[should_panic = \"stride must be non-zero\"]\n    fn conv_transpose_options_stride_zero() {\n        let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);\n    }\n\n    #[test]\n    #[should_panic = \"dilation must be non-zero\"]\n    fn conv_transpose_options_dilation_zero() {\n        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);\n    }\n\n    #[test]\n    #[should_panic = \"groups must be non-zero\"]\n    fn conv_transpose_options_groups_zero() {\n        let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);\n    }\n\n    #[test]\n    #[should_panic = \"stride must be non-zero\"]\n    fn deform_conv_options_stride_zero() {\n        let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);\n    }\n\n    #[test]\n    #[should_panic = \"dilation must be non-zero\"]\n    fn deform_conv_options_dilation_zero() {\n        let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);\n    }\n\n    #[test]\n    #[should_panic = \"weight groups must be non-zero\"]\n    fn deform_conv_options_weights_groups_zero() {\n        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);\n    }\n\n    #[test]\n    #[should_panic = \"offset groups must be non-zero\"]\n    fn deform_conv_options_offset_groups_zero() {\n        let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);\n    }\n\n    #[test]\n    #[should_panic = \"stride must be non-zero\"]\n    fn unfold_options_stride_zero() {\n        let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);\n    }\n\n    #[test]\n    #[should_panic = \"dilation must be non-zero\"]\n    fn unfold_options_dilation_zero() {\n        let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/conv.rs",
    "content": "#![allow(clippy::single_range_in_vec_init)]\nuse super::{ConvOptions, ConvTransposeOptions};\nuse crate::{Backend, TensorMetadata, tensor::FloatTensor};\nuse burn_std::{MetadataError, Shape, Slice};\n\nuse alloc::{vec, vec::Vec};\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a pooling operation.\npub fn calculate_pool_output_shape<const N: usize>(\n    in_shape: &Shape,\n    kernel_size: &[usize; N],\n    stride: &[usize; N],\n    padding: &[usize; N],\n    dilation: &[usize; N],\n    ceil_mode: bool,\n) -> Result<Shape, MetadataError> {\n    if in_shape.rank() != N + 2 {\n        return Err(MetadataError::RankMismatch {\n            left: in_shape.rank(),\n            right: N + 2,\n        });\n    }\n\n    let mut out_shape = in_shape.clone();\n    // Spatial dims\n    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {\n        *size_i = calculate_pool_output_size(\n            kernel_size[i],\n            stride[i],\n            padding[i],\n            dilation[i],\n            *size_i,\n            ceil_mode,\n        );\n    }\n\n    Ok(out_shape)\n}\n\n/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a convolution.\npub fn calculate_conv_output_shape<const N: usize>(\n    in_shape: &Shape,\n    weight_shape: &Shape,\n    stride: &[usize; N],\n    padding: &[usize; N],\n    dilation: &[usize; N],\n) -> Result<Shape, MetadataError> {\n    if weight_shape.rank() != N + 2 {\n        return Err(MetadataError::RankMismatch {\n            left: weight_shape.rank(),\n            right: N + 2,\n        });\n    }\n\n    if in_shape.rank() != N + 2 {\n        return Err(MetadataError::RankMismatch {\n            left: in_shape.rank(),\n            right: N + 2,\n        });\n    }\n\n    let kernel_size = &weight_shape[2..];\n\n    let mut out_shape = in_shape.clone();\n    // Spatial dims\n    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {\n        *size_i =\n            calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_i);\n    }\n    // Output channels\n    out_shape[1] = weight_shape[0];\n\n    Ok(out_shape)\n}\n\n/// Calculate the expected output shape `[batch_size, channels_out, spatial_dims, ..]` for a transposed convolution.\npub fn calculate_conv_transpose_output_shape<const N: usize>(\n    in_shape: &Shape,\n    weight_shape: &Shape,\n    stride: &[usize; N],\n    padding: &[usize; N],\n    padding_out: &[usize; N],\n    dilation: &[usize; N],\n    groups: usize,\n) -> Result<Shape, MetadataError> {\n    if weight_shape.rank() != N + 2 {\n        return Err(MetadataError::RankMismatch {\n            left: weight_shape.rank(),\n            right: N + 2,\n        });\n    }\n\n    if in_shape.rank() != N + 2 {\n        return Err(MetadataError::RankMismatch {\n            left: in_shape.rank(),\n            right: N + 2,\n        });\n    }\n\n    let kernel_size = &weight_shape[2..];\n\n    let mut out_shape = in_shape.clone();\n    // Spatial dims\n    for (i, size_i) in out_shape[2..].iter_mut().enumerate() {\n        *size_i = calculate_conv_transpose_output_size(\n            kernel_size[i],\n            stride[i],\n            padding[i],\n            padding_out[i],\n            dilation[i],\n            *size_i,\n        );\n    }\n    // Output channels\n    out_shape[1] = weight_shape[1] * groups;\n\n    Ok(out_shape)\n}\n\n/// Calculate the expected padding size required when applying a convolution.\npub fn calculate_conv_padding(\n    kernel_size: usize,\n    stride: usize,\n    size_in: usize,\n    size_out: usize,\n) -> usize {\n    let kernel_size = kernel_size as f32;\n    let stride = stride as f32;\n    let size_in = size_in as f32;\n    let size_out = size_out as f32;\n\n    let padding = stride * (size_out - 1.) - size_in + kernel_size;\n    let padding = (padding / 2.).ceil();\n\n    padding as usize\n}\n\n/// Calculate the expected output size when doing a convolution operation.\npub fn calculate_conv_output_size(\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    size_in: usize,\n) -> usize {\n    (size_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1\n}\n\n/// Calculate the expected output sizes when doing a convolution operation.\npub fn calculate_conv_output_sizes(\n    kernel_size: &[usize],\n    stride: &[usize],\n    padding: &[usize],\n    dilation: &[usize],\n    size_in: &[usize],\n) -> Vec<usize> {\n    size_in\n        .iter()\n        .enumerate()\n        .map(|(i, size_in)| {\n            calculate_conv_output_size(kernel_size[i], stride[i], padding[i], dilation[i], *size_in)\n        })\n        .collect()\n}\n\n/// Calculate the expected output size when doing a transposed convolution operation.\npub fn calculate_conv_transpose_output_size(\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    padding_out: usize,\n    dilation: usize,\n    size_in: usize,\n) -> usize {\n    (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding\n}\n\n/// Calculate the expected output size when doing a pooling operation.\n///\n/// # Arguments\n///\n/// * `kernel_size` - Size of the pooling kernel\n/// * `stride` - Stride of the pooling operation\n/// * `padding` - Padding applied to input\n/// * `dilation` - Dilation of the pooling kernel\n/// * `size_in` - Input size (height or width)\n/// * `ceil_mode` - If true, use ceiling instead of floor for output size calculation.\n///   This allows the last pooling window to go out-of-bounds if needed.\npub fn calculate_pool_output_size(\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    size_in: usize,\n    ceil_mode: bool,\n) -> usize {\n    let numerator = size_in + 2 * padding - dilation * (kernel_size - 1) - 1;\n    if ceil_mode {\n        // Ceiling division: (a + b - 1) / b\n        numerator.div_ceil(stride) + 1\n    } else {\n        // Floor division (default)\n        numerator / stride + 1\n    }\n}\n\n/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `x`.\npub(crate) fn conv1d_x_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<1>,\n) -> FloatTensor<B> {\n    let weight_shape = weight.shape();\n\n    let [_batch_size, _, length_in] = x.shape().dims();\n    let [_batch_size, _channels_out, length_out] = output_grad.shape().dims();\n    let [_, _, kernel_size] = weight_shape.dims();\n\n    let padding_out = calculate_padding_out(\n        kernel_size,\n        options.stride[0],\n        options.padding[0],\n        options.dilation[0],\n        length_in,\n        length_out,\n    );\n\n    B::conv_transpose1d(\n        output_grad,\n        weight,\n        None,\n        ConvTransposeOptions::new(\n            options.stride,\n            options.padding,\n            [padding_out],\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv1d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<1>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv1d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [1D convolution](crate::ops::ModuleOps::conv1d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv1d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _, _length_in] = x.shape().dims();\n    let [_batch_size, channels_out, length_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `x`.\npub(crate) fn conv2d_x_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<2>,\n) -> FloatTensor<B> {\n    let weight_shape = weight.shape();\n\n    let [_batch_size, _channels_in, height_in, width_in] = x.shape().dims();\n    let [_, _, height_out, width_out] = output_grad.shape().dims();\n    let [_channels_out, _, kernel_size_1, kernel_size_2] = weight_shape.dims();\n\n    let padding_1_out = calculate_padding_out(\n        kernel_size_1,\n        options.stride[0],\n        options.padding[0],\n        options.dilation[0],\n        height_in,\n        height_out,\n    );\n    let padding_2_out = calculate_padding_out(\n        kernel_size_2,\n        options.stride[1],\n        options.padding[1],\n        options.dilation[1],\n        width_in,\n        width_out,\n    );\n\n    B::conv_transpose2d(\n        output_grad,\n        weight,\n        None,\n        ConvTransposeOptions::new(\n            options.stride,\n            options.padding,\n            [padding_1_out, padding_2_out],\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv2d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<2>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv2d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [2D convolution](crate::ops::ModuleOps::conv2d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv2d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _, _, _] = x.shape().dims();\n    let [_, channels_out, height_out, width_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(\n        grad,\n        Shape::new([channels_out, batch_size * height_out * width_out]),\n    );\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `x`.\npub(crate) fn conv3d_x_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<3>,\n) -> FloatTensor<B> {\n    let weight_shape = weight.shape();\n\n    let [_batch_size, _channels_in, depth_in, height_in, width_in] = x.shape().dims();\n    let [_, _, depth_out, height_out, width_out] = output_grad.shape().dims();\n    let [\n        _channels_out,\n        _,\n        kernel_size_1,\n        kernel_size_2,\n        kernel_size_3,\n    ] = weight_shape.dims();\n\n    let padding_1_out = calculate_padding_out(\n        kernel_size_1,\n        options.stride[0],\n        options.padding[0],\n        options.dilation[0],\n        depth_in,\n        depth_out,\n    );\n    let padding_2_out = calculate_padding_out(\n        kernel_size_2,\n        options.stride[1],\n        options.padding[1],\n        options.dilation[1],\n        height_in,\n        height_out,\n    );\n    let padding_3_out = calculate_padding_out(\n        kernel_size_3,\n        options.stride[2],\n        options.padding[2],\n        options.dilation[2],\n        width_in,\n        width_out,\n    );\n\n    B::conv_transpose3d(\n        output_grad,\n        weight,\n        None,\n        ConvTransposeOptions::new(\n            options.stride,\n            options.padding,\n            [padding_1_out, padding_2_out, padding_3_out],\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv3d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<3>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv3d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [3D convolution](crate::ops::ModuleOps::conv3d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv3d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _channels_in, _depth_in, _height_in, _width_in] = x.shape().dims();\n    let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(\n        grad,\n        Shape::new([\n            channels_out,\n            batch_size * depth_out * height_out * width_out,\n        ]),\n    );\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `x`.\npub(crate) fn conv_transpose1d_x_backward<B: Backend>(\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<1>,\n) -> FloatTensor<B> {\n    B::conv1d(\n        output_grad,\n        weight,\n        None,\n        ConvOptions::new(\n            options.stride,\n            options.padding,\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv_transpose1d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<1>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv_transpose1d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv_transpose1d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv_transpose1d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _channels_in, _] = x.shape().dims();\n    let [_, channels_out, length_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(grad, Shape::new([channels_out, batch_size * length_out]));\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `x`.\npub(crate) fn conv_transpose2d_x_backward<B: Backend>(\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<2>,\n) -> FloatTensor<B> {\n    B::conv2d(\n        output_grad,\n        weight,\n        None,\n        ConvOptions::new(\n            options.stride,\n            options.padding,\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv_transpose2d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<2>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv_transpose2d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv_transpose2d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv_transpose2d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _channels_in, _, _] = x.shape().dims();\n    let [_, channels_out, height_out, width_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(\n        grad,\n        Shape::new([channels_out, batch_size * height_out * width_out]),\n    );\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `x`.\npub(crate) fn conv_transpose3d_x_backward<B: Backend>(\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<3>,\n) -> FloatTensor<B> {\n    B::conv3d(\n        output_grad,\n        weight,\n        None,\n        ConvOptions::new(\n            options.stride,\n            options.padding,\n            options.dilation,\n            options.groups,\n        ),\n    )\n}\n\n/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.\npub(crate) fn conv_transpose3d_weight_backward<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<3>,\n) -> FloatTensor<B> {\n    let weight_dtype = weight.dtype();\n    let weight_shape = weight.shape();\n    let weight_device = B::float_device(&weight);\n\n    match options.groups == 1 {\n        true => conv_transpose3d_weight_grad_no_groups::<B>(x, output_grad, weight_shape, options),\n        false => conv_transpose3d_weight_grad_groups::<B>(\n            x,\n            B::float_zeros(weight_shape, &weight_device, weight_dtype.into()),\n            output_grad,\n            options,\n        ),\n    }\n}\n\n/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `bias`.\npub(crate) fn conv_transpose3d_bias_backward<B: Backend>(\n    x: FloatTensor<B>,\n    bias: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, _channels_in, _, _, _] = x.shape().dims();\n    let [_, channels_out, depth_out, height_out, width_out] = output_grad.shape().dims();\n\n    let grad = B::float_swap_dims(output_grad, 0, 1);\n    let grad = B::float_reshape(\n        grad,\n        Shape::new([\n            channels_out,\n            batch_size * depth_out * height_out * width_out,\n        ]),\n    );\n    let grad = B::float_sum_dim(grad, 1);\n\n    B::float_reshape(grad, bias.shape())\n}\n\n/// Execute a 1D convolution using a 2D convolution.\npub(crate) fn conv1d_from_conv2d<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    bias: Option<FloatTensor<B>>,\n    options: ConvOptions<1>,\n) -> FloatTensor<B> {\n    let [channels_out, _channels_in, kernel_size] = weight.shape().dims();\n    let [batch_size, channels_in, length_in] = x.shape().dims();\n\n    let weight = B::float_reshape(\n        weight,\n        Shape::new([channels_out, channels_in / options.groups, kernel_size, 1]),\n    );\n    let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));\n\n    let tensor = B::conv2d(\n        x,\n        weight,\n        bias,\n        ConvOptions::new(\n            [options.stride[0], 1],\n            [options.padding[0], 0],\n            [options.dilation[0], 1],\n            options.groups,\n        ),\n    );\n    let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();\n    B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))\n}\n\n/// Execute a 1D transposed convolution using a 2D transposed convolution.\npub(crate) fn conv_transpose1d_from_conv_transpose2d<B: Backend>(\n    x: FloatTensor<B>,\n    weight: FloatTensor<B>,\n    bias: Option<FloatTensor<B>>,\n    options: ConvTransposeOptions<1>,\n) -> FloatTensor<B> {\n    let [channels_in, channels_out, kernel_size] = weight.shape().dims();\n    let [batch_size, _channels_in, length_in] = x.shape().dims();\n\n    let weight = B::float_reshape(\n        weight,\n        Shape::new([channels_in, channels_out, kernel_size, 1]),\n    );\n    let x = B::float_reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));\n\n    let tensor = B::conv_transpose2d(\n        x,\n        weight,\n        bias,\n        ConvTransposeOptions::new(\n            [options.stride[0], 1],\n            [options.padding[0], 0],\n            [options.padding_out[0], 0],\n            [options.dilation[0], 1],\n            options.groups,\n        ),\n    );\n    let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();\n    B::float_reshape(tensor, Shape::from([batch_size, channels_out, height_out]))\n}\n\nfn conv1d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvOptions<1>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv1d(\n        x_swapped,\n        output_grad_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    if weight_grad.shape() != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv2d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvOptions<2>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv2d(\n        x_swapped,\n        output_grad_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    if weight_grad.shape() != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n            Slice::from(0..weight_shape[3]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv3d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvOptions<3>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv3d(\n        x_swapped,\n        output_grad_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    if weight_grad.shape() != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n            Slice::from(0..weight_shape[3]),\n            Slice::from(0..weight_shape[4]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv1d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<1>,\n) -> FloatTensor<B> {\n    let [channels_out, increment_ci, kernel_size] = weight_grad.shape().dims();\n    let increment_co = channels_out / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv1d(\n            x,\n            grad,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_co..end_idx_co),\n                Slice::from(0..increment_ci),\n                Slice::from(0..kernel_size),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn conv2d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<2>,\n) -> FloatTensor<B> {\n    let [channels_out, increment_ci, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();\n    let increment_co = channels_out / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv2d(\n            x,\n            grad,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();\n\n        if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {\n            let slices = vec![\n                Slice::from(0..increment_co),\n                Slice::from(0..increment_ci),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n            ];\n            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);\n        }\n\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_co..end_idx_co),\n                Slice::from(0..increment_ci),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn conv3d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvOptions<3>,\n) -> FloatTensor<B> {\n    let [\n        channels_out,\n        increment_ci,\n        kernel_size_1,\n        kernel_size_2,\n        kernel_size_3,\n    ] = weight_grad.shape().dims();\n    let increment_co = channels_out / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv3d(\n            x,\n            grad,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        let [\n            _,\n            _,\n            kernel_size_1_tmp,\n            kernel_size_2_tmp,\n            kernel_size_3_tmp,\n        ] = weight_grad_tmp.shape().dims();\n\n        if kernel_size_1_tmp != kernel_size_1\n            || kernel_size_2_tmp != kernel_size_2\n            || kernel_size_3_tmp != kernel_size_3\n        {\n            let slices = vec![\n                Slice::from(0..increment_co),\n                Slice::from(0..increment_ci),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n                Slice::from(0..kernel_size_3),\n            ];\n            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);\n        }\n\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_co..end_idx_co),\n                Slice::from(0..increment_ci),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n                Slice::from(0..kernel_size_3),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn conv_transpose1d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvTransposeOptions<1>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv1d(\n        output_grad_swapped,\n        x_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    let grad_shape = weight_grad.shape();\n    if grad_shape != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv_transpose2d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvTransposeOptions<2>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv2d(\n        output_grad_swapped,\n        x_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    let grad_shape = weight_grad.shape();\n    if grad_shape != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n            Slice::from(0..weight_shape[3]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv_transpose3d_weight_grad_no_groups<B: Backend>(\n    x: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    weight_shape: Shape,\n    options: ConvTransposeOptions<3>,\n) -> FloatTensor<B> {\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n    let weight_grad_swapped = B::conv3d(\n        output_grad_swapped,\n        x_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n    );\n    let mut weight_grad = B::float_swap_dims(weight_grad_swapped, 0, 1);\n\n    let grad_shape = weight_grad.shape();\n    if grad_shape != weight_shape {\n        let slices = vec![\n            Slice::from(0..weight_shape[0]),\n            Slice::from(0..weight_shape[1]),\n            Slice::from(0..weight_shape[2]),\n            Slice::from(0..weight_shape[3]),\n            Slice::from(0..weight_shape[4]),\n        ];\n        weight_grad = B::float_slice(weight_grad, &slices);\n    }\n    weight_grad\n}\n\nfn conv_transpose1d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<1>,\n) -> FloatTensor<B> {\n    let [channels_in, increment_co, kernel_size] = weight_grad.shape().dims();\n    let increment_ci = channels_in / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv1d(\n            grad,\n            x,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        let [_, _, kernel_size_tmp] = weight_grad_tmp.shape().dims();\n\n        if kernel_size_tmp != kernel_size {\n            let slices = vec![\n                Slice::from(0..increment_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size),\n            ];\n            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);\n        }\n\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_ci..end_idx_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn conv_transpose2d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<2>,\n) -> FloatTensor<B> {\n    let [channels_in, increment_co, kernel_size_1, kernel_size_2] = weight_grad.shape().dims();\n    let increment_ci = channels_in / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv2d(\n            grad,\n            x,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        let [_, _, kernel_size_1_tmp, kernel_size_2_tmp] = weight_grad_tmp.shape().dims();\n\n        if kernel_size_1_tmp != kernel_size_1 || kernel_size_2_tmp != kernel_size_2 {\n            let slices = vec![\n                Slice::from(0..increment_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n            ];\n            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);\n        }\n\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_ci..end_idx_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn conv_transpose3d_weight_grad_groups<B: Backend>(\n    x: FloatTensor<B>,\n    mut weight_grad: FloatTensor<B>,\n    output_grad: FloatTensor<B>,\n    options: ConvTransposeOptions<3>,\n) -> FloatTensor<B> {\n    let [\n        channels_in,\n        increment_co,\n        kernel_size_1,\n        kernel_size_2,\n        kernel_size_3,\n    ] = weight_grad.shape().dims();\n    let increment_ci = channels_in / options.groups;\n\n    let x_swapped = B::float_swap_dims(x, 0, 1);\n    let output_grad_swapped = B::float_swap_dims(output_grad, 0, 1);\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let x_slice = vec![Slice::new(\n            start_idx_ci as isize,\n            Some(end_idx_ci as isize),\n            1,\n        )];\n        let x = B::float_slice(x_swapped.clone(), &x_slice);\n        let grad_slice = vec![Slice::new(\n            start_idx_co as isize,\n            Some(end_idx_co as isize),\n            1,\n        )];\n        let grad = B::float_slice(output_grad_swapped.clone(), &grad_slice);\n        let mut weight_grad_tmp = B::conv3d(\n            grad,\n            x,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        );\n        weight_grad_tmp = B::float_swap_dims(weight_grad_tmp, 0, 1);\n        let [\n            _,\n            _,\n            kernel_size_1_tmp,\n            kernel_size_2_tmp,\n            kernel_size_3_tmp,\n        ] = weight_grad_tmp.shape().dims();\n\n        if kernel_size_1_tmp != kernel_size_1\n            || kernel_size_2_tmp != kernel_size_2\n            || kernel_size_3_tmp != kernel_size_3\n        {\n            let slices = vec![\n                Slice::from(0..increment_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n                Slice::from(0..kernel_size_3),\n            ];\n            weight_grad_tmp = B::float_slice(weight_grad_tmp, &slices);\n        }\n        weight_grad = B::float_slice_assign(\n            weight_grad,\n            &[\n                Slice::from(start_idx_ci..end_idx_ci),\n                Slice::from(0..increment_co),\n                Slice::from(0..kernel_size_1),\n                Slice::from(0..kernel_size_2),\n                Slice::from(0..kernel_size_3),\n            ],\n            weight_grad_tmp,\n        );\n    }\n\n    weight_grad\n}\n\nfn calculate_padding_out(\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    size_in: usize,\n    size_out: usize,\n) -> usize {\n    if stride <= 1 {\n        return 0;\n    }\n\n    let out = 1\n        + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil()\n            as usize;\n    i64::max(0, out as i64 - size_out as i64) as usize\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_calculate_output_size_1() {\n        let kernel_size = 3;\n        let stride = 1;\n        let padding = 1;\n        let size_in = 3;\n        let dilation = 1;\n\n        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_out, 3);\n    }\n\n    #[test]\n    fn test_calculate_output_size_2() {\n        let kernel_size = 5;\n        let stride = 2;\n        let padding = 3;\n        let size_in = 27;\n        let dilation = 1;\n\n        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_out, 15);\n    }\n\n    #[test]\n    fn test_calculate_output_size_3() {\n        let kernel_size = 5;\n        let stride = 2;\n        let padding = 3;\n        let size_in = 27;\n        let dilation = 2;\n\n        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_out, 13);\n    }\n\n    #[test]\n    fn test_calculate_same_padding_1() {\n        let kernel_size = 3;\n        let stride = 1;\n        let size_in = 3;\n        let dilation = 1;\n\n        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);\n        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_in, size_out, \"Expected size\");\n    }\n\n    #[test]\n    fn test_calculate_same_padding_2() {\n        let kernel_size = 3;\n        let stride = 2;\n        let size_in = 7;\n        let dilation = 1;\n\n        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_in);\n        let size_out = calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_in, size_out, \"Expected size\");\n    }\n\n    #[test]\n    fn test_calculate_output_padding_1() {\n        let kernel_size = 3;\n        let stride = 2;\n        let size_in = 7;\n        let size_out = 10;\n        let dilation = 1;\n\n        let padding = calculate_conv_padding(kernel_size, stride, size_in, size_out);\n        let size_out_expected =\n            calculate_conv_output_size(kernel_size, stride, padding, dilation, size_in);\n\n        assert_eq!(size_out, size_out_expected, \"Expected size\");\n    }\n\n    #[test]\n    fn test_expect_conv2d_output_shape() {\n        // in channels: 3\n        // out channels: 8\n        // size in: [27, 3]\n        // kernel size: [5, 3]\n        let stride = [2, 1];\n        let padding = [3, 1];\n        let dilation = [2, 1];\n        let shape = calculate_conv_output_shape(\n            &Shape::new([12, 3, 27, 3]),\n            &Shape::new([8, 3, 5, 3]),\n            &stride,\n            &padding,\n            &dilation,\n        )\n        .unwrap();\n        assert_eq!(shape, Shape::new([12, 8, 13, 3]))\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/grid_sample.rs",
    "content": "use crate::{\n    Backend, TensorMetadata,\n    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},\n    tensor::FloatTensor,\n};\nuse alloc::vec;\nuse burn_std::{Shape, Slice};\n\n/// Reference implementation of grid_sample_2d that supports all options.\n///\n/// # Arguments\n///\n/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)\n/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].\n///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right\n/// * `options` - Grid sampling options\n///\n/// # Returns\n///\n/// A tensor with shape (N, C, H_out, W_out)\npub fn float_grid_sample_2d_ref<B: Backend>(\n    tensor: FloatTensor<B>,\n    grid: FloatTensor<B>,\n    options: GridSampleOptions,\n) -> FloatTensor<B> {\n    match options.mode {\n        InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(\n            tensor,\n            grid,\n            options.padding_mode,\n            options.align_corners,\n        ),\n        _ => todo!(\n            \"Default implementation for grid_sample_2d with {:?} unimplemented\",\n            options.mode\n        ),\n    }\n}\n\n/// Bilinear grid sampling implementation.\nfn float_grid_sample_2d_bilinear<B: Backend>(\n    tensor: FloatTensor<B>,\n    grid: FloatTensor<B>,\n    padding_mode: GridSamplePaddingMode,\n    align_corners: bool,\n) -> FloatTensor<B> {\n    let n = tensor.shape()[0];\n    let c = tensor.shape()[1];\n    let h_in = tensor.shape()[2];\n    let w_in = tensor.shape()[3];\n    let h_out = grid.shape()[1];\n    let w_out = grid.shape()[2];\n    let spatial_in = h_in * w_in;\n    let spatial_out = h_out * w_out;\n\n    // Separate x and y coordinates from grid\n    // shape: (N, H_out, W_out, 1)\n    let grid_x_slice = vec![\n        Slice::new(0, Some(n as isize), 1),\n        Slice::new(0, Some(h_out as isize), 1),\n        Slice::new(0, Some(w_out as isize), 1),\n        Slice::new(0, Some(1), 1),\n    ];\n    let grid_y_slice = vec![\n        Slice::new(0, Some(n as isize), 1),\n        Slice::new(0, Some(h_out as isize), 1),\n        Slice::new(0, Some(w_out as isize), 1),\n        Slice::new(1, Some(2), 1),\n    ];\n\n    let grid_x = B::float_slice(grid.clone(), &grid_x_slice);\n    let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));\n    let grid_y = B::float_slice(grid.clone(), &grid_y_slice);\n    let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));\n\n    // Convert normalized grid coordinates [-1, 1] to pixel coordinates\n    let w_in_f = w_in as f64;\n    let h_in_f = h_in as f64;\n\n    let (grid_x, grid_y) = if align_corners {\n        // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2\n        // Maps -1 to 0 and 1 to width - 1\n        let grid_x = B::float_add_scalar(grid_x, 1f32.into());\n        let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).into());\n\n        let grid_y = B::float_add_scalar(grid_y, 1f32.into());\n        let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).into());\n\n        (grid_x, grid_y)\n    } else {\n        // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5\n        // Maps -1 to -0.5 and 1 to width - 0.5\n        let grid_x = B::float_add_scalar(grid_x, 1f32.into());\n        let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).into());\n        let grid_x = B::float_sub_scalar(grid_x, 0.5f32.into());\n\n        let grid_y = B::float_add_scalar(grid_y, 1f32.into());\n        let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).into());\n        let grid_y = B::float_sub_scalar(grid_y, 0.5f32.into());\n\n        (grid_x, grid_y)\n    };\n\n    // Apply padding mode to coordinates\n    let (grid_x, grid_y) = match padding_mode {\n        GridSamplePaddingMode::Border => {\n            // Clamp coordinates to valid range [0, size-1]\n            let grid_x = B::float_clamp(grid_x, 0f32.into(), ((w_in - 1) as f32).into());\n            let grid_y = B::float_clamp(grid_y, 0f32.into(), ((h_in - 1) as f32).into());\n            (grid_x, grid_y)\n        }\n        GridSamplePaddingMode::Reflection => {\n            // Reflect coordinates at boundaries\n            let grid_x = reflect_coordinates::<B>(grid_x, w_in_f, align_corners);\n            let grid_y = reflect_coordinates::<B>(grid_y, h_in_f, align_corners);\n            (grid_x, grid_y)\n        }\n        GridSamplePaddingMode::Zeros => {\n            // Keep coordinates as-is, we'll mask out-of-bounds later\n            (grid_x, grid_y)\n        }\n    };\n\n    // Get floor indices for the four corners\n    let grid_x_floored = B::float_floor(grid_x.clone());\n    let grid_y_floored = B::float_floor(grid_y.clone());\n\n    // Compute interpolation weights (fractional part)\n    let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone());\n    let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone());\n\n    // Convert to integer indices\n    let x0 = B::float_into_int(grid_x_floored.clone());\n    let y0 = B::float_into_int(grid_y_floored.clone());\n    let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1f32.into()));\n    let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1f32.into()));\n\n    // Create masks for out-of-bounds coordinates (only used for zeros padding)\n    let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros {\n        let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.into());\n        let x0_valid = B::bool_and(\n            x0_valid,\n            B::int_lower_elem(x0.clone(), (w_in as i32).into()),\n        );\n        let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.into());\n        let x1_valid = B::bool_and(\n            x1_valid,\n            B::int_lower_elem(x1.clone(), (w_in as i32).into()),\n        );\n        let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.into());\n        let y0_valid = B::bool_and(\n            y0_valid,\n            B::int_lower_elem(y0.clone(), (h_in as i32).into()),\n        );\n        let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.into());\n        let y1_valid = B::bool_and(\n            y1_valid,\n            B::int_lower_elem(y1.clone(), (h_in as i32).into()),\n        );\n\n        (\n            Some(B::bool_and(x0_valid.clone(), y0_valid.clone())),\n            Some(B::bool_and(x0_valid.clone(), y1_valid.clone())),\n            Some(B::bool_and(x1_valid.clone(), y0_valid)),\n            Some(B::bool_and(x1_valid, y1_valid)),\n        )\n    } else {\n        (None, None, None, None)\n    };\n\n    // Clamp indices to valid range for gather\n    let x0_clamped = B::int_clamp(x0, 0.into(), ((w_in - 1) as i32).into());\n    let x1_clamped = B::int_clamp(x1, 0.into(), ((w_in - 1) as i32).into());\n    let y0_clamped = B::int_clamp(y0, 0.into(), ((h_in - 1) as i32).into());\n    let y1_clamped = B::int_clamp(y1, 0.into(), ((h_in - 1) as i32).into());\n\n    // Linear indices: idx = y * W_in + x\n    let w_in_scalar: i32 = w_in as i32;\n    let idx_00 = B::int_add(\n        B::int_mul_scalar(y0_clamped.clone(), w_in_scalar.into()),\n        x0_clamped.clone(),\n    );\n    let idx_01 = B::int_add(\n        B::int_mul_scalar(y1_clamped.clone(), w_in_scalar.into()),\n        x0_clamped,\n    );\n    let idx_10 = B::int_add(\n        B::int_mul_scalar(y0_clamped, w_in_scalar.into()),\n        x1_clamped.clone(),\n    );\n    let idx_11 = B::int_add(\n        B::int_mul_scalar(y1_clamped, w_in_scalar.into()),\n        x1_clamped,\n    );\n\n    // [N, 1, H_out, W_out] -> [N, 1, H_out * W_out]\n    let idx_00 = B::int_reshape(idx_00, Shape::new([n, 1, spatial_out]));\n    let idx_01 = B::int_reshape(idx_01, Shape::new([n, 1, spatial_out]));\n    let idx_10 = B::int_reshape(idx_10, Shape::new([n, 1, spatial_out]));\n    let idx_11 = B::int_reshape(idx_11, Shape::new([n, 1, spatial_out]));\n\n    // [N, 1, spatial] -> [N, C, spatial]\n    let idx_00 = B::int_expand(idx_00, Shape::new([n, c, spatial_out]));\n    let idx_01 = B::int_expand(idx_01, Shape::new([n, c, spatial_out]));\n    let idx_10 = B::int_expand(idx_10, Shape::new([n, c, spatial_out]));\n    let idx_11 = B::int_expand(idx_11, Shape::new([n, c, spatial_out]));\n\n    let tensor_flat = B::float_reshape(tensor, Shape::new([n, c, spatial_in]));\n\n    let sample_00 = B::float_gather(2, tensor_flat.clone(), idx_00);\n    let sample_01 = B::float_gather(2, tensor_flat.clone(), idx_01);\n    let sample_10 = B::float_gather(2, tensor_flat.clone(), idx_10);\n    let sample_11 = B::float_gather(2, tensor_flat, idx_11);\n\n    // Reshape samples to (N, C, H_out, W_out)\n    let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));\n    let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));\n    let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));\n    let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));\n\n    // Apply masks for zeros padding (set out-of-bounds samples to 0)\n    let (sample_00, sample_01, sample_10, sample_11) =\n        if padding_mode == GridSamplePaddingMode::Zeros {\n            let mask_00 = mask_00.unwrap();\n            let mask_01 = mask_01.unwrap();\n            let mask_10 = mask_10.unwrap();\n            let mask_11 = mask_11.unwrap();\n\n            let mask_00_inv = B::bool_not(mask_00);\n            let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out]));\n            let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out]));\n            let mask_01_inv = B::bool_not(mask_01);\n            let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out]));\n            let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out]));\n            let mask_10_inv = B::bool_not(mask_10);\n            let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out]));\n            let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out]));\n            let mask_11_inv = B::bool_not(mask_11);\n            let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out]));\n            let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out]));\n\n            (\n                B::float_mask_fill(sample_00, mask_00_inv, 0f32.into()),\n                B::float_mask_fill(sample_01, mask_01_inv, 0f32.into()),\n                B::float_mask_fill(sample_10, mask_10_inv, 0f32.into()),\n                B::float_mask_fill(sample_11, mask_11_inv, 0f32.into()),\n            )\n        } else {\n            (sample_00, sample_01, sample_10, sample_11)\n        };\n\n    // Compute bilinear interpolation weights\n    let one_minus_x = B::float_neg(x_frac.clone());\n    let one_minus_x = B::float_add_scalar(one_minus_x, 1f32.into());\n\n    let one_minus_y = B::float_neg(y_frac.clone());\n    let one_minus_y = B::float_add_scalar(one_minus_y, 1f32.into());\n\n    let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone());\n    let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone());\n    let weight_10 = B::float_mul(x_frac.clone(), one_minus_y);\n    let weight_11 = B::float_mul(x_frac, y_frac);\n\n    // Bilinear interpolation\n    let result = B::float_mul(sample_00, weight_00);\n    let result = B::float_add(result, B::float_mul(sample_01, weight_01));\n    let result = B::float_add(result, B::float_mul(sample_10, weight_10));\n\n    B::float_add(result, B::float_mul(sample_11, weight_11))\n}\n\n/// Reflect coordinates at boundaries using a triangle wave pattern.\n///\n/// For align_corners=true: reflects within [0, size-1]\n/// For align_corners=false: reflects within [-0.5, size-0.5]\nfn reflect_coordinates<B: Backend>(\n    coords: FloatTensor<B>,\n    size: f64,\n    align_corners: bool,\n) -> FloatTensor<B> {\n    let (min_val, max_val) = if align_corners {\n        (0.0f32, (size - 1.0) as f32)\n    } else {\n        (-0.5f32, (size - 0.5) as f32)\n    };\n\n    let span = max_val - min_val;\n    if span <= 0.0 {\n        // Edge case: size is 1, just return min_val everywhere\n        let zeros = B::float_mul_scalar(coords, 0f32.into());\n        return B::float_add_scalar(zeros, min_val.into());\n    }\n\n    // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val\n    let period = 2.0 * span;\n\n    // x = abs(coord - min_val)\n    let x = B::float_sub_scalar(coords, min_val.into());\n    let x = B::float_abs(x);\n\n    // x_mod = x - floor(x / period) * period\n    let x_div = B::float_div_scalar(x.clone(), period.into());\n    let x_div_floor = B::float_floor(x_div);\n    let x_mod = B::float_sub(x, B::float_mul_scalar(x_div_floor, period.into()));\n\n    // result = span - abs(x_mod - span) + min_val\n    let diff = B::float_sub_scalar(x_mod, span.into());\n    let abs_diff = B::float_abs(diff);\n    let reflected = B::float_sub_scalar(abs_diff, span.into());\n    let reflected = B::float_neg(reflected);\n    B::float_add_scalar(reflected, min_val.into())\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/mod.rs",
    "content": "/// Module with convolution operations.\npub mod conv;\n\n/// Module with attention operations.\npub mod attention;\n\n/// Module with unfold operations.\npub mod unfold;\n\n/// Module with pooling operations.\npub mod pool;\n\n/// Module for grid_sample operations\npub mod grid_sample;\n\nmod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/pool.rs",
    "content": "use crate::tensor::{FloatTensor, IntTensor};\nuse crate::{Backend, TensorMetadata};\nuse burn_std::Shape;\n\nuse super::{MaxPool1dBackward, MaxPool1dWithIndices};\n\npub(crate) fn avg_pool1d_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> FloatTensor<B> {\n    let [batch_size, channels, length] = x.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));\n    let x = B::avg_pool2d(\n        x,\n        [kernel_size, 1],\n        [stride, 1],\n        [padding, 0],\n        count_include_pad,\n        ceil_mode,\n    );\n\n    let [batch_size, channels, length, _] = x.shape().dims();\n\n    B::float_reshape(x, Shape::from([batch_size, channels, length]))\n}\n\npub(crate) fn avg_pool1d_backward_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    grad: FloatTensor<B>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> FloatTensor<B> {\n    let [batch_size, channels, length_in] = x.shape().dims();\n    let [_, _, length_out] = grad.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));\n    let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));\n\n    let grad_x = B::avg_pool2d_backward(\n        x,\n        grad_x,\n        [kernel_size, 1],\n        [stride, 1],\n        [padding, 0],\n        count_include_pad,\n        ceil_mode,\n    );\n\n    B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))\n}\n\npub(crate) fn adaptive_avg_pool1d_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    output_size: usize,\n) -> FloatTensor<B> {\n    let [batch_size, channels, length] = x.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));\n    let x = B::adaptive_avg_pool2d(x, [output_size, 1]);\n\n    let [batch_size, channels, length, _] = x.shape().dims();\n\n    B::float_reshape(x, Shape::from([batch_size, channels, length]))\n}\n\npub(crate) fn adaptive_avg_pool1d_backward_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    grad: FloatTensor<B>,\n) -> FloatTensor<B> {\n    let [batch_size, channels, length_in] = x.shape().dims();\n    let [_, _, length_out] = grad.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));\n    let grad_x = B::float_reshape(grad, Shape::from([batch_size, channels, length_out, 1]));\n\n    let grad_x = B::adaptive_avg_pool2d_backward(x, grad_x);\n\n    B::float_reshape(grad_x, Shape::from([batch_size, channels, length_in]))\n}\n\npub(crate) fn max_pool1d_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    ceil_mode: bool,\n) -> FloatTensor<B> {\n    let [batch_size, channels, length] = x.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length, 1]));\n    let x = B::max_pool2d(\n        x,\n        [kernel_size, 1],\n        [stride, 1],\n        [padding, 0],\n        [dilation, 1],\n        ceil_mode,\n    );\n\n    let [batch_size, channels, length, _] = x.shape().dims();\n\n    B::float_reshape(x, Shape::from([batch_size, channels, length]))\n}\n\npub(crate) fn max_pool1d_with_indices_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    ceil_mode: bool,\n) -> MaxPool1dWithIndices<B> {\n    let [batch_size, channels, length] = x.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, 1, length]));\n    let x = B::max_pool2d_with_indices(\n        x,\n        [1, kernel_size],\n        [1, stride],\n        [0, padding],\n        [1, dilation],\n        ceil_mode,\n    );\n    let [batch_size, channels, _, length] = x.output.shape().dims();\n    let output = B::float_reshape(x.output, Shape::from([batch_size, channels, length]));\n    let indices = B::int_reshape(x.indices, Shape::from([batch_size, channels, length]));\n    MaxPool1dWithIndices::new(output, indices)\n}\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn max_pool1d_with_indices_backward_from_2d<B: Backend>(\n    x: FloatTensor<B>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    ceil_mode: bool,\n    output_grad: FloatTensor<B>,\n    indices: IntTensor<B>,\n) -> MaxPool1dBackward<B> {\n    let [batch_size, channels, length_in] = x.shape().dims();\n    let [_, _, length_out] = output_grad.shape().dims();\n\n    let x = B::float_reshape(x, Shape::from([batch_size, channels, length_in, 1]));\n    let grad_x = B::float_reshape(\n        output_grad,\n        Shape::from([batch_size, channels, length_out, 1]),\n    );\n    let indices = B::int_reshape(indices, Shape::from([batch_size, channels, length_out, 1]));\n\n    let grad_x = B::max_pool2d_with_indices_backward(\n        x,\n        [kernel_size, 1],\n        [stride, 1],\n        [padding, 0],\n        [dilation, 1],\n        ceil_mode,\n        grad_x,\n        indices,\n    )\n    .x_grad;\n\n    MaxPool1dBackward::new(B::float_reshape(\n        grad_x,\n        Shape::from([batch_size, channels, length_in]),\n    ))\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/modules/unfold.rs",
    "content": "use super::{ConvOptions, UnfoldOptions};\nuse crate::tensor::FloatTensor;\nuse crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_std::Shape;\n\n/// Constructs a special weight tensor used for unfolding.\n///\n/// # Notes\n///\n/// The idea behind using convolution for unfolding is to leverage the sliding window mechanism of\n/// convolution. By creating a weight tensor with ones in a particular pattern, we are able to borrow\n/// the convolution operation's mechanism as it moves across the input tensor, picking up the desired\n/// values in the pattern of the unfolding operation.\npub(crate) fn create_unfolding_weight<B: Backend>(\n    in_channels: usize,\n    kernel_size: [usize; 2],\n    device: &B::Device,\n) -> FloatTensor<B> {\n    let shape = Shape::new([\n        in_channels * kernel_size[0] * kernel_size[1],\n        in_channels,\n        kernel_size[0],\n        kernel_size[1],\n    ]);\n\n    let mut strides = [0; 4];\n    let mut current = 1;\n    shape.iter().enumerate().rev().for_each(|(index, val)| {\n        strides[index] = current;\n        current *= val;\n    });\n\n    let num_elements = shape.num_elements();\n\n    let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];\n\n    for k in 0..in_channels {\n        for i in 0..kernel_size[0] {\n            for j in 0..kernel_size[1] {\n                let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;\n                let index =\n                    output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];\n\n                weight[index] = 1.elem();\n            }\n        }\n    }\n\n    B::float_from_data(TensorData::new(weight, shape), device)\n}\n\n/// Compute the unfold4d operation using the conv2d operations.\npub(crate) fn unfold4d_using_conv2d<B: Backend>(\n    x: FloatTensor<B>,\n    kernel_size: [usize; 2],\n    options: UnfoldOptions,\n) -> FloatTensor<B> {\n    let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();\n    let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x));\n    let unfolded = B::conv2d(\n        x,\n        weight,\n        None,\n        ConvOptions::new(options.stride, options.padding, options.dilation, 1),\n    );\n\n    let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();\n\n    B::float_reshape(\n        unfolded,\n        Shape::new([batch_size, channels_out, out_height * out_width]),\n    )\n}\n\n/// Calculate the number of unfolding windows that can be extracted from a dimension of given size.\npub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize {\n    assert!(step_size > 0);\n    let x = dim_size + step_size;\n    if x < window_size {\n        0\n    } else {\n        (x - window_size) / step_size\n    }\n}\n\n/// Calculate the output shape for an unfold operation.\n///\n/// The operation yields a view with all complete windows of size `size` in dimension `dim`;\n/// where windows are advanced by `step` at each index.\n///\n/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n///\n/// # Arguments\n///\n/// * `shape` - The input shape to unfold; of shape ``[pre=..., dim shape, post=...]``\n/// * `dim` - the dimension to unfold.\n/// * `size` - the size of each unfolded window.\n/// * `step` - the step between each window.\n///\n/// # Returns\n///\n/// A shape with ``[pre=..., windows, post=..., size]``.\npub fn calculate_unfold_shape<S: Into<Shape>>(\n    shape: S,\n    dim: usize,\n    size: usize,\n    step: usize,\n) -> Shape {\n    let mut shape = shape.into();\n    let d_shape = shape[dim];\n    let windows = calculate_unfold_windows(d_shape, size, step);\n    shape[dim] = windows;\n    shape.push(size);\n\n    shape\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_calculate_unfold_windows() {\n        assert_eq!(calculate_unfold_windows(2, 5, 1), 0);\n\n        assert_eq!(calculate_unfold_windows(2, 3, 1), 0);\n        assert_eq!(calculate_unfold_windows(3, 3, 1), 1);\n        assert_eq!(calculate_unfold_windows(4, 3, 1), 2);\n        assert_eq!(calculate_unfold_windows(5, 3, 1), 3);\n\n        assert_eq!(calculate_unfold_windows(2, 3, 2), 0);\n        assert_eq!(calculate_unfold_windows(3, 3, 2), 1);\n        assert_eq!(calculate_unfold_windows(4, 3, 2), 1);\n        assert_eq!(calculate_unfold_windows(5, 3, 2), 2);\n    }\n\n    #[test]\n    fn test_calculate_unfold_shape() {\n        assert_eq!(\n            calculate_unfold_shape([2, 6, 6], 1, 3, 2),\n            Shape::new([2, 2, 6, 3])\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/qtensor.rs",
    "content": "use alloc::vec::Vec;\nuse burn_std::{\n    Shape, Slice,\n    quantization::{QuantPropagation, QuantScheme},\n};\n\nuse crate::{\n    Backend, ExecutionError, QTensorPrimitive, TensorData, TensorMetadata, TensorPrimitive,\n};\nuse crate::{\n    Scalar,\n    tensor::{\n        BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor,\n        quantization::{\n            Calibration, QuantizationParametersPrimitive, compute_q_params, compute_range,\n        },\n    },\n};\n\n/// Automatically applies `dequantization -> float operation -> quantization`.\n///\n/// Used for tensor ops that should always return a quantized output.\n#[macro_export]\nmacro_rules! dequant_op_quant {\n    // Binary tensor float op w/ lhs & rhs\n    (\n        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr\n    ) => {{\n        // Heuristic: prioritize lhs scheme\n        let scheme = $t1.scheme().clone();\n\n        let t1_f = <$ty>::dequantize($t1);\n        let t2_f = <$ty>::dequantize($t2);\n        #[allow(clippy::redundant_closure_call)]\n        let out_f = $float_op(t1_f, t2_f);\n\n        <$ty>::quantize_dynamic(out_f, &scheme)\n    }};\n    // Unary tensor float op\n    (\n        ty $ty:ty, float_op $float_op:expr, $tensor:expr\n    ) => {{\n        let scheme = $tensor.scheme().clone();\n\n        let tensor_f = <$ty>::dequantize($tensor);\n        #[allow(clippy::redundant_closure_call)]\n        let out_f = $float_op(tensor_f);\n\n        <$ty>::quantize_dynamic(out_f, &scheme)\n    }};\n}\n\n/// Automatically applies `dequantization -> float operation [-> quantization]`.\n///\n/// The output quantization step is optional.\n/// It is only performed when the input quantization scheme is propagated.\n#[macro_export]\nmacro_rules! dequant_op_flow {\n    // Binary tensor float op w/ lhs & rhs\n    (\n        ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr\n    ) => {{\n        // Heuristic: prioritize lhs scheme\n        let scheme = $t1.scheme().clone();\n        let propagation = $t1.propagation();\n\n        let t1_f = <$ty>::dequantize($t1);\n        let t2_f = <$ty>::dequantize($t2);\n        #[allow(clippy::redundant_closure_call)]\n        let out_f = $float_op(t1_f, t2_f);\n\n        match propagation {\n            QuantPropagation::Propagate => {\n                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))\n            }\n            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),\n        }\n    }};\n    // Unary tensor float op\n    (\n        ty $ty:ty, float_op $float_op:expr, $tensor:expr\n    ) => {{\n        let scheme = $tensor.scheme().clone();\n        let propagation = $tensor.propagation();\n\n        let tensor_f = <$ty>::dequantize($tensor);\n        #[allow(clippy::redundant_closure_call)]\n        let out_f = $float_op(tensor_f);\n\n        match propagation {\n            QuantPropagation::Propagate => {\n                TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme))\n            }\n            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),\n        }\n    }};\n}\n\n/// Operations on quantized tensors.\n///\n/// # Return Type Semantics\n///\n/// The return type of each operation indicates how quantization is handled:\n///\n/// ## [`QuantizedTensor<B>`]\n/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized\n/// representation. Implementations should avoid dequantizing when possible to maintain performance.\n/// For example, shape or layout changes such as expand or transpose preserve quantization.\n///\n/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is\n/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering\n/// the quantization parameters to match the new layout.*\n///\n///\n/// ## [`TensorPrimitive<B>`]\n/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation\n/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])\n/// returned in floating-point form ([`TensorPrimitive::Float`]).\n///\n/// This distinction allows for fine-grained control over mixed-precision flows while still operating\n/// through a unified API.\npub trait QTensorOps<B: Backend> {\n    /// Creates a new tensor from the data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The data structure.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given data.\n    fn q_from_data(data: TensorData, device: &Device<B>) -> QuantizedTensor<B>;\n\n    /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.\n    fn quantize(\n        tensor: FloatTensor<B>,\n        scheme: &QuantScheme,\n        qparams: QuantizationParametersPrimitive<B>,\n    ) -> QuantizedTensor<B>;\n\n    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.\n    fn quantize_dynamic(tensor: FloatTensor<B>, scheme: &QuantScheme) -> QuantizedTensor<B> {\n        // Dynamically compute min/max tensor range and qparams before quantizing\n        let (min, max) = compute_range::<B>(scheme, tensor.clone(), &Calibration::MinMax);\n        let qparams = compute_q_params(scheme, min, max);\n        Self::quantize(tensor, scheme, qparams)\n    }\n\n    /// Convert the tensor back to a higher precision data type.\n    fn dequantize(tensor: QuantizedTensor<B>) -> FloatTensor<B>;\n\n    /// Gets the device of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The device of the tensor.\n    fn q_device(tensor: &QuantizedTensor<B>) -> Device<B>;\n\n    /// Moves the tensor to the given device.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `device` - The device to move the tensor to.\n    ///\n    /// # Returns\n    ///\n    /// The tensor on the given device.\n    fn q_to_device(tensor: QuantizedTensor<B>, device: &Device<B>) -> QuantizedTensor<B>;\n\n    /// Reshapes a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reshape.\n    /// * `shape` - The new shape of the tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the new shape.\n    fn q_reshape(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;\n\n    /// Converts the tensor to a data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The data structure with the tensor's data.\n    fn q_into_data(\n        tensor: QuantizedTensor<B>,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;\n\n    /// Detaches a tensor from the computation graph.\n    fn q_detach(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        // Should only be overridden by autodiff backends.\n        tensor\n    }\n\n    /// Sets the `require_grad` flag of a tensor.\n    fn q_set_require_grad(tensor: QuantizedTensor<B>, _require_grad: bool) -> QuantizedTensor<B> {\n        // Should only be overridden by autodiff backends.\n        tensor\n    }\n\n    /// Returns the `require_grad` flag of a tensor.\n    fn q_is_require_grad(_tensor: &QuantizedTensor<B>) -> bool {\n        // Should only be overridden by autodiff backends.\n        false\n    }\n\n    /// Broadcasts the `tensor` to the given `shape`.\n    fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;\n\n    /// Transposes a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    fn q_transpose(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        let ndims = tensor.shape().num_dims();\n        Self::q_swap_dims(tensor, ndims - 2, ndims - 1)\n    }\n\n    /// Swaps two dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap.\n    /// * `dim2` - The second dimension to swap.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    fn q_swap_dims(tensor: QuantizedTensor<B>, dim1: usize, dim2: usize) -> QuantizedTensor<B>;\n\n    /// Permutes the dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to permute the dimensions of.\n    /// * `axes` - The new order of the dimensions.\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    fn q_permute(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;\n\n    /// Reverse the order of elements in a tensor along the given axes.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reverse.\n    /// * `axes` - The axes to reverse.\n    ///\n    /// The tensor with the elements reversed.\n    fn q_flip(tensor: QuantizedTensor<B>, axes: &[usize]) -> QuantizedTensor<B>;\n\n    /// Select tensor elements along the given dimension corresponding for the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices to select.\n    ///\n    /// # Returns\n    ///\n    /// The selected elements.\n    fn q_select(\n        tensor: QuantizedTensor<B>,\n        dim: usize,\n        indices: IntTensor<B>,\n    ) -> QuantizedTensor<B>;\n\n    /// Select tensor elements corresponding to the given slices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `slices` - The slices specifying ranges and steps for each dimension.\n    ///\n    /// # Returns\n    ///\n    /// The selected elements in a new tensor.\n    fn q_slice(tensor: QuantizedTensor<B>, slices: &[Slice]) -> QuantizedTensor<B>;\n\n    /// Gather elements from a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to gather from.\n    /// * `tensor` - The tensor to gather from.\n    /// * `indices` - The indices to gather.\n    ///\n    /// # Returns\n    ///\n    /// The gathered elements.\n    fn q_gather(\n        dim: usize,\n        tensor: QuantizedTensor<B>,\n        indices: IntTensor<B>,\n    ) -> QuantizedTensor<B> {\n        // Default implementation. Backends can gather on the quantized values when supported.\n        dequant_op_quant!(\n            ty Self,\n            float_op |tensor| B::float_gather(dim, tensor, indices),\n            tensor\n        )\n    }\n\n    /// Repeat the tensor along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to repeat.\n    /// * `times` - The number of times to repeat the dimension.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given dimension repeated.\n    fn q_repeat_dim(tensor: QuantizedTensor<B>, dim: usize, times: usize) -> QuantizedTensor<B> {\n        dequant_op_quant!(\n            ty Self,\n            float_op |tensor| B::float_repeat_dim(tensor, dim, times),\n            tensor\n        )\n    }\n\n    /// Adds two tensors together.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of adding the two tensors together.\n    fn q_add(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |lhs, rhs| B::float_add(lhs, rhs),\n            lhs,\n            rhs\n        )\n    }\n\n    /// Adds a scalar to a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of adding the scalar to the tensor.\n    fn q_add_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_add_scalar(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Clamps a tensor under a minimum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn q_clamp_min(tensor: QuantizedTensor<B>, min: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_clamp_min(tensor, min),\n            tensor\n        )\n    }\n\n    /// Clamps a tensor over a maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn q_clamp_max(tensor: QuantizedTensor<B>, max: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_clamp_max(tensor, max),\n            tensor\n        )\n    }\n\n    /// Clamps a tensor between a minimum and maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn q_clamp(tensor: QuantizedTensor<B>, min: Scalar, max: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_clamp(tensor, min, max),\n            tensor\n        )\n    }\n\n    /// Subtracts two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of subtracting the two tensors.\n    fn q_sub(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |lhs, rhs| B::float_sub(lhs, rhs),\n            lhs,\n            rhs\n        )\n    }\n\n    /// Subtracts a scalar from a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of subtracting the scalar from the tensor.\n    fn q_sub_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sub_scalar(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Multiplies two tensors together element-wise.\n    fn q_mul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |lhs, rhs| B::float_mul(lhs, rhs),\n            lhs,\n            rhs\n        )\n    }\n\n    /// Multiplies a tensor by a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of multiplying the tensor by the scalar.\n    fn q_mul_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_mul_scalar(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Divides two tensors element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of dividing the two tensors.\n    fn q_div(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |lhs, rhs| B::float_div(lhs, rhs),\n            lhs,\n            rhs\n        )\n    }\n\n    /// Divides a tensor by a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of dividing the tensor by the scalar.\n    fn q_div_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_div_scalar(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Multiplies two tensors together using matrix multiplication.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of multiplying the two tensors together using matrix multiplication.\n    fn q_matmul(lhs: TensorPrimitive<B>, rhs: TensorPrimitive<B>) -> TensorPrimitive<B> {\n        let mut propagation = QuantPropagation::Inhibit;\n        let mut scheme = QuantScheme::default();\n        let lhs = match lhs {\n            TensorPrimitive::Float(lhs) => lhs,\n            TensorPrimitive::QFloat(lhs) => {\n                propagation = lhs.propagation();\n                scheme = *lhs.scheme();\n                Self::dequantize(lhs)\n            }\n        };\n        let rhs = match rhs {\n            TensorPrimitive::Float(rhs) => rhs,\n            TensorPrimitive::QFloat(rhs) => {\n                propagation = rhs.propagation();\n                scheme = *rhs.scheme();\n                Self::dequantize(rhs)\n            }\n        };\n\n        let out_f = B::float_matmul(lhs, rhs);\n        match propagation {\n            QuantPropagation::Propagate => {\n                TensorPrimitive::QFloat(<Self>::quantize_dynamic(out_f, &scheme))\n            }\n            QuantPropagation::Inhibit => TensorPrimitive::Float(out_f),\n        }\n    }\n\n    /// Negates a tensor element-wise.\n    fn q_neg(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_neg(tensor),\n            tensor\n        )\n    }\n\n    /// Calculates the reciprocals element-wise\n    fn q_recip(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_recip(tensor),\n            tensor\n        )\n    }\n\n    /// Sum of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the sum of all elements in `tensor`.\n    fn q_sum(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sum(tensor),\n            tensor\n        )\n    }\n\n    /// Sum of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    /// * `dim` - The dimension along which to sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the sum of all elements in `tensor` along `dim`.\n    fn q_sum_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sum_dim(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Product of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to product.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the product of all elements in `tensor`.\n    fn q_prod(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_prod(tensor),\n            tensor\n        )\n    }\n\n    /// Product of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the product of all elements in `tensor` along `dim`.\n    fn q_prod_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_prod_dim(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Mean of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to mean.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the mean of all elements in `tensor`.\n    fn q_mean(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_mean(tensor),\n            tensor\n        )\n    }\n\n    /// Mean of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to mean.\n    /// * `dim` - The dimension along which to mean.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the mean of all elements in `tensor` along `dim`.\n    fn q_mean_dim(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_mean_dim(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Computes the cumulative sum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative sum of.\n    /// * `dim` - The dimension along which to compute the cumulative sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative sum\n    /// of all elements up to and including that position along the dimension.\n    fn q_cumsum(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cumsum(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Computes the cumulative product of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative product of.\n    /// * `dim` - The dimension along which to compute the cumulative product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative product\n    /// of all elements up to and including that position along the dimension.\n    fn q_cumprod(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cumprod(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Computes the cumulative minimum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative minimum of.\n    /// * `dim` - The dimension along which to compute the cumulative minimum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the minimum\n    /// of all elements up to and including that position along the dimension.\n    fn q_cummin(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cummin(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Computes the cumulative maximum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative maximum of.\n    /// * `dim` - The dimension along which to compute the cumulative maximum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the maximum\n    /// of all elements up to and including that position along the dimension.\n    fn q_cummax(tensor: QuantizedTensor<B>, dim: usize) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cummax(tensor, dim),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with exponential values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to exponentiate.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with exponential values.\n    fn q_exp(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_exp(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with natural logarithm values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the logarithm of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with natural logarithm values.\n    fn q_log(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_log(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with logarithm values of (1 + Xi).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the logarithm of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).\n    fn q_log1p(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_log1p(tensor),\n            tensor\n        )\n    }\n\n    /// Element-wise power with another tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the power of the elements of `rhs`.\n    fn q_powf(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |lhs, rhs| B::float_powf(lhs, rhs),\n            lhs,\n            rhs\n        )\n    }\n\n    /// Element-wise power with an IntTensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side floatTensor.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.\n    fn q_powi(lhs: QuantizedTensor<B>, rhs: IntTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_powi(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Element-wise power with an int scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`.\n    fn q_powi_scalar(lhs: QuantizedTensor<B>, rhs: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_powi_scalar(tensor, rhs),\n            lhs\n        )\n    }\n\n    /// Element-wise power with a float scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to exponentiate.\n    /// * `value` - The exponent.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.\n    fn q_powf_scalar(tensor: QuantizedTensor<B>, value: Scalar) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_powf_scalar(tensor, value),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with square root values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the square root of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with square root values.\n    fn q_sqrt(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sqrt(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with absolute values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take absolute value of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with absolute values.\n    fn q_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        dequant_op_quant!(\n            ty Self,\n            float_op |tensor| B::float_abs(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the cosine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with cosine values.\n    fn q_cos(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cos(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the sine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with sine values.\n    fn q_sin(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sin(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the tangent of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with tangent values.\n    fn q_tan(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_tan(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with hyperbolic cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic cosine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.\n    fn q_cosh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_cosh(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with hyperbolic sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic sine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic sine values.\n    fn q_sinh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_sinh(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with hyperbolic tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic tangent of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.\n    fn q_tanh(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_tanh(tensor),\n            tensor\n        )\n    }\n\n    /// Returns a new tensor with the error function values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the error function of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with error function values.\n    fn q_erf(tensor: QuantizedTensor<B>) -> TensorPrimitive<B> {\n        dequant_op_flow!(\n            ty Self,\n            float_op |tensor| B::float_erf(tensor),\n            tensor\n        )\n    }\n\n    /// Concatenates tensors along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensors` - The tensors to concatenate.\n    /// * `dim` - The dimension along which to concatenate.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the concatenated tensors along `dim`.\n    fn q_cat(tensors: Vec<QuantizedTensor<B>>, dim: usize) -> QuantizedTensor<B> {\n        // Heuristic: prioritize first tensor scheme\n        let scheme = *tensors.first().unwrap().scheme();\n\n        let tensor_f = tensors\n            .into_iter()\n            .map(|tensor| Self::dequantize(tensor))\n            .collect();\n\n        let out_f = B::float_cat(tensor_f, dim);\n\n        Self::quantize_dynamic(out_f, &scheme)\n    }\n\n    /// Gets the indices of the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.\n    fn q_argmax(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {\n        // Default implementation. Backends can sort on the int values since qparams remain the same.\n        let tensor_f = Self::dequantize(tensor);\n        B::float_argmax(tensor_f, dim)\n    }\n\n    /// Gets the indices of the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.\n    fn q_argmin(tensor: QuantizedTensor<B>, dim: usize) -> IntTensor<B> {\n        // Default implementation. Backends can sort on the int values since qparams remain the same.\n        let tensor_f = Self::dequantize(tensor);\n        B::float_argmin(tensor_f, dim)\n    }\n\n    /// Gets the maximum element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum element of `tensor`.\n    fn q_max(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::q_max_dim(tensor, 0)\n    }\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum elements of `tensor` along `dim`.\n    fn q_max_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {\n        let index = B::q_argmax(tensor.clone(), dim);\n\n        B::q_gather(dim, tensor, index)\n    }\n\n    /// Gets the maximum elements of a tensor along an axis and their indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.\n    fn q_max_dim_with_indices(\n        tensor: QuantizedTensor<B>,\n        dim: usize,\n    ) -> (QuantizedTensor<B>, IntTensor<B>) {\n        let index = B::q_argmax(tensor.clone(), dim);\n        let values = B::q_gather(dim, tensor, index.clone());\n\n        (values, index)\n    }\n\n    /// Gets the minimum element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the minimum element of `tensor`.\n    fn q_min(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::q_min_dim(tensor, 0)\n    }\n\n    /// Gets the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the minimum elements of `tensor` along `dim`.\n    fn q_min_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {\n        let index = B::q_argmin(tensor.clone(), dim);\n\n        B::q_gather(dim, tensor, index)\n    }\n\n    /// Gets the minimum elements of a tensor along an axis and their indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.\n    fn q_min_dim_with_indices(\n        tensor: QuantizedTensor<B>,\n        dim: usize,\n    ) -> (QuantizedTensor<B>, IntTensor<B>) {\n        let index = B::q_argmin(tensor.clone(), dim);\n        let values = B::q_gather(dim, tensor, index.clone());\n\n        (values, index)\n    }\n\n    /// Gets the maximum element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum element of `tensor`.\n    fn q_max_abs(tensor: QuantizedTensor<B>) -> QuantizedTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::q_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::q_max_abs_dim(tensor, 0)\n    }\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum elements of `tensor` along `dim`.\n    fn q_max_abs_dim(tensor: QuantizedTensor<B>, dim: usize) -> QuantizedTensor<B> {\n        let index = B::q_argmax(B::q_abs(tensor.clone()), dim);\n\n        B::q_gather(dim, tensor, index)\n    }\n\n    /// Tests if any element in the `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.\n    fn q_any(tensor: QuantizedTensor<B>) -> BoolTensor<B> {\n        let tensor_f = Self::dequantize(tensor);\n        B::float_any(tensor_f)\n    }\n\n    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the\n    /// input evaluates to True, False otherwise.\n    fn q_any_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {\n        let tensor_f = Self::dequantize(tensor);\n        B::float_any_dim(tensor_f, dim)\n    }\n\n    /// Tests if all elements in the `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor\n    /// evaluate to True, False otherwise.\n    fn q_all(tensor: QuantizedTensor<B>) -> BoolTensor<B> {\n        let tensor_f = Self::dequantize(tensor);\n        B::float_all(tensor_f)\n    }\n\n    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn q_all_dim(tensor: QuantizedTensor<B>, dim: usize) -> BoolTensor<B> {\n        let tensor_f = Self::dequantize(tensor);\n        B::float_all_dim(tensor_f, dim)\n    }\n\n    /// Sort the elements of the input `tensor` by value in along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.\n    fn q_sort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> QuantizedTensor<B> {\n        // Default implementation. Backends can sort on the int values since qparams remain the same.\n        dequant_op_quant!(\n            ty Self,\n            float_op |tensor| B::float_sort(tensor, dim, descending),\n            tensor\n        )\n    }\n\n    /// Sort the elements of the input `tensor` by value in along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor and corresponding indices, where\n    /// the elements are sorted by value and the indices map back to the original input tensor.\n    fn q_sort_with_indices(\n        tensor: QuantizedTensor<B>,\n        dim: usize,\n        descending: bool,\n    ) -> (QuantizedTensor<B>, IntTensor<B>) {\n        // Default implementation. Backends can sort on the int values since qparams remain the same.\n        let scheme = *tensor.scheme();\n\n        let tensor_f = Self::dequantize(tensor);\n        let (out_f, indices) = B::float_sort_with_indices(tensor_f, dim, descending);\n\n        (Self::quantize_dynamic(out_f, &scheme), indices)\n    }\n\n    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.\n    fn q_argsort(tensor: QuantizedTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {\n        // Default implementation. Backends can sort on the int values since qparams remain the same.\n        let tensor_f = Self::dequantize(tensor);\n        B::float_argsort(tensor_f, dim, descending)\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/repeat_dim.rs",
    "content": "use crate::{\n    Backend, TensorMetadata,\n    tensor::{BasicOps, TensorKind},\n};\nuse alloc::vec::Vec;\nuse burn_std::Slice;\n\npub(crate) fn repeat_with_slice_assign<B: Backend, K: TensorKind<B> + BasicOps<B>>(\n    tensor: K::Primitive,\n    dim: usize,\n    times: usize,\n) -> K::Primitive {\n    let shape = tensor.shape();\n    let device = K::device(&tensor);\n    let dtype = tensor.dtype();\n\n    let original_dim_length = shape[dim];\n    let shape = shape.repeat(dim, times).unwrap();\n\n    let mut tensor_output = K::empty(shape.clone(), &device, dtype);\n\n    let indices_select_all = shape.iter().map(|d| 0..*d).collect::<Vec<_>>();\n\n    let mut output_index = 0;\n    for _ in 0..times {\n        let mut indices = indices_select_all.clone();\n        indices[dim] = output_index..output_index + original_dim_length;\n        output_index += original_dim_length;\n\n        // Convert ranges to Slice\n        let slices: Vec<Slice> = indices\n            .iter()\n            .map(|r| Slice::new(r.start as isize, Some(r.end as isize), 1))\n            .collect();\n        tensor_output = K::slice_assign(tensor_output, &slices, tensor.clone());\n    }\n\n    tensor_output\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/sort.rs",
    "content": "use core::cmp::Ordering;\n\nuse crate::{\n    Backend, DType, TensorData,\n    element::{ElementConversion, ElementOrdered},\n    tensor::{BasicOps, IntElem, IntTensor},\n};\nuse alloc::{vec, vec::Vec};\nuse burn_std::reader::try_read_sync;\nuse burn_std::{bf16, f16};\n\n/// Macro used to dispatch sort operations based on dtype.\nmacro_rules! sort_dispatch_dtype {\n    ($fn:ident, $data:ident, $($args:expr),*) => {\n        match $data.dtype {\n            DType::F64 => $fn::<B, f64>($data, $($args),*),\n            DType::F32 | DType::Flex32 => $fn::<B, f32>($data, $($args),*),\n            DType::F16 => $fn::<B, f16>($data, $($args),*),\n            DType::BF16 => $fn::<B, bf16>($data, $($args),*),\n            DType::I64 => $fn::<B, i64>($data, $($args),*),\n            DType::I32 => $fn::<B, i32>($data, $($args),*),\n            DType::I16 => $fn::<B, i16>($data, $($args),*),\n            DType::I8 => $fn::<B, i8>($data, $($args),*),\n            DType::U64 => $fn::<B, u64>($data, $($args),*),\n            DType::U32 => $fn::<B, u32>($data, $($args),*),\n            DType::U16 => $fn::<B, u16>($data, $($args),*),\n            DType::U8 => $fn::<B, u8>($data, $($args),*),\n            DType::Bool(_) | DType::QFloat(_) => unimplemented!(\"not supported for sorting operations\"),\n        }\n    };\n}\n\n/// Sort the elements of the input `tensor` by value along a given dimension.\n///\n/// This sort is unstable (i.e., may reorder equal elements).\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor.\n/// * `dim` - The axis along which to sort.\n/// * `descending` - The sorting order.\n///\n/// # Returns\n///\n/// A tensor with the same shape as the input tensor, where the elements are sorted by value.\n///\n/// # Remarks\n///\n/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.\n/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved\n/// by static dispatch. It is not designed for direct usage by users, and not recommended to import\n/// or use this function directly.\npub fn sort<B: Backend, K: BasicOps<B>>(\n    tensor: K::Primitive,\n    dim: usize,\n    descending: bool,\n) -> K::Primitive {\n    let device = K::device(&tensor);\n    let msg = \"Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.\";\n    let data = try_read_sync(K::into_data_async(tensor))\n        .expect(msg)\n        .expect(msg);\n\n    let data = sort_dispatch_dtype!(sort_data, data, dim, descending);\n    K::from_data(data, &device)\n}\n\npub fn sort_data<B: Backend, E: ElementOrdered>(\n    mut data: TensorData,\n    dim: usize,\n    descending: bool,\n) -> TensorData {\n    let dims = data.shape.clone();\n    let data_slice = data.as_mut_slice().unwrap();\n    if dims.len() == 1 {\n        // 1D sort\n        data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));\n    } else {\n        sort_slice::<B, E>(data_slice, &dims, dim, None, false, descending);\n    }\n\n    data\n}\n\n/// Sort the elements of the input `tensor` by value along a given dimension.\n///\n/// This sort is unstable (i.e., may reorder equal elements).\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor.\n/// * `dim` - The axis along which to sort.\n/// * `descending` - The sorting order.\n///\n/// # Returns\n///\n/// A tensor with the same shape as the input tensor and corresponding indices, where\n/// the elements are sorted by value and the indices map back to the original input tensor.\n///\n/// # Remarks\n///\n/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.\n/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved\n/// by static dispatch. It is not designed for direct usage by users, and not recommended to import\n/// or use this function directly.\npub fn sort_with_indices<B: Backend, K: BasicOps<B>>(\n    tensor: K::Primitive,\n    dim: usize,\n    descending: bool,\n) -> (K::Primitive, IntTensor<B>) {\n    let device = K::device(&tensor);\n    let msg = \"Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.\";\n    let data = try_read_sync(K::into_data_async(tensor))\n        .expect(msg)\n        .expect(msg);\n\n    let (values, indices) = sort_dispatch_dtype!(sort_data_with_indices, data, dim, descending);\n\n    (\n        K::from_data(values, &device),\n        B::int_from_data(indices, &device),\n    )\n}\n\nfn sort_data_with_indices<B: Backend, E: ElementOrdered>(\n    mut data: TensorData,\n    dim: usize,\n    descending: bool,\n) -> (TensorData, TensorData) {\n    let dims = data.shape.clone();\n    let mut indices_data = dim_indices::<B>(&dims, dim);\n    let data_slice = data.as_mut_slice().unwrap();\n    if dims.len() == 1 {\n        // 1D sort\n        indices_data.sort_unstable_by(|&a, &b| {\n            compare(\n                &data_slice[a.elem::<i64>() as usize],\n                &data_slice[b.elem::<i64>() as usize],\n                descending,\n            )\n        });\n\n        // Permute data in-place by the sorted indices\n        let mut indices = indices_data\n            .clone()\n            .iter()\n            .map(|i| i.elem::<i64>() as usize)\n            .collect::<Vec<_>>();\n        for idx in 0..indices.len() {\n            if indices[idx] != idx {\n                let mut current_idx = idx;\n                loop {\n                    let target_idx = indices[current_idx];\n                    indices[current_idx] = current_idx;\n                    if indices[target_idx] == target_idx {\n                        // correct position\n                        break;\n                    }\n\n                    // Permute data by indices\n                    data_slice.swap(current_idx, target_idx);\n                    current_idx = target_idx;\n                }\n            }\n        }\n    } else {\n        sort_slice::<B, E>(\n            data_slice,\n            &dims,\n            dim,\n            Some(&mut indices_data),\n            true,\n            descending,\n        );\n    }\n\n    (data, TensorData::new(indices_data, dims))\n}\n\n/// Returns the indices that sort the elements of the input `tensor` along a given dimension.\n///\n/// This sort is unstable (i.e., may reorder equal elements).\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor.\n/// * `dim` - The axis along which to sort.\n/// * `descending` - The sorting order.\n///\n/// # Returns\n///\n/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.\n///\n/// # Remarks\n///\n/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.\n/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved\n/// by static dispatch. It is not designed for direct usage by users, and not recommended to import\n/// or use this function directly.\npub fn argsort<B: Backend, K: BasicOps<B>>(\n    tensor: K::Primitive,\n    dim: usize,\n    descending: bool,\n) -> IntTensor<B> {\n    let device = K::device(&tensor);\n    let msg = \"Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.\";\n    let data = try_read_sync(K::into_data_async(tensor))\n        .expect(msg)\n        .expect(msg);\n\n    let data = sort_dispatch_dtype!(argsort_data, data, dim, descending);\n    B::int_from_data(data, &device)\n}\n\nfn argsort_data<B: Backend, E: ElementOrdered>(\n    mut data: TensorData,\n    dim: usize,\n    descending: bool,\n) -> TensorData {\n    let dims = data.shape.clone();\n    let mut indices_data = dim_indices::<B>(&dims, dim);\n    if dims.len() == 1 {\n        // 1D sort\n        let slice = data.as_slice::<E>().unwrap();\n        indices_data.sort_unstable_by(|&a, &b| {\n            compare(\n                &slice[a.elem::<i64>() as usize],\n                &slice[b.elem::<i64>() as usize],\n                descending,\n            )\n        });\n    } else {\n        sort_slice::<B, E>(\n            data.as_mut_slice().unwrap(),\n            &dims,\n            dim,\n            Some(&mut indices_data),\n            false,\n            descending,\n        );\n    }\n\n    TensorData::new(indices_data, dims)\n}\n\n/// Sort the elements by value along a given dimension.\n///\n/// When `indices` are not provided, the `data` is sorted.\n/// Otherwise, the `indices` are sorted based on the value of the elements in `data`,\n/// and if `permute_both` is enabled then the data is also sorted.\n///\n/// This sort is unstable (i.e., may reorder equal elements).\nfn sort_slice<B: Backend, E: ElementOrdered>(\n    data: &mut [E],\n    dims: &[usize],\n    dim: usize,\n    mut indices: Option<&mut [IntElem<B>]>,\n    permute_both: bool,\n    descending: bool,\n) {\n    let ndims = dims.len();\n    let strides = compute_strides(dims);\n    // Dimensions to access elements to sort\n    let mut sort_dims = dims.to_vec();\n    sort_dims[dim] = 1;\n    let strides_out = compute_strides(&sort_dims);\n\n    // Number of groups to sort\n    let num_sorts: usize = dims\n        .iter()\n        .enumerate()\n        .filter(|&(i, _)| i != dim)\n        .map(|(_, d)| d)\n        .product();\n\n    // TODO: run each sort in parallel\n    // run_par!(|| {\n    //     iter_range_par!(0, num_sorts).for_each(|id| {...})\n    for id in 0..num_sorts {\n        let mut index_offset = 0;\n        let mut stride_dim = 0;\n        let mut shape_dim = 0;\n        for d in 0..ndims {\n            let stride_input = strides[d];\n            let stride_output = strides_out[d];\n            let shape_output = sort_dims[d];\n\n            let num_block = id / stride_output % shape_output;\n\n            if d != dim {\n                index_offset += num_block * stride_input;\n            } else {\n                let shape_input = dims[d];\n                stride_dim = stride_input;\n                shape_dim = shape_input;\n                index_offset += num_block;\n            }\n        }\n\n        // For each group, sort the indices based on the element values\n        // NOTE: Sorting methods like `sort_unstable_by` are in-place but we need to sort\n        // different views/groups of the underlying data, so the swap is performed on the elements\n        // of the (flat index, element value) collection.\n        let mut elements = (0..shape_dim)\n            .map(|d| {\n                let flat_index = d * stride_dim + index_offset;\n                let elem = data[flat_index];\n                (d, flat_index, elem)\n            })\n            .collect::<Vec<_>>();\n\n        elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));\n\n        // Permute data in-place by the sorted indices\n        for idx in 0..elements.len() {\n            if elements[idx].0 != idx {\n                let mut current_idx = idx;\n                loop {\n                    let target_idx = elements[current_idx].0;\n                    elements[current_idx].0 = current_idx;\n                    if elements[target_idx].0 == target_idx {\n                        // correct position\n                        break;\n                    }\n\n                    if indices.is_none() || permute_both {\n                        // Permute data by indices\n                        data.swap(elements[current_idx].1, elements[target_idx].1);\n                    }\n\n                    if let Some(ref mut indices_data) = indices {\n                        // Permute data element indices\n                        indices_data.swap(elements[current_idx].1, elements[target_idx].1);\n                    }\n\n                    current_idx = target_idx;\n                }\n            }\n        }\n    }\n}\n\n/// Computes the steps for each dimension when traversing an array.\nfn compute_strides(dims: &[usize]) -> Vec<usize> {\n    let mut strides = vec![0; dims.len()];\n    let mut current = 1;\n\n    dims.iter().enumerate().rev().for_each(|(index, val)| {\n        strides[index] = current;\n        current *= val;\n    });\n\n    strides\n}\n\n/// Generates the indices for each element along the specified dimension.\nfn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {\n    if dims.len() == 1 {\n        (0..dims[dim])\n            .map(|i| (i as i64).elem::<IntElem<B>>())\n            .collect::<Vec<_>>()\n    } else {\n        // Dimension indices tensor\n        let numel_leading_dims: usize = dims[..dim].iter().product();\n        let numel_trailing_dims: usize = dims[dim + 1..].iter().product();\n        (0..dims[dim])\n            .map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))\n            .collect::<Vec<_>>()\n            .concat()\n            .repeat(numel_leading_dims)\n    }\n}\n\n/// Compare two elements\nfn compare<E: ElementOrdered>(a: &E, b: &E, descending: bool) -> Ordering {\n    if descending { b.cmp(a) } else { a.cmp(b) }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/tensor.rs",
    "content": "use super::cat::cat_with_slice_assign;\nuse super::grid_sample::float_grid_sample_2d_ref;\nuse super::repeat_dim::repeat_with_slice_assign;\nuse super::sort::{argsort, sort, sort_with_indices};\nuse crate::ops::GridSampleOptions;\nuse crate::tensor::{BoolTensor, Device, Float, FloatTensor, IntTensor};\nuse crate::{Backend, Distribution, TensorData};\nuse crate::{ExecutionError, Scalar, TensorMetadata, TensorPrimitive};\nuse alloc::vec::Vec;\nuse burn_std::{FloatDType, Shape, Slice};\n\n/// Operations on float tensors.\npub trait FloatTensorOps<B: Backend> {\n    /// Creates a new tensor from the data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The data structure.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given data.\n    fn float_from_data(data: TensorData, device: &Device<B>) -> FloatTensor<B>;\n\n    /// Creates a new tensor with random values.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `distribution` - The distribution to sample from.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given shape and random values.\n    fn float_random(shape: Shape, distribution: Distribution, device: &Device<B>)\n    -> FloatTensor<B>;\n\n    /// Creates a new tensor with zeros.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given shape and zeros.\n    fn float_zeros(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {\n        Self::float_from_data(TensorData::full_dtype(shape, 0., dtype.into()), device)\n    }\n\n    /// Creates a new tensor with ones.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given shape and ones.\n    fn float_ones(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B> {\n        Self::float_from_data(TensorData::full_dtype(shape, 1., dtype.into()), device)\n    }\n\n    /// Creates a tensor filled with given value.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `fill_value` - The value with which to fill the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor filled with given value\n    fn float_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<B>,\n        dtype: FloatDType,\n    ) -> FloatTensor<B> {\n        Self::float_from_data(\n            TensorData::full_dtype(shape, fill_value, dtype.into()),\n            device,\n        )\n    }\n\n    /// Converts the tensor to a data structure.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The data structure with the tensor's data.\n    fn float_into_data(\n        tensor: FloatTensor<B>,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;\n\n    /// Gets the device of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The device of the tensor.\n    fn float_device(tensor: &FloatTensor<B>) -> Device<B>;\n\n    /// Moves the tensor to the given device.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `device` - The device to move the tensor to.\n    ///\n    /// # Returns\n    ///\n    /// The tensor on the given device.\n    fn float_to_device(tensor: FloatTensor<B>, device: &Device<B>) -> FloatTensor<B>;\n\n    /// Converts float tensor to int tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The int tensor with the same data as the float tensor.\n    fn float_into_int(tensor: FloatTensor<B>) -> IntTensor<B>;\n\n    /// Creates an empty tensor with the given shape.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The empty tensor with the given shape.\n    fn float_empty(shape: Shape, device: &Device<B>, dtype: FloatDType) -> FloatTensor<B>;\n\n    /// Repeat the tensor along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension to repeat.\n    /// * `times` - The number of times to repeat the dimension.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the given dimension repeated.\n    fn float_repeat_dim(tensor: FloatTensor<B>, dim: usize, times: usize) -> FloatTensor<B> {\n        repeat_with_slice_assign::<B, Float>(TensorPrimitive::Float(tensor), dim, times).tensor()\n    }\n\n    /// Adds two tensors together.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of adding the two tensors together.\n    fn float_add(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Adds a scalar to a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of adding the scalar to the tensor.\n    fn float_add_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;\n\n    /// Clamps a tensor under a minimum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn float_clamp_min(tensor: FloatTensor<B>, min: Scalar) -> FloatTensor<B> {\n        // Default implementation\n        let mask = Self::float_lower_elem(tensor.clone(), min);\n        B::float_mask_fill(tensor, mask, min)\n    }\n\n    /// Clamps a tensor over a maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn float_clamp_max(tensor: FloatTensor<B>, max: Scalar) -> FloatTensor<B> {\n        // Default implementation\n        let mask = Self::float_greater_elem(tensor.clone(), max);\n        B::float_mask_fill(tensor, mask, max)\n    }\n\n    /// Clamps a tensor between a minimum and maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// The clamped tensor.\n    fn float_clamp(tensor: FloatTensor<B>, min: Scalar, max: Scalar) -> FloatTensor<B> {\n        // Default implementation\n        Self::float_clamp_min(Self::float_clamp_max(tensor, max), min)\n    }\n\n    /// Subtracts two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of subtracting the two tensors.\n    fn float_sub(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Subtracts a scalar from a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of subtracting the scalar from the tensor.\n    fn float_sub_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;\n\n    /// Multiplies two tensors together element-wise.\n    fn float_mul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Multiplies a tensor by a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of multiplying the tensor by the scalar.\n    fn float_mul_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;\n\n    /// Divides two tensors element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of dividing the two tensors.\n    fn float_div(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Divides a tensor by a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of dividing the tensor by the scalar.\n    fn float_div_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;\n\n    /// Computes the remainder of division between two tensors element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The element-wise remainder when dividing `lhs` by `rhs`.\n    fn float_remainder(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Computes the modulus of a tensor given a scalar.\n    ///\n    /// # Arguments\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The result of applying the modulus of the scalar to the tensor.\n    fn float_remainder_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B>;\n\n    /// Multiplies two tensors together using matrix multiplication.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The result of multiplying the two tensors together using matrix multiplication.\n    fn float_matmul(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Computes the cross product of two tensors along a given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    /// * `dim` - The dimension to compute the cross product along.\n    ///\n    /// # Returns\n    ///\n    /// The cross product of the two tensors.\n    fn float_cross(lhs: FloatTensor<B>, rhs: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Negates a tensor element-wise.\n    fn float_neg(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        Self::float_mul_scalar(tensor, (-1f32).into())\n    }\n\n    /// Calculates the reciprocals element-wise\n    fn float_recip(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Transposes a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    fn float_transpose(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let ndims = tensor.shape().num_dims();\n        Self::float_swap_dims(tensor, ndims - 2, ndims - 1)\n    }\n\n    /// Swaps two dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap.\n    /// * `dim2` - The second dimension to swap.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    fn float_swap_dims(tensor: FloatTensor<B>, dim1: usize, dim2: usize) -> FloatTensor<B>;\n\n    /// Permutes the dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to permute the dimensions of.\n    /// * `axes` - The new order of the dimensions.\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    fn float_permute(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;\n\n    /// Reverse the order of elements in a tensor along the given axes.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reverse.\n    /// * `axes` - The axes to reverse.\n    ///\n    /// The tensor with the elements reversed.\n    fn float_flip(tensor: FloatTensor<B>, axes: &[usize]) -> FloatTensor<B>;\n\n    /// Reshapes a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to reshape.\n    /// * `shape` - The new shape of the tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the new shape.\n    fn float_reshape(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;\n\n    /// Gather elements from a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to gather from.\n    /// * `tensor` - The tensor to gather from.\n    /// * `indices` - The indices to gather.\n    ///\n    /// # Returns\n    ///\n    /// The gathered elements.\n    fn float_gather(dim: usize, tensor: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B>;\n\n    /// Scatter elements into a tensor using sum reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to scatter into.\n    /// * `tensor` - The tensor to scatter into.\n    /// * `indices` - The indices to scatter into.\n    /// * `value` - The value to scatter.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the scattered elements.\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<B>,\n        indices: IntTensor<B>,\n        value: FloatTensor<B>,\n    ) -> FloatTensor<B>;\n\n    /// Select tensor elements along the given dimension corresponding for the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices to select.\n    ///\n    /// # Returns\n    ///\n    /// The selected elements.\n    fn float_select(tensor: FloatTensor<B>, dim: usize, indices: IntTensor<B>) -> FloatTensor<B>;\n\n    /// Assign the selected elements along the given dimension corresponding for the given indices\n    /// to the given value using sum reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `dim` - The dimension to select from.\n    /// * `indices` - The indices to select.\n    /// * `value` - The value to assign.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements assigned to the given value.\n    fn float_select_add(\n        tensor: FloatTensor<B>,\n        dim: usize,\n        indices: IntTensor<B>,\n        value: FloatTensor<B>,\n    ) -> FloatTensor<B>;\n\n    /// Select tensor elements corresponding to the given slices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `slices` - The slices specifying ranges and steps for each dimension.\n    ///\n    /// # Returns\n    ///\n    /// The selected elements in a new tensor.\n    ///\n    /// # Note\n    ///\n    /// Empty slices (where start >= end) are handled at the high-level tensor API and will not\n    /// be passed to this method. Backend implementations do not need to handle empty slices.\n    fn float_slice(tensor: FloatTensor<B>, slices: &[Slice]) -> FloatTensor<B>;\n\n    /// Assign the selected elements corresponding to the given slices to the given value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `ranges` - The ranges to select.\n    /// * `value` - The value to assign.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements assigned to the given value.\n    ///\n    /// # Note\n    ///\n    /// Empty slice assignments (where any slice range produces 0 elements) are handled at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty slice assignments.\n    fn float_slice_assign(\n        tensor: FloatTensor<B>,\n        slices: &[Slice],\n        value: FloatTensor<B>,\n    ) -> FloatTensor<B>;\n\n    /// Update the given tensor with the value tensor where the mask is true.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `mask` - The boolean mask to select with.\n    /// * `value` - The value to assign to the selected elements from the value tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements assigned to the given value.\n    fn float_mask_where(\n        tensor: FloatTensor<B>,\n        mask: BoolTensor<B>,\n        value: FloatTensor<B>,\n    ) -> FloatTensor<B>;\n\n    /// Update the given tensor with the value where the mask is true.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `mask` - The boolean mask to select with.\n    /// * `value` - The value to assign to the selected elements.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the selected elements assigned to the given value.\n    fn float_mask_fill(\n        tensor: FloatTensor<B>,\n        mask: BoolTensor<B>,\n        value: Scalar,\n    ) -> FloatTensor<B>;\n\n    /// Equal comparison of two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_not_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B> {\n        let equal_tensor = B::float_equal(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Equal comparison of a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Element-wise non-equality comparison with a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_not_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B> {\n        let equal_tensor = B::float_equal_elem(lhs, rhs);\n        B::bool_not(equal_tensor)\n    }\n\n    /// Greater than comparison of two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_greater(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;\n\n    /// Greater than comparison of a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_greater_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Greater than or equal comparison of two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_greater_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;\n\n    /// Greater than or equal comparison of a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_greater_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Less than comparison of two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_lower(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;\n\n    /// Less than comparison of a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_lower_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Less than or equal comparison of two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_lower_equal(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> BoolTensor<B>;\n\n    /// Less than or equal comparison of a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the result of the comparison.\n    fn float_lower_equal_elem(lhs: FloatTensor<B>, rhs: Scalar) -> BoolTensor<B>;\n\n    /// Detaches a tensor from the computation graph.\n    fn float_detach(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        // Should only be overridden by autodiff backends.\n        tensor\n    }\n\n    /// Sets the `require_grad` flag of a tensor.\n    fn float_set_require_grad(tensor: FloatTensor<B>, _require_grad: bool) -> FloatTensor<B> {\n        // Should only be overridden by autodiff backends.\n        tensor\n    }\n\n    /// Returns the `require_grad` flag of a tensor.\n    fn float_is_require_grad(_tensor: &FloatTensor<B>) -> bool {\n        // Should only be overridden by autodiff backends.\n        false\n    }\n\n    /// Sum of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the sum of all elements in `tensor`.\n    fn float_sum(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Sum of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    /// * `dim` - The dimension along which to sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the sum of all elements in `tensor` along `dim`.\n    fn float_sum_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Product of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to product.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the product of all elements in `tensor`.\n    fn float_prod(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        // Product of all elements in a tensor\n        B::float_exp(B::float_sum(B::float_log(tensor)))\n    }\n\n    /// Product of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the product of all elements in `tensor` along `dim`.\n    fn float_prod_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {\n        // Product of all elements in a tensor along a dimension\n        B::float_exp(B::float_sum_dim(B::float_log(tensor), dim))\n    }\n\n    /// Mean of all elements in a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to mean.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor with the mean of all elements in `tensor`.\n    fn float_mean(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let num_elems = tensor.shape().num_elements() as f32;\n        B::float_div_scalar(B::float_sum(tensor), num_elems.into())\n    }\n\n    /// Mean of all elements in a tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to mean.\n    /// * `dim` - The dimension along which to mean.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the mean of all elements in `tensor` along `dim`.\n    fn float_mean_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Computes the cumulative sum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative sum of.\n    /// * `dim` - The dimension along which to compute the cumulative sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative sum\n    /// of all elements up to and including that position along the dimension.\n    fn float_cumsum(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Computes the cumulative product of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative product of.\n    /// * `dim` - The dimension along which to compute the cumulative product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the cumulative product\n    /// of all elements up to and including that position along the dimension.\n    fn float_cumprod(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Computes the cumulative minimum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative minimum of.\n    /// * `dim` - The dimension along which to compute the cumulative minimum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the minimum\n    /// of all elements up to and including that position along the dimension.\n    fn float_cummin(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Computes the cumulative maximum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative maximum of.\n    /// * `dim` - The dimension along which to compute the cumulative maximum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the maximum\n    /// of all elements up to and including that position along the dimension.\n    fn float_cummax(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B>;\n\n    /// Converts a tensor to another floating point data type.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to convert.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same values as `tensor` but in the target floating point data type.\n    fn float_cast(tensor: FloatTensor<B>, dtype: FloatDType) -> FloatTensor<B>;\n\n    /// Returns a new tensor with exponential values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to exponentiate.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with exponential values.\n    fn float_exp(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with natural logarithm values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the logarithm of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with natural logarithm values.\n    fn float_log(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with logarithm values of (1 + Xi).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the logarithm of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).\n    fn float_log1p(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Element-wise power with a FloatTensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the power of the elements of `rhs`.\n    fn float_powf(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Element-wise power with an IntTensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side floatTensor.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor.\n    fn float_powi(lhs: FloatTensor<B>, rhs: IntTensor<B>) -> FloatTensor<B> {\n        Self::float_powf(lhs, B::int_into_float(rhs))\n    }\n\n    /// Raises a tensor to the power of an int scalar.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// A number of common exponent cases can be implemented with operations\n    /// which are much cheaper than generic exponentiation.\n    ///\n    /// This (`Backend` impl overridable) operation handles generic optimizations\n    /// for several common integer exponent cases; and then dispatches to\n    /// the (`Backend` impl overridable) [`Self::float_powi_scalar_impl`]\n    /// operation to handle the generic case.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`.\n    fn float_powi_scalar(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {\n        match rhs.elem::<i64>() {\n            0 => Self::float_ones(lhs.shape(), &B::float_device(&lhs), lhs.dtype().into()),\n            1 => lhs,\n            2 => B::float_mul(lhs.clone(), lhs),\n            -1 => Self::float_recip(lhs),\n            -2 => Self::float_recip(B::float_mul(lhs.clone(), lhs)),\n            _ => Self::float_powi_scalar_impl(lhs, rhs),\n        }\n    }\n\n    /// Raises a tensor to the power of an int scalar.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// This is the generic implementation of integer exponentiation\n    /// called by [`Self::float_powi_scalar`] in the fallback case.\n    ///\n    /// As a general rule, this should not be called directly.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left-hand side tensor.\n    /// * `rhs` - The right-hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The elements of `lhs` raised to the value of `rhs`.\n    fn float_powi_scalar_impl(lhs: FloatTensor<B>, rhs: Scalar) -> FloatTensor<B> {\n        // Avoid a recursive loop by deferring directly to float_powf_scalar_impl.\n        Self::float_powf_scalar_impl(lhs, rhs)\n    }\n\n    /// Returns a new tensor with values raised to the power of float `value`.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// This (`Backend` impl overridable) operation dispatches integer exponentiation\n    /// to [`Self::float_powi_scalar`], and the remaining non-integer exponent cases to\n    /// the (`Backend` impl overridable) [`Self::float_powf_scalar_impl`]\n    /// operation to handle the generic case.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to exponentiate.\n    /// * `value` - The exponent.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.\n    fn float_powf_scalar(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B> {\n        if let Some(exp) = value.try_as_integer() {\n            Self::float_powi_scalar(tensor, exp)\n        } else {\n            Self::float_powf_scalar_impl(tensor, value)\n        }\n    }\n\n    /// Returns a new tensor with values raised to the power of float `value`.\n    ///\n    /// # Backend Implementors Note\n    ///\n    /// This is the generic implementation of integer exponentiation\n    /// called by [`Self::float_powf_scalar`] in the fallback case.\n    ///\n    /// This is the minimal required support a `Backend` must implement\n    /// for exponentiation.\n    ///\n    /// As a general rule, this should not be called directly.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to exponentiate.\n    /// * `value` - The exponent.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with values raised to the power of `value`.\n    fn float_powf_scalar_impl(tensor: FloatTensor<B>, value: Scalar) -> FloatTensor<B>;\n\n    /// Returns a new tensor with square root values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the square root of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with square root values.\n    fn float_sqrt(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with absolute values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take absolute value of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with absolute values.\n    fn float_abs(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the cosine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with cosine values.\n    fn float_cos(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the sine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with sine values.\n    fn float_sin(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the tangent of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with tangent values.\n    fn float_tan(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with hyperbolic cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic cosine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic cosine values.\n    fn float_cosh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with hyperbolic sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic sine of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic sine values.\n    fn float_sinh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with hyperbolic tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the hyperbolic tangent of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with hyperbolic tangent values.\n    fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with inverse cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with inverse cosine values.\n    fn float_acos(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with inverse hyperbolic cosine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with inverse hyperbolic cosine values.\n    fn float_acosh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with inverse sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with inverse sine values.\n    fn float_asin(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with inverse hyperbolic sine values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with inverse hyperbolic sine values.\n    fn float_asinh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with the inverse tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with the inverse tangent values.\n    fn float_atan(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with the inverse hyperbolic tangent values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with the inverse hyperbolic tangent values.\n    fn float_atanh(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a tensor with the four-quadrant inverse tangent values of `y` and `x`.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The tensor with y coordinates.\n    /// * `rhs` - The tensor with x coordinates.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the four-quadrant inverse tangent values.\n    fn float_atan2(lhs: FloatTensor<B>, rhs: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with rounded values.\n    ///\n    /// This function should implement the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)\n    /// strategy, with halfway cases rounded to the nearest even integer value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to be rounded.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with rounded values.\n    fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with floored values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to be floored.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with floored values.\n    fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with ceiled values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to be ceiled.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with ceiled values.\n    fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with truncated values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to be truncated.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with truncated values.\n    fn float_trunc(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Returns a new tensor with the error function values.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to take the error function of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` with error function values.\n    fn float_erf(tensor: FloatTensor<B>) -> FloatTensor<B>;\n\n    /// Concatenates tensors along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensors` - The tensors to concatenate.\n    /// * `dim` - The dimension along which to concatenate.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the concatenated tensors along `dim`.\n    ///\n    /// # Note\n    ///\n    /// Empty tensors (where the concatenation dimension has size 0) are filtered out at the\n    /// high-level tensor API and will not be passed to this method. Backend implementations do\n    /// not need to handle empty tensors.\n    fn float_cat(tensors: Vec<FloatTensor<B>>, dim: usize) -> FloatTensor<B> {\n        cat_with_slice_assign::<B, Float>(\n            tensors.into_iter().map(TensorPrimitive::Float).collect(),\n            dim,\n        )\n        .tensor()\n    }\n\n    /// Gets the indices of the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the indices of the maximum elements of `tensor` along `dim`.\n    fn float_argmax(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Gets the indices of the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the indices of the minimum elements of `tensor` along `dim`.\n    fn float_argmin(tensor: FloatTensor<B>, dim: usize) -> IntTensor<B>;\n\n    /// Gets the maximum element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum element of `tensor`.\n    fn float_max(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::float_max_dim(tensor, 0)\n    }\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum elements of `tensor` along `dim`.\n    fn float_max_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {\n        let index = B::float_argmax(tensor.clone(), dim);\n\n        B::float_gather(dim, tensor, index)\n    }\n\n    /// Gets the maximum elements of a tensor along an axis and their indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tuple with the maximum elements of `tensor` along `dim` and their indices.\n    fn float_max_dim_with_indices(\n        tensor: FloatTensor<B>,\n        dim: usize,\n    ) -> (FloatTensor<B>, IntTensor<B>) {\n        let index = B::float_argmax(tensor.clone(), dim);\n        let values = B::float_gather(dim, tensor, index.clone());\n\n        (values, index)\n    }\n\n    /// Gets the minimum element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the minimum element of `tensor`.\n    fn float_min(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::float_min_dim(tensor, 0)\n    }\n\n    /// Gets the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the minimum elements of `tensor` along `dim`.\n    fn float_min_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {\n        let index = B::float_argmin(tensor.clone(), dim);\n\n        B::float_gather(dim, tensor, index)\n    }\n\n    /// Gets the minimum elements of a tensor along an axis and their indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements of.\n    /// * `dim` - The dimension along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tuple with the minimum elements of `tensor` along `dim` and their indices.\n    fn float_min_dim_with_indices(\n        tensor: FloatTensor<B>,\n        dim: usize,\n    ) -> (FloatTensor<B>, IntTensor<B>) {\n        let index = B::float_argmin(tensor.clone(), dim);\n        let values = B::float_gather(dim, tensor, index.clone());\n\n        (values, index)\n    }\n\n    /// Gets the maximum absolute element of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum element of `tensor`.\n    fn float_max_abs(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let shape = tensor.shape();\n        let tensor = B::float_reshape(tensor, Shape::new([shape.num_elements()]));\n\n        B::float_max_abs_dim(tensor, 0)\n    }\n\n    /// Gets the maximum absolute elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements of.\n    /// * `dim` - The dimension along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the maximum elements of `tensor` along `dim`.\n    fn float_max_abs_dim(tensor: FloatTensor<B>, dim: usize) -> FloatTensor<B> {\n        B::float_max_dim(B::float_abs(tensor), dim)\n    }\n\n    /// Tests if any element in the float `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if any element in the tensor is True, False otherwise.\n    fn float_any(tensor: FloatTensor<B>) -> BoolTensor<B> {\n        let bool_tensor = B::float_equal_elem(tensor, 0f32.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::float_sum(B::bool_into_float(bool_tensor));\n        B::float_greater_elem(sum, 0f32.into())\n    }\n\n    /// Tests if any element in the float `tensor` evaluates to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the\n    /// input evaluates to True, False otherwise.\n    fn float_any_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {\n        let bool_tensor = B::float_equal_elem(tensor, 0f32.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);\n        B::float_greater_elem(sum, 0f32.into())\n    }\n\n    /// Tests if all elements in the float `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor\n    /// evaluate to True, False otherwise.\n    fn float_all(tensor: FloatTensor<B>) -> BoolTensor<B> {\n        let num_elems = tensor.shape().num_elements() as f32;\n        let bool_tensor = B::float_equal_elem(tensor, 0f32.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::float_sum(B::bool_into_float(bool_tensor));\n        B::float_equal_elem(sum, num_elems.into())\n    }\n\n    /// Tests if all elements in the float `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same size as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input\n    /// evaluates to True, False otherwise.\n    fn float_all_dim(tensor: FloatTensor<B>, dim: usize) -> BoolTensor<B> {\n        let num_elems = tensor.shape()[dim] as f32;\n        let bool_tensor = B::float_equal_elem(tensor, 0f32.into());\n        let bool_tensor = B::bool_not(bool_tensor);\n        let sum = B::float_sum_dim(B::bool_into_float(bool_tensor), dim);\n        B::float_equal_elem(sum, num_elems.into())\n    }\n\n    /// Returns the signs of the float `tensor`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to extract the signs from.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as `tensor` containing the signs of the elements of `tensor`.\n    fn float_sign(tensor: FloatTensor<B>) -> FloatTensor<B> {\n        let zeros = B::float_zeros(\n            tensor.shape(),\n            &B::float_device(&tensor),\n            tensor.dtype().into(),\n        );\n        let less_than_zero = B::float_lower_elem(tensor.clone(), 0f32.into());\n        let greater_than_zero = B::float_greater_elem(tensor, 0f32.into());\n\n        let mut result = B::float_mask_fill(zeros, less_than_zero, (-1f32).into());\n        result = B::float_mask_fill(result, greater_than_zero, 1f32.into());\n        result\n    }\n\n    /// Broadcasts the float `tensor` to the given `shape`.\n    fn float_expand(tensor: FloatTensor<B>, shape: Shape) -> FloatTensor<B>;\n\n    /// Sort the elements of the input `tensor` by value in along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.\n    fn float_sort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> FloatTensor<B> {\n        sort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending).tensor()\n    }\n\n    /// Sort the elements of the input `tensor` by value in along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor and corresponding indices, where\n    /// the elements are sorted by value and the indices map back to the original input tensor.\n    fn float_sort_with_indices(\n        tensor: FloatTensor<B>,\n        dim: usize,\n        descending: bool,\n    ) -> (FloatTensor<B>, IntTensor<B>) {\n        let (values, indices) =\n            sort_with_indices::<B, Float>(TensorPrimitive::Float(tensor), dim, descending);\n        (values.tensor(), indices)\n    }\n\n    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.\n    fn float_argsort(tensor: FloatTensor<B>, dim: usize, descending: bool) -> IntTensor<B> {\n        argsort::<B, Float>(TensorPrimitive::Float(tensor), dim, descending)\n    }\n\n    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,\n    /// using the given locations in [-1, 1].\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)\n    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].\n    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right\n    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)\n    ///\n    /// # Returns\n    ///\n    /// A tensor with shape (N, C, H_out, W_out)\n    fn float_grid_sample_2d(\n        tensor: FloatTensor<B>,\n        grid: FloatTensor<B>,\n        options: GridSampleOptions,\n    ) -> FloatTensor<B> {\n        float_grid_sample_2d_ref::<B>(tensor, grid, options)\n    }\n\n    /// Unfold windows along a dimension.\n    ///\n    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n    /// * `dim` - the selected dim.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with shape ``[pre=..., windows, size, post=...]``.\n    fn float_unfold(tensor: FloatTensor<B>, dim: usize, size: usize, step: usize)\n    -> FloatTensor<B>;\n\n    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.\n    fn float_is_nan(tensor: FloatTensor<B>) -> BoolTensor<B> {\n        // Check if the input tensor is NaN by comparing it to itself\n        // NaN is the only value that is not equal to itself\n        B::float_not_equal(tensor.clone(), tensor)\n    }\n\n    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where `true` indicates that the value is infinite\n    fn float_is_inf(tensor: FloatTensor<B>) -> BoolTensor<B> {\n        B::float_equal_elem(B::float_abs(tensor), f64::INFINITY.into())\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/ops/transaction.rs",
    "content": "use alloc::vec::Vec;\nuse core::future::Future;\n\nuse crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};\nuse crate::{Backend, ExecutionError, TensorData, TensorPrimitive};\n\nenum Order {\n    Float(usize),\n    QFloat(usize),\n    Int(usize),\n    Bool(usize),\n}\n\n#[derive(Default)]\n/// Contains all tensor primitives that are going to be read.\npub struct TransactionPrimitive<B: Backend> {\n    /// Float tensors.\n    pub read_floats: Vec<FloatTensor<B>>,\n    /// Quantized tensors.\n    pub read_qfloats: Vec<QuantizedTensor<B>>,\n    /// Int tensors.\n    pub read_ints: Vec<IntTensor<B>>,\n    /// Bool tensors.\n    pub read_bools: Vec<BoolTensor<B>>,\n    orders: Vec<Order>,\n}\n\n#[derive(Default)]\n/// Contains all [data](TensorData) related to a [transaction](TransactionPrimitive).\npub struct TransactionPrimitiveData {\n    /// Float tensor data.\n    pub read_floats: Vec<TensorData>,\n    /// Quantized tensor data.\n    pub read_qfloats: Vec<TensorData>,\n    /// Int tensor data.\n    pub read_ints: Vec<TensorData>,\n    /// Bool tensor data.\n    pub read_bools: Vec<TensorData>,\n}\n\n/// Operations that are sync by nature and that can be batch together in transactions to improve\n/// compute utilization with efficient laziness.\npub trait TransactionOps<B: Backend> {\n    /// Executes a [transaction](TransactionPrimitive) and return its\n    /// [data](TransactionPrimitiveData).\n    fn tr_execute(\n        transaction: TransactionPrimitive<B>,\n    ) -> impl Future<Output = Result<TransactionPrimitiveData, ExecutionError>> + Send {\n        async move {\n            let mut floats = Vec::new();\n            let mut qfloats = Vec::new();\n            let mut ints = Vec::new();\n            let mut bools = Vec::new();\n\n            for t in transaction.read_floats {\n                floats.push(B::float_into_data(t).await?);\n            }\n            for t in transaction.read_qfloats {\n                qfloats.push(B::q_into_data(t).await?);\n            }\n            for t in transaction.read_ints {\n                ints.push(B::int_into_data(t).await?);\n            }\n            for t in transaction.read_bools {\n                bools.push(B::bool_into_data(t).await?);\n            }\n\n            Ok(TransactionPrimitiveData {\n                read_floats: floats,\n                read_qfloats: qfloats,\n                read_ints: ints,\n                read_bools: bools,\n            })\n        }\n    }\n}\n\nimpl<B: Backend> TransactionPrimitive<B> {\n    /// Creates a new transaction.\n    pub fn new(\n        read_floats: Vec<FloatTensor<B>>,\n        read_qfloats: Vec<QuantizedTensor<B>>,\n        read_ints: Vec<IntTensor<B>>,\n        read_bools: Vec<BoolTensor<B>>,\n    ) -> Self {\n        Self {\n            read_floats,\n            read_qfloats,\n            read_ints,\n            read_bools,\n            orders: Vec::default(),\n        }\n    }\n    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order\n    /// in which they were [registered](crate::tensor::BasicOps::register_transaction).\n    pub async fn execute_async(mut self) -> Result<Vec<TensorData>, ExecutionError> {\n        let mut orders = Vec::new();\n        core::mem::swap(&mut orders, &mut self.orders);\n        let result = B::tr_execute(self).await?;\n\n        let mut floats: Vec<_> = result.read_floats.into_iter().map(Some).collect();\n        let mut qfloats: Vec<_> = result.read_qfloats.into_iter().map(Some).collect();\n        let mut ints: Vec<_> = result.read_ints.into_iter().map(Some).collect();\n        let mut bools: Vec<_> = result.read_bools.into_iter().map(Some).collect();\n\n        Ok(orders\n            .into_iter()\n            .map(|order| match order {\n                Order::Float(index) => floats.get_mut(index).unwrap().take().unwrap(),\n                Order::QFloat(index) => qfloats.get_mut(index).unwrap().take().unwrap(),\n                Order::Int(index) => ints.get_mut(index).unwrap().take().unwrap(),\n                Order::Bool(index) => bools.get_mut(index).unwrap().take().unwrap(),\n            })\n            .collect::<Vec<_>>())\n    }\n\n    pub(crate) fn register_float(&mut self, tensor: TensorPrimitive<B>) {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                self.orders.push(Order::Float(self.read_floats.len()));\n                self.read_floats.push(tensor);\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                self.orders.push(Order::QFloat(self.read_qfloats.len()));\n                self.read_qfloats.push(tensor);\n            }\n        }\n    }\n\n    pub(crate) fn register_int(&mut self, tensor: IntTensor<B>) {\n        self.orders.push(Order::Int(self.read_ints.len()));\n        self.read_ints.push(tensor);\n    }\n\n    pub(crate) fn register_bool(&mut self, tensor: BoolTensor<B>) {\n        self.orders.push(Order::Bool(self.read_bools.len()));\n        self.read_bools.push(tensor);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/backend/primitive.rs",
    "content": "use crate::Backend;\nuse burn_std::quantization::{QuantAcc, QuantPropagation, QuantScheme};\nuse burn_std::{DType, Shape};\n\n#[derive(Debug, Clone)]\n/// A primitive tensor representation.\npub enum TensorPrimitive<B: Backend> {\n    /// Float tensor primitive.\n    Float(B::FloatTensorPrimitive),\n    /// Quantized float tensor primitive.\n    QFloat(B::QuantizedTensorPrimitive),\n}\n\nimpl<B: Backend> TensorPrimitive<B> {\n    /// Returns the full tensor representation.\n    pub fn tensor(self) -> B::FloatTensorPrimitive {\n        match self {\n            Self::QFloat(tensor) => B::dequantize(tensor),\n            Self::Float(tensor) => tensor,\n        }\n    }\n}\n\nimpl<B: Backend> TensorMetadata for TensorPrimitive<B> {\n    fn dtype(&self) -> DType {\n        match self {\n            TensorPrimitive::Float(tensor) => tensor.dtype(),\n            TensorPrimitive::QFloat(tensor) => tensor.dtype(),\n        }\n    }\n\n    fn shape(&self) -> Shape {\n        match self {\n            TensorPrimitive::Float(tensor) => tensor.shape(),\n            TensorPrimitive::QFloat(tensor) => tensor.shape(),\n        }\n    }\n\n    fn rank(&self) -> usize {\n        match self {\n            TensorPrimitive::Float(tensor) => tensor.rank(),\n            TensorPrimitive::QFloat(tensor) => tensor.rank(),\n        }\n    }\n}\n\n/// Tensor metadata trait for tensor primitive.\npub trait TensorMetadata: Clone + Send + Sync + core::fmt::Debug {\n    /// The dtype of the tensor.\n    fn dtype(&self) -> DType;\n    /// The shape of the tensor.\n    fn shape(&self) -> Shape;\n\n    /// The number of dimensions of the tensor.\n    fn rank(&self) -> usize {\n        self.shape().num_dims()\n    }\n}\n\n/// Quantized tensor primitive.\npub trait QTensorPrimitive {\n    /// Returns the quantization settings for the given tensor.\n    fn scheme(&self) -> &QuantScheme;\n    /// The precision used for the accumulation in various kernels.\n    fn acc_precision(&self) -> QuantAcc {\n        QuantAcc::F32\n    }\n    /// How quantization is propagated during computation.\n    fn propagation(&self) -> QuantPropagation {\n        QuantPropagation::Inhibit\n    }\n\n    /// Returns the default tensor quantization scheme.\n    fn default_scheme() -> QuantScheme {\n        QuantScheme::default()\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/data/compare.rs",
    "content": "use alloc::format;\nuse alloc::string::String;\nuse burn_std::{BoolStore, DType, bf16, f16};\nuse num_traits::{Float, ToPrimitive};\n\nuse super::TensorData;\nuse crate::{Element, ElementOrdered};\n\n/// The tolerance used to compare to floating point numbers.\n///\n/// Generally, two numbers `x` and `y` are approximately equal if\n///\n/// ```text\n/// |x - y| < max(R * (|x + y|), A)\n/// ```\n///\n/// where `R` is the relative tolerance and `A` is the absolute tolerance.\n///\n///\n/// The most common way to initialize this struct is to use `Tolerance::<F>::default()`.\n/// In that case, the relative and absolute tolerances are computed using an heuristic based\n/// on the EPSILON and MIN_POSITIVE values of the given floating point type `F`.\n///\n/// Another common initialization is `Tolerance::<F>::rel_abs(1e-4, 1e-5).set_half_precision_relative(1e-2)`.\n/// This will use a sane default to manage values too close to 0.0 and\n/// use different relative tolerances depending on the floating point precision.\n#[derive(Debug, Clone, Copy)]\npub struct Tolerance<F> {\n    relative: F,\n    absolute: F,\n}\n\nimpl<F: Float> Default for Tolerance<F> {\n    fn default() -> Self {\n        Self::balanced()\n    }\n}\n\nimpl<F: Float> Tolerance<F> {\n    /// Create a tolerance with strict precision setting.\n    pub fn strict() -> Self {\n        Self {\n            relative: F::from(0.00).unwrap(),\n            absolute: F::from(64).unwrap() * F::min_positive_value(),\n        }\n    }\n    /// Create a tolerance with balanced precision setting.\n    pub fn balanced() -> Self {\n        Self {\n            relative: F::from(0.005).unwrap(), // 0.5%\n            absolute: F::from(1e-5).unwrap(),\n        }\n    }\n\n    /// Create a tolerance with permissive precision setting.\n    pub fn permissive() -> Self {\n        Self {\n            relative: F::from(0.01).unwrap(), // 1.0%\n            absolute: F::from(0.01).unwrap(),\n        }\n    }\n    /// When comparing two numbers, this uses both the relative and absolute differences.\n    ///\n    /// That is, `x` and `y` are approximately equal if\n    ///\n    /// ```text\n    /// |x - y| < max(R * (|x + y|), A)\n    /// ```\n    ///\n    /// where `R` is the `relative` tolerance and `A` is the `absolute` tolerance.\n    pub fn rel_abs<FF: ToPrimitive>(relative: FF, absolute: FF) -> Self {\n        let relative = Self::check_relative(relative);\n        let absolute = Self::check_absolute(absolute);\n\n        Self { relative, absolute }\n    }\n\n    /// When comparing two numbers, this uses only the relative difference.\n    ///\n    /// That is, `x` and `y` are approximately equal if\n    ///\n    /// ```text\n    /// |x - y| < R * max(|x|, |y|)\n    /// ```\n    ///\n    /// where `R` is the relative `tolerance`.\n    pub fn relative<FF: ToPrimitive>(tolerance: FF) -> Self {\n        let relative = Self::check_relative(tolerance);\n\n        Self {\n            relative,\n            absolute: F::from(0.0).unwrap(),\n        }\n    }\n\n    /// When comparing two numbers, this uses only the absolute difference.\n    ///\n    /// That is, `x` and `y` are approximately equal if\n    ///\n    /// ```text\n    /// |x - y| < A\n    /// ```\n    ///\n    /// where `A` is the absolute `tolerance`.\n    pub fn absolute<FF: ToPrimitive>(tolerance: FF) -> Self {\n        let absolute = Self::check_absolute(tolerance);\n\n        Self {\n            relative: F::from(0.0).unwrap(),\n            absolute,\n        }\n    }\n\n    /// Change the relative tolerance to the given one.\n    pub fn set_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        self.relative = Self::check_relative(tolerance);\n        self\n    }\n\n    /// Change the relative tolerance to the given one only if `F` is half precision.\n    pub fn set_half_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 2 {\n            self.relative = Self::check_relative(tolerance);\n        }\n        self\n    }\n\n    /// Change the relative tolerance to the given one only if `F` is single precision.\n    pub fn set_single_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 4 {\n            self.relative = Self::check_relative(tolerance);\n        }\n        self\n    }\n\n    /// Change the relative tolerance to the given one only if `F` is double precision.\n    pub fn set_double_precision_relative<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 8 {\n            self.relative = Self::check_relative(tolerance);\n        }\n        self\n    }\n\n    /// Change the absolute tolerance to the given one.\n    pub fn set_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        self.absolute = Self::check_absolute(tolerance);\n        self\n    }\n\n    /// Change the absolute tolerance to the given one only if `F` is half precision.\n    pub fn set_half_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 2 {\n            self.absolute = Self::check_absolute(tolerance);\n        }\n        self\n    }\n\n    /// Change the absolute tolerance to the given one only if `F` is single precision.\n    pub fn set_single_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 4 {\n            self.absolute = Self::check_absolute(tolerance);\n        }\n        self\n    }\n\n    /// Change the absolute tolerance to the given one only if `F` is double precision.\n    pub fn set_double_precision_absolute<FF: ToPrimitive>(mut self, tolerance: FF) -> Self {\n        if core::mem::size_of::<F>() == 8 {\n            self.absolute = Self::check_absolute(tolerance);\n        }\n        self\n    }\n\n    /// Checks if `x` and `y` are approximately equal given the tolerance.\n    pub fn approx_eq(&self, x: F, y: F) -> bool {\n        // See the accepted answer here\n        // https://stackoverflow.com/questions/4915462/how-should-i-do-floating-point-comparison\n\n        // This also handles the case where both a and b are infinity so that we don't need\n        // to manage it in the rest of the function.\n        if x == y {\n            return true;\n        }\n\n        let diff = (x - y).abs();\n        let max = F::max(x.abs(), y.abs());\n\n        diff < self.absolute.max(self.relative * max)\n    }\n\n    fn check_relative<FF: ToPrimitive>(tolerance: FF) -> F {\n        let tolerance = F::from(tolerance).unwrap();\n        assert!(tolerance <= F::one());\n        tolerance\n    }\n\n    fn check_absolute<FF: ToPrimitive>(tolerance: FF) -> F {\n        let tolerance = F::from(tolerance).unwrap();\n        assert!(tolerance >= F::zero());\n        tolerance\n    }\n}\n\nimpl TensorData {\n    /// Asserts the data is equal to another data.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The other data.\n    /// * `strict` - If true, the data types must the be same.\n    ///   Otherwise, the comparison is done in the current data type.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the data is not equal.\n    #[track_caller]\n    pub fn assert_eq(&self, other: &Self, strict: bool) {\n        if strict {\n            assert_eq!(\n                self.dtype, other.dtype,\n                \"Data types differ ({:?} != {:?})\",\n                self.dtype, other.dtype\n            );\n        }\n\n        match self.dtype {\n            DType::F64 => self.assert_eq_elem::<f64>(other),\n            DType::F32 | DType::Flex32 => self.assert_eq_elem::<f32>(other),\n            DType::F16 => self.assert_eq_elem::<f16>(other),\n            DType::BF16 => self.assert_eq_elem::<bf16>(other),\n            DType::I64 => self.assert_eq_elem::<i64>(other),\n            DType::I32 => self.assert_eq_elem::<i32>(other),\n            DType::I16 => self.assert_eq_elem::<i16>(other),\n            DType::I8 => self.assert_eq_elem::<i8>(other),\n            DType::U64 => self.assert_eq_elem::<u64>(other),\n            DType::U32 => self.assert_eq_elem::<u32>(other),\n            DType::U16 => self.assert_eq_elem::<u16>(other),\n            DType::U8 => self.assert_eq_elem::<u8>(other),\n            DType::Bool(BoolStore::Native) => self.assert_eq_elem::<bool>(other),\n            DType::Bool(BoolStore::U8) => self.assert_eq_elem::<u8>(other),\n            DType::Bool(BoolStore::U32) => self.assert_eq_elem::<u32>(other),\n            DType::QFloat(q) => {\n                // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality\n                let q_other = if let DType::QFloat(q_other) = other.dtype {\n                    q_other\n                } else {\n                    panic!(\"Quantized data differs from other not quantized data\")\n                };\n\n                // Data equality mostly depends on input quantization type, but we also check level\n                if q.value == q_other.value && q.level == q_other.level {\n                    self.assert_eq_elem::<i8>(other)\n                } else {\n                    panic!(\"Quantization schemes differ ({q:?} != {q_other:?})\")\n                }\n            }\n        }\n    }\n\n    #[track_caller]\n    fn assert_eq_elem<E: Element>(&self, other: &Self) {\n        let mut message = String::new();\n        if self.shape != other.shape {\n            message += format!(\n                \"\\n  => Shape is different: {:?} != {:?}\",\n                self.shape, other.shape\n            )\n            .as_str();\n        }\n\n        let mut num_diff = 0;\n        let max_num_diff = 5;\n        for (i, (a, b)) in self.iter::<E>().zip(other.iter::<E>()).enumerate() {\n            if !a.eq(&b) {\n                // Only print the first 5 different values.\n                if num_diff < max_num_diff {\n                    message += format!(\"\\n  => Position {i}: {a} != {b}\").as_str();\n                }\n                num_diff += 1;\n            }\n        }\n\n        if num_diff >= max_num_diff {\n            message += format!(\"\\n{} more errors...\", num_diff - max_num_diff).as_str();\n        }\n\n        if !message.is_empty() {\n            panic!(\"Tensors are not eq:{message}\");\n        }\n    }\n\n    /// Asserts the data is approximately equal to another data.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The other data.\n    /// * `tolerance` - The tolerance of the comparison.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the data is not approximately equal.\n    #[track_caller]\n    pub fn assert_approx_eq<F: Float + Element>(&self, other: &Self, tolerance: Tolerance<F>) {\n        let mut message = String::new();\n        if self.shape != other.shape {\n            message += format!(\n                \"\\n  => Shape is different: {:?} != {:?}\",\n                self.shape, other.shape\n            )\n            .as_str();\n        }\n\n        let iter = self.iter::<F>().zip(other.iter::<F>());\n\n        let mut num_diff = 0;\n        let max_num_diff = 5;\n\n        for (i, (a, b)) in iter.enumerate() {\n            //if they are both nan, then they are equally nan\n            let both_nan = a.is_nan() && b.is_nan();\n            //this works for both infinities\n            let both_inf =\n                a.is_infinite() && b.is_infinite() && ((a > F::zero()) == (b > F::zero()));\n\n            if both_nan || both_inf {\n                continue;\n            }\n\n            if !tolerance.approx_eq(F::from(a).unwrap(), F::from(b).unwrap()) {\n                // Only print the first 5 different values.\n                if num_diff < max_num_diff {\n                    let diff_abs = ToPrimitive::to_f64(&(a - b).abs()).unwrap();\n                    let max = F::max(a.abs(), b.abs());\n                    let diff_rel = diff_abs / ToPrimitive::to_f64(&max).unwrap();\n\n                    let tol_rel = ToPrimitive::to_f64(&tolerance.relative).unwrap();\n                    let tol_abs = ToPrimitive::to_f64(&tolerance.absolute).unwrap();\n\n                    message += format!(\n                        \"\\n  => Position {i}: {a} != {b}\\n     diff (rel = {diff_rel:+.2e}, abs = {diff_abs:+.2e}), tol (rel = {tol_rel:+.2e}, abs = {tol_abs:+.2e})\"\n                    )\n                    .as_str();\n                }\n                num_diff += 1;\n            }\n        }\n\n        if num_diff >= max_num_diff {\n            message += format!(\"\\n{} more errors...\", num_diff - 5).as_str();\n        }\n\n        if !message.is_empty() {\n            panic!(\"Tensors are not approx eq:{message}\");\n        }\n    }\n\n    /// Asserts each value is within a given range.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range.\n    ///\n    /// # Panics\n    ///\n    /// If any value is not within the half-open range bounded inclusively below\n    /// and exclusively above (`start..end`).\n    pub fn assert_within_range<E: ElementOrdered>(&self, range: core::ops::Range<E>) {\n        for elem in self.iter::<E>() {\n            if elem.cmp(&range.start).is_lt() || elem.cmp(&range.end).is_ge() {\n                panic!(\"Element ({elem:?}) is not within range {range:?}\");\n            }\n        }\n    }\n\n    /// Asserts each value is within a given inclusive range.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range.\n    ///\n    /// # Panics\n    ///\n    /// If any value is not within the half-open range bounded inclusively (`start..=end`).\n    pub fn assert_within_range_inclusive<E: ElementOrdered>(\n        &self,\n        range: core::ops::RangeInclusive<E>,\n    ) {\n        let start = range.start();\n        let end = range.end();\n\n        for elem in self.iter::<E>() {\n            if elem.cmp(start).is_lt() || elem.cmp(end).is_gt() {\n                panic!(\"Element ({elem:?}) is not within range {range:?}\");\n            }\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn should_assert_appox_eq_limit() {\n        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);\n        let data2 = TensorData::from([[3.03, 5.0, 6.0]]);\n\n        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(3e-2));\n        data1.assert_approx_eq::<f16>(&data2, Tolerance::absolute(3e-2));\n    }\n\n    #[test]\n    #[should_panic]\n    fn should_assert_approx_eq_above_limit() {\n        let data1 = TensorData::from([[3.0, 5.0, 6.0]]);\n        let data2 = TensorData::from([[3.031, 5.0, 6.0]]);\n\n        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));\n    }\n\n    #[test]\n    #[should_panic]\n    fn should_assert_approx_eq_check_shape() {\n        let data1 = TensorData::from([[3.0, 5.0, 6.0, 7.0]]);\n        let data2 = TensorData::from([[3.0, 5.0, 6.0]]);\n\n        data1.assert_approx_eq::<f32>(&data2, Tolerance::absolute(1e-2));\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/data/mod.rs",
    "content": "mod compare;\nmod tensor;\n\npub use compare::*;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-backend/src/data/tensor.rs",
    "content": "use core::f32;\n\nuse alloc::boxed::Box;\nuse alloc::format;\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse bytemuck::{AnyBitPattern, CheckedBitPattern, Zeroable, cast_mut, checked::CheckedCastError};\nuse rand::Rng;\nuse thiserror::Error;\n\nuse crate::Scalar;\nuse crate::distribution::Distribution;\nuse crate::element::{Element, ElementConversion};\nuse burn_std::tensor::DType;\nuse burn_std::{\n    BoolStore, Bytes, QuantLevel, QuantMode, QuantScheme, QuantValue, QuantizedBytes, Shape, bf16,\n    f16,\n};\n\nuse serde::{Deserialize, Serialize};\n\n/// Data structure for tensors.\n#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]\npub struct TensorData {\n    /// The values of the tensor (as bytes).\n    pub bytes: Bytes,\n\n    /// The shape of the tensor.\n    #[serde(with = \"shape_inner\")]\n    pub shape: Shape,\n\n    /// The data type of the tensor.\n    pub dtype: DType,\n}\n\n// For backward compatibility with shape `Vec<usize>`\nmod shape_inner {\n    use burn_std::SmallVec;\n\n    use super::*;\n\n    pub fn serialize<S: serde::Serializer>(\n        shape: &Shape,\n        serializer: S,\n    ) -> Result<S::Ok, S::Error> {\n        shape.as_slice().serialize(serializer)\n    }\n\n    pub fn deserialize<'de, D: serde::Deserializer<'de>>(\n        deserializer: D,\n    ) -> Result<Shape, D::Error> {\n        let dims = SmallVec::<[usize; _]>::deserialize(deserializer)?;\n        Ok(Shape::new_raw(dims))\n    }\n}\n\nimpl TensorData {\n    /// Creates a new tensor data structure.\n    pub fn new<E: Element, S: Into<Shape>>(value: Vec<E>, shape: S) -> Self {\n        // Ensure shape is valid\n        let shape = shape.into();\n        Self::check_data_len(&value, &shape);\n\n        Self {\n            bytes: Bytes::from_elems(value),\n            shape,\n            dtype: E::dtype(),\n        }\n    }\n\n    /// Creates a new quantized tensor data structure.\n    pub fn quantized<E: Element, S: Into<Shape>>(\n        value: Vec<E>,\n        shape: S,\n        scheme: QuantScheme,\n        qparams: &[f32],\n    ) -> Self {\n        let shape = shape.into();\n        Self::check_data_len(&value, &shape);\n\n        let q_bytes = QuantizedBytes::new(value, scheme, qparams);\n\n        Self {\n            bytes: q_bytes.bytes,\n            shape,\n            dtype: DType::QFloat(q_bytes.scheme),\n        }\n    }\n\n    /// Creates a new tensor data structure from raw bytes.\n    pub fn from_bytes<S: Into<Shape>>(bytes: Bytes, shape: S, dtype: DType) -> Self {\n        Self {\n            bytes,\n            shape: shape.into(),\n            dtype,\n        }\n    }\n\n    /// Creates a new tensor data structure from raw bytes stored in a vector.\n    ///\n    /// Prefer [`TensorData::new`] or [`TensorData::quantized`] over this method unless you are\n    /// certain that the bytes representation is valid.\n    pub fn from_bytes_vec<S: Into<Shape>>(bytes: Vec<u8>, shape: S, dtype: DType) -> Self {\n        Self {\n            bytes: Bytes::from_bytes_vec(bytes),\n            shape: shape.into(),\n            dtype,\n        }\n    }\n\n    // Check that the input vector contains a correct number of elements\n    fn check_data_len<E: Element>(data: &[E], shape: &Shape) {\n        let expected_data_len = Self::numel(shape);\n        let num_data = data.len();\n        assert_eq!(\n            expected_data_len, num_data,\n            \"Shape {shape:?} is invalid for input of size {num_data:?}\",\n        );\n    }\n\n    /// Returns the immutable slice view of the tensor data.\n    pub fn as_slice<E: Element>(&self) -> Result<&[E], DataError> {\n        if self.matches_target_dtype::<E>() {\n            match E::dtype() {\n                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying\n                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool\n                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.\n                DType::Bool(BoolStore::Native) => {\n                    let slice = bytemuck::checked::try_cast_slice::<_, u8>(&self.bytes)\n                        .map_err(DataError::CastError)?;\n                    Ok(unsafe { core::mem::transmute::<&[u8], &[E]>(slice) })\n                }\n                _ => bytemuck::checked::try_cast_slice(&self.bytes).map_err(DataError::CastError),\n            }\n        } else {\n            Err(DataError::TypeMismatch(format!(\n                \"Invalid target element type (expected {:?}, got {:?})\",\n                self.dtype,\n                E::dtype()\n            )))\n        }\n    }\n\n    /// Returns the mutable slice view of the tensor data.\n    ///\n    /// # Panics\n    /// If the target element type is different from the stored element type.\n    pub fn as_mut_slice<E: Element>(&mut self) -> Result<&mut [E], DataError> {\n        if self.matches_target_dtype::<E>() {\n            match E::dtype() {\n                // The only way to create a bool `TensorData` with invalid values is by unsafely modifying\n                // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool\n                // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.\n                DType::Bool(BoolStore::Native) => {\n                    let slice = bytemuck::checked::try_cast_slice_mut::<_, u8>(&mut self.bytes)\n                        .map_err(DataError::CastError)?;\n                    Ok(unsafe { core::mem::transmute::<&mut [u8], &mut [E]>(slice) })\n                }\n                _ => bytemuck::checked::try_cast_slice_mut(&mut self.bytes)\n                    .map_err(DataError::CastError),\n            }\n        } else {\n            Err(DataError::TypeMismatch(format!(\n                \"Invalid target element type (expected {:?}, got {:?})\",\n                self.dtype,\n                E::dtype()\n            )))\n        }\n    }\n\n    /// Returns the tensor data as a vector of scalar values.\n    pub fn to_vec<E: Element>(&self) -> Result<Vec<E>, DataError> {\n        Ok(self.as_slice()?.to_vec())\n    }\n\n    /// Returns the tensor data as a vector of scalar values.\n    pub fn into_vec<E: Element>(self) -> Result<Vec<E>, DataError> {\n        // This means we cannot call `into_vec` for QFloat\n        if !self.matches_target_dtype::<E>() {\n            return Err(DataError::TypeMismatch(format!(\n                \"Invalid target element type (expected {:?}, got {:?})\",\n                self.dtype,\n                E::dtype()\n            )));\n        }\n\n        match E::dtype() {\n            // The only way to create a bool `TensorData` with invalid values is by unsafely modifying\n            // the dtype. This should be considered unsafe to begin with, so we unsafely cast bool\n            // to u8 to skip bit validation. Validation iterates through the entire vector, so it's slow.\n            DType::Bool(BoolStore::Native) => {\n                let vec = self.into_vec_unchecked::<u8>()?;\n                Ok(unsafe { core::mem::transmute::<Vec<u8>, Vec<E>>(vec) })\n            }\n            _ => self.into_vec_unchecked(),\n        }\n    }\n\n    /// Returns the tensor data as a vector of scalar values. Does not check dtype.\n    fn into_vec_unchecked<E: Element>(self) -> Result<Vec<E>, DataError> {\n        let mut me = self;\n        me.bytes = match me.bytes.try_into_vec::<E>() {\n            Ok(elems) => return Ok(elems),\n            Err(bytes) => bytes,\n        };\n\n        // The bytes might have been deserialized and allocated with a different align.\n        // In that case, we have to memcopy the data into a new vector, more suitably allocated\n        Ok(bytemuck::checked::try_cast_slice(me.as_bytes())\n            .map_err(DataError::CastError)?\n            .to_vec())\n    }\n\n    fn matches_target_dtype<E: Element>(&self) -> bool {\n        let target_dtype = E::dtype();\n        match self.dtype {\n            DType::Bool(BoolStore::U8) => {\n                matches!(target_dtype, DType::U8 | DType::Bool(BoolStore::U8))\n            }\n            DType::Bool(BoolStore::U32) => {\n                matches!(target_dtype, DType::U32 | DType::Bool(BoolStore::U32))\n            }\n            dtype => dtype == target_dtype,\n        }\n    }\n\n    /// Returns an iterator over the values of the tensor data.\n    pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {\n        if E::dtype() == self.dtype {\n            Box::new(bytemuck::checked::cast_slice(&self.bytes).iter().copied())\n        } else {\n            match self.dtype {\n                DType::I8 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &i8| e.elem::<E>()),\n                ),\n                DType::I16 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &i16| e.elem::<E>()),\n                ),\n                DType::I32 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &i32| e.elem::<E>()),\n                ),\n                DType::I64 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &i64| e.elem::<E>()),\n                ),\n                DType::U8 => Box::new(self.bytes.iter().map(|e| e.elem::<E>())),\n                DType::U16 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &u16| e.elem::<E>()),\n                ),\n                DType::U32 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &u32| e.elem::<E>()),\n                ),\n                DType::U64 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &u64| e.elem::<E>()),\n                ),\n                DType::BF16 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &bf16| e.elem::<E>()),\n                ),\n                DType::F16 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &f16| e.elem::<E>()),\n                ),\n                DType::F32 | DType::Flex32 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &f32| e.elem::<E>()),\n                ),\n                DType::F64 => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &f64| e.elem::<E>()),\n                ),\n                // bool is a byte value equal to either 0 or 1\n                DType::Bool(BoolStore::Native) | DType::Bool(BoolStore::U8) => {\n                    Box::new(self.bytes.iter().map(|e| e.elem::<E>()))\n                }\n                DType::Bool(BoolStore::U32) => Box::new(\n                    bytemuck::checked::cast_slice(&self.bytes)\n                        .iter()\n                        .map(|e: &u32| e.elem::<E>()),\n                ),\n                DType::QFloat(scheme) => match scheme {\n                    QuantScheme {\n                        level: QuantLevel::Tensor | QuantLevel::Block(_),\n                        mode: QuantMode::Symmetric,\n                        value:\n                            QuantValue::Q8F\n                            | QuantValue::Q8S\n                            // Represent sub-byte values as i8\n                            | QuantValue::Q4F\n                            | QuantValue::Q4S\n                            | QuantValue::Q2F\n                            | QuantValue::Q2S,\n                        ..\n                    } => {\n                        // Quantized int8 values\n                        let q_bytes = QuantizedBytes {\n                            bytes: self.bytes.clone(),\n                            scheme,\n                            num_elements: self.num_elements(),\n                        };\n                        let (values, _) = q_bytes.into_vec_i8();\n\n                        Box::new(\n                            values\n                                .iter()\n                                .map(|e: &i8| e.elem::<E>())\n                                .collect::<Vec<_>>()\n                                .into_iter(),\n                        )\n                    }\n                    QuantScheme {\n                        level: QuantLevel::Tensor | QuantLevel::Block(_),\n                        mode: QuantMode::Symmetric,\n                        value:\n                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,\n                        ..\n                    } => {\n                        unimplemented!(\"Not yet implemented for iteration\");\n                    }\n                },\n            }\n        }\n    }\n\n    /// Returns the rank (the number of dimensions).\n    pub fn rank(&self) -> usize {\n        self.shape.len()\n    }\n\n    /// Returns the total number of elements of the tensor data.\n    pub fn num_elements(&self) -> usize {\n        Self::numel(&self.shape)\n    }\n\n    fn numel(shape: &[usize]) -> usize {\n        shape.iter().product()\n    }\n\n    /// Populates the data with random values.\n    pub fn random<E: Element, R: Rng, S: Into<Shape>>(\n        shape: S,\n        distribution: Distribution,\n        rng: &mut R,\n    ) -> Self {\n        let shape = shape.into();\n        let num_elements = Self::numel(&shape);\n        let mut data = Vec::with_capacity(num_elements);\n\n        for _ in 0..num_elements {\n            data.push(E::random(distribution, rng));\n        }\n\n        TensorData::new(data, shape)\n    }\n\n    /// Populates the data with zeros.\n    pub fn zeros<E: Element, S: Into<Shape>>(shape: S) -> TensorData {\n        let shape = shape.into();\n        let num_elements = Self::numel(&shape);\n        let mut data = Vec::<E>::with_capacity(num_elements);\n\n        for _ in 0..num_elements {\n            data.push(0.elem());\n        }\n\n        TensorData::new(data, shape)\n    }\n\n    /// Populates the data with ones.\n    pub fn ones<E: Element, S: Into<Shape>>(shape: S) -> TensorData {\n        let shape = shape.into();\n        let num_elements = Self::numel(&shape);\n        let mut data = Vec::<E>::with_capacity(num_elements);\n\n        for _ in 0..num_elements {\n            data.push(1.elem());\n        }\n\n        TensorData::new(data, shape)\n    }\n\n    /// Populates the data with the given value\n    pub fn full<E: Element, S: Into<Shape>>(shape: S, fill_value: E) -> TensorData {\n        let shape = shape.into();\n        let num_elements = Self::numel(&shape);\n        let mut data = Vec::<E>::with_capacity(num_elements);\n        for _ in 0..num_elements {\n            data.push(fill_value)\n        }\n\n        TensorData::new(data, shape)\n    }\n\n    /// Populates the data with the given value\n    pub fn full_dtype<E: Into<Scalar>, S: Into<Shape>>(\n        shape: S,\n        fill_value: E,\n        dtype: DType,\n    ) -> TensorData {\n        let fill_value = fill_value.into();\n        match dtype {\n            DType::F64 => Self::full::<f64, _>(shape, fill_value.elem()),\n            DType::F32 | DType::Flex32 => Self::full::<f32, _>(shape, fill_value.elem()),\n            DType::F16 => Self::full::<f16, _>(shape, fill_value.elem()),\n            DType::BF16 => Self::full::<bf16, _>(shape, fill_value.elem()),\n            DType::I64 => Self::full::<i64, _>(shape, fill_value.elem()),\n            DType::I32 => Self::full::<i32, _>(shape, fill_value.elem()),\n            DType::I16 => Self::full::<i16, _>(shape, fill_value.elem()),\n            DType::I8 => Self::full::<i8, _>(shape, fill_value.elem()),\n            DType::U64 => Self::full::<u64, _>(shape, fill_value.elem()),\n            DType::U32 => Self::full::<u32, _>(shape, fill_value.elem()),\n            DType::U16 => Self::full::<u16, _>(shape, fill_value.elem()),\n            DType::U8 => Self::full::<u8, _>(shape, fill_value.elem()),\n            DType::Bool(BoolStore::Native) => Self::full::<bool, _>(shape, fill_value.elem()),\n            DType::Bool(BoolStore::U8) => {\n                Self::full::<u8, _>(shape, fill_value.elem()).into_bool_u8()\n            }\n            DType::Bool(BoolStore::U32) => {\n                Self::full::<u32, _>(shape, fill_value.elem()).into_bool_u32()\n            }\n            DType::QFloat(_) => unreachable!(),\n        }\n    }\n\n    // Unchecked, used to overwrite the dtype\n    fn into_bool_u8(mut self) -> Self {\n        self.dtype = DType::Bool(BoolStore::U8);\n        self\n    }\n\n    // Unchecked, used to overwrite the dtype\n    fn into_bool_u32(mut self) -> Self {\n        self.dtype = DType::Bool(BoolStore::U32);\n        self\n    }\n\n    /// Converts the data to a different element type.\n    pub fn convert<E: Element>(self) -> Self {\n        self.convert_dtype(E::dtype())\n    }\n\n    /// Converts the data to a different element type.\n    pub fn convert_dtype(self, dtype: DType) -> Self {\n        if dtype == self.dtype {\n            self\n        } else if dtype.size() == self.dtype.size()\n            && !matches!(\n                self.dtype,\n                DType::Bool(BoolStore::Native) | DType::QFloat(_)\n            )\n            && !matches!(dtype, DType::Bool(BoolStore::Native) | DType::QFloat(_))\n        {\n            match self.dtype {\n                DType::F64 => self.convert_inplace_dtype::<f64>(dtype),\n                DType::F32 | DType::Flex32 => self.convert_inplace_dtype::<f32>(dtype),\n                DType::F16 => self.convert_inplace_dtype::<f16>(dtype),\n                DType::BF16 => self.convert_inplace_dtype::<bf16>(dtype),\n                DType::I64 => self.convert_inplace_dtype::<i64>(dtype),\n                DType::I32 => self.convert_inplace_dtype::<i32>(dtype),\n                DType::I16 => self.convert_inplace_dtype::<i16>(dtype),\n                DType::I8 => self.convert_inplace_dtype::<i8>(dtype),\n                DType::U64 => self.convert_inplace_dtype::<u64>(dtype),\n                DType::U32 => self.convert_inplace_dtype::<u32>(dtype),\n                DType::U16 => self.convert_inplace_dtype::<u16>(dtype),\n                DType::U8 => self.convert_inplace_dtype::<u8>(dtype),\n                DType::Bool(BoolStore::U8) => self.convert_inplace_dtype::<u8>(dtype),\n                DType::Bool(BoolStore::U32) => self.convert_inplace_dtype::<u32>(dtype),\n                DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),\n            }\n        } else {\n            match self.dtype {\n                DType::F64 => self.convert_clone_dtype::<f64>(dtype),\n                DType::F32 | DType::Flex32 => self.convert_clone_dtype::<f32>(dtype),\n                DType::F16 => self.convert_clone_dtype::<f16>(dtype),\n                DType::BF16 => self.convert_clone_dtype::<bf16>(dtype),\n                DType::I64 => self.convert_clone_dtype::<i64>(dtype),\n                DType::I32 => self.convert_clone_dtype::<i32>(dtype),\n                DType::I16 => self.convert_clone_dtype::<i16>(dtype),\n                DType::I8 => self.convert_clone_dtype::<i8>(dtype),\n                DType::U64 => self.convert_clone_dtype::<u64>(dtype),\n                DType::U32 => self.convert_clone_dtype::<u32>(dtype),\n                DType::U16 => self.convert_clone_dtype::<u16>(dtype),\n                DType::U8 => self.convert_clone_dtype::<u8>(dtype),\n                DType::Bool(BoolStore::Native) => self.convert_clone_dtype::<bool>(dtype),\n                DType::Bool(BoolStore::U8) => self.convert_clone_dtype::<u8>(dtype),\n                DType::Bool(BoolStore::U32) => self.convert_clone_dtype::<u32>(dtype),\n                DType::QFloat(_) => unreachable!(),\n            }\n        }\n    }\n\n    fn convert_inplace_dtype<Current: Element + AnyBitPattern>(self, dtype: DType) -> Self {\n        match dtype {\n            DType::F64 => self.convert_inplace::<Current, f64>(),\n            DType::F32 | DType::Flex32 => self.convert_inplace::<Current, f32>(),\n            DType::F16 => self.convert_inplace::<Current, f16>(),\n            DType::BF16 => self.convert_inplace::<Current, bf16>(),\n            DType::I64 => self.convert_inplace::<Current, i64>(),\n            DType::I32 => self.convert_inplace::<Current, i32>(),\n            DType::I16 => self.convert_inplace::<Current, i16>(),\n            DType::I8 => self.convert_inplace::<Current, i8>(),\n            DType::U64 => self.convert_inplace::<Current, u64>(),\n            DType::U32 => self.convert_inplace::<Current, u32>(),\n            DType::U16 => self.convert_inplace::<Current, u16>(),\n            DType::U8 => self.convert_inplace::<Current, u8>(),\n            DType::Bool(BoolStore::U8) => self.convert_inplace::<Current, u8>().into_bool_u8(),\n            DType::Bool(BoolStore::U32) => self.convert_inplace::<Current, u32>().into_bool_u32(),\n            DType::Bool(BoolStore::Native) | DType::QFloat(_) => unreachable!(),\n        }\n    }\n\n    fn convert_inplace<Current: Element + AnyBitPattern, Target: Element + AnyBitPattern>(\n        mut self,\n    ) -> Self {\n        for x in bytemuck::cast_slice_mut::<_, Current>(&mut self.bytes) {\n            let t: Target = x.elem();\n            let x = cast_mut::<_, Target>(x);\n            *x = t;\n        }\n\n        self.dtype = Target::dtype();\n\n        self\n    }\n\n    fn convert_clone_dtype<Current: Element + CheckedBitPattern>(self, dtype: DType) -> Self {\n        match dtype {\n            DType::F64 => self.convert_clone::<Current, f64>(),\n            DType::F32 | DType::Flex32 => self.convert_clone::<Current, f32>(),\n            DType::F16 => self.convert_clone::<Current, f16>(),\n            DType::BF16 => self.convert_clone::<Current, bf16>(),\n            DType::I64 => self.convert_clone::<Current, i64>(),\n            DType::I32 => self.convert_clone::<Current, i32>(),\n            DType::I16 => self.convert_clone::<Current, i16>(),\n            DType::I8 => self.convert_clone::<Current, i8>(),\n            DType::U64 => self.convert_clone::<Current, u64>(),\n            DType::U32 => self.convert_clone::<Current, u32>(),\n            DType::U16 => self.convert_clone::<Current, u16>(),\n            DType::U8 => self.convert_clone::<Current, u8>(),\n            DType::Bool(BoolStore::Native) => self.convert_clone::<Current, bool>(),\n            DType::Bool(BoolStore::U8) => self.convert_clone::<Current, u8>().into_bool_u8(),\n            DType::Bool(BoolStore::U32) => self.convert_clone::<Current, u32>().into_bool_u32(),\n            DType::QFloat(_) => unreachable!(),\n        }\n    }\n\n    fn convert_clone<Current: Element + CheckedBitPattern, Target: Element + Zeroable>(\n        self,\n    ) -> Self {\n        let this = bytemuck::checked::cast_slice::<_, Current>(&self.bytes);\n        let mut out: Vec<Target> = ::alloc::vec![Zeroable::zeroed(); self.num_elements()];\n\n        for (x, out) in this.iter().zip(&mut out) {\n            *out = x.elem();\n        }\n\n        Self::new(out, self.shape)\n    }\n\n    /// Returns the data as a slice of bytes.\n    pub fn as_bytes(&self) -> &[u8] {\n        &self.bytes\n    }\n\n    /// Returns the bytes representation of the data.\n    pub fn into_bytes(self) -> Bytes {\n        self.bytes\n    }\n}\n\nimpl<E: Element, const A: usize> From<[E; A]> for TensorData {\n    fn from(elems: [E; A]) -> Self {\n        TensorData::new(elems.to_vec(), [A])\n    }\n}\n\nimpl<const A: usize> From<[usize; A]> for TensorData {\n    fn from(elems: [usize; A]) -> Self {\n        TensorData::new(elems.iter().map(|&e| e as i64).collect(), [A])\n    }\n}\n\nimpl From<&[usize]> for TensorData {\n    fn from(elems: &[usize]) -> Self {\n        let mut data = Vec::with_capacity(elems.len());\n        for elem in elems.iter() {\n            data.push(*elem as i64);\n        }\n\n        TensorData::new(data, [elems.len()])\n    }\n}\n\nimpl<E: Element> From<&[E]> for TensorData {\n    fn from(elems: &[E]) -> Self {\n        let mut data = Vec::with_capacity(elems.len());\n        for elem in elems.iter() {\n            data.push(*elem);\n        }\n\n        TensorData::new(data, [elems.len()])\n    }\n}\n\nimpl<E: Element, const A: usize, const B: usize> From<[[E; B]; A]> for TensorData {\n    fn from(elems: [[E; B]; A]) -> Self {\n        let mut data = Vec::with_capacity(A * B);\n        for elem in elems.into_iter().take(A) {\n            for elem in elem.into_iter().take(B) {\n                data.push(elem);\n            }\n        }\n\n        TensorData::new(data, [A, B])\n    }\n}\n\nimpl<E: Element, const A: usize, const B: usize, const C: usize> From<[[[E; C]; B]; A]>\n    for TensorData\n{\n    fn from(elems: [[[E; C]; B]; A]) -> Self {\n        let mut data = Vec::with_capacity(A * B * C);\n\n        for elem in elems.into_iter().take(A) {\n            for elem in elem.into_iter().take(B) {\n                for elem in elem.into_iter().take(C) {\n                    data.push(elem);\n                }\n            }\n        }\n\n        TensorData::new(data, [A, B, C])\n    }\n}\n\nimpl<E: Element, const A: usize, const B: usize, const C: usize, const D: usize>\n    From<[[[[E; D]; C]; B]; A]> for TensorData\n{\n    fn from(elems: [[[[E; D]; C]; B]; A]) -> Self {\n        let mut data = Vec::with_capacity(A * B * C * D);\n\n        for elem in elems.into_iter().take(A) {\n            for elem in elem.into_iter().take(B) {\n                for elem in elem.into_iter().take(C) {\n                    for elem in elem.into_iter().take(D) {\n                        data.push(elem);\n                    }\n                }\n            }\n        }\n\n        TensorData::new(data, [A, B, C, D])\n    }\n}\n\nimpl<Elem: Element, const A: usize, const B: usize, const C: usize, const D: usize, const E: usize>\n    From<[[[[[Elem; E]; D]; C]; B]; A]> for TensorData\n{\n    fn from(elems: [[[[[Elem; E]; D]; C]; B]; A]) -> Self {\n        let mut data = Vec::with_capacity(A * B * C * D * E);\n\n        for elem in elems.into_iter().take(A) {\n            for elem in elem.into_iter().take(B) {\n                for elem in elem.into_iter().take(C) {\n                    for elem in elem.into_iter().take(D) {\n                        for elem in elem.into_iter().take(E) {\n                            data.push(elem);\n                        }\n                    }\n                }\n            }\n        }\n\n        TensorData::new(data, [A, B, C, D, E])\n    }\n}\nimpl core::fmt::Display for TensorData {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        let fmt = match self.dtype {\n            DType::F64 => format!(\"{:?}\", self.as_slice::<f64>().unwrap()),\n            DType::F32 | DType::Flex32 => format!(\"{:?}\", self.as_slice::<f32>().unwrap()),\n            DType::F16 => format!(\"{:?}\", self.as_slice::<f16>().unwrap()),\n            DType::BF16 => format!(\"{:?}\", self.as_slice::<bf16>().unwrap()),\n            DType::I64 => format!(\"{:?}\", self.as_slice::<i64>().unwrap()),\n            DType::I32 => format!(\"{:?}\", self.as_slice::<i32>().unwrap()),\n            DType::I16 => format!(\"{:?}\", self.as_slice::<i16>().unwrap()),\n            DType::I8 => format!(\"{:?}\", self.as_slice::<i8>().unwrap()),\n            DType::U64 => format!(\"{:?}\", self.as_slice::<u64>().unwrap()),\n            DType::U32 => format!(\"{:?}\", self.as_slice::<u32>().unwrap()),\n            DType::U16 => format!(\"{:?}\", self.as_slice::<u16>().unwrap()),\n            DType::U8 => format!(\"{:?}\", self.as_slice::<u8>().unwrap()),\n            DType::Bool(BoolStore::Native) => format!(\"{:?}\", self.as_slice::<bool>().unwrap()),\n            DType::Bool(BoolStore::U8) => format!(\"{:?}\", self.as_slice::<u8>().unwrap()),\n            DType::Bool(BoolStore::U32) => format!(\"{:?}\", self.as_slice::<u32>().unwrap()),\n            DType::QFloat(scheme) => match scheme {\n                QuantScheme {\n                    level: QuantLevel::Tensor | QuantLevel::Block(_),\n                    mode: QuantMode::Symmetric,\n                    value:\n                        QuantValue::Q8F\n                        | QuantValue::Q8S\n                        // Display sub-byte values as i8\n                        | QuantValue::Q4F\n                        | QuantValue::Q4S\n                        | QuantValue::Q2F\n                        | QuantValue::Q2S,\n                    ..\n                } => {\n                    format!(\"{:?} {scheme:?}\", self.iter::<i8>().collect::<Vec<_>>())\n                },\n                QuantScheme {\n                        level: QuantLevel::Tensor | QuantLevel::Block(_),\n                        mode: QuantMode::Symmetric,\n                        value:\n                            QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1,\n                        ..\n                    } => {\n                        unimplemented!(\"Can't format yet\");\n                    }\n            },\n        };\n        f.write_str(fmt.as_str())\n    }\n}\n\n/// The things that can go wrong when manipulating tensor data.\n#[derive(Debug, Error)]\npub enum DataError {\n    /// Failed to cast the values to a specified element type.\n    #[error(\"Failed to cast values to the specified element type.\\nError:\\n  {0}\")]\n    CastError(CheckedCastError),\n    /// Invalid target element type.\n    #[error(\"{0}\")]\n    TypeMismatch(String),\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use alloc::vec;\n    use burn_std::shape;\n    use rand::{\n        SeedableRng,\n        rngs::{StdRng, SysRng},\n    };\n\n    #[test]\n    fn should_have_rank() {\n        let shape = [3, 5, 6];\n        let data = TensorData::random::<f32, _, _>(\n            shape,\n            Distribution::Default,\n            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),\n        );\n\n        assert_eq!(data.rank(), 3);\n    }\n\n    #[test]\n    fn into_vec_should_yield_same_value_as_iter() {\n        let shape = [3, 5, 6];\n        let data = TensorData::random::<f32, _, _>(\n            shape,\n            Distribution::Default,\n            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),\n        );\n\n        let expected = data.iter::<f32>().collect::<Vec<f32>>();\n        let actual = data.into_vec::<f32>().unwrap();\n\n        assert_eq!(expected, actual);\n    }\n\n    #[test]\n    #[should_panic]\n    fn into_vec_should_assert_wrong_dtype() {\n        let shape = [3, 5, 6];\n        let data = TensorData::random::<f32, _, _>(\n            shape,\n            Distribution::Default,\n            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),\n        );\n\n        data.into_vec::<i32>().unwrap();\n    }\n\n    #[test]\n    fn should_have_right_num_elements() {\n        let shape = [3, 5, 6];\n        let num_elements: usize = shape.iter().product();\n        let data = TensorData::random::<f32, _, _>(\n            shape,\n            Distribution::Default,\n            &mut StdRng::try_from_rng(&mut SysRng).unwrap(),\n        );\n\n        assert_eq!(num_elements, data.bytes.len() / 4); // f32 stored as u8s\n        assert_eq!(num_elements, data.as_slice::<f32>().unwrap().len());\n    }\n\n    #[test]\n    fn should_have_right_shape() {\n        let data = TensorData::from([[3.0, 5.0, 6.0]]);\n        assert_eq!(data.shape, shape![1, 3]);\n\n        let data = TensorData::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]);\n        assert_eq!(data.shape, shape![2, 3]);\n\n        let data = TensorData::from([3.0, 5.0, 6.0]);\n        assert_eq!(data.shape, shape![3]);\n    }\n\n    #[test]\n    fn should_convert_bytes_correctly() {\n        let mut vector: Vec<f32> = Vec::with_capacity(5);\n        vector.push(2.0);\n        vector.push(3.0);\n        let data1 = TensorData::new(vector, vec![2]);\n\n        let factor = core::mem::size_of::<f32>() / core::mem::size_of::<u8>();\n        assert_eq!(data1.bytes.len(), 2 * factor);\n        assert_eq!(data1.bytes.capacity(), 5 * factor);\n    }\n\n    #[test]\n    fn should_convert_bytes_correctly_inplace() {\n        fn test_precision<E: Element>() {\n            let data = TensorData::new((0..32).collect(), [32]);\n            for (i, val) in data\n                .clone()\n                .convert::<E>()\n                .into_vec::<E>()\n                .unwrap()\n                .into_iter()\n                .enumerate()\n            {\n                assert_eq!(i as u32, val.elem::<u32>())\n            }\n        }\n        test_precision::<f32>();\n        test_precision::<f16>();\n        test_precision::<i64>();\n        test_precision::<i32>();\n    }\n\n    macro_rules! test_dtypes {\n    ($test_name:ident, $($dtype:ty),*) => {\n        $(\n            paste::paste! {\n                #[test]\n                fn [<$test_name _ $dtype:snake>]() {\n                    let full_dtype = TensorData::full_dtype([2, 16], 4, <$dtype>::dtype());\n                    let full = TensorData::full::<$dtype, _>([2, 16], 4.elem());\n                    assert_eq!(full_dtype, full);\n                }\n            }\n        )*\n    };\n}\n\n    test_dtypes!(\n        should_create_with_dtype,\n        bool,\n        i8,\n        i16,\n        i32,\n        i64,\n        u8,\n        u16,\n        u32,\n        u64,\n        f16,\n        bf16,\n        f32,\n        f64\n    );\n\n    #[test]\n    fn should_serialize_deserialize_tensor_data() {\n        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);\n        assert_eq!(\n            data.as_bytes(),\n            [\n                0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192,\n                64\n            ]\n        );\n        let serialized = serde_json::to_string(&data).unwrap();\n        let deserialized: TensorData = serde_json::from_str(&serialized).unwrap();\n        assert_eq!(data, deserialized);\n    }\n\n    #[test]\n    fn should_deserialize_tensor_data_with_shape_inner() {\n        // TensorData `shape` was previously a Vec<usize>.\n        let serialized = r#\"{\n        \"bytes\": [0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64],\n        \"shape\": [2, 3],\n        \"dtype\": \"F32\"\n    }\"#;\n\n        let data: TensorData = serde_json::from_str(serialized).unwrap();\n        assert_eq!(data.shape, shape![2, 3]);\n        assert_eq!(\n            data.as_slice::<f32>().unwrap(),\n            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]\n        );\n    }\n\n    #[test]\n    fn should_serialize_shape_as_flat_array() {\n        // Ensure the new Shape serializes identically to how Vec<usize> used to,\n        // i.e. as a flat JSON array, not as an object like `{\"dims\": [2, 3]}`.\n        let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);\n        let serialized = serde_json::to_string(&data).unwrap();\n        let json: serde_json::Value = serde_json::from_str(&serialized).unwrap();\n        assert_eq!(json[\"shape\"], serde_json::json!([2, 3]));\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/distribution.rs",
    "content": "//! Random value distributions used to initialize and populate tensor data.\n\nuse rand::{Rng, RngExt, distr::StandardUniform};\n\nuse super::element::{Element, ElementConversion};\n\n/// Distribution for random value of a tensor.\n#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]\npub enum Distribution {\n    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).\n    #[default]\n    Default,\n\n    /// Bernoulli distribution with the given probability.\n    Bernoulli(f64),\n\n    /// Uniform distribution `[low, high)`.\n    Uniform(f64, f64),\n\n    /// Normal distribution with the given mean and standard deviation.\n    Normal(f64, f64),\n}\n\n/// Distribution sampler for random value of a tensor.\n#[derive(new)]\npub struct DistributionSampler<'a, E, R>\nwhere\n    StandardUniform: rand::distr::Distribution<E>,\n    E: rand::distr::uniform::SampleUniform,\n    R: Rng,\n{\n    kind: DistributionSamplerKind<E>,\n    rng: &'a mut R,\n}\n\n/// Distribution sampler kind for random value of a tensor.\npub enum DistributionSamplerKind<E>\nwhere\n    StandardUniform: rand::distr::Distribution<E>,\n    E: rand::distr::uniform::SampleUniform,\n{\n    /// Standard distribution.\n    Standard(rand::distr::StandardUniform),\n\n    /// Uniform distribution.\n    Uniform(rand::distr::Uniform<E>),\n\n    /// Bernoulli distribution.\n    Bernoulli(rand::distr::Bernoulli),\n\n    /// Normal distribution.\n    Normal(rand_distr::Normal<f64>),\n}\n\nimpl<E, R> DistributionSampler<'_, E, R>\nwhere\n    StandardUniform: rand::distr::Distribution<E>,\n    E: rand::distr::uniform::SampleUniform,\n    E: Element,\n    R: Rng,\n{\n    /// Sames a random value from the distribution.\n    pub fn sample(&mut self) -> E {\n        match &self.kind {\n            DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),\n            DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),\n            DistributionSamplerKind::Bernoulli(distribution) => {\n                if self.rng.sample(distribution) {\n                    1.elem()\n                } else {\n                    0.elem()\n                }\n            }\n            DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),\n        }\n    }\n}\n\nimpl Distribution {\n    /// Creates a new distribution sampler.\n    ///\n    /// # Arguments\n    ///\n    /// * `rng` - The random number generator.\n    ///\n    /// # Returns\n    ///\n    /// The distribution sampler.\n    pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>\n    where\n        R: Rng,\n        E: Element + rand::distr::uniform::SampleUniform,\n        StandardUniform: rand::distr::Distribution<E>,\n    {\n        let kind = match self {\n            Distribution::Default => {\n                DistributionSamplerKind::Standard(rand::distr::StandardUniform {})\n            }\n            Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(\n                rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),\n            ),\n            Distribution::Bernoulli(prob) => {\n                DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())\n            }\n            Distribution::Normal(mean, std) => {\n                DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())\n            }\n        };\n\n        DistributionSampler::new(kind, rng)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_distribution_default() {\n        let dist: Distribution = Default::default();\n\n        assert_eq!(dist, Distribution::Default);\n        assert_eq!(Distribution::default(), Distribution::Default);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/element/base.rs",
    "content": "use core::cmp::Ordering;\nuse rand::Rng;\n\nuse crate::distribution::Distribution;\nuse burn_std::{BoolStore, DType, bf16, f16};\n\n#[cfg(feature = \"cubecl\")]\nuse burn_std::flex32;\n\nuse super::cast::ToElement;\n\n/// Core element trait for tensor values.\n///\n/// This trait defines the minimal set of capabilities required for a type to be\n/// stored and manipulated as a tensor element across all backends.\npub trait Element:\n    ToElement\n    + ElementRandom\n    + ElementConversion\n    + ElementEq\n    + ElementLimits\n    + bytemuck::CheckedBitPattern\n    + bytemuck::NoUninit\n    + bytemuck::Zeroable\n    + core::fmt::Debug\n    + core::fmt::Display\n    + Default\n    + Send\n    + Sync\n    + Copy\n    + 'static\n{\n    /// The dtype of the element.\n    fn dtype() -> DType;\n}\n\n/// Ordered element trait for tensor values.\n///\n/// This trait extends [`Element`] with ordering semantics, enabling comparison\n/// and order-dependent operations in generic Rust implementations.\n///\n/// Backends that implement these operations entirely at the device level do\n/// not rely on this trait. It only constrains the scalar type for generic Rust code.\npub trait ElementOrdered: Element + ElementComparison {}\n\n/// Element conversion trait for tensor.\npub trait ElementConversion {\n    /// Converts an element to another element.\n    ///\n    /// # Arguments\n    ///\n    /// * `elem` - The element to convert.\n    ///\n    /// # Returns\n    ///\n    /// The converted element.\n    fn from_elem<E: ToElement>(elem: E) -> Self;\n\n    /// Converts and returns the converted element.\n    fn elem<E: Element>(self) -> E;\n}\n\n/// Element trait for random value of a tensor.\npub trait ElementRandom {\n    /// Returns a random value for the given distribution.\n    ///\n    /// # Arguments\n    ///\n    /// * `distribution` - The distribution to sample from.\n    /// * `rng` - The random number generator.\n    ///\n    /// # Returns\n    ///\n    /// The random value.\n    fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self;\n}\n\n/// Element trait for equality of a tensor.\npub trait ElementEq {\n    /// Returns whether `self` and `other` are equal.\n    fn eq(&self, other: &Self) -> bool;\n}\n\n/// Element ordering trait.\npub trait ElementComparison {\n    /// Returns and [Ordering] between `self` and `other`.\n    fn cmp(&self, other: &Self) -> Ordering;\n}\n\n/// Element limits trait.\npub trait ElementLimits {\n    /// The minimum representable value\n    const MIN: Self;\n    /// The maximum representable value\n    const MAX: Self;\n}\n\n/// Macro to implement the element trait for a type.\n#[macro_export]\nmacro_rules! make_element {\n    (\n        ty $type:ident,\n        convert $convert:expr,\n        random $random:expr,\n        cmp $cmp:expr,\n        dtype $dtype:expr\n    ) => {\n        make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);\n    };\n    (\n        ty $type:ident,\n        convert $convert:expr,\n        random $random:expr,\n        cmp $cmp:expr,\n        dtype $dtype:expr,\n        min $min:expr,\n        max $max:expr\n    ) => {\n        impl Element for $type {\n            #[inline(always)]\n            fn dtype() -> burn_std::DType {\n                $dtype\n            }\n        }\n        impl ElementEq for $type {\n            fn eq(&self, other: &Self) -> bool {\n                self == other\n            }\n        }\n\n        impl ElementConversion for $type {\n            #[inline(always)]\n            fn from_elem<E: ToElement>(elem: E) -> Self {\n                #[allow(clippy::redundant_closure_call)]\n                $convert(&elem)\n            }\n            #[inline(always)]\n            fn elem<E: Element>(self) -> E {\n                E::from_elem(self)\n            }\n        }\n\n        impl ElementRandom for $type {\n            fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self {\n                #[allow(clippy::redundant_closure_call)]\n                $random(distribution, rng)\n            }\n        }\n\n        impl ElementComparison for $type {\n            fn cmp(&self, other: &Self) -> Ordering {\n                let a = self.elem::<$type>();\n                let b = other.elem::<$type>();\n                #[allow(clippy::redundant_closure_call)]\n                $cmp(&a, &b)\n            }\n        }\n\n        impl ElementLimits for $type {\n            const MIN: Self = $min;\n            const MAX: Self = $max;\n        }\n\n        impl ElementOrdered for $type {}\n\n    };\n}\n\nmake_element!(\n    ty f64,\n    convert ToElement::to_f64,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &f64, b: &f64| a.total_cmp(b),\n    dtype DType::F64\n);\n\nmake_element!(\n    ty f32,\n    convert ToElement::to_f32,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &f32, b: &f32| a.total_cmp(b),\n    dtype DType::F32\n);\n\nmake_element!(\n    ty i64,\n    convert ToElement::to_i64,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &i64, b: &i64| Ord::cmp(a, b),\n    dtype DType::I64\n);\n\nmake_element!(\n    ty u64,\n    convert ToElement::to_u64,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &u64, b: &u64| Ord::cmp(a, b),\n    dtype DType::U64\n);\n\nmake_element!(\n    ty i32,\n    convert ToElement::to_i32,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &i32, b: &i32| Ord::cmp(a, b),\n    dtype DType::I32\n);\n\nmake_element!(\n    ty u32,\n    convert ToElement::to_u32,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &u32, b: &u32| Ord::cmp(a, b),\n    dtype DType::U32\n);\n\nmake_element!(\n    ty i16,\n    convert ToElement::to_i16,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &i16, b: &i16| Ord::cmp(a, b),\n    dtype DType::I16\n);\n\nmake_element!(\n    ty u16,\n    convert ToElement::to_u16,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &u16, b: &u16| Ord::cmp(a, b),\n    dtype DType::U16\n);\n\nmake_element!(\n    ty i8,\n    convert ToElement::to_i8,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &i8, b: &i8| Ord::cmp(a, b),\n    dtype DType::I8\n);\n\nmake_element!(\n    ty u8,\n    convert ToElement::to_u8,\n    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),\n    cmp |a: &u8, b: &u8| Ord::cmp(a, b),\n    dtype DType::U8\n);\n\nmake_element!(\n    ty f16,\n    convert ToElement::to_f16,\n    random |distribution: Distribution, rng: &mut R| {\n        let sample: f32 = distribution.sampler(rng).sample();\n        f16::from_elem(sample)\n    },\n    cmp |a: &f16, b: &f16| a.total_cmp(b),\n    dtype DType::F16\n);\nmake_element!(\n    ty bf16,\n    convert ToElement::to_bf16,\n    random |distribution: Distribution, rng: &mut R| {\n        let sample: f32 = distribution.sampler(rng).sample();\n        bf16::from_elem(sample)\n    },\n    cmp |a: &bf16, b: &bf16| a.total_cmp(b),\n    dtype DType::BF16\n);\n\n#[cfg(feature = \"cubecl\")]\nmake_element!(\n    ty flex32,\n    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),\n    random |distribution: Distribution, rng: &mut R| {\n        let sample: f32 = distribution.sampler(rng).sample();\n        flex32::from_elem(sample)\n    },\n    cmp |a: &flex32, b: &flex32| a.total_cmp(b),\n    dtype DType::Flex32,\n    min flex32::from_f32(f16::MIN.to_f32_const()),\n    max flex32::from_f32(f16::MAX.to_f32_const())\n);\n\nmake_element!(\n    ty bool,\n    convert ToElement::to_bool,\n    random |distribution: Distribution, rng: &mut R| {\n        let sample: u8 = distribution.sampler(rng).sample();\n        bool::from_elem(sample)\n    },\n    cmp |a: &bool, b: &bool| Ord::cmp(a, b),\n    dtype DType::Bool(BoolStore::Native),\n    min false,\n    max true\n);\n"
  },
  {
    "path": "crates/burn-backend/src/element/cast.rs",
    "content": "use core::mem::size_of;\n\nuse burn_std::{bf16, f16};\n\n/// A generic trait for converting a value to a number.\n/// Adapted from num_traits::ToPrimitive to support [bool].\n///\n/// A value can be represented by the target type when it lies within\n/// the range of scalars supported by the target type.\n/// For example, a negative integer cannot be represented by an unsigned\n/// integer type, and an `i64` with a very high magnitude might not be\n/// convertible to an `i32`.\n/// On the other hand, conversions with possible precision loss or truncation\n/// are admitted, like an `f32` with a decimal part to an integer type, or\n/// even a large `f64` saturating to `f32` infinity.\n///\n/// The methods *panic* when the value cannot be represented by the target type.\npub trait ToElement {\n    /// Converts the value of `self` to an `isize`.\n    #[inline]\n    fn to_isize(&self) -> isize {\n        ToElement::to_isize(&self.to_i64())\n    }\n\n    /// Converts the value of `self` to an `i8`.\n    #[inline]\n    fn to_i8(&self) -> i8 {\n        ToElement::to_i8(&self.to_i64())\n    }\n\n    /// Converts the value of `self` to an `i16`.\n    #[inline]\n    fn to_i16(&self) -> i16 {\n        ToElement::to_i16(&self.to_i64())\n    }\n\n    /// Converts the value of `self` to an `i32`.\n    #[inline]\n    fn to_i32(&self) -> i32 {\n        ToElement::to_i32(&self.to_i64())\n    }\n\n    /// Converts the value of `self` to an `i64`.\n    fn to_i64(&self) -> i64;\n\n    /// Converts the value of `self` to an `i128`.\n    ///\n    /// The default implementation converts through `to_i64()`. Types implementing\n    /// this trait should override this method if they can represent a greater range.\n    #[inline]\n    fn to_i128(&self) -> i128 {\n        i128::from(self.to_i64())\n    }\n\n    /// Converts the value of `self` to a `usize`.\n    #[inline]\n    fn to_usize(&self) -> usize {\n        ToElement::to_usize(&self.to_u64())\n    }\n\n    /// Converts the value of `self` to a `u8`.\n    #[inline]\n    fn to_u8(&self) -> u8 {\n        ToElement::to_u8(&self.to_u64())\n    }\n\n    /// Converts the value of `self` to a `u16`.\n    #[inline]\n    fn to_u16(&self) -> u16 {\n        ToElement::to_u16(&self.to_u64())\n    }\n\n    /// Converts the value of `self` to a `u32`.\n    #[inline]\n    fn to_u32(&self) -> u32 {\n        ToElement::to_u32(&self.to_u64())\n    }\n\n    /// Converts the value of `self` to a `u64`.\n    fn to_u64(&self) -> u64;\n\n    /// Converts the value of `self` to a `u128`.\n    ///\n    /// The default implementation converts through `to_u64()`. Types implementing\n    /// this trait should override this method if they can represent a greater range.\n    #[inline]\n    fn to_u128(&self) -> u128 {\n        u128::from(self.to_u64())\n    }\n\n    /// Converts the value of `self` to an `f16`. Overflows may map to positive\n    /// or negative infinity.\n    #[inline]\n    fn to_f16(&self) -> f16 {\n        f16::from_f32(self.to_f32())\n    }\n\n    /// Converts the value of `self` to an `bf16`. Overflows may map to positive\n    /// or negative infinity.\n    #[inline]\n    fn to_bf16(&self) -> bf16 {\n        bf16::from_f32(self.to_f32())\n    }\n\n    /// Converts the value of `self` to an `f32`. Overflows may map to positive\n    /// or negative infinity.\n    #[inline]\n    fn to_f32(&self) -> f32 {\n        ToElement::to_f32(&self.to_f64())\n    }\n\n    /// Converts the value of `self` to an `f64`. Overflows may map to positive\n    /// or negative infinity.\n    ///\n    /// The default implementation tries to convert through `to_i64()`, and\n    /// failing that through `to_u64()`. Types implementing this trait should\n    /// override this method if they can represent a greater range.\n    #[inline]\n    fn to_f64(&self) -> f64 {\n        ToElement::to_f64(&self.to_u64())\n    }\n\n    /// Converts the value of `self` to a bool.\n    /// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are\n    /// adopted (anything that's not 0 is true).\n    ///\n    /// The default implementation tries to convert through `to_i64()`, and\n    /// failing that through `to_u64()`. Types implementing this trait should\n    /// override this method if they can represent a greater range.\n    #[inline]\n    fn to_bool(&self) -> bool {\n        ToElement::to_bool(&self.to_u64())\n    }\n}\n\nmacro_rules! impl_to_element_int_to_int {\n    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $DstT {\n            let min = $DstT::MIN as $SrcT;\n            let max = $DstT::MAX as $SrcT;\n            if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {\n                *self as $DstT\n            } else {\n                panic!(\n                    \"Element cannot be represented in the target type: {:?}({:?}) => {:?}\",\n                    core::any::type_name::<$SrcT>(),\n                    self,\n                    core::any::type_name::<$DstT>(),\n                )\n            }\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_int_to_uint {\n    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $DstT {\n            let max = $DstT::MAX as $SrcT;\n            if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {\n                *self as $DstT\n            } else {\n                panic!(\n                    \"Element cannot be represented in the target type: {:?}({:?}) => {:?}\",\n                    core::any::type_name::<$SrcT>(),\n                    self,\n                    core::any::type_name::<$DstT>(),\n                )\n            }\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_int {\n    ($T:ident) => {\n        impl ToElement for $T {\n            impl_to_element_int_to_int! { $T:\n                fn to_isize -> isize;\n                fn to_i8 -> i8;\n                fn to_i16 -> i16;\n                fn to_i32 -> i32;\n                fn to_i64 -> i64;\n                fn to_i128 -> i128;\n            }\n\n            impl_to_element_int_to_uint! { $T:\n                fn to_usize -> usize;\n                fn to_u8 -> u8;\n                fn to_u16 -> u16;\n                fn to_u32 -> u32;\n                fn to_u64 -> u64;\n                fn to_u128 -> u128;\n            }\n\n            #[inline]\n            fn to_f32(&self) -> f32 {\n                *self as f32\n            }\n            #[inline]\n            fn to_f64(&self) -> f64 {\n                *self as f64\n            }\n            #[inline]\n            fn to_bool(&self) -> bool {\n                *self != 0\n            }\n        }\n    };\n}\n\nimpl_to_element_int!(isize);\nimpl_to_element_int!(i8);\nimpl_to_element_int!(i16);\nimpl_to_element_int!(i32);\nimpl_to_element_int!(i64);\nimpl_to_element_int!(i128);\n\nmacro_rules! impl_to_element_uint_to_int {\n    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $DstT {\n            let max = $DstT::MAX as $SrcT;\n            if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {\n                *self as $DstT\n            } else {\n                panic!(\n                    \"Element cannot be represented in the target type: {:?}({:?}) => {:?}\",\n                    core::any::type_name::<$SrcT>(),\n                    self,\n                    core::any::type_name::<$DstT>(),\n                )\n            }\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_uint_to_uint {\n    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $DstT {\n            let max = $DstT::MAX as $SrcT;\n            if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {\n                *self as $DstT\n            } else {\n                panic!(\n                    \"Element cannot be represented in the target type: {:?}({:?}) => {:?}\",\n                    core::any::type_name::<$SrcT>(),\n                    self,\n                    core::any::type_name::<$DstT>(),\n                )\n            }\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_uint {\n    ($T:ident) => {\n        impl ToElement for $T {\n            impl_to_element_uint_to_int! { $T:\n                fn to_isize -> isize;\n                fn to_i8 -> i8;\n                fn to_i16 -> i16;\n                fn to_i32 -> i32;\n                fn to_i64 -> i64;\n                fn to_i128 -> i128;\n            }\n\n            impl_to_element_uint_to_uint! { $T:\n                fn to_usize -> usize;\n                fn to_u8 -> u8;\n                fn to_u16 -> u16;\n                fn to_u32 -> u32;\n                fn to_u64 -> u64;\n                fn to_u128 -> u128;\n            }\n\n            #[inline]\n            fn to_f32(&self) -> f32 {\n                *self as f32\n            }\n            #[inline]\n            fn to_f64(&self) -> f64 {\n                *self as f64\n            }\n            #[inline]\n            fn to_bool(&self) -> bool {\n                *self != 0\n            }\n        }\n    };\n}\n\nimpl_to_element_uint!(usize);\nimpl_to_element_uint!(u8);\nimpl_to_element_uint!(u16);\nimpl_to_element_uint!(u32);\nimpl_to_element_uint!(u64);\nimpl_to_element_uint!(u128);\n\nmacro_rules! impl_to_element_float_to_float {\n    ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(\n        #[inline]\n        fn $method(&self) -> $DstT {\n            // We can safely cast all values, whether NaN, +-inf, or finite.\n            // Finite values that are reducing size may saturate to +-inf.\n            *self as $DstT\n        }\n    )*}\n}\n\nmacro_rules! float_to_int_unchecked {\n    // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.\n    // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.\n    ($float:expr => $int:ty) => {\n        unsafe { $float.to_int_unchecked::<$int>() }\n    };\n}\n\nmacro_rules! impl_to_element_float_to_signed_int {\n    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $i {\n            // Float as int truncates toward zero, so we want to allow values\n            // in the exclusive range `(MIN-1, MAX+1)`.\n            if size_of::<$f>() > size_of::<$i>() {\n                // With a larger size, we can represent the range exactly.\n                const MIN_M1: $f = $i::MIN as $f - 1.0;\n                const MAX_P1: $f = $i::MAX as $f + 1.0;\n                if *self > MIN_M1 && *self < MAX_P1 {\n                    return float_to_int_unchecked!(*self => $i);\n                }\n            } else {\n                // We can't represent `MIN-1` exactly, but there's no fractional part\n                // at this magnitude, so we can just use a `MIN` inclusive boundary.\n                const MIN: $f = $i::MIN as $f;\n                // We can't represent `MAX` exactly, but it will round up to exactly\n                // `MAX+1` (a power of two) when we cast it.\n                const MAX_P1: $f = $i::MAX as $f;\n                if *self >= MIN && *self < MAX_P1 {\n                    return float_to_int_unchecked!(*self => $i);\n                }\n            }\n            panic!(\"Float cannot be represented in the target signed int type\")\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_float_to_unsigned_int {\n    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(\n        #[inline]\n        $(#[$cfg])*\n        fn $method(&self) -> $u {\n            // Float as int truncates toward zero, so we want to allow values\n            // in the exclusive range `(-1, MAX+1)`.\n            if size_of::<$f>() > size_of::<$u>() {\n                // With a larger size, we can represent the range exactly.\n                const MAX_P1: $f = $u::MAX as $f + 1.0;\n                if *self > -1.0 && *self < MAX_P1 {\n                    return float_to_int_unchecked!(*self => $u);\n                }\n            } else {\n                // We can't represent `MAX` exactly, but it will round up to exactly\n                // `MAX+1` (a power of two) when we cast it.\n                // (`u128::MAX as f32` is infinity, but this is still ok.)\n                const MAX_P1: $f = $u::MAX as $f;\n                if *self > -1.0 && *self < MAX_P1 {\n                    return float_to_int_unchecked!(*self => $u);\n                }\n            }\n            panic!(\"Float cannot be represented in the target unsigned int type\")\n        }\n    )*}\n}\n\nmacro_rules! impl_to_element_float {\n    ($T:ident) => {\n        impl ToElement for $T {\n            impl_to_element_float_to_signed_int! { $T:\n                fn to_isize -> isize;\n                fn to_i8 -> i8;\n                fn to_i16 -> i16;\n                fn to_i32 -> i32;\n                fn to_i64 -> i64;\n                fn to_i128 -> i128;\n            }\n\n            impl_to_element_float_to_unsigned_int! { $T:\n                fn to_usize -> usize;\n                fn to_u8 -> u8;\n                fn to_u16 -> u16;\n                fn to_u32 -> u32;\n                fn to_u64 -> u64;\n                fn to_u128 -> u128;\n            }\n\n            impl_to_element_float_to_float! { $T:\n                fn to_f32 -> f32;\n                fn to_f64 -> f64;\n            }\n\n            #[inline]\n            fn to_bool(&self) -> bool {\n                *self != 0.0\n            }\n        }\n    };\n}\n\nimpl_to_element_float!(f32);\nimpl_to_element_float!(f64);\n\nimpl ToElement for f16 {\n    #[inline]\n    fn to_i64(&self) -> i64 {\n        Self::to_f32(*self).to_i64()\n    }\n    #[inline]\n    fn to_u64(&self) -> u64 {\n        Self::to_f32(*self).to_u64()\n    }\n    #[inline]\n    fn to_i8(&self) -> i8 {\n        Self::to_f32(*self).to_i8()\n    }\n    #[inline]\n    fn to_u8(&self) -> u8 {\n        Self::to_f32(*self).to_u8()\n    }\n    #[inline]\n    fn to_i16(&self) -> i16 {\n        Self::to_f32(*self).to_i16()\n    }\n    #[inline]\n    fn to_u16(&self) -> u16 {\n        Self::to_f32(*self).to_u16()\n    }\n    #[inline]\n    fn to_i32(&self) -> i32 {\n        Self::to_f32(*self).to_i32()\n    }\n    #[inline]\n    fn to_u32(&self) -> u32 {\n        Self::to_f32(*self).to_u32()\n    }\n    #[inline]\n    fn to_f16(&self) -> f16 {\n        *self\n    }\n    #[inline]\n    fn to_f32(&self) -> f32 {\n        Self::to_f32(*self)\n    }\n    #[inline]\n    fn to_f64(&self) -> f64 {\n        Self::to_f64(*self)\n    }\n    #[inline]\n    fn to_bool(&self) -> bool {\n        *self != f16::from_f32_const(0.0)\n    }\n}\n\nimpl ToElement for bf16 {\n    #[inline]\n    fn to_i64(&self) -> i64 {\n        Self::to_f32(*self).to_i64()\n    }\n    #[inline]\n    fn to_u64(&self) -> u64 {\n        Self::to_f32(*self).to_u64()\n    }\n    #[inline]\n    fn to_i8(&self) -> i8 {\n        Self::to_f32(*self).to_i8()\n    }\n    #[inline]\n    fn to_u8(&self) -> u8 {\n        Self::to_f32(*self).to_u8()\n    }\n    #[inline]\n    fn to_i16(&self) -> i16 {\n        Self::to_f32(*self).to_i16()\n    }\n    #[inline]\n    fn to_u16(&self) -> u16 {\n        Self::to_f32(*self).to_u16()\n    }\n    #[inline]\n    fn to_i32(&self) -> i32 {\n        Self::to_f32(*self).to_i32()\n    }\n    #[inline]\n    fn to_u32(&self) -> u32 {\n        Self::to_f32(*self).to_u32()\n    }\n    #[inline]\n    fn to_bf16(&self) -> bf16 {\n        *self\n    }\n    #[inline]\n    fn to_f32(&self) -> f32 {\n        Self::to_f32(*self)\n    }\n    #[inline]\n    fn to_f64(&self) -> f64 {\n        Self::to_f64(*self)\n    }\n    #[inline]\n    fn to_bool(&self) -> bool {\n        *self != bf16::from_f32_const(0.0)\n    }\n}\n\n#[cfg(feature = \"cubecl\")]\nimpl ToElement for burn_std::flex32 {\n    #[inline]\n    fn to_i64(&self) -> i64 {\n        Self::to_f32(*self).to_i64()\n    }\n    #[inline]\n    fn to_u64(&self) -> u64 {\n        Self::to_f32(*self).to_u64()\n    }\n    #[inline]\n    fn to_i8(&self) -> i8 {\n        Self::to_f32(*self).to_i8()\n    }\n    #[inline]\n    fn to_u8(&self) -> u8 {\n        Self::to_f32(*self).to_u8()\n    }\n    #[inline]\n    fn to_i16(&self) -> i16 {\n        Self::to_f32(*self).to_i16()\n    }\n    #[inline]\n    fn to_u16(&self) -> u16 {\n        Self::to_f32(*self).to_u16()\n    }\n    #[inline]\n    fn to_i32(&self) -> i32 {\n        Self::to_f32(*self).to_i32()\n    }\n    #[inline]\n    fn to_u32(&self) -> u32 {\n        Self::to_f32(*self).to_u32()\n    }\n    #[inline]\n    fn to_f32(&self) -> f32 {\n        Self::to_f32(*self)\n    }\n    #[inline]\n    fn to_f64(&self) -> f64 {\n        Self::to_f64(*self)\n    }\n    #[inline]\n    fn to_bool(&self) -> bool {\n        *self != burn_std::flex32::from_f32(0.0)\n    }\n}\n\nimpl ToElement for bool {\n    #[inline]\n    fn to_i64(&self) -> i64 {\n        *self as i64\n    }\n    #[inline]\n    fn to_u64(&self) -> u64 {\n        *self as u64\n    }\n    #[inline]\n    fn to_i8(&self) -> i8 {\n        *self as i8\n    }\n    #[inline]\n    fn to_u8(&self) -> u8 {\n        *self as u8\n    }\n    #[inline]\n    fn to_i16(&self) -> i16 {\n        *self as i16\n    }\n    #[inline]\n    fn to_u16(&self) -> u16 {\n        *self as u16\n    }\n    #[inline]\n    fn to_i32(&self) -> i32 {\n        *self as i32\n    }\n    #[inline]\n    fn to_u32(&self) -> u32 {\n        *self as u32\n    }\n    #[inline]\n    fn to_f32(&self) -> f32 {\n        self.to_u8() as f32\n    }\n    #[inline]\n    fn to_f64(&self) -> f64 {\n        self.to_u8() as f64\n    }\n    #[inline]\n    fn to_bool(&self) -> bool {\n        *self\n    }\n}\n\nmod tests {\n    #[allow(unused_imports)]\n    use super::*;\n\n    #[test]\n    fn to_element_float() {\n        let f32_toolarge = 1e39f64;\n        assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);\n        assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);\n        assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);\n        assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);\n        assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);\n        assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);\n        assert!((f64::NAN).to_f32().is_nan());\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_u8_underflow() {\n        let _x = (-1i8).to_u8();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_u16_underflow() {\n        let _x = (-1i8).to_u16();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_u32_underflow() {\n        let _x = (-1i8).to_u32();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_u64_underflow() {\n        let _x = (-1i8).to_u64();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_u128_underflow() {\n        let _x = (-1i8).to_u128();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_signed_to_usize_underflow() {\n        let _x = (-1i8).to_usize();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_unsigned_to_u8_overflow() {\n        let _x = 256.to_u8();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_unsigned_to_u16_overflow() {\n        let _x = 65_536.to_u16();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_unsigned_to_u32_overflow() {\n        let _x = 4_294_967_296u64.to_u32();\n    }\n\n    #[test]\n    #[should_panic]\n    fn to_element_unsigned_to_u64_overflow() {\n        let _x = 18_446_744_073_709_551_616u128.to_u64();\n    }\n\n    #[test]\n    fn to_element_int_to_float() {\n        assert_eq!((-1).to_f32(), -1.0);\n        assert_eq!((-1).to_f64(), -1.0);\n        assert_eq!(255.to_f32(), 255.0);\n        assert_eq!(65_535.to_f64(), 65_535.0);\n    }\n\n    #[test]\n    fn to_element_float_to_int() {\n        assert_eq!((-1.0).to_i8(), -1);\n        assert_eq!(1.0.to_u8(), 1);\n        assert_eq!(1.8.to_u16(), 1);\n        assert_eq!(123.456.to_u32(), 123);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/element/mod.rs",
    "content": "//! Traits and helpers for working with element types and conversions.\n\nmod base;\nmod scalar;\n\n/// Tensor element casting.\npub mod cast;\n\npub use base::*;\npub use scalar::*;\n"
  },
  {
    "path": "crates/burn-backend/src/element/scalar.rs",
    "content": "use burn_std::{DType, bf16, f16};\nuse num_traits::ToPrimitive;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse crate::{Element, ElementConversion};\n\n/// A scalar element.\n#[derive(Clone, Copy, Debug)]\n#[allow(missing_docs)]\npub enum Scalar {\n    Float(f64),\n    Int(i64),\n    UInt(u64),\n    Bool(bool),\n}\n\nimpl Scalar {\n    /// Creates a scalar with the specified data type.\n    ///\n    /// # Note\n    /// [`QFloat`](DType::QFloat) scalars are represented as float for element-wise operations.\n    pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {\n        if dtype.is_float() | matches!(dtype, &DType::QFloat(_)) {\n            Self::Float(value.elem())\n        } else if dtype.is_int() {\n            Self::Int(value.elem())\n        } else if dtype.is_uint() {\n            Self::UInt(value.elem())\n        } else if dtype.is_bool() {\n            Self::Bool(value.elem())\n        } else {\n            unimplemented!(\"Scalar not supported for {dtype:?}\")\n        }\n    }\n\n    /// Converts and returns the converted element.\n    pub fn elem<E: Element>(self) -> E {\n        match self {\n            Self::Float(x) => x.elem(),\n            Self::Int(x) => x.elem(),\n            Self::UInt(x) => x.elem(),\n            Self::Bool(x) => x.elem(),\n        }\n    }\n\n    /// Returns the exact integer value, if valid.\n    pub fn try_as_integer(&self) -> Option<Self> {\n        match self {\n            Scalar::Float(x) => (x.floor() == *x).then(|| Self::Int(x.to_i64().unwrap())),\n            Scalar::Int(_) | Scalar::UInt(_) => Some(*self),\n            Scalar::Bool(x) => Some(Scalar::Int(*x as i64)),\n        }\n    }\n}\n\nmacro_rules! impl_from_scalar {\n    ($($ty:ty => $variant:ident),+ $(,)?) => {\n        $(\n            impl From<$ty> for Scalar {\n                fn from(value: $ty) -> Self {\n                    Scalar::$variant(value.elem())\n                }\n            }\n        )+\n    };\n}\n\nimpl_from_scalar! {\n    f64  => Float, f32  => Float, f16  => Float, bf16 => Float,\n    i64  => Int, i32  => Int, i16  => Int, i8 => Int,\n    u64  => UInt, u32  => UInt, u16  => UInt, u8 => UInt, bool => Bool,\n}\n\n// CubeCL requirement\nimpl ToPrimitive for Scalar {\n    fn to_i64(&self) -> Option<i64> {\n        match self {\n            Scalar::Float(x) => x.to_i64(),\n            Scalar::UInt(x) => x.to_i64(),\n            Scalar::Int(x) => Some(*x),\n            Scalar::Bool(x) => Some(*x as i64),\n        }\n    }\n\n    fn to_u64(&self) -> Option<u64> {\n        match self {\n            Scalar::Float(x) => x.to_u64(),\n            Scalar::UInt(x) => Some(*x),\n            Scalar::Int(x) => x.to_u64(),\n            Scalar::Bool(x) => Some(*x as u64),\n        }\n    }\n\n    fn to_f64(&self) -> Option<f64> {\n        match self {\n            Scalar::Float(x) => Some(*x),\n            Scalar::UInt(x) => x.to_f64(),\n            Scalar::Int(x) => x.to_f64(),\n            Scalar::Bool(x) => (*x as u8).to_f64(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! This library provides the core types that define how Burn tensor data is represented, stored, and interpreted.\n\n#[macro_use]\nextern crate derive_new;\n\nextern crate alloc;\n\nmod data;\npub use data::*;\n\npub mod distribution;\npub use distribution::*;\npub mod element;\npub use element::*;\n\n/// [`Backend`] trait and required types.\npub mod backend;\npub use backend::*;\n\n/// Backend tensor primitives and operations.\npub mod tensor;\n\n// Re-exported types\npub use burn_std::reader::*; // Useful so that backends don't have to add `burn_std` as a dependency.\npub use burn_std::{\n    AllocationProperty, BoolDType, BoolStore, Bytes, DType, DeviceHandle, FloatDType, IntDType,\n    bf16, f16, stream_id::StreamId,\n};\n\n/// Shape definition.\npub mod shape {\n    pub use burn_std::shape::*;\n}\npub use shape::*;\n\n/// Slice utilities.\npub mod slice {\n    pub use burn_std::{s, slice::*};\n}\npub use slice::*;\n\n/// Indexing utilities.\npub mod indexing {\n    pub use burn_std::indexing::*;\n}\npub use indexing::*;\n\n/// Quantization data representation.\npub mod quantization {\n    pub use crate::tensor::quantization::*;\n    pub use burn_std::quantization::{\n        BlockSize, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantStore,\n        QuantValue, QuantizedBytes,\n    };\n}\n\n#[cfg(feature = \"cubecl-wgpu\")]\nmod cube_wgpu {\n    use crate::backend::DeviceOps;\n    use cubecl::wgpu::WgpuDevice;\n\n    impl DeviceOps for WgpuDevice {}\n}\n\n#[cfg(feature = \"cubecl-cuda\")]\nmod cube_cuda {\n    use crate::backend::DeviceOps;\n    use cubecl::cuda::CudaDevice;\n\n    impl DeviceOps for CudaDevice {}\n}\n\n#[cfg(feature = \"cubecl-cpu\")]\nmod cube_cpu {\n    use crate::backend::DeviceOps;\n    use cubecl::cpu::CpuDevice;\n\n    impl DeviceOps for CpuDevice {}\n}\n\n#[cfg(feature = \"cubecl-hip\")]\nmod cube_hip {\n    use crate::backend::DeviceOps;\n    use cubecl::hip::AmdDevice;\n\n    impl DeviceOps for AmdDevice {}\n}\n\n/// Convenience macro to link to the `burn-tensor` docs for this crate version.\n///\n/// Usage:\n/// ```rust,ignore\n/// # use burn_backend::doc_tensor;\n/// doc_tensor!();        // Links to `Tensor` struct\n/// doc_tensor!(\"zeros\"); // Links to `Tensor::zeros` method\n/// ```\n#[macro_export]\nmacro_rules! doc_tensor {\n    () => {\n        concat!(\n            \"[`Tensor`](https://docs.rs/burn-tensor/\",\n            env!(\"CARGO_PKG_VERSION\"),\n            \"/burn_tensor/struct.Tensor.html)\"\n        )\n    };\n\n    ($method:literal) => {\n        concat!(\n            \"[`Tensor::\",\n            $method,\n            \"`](\",\n            \"https://docs.rs/burn-tensor/\",\n            env!(\"CARGO_PKG_VERSION\"),\n            \"/burn_tensor/struct.Tensor.html#method.\",\n            $method,\n            \")\"\n        )\n    };\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/alias.rs",
    "content": "use crate::backend::Backend;\n\n// We provide some type aliases to improve the readability of using associated types without\n// having to use the disambiguation syntax.\n\n/// Device type used by the backend.\npub type Device<B> = <B as Backend>::Device;\n\n/// Float element type used by backend.\npub type FloatElem<B> = <B as Backend>::FloatElem;\n/// Integer element type used by backend.\npub type IntElem<B> = <B as Backend>::IntElem;\n/// Boolean element type used by backend.\npub type BoolElem<B> = <B as Backend>::BoolElem;\n\n/// Float tensor primitive type used by the backend.\npub type FloatTensor<B> = <B as Backend>::FloatTensorPrimitive;\n/// Integer tensor primitive type used by the backend.\npub type IntTensor<B> = <B as Backend>::IntTensorPrimitive;\n/// Boolean tensor primitive type used by the backend.\npub type BoolTensor<B> = <B as Backend>::BoolTensorPrimitive;\n/// Quantized tensor primitive type used by the backend.\npub type QuantizedTensor<B> = <B as Backend>::QuantizedTensorPrimitive;\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/container.rs",
    "content": "use alloc::boxed::Box;\nuse core::any::Any;\n\n#[cfg(not(feature = \"std\"))]\nuse alloc::vec::Vec;\n#[cfg(not(feature = \"std\"))]\nuse hashbrown::HashMap;\n\n#[cfg(feature = \"std\")]\nuse std::collections::HashMap;\n\nuse crate::{TensorPrimitive, backend::Backend};\n\n/// Contains tensor of arbitrary dimension.\n#[derive(Debug)]\npub struct TensorContainer<ID> {\n    tensors: HashMap<ID, Box<dyn Any + Send>>,\n}\n\nimpl<ID> Default for TensorContainer<ID>\nwhere\n    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,\n{\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<ID> TensorContainer<ID>\nwhere\n    ID: core::hash::Hash + PartialEq + Eq + core::fmt::Debug,\n{\n    /// Create an empty container.\n    pub fn new() -> Self {\n        Self {\n            tensors: HashMap::new(),\n        }\n    }\n\n    /// Get a tensor with the given ID.\n    pub fn get<B>(&self, id: &ID) -> Option<TensorPrimitive<B>>\n    where\n        B: Backend,\n    {\n        let grad = self.tensors.get(id)?;\n\n        let tensor = grad\n            .downcast_ref::<TensorPrimitive<B>>()\n            // .map(|primitive| Tensor::<B, D>::from_primitive(primitive.clone()))\n            .unwrap();\n\n        Some(tensor.clone())\n    }\n\n    /// Register a new tensor for the given ID.\n    ///\n    /// # Notes\n    ///\n    /// If a tensor is already registered for the given ID, it will be replaced.\n    pub fn register<B>(&mut self, id: ID, value: TensorPrimitive<B>)\n    where\n        B: Backend,\n    {\n        self.tensors.insert(id, Box::new(value));\n    }\n\n    /// Remove a tensor for the given ID and returns it.\n    pub fn remove<B>(&mut self, id: &ID) -> Option<TensorPrimitive<B>>\n    where\n        B: Backend,\n    {\n        self.tensors\n            .remove(id)\n            .map(|item| *item.downcast::<TensorPrimitive<B>>().unwrap())\n        // .map(|primitive| Tensor::from_primitive(*primitive))\n    }\n\n    /// The number of tensors registered.\n    pub fn len(&self) -> usize {\n        self.tensors.len()\n    }\n\n    /// If any tensor is contained.\n    pub fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n\n    /// Get id of every tensor in the container\n    pub fn ids(&self) -> Vec<&ID> {\n        self.tensors.keys().collect()\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/kind.rs",
    "content": "use crate::{Backend, TensorMetadata, TensorPrimitive};\n\n/// A type-level representation of the kind of a float tensor\n#[derive(Clone, Debug)]\npub struct Float;\n\n/// A type-level representation of the kind of a int tensor.\n#[derive(Clone, Debug)]\npub struct Int;\n\n/// A type-level representation of the kind of a bool tensor.\n#[derive(Clone, Debug)]\npub struct Bool;\n\n/// A type-level representation of the kind of a tensor.\n/// Metadata access is lazy.\npub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {\n    /// The primitive type of the tensor.\n    type Primitive: TensorMetadata;\n\n    /// The name of the tensor kind.\n    fn name() -> &'static str;\n}\n\nimpl<B: Backend> TensorKind<B> for Float {\n    type Primitive = TensorPrimitive<B>;\n    fn name() -> &'static str {\n        \"Float\"\n    }\n}\n\nimpl<B: Backend> TensorKind<B> for Int {\n    type Primitive = B::IntTensorPrimitive;\n    fn name() -> &'static str {\n        \"Int\"\n    }\n}\n\nimpl<B: Backend> TensorKind<B> for Bool {\n    type Primitive = B::BoolTensorPrimitive;\n    fn name() -> &'static str {\n        \"Bool\"\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/mod.rs",
    "content": "mod alias;\nmod container;\nmod kind;\nmod ops;\n\npub use alias::*;\npub use container::*;\npub use kind::*;\npub use ops::*;\n\n/// Tensor quantization module.\npub mod quantization;\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/autodiff.rs",
    "content": "use crate::{\n    AutodiffBackend,\n    tensor::{BasicOps, TensorKind},\n};\n\n/// Trait that list all operations that can be applied on all tensors on an autodiff backend.\n///\n/// # Warnings\n///\n/// This is an internal trait, use the public API provided by the\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// struct.\npub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {\n    /// Inner primitive tensor.\n    type InnerKind: BasicOps<B::InnerBackend>;\n\n    /// Returns the inner tensor without the autodiff information.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"inner\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::inner`\")]\n    /// function, which is more high-level and designed for public use.\n    fn inner(\n        tensor: <Self as TensorKind<B>>::Primitive,\n    ) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive;\n\n    /// Convert a tensor to the autodiff backend.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"from_inner\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::from_inner`\")]\n    /// function, which is more high-level and designed for public use.\n    fn from_inner(\n        inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive,\n    ) -> <Self as TensorKind<B>>::Primitive;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/base.rs",
    "content": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    Backend, ExecutionError, Scalar, TensorData, TensorMetadata,\n    element::Element,\n    ops::TransactionPrimitive,\n    tensor::{IndexingUpdateOp, IntTensor, TensorKind},\n};\n\n/// Trait that list all operations that can be applied on all tensors.\n///\n/// # Warnings\n///\n/// This is an internal trait, use the public API provided by the\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// struct.\npub trait BasicOps<B: Backend>: TensorKind<B> {\n    /// The type of the tensor elements.\n    type Elem: Element;\n\n    /// Creates an empty tensor with the given shape.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device on which the tensor will be allocated.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The empty tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating empty tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"empty\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::empty`\")]\n    /// function, which is more high-level and designed for public use.\n    fn empty(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;\n\n    /// Creates a tensor filled with zeros.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device on which the tensor will be allocated.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor filled with zeros.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating a tensor filled with zeros, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"zeros\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::zeros`\")]\n    /// function, which is more high-level and designed for public use.\n    fn zeros(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;\n\n    /// Creates a tensor filled with ones.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device on which the tensor will be allocated.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor filled with ones.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating a tensor filled with ones, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"ones\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::ones`\")]\n    /// function, which is more high-level and designed for public use.\n    fn ones(shape: Shape, device: &B::Device, dtype: DType) -> Self::Primitive;\n\n    /// Creates a tensor of the given shape where each element is equal to the provided value.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `fill_value` - The value with which to fill the tensor.\n    /// * `device` - The device on which the tensor will be allocated.\n    /// * `dtype` - The target data type.\n    ///\n    /// # Returns\n    ///\n    /// The tensor filled with the specified value.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating full tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"full\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::full`\")]\n    /// function, which is more high-level and designed for public use.\n    fn full(shape: Shape, fill_value: Scalar, device: &B::Device, dtype: DType) -> Self::Primitive;\n\n    /// Reshapes the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `shape` - The new shape of the tensor.\n    ///\n    /// # Returns\n    ///\n    /// The reshaped tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For reshaping a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"reshape\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::reshape`\")]\n    /// function, which is more high-level and designed for public use.\n    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;\n\n    /// Transposes a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    fn transpose(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Swaps two dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap.\n    /// * `dim2` - The second dimension to swap.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive;\n\n    /// Permutes the dimensions of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to permute the dimensions of.\n    /// * `axes` - The new order of the dimensions.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;\n\n    /// Flips the tensor along the given axes.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to flip.\n    /// * `axes` - The axes to flip the tensor along.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the axes flipped.\n    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive;\n\n    ///  Select tensor elements corresponding to the given slices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `slices` - The slices specifying ranges and steps for each dimension.\n    ///\n    /// # Returns\n    ///\n    /// The selected elements.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For selecting elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"slice\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::slice`\")]\n    /// function, which is more high-level and designed for public use.\n    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive;\n\n    /// Assigns the given value to the tensor elements corresponding to the given slices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `slices` - The slices specifying which elements to assign, including support for steps.\n    /// * `value` - The value to assign.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the assigned values.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For assigning values to elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"slice_assign\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::slice_assign`\")]\n    /// function, which is more high-level and designed for public use.\n    fn slice_assign(\n        tensor: Self::Primitive,\n        slices: &[Slice],\n        value: Self::Primitive,\n    ) -> Self::Primitive;\n\n    /// Select tensor elements along the given dimension corresponding to the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select from.\n    /// * `dim` - The dimension along which to select.\n    /// * `indices` - The indices of the elements to select.\n    ///\n    /// # Returns\n    ///\n    /// The selected tensor elements.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For selecting elements from a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"select\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::select`\")]\n    /// function, which is more high-level and designed for public use.\n    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive;\n\n    /// Assign the selected elements along the given dimension corresponding to the given indices\n    /// from the value tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to assign elements to.\n    /// * `dim` - The axis along which to assign elements.\n    /// * `indices` - The indices of the elements to assign.\n    /// * `values` - The values to assign to the tensor.\n    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is taken from the\n    /// corresponding element of the input tensor at the corresponding index along the specified axis,\n    /// except for the elements at the specified indices, which are taken from the corresponding\n    /// element of the values tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For assigning elements to a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"select_assign\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::select_assign`\")]\n    /// function, which is more high-level and designed for public use.\n    fn select_assign(\n        tensor: Self::Primitive,\n        dim: usize,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive;\n\n    /// Selects elements from a tensor based on a boolean mask.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.\n    /// * `mask` - The boolean mask to use for selecting elements.\n    /// * `source` - The tensor to select elements from when the corresponding element of the mask is false.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensors, where each element is taken from the\n    /// corresponding element of the left hand side tensor if the corresponding element of the mask\n    /// is true, and from the corresponding element of the right hand side tensor otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For selecting elements from a tensor based on a boolean mask, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mask_where\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mask_where`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mask_where(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        source: Self::Primitive,\n    ) -> Self::Primitive;\n\n    /// Fills elements of a tensor based on a boolean mask.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor where will be overwritten with the value\n    ///   when the corresponding element of the mask is true.\n    /// * `mask` - The boolean mask to use for filling elements.\n    /// * `value` - The value to fill elements with when the corresponding element of the mask is true.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensors, where each element is taken from the\n    /// corresponding element unmodified if the corresponding element of the mask is false, and\n    /// filled with the value otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For filling elements of a tensor based on a boolean mask, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mask_fill\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mask_fill`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mask_fill(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        value: Scalar,\n    ) -> Self::Primitive;\n\n    /// Gathers elements from a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to gather elements.\n    /// * `tensor` - The tensor to gather elements from.\n    /// * `indices` - The indices of the elements to gather.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is taken from the\n    /// corresponding element of the input tensor at the corresponding index along the specified axis.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For gathering elements from a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"gather\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::gather`\")]\n    /// function, which is more high-level and designed for public use.\n    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive;\n\n    /// Scatters elements into a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to scatter elements.\n    /// * `tensor` - The tensor to scatter elements into.\n    /// * `indices` - The indices of the elements to scatter.\n    /// * `values` - The values to scatter into the tensor.\n    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is taken from the\n    /// corresponding element of the input tensor at the corresponding index along the specified axis,\n    /// except for the elements at the specified indices, which are taken from the corresponding\n    /// element of the values tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For scattering elements into a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"scatter\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::scatter`\")]\n    /// function, which is more high-level and designed for public use.\n    fn scatter(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive;\n\n    /// Returns the device on which the tensor is allocated.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The device on which the tensor is allocated.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the device of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"device\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::device`\")]\n    /// function, which is more high-level and designed for public use.\n    fn device(tensor: &Self::Primitive) -> B::Device;\n\n    /// Moves the tensor to the given device.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `device` - The device on which the tensor will be moved.\n    ///\n    /// # Returns\n    ///\n    /// The tensor on the given device.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For moving a tensor to a device, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"to_device\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::to_device`\")]\n    /// function, which is more high-level and designed for public use.\n    #[allow(clippy::wrong_self_convention)]\n    fn to_device(tensor: Self::Primitive, device: &B::Device) -> Self::Primitive;\n\n    /// Extracts the data from the tensor asynchronously.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The data of the tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For extracting the data of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"into_data\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::into_data`\")]\n    /// function, which is more high-level and designed for public use.\n    #[allow(clippy::wrong_self_convention)]\n    fn into_data_async(\n        tensor: Self::Primitive,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send;\n\n    /// Read the data from the tensor using a transaction.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive);\n\n    /// Creates a tensor from the given data.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The data of the tensor.\n    /// * `device` - The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// The tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating a tensor from data, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"from_data\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::from_data`\")]\n    /// function, which is more high-level and designed for public use.\n    fn from_data(data: TensorData, device: &B::Device) -> Self::Primitive;\n    /// Creates a tensor from the given data enforcing the given data type.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For creating a tensor from data, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"from_data_dtype\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::from_data_dtype`\")]\n    /// function, which is more high-level and designed for public use.\n    fn from_data_dtype(data: TensorData, device: &B::Device, dtype: DType) -> Self::Primitive;\n\n    /// Repeat the tensor along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    /// * `dim` - The dimension along which the tensor will be repeated.\n    /// * `times` - The number of times the tensor will be repeated.\n    ///\n    /// # Returns\n    ///\n    /// The repeated tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For repeating a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"repeat_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::repeat_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive;\n\n    /// Concatenates the given tensors along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `vectors` - The tensors to concatenate.\n    /// * `dim` - The dimension along which the tensors will be concatenated.\n    ///\n    /// # Returns\n    ///\n    /// The concatenated tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For concatenating tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"cat\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::cat`\")]\n    /// function, which is more high-level and designed for public use.\n    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive;\n\n    /// Equates the given tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor of booleans indicating whether the corresponding elements are equal.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For equating tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"equal\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::equal`\")]\n    /// function, which is more high-level and designed for public use.\n    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise equality between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding elements of the input tensors are equal, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise equality between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"equal_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::equal_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Applies element-wise non-equality comparison between the given tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The tensor of booleans indicating whether the corresponding elements are equal.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For non-equality comparison of tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"not_equal\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::not_equal`\")]\n    /// function, which is more high-level and designed for public use.\n    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise non-equality between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding elements of the input tensors are equal, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise non-equality between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"not_equal_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::not_equal_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Returns the name of the element type.\n    fn elem_type_name() -> &'static str {\n        core::any::type_name::<Self::Elem>()\n    }\n\n    /// Returns the tensor data type.\n    fn dtype(tensor: &Self::Primitive) -> DType {\n        tensor.dtype()\n    }\n\n    /// Tests if any element in the `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if any element in the input tensor evaluates to True, False otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly. Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"any\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::any`\")]\n    /// function, which is more high-level and designed for public use.\n    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Tests if any element in the tensor evaluates to True along a given dimension dim.\n    ///\n    /// # Arguments\n    ///\n    /// * tensor - The tensor to test.\n    /// * dim - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same size as input tensor, except in the dim axis where the size is 1.\n    /// Returns True if any element in the input tensor along the given dimension evaluates to True, False otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly. Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"any_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::any_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;\n\n    /// Tests if all elements in the `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element, True if all elements in the input tensor evaluates to True, False otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly. Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"all\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::all`\")]\n    /// function, which is more high-level and designed for public use.\n    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same size as input `tensor`, except in the `dim` axis where the size is 1.\n    /// Returns True if all elements in the input tensor along the given dimension evaluate to True, False otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly. Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"all_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::all_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive;\n\n    /// Broadcasts the given tensor to the specified shape.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to broadcast.\n    /// * `shape` - The shape to broadcast to.\n    ///\n    /// # Returns\n    ///\n    /// The broadcasted tensor.\n    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive;\n\n    /// Unfold windows along a dimension.\n    ///\n    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// # Warning\n    ///\n    /// For the `ndarray` and `candle` backends; this is not a view but a full copy.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n    /// * `dim` - the dimension to unfold.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with shape ``[pre=..., windows, post=..., size]``.\n    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/bool.rs",
    "content": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, ExecutionError, Scalar, TensorData,\n    element::Element,\n    ops::TransactionPrimitive,\n    tensor::{BasicAutodiffOps, BasicOps, Bool, Device, IndexingUpdateOp, IntTensor, TensorKind},\n};\n\nimpl<B: Backend> BasicOps<B> for Bool {\n    type Elem = B::BoolElem;\n\n    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        if dtype != Self::Elem::dtype() {\n            panic!(\"Expected bool data type, got {dtype:?}\");\n        }\n        B::bool_empty(shape, device)\n    }\n\n    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        if dtype != Self::Elem::dtype() {\n            panic!(\"Expected bool data type, got {dtype:?}\");\n        }\n        B::bool_zeros(shape, device)\n    }\n    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        if dtype != Self::Elem::dtype() {\n            panic!(\"Expected bool data type, got {dtype:?}\");\n        }\n        B::bool_ones(shape, device)\n    }\n\n    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        if dtype != Self::Elem::dtype() {\n            panic!(\"Expected bool data type, got {dtype:?}\");\n        }\n        if fill_value.elem() {\n            B::bool_ones(shape, device)\n        } else {\n            B::bool_zeros(shape, device)\n        }\n    }\n\n    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {\n        tr.register_bool(tensor);\n    }\n\n    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        B::bool_reshape(tensor, shape)\n    }\n\n    fn transpose(tensor: Self::Primitive) -> Self::Primitive {\n        B::bool_transpose(tensor)\n    }\n\n    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {\n        B::bool_swap_dims(tensor, dim1, dim2)\n    }\n\n    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {\n        B::bool_slice(tensor, slices)\n    }\n\n    fn slice_assign(\n        tensor: Self::Primitive,\n        slices: &[Slice],\n        value: Self::Primitive,\n    ) -> Self::Primitive {\n        B::bool_slice_assign(tensor, slices, value)\n    }\n\n    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {\n        B::bool_select(tensor, dim, indices)\n    }\n\n    fn select_assign(\n        tensor: Self::Primitive,\n        dim: usize,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        match update {\n            IndexingUpdateOp::Add => B::bool_select_or(tensor, dim, indices, values),\n        }\n    }\n\n    fn mask_where(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        source: Self::Primitive,\n    ) -> Self::Primitive {\n        B::bool_mask_where(tensor, mask, source)\n    }\n\n    fn mask_fill(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        value: Scalar,\n    ) -> Self::Primitive {\n        B::bool_mask_fill(tensor, mask, value)\n    }\n\n    fn gather(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: B::IntTensorPrimitive,\n    ) -> Self::Primitive {\n        B::bool_gather(dim, tensor, indices)\n    }\n\n    fn scatter(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: B::IntTensorPrimitive,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        match update {\n            IndexingUpdateOp::Add => B::bool_scatter_or(dim, tensor, indices, values),\n        }\n    }\n\n    fn device(tensor: &Self::Primitive) -> Device<B> {\n        B::bool_device(tensor)\n    }\n\n    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {\n        B::bool_to_device(tensor, device)\n    }\n\n    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {\n        B::bool_into_data(tensor).await\n    }\n\n    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {\n        B::bool_from_data(data.convert::<B::BoolElem>(), device)\n    }\n\n    fn from_data_dtype(data: TensorData, device: &Device<B>, _dtype: DType) -> Self::Primitive {\n        // Bool tensors have exactly one representation per backend, so the\n        // requested dtype is irrelevant. Convert to `B::BoolElem` directly.\n        B::bool_from_data(data.convert::<B::BoolElem>(), device)\n    }\n\n    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {\n        B::bool_repeat_dim(tensor, dim, times)\n    }\n\n    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::bool_equal(lhs, rhs)\n    }\n\n    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::bool_not_equal(lhs, rhs)\n    }\n\n    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::bool_equal_elem(lhs, rhs)\n    }\n\n    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::bool_not_equal_elem(lhs, rhs)\n    }\n\n    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {\n        B::bool_cat(vectors, dim)\n    }\n\n    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::bool_any(tensor)\n    }\n\n    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {\n        B::bool_any_dim(tensor, dim)\n    }\n\n    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::bool_all(tensor)\n    }\n\n    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {\n        B::bool_all_dim(tensor, dim)\n    }\n\n    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        B::bool_permute(tensor, axes)\n    }\n\n    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        B::bool_expand(tensor, shape)\n    }\n\n    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        B::bool_flip(tensor, axes)\n    }\n\n    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {\n        B::bool_unfold(tensor, dim, size, step)\n    }\n}\n\nimpl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {\n    type InnerKind = Bool;\n\n    fn inner(\n        tensor: <Self as TensorKind<B>>::Primitive,\n    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {\n        B::bool_inner(tensor)\n    }\n\n    fn from_inner(\n        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,\n    ) -> <Self as TensorKind<B>>::Primitive {\n        B::bool_from_inner(inner)\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/float.rs",
    "content": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData, TensorPrimitive,\n    ops::TransactionPrimitive,\n    tensor::{\n        BasicAutodiffOps, BasicOps, Device, Float, IndexingUpdateOp, IntTensor, Numeric, Ordered,\n        TensorKind,\n    },\n};\n\nmacro_rules! q_bin_ops {\n    ($lhs:ident, $rhs:ident, $op:ident, $q_op:ident) => {\n        match ($lhs, $rhs) {\n            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {\n                TensorPrimitive::Float(B::$op(lhs, rhs))\n            }\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::$q_op(lhs, rhs),\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {\n                TensorPrimitive::Float(B::$op(B::dequantize(lhs), rhs))\n            }\n            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {\n                TensorPrimitive::Float(B::$op(lhs, B::dequantize(rhs)))\n            }\n        }\n    };\n}\n\nimpl<B: Backend> BasicOps<B> for Float {\n    type Elem = B::FloatElem;\n\n    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_empty(shape, device, dtype.into()))\n    }\n\n    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_zeros(shape, device, dtype.into()))\n    }\n    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_ones(shape, device, dtype.into()))\n    }\n\n    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_full(shape, fill_value, device, dtype.into()))\n    }\n\n    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {\n        tr.register_float(tensor);\n    }\n\n    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_reshape(tensor, shape))\n            }\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_reshape(tensor, shape)),\n        }\n    }\n\n    fn transpose(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_transpose(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_transpose(tensor)),\n        }\n    }\n\n    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_swap_dims(tensor, dim1, dim2))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_swap_dims(tensor, dim1, dim2))\n            }\n        }\n    }\n\n    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_slice(tensor, slices))\n            }\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_slice(tensor, slices)),\n        }\n    }\n\n    fn slice_assign(\n        tensor: Self::Primitive,\n        slices: &[Slice],\n        value: Self::Primitive,\n    ) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_slice_assign(\n            tensor.tensor(),\n            slices,\n            value.tensor(),\n        ))\n    }\n\n    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_select(tensor, dim, indices))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_select(tensor, dim, indices))\n            }\n        }\n    }\n\n    fn select_assign(\n        tensor: Self::Primitive,\n        dim: usize,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        // Select assign is ambiguous for QFloat\n        match update {\n            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_select_add(\n                tensor.tensor(),\n                dim,\n                indices,\n                values.tensor(),\n            )),\n        }\n    }\n\n    fn mask_where(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        source: Self::Primitive,\n    ) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor()))\n    }\n\n    fn mask_fill(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        value: Scalar,\n    ) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value))\n    }\n\n    fn gather(dim: usize, tensor: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_gather(dim, tensor, indices))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_gather(dim, tensor, indices))\n            }\n        }\n    }\n\n    fn scatter(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        match update {\n            IndexingUpdateOp::Add => TensorPrimitive::Float(B::float_scatter_add(\n                dim,\n                tensor.tensor(),\n                indices,\n                values.tensor(),\n            )),\n        }\n    }\n\n    fn device(tensor: &Self::Primitive) -> Device<B> {\n        match tensor {\n            TensorPrimitive::Float(tensor) => B::float_device(tensor),\n            TensorPrimitive::QFloat(tensor) => B::q_device(tensor),\n        }\n    }\n\n    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_to_device(tensor, device))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_to_device(tensor, device))\n            }\n        }\n    }\n\n    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {\n        match tensor {\n            TensorPrimitive::Float(tensor) => B::float_into_data(tensor).await,\n            TensorPrimitive::QFloat(tensor) => B::q_into_data(tensor).await,\n        }\n    }\n\n    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {\n        match &data.dtype {\n            DType::QFloat(_scheme) => TensorPrimitive::QFloat(B::q_from_data(data, device)),\n            _ => TensorPrimitive::Float(B::float_from_data(data.convert::<B::FloatElem>(), device)),\n        }\n    }\n\n    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        match dtype {\n            DType::QFloat(_scheme) => {\n                TensorPrimitive::QFloat(B::q_from_data(data.convert_dtype(dtype), device))\n            }\n            _ if dtype.is_float() => {\n                TensorPrimitive::Float(B::float_from_data(data.convert_dtype(dtype), device))\n            }\n            _ => panic!(\"Expected float dtype, got {dtype:?}\"),\n        }\n    }\n\n    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_repeat_dim(tensor, dim, times))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_repeat_dim(tensor, dim, times))\n            }\n        }\n    }\n\n    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {\n        match vectors.first().unwrap() {\n            TensorPrimitive::Float(_) => TensorPrimitive::Float(B::float_cat(\n                vectors.into_iter().map(|tensor| tensor.tensor()).collect(),\n                dim,\n            )),\n            TensorPrimitive::QFloat(_) => TensorPrimitive::QFloat(B::q_cat(\n                vectors\n                    .into_iter()\n                    .map(|tensor| {\n                        if let TensorPrimitive::QFloat(t) = tensor {\n                            t\n                        } else {\n                            panic!(\"Concatenation only works with vector of QFloat\")\n                        }\n                    })\n                    .collect(),\n                dim,\n            )),\n        }\n    }\n\n    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_equal(lhs.tensor(), rhs.tensor())\n    }\n\n    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_not_equal(lhs.tensor(), rhs.tensor())\n    }\n\n    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_equal_elem(lhs.tensor(), rhs)\n    }\n\n    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_not_equal_elem(lhs.tensor(), rhs)\n    }\n\n    fn any(tensor: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_any(tensor.tensor())\n    }\n\n    fn any_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {\n        B::float_any_dim(tensor.tensor(), dim)\n    }\n\n    fn all(tensor: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_all(tensor.tensor())\n    }\n\n    fn all_dim(tensor: Self::Primitive, dim: usize) -> B::BoolTensorPrimitive {\n        B::float_all_dim(tensor.tensor(), dim)\n    }\n\n    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_permute(tensor, axes))\n            }\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_permute(tensor, axes)),\n        }\n    }\n\n    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_expand(tensor.tensor(), shape))\n    }\n\n    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_flip(tensor, axes)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_flip(tensor, axes)),\n        }\n    }\n\n    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_unfold(tensor.tensor(), dim, size, step))\n    }\n}\n\nimpl<B: Backend> Numeric<B> for Float {\n    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        q_bin_ops!(lhs, rhs, float_add, q_add)\n    }\n\n    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        match lhs {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_add_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs),\n        }\n    }\n\n    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        q_bin_ops!(lhs, rhs, float_sub, q_sub)\n    }\n\n    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        match lhs {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs),\n        }\n    }\n\n    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        q_bin_ops!(lhs, rhs, float_div, q_div)\n    }\n\n    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        match lhs {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_div_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs),\n        }\n    }\n    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor()))\n    }\n\n    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs))\n    }\n\n    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        q_bin_ops!(lhs, rhs, float_mul, q_mul)\n    }\n\n    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        match lhs {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs),\n        }\n    }\n    fn neg(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)),\n            TensorPrimitive::QFloat(tensor) => B::q_neg(tensor),\n        }\n    }\n\n    fn sum(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)),\n            TensorPrimitive::QFloat(tensor) => B::q_sum(tensor),\n        }\n    }\n\n    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim),\n        }\n    }\n\n    fn prod(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)),\n            TensorPrimitive::QFloat(tensor) => B::q_prod(tensor),\n        }\n    }\n\n    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_prod_dim(tensor, dim))\n            }\n            TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim),\n        }\n    }\n\n    fn mean(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)),\n            TensorPrimitive::QFloat(tensor) => B::q_mean(tensor),\n        }\n    }\n\n    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_mean_dim(tensor, dim))\n            }\n            TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim),\n        }\n    }\n\n    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumsum(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => B::q_cumsum(tensor, dim),\n        }\n    }\n\n    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cumprod(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => B::q_cumprod(tensor, dim),\n        }\n    }\n\n    fn abs(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_abs(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_abs(tensor)),\n        }\n    }\n\n    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        q_bin_ops!(lhs, rhs, float_powf, q_powf)\n    }\n\n    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        match lhs {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs),\n        }\n    }\n\n    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_random(shape, distribution, device))\n    }\n\n    fn sign(tensor: Self::Primitive) -> Self::Primitive {\n        TensorPrimitive::Float(B::float_sign(tensor.tensor()))\n    }\n\n    /// Applies the matrix multiplication operation.\n    ///\n    /// `C = AB`\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have a compatible shape.\n    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        match (lhs, rhs) {\n            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {\n                TensorPrimitive::Float(B::float_matmul(lhs, rhs))\n            }\n            (lhs, rhs) => B::q_matmul(lhs, rhs),\n        }\n    }\n}\nimpl<B: Backend> Ordered<B> for Float {\n    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_sort(tensor, dim, descending))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_sort(tensor, dim, descending))\n            }\n        }\n    }\n\n    fn sort_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n        descending: bool,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                let (values, indices) = B::float_sort_with_indices(tensor, dim, descending);\n                (TensorPrimitive::Float(values), indices)\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                let (values, indices) = B::q_sort_with_indices(tensor, dim, descending);\n                (TensorPrimitive::QFloat(values), indices)\n            }\n        }\n    }\n\n    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {\n        match tensor {\n            TensorPrimitive::Float(tensor) => B::float_argsort(tensor, dim, descending),\n            TensorPrimitive::QFloat(tensor) => B::q_argsort(tensor, dim, descending),\n        }\n    }\n\n    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim),\n        }\n    }\n\n    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummax(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => B::q_cummax(tensor, dim),\n        }\n    }\n\n    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_greater(lhs.tensor(), rhs.tensor())\n    }\n\n    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_greater_elem(lhs.tensor(), rhs)\n    }\n\n    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_greater_equal(lhs.tensor(), rhs.tensor())\n    }\n\n    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_greater_equal_elem(lhs.tensor(), rhs)\n    }\n\n    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_lower(lhs.tensor(), rhs.tensor())\n    }\n\n    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_lower_elem(lhs.tensor(), rhs)\n    }\n\n    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::float_lower_equal(lhs.tensor(), rhs.tensor())\n    }\n\n    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::float_lower_equal_elem(lhs.tensor(), rhs)\n    }\n\n    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {\n        match tensor {\n            TensorPrimitive::Float(tensor) => B::float_argmax(tensor, dim),\n            TensorPrimitive::QFloat(tensor) => B::q_argmax(tensor, dim),\n        }\n    }\n\n    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {\n        match tensor {\n            TensorPrimitive::Float(tensor) => B::float_argmin(tensor, dim),\n            TensorPrimitive::QFloat(tensor) => B::q_argmin(tensor, dim),\n        }\n    }\n\n    fn max(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max(tensor)),\n        }\n    }\n\n    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_dim(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_dim(tensor, dim)),\n        }\n    }\n\n    fn max_dim_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                let (values, indices) = B::float_max_dim_with_indices(tensor, dim);\n                (TensorPrimitive::Float(values), indices)\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                let (values, indices) = B::q_max_dim_with_indices(tensor, dim);\n                (TensorPrimitive::QFloat(values), indices)\n            }\n        }\n    }\n\n    fn min(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min(tensor)),\n        }\n    }\n\n    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_min_dim(tensor, dim)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_min_dim(tensor, dim)),\n        }\n    }\n\n    fn min_dim_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                let (values, indices) = B::float_min_dim_with_indices(tensor, dim);\n                (TensorPrimitive::Float(values), indices)\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                let (values, indices) = B::q_min_dim_with_indices(tensor, dim);\n                (TensorPrimitive::QFloat(values), indices)\n            }\n        }\n    }\n\n    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_clamp(tensor, min, max))\n            }\n            TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max),\n        }\n    }\n\n    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_clamp_min(tensor, min))\n            }\n            TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min),\n        }\n    }\n\n    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_clamp_max(tensor, max))\n            }\n            TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max),\n        }\n    }\n\n    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_max_abs(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_max_abs(tensor)),\n        }\n    }\n\n    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_max_abs_dim(tensor, dim))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_max_abs_dim(tensor, dim))\n            }\n        }\n    }\n}\n\nimpl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {\n    type InnerKind = Float;\n\n    fn inner(\n        tensor: <Self as TensorKind<B>>::Primitive,\n    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),\n        }\n    }\n\n    fn from_inner(\n        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,\n    ) -> <Self as TensorKind<B>>::Primitive {\n        match inner {\n            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),\n            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/int.rs",
    "content": "use alloc::vec::Vec;\nuse burn_std::{DType, Shape, Slice};\n\nuse crate::{\n    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,\n    ops::TransactionPrimitive,\n    tensor::{\n        BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,\n        Ordered, TensorKind,\n    },\n};\n\nimpl<B: Backend> BasicOps<B> for Int {\n    type Elem = B::IntElem;\n\n    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        B::int_empty(shape, device, dtype.into())\n    }\n\n    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        B::int_zeros(shape, device, dtype.into())\n    }\n    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        B::int_ones(shape, device, dtype.into())\n    }\n\n    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        B::int_full(shape, fill_value, device, dtype.into())\n    }\n\n    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {\n        tr.register_int(tensor);\n    }\n\n    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        B::int_reshape(tensor, shape)\n    }\n\n    fn transpose(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_transpose(tensor)\n    }\n\n    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {\n        B::int_swap_dims(tensor, dim1, dim2)\n    }\n\n    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {\n        B::int_slice(tensor, slices)\n    }\n\n    fn slice_assign(\n        tensor: Self::Primitive,\n        slices: &[Slice],\n        value: Self::Primitive,\n    ) -> Self::Primitive {\n        B::int_slice_assign(tensor, slices, value)\n    }\n\n    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {\n        B::int_select(tensor, dim, indices)\n    }\n\n    fn select_assign(\n        tensor: Self::Primitive,\n        dim: usize,\n        indices: IntTensor<B>,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        match update {\n            IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),\n        }\n    }\n\n    fn mask_where(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        source: Self::Primitive,\n    ) -> Self::Primitive {\n        B::int_mask_where(tensor, mask, source)\n    }\n\n    fn mask_fill(\n        tensor: Self::Primitive,\n        mask: B::BoolTensorPrimitive,\n        value: Scalar,\n    ) -> Self::Primitive {\n        B::int_mask_fill(tensor, mask, value)\n    }\n\n    fn gather(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: B::IntTensorPrimitive,\n    ) -> Self::Primitive {\n        B::int_gather(dim, tensor, indices)\n    }\n\n    fn scatter(\n        dim: usize,\n        tensor: Self::Primitive,\n        indices: B::IntTensorPrimitive,\n        values: Self::Primitive,\n        update: IndexingUpdateOp,\n    ) -> Self::Primitive {\n        match update {\n            IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),\n        }\n    }\n\n    fn device(tensor: &Self::Primitive) -> Device<B> {\n        B::int_device(tensor)\n    }\n\n    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {\n        B::int_to_device(tensor, device)\n    }\n\n    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {\n        B::int_into_data(tensor).await\n    }\n\n    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {\n        B::int_from_data(data.convert::<B::IntElem>(), device)\n    }\n\n    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {\n        if !dtype.is_int() {\n            panic!(\"Expected int dtype, got {dtype:?}\")\n        }\n\n        B::int_from_data(data.convert_dtype(dtype), device)\n    }\n\n    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {\n        B::int_repeat_dim(tensor, dim, times)\n    }\n\n    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {\n        B::int_equal(lhs, rhs)\n    }\n\n    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {\n        B::int_not_equal(lhs, rhs)\n    }\n\n    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_equal_elem(lhs, rhs)\n    }\n\n    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_not_equal_elem(lhs, rhs)\n    }\n\n    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {\n        B::int_cat(vectors, dim)\n    }\n\n    fn any(tensor: Self::Primitive) -> BoolTensor<B> {\n        B::int_any(tensor)\n    }\n\n    fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {\n        B::int_any_dim(tensor, dim)\n    }\n\n    fn all(tensor: Self::Primitive) -> BoolTensor<B> {\n        B::int_all(tensor)\n    }\n\n    fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {\n        B::int_all_dim(tensor, dim)\n    }\n\n    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        B::int_permute(tensor, axes)\n    }\n\n    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {\n        B::int_expand(tensor, shape)\n    }\n\n    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {\n        B::int_flip(tensor, axes)\n    }\n\n    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {\n        B::int_unfold(tensor, dim, size, step)\n    }\n}\n\nimpl<B: Backend> Numeric<B> for Int {\n    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_add(lhs, rhs)\n    }\n    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_add_scalar(lhs, rhs)\n    }\n    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_sub(lhs, rhs)\n    }\n    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_sub_scalar(lhs, rhs)\n    }\n    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_div(lhs, rhs)\n    }\n    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_div_scalar(lhs, rhs)\n    }\n    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_remainder(lhs, rhs)\n    }\n    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_remainder_scalar(lhs, rhs)\n    }\n    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_mul(lhs, rhs)\n    }\n    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_mul_scalar(lhs, rhs)\n    }\n    fn neg(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_neg(tensor)\n    }\n\n    fn sum(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_sum(tensor)\n    }\n\n    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_sum_dim(tensor, dim)\n    }\n\n    fn prod(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_prod(tensor)\n    }\n\n    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_prod_dim(tensor, dim)\n    }\n\n    fn mean(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_mean(tensor)\n    }\n    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_mean_dim(tensor, dim)\n    }\n    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_cumsum(tensor, dim)\n    }\n    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_cumprod(tensor, dim)\n    }\n\n    fn abs(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_abs(tensor)\n    }\n\n    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_powi(lhs, rhs)\n    }\n\n    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {\n        B::int_powi_scalar(lhs, rhs)\n    }\n\n    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {\n        B::int_random(shape, distribution, device)\n    }\n\n    fn sign(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_sign(tensor)\n    }\n\n    /// Applies the matrix multiplication operation.\n    ///\n    /// `C = AB`\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have a compatible shape.\n    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {\n        B::int_matmul(lhs, rhs)\n    }\n}\n\nimpl<B: Backend> Ordered<B> for Int {\n    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {\n        B::int_sort(tensor, dim, descending)\n    }\n\n    fn sort_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n        descending: bool,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        B::int_sort_with_indices(tensor, dim, descending)\n    }\n\n    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {\n        B::int_argsort(tensor, dim, descending)\n    }\n\n    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_cummin(tensor, dim)\n    }\n\n    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_cummax(tensor, dim)\n    }\n\n    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::int_greater(lhs, rhs)\n    }\n\n    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_greater_elem(lhs, rhs)\n    }\n\n    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::int_greater_equal(lhs, rhs)\n    }\n\n    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_greater_equal_elem(lhs, rhs)\n    }\n\n    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::int_lower(lhs, rhs)\n    }\n\n    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_lower_elem(lhs, rhs)\n    }\n\n    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {\n        B::int_lower_equal(lhs, rhs)\n    }\n\n    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {\n        B::int_lower_equal_elem(lhs, rhs)\n    }\n\n    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {\n        B::int_argmax(tensor, dim)\n    }\n\n    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {\n        B::int_argmin(tensor, dim)\n    }\n\n    fn max(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_max(tensor)\n    }\n\n    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_max_dim(tensor, dim)\n    }\n\n    fn max_dim_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        B::int_max_dim_with_indices(tensor, dim)\n    }\n\n    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_max_abs(tensor)\n    }\n\n    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_max_abs_dim(tensor, dim)\n    }\n\n    fn min(tensor: Self::Primitive) -> Self::Primitive {\n        B::int_min(tensor)\n    }\n\n    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {\n        B::int_min_dim(tensor, dim)\n    }\n\n    fn min_dim_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n    ) -> (Self::Primitive, IntTensor<B>) {\n        B::int_min_dim_with_indices(tensor, dim)\n    }\n\n    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {\n        B::int_clamp(tensor, min, max)\n    }\n\n    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {\n        B::int_clamp_min(tensor, min)\n    }\n\n    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {\n        B::int_clamp_max(tensor, max)\n    }\n}\n\nimpl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {\n    type InnerKind = Int;\n\n    fn inner(\n        tensor: <Self as TensorKind<B>>::Primitive,\n    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {\n        B::int_inner(tensor)\n    }\n\n    fn from_inner(\n        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,\n    ) -> <Self as TensorKind<B>>::Primitive {\n        B::int_from_inner(inner)\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/mod.rs",
    "content": "mod autodiff;\nmod base;\nmod bool;\nmod float;\nmod int;\nmod numeric;\nmod ordered;\n\npub use autodiff::*;\npub use base::*;\npub use numeric::*;\npub use ordered::*;\n\n/// Computation to be used to update the existing values in indexed assignment operations (scatter/select).\n#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\npub enum IndexingUpdateOp {\n    // Assign,\n    /// Performs an addition.\n    Add,\n    // Mul\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/numeric.rs",
    "content": "use burn_std::Shape;\n\nuse crate::{Backend, Distribution, Scalar, element::Element, tensor::BasicOps};\n\n/// Trait that list all operations that can be applied on all numerical tensors.\n///\n/// # Warnings\n///\n/// This is an internal trait, use the public API provided by the\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// struct.\npub trait Numeric<B: Backend>: BasicOps<B>\nwhere\n    Self::Elem: Element,\n{\n    /// Adds two tensors together.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The sum of the two tensors.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For adding tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"add\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::add`\")]\n    /// function, which is more high-level and designed for public use.\n    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Adds a scalar to a tensor element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The sum of the tensor and the scalar.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For adding a scalar to a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"add_scalar\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::add_scalar`\")]\n    /// function, which is more high-level and designed for public use.\n    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Subtracts two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The difference of the two tensors.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For subtracting tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sub\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sub`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Subtracts a scalar from a tensor element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The difference of the tensor and the scalar.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For subtracting a scalar from a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sub_scalar\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sub_scalar`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Divides two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The quotient of the two tensors.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For dividing tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"div\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::div`\")]\n    /// function, which is more high-level and designed for public use.\n    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Divides a tensor by a scalar element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The quotient of the tensor and the scalar.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For dividing a tensor by a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"div_scalar\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::div_scalar`\")]\n    /// function, which is more high-level and designed for public use.\n    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is\n    /// less than that of the divisor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The dividend.\n    /// * `rhs` - The divisor.\n    ///\n    /// # Returns\n    ///\n    /// The modulo of the input tensor with the divisor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For performing the modulo operation, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"remainder\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::remainder`\")]\n    /// function, which is more high-level and designed for public use.\n    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Computes the modulo element-wise. The result is the *signed* remainder of the division and its absolute value is\n    /// less than that of the divisor.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The dividend.\n    /// * `rhs` - The divisor.\n    ///\n    /// # Returns\n    ///\n    /// The modulo of the input tensor with the divisor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For performing the modulo operation, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"remainder_scalar\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::remainder_scalar`\")]\n    /// function, which is more high-level and designed for public use.\n    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Multiplies two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// The product of the two tensors.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For multiplying tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mul\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mul`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Multiplies a tensor by a scalar element-wise.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// The product of the tensor and the scalar.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For multiplying a tensor by a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mul_scalar\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mul_scalar`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Negates a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to negate.\n    ///\n    /// # Returns\n    ///\n    /// The negated tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For negating a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"neg\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::neg`\")]\n    /// function, which is more high-level and designed for public use.\n    fn neg(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Returns the signs of the elements of a tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor.\n    ///\n    /// # Returns\n    ///\n    /// The signs of the elements of the tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the signs of the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sign\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sign`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sign(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Sums all the elements of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    ///\n    /// # Returns\n    ///\n    /// The sum of all the elements of the tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For summing all the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sum\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sum`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sum(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Sums all the elements of the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to sum.\n    /// * `dim` - The dimension along which to sum.\n    ///\n    /// # Returns\n    ///\n    /// The sum of all the elements of the tensor along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For summing all the elements of a tensor along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sum_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sum_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Computes the product of all the elements of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the product of.\n    ///\n    /// # Returns\n    ///\n    /// The product of all the elements of the tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the product of all the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"prod\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::prod`\")]\n    /// function, which is more high-level and designed for public use.\n    fn prod(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Computes the product of all the elements of the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the product of.\n    /// * `dim` - The dimension along which to compute the product.\n    ///\n    /// # Returns\n    ///\n    /// The product of all the elements of the tensor along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the product of all the elements of a tensor along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"prod_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::prod_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Computes the mean of all the elements of the tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the mean of.\n    ///\n    /// # Returns\n    ///\n    /// The mean of all the elements of the tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the mean of all the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mean\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mean`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mean(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Computes the mean of all the elements of the tensor along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the mean of.\n    /// * `dim` - The dimension along which to compute the mean.\n    ///\n    /// # Returns\n    ///\n    /// The mean of all the elements of the tensor along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the mean of all the elements of a tensor along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"mean_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::mean_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Computes the cumulative sum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative sum of.\n    /// * `dim` - The dimension along which to compute the cumulative sum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the cumulative sum\n    /// of all elements up to and including that position along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the cumulative sum of elements along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"cumsum\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::cumsum`\")]\n    /// function, which is more high-level and designed for public use.\n    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Computes the cumulative product of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative product of.\n    /// * `dim` - The dimension along which to compute the cumulative product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the cumulative product\n    /// of all elements up to and including that position along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the cumulative product of elements along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"cumprod\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::cumprod`\")]\n    /// function, which is more high-level and designed for public use.\n    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Calculate absolute value on all elements of a tensor\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to apply abs to.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with absolute values.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For calculating abs of the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"abs\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::abs`\")]\n    /// function, which is more high-level and designed for public use.\n    fn abs(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Element-wise power of a tensor\n    ///\n    /// # Arguments\n    /// * `tensor` - The tensor to apply power to.\n    /// * `power` - The power to apply to the tensor.\n    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n\n    /// Element-wise power of a tensor to a scalar int\n    ///\n    /// # Arguments\n    /// * `tensor` - The tensor to apply power to.\n    /// * `power` - The power to apply to the tensor.\n    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive;\n\n    /// Create a random tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the output tensor.\n    /// * `distribution` - The distribution used to sample.\n    /// * `device` - The device to use.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"random\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::random`\")]\n    /// function, which is more high-level and designed for public use.\n    fn random(shape: Shape, distribution: Distribution, device: &B::Device) -> Self::Primitive;\n\n    /// Applies the matrix multiplication operation.\n    ///\n    /// ```math\n    /// C = AB\n    /// ```\n    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/ops/ordered.rs",
    "content": "use crate::{\n    Backend, Scalar,\n    tensor::{IntTensor, Numeric},\n};\n\n/// Trait that list all operations that can be applied on all numerical tensors\n/// whose elements have a well-defined ordering.\n///\n/// This includes operations such as comparisons, minimum/maximum reductions,\n/// and other order-dependent computations that are not strictly valid for all numerical\n/// types.\n///\n/// # Warnings\n///\n/// This is an internal trait, use the public API provided by the\n#[cfg_attr(doc, doc = crate::doc_tensor!())]\n#[cfg_attr(not(doc), doc = \"`Tensor`\")]\n/// struct.\npub trait Ordered<B: Backend>: Numeric<B> {\n    /// Sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where the elements are sorted by value.\n    ///\n    /// # Remarks\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sort\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sort`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive;\n\n    /// Sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor and corresponding indices, where\n    /// the elements are sorted by value and the indices map back to the original input tensor.\n    ///\n    /// # Remarks\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For sorting the elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"sort_with_indices\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::sort_with_indices`\")]\n    /// function, which is more high-level and designed for public use.\n    fn sort_with_indices(\n        tensor: Self::Primitive,\n        dim: usize,\n        descending: bool,\n    ) -> (Self::Primitive, IntTensor<B>);\n\n    /// Returns the indices that sort the elements of the input `tensor` by value along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor.\n    /// * `dim` - The axis along which to sort.\n    /// * `descending` - The sorting order.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor the indices map back to the original input tensor.\n    ///\n    /// # Remarks\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// Users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"argsort\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::argsort`\")]\n    /// function, which is more high-level and designed for public use.\n    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B>;\n\n    /// Computes the cumulative minimum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative minimum of.\n    /// * `dim` - The dimension along which to compute the cumulative minimum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the minimum\n    /// of all elements up to and including that position along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the cumulative minimum of elements along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"cummin\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::cummin`\")]\n    /// function, which is more high-level and designed for public use.\n    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Computes the cumulative maximum of elements along a dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to compute the cumulative maximum of.\n    /// * `dim` - The dimension along which to compute the cumulative maximum.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the maximum\n    /// of all elements up to and including that position along the specified dimension.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For computing the cumulative maximum of elements along a dimension, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"cummax\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::cummax`\")]\n    /// function, which is more high-level and designed for public use.\n    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Element-wise greater than comparison between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding element of the left hand side tensor is greater than the corresponding element\n    /// of the right hand side tensor, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise greater than comparison between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"greater\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::greater`\")]\n    /// function, which is more high-level and designed for public use.\n    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise greater than comparison between a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensor, where each element is true if the\n    /// corresponding element of the left hand side tensor is greater than the right hand side\n    /// scalar, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise greater than comparison between a tensor and a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"greater_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::greater_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Element-wise greater than or equal comparison between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding element of the left hand side tensor is greater than or equal to the\n    /// corresponding element of the right hand side tensor, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise greater than or equal comparison between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"greater_equal\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::greater_equal`\")]\n    /// function, which is more high-level and designed for public use.\n    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise greater than or equal comparison between a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensor, where each element is true if the\n    /// corresponding element of the left hand side tensor is greater than or equal to the right\n    /// hand side scalar, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"greater_equal_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::greater_equal_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Element-wise less than comparison between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding element of the left hand side tensor is less than the corresponding element of\n    /// the right hand side tensor, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise less than comparison between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"lower\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::lower`\")]\n    /// function, which is more high-level and designed for public use.\n    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise less than comparison between a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensor, where each element is true if the\n    /// corresponding element of the left hand side tensor is less than the right hand side scalar,\n    /// and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise less than comparison between a tensor and a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"lower_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::lower_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Element-wise less than or equal comparison between two tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side tensor.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors, where each element is true if the\n    /// corresponding element of the left hand side tensor is less than or equal to the corresponding\n    /// element of the right hand side tensor, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise less than or equal comparison between two tensors, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"lower_equal\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::lower_equal`\")]\n    /// function, which is more high-level and designed for public use.\n    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive;\n\n    /// Element-wise less than or equal comparison between a tensor and a scalar.\n    ///\n    /// # Arguments\n    ///\n    /// * `lhs` - The left hand side tensor.\n    /// * `rhs` - The right hand side scalar.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensor, where each element is true if the\n    /// corresponding element of the left hand side tensor is less than or equal to the right hand\n    /// side scalar, and false otherwise.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"lower_equal_elem\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::lower_equal_elem`\")]\n    /// function, which is more high-level and designed for public use.\n    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive;\n\n    /// Gets the indices of the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to get the indices of the maximum elements.\n    /// * `tensor` - The tensor to get the indices of the maximum elements from.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the index of the\n    /// maximum element of the input tensor at the corresponding index along the specified axis.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"argmax\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::argmax`\")]\n    /// function, which is more high-level and designed for public use.\n    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;\n\n    /// Gets the indices of the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to get the indices of the minimum elements.\n    /// * `tensor` - The tensor to get the indices of the minimum elements from.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor, where each element is the index of the\n    /// minimum element of the input tensor at the corresponding index along the specified axis.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"argmin\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::argmin`\")]\n    /// function, which is more high-level and designed for public use.\n    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B>;\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A single-element tensor containing the maximum element of the input tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the maximum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"max\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::max`\")]\n    /// function, which is more high-level and designed for public use.\n    fn max(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements from.\n    /// * `dim` - The axis along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.\n    /// Each element is the maximum element of the corresponding input dim.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the maximum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"max_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::max_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements from.\n    /// * `dim` - The axis along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tuple containing the maximum element of the input tensor, and a tensor with the same shape\n    /// as the input tensor, where each element is the index of the maximum element of the input tensor\n    /// at the corresponding index along the specified axis.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the maximum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"max_dim_with_indices\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::max_dim_with_indices`\")]\n    /// function, which is more high-level and designed for public use.\n    fn max_dim_with_indices(tensor: Self::Primitive, dim: usize)\n    -> (Self::Primitive, IntTensor<B>);\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The axis along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A single-element tensor containing the maximum absolute element of the input tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the maximum absolute elements of a tensor, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"max_abs\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::max_abs`\")]\n    /// function, which is more high-level and designed for public use.\n    fn max_abs(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Gets the maximum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the maximum elements from.\n    /// * `dim` - The axis along which to get the maximum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.\n    /// Each element is the maximum absolute element of the corresponding input dim.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the maximum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"max_abs_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::max_abs_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Gets the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements from.\n    ///\n    /// # Returns\n    ///\n    /// A single-element tensor containing the minimum element of the input tensor.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the minimum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"min\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::min`\")]\n    /// function, which is more high-level and designed for public use.\n    fn min(tensor: Self::Primitive) -> Self::Primitive;\n\n    /// Gets the minimum elements of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements from.\n    /// * `dim` - The axis along which to get the minimum elements.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same rank as the input tensor, but the given dim set to a shape of 1.\n    /// Each element is the minimum element of the corresponding input dim.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the minimum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"min_dim\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::min_dim`\")]\n    /// function, which is more high-level and designed for public use.\n    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive;\n\n    /// Gets the minimum elements and indices of a tensor along an axis.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to get the minimum elements from.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensor and corresponding indices, where\n    /// each element is the minimum element of the input tensor at the corresponding index\n    /// along the specified axis.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import\n    /// or use this function directly.\n    ///\n    /// For getting the minimum elements of a tensor along an axis, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"min_dim_with_indices\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::min_dim_with_indices`\")]\n    /// function, which is more high-level and designed for public use.\n    fn min_dim_with_indices(tensor: Self::Primitive, dim: usize)\n    -> (Self::Primitive, IntTensor<B>);\n\n    /// Clamp the tensor between the given min and max values.\n    ///\n    /// # Arguments\n    ///\n    /// * `min` - The minimum value.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped between the given min and max values.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users.\n    ///\n    /// For clamping a tensor between the given min and max values, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"clamp\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::clamp`\")]\n    /// function, which is more high-level and designed for public use.\n    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive;\n\n    /// Clamps a tensor under a minimum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped under the given min value.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users.\n    ///\n    /// For clamping a tensor under a minimum value, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"clamp_min\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::clamp_min`\")]\n    /// function, which is more high-level and designed for public use.\n    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive;\n\n    /// Clamps a tensor over a maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped over the given max value.\n    ///\n    /// # Remarks\n    ///\n    /// This is a low-level function used internally by the library to call different backend functions\n    /// with static dispatch. It is not designed for direct usage by users.\n    ///\n    /// For clamping a tensor over a maximum value, users should prefer the\n    #[cfg_attr(doc, doc = crate::doc_tensor!(\"clamp_max\"))]\n    #[cfg_attr(not(doc), doc = \"`Tensor::clamp_max`\")]\n    /// function, which is more high-level and designed for public use.\n    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive;\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/calibration.rs",
    "content": "/// Calibration method used to compute the quantization range mapping.\npub enum Calibration {\n    /// Computes quantization range mapping based on the min and max values.\n    MinMax,\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/mod.rs",
    "content": "mod calibration;\nmod parameters;\nmod scheme;\n\npub use calibration::*;\npub use parameters::*;\npub use scheme::*;\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/parameters.rs",
    "content": "use crate::Backend;\n\npub use burn_std::quantization::{QParamTensor, QParams};\n\n/// The quantization parameters primitive.\n///\n/// # Remarks\n///\n/// This is a low-level struct used internally by the library to provide the quantization parameters\n/// to the backends. It is not designed for direct usage by users, and not recommended to import\n/// or use this struct directly.\npub struct QuantizationParametersPrimitive<B: Backend> {\n    /// The scaling factor.\n    pub scales: B::FloatTensorPrimitive,\n}\n"
  },
  {
    "path": "crates/burn-backend/src/tensor/quantization/scheme.rs",
    "content": "pub use burn_std::{QPARAM_ALIGN, params_shape};\nuse burn_std::{QuantLevel, QuantMode, QuantScheme, Shape};\n\nuse super::{Calibration, QuantizationParametersPrimitive};\nuse crate::{Backend, TensorMetadata};\n\n/// Compute the quantization range mapping.\npub fn compute_range<B: Backend>(\n    scheme: &QuantScheme,\n    tensor: B::FloatTensorPrimitive,\n    calibration: &Calibration,\n) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) {\n    match calibration {\n        Calibration::MinMax => match scheme.level {\n            QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)),\n            QuantLevel::Block(block_size) => {\n                let block_elems = block_size.num_elements();\n                let shape = tensor.shape();\n                let numel = shape.num_elements();\n\n                assert_eq!(\n                    numel % block_elems,\n                    0,\n                    \"Tensor {shape:?} must be evenly divisible by block size {block_elems}\"\n                );\n\n                let num_blocks = numel / block_elems;\n\n                let params_shape = params_shape(&shape, scheme.level);\n\n                let blocks = B::float_reshape(tensor, Shape::new([num_blocks, block_elems]));\n                let blocks_min =\n                    B::float_reshape(B::float_min_dim(blocks.clone(), 1), params_shape.clone());\n                let blocks_max = B::float_reshape(B::float_max_dim(blocks, 1), params_shape);\n                (blocks_min, blocks_max)\n            }\n        },\n    }\n}\n\n/// Compute the quantization parameters.\npub fn compute_q_params<B: Backend>(\n    scheme: &QuantScheme,\n    min: B::FloatTensorPrimitive,\n    max: B::FloatTensorPrimitive,\n) -> QuantizationParametersPrimitive<B> {\n    match scheme {\n        QuantScheme {\n            level: QuantLevel::Tensor | QuantLevel::Block(_),\n            mode: QuantMode::Symmetric,\n            ..\n        } => {\n            // Quantized range `[a, b]`\n            let (a, b) = scheme.value.range();\n\n            // Compute scale to convert an input value in range `[-alpha, alpha]`\n            let min_abs = B::float_abs(min);\n            let max_abs = B::float_abs(max);\n\n            // `min_abs.max_pair(max_abs)`\n            let mask = B::float_lower(min_abs.clone(), max_abs.clone());\n            let values_range =\n                B::float_mul_scalar(B::float_mask_where(min_abs, mask, max_abs), 2f32.into());\n\n            QuantizationParametersPrimitive {\n                scales: B::float_div_scalar(values_range, (b - a).into()),\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/.cargo/config.toml",
    "content": "[alias]\ntest-cpu = \"test --release --no-default-features --features cpu,std\"\ntest-cuda = \"test --release --no-default-features --features cuda,std\"\ntest-ndarray = \"test --release --no-default-features --features ndarray,std\"\ntest-rocm = \"test --release --no-default-features --features rocm,std\"\ntest-router = \"test --release --no-default-features --features router,std\"\ntest-tch = \"test --release --no-default-features --features tch,std\"\ntest-wgpu = \"test --release --no-default-features --features wgpu,std\"\ntest-vulkan = \"test --release --no-default-features --features vulkan,std\"\ntest-metal = \"test --release --no-default-features --features metal,std\"\n"
  },
  {
    "path": "crates/burn-backend-tests/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Tensor tests for Burn backends\"\ndocumentation = \"https://docs.rs/burn-backend-tests\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-backend-tests\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-backend-tests\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"burn-tensor/default\",\n    \"burn-autodiff/default\",\n    # Backends (default not enabled for CubeCL backends as it activates fusion)\n    \"burn-cpu?/default\",\n    \"burn-ndarray?/default\",\n    \"burn-tch?/default\",\n    # Default\n    \"ndarray\",\n    \"std\",\n]\nstd = [\n    \"burn-tensor/std\",\n    \"burn-autodiff/std\",\n    # Backends\n    \"burn-cpu?/std\",\n    \"burn-ndarray?/std\",\n    \"burn-wgpu?/std\",\n    \"burn-router?/std\",\n    \"burn-cuda?/std\",\n    \"burn-rocm?/std\",\n]\n\ntracing = [\n    \"cubecl?/tracing\",\n    \"burn-tensor/tracing\",\n    \"burn-autodiff/tracing\",\n    # Backends\n    \"burn-cpu?/tracing\",\n    \"burn-ndarray?/tracing\",\n    \"burn-wgpu?/tracing\",\n    \"burn-router?/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-rocm?/tracing\",\n]\n\n# Backends\ncuda = [\"burn-cuda\", \"quantization\", \"cube\"]\nrocm = [\"burn-rocm\", \"quantization\", \"cube\"]\nndarray = [\"burn-ndarray\", \"quantization\"]\ntch = [\"burn-tch\"]\nvulkan = [\"wgpu\", \"burn-wgpu/vulkan\"]\nwebgpu = [\"wgpu\", \"burn-wgpu/webgpu\"]\nmetal = [\"wgpu\", \"burn-wgpu/metal\"]\nwgpu = [\"burn-wgpu\", \"quantization\", \"cube\"]\ncpu = [\"burn-cpu\", \"cube\"]\nrouter = [\"burn-router\", \"ndarray\", \"burn-wgpu\"]\n\nautotune = [\n    \"burn-wgpu?/autotune\",\n    \"burn-cuda?/autotune\",\n    \"burn-rocm?/autotune\",\n    \"burn-cpu?/autotune\",\n]\nautotune-checks = [\n    \"burn-wgpu?/autotune-checks\",\n    \"burn-cuda?/autotune-checks\",\n    \"burn-rocm?/autotune-checks\",\n    \"burn-cpu?/autotune-checks\",\n]\n\n# CubeCL backends\ncube = [\n    \"cubecl\",\n    \"cubek\",\n    \"autotune\",\n    \"burn-fusion\",\n    \"burn-cubecl\",\n    \"burn-ndarray\",\n]\n\n# Test configs\nquantization = []\n\n[dependencies]\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-tensor-testgen = { path = \"../burn-tensor-testgen\", version = \"=0.21.0-pre.2\" }\n\n# Backends\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"export_tests\",\n] }\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-cpu = { path = \"../burn-cpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", optional = true, default-features = false, features = [\n    \"export_tests\",\n] }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\n# To wrap `Fusion<CubeBackend>\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", optional = true, features = [\n    \"fusion\",\n] }\n\nnum-traits = { workspace = true }\nserial_test = { workspace = true }\n\ncubecl = { workspace = true, optional = true }\ncubek = { workspace = true, features = [\"random\"], optional = true }\n"
  },
  {
    "path": "crates/burn-backend-tests/README.md",
    "content": "# Burn Backend Tests\n\nThis crate provides a comprehensive suite of tests for Burn backends, covering:\n\n- Tensor operations: [tests/tensor/](./tests/tensor/)\n- Autodiff: [tests/autodiff/](./tests/autodiff/)\n- (Optional) CubeCL kernels correctness: [tests/cubecl/](./tests/cubecl/)\n\n## Running Tests\n\nThe `TestBackend` is selected via feature flags. Use the provided shorthand commands for\nconvenience:\n\n```sh\n# Cpu\ncargo test-cpu\n# Cuda\ncargo test-cuda\n# Rocm\ncargo test-rocm\n# Wgpu / WebGpu\ncargo test-wgpu\n# Vulkan\ncargo test-vulkan\n# Metal\ncargo test-metal\n# Router\ncargo test-router\n\n# NdArray\ncargo test-ndarray\n# LibTorch\ncargo test-tch\n```\n\nBy default, `cargo test` fail-fast across integration test binaries. When one integration test\nbinary fails, Cargo does not run the remaining test binaries. If you want to run all test binaries\nregardless of failures, pass `--no-fail-fast`, for example:\n\n```sh\ncargo test-cuda --no-fail-fast\n```\n\n## Structure\n\n- `tests/tensor.rs`: Tensor tests\n- `tests/autodiff.rs`: Autodiff tests\n- `tests/fusion.rs`: Fusion backend tests wrapping tensor and autodiff tests\n- `tests/cubecl.rs`: CubeCL kernel tests\n\nEach test module assumes exactly one `FloatElemType`, `IntElemType`, and `TestBackend` in scope.\n\n### Common Modules\n\n- `common/backend.rs`: Backend type definitions\n- `common/tensor.rs`: Reusable tensor test suite, split across float, int and bool tensor kinds\n- `common/autodiff.rs`: Reusable autodiff test suite, with and without checkpointing\n\n### Test Reusability\n\nThis crate uses a pattern of parameterized test modules to run the same tests with different\nconfigurations (backends, dtypes, etc.):\n\n1. **Type aliases define the configuration**: Each test scope declares `FloatElemType`,\n   `IntElemType`, and `TestBackend`\n1. **`#[path = \"...\"]` references shared modules**: Points to test files outside the normal module\n   hierarchy, e.g. `\"common/tensor.rs\"`\n1. **`include!()` imports test code**: Test modules are included multiple times with different type\n   configurations\n1. **`use super::*;`** propagates types down the module tree: Each level re-exports parent types so\n   deeply nested tests have access to the configured types\n\nFor example, `common/tensor.rs` can be included with `FloatElemType = f32` for base tests, then\nincluded again with `FloatElemType = f16` for half-precision tests, running the same test suite\ntwice with different dtypes.\n\n## Adding New Tests\n\nAdd test modules under `tests/tensor/`, `tests/autodiff/`, or `tests/cubecl` respectively. They will\nautomatically run for all required configurations.\n\nFor tensor tests, make sure to add the test to each relevant tensor kind:\n\n- `tensor/bool`: boolean tensor tests\n- `tensor/float`: float tensor tests\n- `tensor/int`: integer tensor tests\n\n**Guidelines:**\n\nImport types with `use super::*;` at the top of each module and use the types defined in\n`common/backend.rs`:\n\n```rust\n/// Collection of types used across tests\npub use burn_autodiff::Autodiff;\npub use burn_tensor::Tensor;\npub type TestBackend = ...;\n\npub type TestTensor<const D: usize> = Tensor<TestBackend, D>;\npub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;\npub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;\n\npub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;\npub type IntElem = burn_tensor::ops::IntElem<TestBackend>;\n\npub type TestAutodiffBackend = Autodiff<TestBackend>;\npub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;\n```\n\nTests will automatically run with default dtypes and any variants (f16, bf16, etc.) based on the\nbackend configuration.\n"
  },
  {
    "path": "crates/burn-backend-tests/cubecl.toml",
    "content": "[profiling]\nlogger = { file = \"target/profiling.log\", level = \"disabled\" }\n\n[autotune]\nlogger = { file = \"target/autotune.log\", level = \"disabled\" }\n\n[compilation]\nlogger = { file = \"target/compilation.log\", level = \"disabled\" }\n\n[memory]\nlogger = { file = \"target/memory.log\", level = \"disabled\" }\n\n[streaming]\nmax_streams = 4\n"
  },
  {
    "path": "crates/burn-backend-tests/src/lib.rs",
    "content": "extern crate alloc;\n\n#[cfg(feature = \"std\")]\npub use burn_tensor_testgen::might_panic;\n\n/// Generate a test module with custom floating element types.\n#[macro_export]\nmacro_rules! test_float_elem_variant {\n    ($modname:ident, $float:ty, $module:literal, [$($feat:literal),* $(,)?]) => {\n        #[cfg(all(test, any($(feature = $feat),*)))]\n        mod $modname {\n            pub type FloatElemType = $float;\n            #[allow(unused)]\n            pub use super::IntElemType;\n\n            mod ty {\n                include!(\"backend.rs\");\n                include!($module);\n            }\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/abs.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance, cast::ToElement};\n\n#[test]\nfn should_diff_abs() {\n    let data_1 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[71.0, 107.0], [71.0, 107.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[84.0, 42.0], [90.0, 54.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_abs_no_nans() {\n    let data_1 = TensorData::from([[6.0, 7.0], [9.0, -10.0]]);\n    let data_2 = TensorData::from([[0.0, -1.0], [3.0, 4.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().abs());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[1.0, 7.0], [1.0, 7.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[0.0, -15.0], [-3.0, -3.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let contains_nan = grad_2.contains_nan();\n    assert!(!contains_nan.into_scalar().to_bool());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/adaptive_avgpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::module::adaptive_avg_pool1d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool1d_simple() {\n    let test = AdaptiveAvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        length: 5,\n        output_size: 3,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[\n            [0.5000, 0.83333, 0.33333, 0.83333, 0.5000],\n            [0.5000, 0.83333, 0.33333, 0.83333, 0.5000],\n        ]],\n        &Default::default(),\n    ));\n}\n\nstruct AdaptiveAvgPool1dTestCase {\n    batch_size: usize,\n    channels: usize,\n    length: usize,\n    output_size: usize,\n}\n\nimpl AdaptiveAvgPool1dTestCase {\n    fn assert_output(self, x_grad: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.length]);\n        let device = Default::default();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = adaptive_avg_pool1d(x.clone(), self.output_size);\n        let grads = output.backward();\n        let x_grad_actual = x.grad(&grads).unwrap();\n\n        x_grad.to_data().assert_approx_eq::<FloatElem>(\n            &x_grad_actual.into_data(),\n            Tolerance::default().set_half_precision_relative(1e-3),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/adaptive_avgpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::module::adaptive_avg_pool2d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool2d_simple() {\n    let test = AdaptiveAvgPool2dTestCase {\n        batch_size: 1,\n        channels: 2,\n        height: 5,\n        width: 3,\n        output_size_1: 3,\n        output_size_2: 2,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[\n            [\n                [0.2500, 0.5000, 0.2500],\n                [0.41667, 0.83333, 0.41667],\n                [0.16667, 0.33333, 0.16667],\n                [0.41667, 0.83333, 0.41667],\n                [0.2500, 0.5000, 0.2500],\n            ],\n            [\n                [0.2500, 0.5000, 0.2500],\n                [0.41667, 0.83333, 0.41667],\n                [0.16667, 0.33333, 0.16667],\n                [0.41667, 0.83333, 0.41667],\n                [0.2500, 0.5000, 0.2500],\n            ],\n        ]],\n        &Default::default(),\n    ));\n}\n\n#[test]\nfn test_avg_pool2d_output_1() {\n    let test = AdaptiveAvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 4,\n        width: 8,\n        output_size_1: 1,\n        output_size_2: 1,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[\n            [\n                0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,\n            ],\n            [\n                0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,\n            ],\n            [\n                0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,\n            ],\n            [\n                0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125, 0.03125,\n            ],\n        ]]],\n        &Default::default(),\n    ));\n}\n\nstruct AdaptiveAvgPool2dTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    output_size_1: usize,\n    output_size_2: usize,\n}\n\nimpl AdaptiveAvgPool2dTestCase {\n    fn assert_output(self, x_grad: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let device = Default::default();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = adaptive_avg_pool2d(x.clone(), [self.output_size_1, self.output_size_2]);\n        let grads = output.backward();\n        let x_grad_actual = x.grad(&grads).unwrap();\n\n        x_grad.to_data().assert_approx_eq::<FloatElem>(\n            &x_grad_actual.into_data(),\n            Tolerance::default().set_half_precision_relative(1e-3),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/add.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_add() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_floats([2.0, 5.0], &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_floats([4.0, 1.0], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone() + tensor_2.clone();\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([1.0, 1.0]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([1.0, 1.0]), false);\n    tensor_3\n        .to_data()\n        .assert_eq(&TensorData::from([6.0, 6.0]), false);\n}\n\n#[test]\nfn should_diff_add_scalar() {\n    let data = TensorData::from([2.0, 10.0]);\n\n    let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();\n    let tensor_out = tensor.clone().add_scalar(5.0);\n    let grads = tensor_out.backward();\n\n    let grad = tensor.grad(&grads).unwrap();\n\n    grad.to_data()\n        .assert_eq(&TensorData::from([1.0, 1.0]), false);\n    tensor_out\n        .into_data()\n        .assert_eq(&TensorData::from([7.0, 15.0]), false);\n}\n\n#[test]\nfn test_add_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().add(tensor_2.clone());\n    let tensor_5 = tensor_4\n        .add(tensor_3)\n        .add_scalar(5.0)\n        .add(tensor_1.clone())\n        .add(tensor_2.clone());\n    let tensor_6 = tensor_1.clone().add(tensor_5);\n\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[3.0, 3.0], [3.0, 3.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/aggregation.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_mean() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.mean().unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[3.5, 9.5], [3.5, 9.5]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[-0.75, -0.75], [3.0, 3.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_sum_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.sum().unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[14.0, 38.0], [14.0, 38.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[-3.0, -3.0], [12.0, 12.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_sum_2() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.clone().sum_dim(1);\n    let tensor_5 = tensor_4.mul(tensor_3);\n\n    let grads = tensor_5.sum().backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[494.0, 722.0], [2990.0, 4370.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[690.0, 690.0], [958.0, 958.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_mean_dim() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.mean_dim(1).unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[4.0, 36.0], [3.0, -17.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[9.0, 9.0], [35.5, 35.5]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_sum_dim() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.sum_dim(1).unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[8.0, 72.0], [6.0, -34.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[18.0, 18.0], [71.0, 71.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/avgpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::module::avg_pool1d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool1d_simple() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size: 3,\n        padding: 0,\n        stride: 1,\n        length: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[0.33333, 0.66667, 1.0000, 1.0000, 0.66667, 0.33333]]],\n        &Default::default(),\n    ));\n}\n\n#[test]\nfn test_avg_pool1d_complex() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        length: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[\n            [0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],\n            [0.33333, 0.66667, 0.33333, 0.66667, 0.33333, 0.33333],\n        ]],\n        &Default::default(),\n    ));\n}\n\n#[test]\nfn test_avg_pool1d_complex_dont_count_pad() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        length: 6,\n        count_include_pad: false,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[\n            [0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],\n            [0.5000, 0.83333, 0.33333, 0.66667, 0.33333, 0.33333],\n        ]],\n        &Default::default(),\n    ));\n}\n\nstruct AvgPool1dTestCase {\n    batch_size: usize,\n    channels: usize,\n    kernel_size: usize,\n    padding: usize,\n    stride: usize,\n    length: usize,\n    count_include_pad: bool,\n}\n\nimpl AvgPool1dTestCase {\n    fn assert_output(self, x_grad: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.length]);\n        let device = Default::default();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = avg_pool1d(\n            x.clone(),\n            self.kernel_size,\n            self.stride,\n            self.padding,\n            self.count_include_pad,\n            false,\n        );\n        let grads = output.backward();\n        let x_grad_actual = x.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n        x_grad\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/avgpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::module::avg_pool2d;\nuse burn_tensor::{Shape, Tolerance};\n\n#[test]\nfn test_avg_pool2d_simple() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        height: 6,\n        width: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[\n            [0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],\n            [0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],\n            [0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],\n            [0.33333, 0.66667, 1.00000, 1.00000, 0.66667, 0.33333],\n            [0.22222, 0.44444, 0.66667, 0.66667, 0.44444, 0.22222],\n            [0.11111, 0.22222, 0.33333, 0.33333, 0.22222, 0.11111],\n        ]]],\n        &Default::default(),\n    ));\n}\n\n#[test]\nfn test_avg_pool2d_complex() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 2,\n        height: 4,\n        width: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[\n            [0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],\n            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],\n            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],\n            [0.33333, 0.33333, 0.33333, 0.33333, 0.33333, 0.33333],\n        ]]],\n        &Default::default(),\n    ));\n}\n\n#[test]\nfn test_avg_pool2d_complex_dont_include_pad() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 2,\n        height: 4,\n        width: 6,\n        count_include_pad: false,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[\n            [0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],\n            [0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],\n            [0.8750, 0.8750, 0.58333, 0.58333, 0.8750, 0.8750],\n            [0.6250, 0.6250, 0.41667, 0.41667, 0.6250, 0.6250],\n        ]]],\n        &Default::default(),\n    ));\n}\n\nstruct AvgPool2dTestCase {\n    batch_size: usize,\n    channels: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    height: usize,\n    width: usize,\n    count_include_pad: bool,\n}\n\nimpl AvgPool2dTestCase {\n    fn assert_output(self, x_grad: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let device = Default::default();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = avg_pool2d(\n            x.clone(),\n            [self.kernel_size_1, self.kernel_size_2],\n            [self.stride_1, self.stride_2],\n            [self.padding_1, self.padding_2],\n            self.count_include_pad,\n            false,\n        );\n        let grads = output.backward();\n        let x_grad_actual = x.grad(&grads).unwrap();\n\n        x_grad.to_data().assert_approx_eq::<FloatElem>(\n            &x_grad_actual.into_data(),\n            Tolerance::default().set_half_precision_relative(1e-3),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/backward.rs",
    "content": "use super::*;\nuse burn_tensor::{Int, Tensor, TensorData, module::embedding};\n\n#[test]\nfn test_embedding_backward() {\n    let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TensorData::from([[0, 1], [1, 1]]);\n    let x = TensorData::from([\n        [[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],\n        [[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],\n    ]);\n    let device = Default::default();\n    let weights = Tensor::<TestAutodiffBackend, 2>::from_data(weights, &device).require_grad();\n    let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(indices, &device);\n    let x = Tensor::<TestAutodiffBackend, 3>::from_data(x, &device).require_grad();\n\n    let output = embedding(weights.clone(), indices);\n    let output = output.matmul(x);\n    let grads = output.backward();\n\n    let grad = weights.grad(&grads).unwrap();\n    grad.to_data()\n        .assert_eq(&TensorData::from([[3., 9., 7.], [21., 35., 27.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/bridge.rs",
    "content": "use super::*;\nuse burn_tensor::{DType, Distribution, Tensor};\n\n#[test]\nfn test_full_precision() {\n    let device = Default::default();\n    let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)\n        .require_grad();\n    let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)\n        .require_grad();\n    let dtype = x1.dtype();\n\n    let x3 = x1.clone().cast(DType::F32);\n    let x4 = x2.clone().cast(DType::F32);\n\n    let x5 = x3.matmul(x4);\n    let x6 = x5.cast(dtype);\n    let x7 = x6 * x1.clone() / x2.clone();\n\n    let grads = x7.backward();\n\n    let x1_grad = x1.grad(&grads);\n    let x2_grad = x2.grad(&grads);\n\n    assert!(x1_grad.is_some());\n    assert!(x2_grad.is_some());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/broadcast.rs",
    "content": "use super::*;\n\n#[test]\nfn mul_broadcast() {\n    test_ops_broadcast_backward(|x, y| x * y);\n}\n\n#[test]\nfn div_broadcast() {\n    test_ops_broadcast_backward(|x, y| x / y);\n}\n\n#[test]\nfn sub_broadcast() {\n    test_ops_broadcast_backward(|x, y| x - y);\n}\n\n#[test]\nfn add_broadcast() {\n    test_ops_broadcast_backward(|x, y| x + y);\n}\n\n#[test]\nfn matmul_broadcast() {\n    test_ops_broadcast_backward(|x, y| x.matmul(y));\n}\n\n#[test]\nfn mask_where_broadcast() {\n    test_ops_broadcast_backward(|x, y| {\n        let cond = y.clone().equal_elem(4);\n        x.mask_where(cond, y)\n    });\n}\n\nfn test_ops_broadcast_backward<F>(func: F)\nwhere\n    F: Fn(TestAutodiffTensor<3>, TestAutodiffTensor<3>) -> TestAutodiffTensor<3>,\n{\n    let device = Default::default();\n    let w = TestAutodiffTensor::zeros([16, 5, 5], &device).require_grad();\n    let x = TestAutodiffTensor::zeros([4, 5, 5], &device).require_grad();\n\n    // Slice isn't a broadcastable operation, so it will fail when the previous backward pass\n    // of an operation that support broadcast doesn't support it during the backward pass.\n    let y = func(w.clone().slice([0..1]), x.clone());\n\n    // Will panic if broadcast isn't supported!\n    let grads = y.backward();\n\n    let w_grad = w.grad(&grads).unwrap();\n    let x_grad = x.grad(&grads).unwrap();\n\n    assert_eq!(w_grad.shape(), w.shape());\n    assert_eq!(x_grad.shape(), x.shape());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cast.rs",
    "content": "// Skip on metal - F64 not supported\n#![cfg(all(feature = \"std\", not(feature = \"metal\")))]\n\nuse super::*;\nuse burn_backend_tests::might_panic;\nuse burn_tensor::{DType, Tensor, TensorData};\n\n#[might_panic(reason = \"Unsupported precision for fusion\")]\n#[test]\nfn cast_keeps_gradient_flow() {\n    let device = Default::default();\n\n    let x = Tensor::<TestAutodiffBackend, 2>::from_data(\n        TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let y = x.clone().cast(DType::F64);\n    let z = y.sum();\n\n    let grads = z.backward();\n    let grad_x = x.grad(&grads).unwrap();\n\n    grad_x\n        .to_data()\n        .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cat.rs",
    "content": "use super::*;\n\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_cat() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let mut tensor_1_list = Vec::new();\n    let mut tensor_2_list = Vec::new();\n\n    for i in 0..2 {\n        tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));\n        tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));\n    }\n\n    let tensor_1_cat = TestAutodiffTensor::cat(tensor_1_list.clone(), 0);\n    let tensor_2_cat = TestAutodiffTensor::cat(tensor_2_list.clone(), 0);\n\n    let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());\n    let grads = tensor_3_cat.backward();\n\n    let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);\n    let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);\n\n    let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);\n    let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);\n\n    grad_1\n        .clone()\n        .slice([0..1])\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&grad_1_slice_1.to_data(), Tolerance::default());\n    grad_1\n        .slice([1..2])\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&grad_1_slice_2.to_data(), Tolerance::default());\n\n    grad_2\n        .clone()\n        .slice([0..1])\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&grad_2_slice_1.to_data(), Tolerance::default());\n    grad_2\n        .slice([1..2])\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&grad_2_slice_2.to_data(), Tolerance::default());\n}\n\n#[test]\nfn should_diff_cat_more_than_1_dim() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::<2>::from_data([[5.0, 4.0], [-1.0, 4.0], [4.0, 1.0]], &device)\n            .require_grad();\n\n    // Concat a tensor [2, 2] with another tensor [3, 2] along dim 0.\n    // The resulting tensor should be [5, 2]\n    let tensor_3 = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 0);\n    assert_eq!(tensor_3.dims(), [5, 2]);\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    assert_eq!(tensor_1.dims(), grad_1.dims());\n    assert_eq!(tensor_2.dims(), grad_2.dims());\n}\n\n#[test]\nfn should_slice_grads_correctly_when_some_inputs_not_tracked() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data([[1.0]], &device).require_grad(); // tracked\n    let tensor_2 = TestAutodiffTensor::<2>::from_data([[10.0, 20.0]], &device); // not tracked\n    let tensor_3 =\n        TestAutodiffTensor::<2>::from_data([[100.0, 200.0, 300.0]], &device).require_grad(); // tracked\n\n    let cat = TestAutodiffTensor::cat(\n        vec![tensor_1.clone(), tensor_2.clone(), tensor_3.clone()],\n        1,\n    );\n\n    // Make gradient per column unique so wrong slicing shows up.\n    let weights = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], &device);\n    let loss = (cat * weights).sum();\n\n    let grads = loss.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_3 = tensor_3.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&burn_tensor::TensorData::from([[1.0]]), false);\n    grad_3\n        .to_data()\n        .assert_eq(&burn_tensor::TensorData::from([[4.0, 5.0, 6.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/ceil.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_ceil() {\n    let data = TensorData::from([\n        [-1.9751, 0.0714, 0.0643, 0.2406],\n        [-1.3172, 0.1252, -0.1119, -0.0127],\n    ]);\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n    let tensor_2 = tensor_1.clone().ceil();\n    let grads = tensor_2.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/checkpoint.rs",
    "content": "use super::*;\nuse burn_tensor::{Bool, Tensor, TensorData};\n\n#[test]\nfn test_autodiff_checkpoint_complicated_computation() {\n    let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);\n    let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);\n    let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);\n    let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);\n    let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n    let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();\n\n    let tensor_5 = compute_bound_eager(tensor_0, tensor_1);\n    let tensor_6 = compute_bound_lazy(tensor_2, tensor_3.clone());\n    let tensor_7 = memory_bound_eager(tensor_3, tensor_4);\n    let tensor_8 = compute_bound_lazy(tensor_6, tensor_7.clone());\n    let tensor_9 = memory_bound_eager_scalar(tensor_7, 11.);\n    let tensor_10 = memory_bound_lazy(tensor_5, tensor_8.clone());\n    let tensor_11 = memory_bound_lazy(tensor_8, tensor_9);\n    let tensor_12 = compute_bound_lazy(tensor_10, tensor_11);\n\n    assert_checkpoint(tensor_12);\n}\n\n#[test]\nfn test_autodiff_checkpoint_with_missing_requirement() {\n    let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);\n    let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device); // does not require_grad\n\n    let tensor_2 = memory_bound_eager(tensor_0, tensor_1);\n    let tensor_3 = memory_bound_eager_scalar(tensor_2.clone(), 11.);\n    let tensor_4 = memory_bound_eager_scalar(tensor_2.clone(), 11.);\n    let tensor_5 = compute_bound_lazy(tensor_3, tensor_4);\n    let tensor_6 = compute_bound_eager_scalar(tensor_5.clone(), 11.);\n    let tensor_7 = memory_bound_eager(tensor_5, tensor_2);\n    let tensor_8 = memory_bound_eager(tensor_6, tensor_7);\n\n    assert_checkpoint(tensor_8);\n}\n\n#[test]\nfn test_autodiff_checkpoint_with_many_duplicates() {\n    let data_0 = TensorData::from([[4.0, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();\n\n    let tensor_1 = memory_bound_eager(tensor_0.clone(), tensor_0.clone());\n    let tensor_2 = compute_bound_eager(tensor_0.clone(), tensor_0.clone());\n    let tensor_3 = memory_bound_lazy(tensor_0.clone(), tensor_0.clone());\n    let tensor_4 = compute_bound_lazy(tensor_0.clone(), tensor_0.clone());\n\n    let tensor_5 = memory_bound_eager(tensor_1.clone(), tensor_0.clone());\n    let tensor_6 = memory_bound_eager(tensor_0.clone(), tensor_5.clone());\n    let tensor_7 = compute_bound_lazy(tensor_3.clone(), tensor_5.clone());\n    let tensor_8 = compute_bound_eager(tensor_4.clone(), tensor_2.clone());\n    let tensor_9 = memory_bound_lazy(tensor_6, tensor_7);\n    let tensor_10 = memory_bound_eager(tensor_0, tensor_9);\n    let tensor_11 = memory_bound_eager_scalar(tensor_10, 9.);\n    let tensor_12 = compute_bound_lazy(tensor_8, tensor_11);\n\n    assert_checkpoint(tensor_12);\n}\n\n#[test]\nfn test_autodiff_checkpoint_with_long_chain_of_eager_memory_bound() {\n    let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);\n    let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);\n    let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);\n    let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);\n    let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n    let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();\n\n    let tensor_5 = memory_bound_eager(tensor_0, tensor_1.clone());\n    let tensor_6 = memory_bound_eager(tensor_5, tensor_2);\n    let tensor_7 = memory_bound_eager(tensor_6, tensor_3);\n    let tensor_8 = memory_bound_eager(tensor_7, tensor_4);\n    let tensor_9 = memory_bound_lazy(tensor_8, tensor_1);\n\n    assert_checkpoint(tensor_9)\n}\n\n#[test]\nfn test_autodiff_checkpoint_half_sub_graph_not_tracked() {\n    let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);\n    let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);\n    let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);\n    let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);\n    let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);\n    let data_5 = TensorData::from([[0.5, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device);\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device);\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n    let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();\n    let tensor_5 = TestAutodiffTensor::from_data(data_5, &device).require_grad();\n\n    let tensor_6 = memory_bound_lazy(tensor_0, tensor_1);\n    let tensor_7 = compute_bound_eager(tensor_6, tensor_2);\n\n    let tensor_8 = memory_bound_eager(tensor_3, tensor_4);\n    let tensor_9 = compute_bound_lazy(tensor_8, tensor_5);\n\n    let tensor_10 = compute_bound_lazy(tensor_7, tensor_9);\n\n    assert_checkpoint(tensor_10);\n}\n\n#[test]\nfn test_autodiff_checkpoint_very_complex() {\n    let data_0 = TensorData::from([[0.0, 7.0], [7.0, 7.0]]);\n    let data_1 = TensorData::from([[0.1, 7.0], [7.0, 7.0]]);\n    let data_2 = TensorData::from([[0.2, 7.0], [7.0, 7.0]]);\n    let data_3 = TensorData::from([[0.3, 7.0], [7.0, 7.0]]);\n    let data_4 = TensorData::from([[0.4, 7.0], [7.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data_0, &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device);\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n    let tensor_4 = TestAutodiffTensor::from_data(data_4, &device).require_grad();\n\n    let tensor_5 = memory_bound_eager_scalar(tensor_0, 8.);\n    let tensor_6 = memory_bound_lazy(tensor_5.clone(), tensor_1.clone());\n    let tensor_7 = compute_bound_lazy(tensor_6.clone(), tensor_6);\n    let tensor_8 = memory_bound_lazy(tensor_1.clone(), tensor_5.clone());\n    let tensor_9 = memory_bound_eager_scalar(tensor_7.clone(), 7.);\n    let tensor_10 = compute_bound_eager(tensor_5, tensor_8);\n    let tensor_11 = memory_bound_eager(tensor_2.clone(), tensor_9);\n    let tensor_12 = memory_bound_lazy(tensor_2.clone(), tensor_2);\n    let tensor_13 = compute_bound_eager(tensor_10.clone(), tensor_11);\n    let tensor_14 = compute_bound_eager_scalar(tensor_3, 8.);\n    let tensor_15 = compute_bound_lazy(tensor_4, tensor_12);\n    let tensor_16 = memory_bound_lazy(tensor_10, tensor_7);\n    let tensor_17 = compute_bound_lazy(tensor_13, tensor_1);\n    let tensor_18 = memory_bound_eager(tensor_15, tensor_16);\n    let tensor_19 = compute_bound_eager(tensor_14, tensor_17);\n    let tensor_20 = memory_bound_lazy(tensor_18, tensor_19);\n    let tensor_21 = memory_bound_eager_scalar(tensor_20, 8.);\n\n    assert_checkpoint(tensor_21)\n}\n\nfn assert_checkpoint<const D: usize>(tensor: TestAutodiffTensor<D>) {\n    // Assert is not explicit here, but the test can fail\n    // - when a tensor is actually required more than n_required, it won't be found and will panic\n    // - when a tensor is actually required less than n_required, the backward states map won't be\n    //   empty and will fail the assertion within the backward code, same for retro_forwards\n    tensor.backward();\n}\n\n// Does not save its state and does not need its parents\nfn memory_bound_eager<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    tensor_b: TestAutodiffTensor<D>,\n) -> TestAutodiffTensor<D> {\n    tensor_a.add(tensor_b)\n}\nfn memory_bound_eager_scalar<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    b: f32,\n) -> TestAutodiffTensor<D> {\n    tensor_a.add_scalar(b)\n}\n\n// Saves its own state and does not need its parents\nfn compute_bound_eager<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    tensor_b: TestAutodiffTensor<D>,\n) -> TestAutodiffTensor<D> {\n    let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());\n    tensor_a.mask_where(mask, tensor_b)\n}\nfn compute_bound_eager_scalar<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    b: f32,\n) -> TestAutodiffTensor<D> {\n    let mask = Tensor::<TestAutodiffBackend, D, Bool>::empty(tensor_a.shape(), &tensor_a.device());\n    tensor_a.mask_fill(mask, b)\n}\n\n// Does not save its state and needs its parents\nfn memory_bound_lazy<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    tensor_b: TestAutodiffTensor<D>,\n) -> TestAutodiffTensor<D> {\n    tensor_a.mul(tensor_b)\n}\n\n// Saves its own state and needs its parents\nfn compute_bound_lazy<const D: usize>(\n    tensor_a: TestAutodiffTensor<D>,\n    tensor_b: TestAutodiffTensor<D>,\n) -> TestAutodiffTensor<D> {\n    tensor_a.matmul(tensor_b)\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/complex.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_full_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.matmul(tensor_1.clone());\n    let tensor_5 = tensor_4.mul(tensor_2.clone());\n\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[593., 463.0], [487.0, 539.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[734.0, 294.0], [1414.0, 242.0]]), false);\n}\n\n#[test]\nfn should_diff_full_complex_2() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.matmul(tensor_1.clone());\n    let tensor_5 = tensor_4.add_scalar(17.0).add(tensor_2.clone());\n\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[166.0, 110.0], [212.0, 156.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[113.0, 141.0], [33.0, 41.0]]), false);\n}\n\n#[test]\nfn should_diff_full_complex_3() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.matmul(tensor_1.clone());\n    let tensor_5 = tensor_4.clone().sub(tensor_2.clone());\n    let tensor_6 = tensor_5.add(tensor_4);\n\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[332.0, 220.0], [424.0, 312.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[223.0, 279.0], [63.0, 79.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv1d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv1d, ops::ConvOptions};\n\n#[test]\nfn test_conv1d_basic() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[14., 24., 24., 18.], [26., 42., 42., 30.]],\n                [[14., 24., 24., 18.], [26., 42., 42., 30.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[30., 44., 36.], [54., 76., 60.]],\n                [[30., 44., 36.], [54., 76., 60.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv1d_different_channels() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[39., 63., 63., 45.], [57., 90., 90., 63.]],\n                [[39., 63., 63., 45.], [57., 90., 90., 63.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[30., 44., 36.], [54., 76., 60.]],\n                [[30., 44., 36.], [54., 76., 60.]],\n                [[30., 44., 36.], [54., 76., 60.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8., 8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv1d_with_padding() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 2,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[24., 24., 24., 24.], [42., 42., 42., 42.]],\n                [[24., 24., 24., 24.], [42., 42., 42., 42.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[44., 44., 44.], [76., 76., 76.]],\n                [[44., 44., 44.], [76., 76., 76.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([12., 12.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv1d_with_stride() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[8., 16., 8., 10.], [14., 28., 14., 16.]],\n                [[8., 16., 8., 10.], [14., 28., 14., 16.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[10., 20., 24.], [18., 36., 40.]],\n                [[10., 20., 24.], [18., 36., 40.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv1d_dilation() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 2,\n        groups: 1,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[6., 8., 8., 10.], [12., 14., 14., 16.]],\n                [[6., 8., 8., 10.], [12., 14., 14., 16.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[8., 22., 14.], [16., 38., 22.]],\n                [[8., 22., 14.], [16., 38., 22.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv1d_groups() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 1,\n        groups: 2,\n        length: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[1., 3., 3., 3.], [7., 12., 12., 9.]],\n                [[1., 3., 3., 3.], [7., 12., 12., 9.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats([[[30., 44., 36.]], [[54., 76., 60.]]], &device),\n        bias: TestTensor::from_floats([8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct Conv1dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size: usize,\n    padding: usize,\n    stride: usize,\n    dilation: usize,\n    groups: usize,\n    length: usize,\n}\n\nstruct Grads {\n    x: TestTensor<3>,\n    weight: TestTensor<3>,\n    bias: TestTensor<1>,\n}\n\nimpl Conv1dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size,\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n\n        let output = conv1d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::default();\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv2d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv2d, ops::ConvOptions};\n\n#[test]\nfn test_conv2d_basic() {\n    let test = Conv2dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [88., 138., 138., 96.],\n                        [150., 234., 234., 162.],\n                        [150., 234., 234., 162.],\n                        [112., 174., 174., 120.],\n                    ],\n                    [\n                        [160., 246., 246., 168.],\n                        [258., 396., 396., 270.],\n                        [258., 396., 396., 270.],\n                        [184., 282., 282., 192.],\n                    ],\n                ],\n                [\n                    [\n                        [88., 138., 138., 96.],\n                        [150., 234., 234., 162.],\n                        [150., 234., 234., 162.],\n                        [112., 174., 174., 120.],\n                    ],\n                    [\n                        [160., 246., 246., 168.],\n                        [258., 396., 396., 270.],\n                        [258., 396., 396., 270.],\n                        [184., 282., 282., 192.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],\n                    [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],\n                ],\n                [\n                    [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],\n                    [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([32., 32.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_channels() {\n    let test = Conv2dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [240., 369., 369., 252.],\n                        [387., 594., 594., 405.],\n                        [387., 594., 594., 405.],\n                        [276., 423., 423., 288.],\n                    ],\n                    [\n                        [348., 531., 531., 360.],\n                        [549., 837., 837., 567.],\n                        [549., 837., 837., 567.],\n                        [384., 585., 585., 396.],\n                    ],\n                ],\n                [\n                    [\n                        [240., 369., 369., 252.],\n                        [387., 594., 594., 405.],\n                        [387., 594., 594., 405.],\n                        [276., 423., 423., 288.],\n                    ],\n                    [\n                        [348., 531., 531., 360.],\n                        [549., 837., 837., 567.],\n                        [549., 837., 837., 567.],\n                        [384., 585., 585., 396.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],\n                    [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],\n                ],\n                [\n                    [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],\n                    [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],\n                ],\n                [\n                    [[378., 516., 396.], [552., 752., 576.], [450., 612., 468.]],\n                    [[666., 900., 684.], [936., 1264., 960.], [738., 996., 756.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([32., 32., 32.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_kernel_size() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [116., 180., 192., 132.],\n                    [198., 306., 324., 222.],\n                    [198., 306., 324., 222.],\n                    [148., 228., 240., 164.],\n                ],\n                [\n                    [212., 324., 336., 228.],\n                    [342., 522., 540., 366.],\n                    [342., 522., 540., 366.],\n                    [244., 372., 384., 260.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [27., 45., 54., 39.],\n                        [52., 84., 96., 68.],\n                        [51., 81., 90., 63.],\n                    ],\n                    [\n                        [123., 189., 198., 135.],\n                        [180., 276., 288., 196.],\n                        [147., 225., 234., 159.],\n                    ],\n                ],\n                [\n                    [\n                        [27., 45., 54., 39.],\n                        [52., 84., 96., 68.],\n                        [51., 81., 90., 63.],\n                    ],\n                    [\n                        [123., 189., 198., 135.],\n                        [180., 276., 288., 196.],\n                        [147., 225., 234., 159.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([12., 12.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_padding() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [138., 138., 138., 138.],\n                    [234., 234., 234., 234.],\n                    [234., 234., 234., 234.],\n                    [174., 174., 174., 174.],\n                ],\n                [\n                    [246., 246., 246., 246.],\n                    [396., 396., 396., 396.],\n                    [396., 396., 396., 396.],\n                    [282., 282., 282., 282.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],\n                    [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],\n                ],\n                [\n                    [[66., 66., 66.], [120., 120., 120.], [114., 114., 114.]],\n                    [[258., 258., 258.], [376., 376., 376.], [306., 306., 306.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([24., 24.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_width() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 5,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [88., 138., 138., 138., 96.],\n                    [150., 234., 234., 234., 162.],\n                    [150., 234., 234., 234., 162.],\n                    [112., 174., 174., 174., 120.],\n                ],\n                [\n                    [160., 246., 246., 246., 168.],\n                    [258., 396., 396., 396., 270.],\n                    [258., 396., 396., 396., 270.],\n                    [184., 282., 282., 282., 192.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],\n                    [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],\n                ],\n                [\n                    [[78., 105., 90.], [144., 190., 160.], [138., 180., 150.]],\n                    [[318., 405., 330.], [464., 590., 480.], [378., 480., 390.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([20., 20.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_stride_2() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 2,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 6,\n        width: 6,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [26., 52., 26., 52., 26., 28.],\n                    [52., 104., 52., 104., 52., 56.],\n                    [26., 52., 26., 52., 26., 28.],\n                    [52., 104., 52., 104., 52., 56.],\n                    [26., 52., 26., 52., 26., 28.],\n                    [32., 64., 32., 64., 32., 34.],\n                ],\n                [\n                    [44., 88., 44., 88., 44., 46.],\n                    [88., 176., 88., 176., 88., 92.],\n                    [44., 88., 44., 88., 44., 46.],\n                    [88., 176., 88., 176., 88., 92.],\n                    [44., 88., 44., 88., 44., 46.],\n                    [50., 100., 50., 100., 50., 52.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],\n                    [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],\n                ],\n                [\n                    [[56., 84., 90.], [84., 126., 135.], [120., 180., 189.]],\n                    [[200., 300., 306.], [300., 450., 459.], [336., 504., 513.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([9., 9.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_stride() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 3,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 8,\n        width: 8,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [50., 78., 78., 78., 78., 78., 78., 54.],\n                    [62., 96., 96., 96., 96., 96., 96., 66.],\n                    [38., 60., 60., 60., 60., 60., 60., 42.],\n                    [50., 78., 78., 78., 78., 78., 78., 54.],\n                    [62., 96., 96., 96., 96., 96., 96., 66.],\n                    [38., 60., 60., 60., 60., 60., 60., 42.],\n                    [50., 78., 78., 78., 78., 78., 78., 54.],\n                    [62., 96., 96., 96., 96., 96., 96., 66.],\n                ],\n                [\n                    [86., 132., 132., 132., 132., 132., 132., 90.],\n                    [98., 150., 150., 150., 150., 150., 150., 102.],\n                    [74., 114., 114., 114., 114., 114., 114., 78.],\n                    [86., 132., 132., 132., 132., 132., 132., 90.],\n                    [98., 150., 150., 150., 150., 150., 150., 102.],\n                    [74., 114., 114., 114., 114., 114., 114., 78.],\n                    [86., 132., 132., 132., 132., 132., 132., 90.],\n                    [98., 150., 150., 150., 150., 150., 150., 102.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],\n                    [\n                        [1330., 1528., 1344.],\n                        [1911., 2196., 1932.],\n                        [2079., 2388., 2100.],\n                    ],\n                ],\n                [\n                    [[434., 504., 448.], [567., 660., 588.], [735., 852., 756.]],\n                    [\n                        [1330., 1528., 1344.],\n                        [1911., 2196., 1932.],\n                        [2079., 2388., 2100.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([24., 24.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_dilation_2() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 2,\n        dilation_2: 2,\n        groups: 1,\n        height: 6,\n        width: 6,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [18., 38., 38., 42., 42., 22.],\n                    [42., 88., 88., 96., 96., 50.],\n                    [42., 88., 88., 96., 96., 50.],\n                    [54., 112., 112., 120., 120., 62.],\n                    [54., 112., 112., 120., 120., 62.],\n                    [30., 62., 62., 66., 66., 34.],\n                ],\n                [\n                    [36., 74., 74., 78., 78., 40.],\n                    [78., 160., 160., 168., 168., 86.],\n                    [78., 160., 160., 168., 168., 86.],\n                    [90., 184., 184., 192., 192., 98.],\n                    [90., 184., 184., 192., 192., 98.],\n                    [48., 98., 98., 102., 102., 52.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],\n                    [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],\n                ],\n                [\n                    [[63., 102., 90.], [192., 280., 228.], [225., 318., 252.]],\n                    [[387., 534., 414.], [624., 856., 660.], [549., 750., 576.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([16., 16.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_different_dilation() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 2,\n        dilation_2: 3,\n        groups: 1,\n        height: 6,\n        width: 6,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [18., 0., 20., 20., 0., 22.],\n                    [42., 0., 46., 46., 0., 50.],\n                    [42., 0., 46., 46., 0., 50.],\n                    [54., 0., 58., 58., 0., 62.],\n                    [54., 0., 58., 58., 0., 62.],\n                    [30., 0., 32., 32., 0., 34.],\n                ],\n                [\n                    [36., 0., 38., 38., 0., 40.],\n                    [78., 0., 82., 82., 0., 86.],\n                    [78., 0., 82., 82., 0., 86.],\n                    [90., 0., 94., 94., 0., 98.],\n                    [90., 0., 94., 94., 0., 98.],\n                    [48., 0., 50., 50., 0., 52.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],\n                    [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],\n                ],\n                [\n                    [[18., 51., 33.], [60., 140., 80.], [72., 159., 87.]],\n                    [[126., 267., 141.], [204., 428., 224.], [180., 375., 195.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_groups() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 5,\n        width: 5,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [0., 1., 3., 3., 2.],\n                    [3., 8., 15., 12., 7.],\n                    [9., 21., 36., 27., 15.],\n                    [9., 20., 33., 24., 13.],\n                    [6., 13., 21., 15., 8.],\n                ],\n                [\n                    [9., 19., 30., 21., 11.],\n                    [21., 44., 69., 48., 25.],\n                    [36., 75., 117., 81., 42.],\n                    [27., 56., 87., 60., 31.],\n                    [15., 31., 48., 33., 17.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[[54., 63., 72.], [99., 108., 117.], [144., 153., 162.]]],\n                [[[279., 288., 297.], [324., 333., 342.], [369., 378., 387.]]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([9., 9.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_groups_stride_2() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 4,\n        channels_out: 4,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 2,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 4,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [4., 8., 4., 5.],\n                    [8., 16., 8., 10.],\n                    [4., 8., 4., 5.],\n                    [7., 14., 7., 8.],\n                ],\n                [\n                    [13., 26., 13., 14.],\n                    [26., 52., 26., 28.],\n                    [13., 26., 13., 14.],\n                    [16., 32., 16., 17.],\n                ],\n                [\n                    [22., 44., 22., 23.],\n                    [44., 88., 44., 46.],\n                    [22., 44., 22., 23.],\n                    [25., 50., 25., 26.],\n                ],\n                [\n                    [31., 62., 31., 32.],\n                    [62., 124., 62., 64.],\n                    [31., 62., 31., 32.],\n                    [34., 68., 34., 35.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[[5., 10., 12.], [10., 20., 24.], [18., 36., 40.]]],\n                [[[21., 42., 44.], [42., 84., 88.], [50., 100., 104.]]],\n                [[[37., 74., 76.], [74., 148., 152.], [82., 164., 168.]]],\n                [[[53., 106., 108.], [106., 212., 216.], [114., 228., 232.]]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4., 4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_groups_different_channels() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 6,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 3,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [9., 20., 24., 13.],\n                    [24., 52., 60., 32.],\n                    [36., 76., 84., 44.],\n                    [21., 44., 48., 25.],\n                ],\n                [\n                    [45., 92., 96., 49.],\n                    [96., 196., 204., 104.],\n                    [108., 220., 228., 116.],\n                    [57., 116., 120., 61.],\n                ],\n                [\n                    [81., 164., 168., 85.],\n                    [168., 340., 348., 176.],\n                    [180., 364., 372., 188.],\n                    [93., 188., 192., 97.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],\n                [[[10., 14., 18.], [26., 30., 34.], [42., 46., 50.]]],\n                [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],\n                [[[74., 78., 82.], [90., 94., 98.], [106., 110., 114.]]],\n                [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],\n                [[[138., 142., 146.], [154., 158., 162.], [170., 174., 178.]]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4., 4., 4., 4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_complex() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 2,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 2,\n        dilation_1: 2,\n        dilation_2: 3,\n        groups: 1,\n        height: 4,\n        width: 5,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [36., 39., 0., 39., 42.],\n                    [81., 87., 0., 87., 93.],\n                    [81., 87., 0., 87., 93.],\n                    [45., 48., 0., 48., 51.],\n                ],\n                [\n                    [54., 57., 0., 57., 60.],\n                    [117., 123., 0., 123., 129.],\n                    [117., 123., 0., 123., 129.],\n                    [63., 66., 0., 66., 69.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[15., 42., 27.], [30., 72., 42.]],\n                    [[75., 162., 87.], [90., 192., 102.]],\n                ],\n                [\n                    [[15., 42., 27.], [30., 72., 42.]],\n                    [[75., 162., 87.], [90., 192., 102.]],\n                ],\n                [\n                    [[15., 42., 27.], [30., 72., 42.]],\n                    [[75., 162., 87.], [90., 192., 102.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8., 8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv2d_groups_stride_2_no_pad() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 4,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 2,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [0., 1., 2., 0.],\n                    [3., 4., 5., 0.],\n                    [6., 7., 8., 0.],\n                    [0., 0., 0., 0.],\n                ],\n                [\n                    [9., 10., 11., 0.],\n                    [12., 13., 14., 0.],\n                    [15., 16., 17., 0.],\n                    [0., 0., 0., 0.],\n                ],\n                [\n                    [18., 19., 20., 0.],\n                    [21., 22., 23., 0.],\n                    [24., 25., 26., 0.],\n                    [0., 0., 0., 0.],\n                ],\n                [\n                    [27., 28., 29., 0.],\n                    [30., 31., 32., 0.],\n                    [33., 34., 35., 0.],\n                    [0., 0., 0., 0.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],\n                    [[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],\n                ],\n                [\n                    [[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],\n                    [[48., 49., 50.], [52., 53., 54.], [56., 57., 58.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([1., 1.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct Conv2dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    groups: usize,\n    height: usize,\n    width: usize,\n}\n\nstruct Grads {\n    x: TestTensor<4>,\n    weight: TestTensor<4>,\n    bias: TestTensor<1>,\n}\n\nimpl Conv2dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = conv2d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvOptions::new(\n                [self.stride_1, self.stride_2],\n                [self.padding_1, self.padding_2],\n                [self.dilation_1, self.dilation_2],\n                self.groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::rel_abs(0.01, 0.01);\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv3d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv3d, ops::ConvOptions};\n\n#[test]\nfn test_conv3d_basic() {\n    let test = Conv3dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 4,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [\n                            [536., 816., 816., 552.],\n                            [840., 1278., 1278., 864.],\n                            [840., 1278., 1278., 864.],\n                            [584., 888., 888., 600.],\n                        ],\n                        [\n                            [912., 1386., 1386., 936.],\n                            [1422., 2160., 2160., 1458.],\n                            [1422., 2160., 2160., 1458.],\n                            [984., 1494., 1494., 1008.],\n                        ],\n                        [\n                            [912., 1386., 1386., 936.],\n                            [1422., 2160., 2160., 1458.],\n                            [1422., 2160., 2160., 1458.],\n                            [984., 1494., 1494., 1008.],\n                        ],\n                        [\n                            [680., 1032., 1032., 696.],\n                            [1056., 1602., 1602., 1080.],\n                            [1056., 1602., 1602., 1080.],\n                            [728., 1104., 1104., 744.],\n                        ],\n                    ],\n                    [\n                        [\n                            [968., 1464., 1464., 984.],\n                            [1488., 2250., 2250., 1512.],\n                            [1488., 2250., 2250., 1512.],\n                            [1016., 1536., 1536., 1032.],\n                        ],\n                        [\n                            [1560., 2358., 2358., 1584.],\n                            [2394., 3618., 3618., 2430.],\n                            [2394., 3618., 3618., 2430.],\n                            [1632., 2466., 2466., 1656.],\n                        ],\n                        [\n                            [1560., 2358., 2358., 1584.],\n                            [2394., 3618., 3618., 2430.],\n                            [2394., 3618., 3618., 2430.],\n                            [1632., 2466., 2466., 1656.],\n                        ],\n                        [\n                            [1112., 1680., 1680., 1128.],\n                            [1704., 2574., 2574., 1728.],\n                            [1704., 2574., 2574., 1728.],\n                            [1160., 1752., 1752., 1176.],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [536., 816., 816., 552.],\n                            [840., 1278., 1278., 864.],\n                            [840., 1278., 1278., 864.],\n                            [584., 888., 888., 600.],\n                        ],\n                        [\n                            [912., 1386., 1386., 936.],\n                            [1422., 2160., 2160., 1458.],\n                            [1422., 2160., 2160., 1458.],\n                            [984., 1494., 1494., 1008.],\n                        ],\n                        [\n                            [912., 1386., 1386., 936.],\n                            [1422., 2160., 2160., 1458.],\n                            [1422., 2160., 2160., 1458.],\n                            [984., 1494., 1494., 1008.],\n                        ],\n                        [\n                            [680., 1032., 1032., 696.],\n                            [1056., 1602., 1602., 1080.],\n                            [1056., 1602., 1602., 1080.],\n                            [728., 1104., 1104., 744.],\n                        ],\n                    ],\n                    [\n                        [\n                            [968., 1464., 1464., 984.],\n                            [1488., 2250., 2250., 1512.],\n                            [1488., 2250., 2250., 1512.],\n                            [1016., 1536., 1536., 1032.],\n                        ],\n                        [\n                            [1560., 2358., 2358., 1584.],\n                            [2394., 3618., 3618., 2430.],\n                            [2394., 3618., 3618., 2430.],\n                            [1632., 2466., 2466., 1656.],\n                        ],\n                        [\n                            [1560., 2358., 2358., 1584.],\n                            [2394., 3618., 3618., 2430.],\n                            [2394., 3618., 3618., 2430.],\n                            [1632., 2466., 2466., 1656.],\n                        ],\n                        [\n                            [1112., 1680., 1680., 1128.],\n                            [1704., 2574., 2574., 1728.],\n                            [1704., 2574., 2574., 1728.],\n                            [1160., 1752., 1752., 1176.],\n                        ],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [\n                            [4590., 6156., 4644.],\n                            [6264., 8400., 6336.],\n                            [4806., 6444., 4860.],\n                        ],\n                        [\n                            [6696., 8976., 6768.],\n                            [9120., 12224., 9216.],\n                            [6984., 9360., 7056.],\n                        ],\n                        [\n                            [5454., 7308., 5508.],\n                            [7416., 9936., 7488.],\n                            [5670., 7596., 5724.],\n                        ],\n                    ],\n                    [\n                        [\n                            [8046., 10764., 8100.],\n                            [10872., 14544., 10944.],\n                            [8262., 11052., 8316.],\n                        ],\n                        [\n                            [11304., 15120., 11376.],\n                            [15264., 20416., 15360.],\n                            [11592., 15504., 11664.],\n                        ],\n                        [\n                            [8910., 11916., 8964.],\n                            [12024., 16080., 12096.],\n                            [9126., 12204., 9180.],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [4590., 6156., 4644.],\n                            [6264., 8400., 6336.],\n                            [4806., 6444., 4860.],\n                        ],\n                        [\n                            [6696., 8976., 6768.],\n                            [9120., 12224., 9216.],\n                            [6984., 9360., 7056.],\n                        ],\n                        [\n                            [5454., 7308., 5508.],\n                            [7416., 9936., 7488.],\n                            [5670., 7596., 5724.],\n                        ],\n                    ],\n                    [\n                        [\n                            [8046., 10764., 8100.],\n                            [10872., 14544., 10944.],\n                            [8262., 11052., 8316.],\n                        ],\n                        [\n                            [11304., 15120., 11376.],\n                            [15264., 20416., 15360.],\n                            [11592., 15504., 11664.],\n                        ],\n                        [\n                            [8910., 11916., 8964.],\n                            [12024., 16080., 12096.],\n                            [9126., 12204., 9180.],\n                        ],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([128., 128.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv3d_complex() {\n    let test = Conv3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 2,\n        kernel_size_2: 3,\n        kernel_size_3: 4,\n        padding_1: 1,\n        padding_2: 2,\n        padding_3: 3,\n        stride_1: 1,\n        stride_2: 2,\n        stride_3: 3,\n        dilation_1: 2,\n        dilation_2: 3,\n        dilation_3: 4,\n        groups: 1,\n        depth: 5,\n        height: 6,\n        width: 7,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [\n                        [0., 147., 0., 0., 0., 150., 0.],\n                        [0., 159., 0., 0., 0., 162., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 159., 0., 0., 0., 162., 0.],\n                        [0., 171., 0., 0., 0., 174., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 330., 0., 0., 0., 336., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 378., 0., 0., 0., 384., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 330., 0., 0., 0., 336., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 378., 0., 0., 0., 384., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 330., 0., 0., 0., 336., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 354., 0., 0., 0., 360., 0.],\n                        [0., 378., 0., 0., 0., 384., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 183., 0., 0., 0., 186., 0.],\n                        [0., 195., 0., 0., 0., 198., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 195., 0., 0., 0., 198., 0.],\n                        [0., 207., 0., 0., 0., 210., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                ],\n                [\n                    [\n                        [0., 219., 0., 0., 0., 222., 0.],\n                        [0., 231., 0., 0., 0., 234., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 231., 0., 0., 0., 234., 0.],\n                        [0., 243., 0., 0., 0., 246., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 474., 0., 0., 0., 480., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 522., 0., 0., 0., 528., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 474., 0., 0., 0., 480., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 522., 0., 0., 0., 528., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 474., 0., 0., 0., 480., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 498., 0., 0., 0., 504., 0.],\n                        [0., 522., 0., 0., 0., 528., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 255., 0., 0., 0., 258., 0.],\n                        [0., 267., 0., 0., 0., 270., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                        [0., 267., 0., 0., 0., 270., 0.],\n                        [0., 279., 0., 0., 0., 282., 0.],\n                        [0., 0., 0., 0., 0., 0., 0.],\n                    ],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [\n                            [0., 256., 272., 0.],\n                            [0., 624., 656., 0.],\n                            [0., 368., 384., 0.],\n                        ],\n                        [\n                            [0., 424., 440., 0.],\n                            [0., 960., 992., 0.],\n                            [0., 536., 552., 0.],\n                        ],\n                    ],\n                    [\n                        [\n                            [0., 1096., 1112., 0.],\n                            [0., 2304., 2336., 0.],\n                            [0., 1208., 1224., 0.],\n                        ],\n                        [\n                            [0., 1264., 1280., 0.],\n                            [0., 2640., 2672., 0.],\n                            [0., 1376., 1392., 0.],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [0., 256., 272., 0.],\n                            [0., 624., 656., 0.],\n                            [0., 368., 384., 0.],\n                        ],\n                        [\n                            [0., 424., 440., 0.],\n                            [0., 960., 992., 0.],\n                            [0., 536., 552., 0.],\n                        ],\n                    ],\n                    [\n                        [\n                            [0., 1096., 1112., 0.],\n                            [0., 2304., 2336., 0.],\n                            [0., 1208., 1224., 0.],\n                        ],\n                        [\n                            [0., 1264., 1280., 0.],\n                            [0., 2640., 2672., 0.],\n                            [0., 1376., 1392., 0.],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [0., 256., 272., 0.],\n                            [0., 624., 656., 0.],\n                            [0., 368., 384., 0.],\n                        ],\n                        [\n                            [0., 424., 440., 0.],\n                            [0., 960., 992., 0.],\n                            [0., 536., 552., 0.],\n                        ],\n                    ],\n                    [\n                        [\n                            [0., 1096., 1112., 0.],\n                            [0., 2304., 2336., 0.],\n                            [0., 1208., 1224., 0.],\n                        ],\n                        [\n                            [0., 1264., 1280., 0.],\n                            [0., 2640., 2672., 0.],\n                            [0., 1376., 1392., 0.],\n                        ],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([10., 10., 10.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv3d_groups_stride_2_no_pad() {\n    let test = Conv3dTestCase {\n        batch_size: 1,\n        channels_in: 4,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 0,\n        padding_2: 0,\n        padding_3: 0,\n        stride_1: 2,\n        stride_2: 2,\n        stride_3: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 2,\n        depth: 4,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [\n                        [0., 1., 2., 0.],\n                        [3., 4., 5., 0.],\n                        [6., 7., 8., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [9., 10., 11., 0.],\n                        [12., 13., 14., 0.],\n                        [15., 16., 17., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [18., 19., 20., 0.],\n                        [21., 22., 23., 0.],\n                        [24., 25., 26., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                ],\n                [\n                    [\n                        [27., 28., 29., 0.],\n                        [30., 31., 32., 0.],\n                        [33., 34., 35., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [36., 37., 38., 0.],\n                        [39., 40., 41., 0.],\n                        [42., 43., 44., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [45., 46., 47., 0.],\n                        [48., 49., 50., 0.],\n                        [51., 52., 53., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                ],\n                [\n                    [\n                        [54., 55., 56., 0.],\n                        [57., 58., 59., 0.],\n                        [60., 61., 62., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [63., 64., 65., 0.],\n                        [66., 67., 68., 0.],\n                        [69., 70., 71., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [72., 73., 74., 0.],\n                        [75., 76., 77., 0.],\n                        [78., 79., 80., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                ],\n                [\n                    [\n                        [81., 82., 83., 0.],\n                        [84., 85., 86., 0.],\n                        [87., 88., 89., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [90., 91., 92., 0.],\n                        [93., 94., 95., 0.],\n                        [96., 97., 98., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [99., 100., 101., 0.],\n                        [102., 103., 104., 0.],\n                        [105., 106., 107., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                    [\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                        [0., 0., 0., 0.],\n                    ],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.]],\n                        [[16., 17., 18.], [20., 21., 22.], [24., 25., 26.]],\n                        [[32., 33., 34.], [36., 37., 38.], [40., 41., 42.]],\n                    ],\n                    [\n                        [[64., 65., 66.], [68., 69., 70.], [72., 73., 74.]],\n                        [[80., 81., 82.], [84., 85., 86.], [88., 89., 90.]],\n                        [[96., 97., 98.], [100., 101., 102.], [104., 105., 106.]],\n                    ],\n                ],\n                [\n                    [\n                        [[128., 129., 130.], [132., 133., 134.], [136., 137., 138.]],\n                        [[144., 145., 146.], [148., 149., 150.], [152., 153., 154.]],\n                        [[160., 161., 162.], [164., 165., 166.], [168., 169., 170.]],\n                    ],\n                    [\n                        [[192., 193., 194.], [196., 197., 198.], [200., 201., 202.]],\n                        [[208., 209., 210.], [212., 213., 214.], [216., 217., 218.]],\n                        [[224., 225., 226.], [228., 229., 230.], [232., 233., 234.]],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([1., 1.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct Conv3dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    kernel_size_3: usize,\n    padding_1: usize,\n    padding_2: usize,\n    padding_3: usize,\n    stride_1: usize,\n    stride_2: usize,\n    stride_3: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    dilation_3: usize,\n    groups: usize,\n    depth: usize,\n    height: usize,\n    width: usize,\n}\n\nstruct Grads {\n    x: TestTensor<5>,\n    weight: TestTensor<5>,\n    bias: TestTensor<1>,\n}\n\nimpl Conv3dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([\n            self.batch_size,\n            self.channels_in,\n            self.depth,\n            self.height,\n            self.width,\n        ]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n            self.kernel_size_3,\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = conv3d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvOptions::new(\n                [self.stride_1, self.stride_2, self.stride_3],\n                [self.padding_1, self.padding_2, self.padding_3],\n                [self.dilation_1, self.dilation_2, self.dilation_3],\n                self.groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::default();\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv_transpose1d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv_transpose1d, ops::ConvTransposeOptions};\n\n#[test]\nfn test_conv_transpose1d_basic() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: 3,\n        padding: 0,\n        padding_out: 0,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        size: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],\n                [[15.0, 15.0, 15.0, 15.0], [51.0, 51.0, 51.0, 51.0]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[44.0, 44.0, 44.0], [44.0, 44.0, 44.0]],\n                [[76.0, 76.0, 76.0], [76.0, 76.0, 76.0]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([12., 12.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose1d_padding() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: 3,\n        padding: 2,\n        padding_out: 0,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        size: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[7., 12., 8., 3.], [19., 36., 32., 15.]],\n                [[7., 12., 8., 3.], [19., 36., 32., 15.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[26., 22., 18.], [26., 22., 18.]],\n                [[42., 38., 34.], [42., 38., 34.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose1d_stride() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: 3,\n        padding: 0,\n        padding_out: 0,\n        stride: 2,\n        dilation: 1,\n        groups: 1,\n        size: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[44., 44., 44.], [44., 44., 44.]],\n                [[76., 76., 76.], [76., 76., 76.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([18., 18.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose1d_stride_padding_out() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: 3,\n        padding: 0,\n        padding_out: 1,\n        stride: 2,\n        dilation: 1,\n        groups: 1,\n        size: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[44., 44., 44.], [44., 44., 44.]],\n                [[76., 76., 76.], [76., 76., 76.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([20., 20.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose1d_dilation() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: 3,\n        padding: 0,\n        padding_out: 0,\n        stride: 1,\n        dilation: 2,\n        groups: 1,\n        size: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n                [[15., 15., 15., 15.], [51., 51., 51., 51.]],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[44., 44., 44.], [44., 44., 44.]],\n                [[76., 76., 76.], [76., 76., 76.]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([16., 16.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose1d_complex() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 2,\n        channels: [2, 4],\n        kernel_size: 3,\n        padding: 1,\n        padding_out: 1,\n        stride: 2,\n        dilation: 2,\n        groups: 2,\n        size: 8,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],\n                    [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],\n                ],\n                [\n                    [12.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0, 15.0],\n                    [36.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0, 51.0],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[168.0, 184.0, 184.0], [168.0, 184.0, 184.0]],\n                [[280.0, 312.0, 312.0], [280.0, 312.0, 312.0]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([36.0, 36.0, 36.0, 36.0], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct ConvTranspose1dTestCase {\n    batch_size: usize,\n    channels: [usize; 2],\n    kernel_size: usize,\n    padding: usize,\n    padding_out: usize,\n    stride: usize,\n    dilation: usize,\n    groups: usize,\n    size: usize,\n}\n\nstruct Grads {\n    x: TestTensor<3>,\n    weight: TestTensor<3>,\n    bias: TestTensor<1>,\n}\n\nimpl ConvTranspose1dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([self.batch_size, self.channels[0], self.size]);\n        let shape_weight = Shape::new([\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size,\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = conv_transpose1d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvTransposeOptions::new(\n                [self.stride],\n                [self.padding],\n                [self.padding_out],\n                [self.dilation],\n                self.groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), Tolerance::default());\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv_transpose2d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv_transpose2d, ops::ConvTransposeOptions};\n\n#[test]\nfn test_conv_transpose2d_basic() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                    ],\n                    [\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                    ],\n                ],\n                [\n                    [\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                        [153., 153., 153., 153.],\n                    ],\n                    [\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                        [477., 477., 477., 477.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],\n                    [[752., 752., 752.], [752., 752., 752.], [752., 752., 752.]],\n                ],\n                [\n                    [\n                        [1264., 1264., 1264.],\n                        [1264., 1264., 1264.],\n                        [1264., 1264., 1264.],\n                    ],\n                    [\n                        [1264., 1264., 1264.],\n                        [1264., 1264., 1264.],\n                        [1264., 1264., 1264.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([72., 72.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_padding() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [1, 1],\n        kernel_size: [3, 3],\n        padding: [1, 2],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[[\n                [13., 24., 20., 9.],\n                [15., 27., 21., 9.],\n                [15., 27., 21., 9.],\n                [7., 12., 8., 3.],\n            ]]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [[[[63., 57., 51.], [68., 60., 52.], [39., 33., 27.]]]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_stride() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [1, 1],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [2, 3],\n        dilation: [1, 1],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[[\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n            ]]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([108.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_stride_padding_out() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [1, 1],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [1, 2],\n        stride: [2, 3],\n        dilation: [1, 1],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[[\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n            ]]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([140.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_dilation() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [1, 1],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [2, 3],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[[\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n                [36., 36., 36., 36.],\n            ]]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [[[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([80.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_channels() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [2, 3],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        groups: 1,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [351., 351., 351., 351.],\n                    [351., 351., 351., 351.],\n                    [351., 351., 351., 351.],\n                    [351., 351., 351., 351.],\n                ],\n                [\n                    [1080., 1080., 1080., 1080.],\n                    [1080., 1080., 1080., 1080.],\n                    [1080., 1080., 1080., 1080.],\n                    [1080., 1080., 1080., 1080.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],\n                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],\n                    [[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]],\n                ],\n                [\n                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],\n                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],\n                    [[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([36., 36., 36.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_kernel_size() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [1, 1],\n        kernel_size: [3, 5],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        groups: 1,\n        size: [6, 6],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[[\n                [105., 105., 105., 105., 105., 105.],\n                [105., 105., 105., 105., 105., 105.],\n                [105., 105., 105., 105., 105., 105.],\n                [105., 105., 105., 105., 105., 105.],\n                [105., 105., 105., 105., 105., 105.],\n                [105., 105., 105., 105., 105., 105.],\n            ]]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [[[\n                [630., 630., 630., 630., 630.],\n                [630., 630., 630., 630., 630.],\n                [630., 630., 630., 630., 630.],\n            ]]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([80.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_groups() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [2, 2],\n        kernel_size: [3, 3],\n        padding: [0, 0],\n        padding_out: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        groups: 2,\n        size: [4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [36., 36., 36., 36.],\n                    [36., 36., 36., 36.],\n                    [36., 36., 36., 36.],\n                    [36., 36., 36., 36.],\n                ],\n                [\n                    [117., 117., 117., 117.],\n                    [117., 117., 117., 117.],\n                    [117., 117., 117., 117.],\n                    [117., 117., 117., 117.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[[120., 120., 120.], [120., 120., 120.], [120., 120., 120.]]],\n                [[[376., 376., 376.], [376., 376., 376.], [376., 376., 376.]]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([36., 36.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_complex_no_groups() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 2,\n        channels: [2, 3],\n        kernel_size: [3, 5],\n        padding: [1, 2],\n        padding_out: [1, 2],\n        stride: [2, 3],\n        dilation: [2, 3],\n        groups: 1,\n        size: [6, 8],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [600., 735., 735., 735., 735., 735., 735., 735.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                    ],\n                    [\n                        [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                    ],\n                ],\n                [\n                    [\n                        [600., 735., 735., 735., 735., 735., 735., 735.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                        [810., 990., 990., 990., 990., 990., 990., 990.],\n                    ],\n                    [\n                        [1680., 2085., 2085., 2085., 2085., 2085., 2085., 2085.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                        [2430., 3015., 3015., 3015., 3015., 3015., 3015., 3015.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [5320., 6040., 6040., 6040., 6040.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                    ],\n                    [\n                        [5320., 6040., 6040., 6040., 6040.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                    ],\n                    [\n                        [5320., 6040., 6040., 6040., 6040.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                        [6048., 6864., 6864., 6864., 6864.],\n                    ],\n                ],\n                [\n                    [\n                        [8680., 9880., 9880., 9880., 9880.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                    ],\n                    [\n                        [8680., 9880., 9880., 9880., 9880.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                    ],\n                    [\n                        [8680., 9880., 9880., 9880., 9880.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                        [10080., 11472., 11472., 11472., 11472.],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([896., 896., 896.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_complex_no_groups_2() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [4, 2],\n        kernel_size: [2, 3],\n        padding: [1, 2],\n        padding_out: [1, 2],\n        stride: [2, 3],\n        dilation: [1, 2],\n        groups: 1,\n        size: [10, 10],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [30., 42., 42., 42., 42., 42., 42., 42., 42., 42.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [48., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                ],\n                [\n                    [78., 114., 114., 114., 114., 114., 114., 114., 114., 114.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                    [144., 210., 210., 210., 210., 210., 210., 210., 210., 210.],\n                ],\n                [\n                    [126., 186., 186., 186., 186., 186., 186., 186., 186., 186.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                    [240., 354., 354., 354., 354., 354., 354., 354., 354., 354.],\n                ],\n                [\n                    [174., 258., 258., 258., 258., 258., 258., 258., 258., 258.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                    [336., 498., 498., 498., 498., 498., 498., 498., 498., 498.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [[4455., 4905., 4905.], [4500., 4950., 4950.]],\n                    [[4455., 4905., 4905.], [4500., 4950., 4950.]],\n                ],\n                [\n                    [[12555., 13905., 13905.], [13500., 14950., 14950.]],\n                    [[12555., 13905., 13905.], [13500., 14950., 14950.]],\n                ],\n                [\n                    [[20655., 22905., 22905.], [22500., 24950., 24950.]],\n                    [[20655., 22905., 22905.], [22500., 24950., 24950.]],\n                ],\n                [\n                    [[28755., 31905., 31905.], [31500., 34950., 34950.]],\n                    [[28755., 31905., 31905.], [31500., 34950., 34950.]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([570., 570.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose2d_complex_groups() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels: [4, 2],\n        kernel_size: [2, 3],\n        padding: [1, 2],\n        padding_out: [1, 2],\n        stride: [2, 3],\n        dilation: [1, 2],\n        groups: 2,\n        size: [10, 10],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [9., 12., 12., 12., 12., 12., 12., 12., 12., 12.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                    [12., 15., 15., 15., 15., 15., 15., 15., 15., 15.],\n                ],\n                [\n                    [21., 30., 30., 30., 30., 30., 30., 30., 30., 30.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                    [36., 51., 51., 51., 51., 51., 51., 51., 51., 51.],\n                ],\n                [\n                    [33., 48., 48., 48., 48., 48., 48., 48., 48., 48.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                    [60., 87., 87., 87., 87., 87., 87., 87., 87., 87.],\n                ],\n                [\n                    [45., 66., 66., 66., 66., 66., 66., 66., 66., 66.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                    [84., 123., 123., 123., 123., 123., 123., 123., 123., 123.],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[[4455., 4905., 4905.], [4500., 4950., 4950.]]],\n                [[[12555., 13905., 13905.], [13500., 14950., 14950.]]],\n                [[[20655., 22905., 22905.], [22500., 24950., 24950.]]],\n                [[[28755., 31905., 31905.], [31500., 34950., 34950.]]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([570., 570.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct ConvTranspose2dTestCase {\n    batch_size: usize,\n    channels: [usize; 2],\n    kernel_size: [usize; 2],\n    padding: [usize; 2],\n    padding_out: [usize; 2],\n    stride: [usize; 2],\n    dilation: [usize; 2],\n    groups: usize,\n    size: [usize; 2],\n}\n\nstruct Grads {\n    x: TestTensor<4>,\n    weight: TestTensor<4>,\n    bias: TestTensor<1>,\n}\n\nimpl ConvTranspose2dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([\n            self.batch_size,\n            self.channels[0],\n            self.size[0],\n            self.size[1],\n        ]);\n        let shape_weight = Shape::new([\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let output = conv_transpose2d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvTransposeOptions::new(\n                self.stride,\n                self.padding,\n                self.padding_out,\n                self.dilation,\n                self.groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::permissive();\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/conv_transpose3d.rs",
    "content": "use super::*;\nuse burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};\n\n#[test]\nfn test_conv_transpose3d_basic() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 2,\n        channels: [2, 2],\n        kernel_size: [3, 3, 3],\n        padding: [0, 0, 0],\n        padding_out: [0, 0, 0],\n        stride: [1, 1, 1],\n        dilation: [1, 1, 1],\n        groups: 1,\n        size: [4, 4, 4],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                    ],\n                    [\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                        [\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                            [13.250001, 13.250001, 13.250001, 13.250001],\n                        ],\n                    ],\n                    [\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                        [\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                            [40.249992, 40.249992, 40.249992, 40.249992],\n                        ],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                    ],\n                    [\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                        [\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                            [47.750000, 47.750000, 47.750000],\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                    ],\n                    [\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                        [\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                            [79.750000, 79.750000, 79.750000],\n                        ],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([432., 432.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_conv_transpose3d_complex_groups() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels: [4, 2],\n        kernel_size: [2, 3, 4],\n        padding: [1, 2, 3],\n        padding_out: [1, 2, 3],\n        stride: [2, 3, 4],\n        dilation: [1, 2, 3],\n        groups: 2,\n        size: [6, 6, 6],\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [\n                        [1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],\n                        [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],\n                        [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],\n                        [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],\n                        [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],\n                        [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],\n                    ],\n                    [\n                        [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                    ],\n                    [\n                        [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                    ],\n                    [\n                        [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                    ],\n                    [\n                        [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                    ],\n                    [\n                        [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                        [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],\n                    ],\n                ],\n                [\n                    [\n                        [2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],\n                        [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],\n                        [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],\n                        [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],\n                        [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],\n                        [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],\n                    ],\n                    [\n                        [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                    ],\n                    [\n                        [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                    ],\n                    [\n                        [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                    ],\n                    [\n                        [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                    ],\n                    [\n                        [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                        [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],\n                    ],\n                ],\n                [\n                    [\n                        [4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],\n                        [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],\n                        [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],\n                        [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],\n                        [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],\n                        [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],\n                    ],\n                    [\n                        [\n                            7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                        [\n                            11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,\n                        ],\n                    ],\n                ],\n                [\n                    [\n                        [5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],\n                        [\n                            8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,\n                        ],\n                        [\n                            8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,\n                        ],\n                        [\n                            8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,\n                        ],\n                        [\n                            8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,\n                        ],\n                        [\n                            8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,\n                        ],\n                    ],\n                    [\n                        [\n                            10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                    ],\n                    [\n                        [\n                            10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                        [\n                            15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,\n                        ],\n                    ],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [[\n                    [\n                        [18.663193, 22.309027, 22.309027, 22.309027],\n                        [21.875000, 26.145834, 26.145834, 26.145834],\n                        [21.875000, 26.145834, 26.145834, 26.145834],\n                    ],\n                    [\n                        [19.270832, 23.020834, 23.020834, 23.020834],\n                        [22.500000, 26.875002, 26.875002, 26.875002],\n                        [22.500000, 26.875002, 26.875002, 26.875002],\n                    ],\n                ]],\n                [[\n                    [\n                        [49.913193, 59.809029, 59.809029, 59.809029],\n                        [59.375000, 71.145836, 71.145836, 71.145836],\n                        [59.375000, 71.145836, 71.145836, 71.145836],\n                    ],\n                    [\n                        [56.770836, 68.020836, 68.020836, 68.020836],\n                        [67.500000, 80.875000, 80.875000, 80.875000],\n                        [67.500000, 80.875000, 80.875000, 80.875000],\n                    ],\n                ]],\n                [[\n                    [\n                        [81.163193, 97.309029, 97.309029, 97.309029],\n                        [96.875000, 116.145828, 116.145828, 116.145828],\n                        [96.875000, 116.145828, 116.145828, 116.145828],\n                    ],\n                    [\n                        [94.270828, 113.020828, 113.020828, 113.020828],\n                        [112.500000, 134.875000, 134.875000, 134.875000],\n                        [112.500000, 134.875000, 134.875000, 134.875000],\n                    ],\n                ]],\n                [[\n                    [\n                        [112.413200, 134.809021, 134.809021, 134.809021],\n                        [134.375000, 161.145828, 161.145828, 161.145828],\n                        [134.375000, 161.145828, 161.145828, 161.145828],\n                    ],\n                    [\n                        [131.770844, 158.020828, 158.020828, 158.020828],\n                        [157.500000, 188.875000, 188.875000, 188.875000],\n                        [157.500000, 188.875000, 188.875000, 188.875000],\n                    ],\n                ]],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([5346., 5346.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct ConvTranspose3dTestCase {\n    batch_size: usize,\n    channels: [usize; 2],\n    kernel_size: [usize; 3],\n    padding: [usize; 3],\n    padding_out: [usize; 3],\n    stride: [usize; 3],\n    dilation: [usize; 3],\n    groups: usize,\n    size: [usize; 3],\n}\n\nstruct Grads {\n    x: TestTensor<5>,\n    weight: TestTensor<5>,\n    bias: TestTensor<1>,\n}\n\nimpl ConvTranspose3dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let shape_x = Shape::new([\n            self.batch_size,\n            self.channels[0],\n            self.size[0],\n            self.size[1],\n            self.size[2],\n        ]);\n        let shape_weight = Shape::new([\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n            self.kernel_size[2],\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_weight.clone())\n                .into_data(),\n            &device,\n        )\n        .div_scalar(shape_weight.num_elements() as f32)\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_x.clone())\n                .into_data(),\n            &device,\n        )\n        .div_scalar(shape_x.num_elements() as f32)\n        .require_grad();\n        let output = conv_transpose3d(\n            x.clone(),\n            weight.clone(),\n            Some(bias.clone()),\n            ConvTransposeOptions::new(\n                self.stride,\n                self.padding,\n                self.padding_out,\n                self.dilation,\n                self.groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        let tolerance = Tolerance::permissive();\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cross.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[cfg(feature = \"std\")]\nuse burn_backend_tests::might_panic;\n\n#[test]\nfn backward_basic() {\n    let device = Default::default();\n    let a = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n        &device,\n    )\n    .require_grad();\n    let b = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),\n        &device,\n    )\n    .require_grad();\n\n    // Simple cross product; grad is a vector of ones.\n    let c = a.clone().cross(b.clone(), 1);\n    let grads = c.backward();\n\n    let a_grad = a.grad(&grads).unwrap().to_data();\n    let b_grad = b.grad(&grads).unwrap().to_data();\n\n    // For a: b×grad_out, where grad_out = [1,1,1]\n    let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);\n    // For b: grad_out×a\n    let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);\n\n    a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());\n    b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());\n}\n\n#[test]\nfn backward_after_sum() {\n    let device = Default::default();\n    let a = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n        &device,\n    )\n    .require_grad();\n    let b = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),\n        &device,\n    )\n    .require_grad();\n\n    // Sum reduces to scalar, but the gradient should be the same.\n    let c = a.clone().cross(b.clone(), 1).sum();\n    let grads = c.backward();\n\n    let a_grad = a.grad(&grads).unwrap().to_data();\n    let b_grad = b.grad(&grads).unwrap().to_data();\n\n    let expected_a = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 2.0, -1.0]]);\n    let expected_b = TensorData::from([[1.0, -2.0, 1.0], [1.0, -2.0, 1.0]]);\n\n    a_grad.assert_approx_eq::<FloatElem>(&expected_a, Tolerance::default());\n    b_grad.assert_approx_eq::<FloatElem>(&expected_b, Tolerance::default());\n}\n\n#[cfg(feature = \"std\")]\n#[might_panic(reason = \"not implemented: Cross product on non-last dimension\")]\n#[test]\nfn different_dim() {\n    // Also check when the cross is along a different dimension (e.g. dim 0).\n    let device = Default::default();\n    let a_raw = [[1.0, 4.0, 7.0], [2.0, 5.0, 8.0], [3.0, 6.0, 9.0]];\n    let b_raw = [[9.0, 6.0, 3.0], [8.0, 5.0, 2.0], [7.0, 4.0, 1.0]];\n\n    let a = TestTensor::<2>::from_data(TensorData::from(a_raw), &device);\n    let b = TestTensor::<2>::from_data(TensorData::from(b_raw), &device);\n    // Cross along dim 0. Some backends (for example CubeCL) may not support\n    // cross on non-last dimensions and will intentionally panic with a\n    // message like \"Cross product on non-last dimension not yet implemented\".\n    // In that case we treat the panic as a skipped test for that backend.\n    let out = a.cross(b.clone(), 0);\n\n    // Manually compute cross of each column vector using raw arrays\n    let expected = [\n        [\n            a_raw[1][0] * b_raw[2][0] - a_raw[2][0] * b_raw[1][0],\n            a_raw[1][1] * b_raw[2][1] - a_raw[2][1] * b_raw[1][1],\n            a_raw[1][2] * b_raw[2][2] - a_raw[2][2] * b_raw[1][2],\n        ],\n        [\n            a_raw[2][0] * b_raw[0][0] - a_raw[0][0] * b_raw[2][0],\n            a_raw[2][1] * b_raw[0][1] - a_raw[0][1] * b_raw[2][1],\n            a_raw[2][2] * b_raw[0][2] - a_raw[0][2] * b_raw[2][2],\n        ],\n        [\n            a_raw[0][0] * b_raw[1][0] - a_raw[1][0] * b_raw[0][0],\n            a_raw[0][1] * b_raw[1][1] - a_raw[1][1] * b_raw[0][1],\n            a_raw[0][2] * b_raw[1][2] - a_raw[1][2] * b_raw[0][2],\n        ],\n    ];\n\n    out.to_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cross_entropy.rs",
    "content": "use super::*;\nuse burn_tensor::{Tensor, TensorData, Tolerance, loss};\n\n#[test]\nfn test_cross_entropy_loss_grad() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n    let data_targets = TensorData::from([[0.8, 0.2], [0.9, 0.1]]);\n\n    let device = Default::default();\n    let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();\n    let tensor_targets =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data_targets, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = loss::cross_entropy_with_logits(tensor_3, tensor_targets);\n\n    let grads = tensor_4.backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::permissive();\n    let expected = TensorData::from([[0.26553, 0.26553], [0.44954, 0.44954]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[-1.34863, 1.34863], [-2.06371, 2.06371]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cummax.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_cummax() {\n    // Simple test to verify cummax gradients work\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 2.0]), &device)\n        .require_grad();\n\n    let output = tensor.clone().cummax(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 2.0, 0.0]\n    let expected = TensorData::from([1.0, 2.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummax_2d() {\n    // Test 2D cummax gradients\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[1.0, 3.0, 2.0], [2.0, 5.0, 4.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let output = tensor.clone().cummax(1);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]\n    let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummax_duplicate_values() {\n    // Test with duplicate maximum values - critical edge case\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 3.0, 3.0, 2.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cummax(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // input:  [1.0, 3.0, 3.0, 2.0]\n    // cummax: [1.0, 3.0, 3.0, 3.0]\n    // PyTorch reference: [1.0, 1.0, 2.0, 0.0]\n    // Position 2 gets grad from itself + position 3\n    let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummax_all_same() {\n    // Test with all same values\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)\n        .require_grad();\n\n    let output = tensor.clone().cummax(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 1.0, 1.0]\n    // Each position matches cummax, so each gets its own gradient\n    let expected = TensorData::from([1.0, 1.0, 1.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummax_increasing() {\n    // Test with increasing sequence\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cummax(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 1.0, 1.0, 1.0]\n    // Each position is a new maximum\n    let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummax_2d_duplicates() {\n    // Test 2D with duplicate values\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[1.0, 3.0, 3.0, 2.0], [2.0, 5.0, 5.0, 4.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let output = tensor.clone().cummax(1);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]\n    let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cummin.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_cummin() {\n    // Simple test to verify cummin gradients work\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 4.0]), &device)\n        .require_grad();\n\n    let output = tensor.clone().cummin(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 2.0, 0.0]\n    let expected = TensorData::from([1.0, 2.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummin_2d() {\n    // Test 2D cummin gradients\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[3.0, 2.0, 4.0], [5.0, 1.0, 3.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let output = tensor.clone().cummin(1);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]\n    let expected = TensorData::from([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummin_duplicate_values() {\n    // Test with duplicate minimum values - critical edge case\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([3.0, 2.0, 2.0, 4.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cummin(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // input:  [3.0, 2.0, 2.0, 4.0]\n    // cummin: [3.0, 2.0, 2.0, 2.0]\n    // PyTorch reference: [1.0, 1.0, 2.0, 0.0]\n    // Position 2 gets grad from itself + position 3\n    let expected = TensorData::from([1.0, 1.0, 2.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummin_all_same() {\n    // Test with all same values\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 2.0, 2.0]), &device)\n        .require_grad();\n\n    let output = tensor.clone().cummin(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 1.0, 1.0]\n    // Each position matches cummin, so each gets its own gradient\n    let expected = TensorData::from([1.0, 1.0, 1.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummin_decreasing() {\n    // Test with decreasing sequence\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([5.0, 4.0, 3.0, 2.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cummin(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 1.0, 1.0, 1.0]\n    // Each position is a new minimum\n    let expected = TensorData::from([1.0, 1.0, 1.0, 1.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cummin_2d_duplicates() {\n    // Test 2D with duplicate values\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[3.0, 2.0, 2.0, 4.0], [5.0, 1.0, 1.0, 3.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let output = tensor.clone().cummin(1);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]\n    let expected = TensorData::from([[1.0, 1.0, 2.0, 0.0], [1.0, 1.0, 2.0, 0.0]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cumprod.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_cumprod() {\n    // Simple test to verify cumprod gradients work\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0]), &device)\n        .require_grad();\n\n    let output = tensor.clone().cumprod(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [16.0, 10.0, 6.0]\n    let expected = TensorData::from([16.0, 10.0, 6.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cumprod_2d() {\n    // Test 2D cumprod gradients\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n        &device,\n    )\n    .require_grad();\n\n    let output = tensor.clone().cumprod(1);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]\n    let expected = TensorData::from([[9.0, 4.0, 2.0], [36.0, 28.0, 20.0]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n// TODO: The following tests are currently ignored due to a known limitation\n// in the cumprod gradient implementation. The current implementation uses\n// division (grad / input), which produces NaN when the input contains zeros.\n//\n// A proper fix requires implementing a zero-safe algorithm using exclusive\n// cumulative products (similar to PyTorch's cumprod_backward or JAX's\n// associative_scan approach). This is a non-trivial implementation that\n// requires careful handling of cumulative products in both forward and\n// reverse directions.\n//\n// See: https://github.com/tracel-ai/burn/issues/3864\n//\n// References:\n// - PyTorch: https://github.com/pytorch/pytorch (cumprod_backward)\n// - JAX PR #2596: Parallel prefix scan implementation\n// - TensorFlow Issue #3862: tf.cumprod's gradient produces nans given zeros\n\n#[test]\n#[ignore = \"cumprod gradient with zeros not yet implemented - produces NaN due to division by zero\"]\nfn should_diff_cumprod_zero_in_middle() {\n    // Test cumprod with zero in the middle - edge case for division\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 4.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cumprod(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 32.0, 0.0, 0.0]\n    let expected = TensorData::from([1.0, 32.0, 0.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[ignore = \"cumprod gradient with zeros not yet implemented - produces NaN due to division by zero\"]\nfn should_diff_cumprod_zero_at_start() {\n    // Test cumprod with zero at the beginning\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([0.0, 2.0, 3.0, 4.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cumprod(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [33.0, 0.0, 0.0, 0.0]\n    let expected = TensorData::from([33.0, 0.0, 0.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[ignore = \"cumprod gradient with zeros not yet implemented - produces NaN due to division by zero\"]\nfn should_diff_cumprod_zero_at_end() {\n    // Test cumprod with zero at the end\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 3.0, 4.0, 0.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cumprod(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [16.0, 10.0, 6.0, 24.0]\n    let expected = TensorData::from([16.0, 10.0, 6.0, 24.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[ignore = \"cumprod gradient with zeros not yet implemented - produces NaN due to division by zero\"]\nfn should_diff_cumprod_multiple_zeros() {\n    // Test cumprod with multiple zeros\n    let device = Default::default();\n    let tensor =\n        TestAutodiffTensor::<1>::from_data(TensorData::from([2.0, 0.0, 3.0, 0.0, 5.0]), &device)\n            .require_grad();\n\n    let output = tensor.clone().cumprod(0);\n    let grads = output.sum().backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    // PyTorch reference: [1.0, 8.0, 0.0, 0.0, 0.0]\n    let expected = TensorData::from([1.0, 8.0, 0.0, 0.0, 0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/cumsum.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_cumsum_dim0() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.cumsum(0);\n    let tensor_5 = tensor_1.clone().mul(tensor_4);\n    let grads = tensor_5.sum().backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    // Expected gradients computed with PyTorch\n    let expected = TensorData::from([[-14.0, 24.0], [17.0, 6.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[3.0, 10.0], [-1.0, 37.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cumsum_dim1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.cumsum(1);\n    let tensor_5 = tensor_1.clone().mul(tensor_4);\n    let grads = tensor_5.sum().backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    // Expected gradients computed with PyTorch\n    let expected = TensorData::from([[1.0, 69.0], [-13.0, -28.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[18.0, 13.0], [71.0, 58.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_cumsum_complex() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.clone().cumsum(1);\n    let tensor_5 = tensor_4.mul(tensor_3);\n\n    let grads = tensor_5.sum().backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    // Expected gradients computed with PyTorch\n    let expected = TensorData::from([[371.0, 542.0], [2246.0, 3281.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[507.0, 528.0], [704.0, 733.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/deform_conv2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Shape, module::deform_conv2d, ops::DeformConvOptions};\n\n#[test]\nfn test_deform_conv2d_basic() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [0.000, 6.0678, 14.2071, 12.2477],\n                    [11.2292, 33.7937, 50.1555, 44.0561],\n                    [17.9294, 57.2174, 85.1505, 79.1840],\n                    [18.0220, 73.6263, 126.8184, 151.6910],\n                ],\n                [\n                    [0.000, 8.9783, 20.7620, 17.7888],\n                    [16.2326, 48.7386, 71.7961, 62.5845],\n                    [25.3808, 80.5195, 119.0949, 110.0938],\n                    [25.0567, 101.8461, 174.3329, 206.6013],\n                ],\n            ]],\n            &device,\n        ),\n        offset: TestTensor::from_floats(\n            [[\n                [[0.000, 15.0000], [30.000, 45.0000]],\n                [[0.000, 3.7500], [7.5000, 11.2500]],\n                [[62.6667, 78.3333], [94.0000, 109.6667]],\n                [[15.6667, 19.5833], [23.5000, 27.4167]],\n                [[130.6667, 104.1250], [163.3333, 122.2732]],\n                [[32.6667, -492.9583], [40.8333, -787.1620]],\n                [[204.0000, 221.0000], [238.0000, 255.0000]],\n                [[51.0000, 55.2500], [59.5000, 63.7500]],\n                [[282.6667, 300.3333], [318.0000, 335.6667]],\n                [[70.6667, 75.0833], [79.5000, 83.9167]],\n                [[366.6667, 144.3750], [403.3333, 146.4121]],\n                [[91.6667, -1788.9860], [100.8333, -2392.7456]],\n                [[456.0000, 475.0000], [-2718.6250, -2953.2188]],\n                [[114.0000, 118.7500], [37.7361, 37.4063]],\n                [[550.6667, 570.3334], [-3404.5139, -3672.5312]],\n                [[137.6667, 142.5833], [28.6806, 27.5197]],\n                [[650.6667, 27.9584], [-4174.3657, -59.7509]],\n                [[162.6667, -3991.0139], [14.4028, -298.7557]],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [0.7029, 2.8356, 5.1067],\n                        [12.7492, 19.4745, 17.8345],\n                        [22.0687, 25.9156, 14.6394],\n                    ],\n                    [\n                        [3.3696, 12.6134, 19.2671],\n                        [36.7492, 50.5856, 43.5506],\n                        [50.8774, 56.3292, 30.7470],\n                    ],\n                ],\n                [\n                    [\n                        [0.7029, 2.8356, 5.1067],\n                        [12.7492, 19.4745, 17.8345],\n                        [22.0687, 25.9156, 14.6394],\n                    ],\n                    [\n                        [3.3696, 12.6134, 19.2671],\n                        [36.7492, 50.5856, 43.5506],\n                        [50.8774, 56.3292, 30.7470],\n                    ],\n                ],\n                [\n                    [\n                        [0.7029, 2.8356, 5.1067],\n                        [12.7492, 19.4745, 17.8345],\n                        [22.0687, 25.9156, 14.6394],\n                    ],\n                    [\n                        [3.3696, 12.6134, 19.2671],\n                        [36.7492, 50.5856, 43.5506],\n                        [50.8774, 56.3292, 30.7470],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        mask: TestTensor::from_floats(\n            [[\n                [[1303.5000, 1447.8750], [1862.2500, 2006.6250]],\n                [[1571.1666, 1721.9581], [2154.7500, 2305.5417]],\n                [[1857.4999, 1396.7151], [2465.9167, 1753.2246]],\n                [[2315.5000, 2479.1250], [2948.7502, 3112.3750]],\n                [[2645.1665, 2815.2085], [3303.2500, 3473.2917]],\n                [[2993.5000, 1150.0625], [3676.4165, 1300.4055]],\n                [[3531.5000, 3714.3752], [1150.1876, 1148.4744]],\n                [[3923.1665, 4112.4585], [794.3865, 770.0470]],\n                [[4333.5000, 181.4101], [368.3260, 4.2679]],\n            ]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([4., 4., 4.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_deform_conv2d_batched() {\n    let test = Conv2dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [0.000, 3.4604, 8.7539, 6.8080],\n                        [8.4661, 24.0784, 35.4610, 26.4276],\n                        [19.5988, 51.0406, 68.4389, 53.4993],\n                        [17.4698, 47.9106, 67.3808, 56.6063],\n                    ],\n                    [\n                        [0.000, 5.1185, 12.7803, 9.8796],\n                        [12.1957, 34.5728, 50.4616, 37.3777],\n                        [27.4521, 71.1227, 94.5778, 73.4724],\n                        [24.1147, 65.8443, 91.8995, 76.7475],\n                    ],\n                ],\n                [\n                    [\n                        [6.3750, 19.3553, 26.4935, 22.5650],\n                        [17.0026, 57.8088, 85.5580, 78.0746],\n                        [20.7334, 86.5793, 139.4667, 136.4133],\n                        [16.8126, 103.0225, 186.4502, 206.9613],\n                    ],\n                    [\n                        [9.5625, 28.8786, 39.1137, 32.9178],\n                        [25.1984, 85.0747, 124.6941, 112.5691],\n                        [30.0242, 124.2863, 198.6056, 192.4489],\n                        [23.5826, 143.4660, 257.8752, 283.2587],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n\n        offset: TestTensor::from_floats(\n            [\n                [\n                    [[0.000, 7.5000], [15.0000, 22.5000]],\n                    [[0.000, 1.8750], [3.7500, 5.6250]],\n                    [[31.3333, 39.1667], [47.0000, 54.8333]],\n                    [[7.8333, 9.7917], [11.7500, 13.7083]],\n                    [[65.3333, 62.7813], [81.6667, 75.4849]],\n                    [[16.3333, -237.8021], [20.4167, -381.7280]],\n                    [[102.0000, 110.5000], [119.0000, 127.5000]],\n                    [[25.5000, 27.6250], [29.7500, 31.8750]],\n                    [[141.3333, 150.1667], [159.0000, 167.8333]],\n                    [[35.3333, 37.5417], [39.7500, 41.9583]],\n                    [[183.3333, 132.3438], [201.6667, 142.0197]],\n                    [[45.8333, -839.6840], [50.4167, -1133.4155]],\n                    [[228.0000, 237.5000], [-1336.1562, -1452.1173]],\n                    [[57.0000, 59.3750], [40.3090, 41.4141]],\n                    [[275.3333, 285.1667], [-1670.5034, -1802.9244]],\n                    [[68.8333, 71.2917], [44.0451, 44.9841]],\n                    [[325.3333, 174.7396], [-2045.1747, -1090.4585]],\n                    [[81.3333, -1844.0659], [46.8090, -1150.2101]],\n                ],\n                [\n                    [[270.000, 277.5000], [285.0000, 292.5000]],\n                    [[67.5000, 69.3750], [71.2500, 73.1250]],\n                    [[313.3333, 321.1667], [329.0000, 336.8333]],\n                    [[78.3333, 80.2917], [82.2500, 84.2083]],\n                    [[359.3333, 130.1563], [375.6667, 130.6099]],\n                    [[89.8333, -4312.7603], [93.9167, -4893.6035]],\n                    [[408.0000, 416.5000], [425.0000, 433.5000]],\n                    [[102.0000, 104.1250], [106.2500, 108.3750]],\n                    [[459.3333, 468.1667], [477.0000, 485.8333]],\n                    [[114.8333, 117.0417], [119.2500, 121.4583]],\n                    [[513.3334, 97.9688], [531.6667, 93.8947]],\n                    [[128.3333, -6720.3926], [132.9167, -7504.5405]],\n                    [[570.000, 579.5000], [-7971.8438, -8251.0850]],\n                    [[142.5000, 144.8750], [22.4965, 21.8203]],\n                    [[629.3333, 639.1667], [-8948.2334, -9249.6641]],\n                    [[157.3333, 159.7917], [15.7743, 14.8695]],\n                    [[691.3333, 14.6145], [-9992.9453, -70.4040]],\n                    [[172.8333, -9818.5234], [7.4132, -352.0222]],\n                ],\n            ],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [77.7195, 89.8692, 69.0213],\n                        [121.0760, 137.0775, 92.2989],\n                        [100.0212, 106.5561, 61.1851],\n                    ],\n                    [\n                        [112.3862, 131.6470, 103.8793],\n                        [177.0760, 200.1887, 138.2681],\n                        [149.5922, 158.7074, 94.3991],\n                    ],\n                ],\n                [\n                    [\n                        [77.7195, 89.8692, 69.0213],\n                        [121.0760, 137.0775, 92.2989],\n                        [100.0212, 106.5561, 61.1851],\n                    ],\n                    [\n                        [112.3862, 131.6470, 103.8793],\n                        [177.0760, 200.1887, 138.2681],\n                        [149.5922, 158.7074, 94.3991],\n                    ],\n                ],\n                [\n                    [\n                        [77.7195, 89.8692, 69.0213],\n                        [121.0760, 137.0775, 92.2989],\n                        [100.0212, 106.5561, 61.1851],\n                    ],\n                    [\n                        [112.3862, 131.6470, 103.8793],\n                        [177.0760, 200.1887, 138.2681],\n                        [149.5922, 158.7074, 94.3991],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        mask: TestTensor::from_floats(\n            [\n                [\n                    [[1299.7499, 1439.4375], [1849.1249, 1988.8125]],\n                    [[1528.0834, 1673.9791], [2101.8750, 2247.7708]],\n                    [[1771.7500, 1624.9811], [2369.9583, 2099.5039]],\n                    [[2183.7500, 2342.0625], [2806.3750, 2964.6875]],\n                    [[2464.0833, 2628.6042], [3111.1250, 3275.6458]],\n                    [[2759.7500, 1979.2551], [3431.2085, 2390.0286]],\n                    [[3241.7498, 3418.6873], [2415.3589, 2500.8682]],\n                    [[3574.0835, 3757.2292], [2394.3889, 2471.7510]],\n                    [[3921.7500, 2095.5293], [2345.9363, 1199.5048]],\n                ],\n                [\n                    [[5957.2500, 6096.9375], [6506.6250, 6646.3125]],\n                    [[6392.5835, 6538.4790], [6966.3750, 7112.2705]],\n                    [[6843.2500, 2443.8982], [7441.4585, 2550.9199]],\n                    [[7462.2505, 7620.5625], [8084.8745, 8243.1875]],\n                    [[7949.5835, 8114.1045], [8596.6250, 8761.1465]],\n                    [[8452.2500, 1591.6719], [9123.7080, 1589.9454]],\n                    [[9141.2500, 9318.1875], [1414.3584, 1375.1803]],\n                    [[9680.5840, 9863.7285], [949.0560, 897.3544]],\n                    [[10235.2500, 213.4454], [428.2699, 2.4790]],\n                ],\n            ],\n            &device,\n        ),\n        bias: TestTensor::from_floats([8., 8., 8.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_deform_conv2d_different_kernel_size() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [14.558521, 27.249609, 37.382030, 36.039406],\n                    [33.151936, 60.480656, 81.264656, 78.618156],\n                    [57.520061, 108.623283, 153.413559, 170.072998],\n                    [54.706184, 102.596664, 144.367157, 162.643570],\n                ],\n                [\n                    [25.836353, 48.088451, 65.249161, 62.103317],\n                    [56.805233, 102.995605, 136.983124, 131.120911],\n                    [96.105408, 179.790192, 250.550934, 272.668793],\n                    [90.210945, 167.567917, 232.847275, 257.934692],\n                ],\n            ]],\n            &device,\n        ),\n        offset: TestTensor::from_floats(\n            [[\n                [\n                    [0.0e+00, 5.355903e+00, 1.171528e+01],\n                    [3.124999e-01, 8.000000e+00, 1.000000e+01],\n                    [7.500000e-01, 1.400000e+01, 1.600000e+01],\n                    [1.312500e+00, 2.000000e+01, 2.200000e+01],\n                ],\n                [\n                    [0.0e+00, 1.736104e-03, 6.944418e-03],\n                    [1.606250e+01, 2.000000e+00, 2.500000e+00],\n                    [4.425000e+01, 3.500000e+00, 4.000000e+00],\n                    [8.456250e+01, 5.000000e+00, 5.500000e+00],\n                ],\n                [\n                    [6.745834e+01, 7.996479e+01, 9.353048e+01],\n                    [3.166667e+01, 3.377778e+01, 3.588889e+01],\n                    [3.800000e+01, 4.011111e+01, 4.222223e+01],\n                    [4.433333e+01, 4.644444e+01, 4.855556e+01],\n                ],\n                [\n                    [5.277777e-01, 5.955827e-01, 6.670526e-01],\n                    [7.916667e+00, 8.444445e+00, 8.972222e+00],\n                    [9.500000e+00, 1.002778e+01, 1.055556e+01],\n                    [1.108333e+01, 1.161111e+01, 1.213889e+01],\n                ],\n                [\n                    [1.547778e+02, 1.751640e+02, 1.518874e+02],\n                    [6.000000e+01, 6.222223e+01, 4.989969e+01],\n                    [6.666666e+01, 6.888889e+01, 5.432098e+01],\n                    [7.333334e+01, 7.555556e+01, 5.860340e+01],\n                ],\n                [\n                    [2.222223e+00, 2.363040e+00, -3.360339e+01],\n                    [1.500000e+01, 1.555556e+01, -2.277485e+02],\n                    [1.666667e+01, 1.722222e+01, -3.231605e+02],\n                    [1.833333e+01, 1.888889e+01, -4.320448e+02],\n                ],\n                [\n                    [2.641250e+02, 2.021189e+02, 0.0e+00],\n                    [9.100000e+01, 6.481482e+01, 0.0e+00],\n                    [9.800000e+01, 6.863078e+01, 0.0e+00],\n                    [1.050000e+02, 7.230093e+01, 0.0e+00],\n                ],\n                [\n                    [5.250000e+00, -7.268316e+01, 0.0e+00],\n                    [2.275000e+01, -3.346296e+02, 0.0e+00],\n                    [2.450000e+01, -4.611053e+02, 0.0e+00],\n                    [2.625000e+01, -6.017269e+02, 0.0e+00],\n                ],\n                [\n                    [4.400000e+01, 1.197778e+02, 1.222222e+02],\n                    [4.804860e+01, 1.271111e+02, 1.295556e+02],\n                    [5.225000e+01, 1.344444e+02, 1.368889e+02],\n                    [-3.138958e+02, -8.007446e+02, -8.507313e+02],\n                ],\n                [\n                    [3.377778e+02, 2.994445e+01, 3.055556e+01],\n                    [4.848542e+02, 3.177778e+01, 3.238889e+01],\n                    [6.467500e+02, 3.361111e+01, 3.422222e+01],\n                    [4.909653e+02, 2.239892e+01, 2.265992e+01],\n                ],\n                [\n                    [1.533333e+02, 1.558889e+02, 1.584444e+02],\n                    [1.610000e+02, 1.635556e+02, 1.661111e+02],\n                    [1.686667e+02, 1.712222e+02, 1.737778e+02],\n                    [-9.952491e+02, -1.054551e+03, -1.115134e+03],\n                ],\n                [\n                    [3.833333e+01, 3.897222e+01, 3.961111e+01],\n                    [4.025000e+01, 4.088889e+01, 4.152778e+01],\n                    [4.216667e+01, 4.280556e+01, 4.344445e+01],\n                    [2.433767e+01, 2.453511e+01, 2.472810e+01],\n                ],\n                [\n                    [1.920000e+02, 1.946667e+02, 8.907407e+01],\n                    [2.000000e+02, 2.026667e+02, 9.054632e+01],\n                    [2.080000e+02, 2.106667e+02, 9.185186e+01],\n                    [-1.272938e+03, -1.343509e+03, -5.811921e+02],\n                ],\n                [\n                    [4.800000e+01, 4.866667e+01, -7.413704e+02],\n                    [5.000000e+01, 5.066667e+01, -9.788981e+02],\n                    [5.200000e+01, 5.266667e+01, -1.232593e+03],\n                    [2.531250e+01, 2.543518e+01, -6.388311e+02],\n                ],\n                [\n                    [2.333333e+02, 8.772182e+01, 0.0e+00],\n                    [2.416667e+02, 8.827161e+01, 0.0e+00],\n                    [2.500000e+02, 8.864776e+01, 0.0e+00],\n                    [-1.587216e+03, -5.535372e+02, 0.0e+00],\n                ],\n                [\n                    [5.833333e+01, -9.011902e+02, 0.0e+00],\n                    [6.041667e+01, -1.179988e+03, 0.0e+00],\n                    [6.250000e+01, -1.475625e+03, 0.0e+00],\n                    [2.489150e+01, -6.213175e+02, 0.0e+00],\n                ],\n                [\n                    [1.964444e+02, 2.802222e+02, 2.831111e+02],\n                    [2.055625e+02, 2.888889e+02, 2.917778e+02],\n                    [-1.173472e+03, -1.679611e+03, -1.771290e+03],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [1.144889e+03, 7.005556e+01, 7.077778e+01],\n                    [1.469646e+03, 7.222223e+01, 7.294444e+01],\n                    [5.029167e+02, 2.298823e+01, 2.295062e+01],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [3.240000e+02, 3.270000e+02, 3.300000e+02],\n                    [3.330000e+02, 3.360000e+02, 3.390000e+02],\n                    [-1.931469e+03, -2.034961e+03, -2.139958e+03],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [8.100000e+01, 8.175000e+01, 8.250000e+01],\n                    [8.325000e+01, 8.400000e+01, 8.475000e+01],\n                    [1.959376e+01, 1.946614e+01, 1.933334e+01],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [3.733333e+02, 3.764445e+02, 4.480865e+01],\n                    [3.826667e+02, 3.857778e+02, 4.185955e+01],\n                    [-2.313792e+03, -2.431276e+03, -2.392101e+02],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [9.333333e+01, 9.411111e+01, -1.904932e+03],\n                    [9.566667e+01, 9.644444e+01, -2.344715e+03],\n                    [1.429166e+01, 1.406212e+01, -3.417283e+02],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [4.253333e+02, 1.636843e+01, 0.0e+00],\n                    [4.350000e+02, 1.217279e+01, 0.0e+00],\n                    [-2.738517e+03, -4.792887e+01, 0.0e+00],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [1.063333e+02, -2.178747e+03, 0.0e+00],\n                    [1.087500e+02, -2.670679e+03, 0.0e+00],\n                    [6.947917e+00, -1.629574e+02, 0.0e+00],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [1.856041, 7.203409, 12.833395, 11.969448],\n                        [24.236776, 40.125511, 41.396423, 27.642044],\n                        [43.613083, 57.508926, 46.093338, 25.174383],\n                    ],\n                    [\n                        [6.989914, 26.580338, 42.618557, 37.501404],\n                        [75.623192, 116.925674, 113.288368, 72.567764],\n                        [112.724869, 139.826447, 107.653435, 56.799385],\n                    ],\n                ],\n                [\n                    [\n                        [1.856041, 7.203409, 12.833395, 11.969448],\n                        [24.236776, 40.125511, 41.396423, 27.642044],\n                        [43.613083, 57.508926, 46.093338, 25.174383],\n                    ],\n                    [\n                        [6.989914, 26.580338, 42.618557, 37.501404],\n                        [75.623192, 116.925674, 113.288368, 72.567764],\n                        [112.724869, 139.826447, 107.653435, 56.799385],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        mask: TestTensor::from_floats(\n            [[\n                [\n                    [0.0e+00, 2.677941e+00, 5.857617e+00],\n                    [4.015623e+01, 7.759999e+02, 8.492499e+02],\n                    [6.637500e+01, 1.067750e+03, 1.141000e+03],\n                    [9.865628e+01, 1.359500e+03, 1.432750e+03],\n                ],\n                [\n                    [6.745831e+01, 7.688924e+01, 8.684974e+01],\n                    [8.387916e+02, 9.161111e+02, 9.934306e+02],\n                    [1.146750e+03, 1.224069e+03, 1.301389e+03],\n                    [1.454708e+03, 1.532028e+03, 1.609347e+03],\n                ],\n                [\n                    [1.547778e+02, 1.716607e+02, 1.460455e+02],\n                    [9.861667e+02, 1.067556e+03, 8.756536e+02],\n                    [1.310333e+03, 1.391722e+03, 1.110864e+03],\n                    [1.634500e+03, 1.715889e+03, 1.339339e+03],\n                ],\n                [\n                    [2.641250e+02, 1.993876e+02, 0.0e+00],\n                    [1.144875e+03, 8.365740e+02, 0.0e+00],\n                    [1.485250e+03, 1.056253e+03, 0.0e+00],\n                    [1.825625e+03, 1.268859e+03, 0.0e+00],\n                ],\n                [\n                    [3.800000e+02, 1.047861e+03, 1.137389e+03],\n                    [5.276354e+02, 1.404444e+03, 1.493972e+03],\n                    [6.826807e+02, 1.761028e+03, 1.850555e+03],\n                    [5.038855e+02, 1.256341e+03, 1.304936e+03],\n                ],\n                [\n                    [1.123500e+03, 1.217097e+03, 1.310694e+03],\n                    [1.496292e+03, 1.589889e+03, 1.683486e+03],\n                    [1.869083e+03, 1.962681e+03, 2.056278e+03],\n                    [1.146700e+03, 1.190136e+03, 1.232930e+03],\n                ],\n                [\n                    [1.300000e+03, 1.397667e+03, 6.512036e+02],\n                    [1.689000e+03, 1.786667e+03, 8.072734e+02],\n                    [2.078000e+03, 2.175667e+03, 9.552593e+02],\n                    [1.060781e+03, 1.097745e+03, 4.656539e+02],\n                ],\n                [\n                    [1.487833e+03, 5.672195e+02, 0.0e+00],\n                    [1.893042e+03, 6.972655e+02, 0.0e+00],\n                    [2.298250e+03, 8.188910e+02, 0.0e+00],\n                    [9.472098e+02, 3.238781e+02, 0.0e+00],\n                ],\n                [\n                    [1.216444e+03, 1.792806e+03, 1.898611e+03],\n                    [1.536448e+03, 2.214222e+03, 2.320028e+03],\n                    [5.177084e+02, 7.256571e+02, 7.493920e+02],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [1.897500e+03, 2.007375e+03, 2.117250e+03],\n                    [2.335125e+03, 2.445000e+03, 2.554875e+03],\n                    [5.591096e+02, 5.750975e+02, 5.903336e+02],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [2.119333e+03, 2.233278e+03, 2.654414e+02],\n                    [2.573167e+03, 2.687111e+03, 2.907444e+02],\n                    [3.856317e+02, 3.924502e+02, 3.737657e+01],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n                [\n                    [2.352500e+03, 9.009851e+01, 0.0e+00],\n                    [2.822542e+03, 7.854909e+01, 0.0e+00],\n                    [1.785990e+02, 2.930897e+00, 0.0e+00],\n                    [0.0e+00, 0.0e+00, 0.0e+00],\n                ],\n            ]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([12., 12.], &device),\n    };\n    test.assert_grads(grads);\n}\n\n#[test]\nfn test_deform_conv2d_different_padding() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 2,\n        padding_2: 3,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n    let device = Default::default();\n    let grads = Grads {\n        x: TestTensor::from_floats(\n            [[\n                [\n                    [60.633026, 60.906506, 61.179493, 61.451954],\n                    [122.557770, 123.088188, 123.618599, 124.149033],\n                    [126.801132, 127.331535, 127.861938, 128.392365],\n                    [131.044434, 131.574875, 132.105286, 132.635712],\n                ],\n                [\n                    [102.000595, 102.497604, 102.993835, 103.489281],\n                    [198.932983, 199.830597, 200.728210, 201.625870],\n                    [206.113968, 207.011627, 207.909256, 208.806870],\n                    [213.294952, 214.192627, 215.090271, 215.987930],\n                ],\n            ]],\n            &device,\n        ),\n        // => Position 788: 10.421875 != 10.0546875\n        //  diff (rel = +1.79e-2, abs = +3.67e-1), tol (rel = +1.00e-2, abs = +9.77e-4)\n        offset: TestTensor::from_floats(\n            [[\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 0.0, 0.895062, 14.760561, 17.604168, 20.698063, 22.200424, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 0.687500, 9.500000, 10.0, 10.500000, 10.108797, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 1.113426, 13.500000, 14.000000, 14.499999, 13.645835, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 1.613426, 17.500000, 18.000000, 18.500000, 17.108795, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        -12.395836,\n                        -122.399445,\n                        -130.752319,\n                        -139.355469,\n                        -131.526810,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 0.0, 0.154321, 0.017506, 0.020833, 0.024450, -0.387539, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 24.187502, 2.375000, 2.500000, 2.625000, -37.863422, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 48.057869, 3.375000, 3.500000, 3.625000, -66.770836, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        80.02312,\n                        4.375000,\n                        4.500000,\n                        4.625000,\n                        -103.752319,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        113.215271,\n                        5.107495,\n                        5.219907,\n                        5.332031,\n                        -139.725891,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 14.206017, 83.017586, 92.379395, 102.010040, 90.356323, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 6.504737, 35.444443, 35.981483, 36.518517, 29.978970, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 7.668316, 39.740742, 40.277779, 40.814816, 33.071907, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 8.911458, 44.037037, 44.574074, 45.111111, 36.085281, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        -57.523048,\n                        -274.267914,\n                        -289.547089,\n                        -305.095093,\n                        -248.578552,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 9.749230, 0.955354, 0.980994, 1.006945, -13.930464, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        96.046921,\n                        8.861111,\n                        8.995371,\n                        9.129629,\n                        -129.920715,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        147.434769,\n                        9.935185,\n                        10.069445,\n                        10.203704,\n                        -186.718735,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        207.494781,\n                        11.009259,\n                        11.143518,\n                        11.277778,\n                        -252.188889,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        226.050003,\n                        10.153355,\n                        10.252030,\n                        10.350393,\n                        -266.255280,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        44.224964, 159.898483, 176.651901, 193.692688, 146.270813, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        19.050755, 64.870377, 65.444443, 66.018517, 46.553150, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        21.049385, 69.462967, 70.037033, 70.611115, 49.104595, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        23.133059, 74.055557, 74.629631, 75.203705, 51.570988, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        -141.200272,\n                        -445.302155,\n                        -468.381012,\n                        -491.747223,\n                        -341.553131,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        35.665298, 3.505739, 3.556735, 3.608062, -48.756947, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        181.404663,\n                        16.217594,\n                        16.361111,\n                        16.504629,\n                        -238.136124,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        263.888885,\n                        17.365742,\n                        17.509258,\n                        17.652779,\n                        -326.403656,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        355.643341,\n                        18.513889,\n                        18.657408,\n                        18.800926,\n                        -423.941345,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        318.709198,\n                        14.359658,\n                        14.441552,\n                        14.523109,\n                        -369.819580,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [\n                        0.0, 0.0, 88.846703, 237.478439, 261.731201, 286.289917, 182.508713, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 37.688015, 94.722221, 95.333328, 95.944450, 57.441605, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 40.562500, 99.611107, 100.222229, 100.833336, 59.410744, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 43.527519, 104.500000, 105.111107, 105.722221, 61.289349, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        -258.324371,\n                        -618.353943,\n                        -649.340271,\n                        -680.632507,\n                        -397.101013,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        0.0,\n                        76.229431,\n                        7.564093,\n                        7.641718,\n                        7.719699,\n                        -102.792252,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        272.015167,\n                        23.680555,\n                        23.833332,\n                        23.986113,\n                        -351.944214,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        386.062500,\n                        24.902777,\n                        25.055557,\n                        25.208334,\n                        -472.147888,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        509.978149,\n                        26.125000,\n                        26.277777,\n                        26.430555,\n                        -602.219971,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        378.410248,\n                        17.123661,\n                        17.187500,\n                        17.250984,\n                        -436.000732,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0, 157.623291, 331.938538, 365.283356, 398.952606, 205.988480, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 66.495949, 130.925934, 131.574066, 132.222229, 64.435974, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 70.396835, 136.111115, 136.759262, 137.407410, 65.672256, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 74.393753, 141.296295, 141.944458, 142.592606, 66.812523, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        -432.798035,\n                        -827.492065,\n                        -867.978455,\n                        -908.789368,\n                        -425.074158,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        140.150024,\n                        14.043960,\n                        14.152921,\n                        14.262260,\n                        -187.656906,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        386.813873,\n                        32.731483,\n                        32.893517,\n                        33.055557,\n                        -494.779602,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        538.926697,\n                        34.027779,\n                        34.189816,\n                        34.351852,\n                        -653.421875,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        701.505859,\n                        35.324074,\n                        35.486115,\n                        35.648151,\n                        -822.530640,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        416.044586,\n                        18.903570,\n                        18.944647,\n                        18.985338,\n                        -476.728790,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        249.876541, 435.868500, 479.178772, 522.832031, 207.919815, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        105.417015, 170.611115, 171.296295, 171.981476, 64.750000, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        110.441696, 176.092590, 176.777771, 177.462952, 65.156044, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        115.567902, 181.574066, 182.259247, 182.944458, 65.460571, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        -662.743530,\n                        -1056.641846,\n                        -1107.501953,\n                        -1158.704712,\n                        -409.510162,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        227.160507,\n                        22.982454,\n                        23.125793,\n                        23.269531,\n                        -303.112030,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        518.495178,\n                        42.652779,\n                        42.824074,\n                        42.995369,\n                        -657.157410,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        712.252380,\n                        44.023148,\n                        44.194443,\n                        44.365738,\n                        -857.817200,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        917.074036,\n                        45.393517,\n                        45.564812,\n                        45.736115,\n                        -1069.541626,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        416.581482,\n                        18.997831,\n                        19.013102,\n                        19.027966,\n                        -475.031525,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0, 0.0, 151.750259, 210.166672, 210.888885, 211.611099, 57.506927, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 157.929276, 215.944443, 216.666672, 217.388901, 57.052204, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 164.215271, 221.722229, 222.444458, 223.166672, 56.490482, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        -931.783752,\n                        -1285.353760,\n                        -1346.555908,\n                        -1408.119385,\n                        -346.739044,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        0.0,\n                        655.669983,\n                        52.541668,\n                        52.722221,\n                        52.902775,\n                        -824.946777,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        890.972473,\n                        53.986111,\n                        54.166668,\n                        54.347225,\n                        -1067.525024,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        1137.937500,\n                        55.430557,\n                        55.611115,\n                        55.791668,\n                        -1321.765625,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        375.580566,\n                        17.180984,\n                        17.169498,\n                        17.157579,\n                        -425.993713,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0, 213.521454, 256.629639, 257.388885, 258.148132, 41.652927, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 221.015625, 262.703705, 263.462982, 264.222229, 40.176598, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 228.622284, 268.777802, 269.537048, 270.296295, 38.587788, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        -1285.466797,\n                        -1554.254517,\n                        -1627.530640,\n                        -1701.186646,\n                        -228.291397,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        823.380554,\n                        64.157410,\n                        64.347221,\n                        64.537033,\n                        -1028.532715,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        1107.296509,\n                        65.675926,\n                        65.865746,\n                        66.055557,\n                        -1320.097534,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        1403.473022,\n                        67.194450,\n                        67.384262,\n                        67.574074,\n                        -1623.922974,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        288.151398,\n                        13.201796,\n                        13.158524,\n                        13.114797,\n                        -323.577820,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        288.790131, 306.574066, 307.370361, 308.166656, 15.734239, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        297.696838, 312.944427, 313.740723, 314.537048, 13.138914, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        306.721527, 319.314819, 320.111115, 320.907410, 10.425544, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        -1711.543213,\n                        -1844.013062,\n                        -1930.236572,\n                        -2016.858643,\n                        -46.846100,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        1011.358093,\n                        76.643517,\n                        76.842590,\n                        77.041664,\n                        -1255.045654,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        1347.466431,\n                        78.236107,\n                        78.435181,\n                        78.634262,\n                        -1599.175903,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        1696.433350,\n                        79.828705,\n                        80.027779,\n                        80.226852,\n                        -1956.164917,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        146.703568,\n                        6.690874,\n                        6.612756,\n                        6.534196,\n                        -159.277222,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n            ]],\n            &device,\n        ),\n        weight: TestTensor::from_floats(\n            [\n                [\n                    [\n                        [10.341997, 22.988085, 35.634174],\n                        [46.920216, 59.566299, 72.212387],\n                        [80.881615, 92.591522, 104.158524],\n                    ],\n                    [\n                        [29.213360, 68.837769, 108.462166],\n                        [143.825104, 183.449509, 223.073944],\n                        [228.029373, 256.751740, 283.807098],\n                    ],\n                ],\n                [\n                    [\n                        [10.341997, 22.988085, 35.634174],\n                        [46.920216, 59.566299, 72.212387],\n                        [80.881615, 92.591522, 104.158524],\n                    ],\n                    [\n                        [29.213360, 68.837769, 108.462166],\n                        [143.825104, 183.449509, 223.073944],\n                        [228.029373, 256.751740, 283.807098],\n                    ],\n                ],\n            ],\n            &device,\n        ),\n        mask: TestTensor::from_floats(\n            [[\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 0.0, 0.447531, 7.380288, 8.802088, 10.349031, 11.100212, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 44.343754, 584.937439, 639.250000, 693.562439, 683.262756, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 68.390068, 803.437561, 857.750000, 912.062500, 874.698059, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        96.473381,\n                        1021.937500,\n                        1076.250000,\n                        1130.562500,\n                        1062.095947,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        121.302101,\n                        1168.487915,\n                        1218.373779,\n                        1268.134888,\n                        1169.444702,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        0.0, 13.084491, 75.860909, 83.767761, 91.809029, 80.728188, 0.0, 0.0,\n                    ],\n                    [\n                        0.0, 118.950417, 649.486084, 707.821777, 766.157410, 658.076599, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        170.660782,\n                        884.171265,\n                        942.506958,\n                        1000.842651,\n                        837.809326,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        226.707260,\n                        1118.856445,\n                        1177.192261,\n                        1235.527710,\n                        1013.205933,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        234.939651,\n                        1106.213867,\n                        1153.415649,\n                        1200.482666,\n                        966.248901,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [\n                        42.524002, 153.045700, 168.319275, 183.736511, 138.144653, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        207.319611, 718.432800, 780.791626, 843.150391, 619.975037, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        290.277802,\n                        969.303223,\n                        1031.661987,\n                        1094.020752,\n                        784.421631,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        377.871063,\n                        1220.173584,\n                        1282.532471,\n                        1344.891235,\n                        944.233032,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        328.083038,\n                        1025.494995,\n                        1069.130615,\n                        1112.622192,\n                        766.054932,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                ],\n                [\n                    [\n                        0.0, 0.0, 88.238174, 235.055206, 258.194336, 281.486389, 178.858536, 0.0,\n                    ],\n                    [\n                        0.0, 0.0, 305.575500, 789.868042, 856.250061, 922.631897, 572.466064, 0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        421.809021,\n                        1056.923584,\n                        1123.305542,\n                        1189.687500,\n                        719.598816,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        542.976746,\n                        1323.979248,\n                        1390.361206,\n                        1456.743042,\n                        861.797302,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        393.291565,\n                        934.439697,\n                        974.010376,\n                        1013.428101,\n                        586.924011,\n                        0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0, 157.214920, 330.227448, 362.473419, 394.881653, 203.374420, 0.0, 0.0,\n                    ],\n                    [\n                        0.0,\n                        424.340576,\n                        867.495361,\n                        937.900452,\n                        1008.305542,\n                        505.640503,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        578.894897,\n                        1150.736084,\n                        1221.141235,\n                        1291.546265,\n                        630.414001,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        738.682495,\n                        1433.976929,\n                        1504.381958,\n                        1574.787109,\n                        749.954346,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0, 429.912781, 816.507507, 850.771973, 884.873779, 411.152588, 0.0, 0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        249.876541, 434.964233, 477.198730, 519.604675, 206.215576, 0.0, 0.0, 0.0,\n                    ],\n                    [\n                        560.309326,\n                        949.520813,\n                        1023.949097,\n                        1098.377319,\n                        422.458344,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        756.768127,\n                        1248.946777,\n                        1323.375000,\n                        1397.803223,\n                        521.289001,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        958.759216,\n                        1548.372803,\n                        1622.800903,\n                        1697.229248,\n                        614.587402,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        428.833923, 679.269775, 707.346252, 735.250916, 258.169373, 0.0, 0.0, 0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        0.0,\n                        707.671387,\n                        1033.687378,\n                        1112.138916,\n                        1190.590210,\n                        328.295044,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        947.779419,\n                        1349.298584,\n                        1427.750000,\n                        1506.201416,\n                        399.438080,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        0.0,\n                        1193.718872,\n                        1664.909668,\n                        1743.361084,\n                        1821.812500,\n                        464.749847,\n                        0.0,\n                    ],\n                    [\n                        0.0, 0.0, 388.737854, 532.503540, 553.962891, 575.241089, 140.658310, 0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        0.0,\n                        880.797302,\n                        1124.393555,\n                        1206.868042,\n                        1289.342651,\n                        209.627625,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        1169.882812,\n                        1456.189819,\n                        1538.664429,\n                        1621.138916,\n                        247.754730,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0,\n                        1465.098755,\n                        1787.986084,\n                        1870.460571,\n                        1952.935181,\n                        279.751526,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        0.0, 297.330719, 356.362152, 369.893524, 383.234344, 50.974621, 0.0, 0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                [\n                    [\n                        1074.567993,\n                        1219.497681,\n                        1305.995361,\n                        1392.493042,\n                        71.162437,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        1416.214722,\n                        1567.479126,\n                        1653.976929,\n                        1740.474609,\n                        72.689949,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        1764.290771,\n                        1915.460571,\n                        2001.958496,\n                        2088.456055,\n                        67.787628,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ],\n                    [\n                        151.018372, 160.055023, 164.776138, 169.298447, 3.865937, 0.0, 0.0, 0.0,\n                    ],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n            ]],\n            &device,\n        ),\n        bias: TestTensor::from_floats([48., 48.], &device),\n    };\n    test.assert_grads(grads);\n}\n\nstruct Conv2dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    groups: usize,\n    offset_groups: usize,\n    height: usize,\n    width: usize,\n}\n\nstruct Grads {\n    x: TestTensor<4>,\n    offset: TestTensor<4>,\n    weight: TestTensor<4>,\n    mask: TestTensor<4>,\n    bias: TestTensor<1>,\n}\n\nimpl Conv2dTestCase {\n    fn assert_grads(self, expected_grads: Grads) {\n        let out_height =\n            (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)\n                / self.stride_1\n                + 1;\n        let out_width =\n            (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)\n                / self.stride_2\n                + 1;\n\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let shape_offset = Shape::new([\n            self.batch_size,\n            2 * self.offset_groups * self.kernel_size_1 * self.kernel_size_2,\n            out_height,\n            out_width,\n        ]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n        ]);\n        let shape_mask = Shape::new([\n            self.batch_size,\n            self.offset_groups * self.kernel_size_1 * self.kernel_size_2,\n            out_height,\n            out_width,\n        ]);\n        let device = Default::default();\n        let weight = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weight)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let bias = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        )\n        .require_grad();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n        let offset = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_offset.clone())\n                .into_data(),\n            &device,\n        )\n        .div_scalar(shape_offset.num_elements() as f32)\n        .require_grad();\n\n        let mask = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_mask.clone())\n                .into_data(),\n            &device,\n        )\n        .div_scalar(shape_mask.num_elements() as f32)\n        .require_grad();\n\n        let output = deform_conv2d(\n            x.clone(),\n            offset.clone(),\n            weight.clone(),\n            Some(mask.clone()),\n            Some(bias.clone()),\n            DeformConvOptions::new(\n                [self.stride_1, self.stride_2],\n                [self.padding_1, self.padding_2],\n                [self.dilation_1, self.dilation_2],\n                self.groups,\n                self.offset_groups,\n            ),\n        );\n        let grads = output.backward();\n\n        // Assert\n        let x_grad_actual = x.grad(&grads).unwrap();\n        let offset_grad_actual = offset.grad(&grads).unwrap();\n        let weight_grad_actual = weight.grad(&grads).unwrap();\n        let mask_grad_actual = mask.grad(&grads).unwrap();\n        let bias_grad_actual = bias.grad(&grads).unwrap();\n\n        // Relative is set to 5%, which is much higher than typical numerical test tolerances.\n        // This is due to the complexity of the deformable convolution operation.\n        // Unlike regular conv2d, which samples from fixed integer grid positions,\n        // deformable conv2d samples input values at fractional offset locations (learned offsets).\n        // These non-integer positions require bilinear interpolation to estimate the input value.\n        // Gradients computed through all these floating-point operations can compound numerical differences.\n        let tolerance = Tolerance::relative(0.5);\n\n        println!(\"Testing bias\");\n        expected_grads\n            .bias\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&bias_grad_actual.to_data(), tolerance);\n        println!(\"Testing input\");\n        expected_grads\n            .x\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), tolerance);\n        println!(\"Testing offset\");\n        expected_grads\n            .offset\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&offset_grad_actual.to_data(), tolerance);\n        println!(\"Testing mask\");\n        expected_grads\n            .mask\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&mask_grad_actual.to_data(), tolerance);\n        println!(\"Testing weight\");\n        expected_grads\n            .weight\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&weight_grad_actual.to_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/div.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_div() {\n    let data_1 = TensorData::from([1.0, 7.0]);\n    let data_2 = TensorData::from([4.0, 7.0]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().div(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([0.25, 0.14285715]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([-0.0625, -0.14285715]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_div_scalar() {\n    let data = TensorData::from([1.0, 7.0]);\n\n    let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();\n    let tensor_out = tensor.clone().div_scalar(4.0);\n\n    let grads = tensor_out.backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    grad.to_data()\n        .assert_eq(&TensorData::from([0.25, 0.25]), false);\n}\n\n#[test]\nfn test_div_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().div(tensor_2.clone());\n    let tensor_5 = tensor_4.div(tensor_3.clone());\n\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n    let grad_3 = tensor_3.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[0.1250, 0.07142857], [0.25, 0.16666667]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[-0.03125, -0.07142857], [-1.6250, 0.16666667]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n    let expected = TensorData::from([[-0.0625, -0.25], [-1.6250, 0.25]]);\n    grad_3\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_div_complex_2() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.div(tensor_2.clone());\n\n    let grads = tensor_4.backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_absolute(2e-3);\n    let expected = TensorData::from([[2.00, 2.92857146], [1.36666667, 2.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[0.08333334, 0.09591837], [-0.05555558, -0.06714284]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/erf.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_erf() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().erf());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[8.0, 8.0], [8.0, 8.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/exp.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_exp() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().exp());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default();\n    let expected = TensorData::from([[54.5991, 27.4746], [54.5991, 27.4746]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[-5.4598e+01, -9.1188e-04], [2.9556e+01, 8.0342e+01]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/expand.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_expand() {\n    // Python code to generate the test case values\n    // import torch\n    // x1 = torch.tensor([4.0, 7.0, 2.0, 3.0], requires_grad=True)\n    // x2 = torch.tensor([2.0, 4.5, 7.0, 3.0], requires_grad=True)\n    // y = x1.expand(4, 4)\n    // z = (x2 * y).sum()\n    // z.backward()\n    // print(\"x1\", x1.grad)\n    // print(\"x2\", x2.grad)\n\n    let device = Default::default();\n\n    let data_1 = TensorData::from([4.0, 7.0, 2.0, 3.0]);\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();\n\n    let data_2 = TensorData::from([2.0, 4.5, 7.0, 3.0]);\n    let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().expand([4, 4]);\n\n    // Use unsqueeze to make tensor_2 have the same shape as tensor_3\n    let tensor_4 = tensor_2.clone().unsqueeze().mul(tensor_3).sum();\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([8., 18., 28., 12.]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([16., 28., 8., 12.]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/flip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_flip() {\n    let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2\n    let data_2 = TensorData::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<3>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().flip([1, 2]);\n    let tensor_4 = tensor_1.clone().matmul(tensor_3);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2\n    grad_2.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]),\n        tolerance,\n    ); // 1x2x3\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/floor.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_floor() {\n    let data = TensorData::from([\n        [-1.9751, 0.0714, 0.0643, 0.2406],\n        [-1.3172, 0.1252, -0.1119, -0.0127],\n    ]);\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n    let tensor_2 = tensor_1.clone().floor();\n    let grads = tensor_2.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/gather_scatter.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};\n\n#[test]\nfn test_gather_grad() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data(\n        TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),\n        &device,\n    )\n    .require_grad();\n    let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(\n        TensorData::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]),\n        &device,\n    );\n\n    let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());\n    let tensor_3 = tensor_1.clone().gather(1, indices);\n    let tensor_4 = tensor_2.matmul(tensor_3);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[94., 150., 187.], [242., 305., 304.]]),\n        false,\n    );\n}\n\n#[test]\nfn test_scatter_grad() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data(\n        TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),\n        &device,\n    )\n    .require_grad();\n    let values = TestAutodiffTensor::from_data(\n        TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n        &device,\n    )\n    .require_grad();\n    let indices = Tensor::<TestAutodiffBackend, 2, Int>::from_data(\n        TensorData::from([[2, 1, 0], [2, 0, 1]]),\n        &device,\n    );\n\n    let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());\n    let tensor_3 = tensor_1\n        .clone()\n        .scatter(1, indices, values.clone(), IndexingUpdateOp::Add);\n    let tensor_4 = tensor_2.matmul(tensor_3);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = values.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[127., 181., 235.], [226., 316., 406.]]),\n        false,\n    );\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[19., 19., 19.], [64., 64., 64.]]), false);\n}\n\n#[test]\nfn test_scatter_add_grad_partial_indices() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::from_data(TensorData::from([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]), &device)\n            .require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_data(TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]), &device)\n            .require_grad();\n    let values =\n        TestAutodiffTensor::from_data(TensorData::from([[4.0, 5.0, 6.0]]), &device).require_grad();\n    let indices =\n        Tensor::<TestAutodiffBackend, 2, Int>::from_data(TensorData::from([[2, 1, 0]]), &device);\n\n    let tensor_3 = tensor_1.clone().mul(tensor_2);\n    let tensor_4 = tensor_3\n        .clone()\n        .scatter(1, indices, values.clone(), IndexingUpdateOp::Add);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = values.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[1., 2., 3., 4., 5., 6.]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[1., 1., 1.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/gelu.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance, activation};\n\n#[test]\nfn should_diff_gelu() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_floats([[0.0, 1.0], [-3.0, 4.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[6.0, -0.5], [9.0, 10.0]], &device).require_grad();\n\n    let x = tensor_1.clone().matmul(activation::gelu(tensor_2.clone()));\n    let x = tensor_1.clone().matmul(x);\n    let grads = x.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::permissive();\n    let expected = TensorData::from([[1.46281, 1.46281], [48.22866, 153.46280]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[-15.0000, -1.98757], [17.0000, 17.0000]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/gradients.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, activation};\n\n#[test]\nfn should_update_tensor_when_grad_replace() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::random([32, 32], Distribution::Default, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::random([32, 32], Distribution::Default, &device);\n\n    let x = tensor_1.clone().matmul(activation::gelu(tensor_2));\n    let mut grads = x.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    let grad_1_updated =\n        TestAutodiffTensor::random([32, 32], Distribution::Default, &device).require_grad();\n    tensor_1.grad_replace(&mut grads, grad_1_updated.clone().inner());\n\n    let grad_1_new = tensor_1.grad(&grads).unwrap();\n\n    assert_ne!(grad_1_new.to_data(), grad_1.into_data());\n    assert_eq!(grad_1_new.into_data(), grad_1_updated.into_data());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/log.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_diff_log() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    let expected = TensorData::from([[60.2652, 72.3130], [60.2652, 72.3130]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[22.8614, 24.5043], [24.5729, 26.8507]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/log1p.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_log1p() {\n    let tensor_1 = TestAutodiffTensor::<2>::from([[0.0, 1.0], [3.0, 4.0]]).require_grad();\n    let tensor_2 = TestAutodiffTensor::from([[6.0, 7.0], [9.0, 10.0]]).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().log1p());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    let expected = TensorData::from([[64.80622101, 75.49362183], [64.80622101, 75.49362183]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[22.92208481, 24.47565651], [24.72780228, 26.86416626]]);\n\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/log_sigmoid.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn should_diff_log_sigmoid() {\n    let data = TensorData::from([[0.8762, -0.1423], [-300., 200.]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n    let tensor_2 = activation::log_sigmoid(tensor_1.clone());\n    let grads = tensor_2.backward();\n\n    let grad = tensor_1.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[0.293966, 0.535515], [1.000000, 0.000000]]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/mask.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Bool, Tensor, TensorData};\n\n#[test]\nfn should_diff_mask_fill() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let mask = TensorData::from([[true, false], [false, true]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device);\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.mask_fill(mask, 2.0);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[7.0, 3.0], [4.0, 2.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[2.0, 1.0], [3.0, 7.0]]), false);\n}\n\n#[test]\nfn should_diff_mask_where() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data([[4.0, 7.0], [2.0, 3.0]], &device).require_grad();\n    let tensor_3 =\n        TestAutodiffTensor::from_data([[8.8, 9.8], [10.8, 11.8]], &device).require_grad();\n    let mask =\n        Tensor::<TestAutodiffBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);\n\n    let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_5 = tensor_4.clone().matmul(tensor_3.clone());\n    let tensor_6 = tensor_5.mask_where(mask, tensor_3.clone());\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n    let grad_3 = tensor_3.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    let expected = TensorData::from([[121.8, 55.0], [110.8, 50.0]]);\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[27.4, 33.4], [95.0, 115.0]]);\n    grad_2\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[15., 18.], [23., 29.]]);\n    grad_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/matmul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_matmul() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);\n    tensor_3\n        .to_data()\n        .assert_eq(&TensorData::from([[18.0, 28.0], [14.0, 23.0]]), false);\n}\n\n#[test]\nfn test_matmul_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_5 = tensor_4.matmul(tensor_3);\n\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[44.0, 20.0], [44.0, 20.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[56.0, 56.0], [16.0, 16.0]]), false);\n}\n\n#[test]\nfn test_matmul_complex_2() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_5 = tensor_4.matmul(tensor_3.clone());\n    let tensor_6 = tensor_1.clone().matmul(tensor_5);\n\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[800.0, 792.0], [360.0, 592.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[264., 264.0], [344.0, 344.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/maxmin.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_max_dim() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.max_dim(1).unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[50.0, 34.0], [40.0, -10.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[8.0, 10.0], [56.0, 15.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_min_dim() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.min_dim(1).unsqueeze());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[-42.0, 38.0], [-34.0, -24.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[10.0, 8.0], [15.0, 56.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_min_dim_3d_dim1() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::<3>::from_floats([[[4., -7.], [2., 3.]]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().mul(tensor_2.clone());\n    let tensor_4 = tensor_3.min_dim(1);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[[0., -7.], [2., 0.]]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[[0., 7.], [-2., -0.]]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/maxpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::max_pool1d;\n\n#[test]\nfn test_max_pool1d_simple() {\n    let kernel_size = 4;\n    let padding = 0;\n    let stride = 1;\n    let dilation = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected =\n        TestAutodiffTensor::<3>::from_floats([[[1., 1., 0., 0., 0., 1.]]], &device);\n\n    let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_with_dilation() {\n    let kernel_size = 4;\n    let padding = 0;\n    let stride = 1;\n    let dilation = 2;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,\n            0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,\n            0.4610, 0.5365, 0.6880,\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<3>::from_floats(\n        [[[\n            0., 0., 1., 0., 0., 3., 0., 1., 2., 1., 0., 0., 2., 0., 0., 0., 4., 4., 0., 0., 0., 0.,\n            0., 0., 1.,\n        ]]],\n        &device,\n    );\n\n    let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_complex() {\n    let kernel_size = 4;\n    let padding = 0;\n    let stride = 1;\n    let dilation = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,\n            0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,\n            0.4610, 0.5365, 0.6880,\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<3>::from_floats(\n        [[[\n            0., 0., 0., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,\n            1., 1., 1.,\n        ]]],\n        &device,\n    );\n\n    let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_complex_with_padding() {\n    let kernel_size = 4;\n    let padding = 2;\n    let stride = 1;\n    let dilation = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            0.5388, 0.0676, 0.7122, 0.8316, 0.0653, 0.9154, 0.1536, 0.9089, 0.8016, 0.7518, 0.2073,\n            0.0501, 0.8811, 0.5604, 0.5075, 0.4384, 0.9963, 0.9698, 0.4988, 0.2609, 0.3391, 0.2230,\n            0.4610, 0.5365, 0.6880,\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<3>::from_floats(\n        [[[\n            1., 0., 1., 2., 0., 4., 0., 2., 1., 0., 0., 0., 4., 0., 0., 0., 4., 1., 1., 0., 0., 0.,\n            1., 1., 3.,\n        ]]],\n        &device,\n    );\n\n    let output = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/maxpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::max_pool2d;\n\n#[test]\nfn test_max_pool2d_simple_1() {\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 0;\n    let padding_2 = 0;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            [0.2479, 0.6386, 0.3166, 0.5742],\n            [0.7065, 0.1940, 0.6305, 0.8959],\n            [0.5416, 0.8602, 0.8129, 0.1662],\n            [0.3358, 0.3059, 0.8293, 0.0990],\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<4>::from_floats(\n        [[[\n            [0.0, 0.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0, 2.0],\n            [0.0, 2.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0, 0.0],\n        ]]],\n        &device,\n    );\n\n    let output = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_simple_2() {\n    let kernel_size_1 = 2;\n    let kernel_size_2 = 2;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            [0.2479, 0.6386, 0.3166, 0.5742],\n            [0.7065, 0.1940, 0.6305, 0.8959],\n            [0.5416, 0.8602, 0.8129, 0.1662],\n            [0.3358, 0.3059, 0.8293, 0.0990],\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<4>::from_floats(\n        [[[\n            [1., 3., 0., 2.],\n            [3., 0., 0., 4.],\n            [1., 4., 0., 1.],\n            [2., 0., 3., 1.],\n        ]]],\n        &device,\n    );\n\n    let output = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_with_dilation() {\n    let kernel_size_1 = 2;\n    let kernel_size_2 = 2;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 2;\n    let dilation_2 = 2;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            [0.2479, 0.6386, 0.3166, 0.5742],\n            [0.7065, 0.1940, 0.6305, 0.8959],\n            [0.5416, 0.8602, 0.8129, 0.1662],\n            [0.3358, 0.3059, 0.8293, 0.0990],\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<4>::from_floats(\n        [[[\n            [0., 0., 0., 0.],\n            [1., 1., 1., 2.],\n            [0., 4., 4., 0.],\n            [0., 1., 2., 0.],\n        ]]],\n        &device,\n    );\n\n    let output = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_complex() {\n    let kernel_size_1 = 4;\n    let kernel_size_2 = 2;\n    let padding_1 = 2;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            [0.5388, 0.0676, 0.7122, 0.8316, 0.0653],\n            [0.9154, 0.1536, 0.9089, 0.8016, 0.7518],\n            [0.2073, 0.0501, 0.8811, 0.5604, 0.5075],\n            [0.4384, 0.9963, 0.9698, 0.4988, 0.2609],\n            [0.3391, 0.2230, 0.4610, 0.5365, 0.6880],\n        ]]],\n        &device,\n    )\n    .require_grad();\n    let x_grad_expected = TestAutodiffTensor::<4>::from_floats(\n        [[[\n            [0., 0., 0., 3., 0.],\n            [4., 0., 2., 1., 0.],\n            [0., 0., 0., 0., 0.],\n            [2., 4., 0., 0., 0.],\n            [0., 0., 0., 0., 2.],\n        ]]],\n        &device,\n    );\n\n    let output = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_ceil_mode() {\n    // Test ceil_mode=true with gradient computation\n    // Using 1x1x6x6 input with kernel 3x3, stride 2x2, padding 0\n    // Floor mode: output 2x2\n    // Ceil mode: output 3x3\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 0;\n    let padding_2 = 0;\n    let stride_1 = 2;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let device = Default::default();\n    // Input (values 1-36):\n    let x = TestAutodiffTensor::from_floats(\n        [[[\n            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],\n            [7.0, 8.0, 9.0, 10.0, 11.0, 12.0],\n            [13.0, 14.0, 15.0, 16.0, 17.0, 18.0],\n            [19.0, 20.0, 21.0, 22.0, 23.0, 24.0],\n            [25.0, 26.0, 27.0, 28.0, 29.0, 30.0],\n            [31.0, 32.0, 33.0, 34.0, 35.0, 36.0],\n        ]]],\n        &device,\n    )\n    .require_grad();\n\n    // Expected gradients for ceil_mode output 3x3:\n    // Output positions and their max value positions:\n    // (0,0): max at (2,2)=15 -> grad[2,2] += 1\n    // (0,1): max at (2,4)=17 -> grad[2,4] += 1\n    // (0,2): max at (2,5)=18 -> grad[2,5] += 1\n    // (1,0): max at (4,2)=27 -> grad[4,2] += 1\n    // (1,1): max at (4,4)=29 -> grad[4,4] += 1\n    // (1,2): max at (4,5)=30 -> grad[4,5] += 1\n    // (2,0): max at (5,2)=33 -> grad[5,2] += 1\n    // (2,1): max at (5,4)=35 -> grad[5,4] += 1\n    // (2,2): max at (5,5)=36 -> grad[5,5] += 1\n    let x_grad_expected = TestAutodiffTensor::<4>::from_floats(\n        [[[\n            [0., 0., 0., 0., 0., 0.],\n            [0., 0., 0., 0., 0., 0.],\n            [0., 0., 1., 0., 1., 1.],\n            [0., 0., 0., 0., 0., 0.],\n            [0., 0., 1., 0., 1., 1.],\n            [0., 0., 1., 0., 1., 1.],\n        ]]],\n        &device,\n    );\n\n    let output = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        true,\n    );\n    let grads = output.backward();\n\n    // Asserts\n    let x_grad_actual = x.grad(&grads).unwrap();\n    x_grad_expected\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&x_grad_actual.to_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/memory_management.rs",
    "content": "use super::*;\nuse burn_tensor::{Tensor, TensorData};\n\n#[test]\nfn test_mm_independent_trees() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3;\n    let tensor_6 = tensor_4 * tensor_5;\n\n    // Second tree\n    let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_11 = tensor_7.clone() * tensor_8.clone();\n    let tensor_12 = tensor_9.clone() * tensor_10.clone();\n    let tensor_13 = tensor_11 * tensor_12;\n\n    let _grads = tensor_6.backward();\n    let grads = tensor_13.backward();\n\n    assert!(tensor_7.grad(&grads).is_some());\n    assert!(tensor_8.grad(&grads).is_some());\n    assert!(tensor_9.grad(&grads).is_some());\n    assert!(tensor_10.grad(&grads).is_some());\n}\n\n#[test]\n#[should_panic]\nfn test_mm_crossover_trees_root_unavailable() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3;\n    let tensor_6 = tensor_4.clone() * tensor_5;\n\n    // Second tree\n    let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_9 = tensor_7.clone() * tensor_8.clone();\n    let tensor_10 = tensor_4 * tensor_9;\n\n    let _grads = tensor_6.backward();\n    let _grads = tensor_10.backward();\n}\n\n#[test]\nfn test_mm_crossover_trees_with_referred_subtree() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3;\n    let tensor_6 = tensor_4.clone() * tensor_5;\n\n    // Second tree\n    let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_9 = tensor_7.clone() * tensor_8.clone();\n    let _tensor_10 = tensor_4 * tensor_9.clone();\n\n    let _grads = tensor_6.backward();\n    let _grads = tensor_9.backward();\n}\n\n#[test]\nfn test_mm_three_crossover_trees_last_still_usable() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3;\n    let tensor_6 = tensor_4 * tensor_5.clone();\n\n    // Third tree\n    let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_11 = tensor_7 * tensor_8;\n    let tensor_12 = tensor_9 * tensor_10;\n    let tensor_13 = tensor_11 * tensor_12.clone();\n\n    // Second tree (in between)\n    let _tensor_14 = tensor_5 * tensor_12;\n\n    let _grads = tensor_6.backward();\n    let _grads = tensor_13.backward();\n}\n\n#[test]\n#[should_panic]\nfn test_mm_three_crossover_trees_middle_one_unavailable() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3;\n    let tensor_6 = tensor_4 * tensor_5.clone();\n\n    // Third tree\n    let tensor_7 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_8 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_9 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_10 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_11 = tensor_7 * tensor_8;\n    let tensor_12 = tensor_9 * tensor_10;\n    let _tensor_13 = tensor_11 * tensor_12.clone();\n\n    // Second tree (in between)\n    let tensor_14 = tensor_5 * tensor_12;\n\n    let _grads = tensor_6.backward();\n    let _grads = tensor_14.backward();\n}\n\n#[test]\nfn test_mm_self_referencing_tree() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // First tree\n    let tensor_0 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data.clone(), &device).require_grad();\n\n    let tensor_3 = tensor_0 * tensor_1;\n    let tensor_5 = tensor_2 * tensor_3.clone();\n    let tensor_6 = tensor_3 * tensor_5;\n\n    let _grads = tensor_6.backward();\n}\n\n#[test]\nfn test_mm_with_non_impacting_detach() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n    let tensor_1 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_2 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone() * tensor_2.clone();\n    let tensor_5 = tensor_4.detach() * tensor_3.clone();\n\n    let grads = tensor_5.backward();\n    assert!(tensor_3.grad(&grads).is_some());\n}\n\n#[test]\nfn test_mm_with_missing_require_grad_after_cleanup() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    let tensor_1 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);\n    let tensor_3 = Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device);\n\n    let tensor_4 = tensor_1.clone() * tensor_2.clone();\n    let tensor_5 = tensor_4 * tensor_3.clone();\n\n    // Trivial backward, just to trigger cleanup\n    Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)\n        .require_grad()\n        .backward();\n\n    let grads = tensor_5.backward();\n    assert!(tensor_1.grad(&grads).is_some());\n    assert!(tensor_2.grad(&grads).is_none());\n    assert!(tensor_3.grad(&grads).is_none());\n}\n\n#[test]\nfn test_mm_with_detach_after_cleanup() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    let tensor_1 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_2 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_3 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n\n    let tensor_4 = tensor_1.clone() * tensor_2.clone();\n    let tensor_5 = tensor_4 * tensor_3.clone().detach();\n\n    // Trivial backward, just to trigger cleanup\n    Tensor::<TestAutodiffBackend, 2>::from_data(data, &device)\n        .require_grad()\n        .backward();\n\n    let grads = tensor_5.backward();\n    assert!(tensor_1.grad(&grads).is_some());\n    assert!(tensor_2.grad(&grads).is_some());\n    assert!(tensor_3.grad(&grads).is_none());\n}\n\n#[test]\n#[should_panic]\nfn test_mm_deletables_propagate_well() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    let tensor_0 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n    let tensor_1 =\n        Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n\n    let tensor_2 = tensor_0 * tensor_1;\n    let tensor_3 = tensor_2.clone().exp();\n    let _tensor_4 = tensor_3.clone().log();\n\n    let _grads = tensor_2.backward();\n\n    // We are testing that after backward on tensor_2, not only the leaf tensor_4 is deleted, but\n    // the intermediate tensor_3 as well\n    let _grads = tensor_3.backward();\n}\n\n#[test]\nfn test_mm_node_explored_once_can_still_be_tagged_as_useful_when_found_again_deeper() {\n    let data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let device = Default::default();\n\n    // The test has 50% chance of starting with leaf tensor_8 instead of tensor_4, which is not informative\n    // By repeating it many times it becomes almost impossible that it passes if it shouldn't\n    for _ in 0..12 {\n        let tensor_0 =\n            Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n        let tensor_1 =\n            Tensor::<TestAutodiffBackend, 2>::from_data(data.clone(), &device).require_grad();\n\n        let tensor_2 = tensor_1.clone().exp();\n        let tensor_3 = tensor_0.exp();\n        let _tensor_4 = tensor_3.clone() * tensor_2.clone();\n        let tensor_5 = tensor_2.exp();\n        let tensor_6 = tensor_5.exp();\n        let tensor_7 = tensor_6.exp();\n        let tensor_8 = tensor_7.exp();\n\n        // tensor_2 should be tagged unknown through the leaf tensor_4, then useful through the leaf tensor_8\n        // which should happen after because tensor_2 is deeper from tensor_8 point of view and we're in breadth first search\n        tensor_3.backward();\n        let grads = tensor_8.backward();\n\n        assert!(tensor_1.grad(&grads).is_some());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/mod.rs",
    "content": "#[allow(unused_imports)] // required for re-included modules\npub use super::*;\n\nmod abs;\nmod adaptive_avgpool1d;\nmod adaptive_avgpool2d;\nmod add;\nmod aggregation;\nmod avgpool1d;\nmod avgpool2d;\nmod backward;\nmod bridge;\nmod broadcast;\nmod cast;\nmod cat;\nmod ceil;\nmod checkpoint;\nmod complex;\nmod conv1d;\nmod conv2d;\nmod conv3d;\nmod conv_transpose1d;\nmod conv_transpose2d;\nmod conv_transpose3d;\nmod cross;\nmod cross_entropy;\nmod cummax;\nmod cummin;\nmod cumprod;\nmod cumsum;\nmod deform_conv2d;\nmod div;\nmod erf;\nmod exp;\nmod expand;\nmod flip;\nmod floor;\nmod gather_scatter;\nmod gelu;\nmod gradients;\nmod log;\nmod log1p;\nmod log_sigmoid;\nmod mask;\nmod matmul;\nmod maxmin;\nmod maxpool1d;\nmod maxpool2d;\nmod memory_management;\nmod mul;\nmod multithread;\nmod nearest_interpolate;\nmod neg;\nmod nonzero;\nmod permute;\nmod pow;\nmod recip;\nmod relu;\nmod remainder;\nmod repeat_dim;\nmod reshape;\nmod round;\nmod select;\nmod sigmoid;\nmod sign;\nmod slice;\nmod slice_assign;\nmod softmax;\nmod sort;\nmod sqrt;\nmod sub;\nmod transpose;\nmod trig;\nmod unfold;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/mul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_mul() {\n    let data_1 = TensorData::from([1.0, 7.0]);\n    let data_2 = TensorData::from([4.0, 7.0]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1.clone(), &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().mul(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let _grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(&data_2, false);\n    tensor_3\n        .into_data()\n        .assert_eq(&TensorData::from([4.0, 49.0]), false);\n}\n\n#[test]\nfn should_diff_mul_scalar() {\n    let data = TensorData::from([2.0, 5.0]);\n\n    let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();\n    let tensor_out = tensor.clone().mul_scalar(4.0);\n\n    let grads = tensor_out.backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    tensor_out\n        .into_data()\n        .assert_eq(&TensorData::from([8.0, 20.0]), false);\n    grad.to_data()\n        .assert_eq(&TensorData::from([4.0, 4.0]), false);\n}\n\n#[test]\nfn test_mul_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().mul(tensor_2.clone());\n    let tensor_5 = tensor_4.mul(tensor_3);\n    let tensor_6 = tensor_1.clone().mul(tensor_5);\n\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[16.0, 196.0], [104.0, -36.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[2.0, 98.0], [338.0, 18.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/multithread.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_behave_the_same_with_multithread() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let with_move = || {\n        let device = Default::default();\n        let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();\n        let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();\n\n        let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n        let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());\n        let tensor_5 = tensor_4.matmul(tensor_3);\n\n        // Task 1\n        let tensor_1_cloned = tensor_1.clone();\n        let tensor_2_cloned = tensor_2.clone();\n        let tensor_5_cloned = tensor_5.clone();\n\n        let first_call = move || {\n            let tensor_6_1 = tensor_5_cloned.matmul(tensor_2_cloned);\n            tensor_6_1.matmul(tensor_1_cloned)\n        };\n\n        // Task 2\n        let tensor_1_cloned = tensor_1.clone();\n        let tensor_2_cloned = tensor_2.clone();\n        let tensor_5_cloned = tensor_5;\n\n        let second_call = move || {\n            let tensor_6_2 = tensor_5_cloned.matmul(tensor_1_cloned);\n            tensor_6_2.matmul(tensor_2_cloned)\n        };\n\n        let tensor_7_1_handle = std::thread::spawn(first_call);\n        let tensor_7_2_handle = std::thread::spawn(second_call);\n\n        let tensor_7_1 = tensor_7_1_handle.join().unwrap();\n        let tensor_7_2 = tensor_7_2_handle.join().unwrap();\n        let tensor_8 = tensor_7_1.matmul(tensor_7_2);\n\n        let grads = tensor_8.backward();\n\n        let grad_1 = tensor_1.grad(&grads).unwrap();\n        let grad_2 = tensor_2.grad(&grads).unwrap();\n\n        (grad_1, grad_2)\n    };\n    let without_move = || {\n        let device = Default::default();\n        let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1.clone(), &device).require_grad();\n        let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();\n\n        let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n        let tensor_4 = tensor_3.clone().matmul(tensor_2.clone());\n        let tensor_5 = tensor_4.matmul(tensor_3);\n\n        // Task 1\n        let tensor_6_1 = tensor_5.clone().matmul(tensor_2.clone());\n        let tensor_7_1 = tensor_6_1.matmul(tensor_1.clone());\n\n        // Task 2\n        let tensor_6_2 = tensor_5.matmul(tensor_1.clone());\n        let tensor_7_2 = tensor_6_2.matmul(tensor_2.clone());\n\n        let tensor_8 = tensor_7_1.matmul(tensor_7_2);\n\n        let grads = tensor_8.backward();\n\n        let grad_1 = tensor_1.grad(&grads).unwrap();\n        let grad_2 = tensor_2.grad(&grads).unwrap();\n\n        (grad_1, grad_2)\n    };\n\n    let (grad_1, grad_2) = without_move();\n    let (grad_1_moved, grad_2_moved) = with_move();\n\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&grad_1_moved.into_data(), Tolerance::default());\n    grad_2\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&grad_2_moved.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/nearest_interpolate.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::interpolate;\nuse burn_tensor::ops::{InterpolateMode, InterpolateOptions};\n\n#[test]\nfn test_upsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 2,\n        channels: 1,\n        height: 7,\n        width: 5,\n        height_out: 8,\n        width_out: 7,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[\n            [4., 2., 4., 2., 2.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n        ]],\n        [[\n            [4., 2., 4., 2., 2.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n            [2., 1., 2., 1., 1.],\n        ]],\n    ]));\n}\n\n#[test]\nfn test_downsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 8,\n        width: 8,\n        height_out: 4,\n        width_out: 6,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [1., 1., 1., 0., 1., 1., 1., 0.],\n        [0., 0., 0., 0., 0., 0., 0., 0.],\n        [1., 1., 1., 0., 1., 1., 1., 0.],\n        [0., 0., 0., 0., 0., 0., 0., 0.],\n        [1., 1., 1., 0., 1., 1., 1., 0.],\n        [0., 0., 0., 0., 0., 0., 0., 0.],\n        [1., 1., 1., 0., 1., 1., 1., 0.],\n        [0., 0., 0., 0., 0., 0., 0., 0.],\n    ]]]));\n}\n\nstruct InterpolateTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl InterpolateTestCase {\n    fn assert_output(self, x_grad: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let device = Default::default();\n        let x = TestAutodiffTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n            &device,\n        )\n        .require_grad();\n\n        let output = interpolate(\n            x.clone(),\n            [self.height_out, self.width_out],\n            InterpolateOptions::new(InterpolateMode::Nearest),\n        );\n\n        let grads = output.backward();\n        let x_grad_actual = x.grad(&grads).unwrap();\n\n        x_grad\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&x_grad_actual.into_data(), Tolerance::permissive());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/neg.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_neg() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().neg());\n    let tensor_4 = tensor_3.neg();\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[3.0, 3.0], [10.0, 10.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/nonzero.rs",
    "content": "use super::*;\nuse burn_tensor::{Bool, Tensor, TensorData};\n\n#[test]\nfn should_diff_nonzero() {\n    let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([-1.0, 1.0]);\n    let mask = TensorData::from([[false, true], [true, false]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();\n\n    // Multi-dimensional tensor indexing isn't really supported yet so the easiest way to do\n    // this is to flatten the mask and tensor to get proper indexing. Anyway the returned tensor would\n    // have dimensions different from the input, so this is somewhat equivalent.\n    let mask = Tensor::<TestAutodiffBackend, 2, Bool>::from_bool(mask, &device).flatten::<1>(0, 1);\n    let indices = mask.nonzero();\n    let tensor_3 = tensor_1\n        .clone()\n        .flatten::<1>(0, 1)\n        .select(0, indices[0].clone());\n\n    // Vector dot product not supported (only 2D matmuls) so unsqueeze for test purposes\n    let tensor_4 = tensor_2\n        .clone()\n        .unsqueeze_dim::<2>(0)\n        .matmul(tensor_3.unsqueeze_dim(1));\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[0.0, -1.0], [1.0, 0.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([2.0, 3.0]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/permute.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_permute() {\n    let data_1 = TensorData::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2\n    let data_2 = TensorData::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().permute([0, 2, 1]);\n    let tensor_4 = tensor_1.clone().matmul(tensor_3);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([[[7.2, 12.0], [7.2, 12.0]]]), tolerance); // 1x2x2\n    grad_2.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]),\n        tolerance,\n    ); // 1x3x2\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/pow.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_powf_scalar() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().powf_scalar(0.4));\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(2e-3);\n    let expected = TensorData::from([[68.0, 79.0328], [68.0, 79.0328]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[23.5081, 25.2779], [26.0502, 28.6383]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn should_diff_powf() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().powf(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([32.0, 14.0]);\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([11.09035, 95.34960]);\n    grad_2\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([16.0, 49.0]);\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_powf_with_untracked_lhs() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device);\n    let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().powf(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([11.09035, 95.34960]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_powf_with_untracked_rhs() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data([2.0, 7.0], &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data([4.0, 2.0], &device);\n\n    let tensor_3 = tensor_1.clone().powf(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    let expected = TensorData::from([32.0, 14.0]);\n    grad_1\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/recip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_recip() {\n    let data = TensorData::from([2.0, 5.0, 0.4]);\n\n    let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();\n    let tensor_out = tensor.clone().recip();\n\n    let grads = tensor_out.backward();\n    let grad = tensor.grad(&grads).unwrap();\n\n    tensor_out\n        .into_data()\n        .assert_eq(&TensorData::from([0.5, 0.2, 2.5]), false);\n    grad.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([-0.25, -0.04, -6.25]),\n        Tolerance::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/relu.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn should_diff_relu() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, -7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = activation::relu(tensor_3);\n    let tensor_5 = tensor_4.matmul(tensor_2.clone());\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[-47.0, 9.0], [-35.0, 15.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[15.0, 13.0], [-2.0, 39.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/remainder.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_remainder() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(\n        TensorData::from([\n            0.9742, 0.3676, 0.0905, 0.8066, 0.7072, 0.7883, 0.6987, 0.1560, 0.7179, 0.7874, 0.9032,\n            0.1845,\n        ]),\n        &device,\n    )\n    .require_grad();\n    let tensor_2 = TestAutodiffTensor::<1>::from_data(\n        TensorData::from([\n            0.3357, 0.0285, 0.4115, 0.5511, 0.8637, 0.3593, 0.3885, 0.2569, 0.0936, 0.7172, 0.4792,\n            0.4898,\n        ]),\n        &device,\n    )\n    .require_grad();\n    let tensor_3 = tensor_1.clone().remainder(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([\n        -2.0, -12.0, -0.0, -1.0, -0.0, -2.0, -1.0, -0.0, -7.0, -1.0, -1.0, -0.0,\n    ]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/repeat_dim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_repeat() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0], [2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().repeat_dim(1, 3);\n\n    let tensor_3 = tensor_1.matmul(tensor_3);\n    let grads = tensor_3.backward();\n\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[-3.0], [12.0]]), false);\n}\n\n#[test]\nfn should_diff_repeat_multi_dim() {\n    let data_1 = TensorData::from([[1.0, 7.0], [-2.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 2.0], [2.0, 4.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().repeat_dim(1, 3);\n\n    let tensor_3 = tensor_1.matmul(tensor_3);\n    let grads = tensor_3.backward();\n\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[-3.0, -3.0], [12.0, 12.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/reshape.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_reshape() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([4.0, 7.0, 2.0, 3.0]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::<1>::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().reshape([2, 2]);\n    let tensor_4 = tensor_1.clone().matmul(tensor_3);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([3.0, 3.0, 10.0, 10.0]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/round.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_round() {\n    let data = TensorData::from([\n        [-1.9751, 0.0714, 0.0643, 0.2406],\n        [-1.3172, 0.1252, -0.1119, -0.0127],\n    ]);\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();\n    let tensor_2 = tensor_1.clone().round();\n    let grads = tensor_2.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/select.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, Int, Tensor, TensorData};\n\n#[test]\nfn test_select_grad() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),\n        &device,\n    )\n    .require_grad();\n    let indices =\n        Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);\n\n    let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());\n    let tensor_3 = tensor_1.clone().select(0, indices);\n    let tensor_4 = tensor_2.matmul(tensor_3);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n\n    grad_1.into_data().assert_eq(\n        &TensorData::from([[109., 148., 187.], [37., 58., 79.]]),\n        false,\n    );\n}\n\n#[test]\nfn test_select_add_grad() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(\n        TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),\n        &device,\n    )\n    .require_grad();\n    let values = TestAutodiffTensor::from_data(\n        TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n        &device,\n    )\n    .require_grad();\n    let indices =\n        Tensor::<TestAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 0]), &device);\n\n    let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());\n    let tensor_3 =\n        tensor_1\n            .clone()\n            .select_assign(0, indices, values.clone(), IndexingUpdateOp::Add);\n    let tensor_4 = tensor_2.matmul(tensor_3);\n\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = values.grad(&grads).unwrap();\n\n    grad_1.into_data().assert_eq(\n        &TensorData::from([[127., 199., 271.], [172., 244., 316.]]),\n        false,\n    );\n    grad_2\n        .into_data()\n        .assert_eq(&TensorData::from([[64., 64., 64.], [19., 19., 19.]]), false);\n}\n\n#[test]\nfn test_select_add_grad_different_shapes() {\n    let device = Default::default();\n\n    let indices: Tensor<TestAutodiffBackend, 1, Int> = Tensor::from_ints([1], &device);\n    let x: Tensor<TestAutodiffBackend, 2> = Tensor::ones([1, 1], &device).require_grad();\n    let y = Tensor::ones([2, 1], &device).require_grad();\n\n    let w = y\n        .clone()\n        .select_assign(0, indices, x.clone(), IndexingUpdateOp::Add);\n    let w = w.matmul(y.clone().transpose());\n\n    let grads = w.backward();\n    let x_grad = x.grad(&grads).unwrap();\n    let y_grad = y.grad(&grads).unwrap();\n\n    x_grad\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0]]), false);\n    y_grad\n        .into_data()\n        .assert_eq(&TensorData::from([[5.0], [5.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/sigmoid.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn should_diff_sigmoid() {\n    let data = TensorData::from([0.8762]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();\n    let tensor_2 = activation::sigmoid(tensor_1.clone());\n    let grads = tensor_2.backward();\n\n    let grad = tensor_1.grad(&grads).unwrap();\n\n    let expected = TensorData::from([0.207549]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn small_neg_val_should_not_cause_grad_overflow() {\n    let data = TensorData::from([-90.0]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();\n    let tensor_2 = activation::sigmoid(tensor_1.clone());\n    let grads = tensor_2.backward();\n\n    let grad = tensor_1.grad(&grads).unwrap();\n\n    let expected = TensorData::from([0.0]);\n    grad.to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/sign.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n/// Example using the sign function with PyTorch:\n// >>> import torch\n// >>> # Create a tensor with requires_grad=True\n// >>> x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)\n// >>> # Forward pass: Apply the sign function\n// >>> y = torch.sign(x)\n// >>> print(\"Forward pass:\")\n// Forward pass:\n// >>> print(\"x:\", x)\n// x: tensor([-2., -1.,  0.,  1.,  2.], requires_grad=True)\n// >>> print(\"y:\", y)\n// y: tensor([-1., -1.,  0.,  1.,  1.], grad_fn=<SignBackward0>)\n// >>> # Compute the loss (just an example)\n// >>> loss = y.sum()\n// >>> # Backward pass: Compute the gradients\n// >>> loss.backward()\n// >>> print(\"\\nBackward pass:\")\n// Backward pass:\n// >>> print(\"x.grad:\", x.grad)\n// x.grad: tensor([0., 0., 0., 0., 0.])\n\n#[test]\nfn should_diff_sign() {\n    let data = TensorData::from([-2.0, -1.0, 0.0, 1.0, 2.0]);\n\n    let device = Default::default();\n    let x = TestAutodiffTensor::<1>::from_data(data, &device).require_grad();\n\n    let y = x.clone().sign();\n\n    let loss = y.clone().sum();\n    let grads = loss.backward();\n    let grad = x.grad(&grads).unwrap();\n\n    y.to_data()\n        .assert_eq(&TensorData::from([-1., -1., 0., 1., 1.]), false);\n    grad.to_data()\n        .assert_eq(&TensorData::from([0., 0., 0., 0., 0.]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/slice.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_matmul_with_slice() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);\n    let tensor_4 = tensor_1.clone().matmul(tensor_3);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[11.0, 5.0], [11.0, 5.0]]), false);\n    grad_2.to_data().assert_eq(\n        &TensorData::from([[3.0, 3.0, 0.0], [10.0, 10.0, 0.0]]),\n        false,\n    );\n}\n\n#[test]\nfn should_diff_matmul_with_slice_stepped() {\n    use burn_tensor::s;\n    let data_1 = TensorData::from([[1.0, 7.0], [100.0, 100.0], [2.0, 3.0], [100.0, 100.0]]);\n    let data_2 = TensorData::from([[4.0, 100.0, 7.0, 100.0], [2.0, 100.0, 3.0, 15.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().slice(s![0..;2, 0..2]); // [[1., 7.], [2., 3.]]\n    let tensor_4 = tensor_2.clone().slice(s![0..2, 0..;2]); // [[4., 7.], [2., 3.]]\n    let tensor_5 = tensor_3.clone().matmul(tensor_4);\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_eq(\n        &TensorData::from([[11., 5.], [0., 0.], [11., 5.], [0., 0.]]),\n        false,\n    );\n    grad_2.to_data().assert_eq(\n        &TensorData::from([[3., 0., 3., 0.], [10., 0., 10., 0.]]),\n        false,\n    );\n}\n\n#[test]\nfn should_panic_on_slice_with_step() {\n    use burn_tensor::s;\n\n    let data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n\n    // This should panic because step is 2\n    let _sliced = tensor.slice(s![.., 0..4; 2]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/slice_assign.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_matmul_with_slice_assign() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_assigned = TensorData::from([[9.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_assigned = TestAutodiffTensor::from_data(data_assigned, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);\n    let tensor_5 = tensor_4.matmul(tensor_1.clone());\n\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[58.0, 38.0], [118.0, 82.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[16.0, 15.0], [24.0, 50.0]]), false);\n}\n\n#[test]\nfn should_diff_matmul_with_slice_assign_complex() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[9.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);\n    let tensor_6 = tensor_5.mul(tensor_3.clone());\n    let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);\n    let tensor_8 = tensor_7.matmul(tensor_1.clone());\n\n    let grads = tensor_8.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n    let grad_3 = tensor_3.grad(&grads).unwrap();\n\n    grad_3\n        .to_data()\n        .assert_eq(&TensorData::from([[32.0]]), false);\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[85.0, 65.0], [118.0, 82.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[88.0, 15.0], [24.0, 50.0]]), false);\n}\n\n#[test]\nfn slice_assign_diff_should_give_same_results_as_cat() {\n    let data_1 = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[5.0, 6.0], [7.0, 8.0]]);\n    let data_3 = TensorData::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);\n\n    let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());\n    let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());\n    let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());\n    let slice_assign_output = slice_assign_output / tensor_3.clone();\n\n    let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);\n    let cat_output = cat_output / tensor_3;\n\n    slice_assign_output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&cat_output.to_data(), Tolerance::default());\n\n    let slice_assign_grads = slice_assign_output.backward();\n    let cat_grads = cat_output.backward();\n\n    let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();\n    let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();\n    let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();\n    let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();\n\n    slice_assign_grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&cat_grad_1.to_data(), Tolerance::default());\n    slice_assign_grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&cat_grad_2.to_data(), Tolerance::default());\n}\n\n#[test]\nfn should_diff_slice_assign_with_step() {\n    use burn_tensor::s;\n    let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);\n    let value_data = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n    let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();\n\n    // Assign with step=2\n    let result = tensor.clone().slice_assign(s![.., 0..4; 2], value.clone());\n    let result = result * 2.0; // Scale to create gradients\n    let grads = result.backward();\n\n    let grad_tensor = tensor.grad(&grads).unwrap();\n    let grad_value = value.grad(&grads).unwrap();\n\n    // The gradient for tensor should be 2.0 everywhere except the assigned positions\n    grad_tensor.to_data().assert_eq(\n        &TensorData::from([[0.0, 2.0, 0.0, 2.0], [0.0, 2.0, 0.0, 2.0]]),\n        false,\n    );\n    // The gradient for value should be 2.0 at all positions\n    grad_value\n        .to_data()\n        .assert_eq(&TensorData::from([[2.0, 2.0], [2.0, 2.0]]), false);\n}\n\n#[test]\nfn should_diff_slice_assign_with_negative_step() {\n    use burn_tensor::s;\n\n    let data = TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);\n    let value_data = TensorData::from([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);\n    let device = Default::default();\n    let tensor = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();\n    let value = TestAutodiffTensor::<2>::from_data(value_data, &device).require_grad();\n\n    // Assign with step=-1 (reverse order, all elements)\n    let result = tensor.clone().slice_assign(s![.., ..;-1], value.clone());\n    let result = result * 2.0; // Scale to create gradients\n    let grads = result.backward();\n\n    let grad_tensor = tensor.grad(&grads).unwrap();\n    let grad_value = value.grad(&grads).unwrap();\n\n    // The gradient for tensor should be 0 since all values were replaced\n    grad_tensor.to_data().assert_eq(\n        &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),\n        false,\n    );\n    // The gradient for value should be 2.0 at all positions\n    grad_value.to_data().assert_eq(\n        &TensorData::from([[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/softmax.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Tensor, TensorData, activation};\n\n#[test]\nfn test_softmax_grad() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n    let device = Default::default();\n    let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());\n\n    let grads = tensor_4.backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);\n\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.5));\n\n    let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(0.05, 0.05));\n}\n\n#[test]\nfn test_log_softmax_grad() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n    let device = Default::default();\n    let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = activation::log_softmax(tensor_3, 1).matmul(tensor_2.clone());\n\n    let grads = tensor_4.backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[-4.3939, -4.3939], [-12.9709, -12.9709]]);\n    // f16 gradients from log-softmax + matmul amplify error, so we increase the tolerance\n    // to account for limited precision and large representable step sizes in this range.\n    let tolerance = Tolerance::permissive().set_half_precision_relative(6e-2);\n\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[30.5984, -47.2267], [55.9631, -56.5914]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_quiet_softmax_grad() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = activation::softmax(tensor_3, 1).matmul(tensor_2.clone());\n\n    let grads = tensor_4.backward();\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[1.179665, 1.179661], [0.005462, 0.005463]]);\n\n    // Precision is quite bad yet on softmax grad especially with half precision.\n    let tolerance = Tolerance::rel_abs(0.5, 0.2);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[0.253469, 0.286237], [0.528630, 2.931664]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/sort.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_sort() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_sort_with_indices() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<2>::from_floats([[1.0, 7.0], [-2.0, -3.0]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[4.0, -7.0], [2.0, 3.0]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let (values, _indices) = tensor_3.sort_with_indices(1);\n    let tensor_4 = tensor_1.clone().mul(values);\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[35.0, 35.0], [-1.0, -8.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[11.0, 7.0], [55.0, 16.0]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_diff_sort_3d_dim1() {\n    let device = Default::default();\n    let tensor_1 =\n        TestAutodiffTensor::<3>::from_floats([[[1.0, 7.0], [-2.0, -3.0]]], &device).require_grad();\n    let tensor_2 =\n        TestAutodiffTensor::from_floats([[[4.0, -7.0], [2.0, 3.0]]], &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());\n    let tensor_4 = tensor_1.clone().mul(tensor_3.sort(1));\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let expected = TensorData::from([[[-1., -8.], [-27., 37.]]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[[-4., -17.], [-17., -42.]]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/sqrt.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_sqrt() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sqrt());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    let expected = TensorData::from([[82.112640, 99.083275], [82.112640, 99.083275]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[30.309311, 33.120457], [34.581974, 38.769463]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/sub.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_diff_sub() {\n    let data_1 = TensorData::from([2.0, 5.0]);\n    let data_2 = TensorData::from([4.0, 1.0]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<1>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().sub(tensor_2.clone());\n    let grads = tensor_3.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([1.0, 1.0]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([-1.0, -1.0]), false);\n\n    tensor_3\n        .into_data()\n        .assert_eq(&TensorData::from([-2.0, 4.0]), false);\n}\n\n#[test]\nfn should_diff_sub_scalar() {\n    let data = TensorData::from([2.0, 10.0]);\n    let tensor = TestAutodiffTensor::<1>::from_data(data, &Default::default()).require_grad();\n    let tensor_out = tensor.clone().sub_scalar(5.0);\n    let grads = tensor_out.backward();\n\n    let grad = tensor.grad(&grads).unwrap();\n\n    grad.to_data()\n        .assert_eq(&TensorData::from([1.0, 1.0]), false);\n    tensor_out\n        .into_data()\n        .assert_eq(&TensorData::from([-3.0, 5.0]), false);\n}\n\n#[test]\nfn test_sub_complex_1() {\n    let data_1 = TensorData::from([[1.0, 7.0], [13.0, -3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n    let data_3 = TensorData::from([[2.0, 2.0], [2.0, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1.clone().sub(tensor_2.clone());\n    let tensor_5 = tensor_4.sub(tensor_3).sub_scalar(5.0);\n    let tensor_6 = tensor_1.clone().sub(tensor_5);\n\n    let grads = tensor_6.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1\n        .to_data()\n        .assert_eq(&TensorData::from([[0.0, 0.0], [0.0, 0.0]]), false);\n    grad_2\n        .to_data()\n        .assert_eq(&TensorData::from([[1.0, 1.0], [1.0, 1.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/transpose.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_transpose() {\n    let data_1 = TensorData::from([[1.0, 7.0], [2.0, 3.0]]);\n    let data_2 = TensorData::from([[4.0, 7.0], [2.0, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().transpose());\n    let tensor_4 = tensor_3.transpose();\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[6.0, 10.0], [6.0, 10.0]]),\n        Tolerance::default(),\n    );\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[3.0, 10.0], [3.0, 10.0]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_swap_dims() {\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<3>::from_floats(\n        [[[0.0, 1.0], [3.0, 4.0]], [[6.0, 7.0], [9.0, 10.0]]],\n        &device,\n    )\n    .require_grad();\n    let tensor_2 = TestAutodiffTensor::from_floats(\n        [[[1.0, 4.0], [2.0, 5.0]], [[7.0, 10.0], [8.0, 11.0]]],\n        &device,\n    )\n    .require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().swap_dims(0, 2));\n    let tensor_4 = tensor_3.matmul(tensor_2.clone().swap_dims(1, 2));\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[[66., 78.], [66., 78.]], [[270., 306.], [270., 306.]]]),\n        Tolerance::default(),\n    );\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[[22., 286.], [28., 316.]], [[172., 652.], [190., 694.]]]),\n        Tolerance::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/trig.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_diff_cos() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cos());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    // Metal has less precise trigonometric functions\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-2);\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[26.8063, -27.7870], [26.8063, -27.7870]]),\n        tolerance,\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[9.222064, -39.123375], [-28.721354, 49.748356]]),\n        tolerance,\n    );\n}\n\n#[test]\nfn should_diff_sin() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sin());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    // Metal has less precise trigonometric functions\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-2);\n\n    let expected = TensorData::from([[8.8500, -4.9790], [8.8500, -4.9790]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[38.668987, 44.194775], [-59.97261, -80.46094]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn should_diff_tanh() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0], [9.0, 10.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tanh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(8e-3);\n    let expected = TensorData::from([[32.0, 32.0], [32.0, 32.0]]);\n    grad_1\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TensorData::from([[8.00092, 8.000153], [8.000003, 7.999995]]);\n    grad_2\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn should_diff_cosh() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().cosh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[7.092221, 16.696301], [7.092221, 16.696301]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[17.489855, 27.484539], [39.409813, 86.910278]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_sinh() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().sinh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[4.894847, 15.887931], [4.894847, 15.887931]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[17.284000, 28.412029], [39.302979, 87.498329]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_tan() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [0.3, 0.8]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().tan());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[2.532602, 1.596607], [2.532602, 1.596607]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[9.028598, 14.489801], [18.038082, 21.151270]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_asin() {\n    let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);\n    let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asin());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.435841, 0.969651], [0.435841, 0.969651]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.475300, 0.668141], [0.701834, 1.100658]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_acos() {\n    let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);\n    let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acos());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[2.077433, 1.543624], [2.077433, 1.543624]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.781337, 0.588496], [0.554804, 0.155979]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_atan() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atan());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[3.444365, 5.349211], [3.444365, 5.349211]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[9.904911, 11.554912], [10.199631, 11.391938]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_asinh() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().asinh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[3.806625, 6.844869], [3.806625, 6.844869]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[11.442373, 14.842072], [14.022551, 17.688538]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_acosh() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[1.5, 2.0], [2.5, 3.0]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().acosh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[10.611752, 15.178907], [10.611752, 15.178907]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[20.112753, 20.247547], [20.402235, 22.487328]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_atanh() {\n    let data_1 = TensorData::from([[0.0, 0.1], [0.3, 0.4]]);\n    let data_2 = TensorData::from([[0.2, 0.3], [0.5, 0.6]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_2.clone().atanh());\n    let tensor_4 = tensor_3.matmul(tensor_2.clone());\n    let grads = tensor_4.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.441838, 1.037115], [0.441838, 1.037115]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.491723, 0.698110], [0.772763, 1.298805]]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn should_diff_atan2() {\n    let data_1 = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n    let data_2 = TensorData::from([[0.5, 1.0], [1.5, 2.0]]);\n    let data_3 = TensorData::from([[1.0, 0.5], [2.0, 1.5]]);\n\n    let device = Default::default();\n    let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();\n    let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();\n    let tensor_3 = TestAutodiffTensor::from_data(data_3, &device).require_grad();\n\n    let tensor_4 = tensor_1\n        .clone()\n        .matmul(tensor_2.clone().atan2(tensor_3.clone()));\n    let tensor_5 = tensor_4.matmul(tensor_2.clone());\n    let grads = tensor_5.backward();\n\n    let grad_1 = tensor_1.grad(&grads).unwrap();\n    let grad_2 = tensor_2.grad(&grads).unwrap();\n    let grad_3 = tensor_3.grad(&grads).unwrap();\n\n    grad_1.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[4.570492, 4.210785], [4.570492, 4.210785]]),\n        Tolerance::default(),\n    );\n\n    grad_2.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[8.208448, 8.808449], [10.357923, 12.157923]]),\n        Tolerance::default(),\n    );\n\n    grad_3.to_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[-1.8, -8.4], [-1.8, -5.6]]),\n        Tolerance::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff/unfold.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn unfold_backward_accumulates_overlaps() {\n    let device = Default::default();\n    let x = TestAutodiffTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0]], &device).require_grad();\n\n    let y = x.clone().unfold::<3, _>(1, 2, 1);\n    let loss = y.sum();\n\n    let grads = loss.backward();\n    let grad_x = x.grad(&grads).unwrap();\n\n    grad_x\n        .to_data()\n        .assert_eq(&TensorData::from([[1., 2., 2., 1.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/autodiff.rs",
    "content": "//! Burn autodiff tests.\n\n#![allow(\n    clippy::single_range_in_vec_init,\n    clippy::duplicate_mod,\n    reason = \"false positive\"\n)]\nextern crate alloc;\n\npub type FloatElemType = f32;\n#[allow(unused)]\npub type IntElemType = i32;\n\n#[path = \"common/backend.rs\"]\nmod backend;\npub use backend::*;\n\n#[allow(clippy::module_inception)]\n#[path = \"common/autodiff.rs\"]\nmod autodiff;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/common/autodiff.rs",
    "content": "// Burn autodiff tests, reusable with element types.\n\npub use super::*;\n\n#[path = \"../autodiff/mod.rs\"]\nmod base;\n\nmod checkpointing {\n    pub use super::*;\n    use burn_autodiff::checkpoint::strategy::BalancedCheckpointing;\n\n    // Override type def\n    pub type TestAutodiffBackend = Autodiff<TestBackend, BalancedCheckpointing>;\n    pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;\n\n    include!(\"../autodiff/mod.rs\");\n}\n\nuse burn_backend_tests::test_float_elem_variant;\n\n// NOTE: this currently doesn't test checkpointing with different dtypes\ntest_float_elem_variant!(\n    f16,\n    burn_tensor::f16,\n    \"../autodiff/mod.rs\",\n    [\"vulkan\", \"cuda\", \"rocm\", \"metal\"]\n);\n\n// TODO: bf16 not yet supported on any backend for full test suite\n// test_float_elem_variant!(\n//     bf16,\n//     burn_tensor::bf16,\n//     \"../autodiff/mod.rs\",\n//     [] // [\"cuda\", \"rocm\"] TODO, [\"vulkan\"] only supports bf16 for matmul, metal/wgpu doesn't support bf16\n// );\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/common/backend.rs",
    "content": "// Re-export\nuse super::FloatElemType;\n\n// Default\n#[cfg(feature = \"ndarray\")]\npub type TestBackend = burn_ndarray::NdArray<FloatElemType>;\n\n#[cfg(feature = \"tch\")]\npub type TestBackend = burn_tch::LibTorch<FloatElemType>;\n\n#[cfg(feature = \"cuda\")]\npub type TestBackend = burn_cuda::Cuda<FloatElemType, super::IntElemType>;\n\n#[cfg(feature = \"rocm\")]\npub type TestBackend = burn_rocm::Rocm<FloatElemType, super::IntElemType>;\n\n#[cfg(feature = \"wgpu\")]\npub type TestBackend = burn_wgpu::Wgpu<FloatElemType, super::IntElemType>;\n\n#[cfg(feature = \"cpu\")]\npub type TestBackend = burn_cpu::Cpu<FloatElemType, super::IntElemType>;\n\n#[cfg(feature = \"router\")]\npub type TestBackend = burn_router::BackendRouter<\n    burn_router::DirectByteChannel<(burn_ndarray::NdArray, burn_wgpu::Wgpu)>,\n>;\n\n/// Collection of types used across tests\n#[allow(unused)]\npub mod prelude {\n    pub use burn_autodiff::Autodiff;\n    pub use burn_tensor::Tensor;\n\n    use super::*;\n    pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n    pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;\n    pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;\n\n    pub type FloatElem = burn_tensor::ops::FloatElem<TestBackend>;\n    pub type IntElem = burn_tensor::ops::IntElem<TestBackend>;\n\n    pub type TestAutodiffBackend = Autodiff<TestBackend>;\n    pub type TestAutodiffTensor<const D: usize> = Tensor<TestAutodiffBackend, D>;\n}\n\n#[allow(unused)]\npub use prelude::*;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/common/tensor.rs",
    "content": "// Burn backend tensor tests, reusable with element types.\n\npub use super::*;\n\n#[path = \"../tensor/clone_invariance.rs\"]\nmod clone_invariance;\n\n#[cfg(feature = \"std\")]\n#[path = \"../tensor/multi_threads.rs\"]\nmod multi_threads;\n\n// Default float dtype\n#[path = \"../tensor/float/mod.rs\"]\nmod float;\n\n// Default integer dtype\n#[path = \"../tensor/int/mod.rs\"]\nmod int;\n\n// Default bool dtype\n#[path = \"../tensor/bool/mod.rs\"]\nmod bool;\n\nuse burn_backend_tests::test_float_elem_variant;\n\ntest_float_elem_variant!(\n    f16,\n    burn_tensor::f16,\n    \"../tensor/float/mod.rs\",\n    [\"vulkan\", \"cuda\", \"rocm\", \"metal\"]\n);\n\n// TODO: bf16 not yet supported on any backend for full test suite\n// test_float_elem_variant!(\n//     bf16,\n//     burn_tensor::bf16,\n//     \"../tensor/float/mod.rs\",\n//     [] // [\"cuda\", \"rocm\"] TODO, [\"vulkan\"] only supports bf16 for matmul, metal/wgpu doesn't support bf16\n// );\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/avg_pool2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{\n    Distribution, Tensor, TensorPrimitive, backend::Backend, module, ops::ModuleOps,\n};\n\n#[test]\nfn avg_pool2d_should_match_reference_backend() {\n    let tensor = Tensor::<TestBackend, 4>::random(\n        [32, 32, 32, 32],\n        Distribution::Default,\n        &Default::default(),\n    );\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());\n    let kernel_size = [3, 4];\n    let stride = [1, 2];\n    let padding = [1, 2];\n    let count_include_pad = true;\n\n    let pooled = module::avg_pool2d(\n        tensor,\n        kernel_size,\n        stride,\n        padding,\n        count_include_pad,\n        false,\n    );\n    let pooled_ref = module::avg_pool2d(\n        tensor_ref,\n        kernel_size,\n        stride,\n        padding,\n        count_include_pad,\n        false,\n    );\n\n    pooled\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());\n}\n\n#[test]\nfn avg_pool2d_backward_should_match_reference_backend() {\n    let device = Default::default();\n\n    TestBackend::seed(&device, 0);\n    ReferenceBackend::seed(&Default::default(), 0);\n\n    let tensor = Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &device);\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());\n    let kernel_size = [3, 3];\n    let stride = [1, 1];\n    let padding = [1, 1];\n    let count_include_pad = true;\n\n    let shape_out = module::avg_pool2d(\n        tensor.clone(),\n        kernel_size,\n        stride,\n        padding,\n        count_include_pad,\n        false,\n    )\n    .shape();\n    let grad_output =\n        Tensor::<TestBackend, 4>::random(shape_out, Distribution::Default, &Default::default());\n    let grad_output_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &Default::default());\n\n    let grad: Tensor<TestBackend, 4> =\n        Tensor::from_primitive(TensorPrimitive::Float(TestBackend::avg_pool2d_backward(\n            tensor.into_primitive().tensor(),\n            grad_output.into_primitive().tensor(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            false,\n        )));\n    let grad_ref: Tensor<ReferenceBackend, 4> = Tensor::from_primitive(TensorPrimitive::Float(\n        ReferenceBackend::avg_pool2d_backward(\n            tensor_ref.into_primitive().tensor(),\n            grad_output_ref.into_primitive().tensor(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            false,\n        ),\n    ));\n\n    grad.into_data()\n        .assert_approx_eq::<FloatElem>(&grad_ref.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/bernoulli.rs",
    "content": "use super::*;\n\nuse serial_test::serial;\n\nuse core::f32;\n\nuse burn_tensor::{Distribution, Shape, Tensor, backend::Backend};\n\nuse cubek::random::{assert_number_of_1_proportional_to_prob, assert_wald_wolfowitz_runs_test};\n\n#[test]\n#[serial]\nfn number_of_1_proportional_to_prob() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let shape: Shape = [40, 40].into();\n    let prob = 0.7;\n\n    let tensor =\n        Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Bernoulli(prob), &device)\n            .into_data();\n\n    let numbers = tensor\n        .as_slice::<<TestBackend as Backend>::FloatElem>()\n        .unwrap();\n\n    assert_number_of_1_proportional_to_prob(numbers, prob as f32);\n}\n\n#[test]\n#[serial]\nfn wald_wolfowitz_runs_test() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let shape = Shape::new([512, 512]);\n    let device = Default::default();\n    let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Bernoulli(0.5), &device);\n\n    let data = tensor.into_data();\n    let numbers = data\n        .as_slice::<<TestBackend as Backend>::FloatElem>()\n        .unwrap();\n\n    // High bound slightly over 1 so 1.0 is included in second bin\n    assert_wald_wolfowitz_runs_test(numbers, 0., 1.1);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/cast.rs",
    "content": "use super::*;\nuse burn_tensor::{Int, Tensor, TensorData};\n\n#[test]\nfn should_cast_int_to_float() {\n    const START: usize = 0;\n    const END: usize = 100;\n\n    let device = Default::default();\n    let tensor = Tensor::<TestBackend, 1, Int>::arange(START as i64..END as i64, &device);\n\n    let data_int = tensor.to_data();\n    let data_int = data_int.as_slice::<i32>().unwrap();\n    let data_float = tensor.float().into_data();\n    let data_float = data_float.as_slice::<f32>().unwrap();\n\n    for i in START..END {\n        assert_eq!(data_int[i], i as i32);\n        assert_eq!(data_float[i], i as f32);\n    }\n}\n\n#[test]\nfn should_cast_bool_to_int() {\n    let device = Default::default();\n\n    let tensor_1 = Tensor::<TestBackend, 2>::from_floats([[1., 0., 3.], [0., 0., 900.]], &device);\n    let tensor_2: Tensor<TestBackend, 2, Int> = tensor_1.clone().greater_elem(0.0).int();\n\n    tensor_2\n        .to_data()\n        .assert_eq(&TensorData::from([[1, 0, 1], [0, 0, 1]]), false);\n}\n\n#[test]\nfn should_cast_bool_to_float() {\n    let device = Default::default();\n\n    let tensor_1 = Tensor::<TestBackend, 2>::from_floats([[1., 0., 3.], [0., 0., 900.]], &device);\n    let tensor_2: Tensor<TestBackend, 2> = tensor_1.clone().greater_elem(0.0).float();\n\n    tensor_2\n        .to_data()\n        .assert_eq(&TensorData::from([[1., 0., 1.], [0., 0., 1.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/cat.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, backend::Backend};\n\n#[test]\nfn cat_should_match_reference_backend_dim0() {\n    test_same_as_reference([6, 256], 2, 0);\n}\n\n#[test]\nfn cat_should_match_reference_backend_dim1() {\n    test_same_as_reference([6, 256], 2, 1);\n}\n\n#[test]\nfn cat_should_support_uneven_launch() {\n    test_same_as_reference([1, 137], 2, 0);\n}\n\nfn test_same_as_reference(shape: [usize; 2], num_tensors: usize, dim: usize) {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let tensors = (0..num_tensors)\n        .map(|_| {\n            Tensor::<TestBackend, 2>::random(shape, Distribution::Default, &Default::default())\n        })\n        .collect::<Vec<_>>();\n    let tensors_ref = tensors\n        .iter()\n        .map(|tensor| {\n            Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default())\n        })\n        .collect::<Vec<_>>();\n\n    let tensor = Tensor::<TestBackend, 2>::cat(tensors, dim);\n    let tensor_ref = Tensor::<ReferenceBackend, 2>::cat(tensors_ref, dim);\n\n    tensor\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&tensor_ref.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/clamp.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor};\n\n#[test]\nfn clamp_should_match_reference() {\n    let input = Tensor::<TestBackend, 4>::random(\n        [1, 5, 32, 32],\n        Distribution::Default,\n        &Default::default(),\n    );\n    let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &Default::default());\n\n    let output = input.clamp(0.3, 0.7);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &input_ref.clamp(0.3, 0.7).into_data(),\n        Tolerance::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/contiguous.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Int, Tensor};\n\n#[test]\npub fn into_contiguous_match_reference_backend_1() {\n    for shape in [\n        [4, 4, 4, 4],\n        [32, 42, 24, 48],\n        [8, 3, 7, 4],\n        [1, 4, 1, 1],\n        [1, 32, 256, 128],\n    ] {\n        let num_elems = shape.iter().product::<usize>() as i64;\n        let tensor: Tensor<TestBackend, 4> =\n            Tensor::<TestBackend, 1, Int>::arange(0..num_elems, &Default::default())\n                .reshape(shape)\n                .float();\n        let tensor_ref =\n            Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());\n\n        for (i, j) in get_combinations(shape.len()) {\n            let view = tensor.clone().swap_dims(i, j);\n            let view_ref = tensor_ref.clone().swap_dims(i, j);\n            let data = view.into_data();\n            let data_ref = view_ref.into_data();\n\n            data_ref.assert_approx_eq::<FloatElem>(&data, Tolerance::default());\n        }\n    }\n}\n\nfn get_combinations(n: usize) -> impl Iterator<Item = (usize, usize)> {\n    // Iterate from 0 up to n\n    (0..n).flat_map(move |i| {\n        // For each i, iterate from i + 1 up to n\n        // This ensures no repeats (i == j) and no duplicates (j, i)\n        (i + 1..n).map(move |j| (i, j))\n    })\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/conv2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::ops::{ConvOptions, ModuleOps};\nuse burn_tensor::{Distribution, Tensor, TensorPrimitive, module};\n\n#[test]\nfn conv2d_should_match_reference_backend() {\n    let test_device = Default::default();\n    let input =\n        Tensor::<TestBackend, 4>::random([6, 16, 32, 32], Distribution::Default, &test_device);\n    let weight =\n        Tensor::<TestBackend, 4>::random([12, 8, 3, 3], Distribution::Default, &test_device);\n    let bias = Tensor::<TestBackend, 1>::random([12], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n\n    let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let options = ConvOptions::new([2, 3], [2, 3], [2, 3], 2);\n\n    let output = module::conv2d(input, weight, Some(bias), options.clone());\n    let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());\n}\n\n#[test]\nfn conv2d_should_match_reference_backend_implicit() {\n    let test_device = Default::default();\n    let input =\n        Tensor::<TestBackend, 4>::random([4, 16, 6, 6], Distribution::Default, &test_device);\n    let weight =\n        Tensor::<TestBackend, 4>::random([16, 16, 3, 3], Distribution::Default, &test_device);\n    let bias = Tensor::<TestBackend, 1>::random([16], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n\n    let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let options = ConvOptions::new([1, 1], [2, 2], [1, 1], 1);\n\n    let output = module::conv2d(input, weight, Some(bias), options.clone());\n    let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);\n\n    let tolerance = Tolerance::default();\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), tolerance);\n}\n\n/// Regression test for bias loader in new implicit GEMM\n#[test]\nfn conv2d_should_match_reference_backend_bias_regression() {\n    let test_device = Default::default();\n    let input = Tensor::<TestBackend, 4>::random([1, 1, 1, 1], Distribution::Default, &test_device);\n    let weight =\n        Tensor::<TestBackend, 4>::random([32, 1, 3, 3], Distribution::Default, &test_device);\n    let bias = Tensor::<TestBackend, 1>::random([32], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n\n    let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let options = ConvOptions::new([1, 1], [1, 1], [1, 1], 1);\n\n    let output = module::conv2d(input, weight, Some(bias), options.clone()).permute([0, 2, 3, 1]);\n    let output_ref =\n        module::conv2d(input_ref, weight_ref, Some(bias_ref), options).permute([0, 2, 3, 1]);\n\n    let tolerance = Tolerance::default();\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), tolerance);\n}\n\n#[test]\nfn conv2d_weight_backward_should_run() {\n    // https://github.com/tracel-ai/burn/issues/4226#issuecomment-3911335769\n    let device = Default::default();\n    let options = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);\n    let x = Tensor::<TestBackend, 4>::random([1, 1, 1, 672], Distribution::Default, &device);\n    // let x = x.permute([0, 3, 1, 2]);\n\n    let output_grad =\n        Tensor::<TestBackend, 4>::random([1, 168, 1, 1], Distribution::Default, &device);\n    let weight = Tensor::<TestBackend, 4>::random([168, 672, 1, 1], Distribution::Default, &device);\n\n    let ref_device = Default::default();\n    let x_ref = Tensor::<ReferenceBackend, 4>::from_data(x.to_data(), &ref_device);\n    let output_grad_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(output_grad.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);\n\n    // Input shape [672, 1] and strides [672, 672] should be valid\n    let output = TestBackend::conv2d_weight_backward(\n        x.permute([0, 3, 1, 2]).into_primitive().tensor(),\n        weight.into_primitive().tensor(),\n        output_grad.into_primitive().tensor(),\n        options.clone(),\n    );\n\n    // Input shape [672, 1] and strides [672, 672] should be valid\n    let output_ref = ReferenceBackend::conv2d_weight_backward(\n        x_ref.permute([0, 3, 1, 2]).into_primitive().tensor(),\n        weight_ref.into_primitive().tensor(),\n        output_grad_ref.into_primitive().tensor(),\n        options,\n    );\n\n    let tolerance = Tolerance::default();\n    Tensor::<TestBackend, 4>::from_primitive(TensorPrimitive::Float(output))\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &Tensor::<ReferenceBackend, 4>::from_primitive(TensorPrimitive::Float(output_ref))\n                .into_data(),\n            tolerance,\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/conv3d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, module};\n\n#[test]\nfn conv3d_should_match_reference_backend() {\n    let test_device = Default::default();\n    let input =\n        Tensor::<TestBackend, 5>::random([6, 16, 32, 32, 32], Distribution::Default, &test_device);\n    let weight =\n        Tensor::<TestBackend, 5>::random([12, 8, 3, 3, 3], Distribution::Default, &test_device);\n    let bias = Tensor::<TestBackend, 1>::random([12], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n\n    let input_ref = Tensor::<ReferenceBackend, 5>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 5>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let options = burn_tensor::ops::ConvOptions::new([2, 3, 4], [2, 3, 4], [2, 3, 4], 2);\n\n    let output = module::conv3d(input, weight, Some(bias), options.clone());\n    let output_ref = module::conv3d(input_ref, weight_ref, Some(bias_ref), options);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/conv_transpose2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, backend::Backend, module};\n\n#[test]\nfn conv_transpose2d_should_match_reference_backend() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let height = 8;\n    let width = 8;\n    let in_channels = 8;\n    let out_channels = 8;\n    let batch_size = 32;\n    let kernel_size_0 = 3;\n    let kernel_size_1 = 3;\n    let options = burn_tensor::ops::ConvTransposeOptions::new([1, 1], [1, 1], [0, 0], [1, 1], 1);\n\n    let test_device = Default::default();\n    let input = Tensor::<TestBackend, 4>::random(\n        [batch_size, in_channels, height, width],\n        Distribution::Default,\n        &test_device,\n    );\n    let weight = Tensor::<TestBackend, 4>::random(\n        [\n            in_channels,\n            out_channels / options.groups,\n            kernel_size_0,\n            kernel_size_1,\n        ],\n        Distribution::Default,\n        &test_device,\n    );\n    let bias =\n        Tensor::<TestBackend, 1>::random([out_channels], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n    let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let output = module::conv_transpose2d(input, weight, Some(bias), options.clone());\n    let output_ref = module::conv_transpose2d(input_ref, weight_ref, Some(bias_ref), options);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::rel_abs(0.01, 0.02));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/conv_transpose3d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, backend::Backend, module};\n\n#[test]\nfn conv_transpose3d_should_match_reference_backend() {\n    let test_device = Default::default();\n    TestBackend::seed(&test_device, 0);\n\n    let depth = 8;\n    let height = 8;\n    let width = 8;\n    let in_channels = 8;\n    let out_channels = 8;\n    let batch_size = 32;\n    let kernel_size_0 = 3;\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let options =\n        burn_tensor::ops::ConvTransposeOptions::new([1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1], 1);\n\n    let input = Tensor::<TestBackend, 5>::random(\n        [batch_size, in_channels, depth, height, width],\n        Distribution::Default,\n        &test_device,\n    );\n    let weight = Tensor::<TestBackend, 5>::random(\n        [\n            in_channels,\n            out_channels / options.groups,\n            kernel_size_0,\n            kernel_size_1,\n            kernel_size_2,\n        ],\n        Distribution::Default,\n        &test_device,\n    );\n    let bias =\n        Tensor::<TestBackend, 1>::random([out_channels], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n    let input_ref = Tensor::<ReferenceBackend, 5>::from_data(input.to_data(), &ref_device);\n    let weight_ref = Tensor::<ReferenceBackend, 5>::from_data(weight.to_data(), &ref_device);\n    let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);\n\n    let output = module::conv_transpose3d(input, weight, Some(bias), options.clone());\n    let output_ref = module::conv_transpose3d(input_ref, weight_ref, Some(bias_ref), options);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/cross.rs",
    "content": "use super::*;\nuse burn_tensor::Tensor;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_cross_product() {\n    let device = Default::default();\n    // Test with well-known orthogonal vectors for clearer validation\n    let a = Tensor::<TestBackend, 2>::from_data([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], &device);\n    let b = Tensor::<TestBackend, 2>::from_data([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], &device);\n\n    let result = a.cross(b, 1);\n    // For orthogonal unit vectors:\n    // i × j = k\n    // j × k = i\n    let expected = Tensor::<TestBackend, 2>::from_data([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], &device);\n\n    // Use Tolerance for floating-point comparisons\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n\n#[test]\nfn test_cross_product_zeros() {\n    let device = Default::default();\n    // Test cross product with zero vector - should always give zero vector\n    let a = Tensor::<TestBackend, 2>::from_data([[2.0, 3.0, 4.0]], &device);\n    let b = Tensor::<TestBackend, 2>::zeros([1, 3], &device);\n\n    let result = a.cross(b, 1);\n    let expected = Tensor::<TestBackend, 2>::zeros([1, 3], &device);\n\n    // For zeros, we can use exact equality or a very tight tolerance\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n\n#[test]\nfn test_cross_product_batch() {\n    let device = Default::default();\n    // Test typical cross product computations in batch\n    let a = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let b = Tensor::<TestBackend, 2>::from_data([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device);\n\n    let result = a.cross(b, 1);\n    // Cross products:\n    // [1,2,3] × [4,5,6] = [-3,6,-3]\n    // [4,5,6] × [7,8,9] = [-3,6,-3]\n    let expected =\n        Tensor::<TestBackend, 2>::from_data([[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]], &device);\n\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n\n#[test]\n#[should_panic]\nfn test_cross_product_invalid_dimension() {\n    let device = Default::default();\n    let a = Tensor::<TestBackend, 2>::zeros([1, 4], &device);\n    let b = Tensor::<TestBackend, 2>::zeros([1, 4], &device);\n\n    let _ = a.cross(b, 1);\n}\n\n#[test]\nfn test_cross_product_parallel_vectors() {\n    let device = Default::default();\n    // Test cross product of parallel vectors (should be zero)\n    let a = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let b = Tensor::<TestBackend, 2>::from_data([[2.0, 4.0, 6.0]], &device); // b = 2 * a\n\n    let result = a.cross(b, 1);\n    let expected = Tensor::<TestBackend, 2>::zeros([1, 3], &device);\n\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n\n#[test]\nfn test_cross_product_3d_tensor() {\n    let device = Default::default();\n    // Test with 3D tensor (batch of matrices)\n    let a = Tensor::<TestBackend, 3>::from_data(\n        [\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],\n        ],\n        &device,\n    );\n\n    let b = Tensor::<TestBackend, 3>::from_data(\n        [\n            [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],\n        ],\n        &device,\n    );\n\n    let result = a.cross(b, 2); // Cross on last dimension\n    let expected = Tensor::<TestBackend, 3>::from_data(\n        [\n            [[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],\n            [[-3.0, 6.0, -3.0], [-3.0, 6.0, -3.0]],\n        ],\n        &device,\n    );\n\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n\n// Test to verify that padding doesn't affect results\n#[test]\nfn test_cross_product_with_padding_awareness() {\n    let device = Default::default();\n    // Create tensors that would span multiple 4-element blocks\n    // This tests that the padding doesn't corrupt adjacent data\n    let a = Tensor::<TestBackend, 2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], // Two vectors: [1,2,3] and [4,5,6]\n        ],\n        &device,\n    );\n\n    let b = Tensor::<TestBackend, 2>::from_data(\n        [\n            [7.0, 8.0, 9.0, 10.0, 11.0, 12.0], // Two vectors: [7,8,9] and [10,11,12]\n        ],\n        &device,\n    );\n\n    // Reshape to have proper 3-element vectors in last dimension\n    let a_reshaped = a.reshape([2, 3]);\n    let b_reshaped = b.reshape([2, 3]);\n\n    let result = a_reshaped.cross(b_reshaped, 1);\n\n    // Expected cross products:\n    // [1,2,3] × [7,8,9] = [-6,12,-6]\n    // [4,5,6] × [10,11,12] = [-6,12,-6]\n    let expected =\n        Tensor::<TestBackend, 2>::from_data([[-6.0, 12.0, -6.0], [-6.0, 12.0, -6.0]], &device);\n\n    let tolerance = Tolerance::<FloatElem>::default();\n    result\n        .to_data()\n        .assert_approx_eq(&expected.to_data(), tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/gather.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Int, Shape, Tensor, backend::Backend};\n\n#[test]\nfn gather_should_work_with_multiple_workgroups_dim0() {\n    test_same_as_ref([6, 256], 0);\n}\n\n#[test]\nfn gather_should_work_with_multiple_workgroups_dim1() {\n    test_same_as_ref([6, 256], 1);\n}\n\nfn test_same_as_ref<const D: usize>(shape: [usize; D], dim: usize) {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let max = shape[dim];\n    let shape = Shape::new(shape);\n    let tensor =\n        Tensor::<TestBackend, D>::random(shape.clone(), Distribution::Default, &Default::default());\n    let indices = Tensor::<TestBackend, 1, Int>::from_data(\n        Tensor::<TestBackend, 1>::random(\n            [shape.num_elements()],\n            Distribution::Uniform(0., max as f64),\n            &Default::default(),\n        )\n        .into_data(),\n        &Default::default(),\n    )\n    .reshape(shape);\n    let tensor_ref =\n        Tensor::<ReferenceBackend, D>::from_data(tensor.to_data(), &Default::default());\n    let indices_ref =\n        Tensor::<ReferenceBackend, D, Int>::from_data(indices.to_data(), &Default::default());\n\n    let actual = tensor.gather(dim, indices);\n    let expected = tensor_ref.gather(dim, indices_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/mask_fill.rs",
    "content": "use super::*;\nuse burn_cubecl::kernel::{MaskFillStrategy, mask_fill};\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend};\nuse cubecl::prelude::InputScalar;\n\n#[test]\nfn mask_fill_should_match_reference_backend() {\n    let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();\n    let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();\n    let dtype_ft = <FloatElem as Element>::dtype();\n\n    let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(\n        tensor.into_primitive().tensor(),\n        mask.into_primitive(),\n        InputScalar::new(4.0, dtype_ft),\n        MaskFillStrategy::Readonly,\n        dtype_bool,\n    )));\n    let expected = tensor_ref.mask_fill(mask_ref, 4.0);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[test]\nfn mask_fill_inplace_should_match_reference_backend() {\n    let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();\n    let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();\n    let dtype_ft = <FloatElem as Element>::dtype();\n\n    let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill::<_>(\n        tensor.into_primitive().tensor(),\n        mask.into_primitive(),\n        InputScalar::new(4.0, dtype_ft),\n        MaskFillStrategy::Inplace,\n        dtype_bool,\n    )));\n    let expected = tensor_ref.mask_fill(mask_ref, 4.0);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[allow(clippy::type_complexity)]\nfn inputs_mask_fill() -> (\n    Tensor<TestBackend, 3>,\n    Tensor<TestBackend, 3, Bool>,\n    Tensor<ReferenceBackend, 3>,\n    Tensor<ReferenceBackend, 3, Bool>,\n) {\n    let test_device = Default::default();\n    let tensor = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &test_device);\n    let mask =\n        Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0., 1.), &test_device)\n            .lower_equal_elem(0.5);\n    let ref_device = Default::default();\n    let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &ref_device);\n    let mask_ref = Tensor::<ReferenceBackend, 3, Bool>::from_data(mask.to_data(), &ref_device);\n\n    (tensor, mask, tensor_ref, mask_ref)\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/mask_where.rs",
    "content": "use super::*;\nuse burn_cubecl::kernel::{MaskWhereStrategy, mask_where};\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend};\n\n#[test]\nfn mask_where_should_match_reference_backend() {\n    let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();\n\n    let actual = tensor.mask_where(mask, value);\n    let expected = tensor_ref.mask_where(mask_ref, value_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n#[test]\nfn mask_where_inplace_lhs_should_match_reference_backend() {\n    let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();\n    let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();\n\n    let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where::<_>(\n        tensor.into_primitive().tensor(),\n        mask.into_primitive(),\n        value.into_primitive().tensor(),\n        MaskWhereStrategy::InplaceLhs,\n        dtype_bool,\n    )));\n    let expected = tensor_ref.mask_where(mask_ref, value_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[test]\nfn mask_where_inplace_rhs_should_match_reference_backend() {\n    let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();\n    let dtype_bool = <<TestBackend as Backend>::BoolElem as Element>::dtype();\n\n    let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where::<_>(\n        tensor.into_primitive().tensor(),\n        mask.into_primitive(),\n        value.into_primitive().tensor(),\n        MaskWhereStrategy::InplaceRhs,\n        dtype_bool,\n    )));\n    let expected = tensor_ref.mask_where(mask_ref, value_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[allow(clippy::type_complexity)]\nfn inputs_mask_where() -> (\n    Tensor<TestBackend, 3>,\n    Tensor<TestBackend, 3>,\n    Tensor<TestBackend, 3, Bool>,\n    Tensor<ReferenceBackend, 3>,\n    Tensor<ReferenceBackend, 3>,\n    Tensor<ReferenceBackend, 3, Bool>,\n) {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let tensor = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &device);\n    let value = Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Default, &device);\n    let mask =\n        Tensor::<TestBackend, 3>::random([2, 6, 256], Distribution::Uniform(0., 1.), &device)\n            .lower_equal_elem(0.5);\n\n    let device_ref = Default::default();\n    let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &device_ref);\n    let value_ref = Tensor::<ReferenceBackend, 3>::from_data(value.to_data(), &device_ref);\n    let mask_ref = Tensor::<ReferenceBackend, 3, Bool>::from_data(mask.to_data(), &device_ref);\n    mask.to_data().assert_eq(&mask_ref.to_data(), false);\n\n    (tensor, value, mask, tensor_ref, value_ref, mask_ref)\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/max_pool2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, module};\n\n#[test]\npub fn max_pool2d_should_match_reference_backends() {\n    let tensor = Tensor::<TestBackend, 4>::random(\n        [32, 32, 32, 32],\n        Distribution::Default,\n        &Default::default(),\n    );\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());\n    let kernel_size = [3, 3];\n    let stride = [2, 2];\n    let padding = [1, 1];\n    let dilation = [1, 1];\n\n    let pooled = module::max_pool2d(tensor, kernel_size, stride, padding, dilation, false);\n    let pooled_ref = module::max_pool2d(tensor_ref, kernel_size, stride, padding, dilation, false);\n\n    pooled\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());\n}\n\n#[test]\npub fn max_pool2d_with_indices_should_match_reference_backend() {\n    let tensor = Tensor::<TestBackend, 4>::random(\n        [32, 32, 32, 32],\n        Distribution::Default,\n        &Default::default(),\n    );\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &Default::default());\n    let kernel_size = [3, 3];\n    let stride = [2, 2];\n    let padding = [1, 1];\n    let dilation = [1, 1];\n\n    let (pooled, indices) =\n        module::max_pool2d_with_indices(tensor, kernel_size, stride, padding, dilation, false);\n    let (pooled_ref, indices_ref) =\n        module::max_pool2d_with_indices(tensor_ref, kernel_size, stride, padding, dilation, false);\n\n    pooled\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&pooled_ref.into_data(), Tolerance::default());\n    indices\n        .into_data()\n        .assert_eq(&indices_ref.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/max_pool2d_backward.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor, TensorPrimitive, module, ops::ModuleOps};\n\n#[test]\npub fn max_pool2d_with_indices_backward_should_match_reference_backend() {\n    let test_device = Default::default();\n    let tensor =\n        Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);\n    let grad_output =\n        Tensor::<TestBackend, 4>::random([32, 32, 16, 16], Distribution::Default, &test_device);\n    let ref_device = Default::default();\n    let tensor_ref = Tensor::<ReferenceBackend, 4>::from_data(tensor.to_data(), &ref_device);\n    let grad_output_ref =\n        Tensor::<ReferenceBackend, 4>::from_data(grad_output.to_data(), &ref_device);\n    let kernel_size = [3, 3];\n    let stride = [2, 2];\n    let padding = [1, 1];\n    let dilation = [1, 1];\n\n    let (_, indices) = module::max_pool2d_with_indices(\n        tensor.clone(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        false,\n    );\n    let (_, indices_ref) = module::max_pool2d_with_indices(\n        tensor_ref.clone(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        false,\n    );\n    let grad = TestBackend::max_pool2d_with_indices_backward(\n        tensor.into_primitive().tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        false,\n        grad_output.into_primitive().tensor(),\n        indices.into_primitive(),\n    )\n    .x_grad;\n    let grad_ref = ReferenceBackend::max_pool2d_with_indices_backward(\n        tensor_ref.into_primitive().tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        false,\n        grad_output_ref.into_primitive().tensor(),\n        indices_ref.into_primitive(),\n    )\n    .x_grad;\n\n    Tensor::<TestBackend, 4>::from_primitive(TensorPrimitive::Float(grad))\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &Tensor::<ReferenceBackend, 4>::from_primitive(TensorPrimitive::Float(grad_ref))\n                .into_data(),\n            Tolerance::default(),\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/mod.rs",
    "content": "// #[allow(unused_imports)] // required for re-included modules\npub use super::*;\n\nmod avg_pool2d;\nmod bernoulli;\nmod cast;\nmod cat;\nmod clamp;\nmod contiguous;\nmod conv2d;\nmod conv3d;\nmod conv_transpose2d;\nmod conv_transpose3d;\nmod cross;\nmod gather;\nmod mask_fill;\nmod mask_where;\nmod max_pool2d;\nmod max_pool2d_backward;\nmod normal;\nmod quantization;\nmod reduce;\nmod repeat_dim;\nmod scatter;\nmod select;\nmod select_assign;\nmod slice;\nmod slice_assign;\nmod unary;\nmod uniform;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/normal.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, Shape, Tensor, backend::Backend};\nuse cubek::random::{assert_mean_approx_equal, assert_normal_respects_68_95_99_rule};\nuse serial_test::serial;\n\n#[test]\n#[serial]\nfn empirical_mean_close_to_expectation() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let shape = [100, 100];\n    let mean = 10.;\n    let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Normal(mean, 2.), &device)\n        .into_data();\n    let numbers = tensor.as_slice::<FloatElem>().unwrap();\n\n    assert_mean_approx_equal(numbers, mean as f32);\n}\n\n#[test]\n#[serial]\nfn normal_respects_68_95_99_rule() {\n    // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule\n    let shape: Shape = [1000, 1000].into();\n    let device = Default::default();\n    let mu = 0.;\n    let s = 1.;\n    let tensor =\n        Tensor::<TestBackend, 2>::random(shape.clone(), Distribution::Normal(mu, s), &device)\n            .into_data();\n\n    let numbers = tensor.as_slice::<FloatElem>().unwrap();\n\n    assert_normal_respects_68_95_99_rule(numbers, mu as f32, s as f32);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/quantization.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{\n    Shape, Tensor,\n    backend::Backend,\n    quantization::{QuantLevel, QuantScheme, QuantStore, QuantValue},\n};\n\nfn should_quantize_dequantize_symmetric_arange<S: Into<Shape>>(\n    value: QuantValue,\n    store: QuantStore,\n    shape: S,\n) {\n    let shape = shape.into();\n    assert_eq!(shape.rank(), 2); // 2D tests\n\n    let scheme = QuantScheme::default().with_value(value).with_store(store);\n    let scheme_ref = scheme.clone().with_store(QuantStore::Native);\n\n    let input: Tensor<TestBackend, 2> =\n        Tensor::arange(0..shape.num_elements() as i64, &Default::default())\n            .float()\n            .reshape(shape);\n    let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());\n\n    let output = input.quantize_dynamic(&scheme);\n    let output_ref = input_ref.quantize_dynamic(&scheme_ref);\n\n    output.to_data().assert_eq(&output_ref.to_data(), false);\n\n    let output = output.dequantize();\n    let output_ref = output_ref.dequantize();\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());\n}\n\nfn should_quantize_dequantize_symmetric_per_block_arange<S: Into<Shape>>(\n    value: QuantValue,\n    block_size: usize,\n    store: QuantStore,\n    shape: S,\n) {\n    let scheme = QuantScheme::default()\n        .with_value(value)\n        .with_level(QuantLevel::block([block_size as u8]))\n        .with_store(store);\n    let scheme_ref = scheme.clone().with_store(QuantStore::Native);\n\n    let shape = shape.into();\n    let input: Tensor<TestBackend, 2> =\n        Tensor::arange(0..shape.num_elements() as i64, &Default::default())\n            .float()\n            .reshape(shape);\n    let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());\n\n    let output = input.quantize_dynamic(&scheme);\n    let output_ref = input_ref.quantize_dynamic(&scheme_ref);\n\n    output.to_data().assert_eq(&output_ref.to_data(), false);\n\n    let output = output.dequantize();\n    let output_ref = output_ref.dequantize();\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());\n}\n\nfn should_quantize_dequantize_symmetric_per_block(\n    value: QuantValue,\n    block_size: usize,\n    store: QuantStore,\n) {\n    let scheme = QuantScheme::default()\n        .with_value(value)\n        .with_level(QuantLevel::block([block_size as u8]))\n        .with_store(store);\n    let scheme_ref = scheme.clone().with_store(QuantStore::Native);\n\n    let input = Tensor::<TestBackend, 2>::from_floats(\n        [\n            [\n                -1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5, 0.01, 0.025, 0.03, 0.04, 0.01, 0.025,\n                0.03, 0.04,\n            ],\n            [\n                1.8, 1.0, 0.0, -0.5, 1.8, 1.0, 0.0, -0.5, -0.01, -0.025, -0.03, -0.04, -0.01,\n                -0.025, -0.03, -0.04,\n            ],\n        ],\n        &Default::default(),\n    );\n    let input_ref = Tensor::<ReferenceBackend, 2>::from_data(input.to_data(), &Default::default());\n\n    let output = input.quantize_dynamic(&scheme);\n    let output_ref = input_ref.quantize_dynamic(&scheme_ref);\n\n    output.to_data().assert_eq(&output_ref.to_data(), false);\n\n    let output = output.dequantize();\n    let output_ref = output_ref.dequantize();\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_ref.to_data(), Tolerance::default());\n}\n\nfn supports_native() -> bool {\n    let name = <TestBackend as Backend>::name(&Default::default());\n    // TODO: Proper checks for i8 support.\n    name.contains(\"cuda\")\n        || name.contains(\"rocm\")\n        || name.contains(\"hip\")\n        || name.contains(\"vulkan\")\n        || name.contains(\"spirv\")\n        || name.contains(\"metal\")\n        || name.contains(\"msl\")\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q8s_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q8f_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q8F, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q4s_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q4S, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q4f_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q4F, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q2s_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q2S, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q2f_packed() {\n    should_quantize_dequantize_symmetric_arange(QuantValue::Q2F, QuantStore::PackedU32(0), [8, 16])\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_q8s_packed() {\n    should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::PackedU32(0))\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_q4s_packed() {\n    should_quantize_dequantize_symmetric_per_block(QuantValue::Q4S, 8, QuantStore::PackedU32(0))\n}\n\n#[test]\n#[should_panic = \"Block size must be divisible by 16\"]\nfn should_panic_when_block_size_cannot_store_num_quants() {\n    // num_quants in u32 = 32 bits / 2 bits = 16\n    should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 8, QuantStore::PackedU32(0))\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_q2s_packed() {\n    should_quantize_dequantize_symmetric_per_block(QuantValue::Q2S, 16, QuantStore::PackedU32(0))\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_q8s_native() {\n    if supports_native() {\n        should_quantize_dequantize_symmetric_arange(QuantValue::Q8S, QuantStore::Native, [32, 32])\n    }\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_q8s_native() {\n    if supports_native() {\n        should_quantize_dequantize_symmetric_per_block(QuantValue::Q8S, 8, QuantStore::Native)\n    }\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_arange_q8s_packed() {\n    should_quantize_dequantize_symmetric_per_block_arange(\n        QuantValue::Q8S,\n        32,\n        QuantStore::PackedU32(0),\n        [32, 32],\n    )\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_arange_q8s_native() {\n    if supports_native() {\n        should_quantize_dequantize_symmetric_per_block_arange(\n            QuantValue::Q8S,\n            32,\n            QuantStore::Native,\n            [32, 32],\n        )\n    }\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_128x256_q8s_native() {\n    if supports_native() {\n        should_quantize_dequantize_symmetric_per_block_arange(\n            QuantValue::Q8S,\n            32,\n            QuantStore::Native,\n            [128, 256],\n        )\n    }\n}\n#[test]\nfn should_quantize_dequantize_symmetric_arange_128x256_q8s_packed() {\n    should_quantize_dequantize_symmetric_per_block_arange(\n        QuantValue::Q8S,\n        32,\n        QuantStore::PackedU32(0),\n        [128, 256],\n    )\n}\n\n#[test]\n#[should_panic = \"Can't store in u32\"]\nfn should_panic_when_shape_cannot_store_quants() {\n    let device = Default::default();\n    let scheme = QuantScheme::default();\n\n    let _tensor_1 =\n        Tensor::<TestBackend, 2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device)\n            .quantize_dynamic(&scheme);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/reduce.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor};\n\nconst RANK: usize = 4;\nconst SHAPE: [usize; RANK] = [2, 4, 8, 16];\n\n#[test]\nfn reduction_argmax_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    for dim in 0..RANK {\n        tensor\n            .clone()\n            .argmax(dim)\n            .into_data()\n            .assert_eq(&tensor_ref.clone().argmax(dim).into_data(), false);\n    }\n}\n\n#[test]\nfn reduction_argmin_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    for dim in 0..RANK {\n        tensor\n            .clone()\n            .argmin(dim)\n            .into_data()\n            .assert_eq(&tensor_ref.clone().argmin(dim).into_data(), false);\n    }\n}\n\n#[test]\nfn reduction_mean_dim_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    for dim in 0..RANK {\n        tensor\n            .clone()\n            .mean_dim(dim)\n            .into_data()\n            .assert_approx_eq::<FloatElem>(\n                &tensor_ref.clone().mean_dim(dim).into_data(),\n                Tolerance::default(),\n            );\n    }\n}\n\n#[test]\nfn reduction_mean_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    tensor\n        .clone()\n        .mean()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &tensor_ref.clone().mean().into_data(),\n            Tolerance::default(),\n        );\n}\n\n#[test]\nfn reduction_prod_dim_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    for dim in 0..RANK {\n        tensor\n            .clone()\n            .prod_dim(dim)\n            .into_data()\n            .assert_approx_eq::<FloatElem>(\n                &tensor_ref.clone().prod_dim(dim).into_data(),\n                Tolerance::default(),\n            );\n    }\n}\n\n#[test]\nfn reduction_prod_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    tensor\n        .clone()\n        .prod()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &tensor_ref.clone().prod().into_data(),\n            Tolerance::default(),\n        );\n}\n\n#[test]\nfn reduction_sum_dim_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    for dim in 0..RANK {\n        tensor\n            .clone()\n            .sum_dim(dim)\n            .into_data()\n            .assert_approx_eq::<FloatElem>(\n                &tensor_ref.clone().sum_dim(dim).into_data(),\n                Tolerance::default(),\n            );\n    }\n}\n\n#[test]\nfn reduction_sum_should_match_reference_backend() {\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    tensor\n        .clone()\n        .sum()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&tensor_ref.clone().sum().into_data(), Tolerance::default());\n}\n\n#[test]\n#[ignore = \"Impossible to run unless you have tons of VRAM. Also reference backend is broken.\"]\nfn reduction_sum_should_match_reference_backend_64bit() {\n    const SHAPE: [usize; RANK] = [33, 512, 512, 512];\n\n    let tensor =\n        Tensor::<TestBackend, RANK>::random(SHAPE, Distribution::Default, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, RANK>::from_data(tensor.to_data(), &Default::default());\n    let data = tensor.sum().into_data();\n    let data_ref = tensor_ref.sum().into_data();\n    println!(\"result: {:?}\", data.as_slice::<f32>());\n    data.assert_approx_eq::<FloatElem>(&data_ref, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/repeat_dim.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor};\n\n#[test]\nfn repeat_dim_0_few_times() {\n    let tensor =\n        Tensor::<TestBackend, 3>::random([1, 6, 6], Distribution::Default, &Default::default());\n    let dim = 0;\n    let times = 4;\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());\n\n    let actual = tensor.repeat_dim(dim, times);\n    let expected = tensor_ref.repeat_dim(dim, times);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[test]\nfn repeat_dim_1_few_times() {\n    let tensor =\n        Tensor::<TestBackend, 3>::random([6, 1, 6], Distribution::Default, &Default::default());\n    let dim = 1;\n    let times = 4;\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());\n\n    let actual = tensor.repeat_dim(dim, times);\n    let expected = tensor_ref.repeat_dim(dim, times);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[test]\nfn repeat_dim_2_few_times() {\n    let tensor =\n        Tensor::<TestBackend, 3>::random([6, 6, 1], Distribution::Default, &Default::default());\n    let dim = 2;\n    let times = 4;\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());\n\n    let actual = tensor.repeat_dim(dim, times);\n    let expected = tensor_ref.repeat_dim(dim, times);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\n#[test]\nfn repeat_dim_2_many_times() {\n    let tensor =\n        Tensor::<TestBackend, 3>::random([10, 10, 1], Distribution::Default, &Default::default());\n    let dim = 2;\n    let times = 200;\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());\n\n    let actual = tensor.repeat_dim(dim, times);\n    let expected = tensor_ref.repeat_dim(dim, times);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/scatter.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, Int, Tensor, backend::Backend};\nuse burn_tensor::{IndexingUpdateOp, Tolerance};\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_2d_dim0() {\n    same_as_reference_same_shape(0, [256, 32]);\n}\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_2d_dim1() {\n    same_as_reference_same_shape(1, [32, 256]);\n}\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_3d_dim0() {\n    same_as_reference_same_shape(0, [256, 6, 6]);\n}\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_3d_dim1() {\n    same_as_reference_same_shape(1, [6, 256, 6]);\n}\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_3d_dim2() {\n    same_as_reference_same_shape(2, [6, 6, 256]);\n}\n\n#[test]\nfn scatter_should_work_with_multiple_workgroups_diff_shapes() {\n    same_as_reference_diff_shape(1, [32, 128], [32, 1]);\n}\n\nfn same_as_reference_diff_shape<const D: usize>(\n    dim: usize,\n    shape1: [usize; D],\n    shape2: [usize; D],\n) {\n    let test_device = Default::default();\n    TestBackend::seed(&test_device, 0);\n\n    let tensor = Tensor::<TestBackend, D>::random(shape1, Distribution::Default, &test_device);\n    let value = Tensor::<TestBackend, D>::random(shape2, Distribution::Default, &test_device);\n    let indices = Tensor::<TestBackend, 1, Int>::random(\n        [shape2.iter().product::<usize>()],\n        Distribution::Uniform(0., shape2[dim] as f64),\n        &test_device,\n    )\n    .reshape(shape2);\n    let ref_device = Default::default();\n    let tensor_ref = Tensor::<ReferenceBackend, D>::from_data(tensor.to_data(), &ref_device);\n    let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data(), &ref_device);\n    let indices_ref = Tensor::<ReferenceBackend, D, Int>::from_data(indices.to_data(), &ref_device);\n\n    let actual = tensor.scatter(dim, indices, value, IndexingUpdateOp::Add);\n    let expected = tensor_ref.scatter(dim, indices_ref, value_ref, IndexingUpdateOp::Add);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n\nfn same_as_reference_same_shape<const D: usize>(dim: usize, shape: [usize; D]) {\n    same_as_reference_diff_shape(dim, shape, shape);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/select.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Int, Tensor};\n\n#[test]\nfn select_should_work_with_multiple_workgroups() {\n    let tensor =\n        Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());\n    let indices = Tensor::<TestBackend, 1, Int>::arange(0..100, &Default::default());\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default());\n    let indices_ref =\n        Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data(), &Default::default());\n\n    let actual = tensor.select(1, indices);\n    let expected = tensor_ref.select(1, indices_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/select_assign.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, Int, Tensor, backend::Backend};\nuse burn_tensor::{IndexingUpdateOp, Tolerance};\n\n#[test]\nfn select_add_should_work_with_multiple_workgroups_2d_dim0() {\n    select_add_same_as_ref(0, [256, 6]);\n}\n\n#[test]\nfn select_add_should_work_with_multiple_workgroups_2d_dim1() {\n    select_add_same_as_ref(1, [6, 256]);\n}\n\n#[test]\nfn select_add_should_work_with_multiple_workgroups_3d_dim0() {\n    select_add_same_as_ref(0, [256, 6, 6]);\n}\n\n#[test]\nfn select_add_should_work_with_multiple_workgroups_3d_dim1() {\n    select_add_same_as_ref(1, [6, 256, 6]);\n}\n\n#[test]\nfn select_add_should_work_with_multiple_workgroups_3d_dim2() {\n    select_add_same_as_ref(2, [6, 6, 256]);\n}\n\nfn select_add_same_as_ref<const D: usize>(dim: usize, shape: [usize; D]) {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n\n    let tensor =\n        Tensor::<TestBackend, D>::random(shape, Distribution::Default, &Default::default());\n    let value = Tensor::<TestBackend, D>::random(shape, Distribution::Default, &Default::default());\n    let indices = Tensor::<TestBackend, 1, Int>::random(\n        [shape[dim]],\n        Distribution::Uniform(0., shape[dim] as f64),\n        &Default::default(),\n    );\n    let tensor_ref =\n        Tensor::<ReferenceBackend, D>::from_data(tensor.to_data(), &Default::default());\n    let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data(), &Default::default());\n    let indices_ref =\n        Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data(), &Default::default());\n\n    let actual = tensor.select_assign(dim, indices, value, IndexingUpdateOp::Add);\n    let expected = tensor_ref.select_assign(dim, indices_ref, value_ref, IndexingUpdateOp::Add);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/slice.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, Tensor};\n\n#[test]\nfn slice_should_work_with_multiple_workgroups() {\n    let tensor =\n        Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());\n    let indices = [3..5, 45..256];\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default());\n\n    let actual = tensor.slice(indices.clone());\n    let expected = tensor_ref.slice(indices);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/slice_assign.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, Tensor, Tolerance};\n\n#[test]\nfn slice_assign_should_work_with_multiple_workgroups() {\n    let tensor =\n        Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());\n    let value =\n        Tensor::<TestBackend, 2>::random([2, 211], Distribution::Default, &Default::default());\n    let indices = [3..5, 45..256];\n    let tensor_ref =\n        Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data(), &Default::default());\n    let value_ref = Tensor::<ReferenceBackend, 2>::from_data(value.to_data(), &Default::default());\n\n    let actual = tensor.slice_assign(indices.clone(), value);\n    let expected = tensor_ref.slice_assign(indices, value_ref);\n\n    expected\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&actual.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/unary.rs",
    "content": "use super::*;\nuse burn_tensor::Tensor;\n\n#[test]\nfn tanh_should_not_have_numerical_bugs_on_macos() {\n    fn tanh_one_value(input: f32) -> f32 {\n        let tensor = Tensor::<TestBackend, 1>::ones([1], &Default::default()) * input;\n        let output = tensor.tanh().into_primitive();\n        Tensor::<TestBackend, 1>::from_primitive(output)\n            .into_data()\n            .as_slice()\n            .unwrap()[0]\n    }\n\n    let ok = tanh_one_value(43.0); // metal tanh gives 1.0 which is the right answer\n    let zero = tanh_one_value(44.0); // metal tanh gives zero when within 43.67..44.36\n    let nan = tanh_one_value(45.0); // metal tanh gives nan when over 44.36\n    let neg = tanh_one_value(-45.0); //  metal works correctly here\n\n    assert!(!ok.is_nan() && ok == 1.0);\n    assert!(!zero.is_nan() && zero == 1.0);\n    assert!(!nan.is_nan() && nan == 1.0);\n    assert!(!neg.is_nan() && neg == -1.0);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl/uniform.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, Int, Shape, Tensor, backend::Backend};\nuse burn_tensor::{ElementConversion, Tolerance};\n\nuse serial_test::serial;\n\nuse cubek::random::{assert_at_least_one_value_per_bin, assert_wald_wolfowitz_runs_test};\n\n#[test]\n#[serial]\nfn values_all_within_interval_default() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = [24, 24];\n\n    let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Default, &device);\n    tensor\n        .to_data()\n        .assert_within_range::<FloatElem>(0.elem()..1.elem());\n}\n\n#[test]\n#[serial]\nfn values_all_within_interval_uniform() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = [24, 24];\n\n    let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Uniform(5., 17.), &device);\n    tensor\n        .to_data()\n        .assert_within_range::<FloatElem>(5.elem()..17.elem());\n}\n\n#[test]\n#[serial]\nfn at_least_one_value_per_bin_uniform() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = [64, 64];\n\n    let tensor = Tensor::<TestBackend, 2>::random(shape, Distribution::Uniform(-5., 10.), &device)\n        .into_data();\n    let numbers = tensor.as_slice::<FloatElem>().unwrap();\n\n    assert_at_least_one_value_per_bin(numbers, 3, -5., 10.);\n}\n\n#[test]\n#[serial]\nfn runs_test() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = Shape::new([512, 512]);\n    let tensor =\n        Tensor::<TestBackend, 2>::random(shape, Distribution::Default, &device).into_data();\n\n    let numbers = tensor.as_slice::<FloatElem>().unwrap();\n\n    assert_wald_wolfowitz_runs_test(numbers, 0., 1.);\n}\n\n#[test]\n#[serial]\nfn int_values_all_within_interval_uniform() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = Shape::new([20, 20]);\n    let tensor: Tensor<TestBackend, 2, Int> = Tensor::random(shape, Distribution::Default, &device);\n\n    let data_float = tensor.float().into_data();\n\n    data_float.assert_within_range(0..255);\n}\n\n#[test]\n#[serial]\nfn at_least_one_value_per_bin_int_uniform() {\n    let device = Default::default();\n    TestBackend::seed(&device, 0);\n    let shape = Shape::new([64, 64]);\n\n    let tensor: Tensor<TestBackend, 2, Int> =\n        Tensor::random(shape, Distribution::Uniform(-10.0, 10.0), &device);\n\n    let data_float = tensor.float().into_data();\n\n    let numbers = data_float.as_slice::<FloatElem>().unwrap();\n\n    assert_at_least_one_value_per_bin(numbers, 10, -10., 10.);\n}\n\n#[test]\nfn should_not_fail_on_non_float_autotune() {\n    let device = Default::default();\n    let tensor_1 = Tensor::<TestBackend, 2>::from_floats([[1., 2., 3.], [3., 4., 5.]], &device);\n\n    // Autotune of all (reduce) on lower_equal_elem's output calls uniform distribution\n    tensor_1.lower_equal_elem(1.0).all();\n}\n\n#[test]\n#[serial]\nfn test_seed_reproducibility() {\n    let device = Default::default();\n    TestBackend::seed(&device, 42);\n    let t1 = TestTensor::<1>::random([5], Distribution::Default, &device);\n    TestBackend::seed(&device, 42);\n    let t2 = TestTensor::<1>::random([5], Distribution::Default, &device);\n\n    t1.into_data()\n        .assert_approx_eq::<FloatElem>(&t2.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/cubecl.rs",
    "content": "//! CubeCL kernel tests.\n\n#[cfg(feature = \"cube\")]\n#[path = \".\"]\nmod cube {\n    type FloatElemType = f32;\n    type IntElemType = i32;\n\n    mod backend {\n        include!(\"common/backend.rs\");\n        pub type ReferenceBackend = burn_ndarray::NdArray<FloatElemType>;\n    }\n    pub use backend::*;\n\n    #[path = \"cubecl/mod.rs\"]\n    mod kernel;\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/fused_ops/mod.rs",
    "content": "mod reduce_broadcasted;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/fused_ops/reduce_broadcasted.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance, backend::Backend};\n\n#[test]\nfn test_reduce_broadcasted_1() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_read = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_write = TestTensorInt::<1>::arange(0..4, &device)\n        .reshape([4, 1])\n        .float();\n\n    // Forces previous tensors to be materialized.\n    TestBackend::sync(&device).unwrap();\n\n    let x = tensor + fused_on_read.clone();\n    let x = x.sum_dim(1);\n\n    let x = x + fused_on_write;\n\n    // Broadcast\n    let end = x + fused_on_read;\n    let actual = end.into_data();\n    let expected = TensorData::from([\n        [56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0],\n        [193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0],\n        [330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 337.0],\n        [467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0],\n    ]);\n    actual.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_reduce_broadcasted_2() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_read = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_write = TestTensorInt::<1>::arange(16..48, &device)\n        .reshape([4, 8])\n        .float();\n    // Second fuse on read\n    let y = TestTensorInt::<1>::arange(32..64, &device)\n        .reshape([4, 8])\n        .float();\n\n    // Forces previous tensors to be materialized.\n    TestBackend::sync(&device).unwrap();\n\n    let x = tensor + fused_on_read.clone();\n    let x = x.sum_dim(1);\n    let x = x + fused_on_write;\n    let x = x.mean_dim(1);\n\n    let end = x + y;\n    TestBackend::sync(&device).unwrap();\n\n    let actual = end.into_data();\n    let expected = TensorData::from([\n        [107.5, 108.5, 109.5, 110.5, 111.5, 112.5, 113.5, 114.5],\n        [251.5, 252.5, 253.5, 254.5, 255.5, 256.5, 257.5, 258.5],\n        [395.5, 396.5, 397.5, 398.5, 399.5, 400.5, 401.5, 402.5],\n        [539.5, 540.5, 541.5, 542.5, 543.5, 544.5, 545.5, 546.5],\n    ]);\n    actual.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_reduce_broadcasted_3() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_read = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_write = TestTensorInt::<1>::arange(0..4, &device)\n        .reshape([4, 1])\n        .float();\n    // Second fuse on read\n    let y = TestTensorInt::<1>::arange(32..64, &device)\n        .reshape([4, 8])\n        .float();\n\n    // Forces previous tensors to be materialized.\n    TestBackend::sync(&device).unwrap();\n\n    let x = tensor + fused_on_read.clone();\n    let x = x.sum_dim(1);\n\n    let x = x + fused_on_write;\n\n    // Broadcast\n    let x = x + fused_on_read;\n    // Second reduce\n    let x = x.mean_dim(1);\n\n    let end = x + y;\n    let actual = end.into_data();\n    let expected = TensorData::from([\n        [91.5, 92.5, 93.5, 94.5, 95.5, 96.5, 97.5, 98.5],\n        [236.5, 237.5, 238.5, 239.5, 240.5, 241.5, 242.5, 243.5],\n        [381.5, 382.5, 383.5, 384.5, 385.5, 386.5, 387.5, 388.5],\n        [526.5, 527.5, 528.5, 529.5, 530.5, 531.5, 532.5, 533.5],\n    ]);\n    actual.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_reduce_broadcasted_4_reused_partial() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_read = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n    let fused_on_write = TestTensorInt::<1>::arange(0..4, &device)\n        .reshape([4, 1])\n        .float();\n    let y = TestTensorInt::<1>::arange(32..64, &device)\n        .reshape([4, 8])\n        .float();\n\n    // Forces previous tensors to be materialized.\n    TestBackend::sync(&device).unwrap();\n\n    // In fusion we have to create a global buffer to keep the intermediate data for now.\n    let x_previous = tensor + fused_on_read;\n    let x = x_previous.clone().sum_dim(1);\n\n    let x = x * fused_on_write;\n\n    // Broadcast\n    let x = x + x_previous;\n    // Second reduce\n    let x = x.mean_dim(1);\n\n    // Second fuse on read\n    let end = x + y;\n    let actual = end.into_data();\n    let expected = TensorData::from([\n        [39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0],\n        [247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0],\n        [711.0, 712.0, 713.0, 714.0, 715.0, 716.0, 717.0, 718.0],\n        [\n            1431.0, 1432.0, 1433.0, 1434.0, 1435.0, 1436.0, 1437.0, 1438.0,\n        ],\n    ]);\n    actual.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/fusion.rs",
    "content": "//! Burn tensor and autodiff tests for CubeCL backends with fusion enabled.\n\n#![allow(\n    clippy::single_range_in_vec_init,\n    clippy::duplicate_mod,\n    reason = \"false positive\"\n)]\nextern crate alloc;\n\n#[cfg(feature = \"cube\")]\n#[path = \".\"]\nmod fusion {\n    pub type FloatElemType = f32;\n    pub type IntElemType = i32;\n\n    #[path = \"common/backend.rs\"]\n    mod backend;\n    pub use backend::prelude::*;\n\n    // NOTE:\n    // We re-include the tensor and autodiff test suites after overriding `TestBackend`\n    // with `Fusion<TestBackend>`. This intentionally duplicates module names and test\n    // logic to execute the same tests under fusion.\n    pub type TestBackend = burn_fusion::Fusion<backend::TestBackend>;\n    pub type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n    pub type TestTensorInt<const D: usize> = Tensor<TestBackend, D, burn_tensor::Int>;\n    pub type TestTensorBool<const D: usize> = Tensor<TestBackend, D, burn_tensor::Bool>;\n\n    // Tensor tests\n    mod tensor {\n        include!(\"common/tensor.rs\");\n    }\n\n    // Autodiff tests\n    mod autodiff {\n        include!(\"common/autodiff.rs\");\n    }\n\n    // Fusion tests\n    include!(\"fused_ops/mod.rs\");\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod ops;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/all.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_all() {\n    let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorBool::<2>::from([[true, true, true], [true, true, true]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_all_dim() {\n    let tensor = TestTensorBool::<2>::from([[false, true, false], [true, true, true]]);\n    let data_actual = tensor.all_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_all_with_bool_from_lower_equal() {\n    let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6;\n    let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-6;\n\n    let ge = tensor1.lower_equal(tensor2);\n    let all = ge.clone().all();\n\n    TensorData::from([true]).assert_eq(&all.clone().into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/any.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_any() {\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_any_dim() {\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/argwhere_nonzero.rs",
    "content": "use super::*;\nuse alloc::vec::Vec;\nuse burn_tensor::{Shape, TensorData};\n\n#[test]\nfn test_argwhere_1d() {\n    let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);\n    let output = tensor.argwhere();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1], [3], [4]]), false);\n}\n\n#[test]\nfn test_argwhere_2d() {\n    let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);\n    let output = tensor.argwhere();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 1], [2, 0], [2, 1]]), false);\n}\n\n#[test]\nfn test_argwhere_3d() {\n    let tensor = TestTensorBool::<3>::from([\n        [[false, false, false], [false, true, false]],\n        [[true, false, true], [true, true, false]],\n    ]);\n    let output = tensor.argwhere();\n\n    output.into_data().assert_eq(\n        &TensorData::from([[0, 1, 1], [1, 0, 0], [1, 0, 2], [1, 1, 0], [1, 1, 1]]),\n        false,\n    );\n}\n\n#[test]\nfn test_nonzero_1d() {\n    let tensor = TestTensorBool::<1>::from([false, true, false, true, true]);\n    let data_actual = tensor\n        .nonzero()\n        .into_iter()\n        .map(|t| t.into_data())\n        .collect::<Vec<_>>();\n\n    assert_eq!(data_actual.len(), 1);\n    data_actual[0].assert_eq(&TensorData::from([1, 3, 4]), false);\n}\n\n#[test]\nfn test_nonzero_2d() {\n    // 2-D tensor\n    let tensor = TestTensorBool::<2>::from([[false, false], [false, true], [true, true]]);\n    let data_actual = tensor\n        .nonzero()\n        .into_iter()\n        .map(|t| t.into_data())\n        .collect::<Vec<_>>();\n    let data_expected = [TensorData::from([1, 2, 2]), TensorData::from([1, 0, 1])];\n\n    assert_eq!(data_actual.len(), 2);\n    for (idx, actual) in data_actual.iter().enumerate() {\n        actual.assert_eq(&data_expected[idx], false)\n    }\n}\n\n#[test]\nfn test_nonzero_3d() {\n    // 3-D tensor\n    let tensor = TestTensorBool::<3>::from([\n        [[false, false, false], [false, true, false]],\n        [[true, false, true], [true, true, false]],\n    ]);\n    let data_actual = tensor\n        .nonzero()\n        .into_iter()\n        .map(|t| t.into_data())\n        .collect::<Vec<_>>();\n    let data_expected = [\n        TensorData::from([0, 1, 1, 1, 1]),\n        TensorData::from([1, 0, 0, 1, 1]),\n        TensorData::from([1, 0, 2, 0, 1]),\n    ];\n\n    assert_eq!(data_actual.len(), 3);\n    for (idx, actual) in data_actual.iter().enumerate() {\n        actual.assert_eq(&data_expected[idx], false)\n    }\n}\n\n#[test]\nfn test_nonzero_empty() {\n    let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);\n    let output = tensor.nonzero();\n\n    assert_eq!(output.len(), 0);\n}\n\n#[test]\nfn test_argwhere_empty() {\n    let tensor = TestTensorBool::<1>::from([false, false, false, false, false]);\n    let output = tensor.argwhere();\n\n    assert_eq!(output.shape(), Shape::new([0, 1]));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/cat.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_cat_ops_bool() {\n    let device = Default::default();\n    let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);\n    let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);\n\n    let output = Tensor::cat(vec![tensor_1, tensor_2], 0);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[false, true, true], [true, true, false]]),\n        false,\n    );\n}\n\n#[test]\nfn should_support_cat_with_empty_tensor_bool() {\n    let device = Default::default();\n    let tensor_1 = TestTensorBool::<2>::from_data([[true, false, true]], &device);\n    let tensor_2: TestTensorBool<2> = TestTensorBool::empty([1, 0], &device);\n\n    let output = Tensor::cat(vec![tensor_1, tensor_2], 1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[true, false, true]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/comparison.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_bool_equal() {\n    let data_1 = TensorData::from([[false, true, true], [true, false, true]]);\n    let data_2 = TensorData::from([[false, false, true], [false, true, true]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device);\n\n    let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, true], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn should_support_bool_not_equal() {\n    let data_1 = TensorData::from([[false, true, true], [true, false, true]]);\n    let data_2 = TensorData::from([[false, false, true], [false, true, true]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorBool::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorBool::<2>::from_data(data_2, &device);\n\n    let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.not_equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, false], [true, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn should_support_bool_not() {\n    let data_1 = TensorData::from([[false, true, true], [true, true, false]]);\n    let tensor_1 = TestTensorBool::<2>::from_data(data_1, &Default::default());\n\n    let data_actual_cloned = tensor_1.clone().bool_not();\n    let data_actual_inplace = tensor_1.bool_not();\n\n    let data_expected = TensorData::from([[true, false, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_bool_equal_elem() {\n    let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]);\n\n    let data_actual_cloned = tensor_1.clone().equal_elem(false);\n    let data_actual_inplace = tensor_1.equal_elem(false);\n\n    let data_expected = TensorData::from([[false, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_bool_not_equal_elem() {\n    let tensor_1 = TestTensorBool::<2>::from([[true, false, true], [false, true, false]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal_elem(true);\n    let data_actual_inplace = tensor_1.not_equal_elem(true);\n\n    let data_expected = TensorData::from([[false, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/create_like.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_zeros_like() {\n    let tensor = TestTensorBool::<3>::from([\n        [[false, true, false], [true, true, true]],\n        [[false, false, false], [true, true, false]],\n    ]);\n\n    let tensor = tensor.zeros_like();\n    let expected = TensorData::from([\n        [[false, false, false], [false, false, false]],\n        [[false, false, false], [false, false, false]],\n    ]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_ones_like() {\n    let tensor = TestTensorBool::<3>::from([\n        [[false, true, false], [true, true, true]],\n        [[false, false, false], [true, true, false]],\n    ]);\n\n    let tensor = tensor.ones_like();\n    let expected = TensorData::from([\n        [[true, true, true], [true, true, true]],\n        [[true, true, true], [true, true, true]],\n    ]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/expand.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn expand_2d_bool() {\n    let tensor = TestTensorBool::<1>::from([false, true, false]);\n    let expanded_tensor = tensor.expand([3, 3]);\n\n    let expected_data = TensorData::from([\n        [false, true, false],\n        [false, true, false],\n        [false, true, false],\n    ]);\n\n    expanded_tensor.into_data().assert_eq(&expected_data, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/flip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn flip_bool() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .greater_elem(10);\n\n    let flipped = tensor.clone().flip([0, 2]);\n\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10)\n    let data_expected = TensorData::from([\n        [\n            [true, true, true, true],\n            [true, true, true, true],\n            [true, true, true, true],\n        ],\n        [\n            [false, false, false, false],\n            [false, false, false, false],\n            [true, false, false, false],\n        ],\n    ]);\n\n    flipped.into_data().assert_eq(&data_expected, false);\n\n    // Test with no flip\n    let flipped = tensor.clone().flip([]);\n    tensor.into_data().assert_eq(&flipped.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/full.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_tensor_full() {\n    let device = Default::default();\n    let bool_tensor = TestTensorBool::<2>::full([2, 2], true, &device);\n    bool_tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[true, true], [true, true]]), false);\n\n    let bool_tensor = TestTensorBool::<2>::full([2, 2], false, &device);\n    bool_tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[false, false], [false, false]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/gather_scatter.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_scatter_1d_bool() {\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, false], &device);\n    let values = TestTensorBool::from_data([false, true, true], &device);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &device);\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([true, false, true]), false);\n}\n\n#[test]\nfn should_gather_1d_dim0_bool() {\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, false], &device);\n    let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([false, false, true, false, false]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/init.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_bool_empty() {\n    let shape = [2, 2];\n    let tensor = TestTensorBool::<2>::empty(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into())\n}\n\n#[test]\nfn should_support_bool_zeros() {\n    let shape = [2, 2];\n    let tensor = TestTensorBool::<2>::zeros(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[false, false], [false, false]]), false);\n}\n\n#[test]\nfn should_support_bool_ones() {\n    let shape = [2, 2];\n    let tensor = TestTensorBool::<2>::ones(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[true, true], [true, true]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/logical.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_bool_and() {\n    let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);\n    let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);\n    let data_actual = tensor1.bool_and(tensor2).into_data();\n    let data_expected = TensorData::from([[false, true, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_bool_or() {\n    let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);\n    let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);\n    let data_actual = tensor1.bool_or(tensor2).into_data();\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_bool_xor() {\n    let tensor1 = TestTensorBool::<2>::from([[false, true, false], [true, false, true]]);\n    let tensor2 = TestTensorBool::<2>::from([[true, true, false], [false, false, true]]);\n    let data_actual = tensor1.bool_xor(tensor2).into_data();\n    let data_expected = TensorData::from([[true, false, false], [true, false, false]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_bool_or_vec() {\n    let device = Default::default();\n    let tensor1 = TestTensorBool::<1>::full([256], 0, &device);\n    let tensor2 = TestTensorBool::<1>::full([256], 1, &device);\n    let data_actual = tensor1.bool_or(tensor2).into_data();\n    let data_expected = TensorData::from([true; 256]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_bool_and_vec() {\n    let device = Default::default();\n    let tensor1 = TestTensorBool::<1>::full([256], 0, &device);\n    let tensor2 = TestTensorBool::<1>::full([256], 1, &device);\n    let data_actual = tensor1.bool_and(tensor2).into_data();\n    let data_expected = TensorData::from([false; 256]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/mask.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_bool_mask_where_ops() {\n    let device = Default::default();\n    let tensor = TestTensorBool::<2>::from_data([[true, false], [false, false]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n    let value =\n        TestTensorBool::<2>::from_data(TensorData::from([[false, true], [true, false]]), &device);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([[false, false], [false, false]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_bool_mask_fill_ops() {\n    let device = Default::default();\n    let tensor = TestTensorBool::<2>::from_data([[false, true], [false, false]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n\n    let output = tensor.mask_fill(mask, true);\n    let expected = TensorData::from([[true, true], [false, true]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod all;\nmod any;\nmod argwhere_nonzero;\nmod cat;\nmod comparison;\nmod create_like;\nmod expand;\nmod flip;\nmod full;\nmod gather_scatter;\nmod init;\nmod logical;\nmod mask;\nmod movedim;\nmod permute;\nmod repeat;\nmod repeat_dim;\nmod reshape;\nmod select;\nmod stack;\nmod take;\nmod transpose;\nmod tri_mask;\nmod unfold;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/movedim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn movedim_bool() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .greater_elem(10);\n\n    let permuted = tensor.clone().movedim(0, 2);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).gt(10)\n    let expected = TensorData::from([\n        [[false, true], [false, true], [false, true], [false, true]],\n        [[false, true], [false, true], [false, true], [false, true]],\n        [[false, true], [false, true], [false, true], [true, true]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().movedim(0, -1);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().movedim(0, 0);\n    permuted.into_data().assert_eq(&tensor.into_data(), false);\n}\n\n#[test]\nfn vec_input_bool() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .greater_elem(10);\n\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);\n    // from pytorch\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).gt(10)\n    let expected = TensorData::from([\n        [[false, false, false, false], [true, true, true, true]],\n        [[false, false, false, false], [true, true, true, true]],\n        [[false, false, false, true], [true, true, true, true]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axes\n    let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axes\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);\n    permuted.into_data().assert_eq(&tensor.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/permute.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn permute_bool() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .greater_elem(10);\n\n    let permuted = tensor.clone().permute([2, 1, 0]);\n\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0).gt(10)\n    let expected = TensorData::from([\n        [[false, true], [false, true], [false, true]],\n        [[false, true], [false, true], [false, true]],\n        [[false, true], [false, true], [false, true]],\n        [[false, true], [false, true], [true, true]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().permute([-1, 1, 0]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().permute([0, 1, 2]);\n    permuted.into_data().assert_eq(&tensor.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/repeat.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_bool_repeat_ops_one_dimension() {\n    let data = TensorData::from([[true, false, false]]);\n    let tensor = TestTensorBool::<2>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[4, 1, 1]);\n    let expected = TensorData::from([\n        [true, false, false],\n        [true, false, false],\n        [true, false, false],\n        [true, false, false],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_bool_repeat_on_many_dimension() {\n    let data = TensorData::from([\n        [[false, true], [true, false]],\n        [[true, true], [false, false]],\n    ]);\n    let tensor = TestTensorBool::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[2, 3, 2]);\n    let expected = TensorData::from([\n        [\n            [false, true, false, true],\n            [true, false, true, false],\n            [false, true, false, true],\n            [true, false, true, false],\n            [false, true, false, true],\n            [true, false, true, false],\n        ],\n        [\n            [true, true, true, true],\n            [false, false, false, false],\n            [true, true, true, true],\n            [false, false, false, false],\n            [true, true, true, true],\n            [false, false, false, false],\n        ],\n        [\n            [false, true, false, true],\n            [true, false, true, false],\n            [false, true, false, true],\n            [true, false, true, false],\n            [false, true, false, true],\n            [true, false, true, false],\n        ],\n        [\n            [true, true, true, true],\n            [false, false, false, false],\n            [true, true, true, true],\n            [false, false, false, false],\n            [true, true, true, true],\n            [false, false, false, false],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/repeat_dim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_bool_repeat_ops() {\n    let data = TensorData::from([[true, false, false]]);\n    let tensor = TestTensorBool::<2>::from_data(data, &Default::default());\n\n    let output = tensor.repeat_dim(0, 4);\n    let expected = TensorData::from([\n        [true, false, false],\n        [true, false, false],\n        [true, false, false],\n        [true, false, false],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_bool_repeat_on_dims_larger_than_1() {\n    let data = TensorData::from([\n        [[false, true], [true, false]],\n        [[true, true], [false, false]],\n    ]);\n    let tensor = TestTensorBool::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat_dim(1, 2);\n    let expected = TensorData::from([\n        [[false, true], [true, false], [false, true], [true, false]],\n        [[true, true], [false, false], [true, true], [false, false]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/reshape.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_reshape_bool() {\n    let data = TensorData::from([false, true, false]);\n    let tensor = TestTensorBool::<1>::from_data(data, &Default::default());\n\n    let output = tensor.clone().reshape([1, 3]);\n    let expected = TensorData::from([[false, true, false]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/select.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_select_bool_tensor_1d() {\n    // Test that select works for boolean tensors\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);\n    let indices = TestTensorInt::from_data([0, 2, 1, 0], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([true, true, false, true]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_bool_tensor_2d() {\n    // Test that select works for boolean 2D tensors\n    let device = Default::default();\n    let tensor =\n        TestTensorBool::<2>::from_data([[true, false, true], [false, true, false]], &device);\n    let indices = TestTensorInt::from_data([1, 0], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([[false, true, false], [true, false, true]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_bool_tensor() {\n    // Test that select_add works for boolean tensors\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);\n    let values = TestTensorBool::<1>::from_data([false, true], &device);\n    let indices = TestTensorInt::from_data([0, 2], &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    // Note: select_add uses sum reduction, so:\n    // index 0: true OR false = true\n    // index 2: true OR true = true\n    // index 1: false (unchanged)\n    let expected = TensorData::from([true, false, true]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_bool_overlapping_indices() {\n    // Test accumulation behavior with overlapping indices\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([false, true], &device);\n    let indices = TestTensorInt::from_data([0, 0], &device);\n    let values = TestTensorBool::<1>::from_data([true, false], &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    // Index 0: false OR true OR false = true\n    let expected = TensorData::from([true, true]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_bool_false_to_true_case() {\n    // Test false OR true = true\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([false], &device);\n    let indices = TestTensorInt::from_data([0], &device);\n    let values = TestTensorBool::<1>::from_data([true], &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([true]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_bool_true_or_true_accumulation() {\n    // Test multiple true accumulations\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false], &device);\n    let indices = TestTensorInt::from_data([0, 0, 0], &device);\n    let values = TestTensorBool::<1>::from_data([true, true, true], &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([true, false]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_match_default_implementation_behavior() {\n    // Verify optimized implementation matches original default logic\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);\n    let indices = TestTensorInt::from_data([0, 1, 0], &device);\n    let values = TestTensorBool::<1>::from_data([false, true, true], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    // Manual default implementation logic\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\nfn should_select_add_bool_overlapping_indices_vs_default() {\n    // Test overlapping indices against default implementation\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([false, true], &device);\n    let indices = TestTensorInt::from_data([0, 0], &device);\n    let values = TestTensorBool::<1>::from_data([true, false], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\nfn should_select_add_bool_true_or_true_accumulation_vs_default() {\n    // Test multiple true accumulations against default implementation\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false], &device);\n    let indices = TestTensorInt::from_data([0, 0, 0], &device);\n    let values = TestTensorBool::<1>::from_data([true, true, true], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\nfn should_select_add_bool_false_to_true_case_vs_default() {\n    // Test false OR true case against default implementation\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([false], &device);\n    let indices = TestTensorInt::from_data([0], &device);\n    let values = TestTensorBool::<1>::from_data([true], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\nfn should_select_add_bool_tensor_vs_default() {\n    // Test existing basic case against default implementation\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);\n    let indices = TestTensorInt::from_data([0, 2], &device);\n    let values = TestTensorBool::<1>::from_data([false, false], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\n#[should_panic(expected = \"Tensors are not eq\")]\nfn should_fail_if_replacement_semantics_were_used() {\n    // Test that framework uses accumulation, not replacement\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true], &device);\n    let indices = TestTensorInt::from_data([0], &device);\n    let values = TestTensorBool::<1>::from_data([false], &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let replacement_expected = TensorData::from([false]);\n\n    output.into_data().assert_eq(&replacement_expected, false);\n}\n\n#[test]\n#[should_panic(expected = \"Tensors are not eq\")]\nfn should_fail_if_replacement_semantics_were_used_vs_default() {\n    // Test that default implementation also uses accumulation, not replacement\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true], &device);\n    let indices = TestTensorInt::from_data([0], &device);\n    let values = TestTensorBool::<1>::from_data([false], &device);\n\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n    let replacement_expected = TensorData::from([false]);\n\n    default_result\n        .into_data()\n        .assert_eq(&replacement_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/stack.rs",
    "content": "use super::*;\nuse alloc::vec;\nuse burn_tensor::{Tensor, TensorData};\n\n#[test]\nfn should_support_stack_ops_bool() {\n    let device = Default::default();\n    let tensor_1 = TestTensorBool::<2>::from_data([[false, true, true]], &device);\n    let tensor_2 = TestTensorBool::<2>::from_data([[true, true, false]], &device);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[false, true, true]], [[true, true, false]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/take.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_take_bool_tensor() {\n    // Test take with boolean tensors\n    let device = Default::default();\n    let tensor = TestTensorBool::<2>::from_data([[true, false], [false, true]], &device);\n    let indices = TestTensorInt::<1>::from_data([1, 0], &device);\n\n    let output = tensor.take::<1, 2>(0, indices);\n    let expected = TensorData::from([[false, true], [true, false]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_bool_tensor_with_2d_indices() {\n    // Test take with boolean tensors - output will be 3D\n    let device = Default::default();\n    let tensor = TestTensorBool::<2>::from_data(\n        [\n            [true, false, true],\n            [false, true, false],\n            [true, true, false],\n        ],\n        &device,\n    );\n\n    // 2D indices - shape [2, 2]\n    let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device);\n    let output = tensor.take::<2, 3>(0, indices);\n\n    // Expected: shape [2, 2, 3]\n    let expected = TensorData::from([\n        [[true, false, true], [true, true, false]],\n        [[false, true, false], [true, false, true]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/transpose.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_transpose_bool() {\n    let tensor = TestTensorBool::<3>::from_data(\n        [\n            [[false, true, false], [false, false, false]],\n            [[false, false, true], [false, false, true]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.transpose();\n    let expected = TensorData::from([\n        [[false, false], [true, false], [false, false]],\n        [[false, false], [false, false], [true, true]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_swap_dims_bool() {\n    let tensor = TestTensorBool::<3>::from_data(\n        [\n            [[false, true, false], [false, false, false]],\n            [[false, false, true], [false, false, true]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.swap_dims(0, 2);\n    let expected = TensorData::from([\n        [[false, false], [false, false]],\n        [[true, false], [false, false]],\n        [[false, true], [false, true]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/tri_mask.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn square_diag() {\n    let device = Default::default();\n    let data_expected = TensorData::from([\n        [false, true, true],\n        [true, false, true],\n        [true, true, false],\n    ]);\n    let tensor = TestTensorBool::<2>::diag_mask([3, 3], 0, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn square_diag_offset() {\n    let device = Default::default();\n    let data_expected =\n        TensorData::from([[true, false, true], [true, true, false], [true, true, true]]);\n    let tensor = TestTensorBool::<2>::diag_mask([3, 3], 1, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn square_tri_upper() {\n    let device = Default::default();\n    let data_expected = TensorData::from([\n        [false, false, false],\n        [true, false, false],\n        [true, true, false],\n    ]);\n    let tensor = TestTensorBool::<2>::triu_mask([3, 3], 0, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn square_tri_upper_offset() {\n    let device = Default::default();\n    let data_expected = TensorData::from([\n        [true, false, false],\n        [true, true, false],\n        [true, true, true],\n    ]);\n    let tensor = TestTensorBool::<2>::triu_mask([3, 3], 1, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn square_tri_lower() {\n    let device = Default::default();\n\n    let data_expected = TensorData::from([\n        [false, true, true],\n        [false, false, true],\n        [false, false, false],\n    ]);\n    let tensor = TestTensorBool::<2>::tril_mask([3, 3], 0, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn square_tri_lower_offset() {\n    let device = Default::default();\n\n    let data_expected = TensorData::from([\n        [true, true, true],\n        [false, true, true],\n        [false, false, true],\n    ]);\n    let tensor = TestTensorBool::<2>::tril_mask([3, 3], -1, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n\n#[test]\nfn rect_diag() {\n    let device = Default::default();\n    let data_expected = TensorData::from([\n        [false, true, true, true],\n        [true, false, true, true],\n        [true, true, false, true],\n    ]);\n    let tensor = TestTensorBool::<2>::diag_mask([3, 4], 0, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n\n    let data_expected = TensorData::from([\n        [false, true, true],\n        [true, false, true],\n        [true, true, false],\n        [true, true, true],\n    ]);\n    let tensor = TestTensorBool::<2>::diag_mask([4, 3], 0, &device);\n    tensor.into_data().assert_eq(&data_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/bool/ops/unfold.rs",
    "content": "use super::*;\nuse burn_tensor::Distribution;\nuse burn_tensor::s;\n\n#[test]\nfn test_unfold_bool() {\n    let device = Default::default();\n\n    let input =\n        TestTensor::<3>::random([2, 6, 6], Distribution::Default, &device).greater_elem(0.5);\n\n    let dim = 1;\n    let size = 3;\n    let step = 2;\n    let actual: TestTensorBool<4> = input.clone().unfold(dim, size, step);\n\n    let expected = TestTensorBool::<4>::empty([2, 2, 6, 3], &device)\n        .slice_assign(\n            s![.., 0, .., ..],\n            input\n                .clone()\n                .slice(s![.., 0..3, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        )\n        .slice_assign(\n            s![.., 1, .., ..],\n            input\n                .clone()\n                .slice(s![.., 2..5, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        );\n\n    actual.to_data().assert_eq(&expected.to_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/clone_invariance.rs",
    "content": "/// This module tests whether basic tensor operations remain invariant when performed on clones,\n/// meaning that cloning input tensors won't affect the results.\n///\n/// Those are relevant tests because backends may employ unsafe optimizations to reuse tensor data\n/// and use different kernels in such cases. We ensure that the results are consistent regardless\n/// of the approach and that the input tensors are not modified when cloned.\nuse super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::activation::{\n    gelu, log_sigmoid, log_softmax, mish, relu, sigmoid, silu, softmax, softplus, tanh,\n};\nuse burn_tensor::{Distribution, IndexingUpdateOp, TensorData};\n\npub trait CloneInvarianceTest<const D: usize> {\n    type Args;\n\n    fn args(&self) -> Self::Args;\n\n    fn run(&self, args: &Self::Args, inplace: bool) -> TensorData;\n\n    fn check(&self) {\n        let args = self.args();\n        let out = self.run(&args, false);\n        let out_inplace = self.run(&args, true);\n\n        out.assert_approx_eq::<FloatElem>(&out_inplace, Tolerance::default());\n    }\n}\n\nmacro_rules! clone_invariance_test {\n    (unary: $name:ident, ops_float: $ops:expr) => {\n        #[test]\n        #[allow(non_snake_case)]\n        fn $name() {\n            struct $name;\n\n            impl CloneInvarianceTest<2> for $name {\n                type Args = TensorData;\n\n                fn args(&self) -> Self::Args {\n                    TestTensor::<2>::random([32, 32], Distribution::Default, &Default::default())\n                        .into_data()\n                        .convert::<f32>()\n                }\n\n                fn run(&self, args: &Self::Args, inplace: bool) -> TensorData {\n                    let lhs = TestTensor::from_data(args.clone(), &Default::default());\n\n                    if inplace {\n                        $ops(lhs).into_data().convert::<f32>()\n                    } else {\n                        let out = $ops(lhs.clone()).into_data().convert::<f32>();\n                        lhs.into_data()\n                            .assert_approx_eq::<FloatElem>(args, Tolerance::default());\n                        out\n                    }\n                }\n            }\n\n            CloneInvarianceTest::<2>::check(&$name);\n        }\n    };\n\n    (binary: $name:ident, ops_float: $ops:expr) => {\n        #[test]\n        #[allow(non_snake_case)]\n        fn $name() {\n            struct $name;\n\n            impl CloneInvarianceTest<2> for $name {\n                type Args = (TensorData, TensorData);\n\n                fn args(&self) -> Self::Args {\n                    let device = Default::default();\n                    (\n                        TestTensor::<2>::ones([32, 32], &device)\n                            .into_data()\n                            .convert::<f32>(),\n                        // Avoid div by zero.\n                        TestTensor::<2>::ones([32, 32], &device)\n                            .into_data()\n                            .convert::<f32>(),\n                    )\n                }\n\n                fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData {\n                    let device = Default::default();\n                    let lhs = TestTensor::from_data(lhs_arg.clone(), &device);\n                    let rhs = TestTensor::from_data(rhs_arg.clone(), &device);\n\n                    if inplace {\n                        $ops(lhs, rhs).into_data().convert::<f32>()\n                    } else {\n                        let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::<f32>();\n\n                        lhs.into_data()\n                            .assert_approx_eq::<FloatElem>(lhs_arg, Tolerance::default());\n                        rhs.into_data()\n                            .assert_approx_eq::<FloatElem>(rhs_arg, Tolerance::default());\n\n                        out\n                    }\n                }\n            }\n\n            CloneInvarianceTest::<2>::check(&$name);\n        }\n    };\n\n    (unary: $name:ident, ops_int: $ops:expr) => {\n        #[test]\n        #[allow(non_snake_case)]\n        fn $name() {\n            struct $name;\n\n            impl CloneInvarianceTest<2> for $name {\n                type Args = TensorData;\n\n                fn args(&self) -> Self::Args {\n                    TestTensor::<2>::random(\n                        [32, 32],\n                        Distribution::Uniform(0.0, 50.0),\n                        &Default::default(),\n                    )\n                    .into_data()\n                    .convert::<i32>()\n                }\n\n                fn run(&self, args: &Self::Args, inplace: bool) -> TensorData {\n                    let lhs = TestTensorInt::from_data(args.clone(), &Default::default());\n\n                    if inplace {\n                        $ops(lhs).into_data().convert::<f32>()\n                    } else {\n                        let out = $ops(lhs.clone()).into_data().convert::<f32>();\n                        lhs.into_data()\n                            .convert::<i32>()\n                            .assert_approx_eq::<FloatElem>(args, Tolerance::default());\n                        out\n                    }\n                }\n            }\n\n            CloneInvarianceTest::<2>::check(&$name);\n        }\n    };\n\n    (binary: $name:ident, ops_int: $ops:expr) => {\n        #[test]\n        #[allow(non_snake_case)]\n        fn $name() {\n            struct $name;\n\n            impl CloneInvarianceTest<2> for $name {\n                type Args = (TensorData, TensorData);\n\n                fn args(&self) -> Self::Args {\n                    let device = Default::default();\n                    (\n                        TestTensor::<2>::random([32, 32], Distribution::Uniform(0., 50.), &device)\n                            .into_data()\n                            .convert::<i32>(),\n                        // Avoid div by zero.\n                        TestTensor::<2>::random([32, 32], Distribution::Uniform(1., 51.), &device)\n                            .into_data()\n                            .convert::<i32>(),\n                    )\n                }\n\n                fn run(&self, (lhs_arg, rhs_arg): &Self::Args, inplace: bool) -> TensorData {\n                    let device = Default::default();\n                    let lhs = TestTensorInt::from_data(lhs_arg.clone(), &device);\n                    let rhs = TestTensorInt::from_data(rhs_arg.clone(), &device);\n\n                    if inplace {\n                        $ops(lhs, rhs).into_data().convert::<f32>()\n                    } else {\n                        let out = $ops(lhs.clone(), rhs.clone()).into_data().convert::<f32>();\n\n                        lhs.into_data()\n                            .convert::<i32>()\n                            .assert_approx_eq::<FloatElem>(lhs_arg, Tolerance::default());\n                        rhs.into_data()\n                            .convert::<i32>()\n                            .assert_approx_eq::<FloatElem>(rhs_arg, Tolerance::default());\n\n                        out\n                    }\n                }\n            }\n\n            CloneInvarianceTest::<2>::check(&$name);\n        }\n    };\n}\n\nmod float {\n    use super::*;\n\n    // Unary ops\n    clone_invariance_test!(\n        unary: AddScalar,\n        ops_float: |tensor: TestTensor<2>| tensor.add_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: SubScalar,\n        ops_float: |tensor: TestTensor<2>| tensor.sub_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: DivScalar,\n        ops_float: |tensor: TestTensor<2>| tensor.div_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: MulScalar,\n        ops_float: |tensor: TestTensor<2>| tensor.mul_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: PowScalar,\n        ops_float: |tensor: TestTensor<2>| tensor.powf_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: Square,\n        ops_float: |tensor: TestTensor<2>| tensor.square()\n    );\n    clone_invariance_test!(\n        unary: Sqrt,\n        ops_float: |tensor: TestTensor<2>| tensor.sqrt()\n    );\n    clone_invariance_test!(\n        unary: Exp,\n        ops_float: |tensor: TestTensor<2>| tensor.exp()\n    );\n    clone_invariance_test!(\n        unary: Neg,\n        ops_float: |tensor: TestTensor<2>| tensor.neg()\n    );\n    clone_invariance_test!(\n        unary: MeanDim,\n        ops_float: |tensor: TestTensor<2>| tensor.mean_dim(1)\n    );\n    clone_invariance_test!(\n        unary: SumDim,\n        ops_float: |tensor: TestTensor<2>| tensor.sum_dim(1)\n    );\n    clone_invariance_test!(\n        unary: Sum,\n        ops_float: |tensor: TestTensor<2>| tensor.sum().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Mean,\n        ops_float: |tensor: TestTensor<2>| tensor.mean().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Clamp,\n        ops_float: |tensor: TestTensor<2>| tensor.clamp(-2., 2.)\n    );\n    clone_invariance_test!(\n        unary: ClampMin,\n        ops_float: |tensor: TestTensor<2>| tensor.clamp_min(-2.)\n    );\n    clone_invariance_test!(\n        unary: ClampMax,\n        ops_float: |tensor: TestTensor<2>| tensor.clamp_max(2.)\n    );\n    clone_invariance_test!(\n        unary: Abs,\n        ops_float: |tensor: TestTensor<2>| tensor.abs()\n    );\n    clone_invariance_test!(\n        unary: Cos,\n        ops_float: |tensor: TestTensor<2>| tensor.cos()\n    );\n    clone_invariance_test!(\n        unary: Sin,\n        ops_float: |tensor: TestTensor<2>| tensor.sin()\n    );\n    clone_invariance_test!(\n        unary: Tan,\n        ops_float: |tensor: TestTensor<2>| tensor.tan()\n    );\n    clone_invariance_test!(\n        unary: Log,\n        ops_float: |tensor: TestTensor<2>| tensor.log()\n    );\n    clone_invariance_test!(\n        unary: Log1P,\n        ops_float: |tensor: TestTensor<2>| tensor.log1p()\n    );\n    clone_invariance_test!(\n        unary: SwapDims,\n        ops_float: |tensor: TestTensor<2>| tensor.swap_dims(0, 1)\n    );\n    clone_invariance_test!(\n        unary: Transpose,\n        ops_float: |tensor: TestTensor<2>| tensor.transpose()\n    );\n    clone_invariance_test!(\n        unary: Slice,\n        ops_float: |tensor: TestTensor<2>| tensor.slice([0..12, 12..24])\n    );\n    clone_invariance_test!(\n        unary: Erf,\n        ops_float: |tensor: TestTensor<2>| tensor.erf()\n    );\n    clone_invariance_test!(\n        unary: EqualElem,\n        ops_float: |tensor: TestTensor<2>| tensor.equal_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: NotEqualElem,\n        ops_float: |tensor: TestTensor<2>| tensor.not_equal_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: GreaterElem,\n        ops_float: |tensor: TestTensor<2>| tensor.greater_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: GreaterEqualElem,\n        ops_float: |tensor: TestTensor<2>| tensor.greater_equal_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: LowerElem,\n        ops_float: |tensor: TestTensor<2>| tensor.lower_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: LowerEqualElem,\n        ops_float: |tensor: TestTensor<2>| tensor.lower_equal_elem(0.5)\n    );\n    clone_invariance_test!(\n        unary: Argmax,\n        ops_float: |tensor: TestTensor<2>| tensor.argmax(0)\n    );\n    clone_invariance_test!(\n        unary: Argmin,\n        ops_float: |tensor: TestTensor<2>| tensor.argmin(0)\n    );\n    clone_invariance_test!(\n        unary: Max,\n        ops_float: |tensor: TestTensor<2>| tensor.max().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Min,\n        ops_float: |tensor: TestTensor<2>| tensor.min().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: MaxDim,\n        ops_float: |tensor: TestTensor<2>| tensor.max_dim(1)\n    );\n    clone_invariance_test!(\n        unary: MaxDimWithIndices,\n        ops_float: |tensor: TestTensor<2>| tensor.max_dim_with_indices(1).0\n    );\n    clone_invariance_test!(\n        unary: MinDimWithIndices,\n        ops_float: |tensor: TestTensor<2>| tensor.min_dim_with_indices(1).0\n    );\n    clone_invariance_test!(\n        unary: MinDim,\n        ops_float: |tensor: TestTensor<2>| tensor.min_dim(1)\n    );\n    clone_invariance_test!(\n        unary: Repeat,\n        ops_float: |tensor: TestTensor<2>| {\n            tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])\n        }\n    );\n    clone_invariance_test!(\n        unary: Reshape,\n        ops_float: |tensor: TestTensor<2>| {\n            let shape = tensor.shape();\n            let new_shape = [shape.num_elements(), 1];\n            tensor.reshape(new_shape)\n        }\n    );\n    clone_invariance_test!(\n        unary: Gatter,\n        ops_float: |tensor: TestTensor<2>| {\n            let shape = tensor.shape();\n            let indices = TestTensorInt::ones(shape, &Default::default());\n            tensor.gather(0, indices)\n        }\n    );\n    clone_invariance_test!(\n        unary: Select,\n        ops_float: |tensor: TestTensor<2>| {\n            let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());\n            tensor.select(0, indices)\n        }\n    );\n    clone_invariance_test!(\n        unary: MaskFill,\n        ops_float: |tensor: TestTensor<2>| {\n            let mask = tensor.clone().greater_elem(0.5);\n            tensor.mask_fill(mask, 77.0)\n        }\n    );\n\n    // Activation\n    clone_invariance_test!(\n        unary: Softmax,\n        ops_float: |tensor: TestTensor<2>| softmax(tensor, 1)\n    );\n    clone_invariance_test!(\n        unary: LogSoftmax,\n        ops_float: |tensor: TestTensor<2>| log_softmax(tensor, 1)\n    );\n    clone_invariance_test!(\n        unary: Sigmoid,\n        ops_float: |tensor: TestTensor<2>| sigmoid(tensor)\n    );\n    clone_invariance_test!(\n        unary: LogSigmoid,\n        ops_float: |tensor: TestTensor<2>| log_sigmoid(tensor)\n    );\n    clone_invariance_test!(\n        unary: Relu,\n        ops_float: |tensor: TestTensor<2>| relu(tensor)\n    );\n    clone_invariance_test!(\n        unary: Gelu,\n        ops_float: |tensor: TestTensor<2>| gelu(tensor)\n    );\n    clone_invariance_test!(\n        unary: Mish,\n        ops_float: |tensor: TestTensor<2>| mish(tensor)\n    );\n    clone_invariance_test!(\n        unary: Silu,\n        ops_float: |tensor: TestTensor<2>| silu(tensor)\n    );\n    clone_invariance_test!(\n        unary: Softplus,\n        ops_float: |tensor: TestTensor<2>| softplus(tensor, 1.0)\n    );\n    clone_invariance_test!(\n        unary: Tanh,\n        ops_float: |tensor: TestTensor<2>| tanh(tensor)\n    );\n\n    // Binary ops\n    clone_invariance_test!(\n        binary: Add,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.add(rhs)\n    );\n    clone_invariance_test!(\n        binary: Sub,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.sub(rhs)\n    );\n    clone_invariance_test!(\n        binary: Div,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.div(rhs)\n    );\n    clone_invariance_test!(\n        binary: Mul,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.mul(rhs)\n    );\n    clone_invariance_test!(\n        binary: Matmul,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.matmul(rhs)\n    );\n    clone_invariance_test!(\n        binary: Equal,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Greater,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater(rhs)\n    );\n    clone_invariance_test!(\n        binary: GreaterEqual,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.greater_equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Lower,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower(rhs)\n    );\n    clone_invariance_test!(\n        binary: LowerEqual,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| lhs.lower_equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Cat,\n        ops_float: |lhs: TestTensor<2>, rhs: TestTensor<2>| {\n            let lhs = lhs.reshape([1usize, 32, 32]);\n            let rhs = rhs.reshape([1usize, 32, 32]);\n\n            TestTensor::cat(vec![lhs, rhs], 0).reshape([64, 32])\n        }\n    );\n    clone_invariance_test!(\n        binary: Scatter,\n        ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {\n            let shape = tensor.shape();\n            let indices = TestTensorInt::ones(shape, &Default::default());\n            tensor.scatter(0, indices, values, IndexingUpdateOp::Add)\n        }\n    );\n    clone_invariance_test!(\n        binary: SliceAssign,\n        ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {\n            tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12]))\n        }\n    );\n    clone_invariance_test!(\n        binary: MaskWhere,\n        ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {\n            let mask = tensor.clone().greater_elem(0.5);\n            tensor.mask_where(mask, values)\n        }\n    );\n    clone_invariance_test!(\n        binary: SelectAssign,\n        ops_float: |tensor: TestTensor<2>, values: TestTensor<2>| {\n            let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());\n            let values = values.select(0, indices.clone());\n            tensor.select_assign(0, indices, values, IndexingUpdateOp::Add)\n        }\n    );\n}\n\nmod int {\n    use super::*;\n\n    // Unary ops\n    clone_invariance_test!(\n        unary: AddScalar,\n        ops_int: |tensor: TestTensorInt<2>| tensor.add_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: SubScalar,\n        ops_int: |tensor: TestTensorInt<2>| tensor.sub_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: DivScalar,\n        ops_int: |tensor: TestTensorInt<2>| tensor.div_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: MulScalar,\n        ops_int: |tensor: TestTensorInt<2>| tensor.mul_scalar(2.0)\n    );\n    clone_invariance_test!(\n        unary: Neg,\n        ops_int: |tensor: TestTensorInt<2>| tensor.neg()\n    );\n    clone_invariance_test!(\n        unary: MeanDim,\n        ops_int: |tensor: TestTensorInt<2>| tensor.mean_dim(1)\n    );\n    clone_invariance_test!(\n        unary: SumDim,\n        ops_int: |tensor: TestTensorInt<2>| tensor.sum_dim(1)\n    );\n    clone_invariance_test!(\n        unary: Sum,\n        ops_int: |tensor: TestTensorInt<2>| tensor.sum().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Mean,\n        ops_int: |tensor: TestTensorInt<2>| tensor.mean().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Clamp,\n        ops_int: |tensor: TestTensorInt<2>| tensor.clamp(-2., 2.)\n    );\n    clone_invariance_test!(\n        unary: ClampMin,\n        ops_int: |tensor: TestTensorInt<2>| tensor.clamp_min(-2.)\n    );\n    clone_invariance_test!(\n        unary: ClampMax,\n        ops_int: |tensor: TestTensorInt<2>| tensor.clamp_max(2.)\n    );\n    clone_invariance_test!(\n        unary: Abs,\n        ops_int: |tensor: TestTensorInt<2>| tensor.abs()\n    );\n    clone_invariance_test!(\n        unary: SwapDims,\n        ops_int: |tensor: TestTensorInt<2>| tensor.swap_dims(0, 1)\n    );\n    clone_invariance_test!(\n        unary: Transpose,\n        ops_int: |tensor: TestTensorInt<2>| tensor.transpose()\n    );\n    clone_invariance_test!(\n        unary: Slice,\n        ops_int: |tensor: TestTensorInt<2>| tensor.slice([0..12, 12..24])\n    );\n    clone_invariance_test!(\n        unary: EqualElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.equal_elem(25)\n    );\n    clone_invariance_test!(\n        unary: NotEqualElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.not_equal_elem(25)\n    );\n    clone_invariance_test!(\n        unary: GreaterElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.greater_elem(25)\n    );\n    clone_invariance_test!(\n        unary: GreaterEqualElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.greater_equal_elem(25)\n    );\n    clone_invariance_test!(\n        unary: LowerElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.lower_elem(25)\n    );\n    clone_invariance_test!(\n        unary: LowerEqualElem,\n        ops_int: |tensor: TestTensorInt<2>| tensor.lower_equal_elem(25)\n    );\n    clone_invariance_test!(\n        unary: Argmax,\n        ops_int: |tensor: TestTensorInt<2>| tensor.argmax(0)\n    );\n    clone_invariance_test!(\n        unary: Argmin,\n        ops_int: |tensor: TestTensorInt<2>| tensor.argmin(0)\n    );\n    clone_invariance_test!(\n        unary: Max,\n        ops_int: |tensor: TestTensorInt<2>| tensor.max().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: Min,\n        ops_int: |tensor: TestTensorInt<2>| tensor.min().unsqueeze::<2>()\n    );\n    clone_invariance_test!(\n        unary: MaxDim,\n        ops_int: |tensor: TestTensorInt<2>| tensor.max_dim(1)\n    );\n    clone_invariance_test!(\n        unary: MaxDimWithIndices,\n        ops_int: |tensor: TestTensorInt<2>| tensor.max_dim_with_indices(1).0\n    );\n    clone_invariance_test!(\n        unary: MinDimWithIndices,\n        ops_int: |tensor: TestTensorInt<2>| tensor.min_dim_with_indices(1).0\n    );\n    clone_invariance_test!(\n        unary: MinDim,\n        ops_int: |tensor: TestTensorInt<2>| tensor.min_dim(1)\n    );\n    clone_invariance_test!(\n        unary: Repeat,\n        ops_int: |tensor: TestTensorInt<2>| {\n            tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])\n        }\n    );\n    clone_invariance_test!(\n        unary: Reshape,\n        ops_int: |tensor: TestTensorInt<2>| {\n            let shape = tensor.shape();\n            let new_shape = [shape.num_elements(), 1];\n            tensor.reshape(new_shape)\n        }\n    );\n    clone_invariance_test!(\n        unary: Gatter,\n        ops_int: |tensor: TestTensorInt<2>| {\n            let shape = tensor.shape();\n            let indices = TestTensorInt::ones(shape, &Default::default());\n            tensor.gather(0, indices)\n        }\n    );\n    clone_invariance_test!(\n        unary: Select,\n        ops_int: |tensor: TestTensorInt<2>| {\n            let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());\n            tensor.select(0, indices)\n        }\n    );\n    clone_invariance_test!(\n        unary: MaskFill,\n        ops_int: |tensor: TestTensorInt<2>| {\n            let mask = tensor.clone().greater_elem(0.5);\n            tensor.mask_fill(mask, 77.0)\n        }\n    );\n\n    // Binary ops\n    clone_invariance_test!(\n        binary: Add,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.add(rhs)\n    );\n    clone_invariance_test!(\n        binary: Sub,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.sub(rhs)\n    );\n    clone_invariance_test!(\n        binary: Div,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.div(rhs)\n    );\n    clone_invariance_test!(\n        binary: Mul,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.mul(rhs)\n    );\n    clone_invariance_test!(\n        binary: Equal,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: NotEqual,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.not_equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Greater,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater(rhs)\n    );\n    clone_invariance_test!(\n        binary: GreaterEqual,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.greater_equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Lower,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower(rhs)\n    );\n    clone_invariance_test!(\n        binary: LowerEqual,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| lhs.lower_equal(rhs)\n    );\n    clone_invariance_test!(\n        binary: Cat,\n        ops_int: |lhs: TestTensorInt<2>, rhs: TestTensorInt<2>| {\n            let lhs = lhs.reshape([1usize, 32, 32]);\n            let rhs = rhs.reshape([1usize, 32, 32]);\n\n            TestTensorInt::cat(vec![lhs, rhs], 0).reshape([64, 32])\n        }\n    );\n    clone_invariance_test!(\n        binary: Scatter,\n        ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {\n            let shape = tensor.shape();\n            let indices = TestTensorInt::ones(shape, &Default::default());\n            tensor.scatter(0, indices, values, IndexingUpdateOp::Add)\n        }\n    );\n    clone_invariance_test!(\n        binary: SliceAssign,\n        ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {\n            tensor.slice_assign([0..12, 12..24], values.slice([12..24, 0..12]))\n        }\n    );\n    clone_invariance_test!(\n        binary: MaskWhere,\n        ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {\n            let mask = tensor.clone().greater_elem(0.5);\n            tensor.mask_where(mask, values)\n        }\n    );\n    clone_invariance_test!(\n        binary: SelectAssign,\n        ops_int: |tensor: TestTensorInt<2>, values: TestTensorInt<2>| {\n            let indices = TestTensorInt::from_ints([1, 2, 0, 5], &Default::default());\n            let values = values.select(0, indices.clone());\n            tensor.select_assign(0, indices, values, IndexingUpdateOp::Add)\n        }\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/celu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_celu_d2() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [-3.0, 0.5]]);\n\n    let output = activation::celu(tensor, 1.0);\n    // celu(1, 1) = 1\n    // celu(7, 1) = 7\n    // celu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213\n    // celu(0.5, 1) = 0.5\n    let expected = TensorData::from([[1.0, 7.0], [-0.950213, 0.5]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_celu_with_alpha() {\n    let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]);\n\n    let output = activation::celu(tensor, 2.0);\n    // celu(0, 2) = 0\n    // celu(-1, 2) = 2 * (exp(-0.5) - 1) = -0.786939\n    // celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241\n    let expected = TensorData::from([0.0, -0.786939, -1.264241]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/elu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_elu() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::elu(tensor, 1.0);\n    // elu(1, 1) = 1, elu(7, 1) = 7, elu(13, 1) = 13\n    // elu(-3, 1) = 1 * (exp(-3) - 1) = -0.950213\n    let expected = TensorData::from([[1.0, 7.0], [13.0, -0.950213]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_elu_alpha() {\n    let tensor = TestTensor::<1>::from([0.0, -1.0, -2.0]);\n\n    let output = activation::elu(tensor, 2.0);\n    // elu(0, 2) = 2*(exp(0)-1) = 0\n    // elu(-1, 2) = 2*(exp(-1)-1) = 2*(-0.632121) = -1.264241\n    // elu(-2, 2) = 2*(exp(-2)-1) = 2*(-0.864665) = -1.729329\n    let expected = TensorData::from([0.0, -1.264241, -1.729329]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/gelu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_gelu() {\n    let tensor = TestTensor::<2>::from([[\n        0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,\n    ]]);\n    let output = activation::gelu(tensor);\n    let expected = TensorData::from([[\n        0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,\n    ]]);\n\n    // Low precision to allow approximation implementation using tanh\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_absolute(2e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/glu.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_glu_d3() {\n    let tensor = TestTensor::<3>::from([[\n        [\n            -0.5710, -1.3416, 1.9128, -0.8257, -0.1331, -1.4804, -0.6281, -0.6115,\n        ],\n        [\n            0.0267, -1.3834, 0.2752, 0.7844, -0.3549, -0.4274, 0.3290, -0.5459,\n        ],\n        [\n            -1.6347, -2.0908, 1.8801, 0.3541, 0.2237, 1.0377, 2.4850, 0.3490,\n        ],\n    ]]);\n\n    let output = activation::glu(tensor, 2);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[\n            [-0.2665, -0.2487, 0.6656, -0.2904],\n            [0.0110, -0.5461, 0.1601, 0.2877],\n            [-0.9084, -1.5439, 1.7355, 0.2077],\n        ]]),\n        Default::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/hard_sigmoid.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_hard_sigmoid() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::hard_sigmoid(tensor, 0.2, 0.5);\n    let expected = TensorData::from([[0.7, 1.0], [1.0, 0.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_hard_sigmoid_overflow() {\n    let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);\n\n    let output = activation::hard_sigmoid(tensor, 0.2, 0.5);\n    let expected = TensorData::from([1.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/leaky_relu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_leaky_relu_d2() {\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);\n\n    let output = activation::leaky_relu(tensor, 0.01);\n\n    // Account for conversion errors if `FloatType != f32`\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([[0.0, -0.01, 2.0], [3.0, -0.04, 5.0]]),\n        Tolerance::default(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/log_sigmoid.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{ElementConversion, TensorData, activation};\n\n#[test]\nfn test_log_sigmoid() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::log_sigmoid(tensor);\n    let expected = TensorData::from([[-3.132617e-1, -9.114665e-4], [-2.260327e-6, -3.0485873]]);\n\n    let tolerance = Tolerance::rel_abs(0.01, 0.0001);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_log_sigmoid_numerical_stability() {\n    let tensor = TestTensor::<1>::from([300.0, -300.0]);\n\n    let output = activation::log_sigmoid(tensor);\n\n    // For large negative values, the previous implementation −log(1 + exp(−x)) would give -inf\n    let expected = TensorData::from([0.0, -300.0]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);\n    let output = activation::log_sigmoid(tensor);\n    let expected = TensorData::from([0.elem(), FloatElem::MIN]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/mish.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_mish() {\n    let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);\n\n    let output = activation::mish(tensor);\n    let expected = TensorData::from([\n        [-0.19709, -0.30056, -0.11714],\n        [-0.24132, 0.58235, -0.08877],\n    ]);\n\n    // Metal has less precise trigonometric functions (tanh inside mish)\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-2);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/mod.rs",
    "content": "use super::*;\n\nmod celu;\nmod elu;\nmod gelu;\nmod glu;\nmod hard_sigmoid;\nmod leaky_relu;\nmod log_sigmoid;\nmod mish;\nmod prelu;\nmod quiet_softmax;\nmod relu;\nmod selu;\nmod sigmoid;\nmod silu;\nmod softmax;\nmod softmin;\nmod softplus;\nmod softsign;\nmod tanh_activation;\nmod thresholded_relu;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/prelu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_prelu_2_dimension() {\n    let data = [\n        [-1.1, 0.0, 1.2, 0.25, -5.4],\n        [-4.567, 0.56, -1.55, 99.9, 0.0],\n    ];\n    let tensor = TestTensor::<2>::from(data);\n    let output = activation::prelu(tensor, TestTensor::from([0.5, 0.25, 0.0, -0.8, -0.4]));\n    let expected = TensorData::from([\n        [-0.5500, 0.0000, 1.2000, 0.2500, 2.1600],\n        [-2.2835, 0.5600, -0.0000, 99.9000, -0.0000],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n#[test]\nfn test_prelu_2_dimension_scalar_weight() {\n    let data = [\n        [-1.1, 0.0, 1.2, 0.25, -5.4],\n        [-4.567, 0.56, -1.55, 99.9, 0.0],\n    ];\n    let tensor = TestTensor::<2>::from(data);\n    let output = activation::prelu(tensor, TestTensor::from([-0.8]));\n    let expected = TensorData::from([\n        [0.8800, -0.0000, 1.2000, 0.2500, 4.3200],\n        [3.6536, 0.5600, 1.2400, 99.9000, -0.0000],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_prelu_positives() {\n    // Check that positives are untouched\n    let data = [[\n        0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,\n    ]];\n    let tensor = TestTensor::<2>::from(data);\n    let output = activation::prelu(tensor, TestTensor::from([0.25]));\n    let expected = TensorData::from(data);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_prelu_zero_weight() {\n    // test that with weight 0 it behaves as relu\n    let data = [-1.1, 0.0, 1.2, 0.25, -5.4];\n    let tensor = TestTensor::<1>::from(data);\n    let output = activation::prelu(tensor, TestTensor::from([0.0]));\n    let expected = TensorData::from([0.0, 0.0, 1.2, 0.25, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_prelu_some_weight() {\n    // test that with some non zero weight it works like leaky relu\n    let data = [-1.1, 0.0, 1.2, 0.25, -5.4];\n    let tensor = TestTensor::<1>::from(data);\n    let output = activation::prelu(tensor, TestTensor::from([0.5]));\n    let expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[should_panic]\nfn test_prelu_single_dim_multi_weight() {\n    // should panic because the data has only 1 channel\n    let data = [-1.1, 2.0, 1.2, 0.25, -5.4];\n    let tensor = TestTensor::<1>::from(data);\n    let data_actual =\n        activation::prelu(tensor, TestTensor::from([0.5, -0.25, 0.0, 0.5, -1.0])).into_data();\n    let data_expected = TensorData::from([-0.550, 0.0, 1.20, 0.250, -2.70]);\n    data_expected.assert_approx_eq::<FloatElem>(&data_actual, Tolerance::default());\n}\n\n#[test]\n#[should_panic]\nfn test_prelu_multi_dim_wrong_weights() {\n    let data = [\n        [-1.1, 0.0, 1.2, 0.25, -5.4],\n        [-4.567, 0.56, -1.55, 99.9, 0.0],\n    ];\n    let tensor = TestTensor::<2>::from(data);\n    let _ = activation::prelu(tensor, TestTensor::from([-0.8, 0.1]));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/quiet_softmax.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_quiet_softmax_d2() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::quiet_softmax(tensor, 1);\n    let expected = TensorData::from([[2.47e-03, 9.975e-01], [1.0, 1.1254e-07]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/relu.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_relu_d2() {\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, -4.0, 5.0]]);\n\n    let output = activation::relu(tensor);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 5.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/selu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_selu() {\n    // selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0\n    // alpha = 1.6733, gamma = 1.0507\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, -1.0], [2.0, -2.0, 0.5]]);\n\n    let output = activation::selu(tensor);\n\n    // Expected values computed from the formula:\n    // selu(0.0)  = 1.0507 * 1.6733 * (exp(0) - 1) = 0.0\n    // selu(1.0)  = 1.0507 * 1.0 = 1.0507\n    // selu(-1.0) = 1.0507 * 1.6733 * (exp(-1) - 1) = 1.7581 * (0.3679 - 1) = -1.1113\n    // selu(2.0)  = 1.0507 * 2.0 = 2.1014\n    // selu(-2.0) = 1.0507 * 1.6733 * (exp(-2) - 1) = 1.7581 * (0.1353 - 1) = -1.5202\n    // selu(0.5)  = 1.0507 * 0.5 = 0.5254\n    let expected = TensorData::from([[0.0, 1.0507, -1.1113], [2.1014, -1.5202, 0.5254]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_selu_zero() {\n    let tensor = TestTensor::<1>::from([0.0]);\n\n    let output = activation::selu(tensor);\n    let expected = TensorData::from([0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/sigmoid.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_sigmoid() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::sigmoid(tensor);\n    let expected = TensorData::from([[0.731059, 0.999089], [0.999998, 0.047426]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_sigmoid_overflow() {\n    let tensor = TestTensor::<1>::from([FloatElem::MAX, FloatElem::MIN]);\n\n    let output = activation::sigmoid(tensor);\n    let expected = TensorData::from([1.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/silu.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_silu() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let output = activation::silu(tensor);\n    let expected = TensorData::from([[0.73106, 1.76159], [2.85772, 3.92806]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/softmax.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_softmax_d2() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::softmax(tensor, 1);\n    let expected = TensorData::from([[2.472623e-03, 9.975274e-01], [1.0, 1.125352e-07]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/softmin.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_softmin_d2() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::softmin(tensor, 1);\n    let expected = TensorData::from([[9.975274e-01, 2.472623e-03], [1.125352e-07, 1.0000]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/softplus.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_softplus_d2() {\n    let tensor = TestTensor::<2>::from([[-0.4240, -0.9574, -0.2215], [-0.5767, 0.7218, -0.1620]]);\n\n    let output = activation::softplus(tensor.clone(), 1.0);\n    let expected = TensorData::from([\n        [0.503453, 0.324898, 0.588517],\n        [0.445806, 1.117805, 0.615424],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let output = activation::softplus(tensor, 2.0);\n    let expected = TensorData::from([\n        [0.178232, 0.068737, 0.247990],\n        [0.137132, 0.827771, 0.272106],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/softsign.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_softsign() {\n    let tensor = TestTensor::<2>::from([[1.0, 7.0], [13.0, -3.0]]);\n\n    let output = activation::softsign(tensor);\n    let expected = TensorData::from([[0.5, 0.875], [0.928571, -0.75]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_softsign_zero() {\n    let tensor = TestTensor::<1>::from([0.0]);\n\n    let output = activation::softsign(tensor);\n    let expected = TensorData::from([0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/tanh_activation.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_tanh() {\n    let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]);\n\n    let output = activation::tanh(tensor);\n    let expected = TensorData::from([[0.761594, 0.964028], [0.995055, 0.999329]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/activation/thresholded_relu.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, activation};\n\n#[test]\nfn test_thresholded_relu_d2() {\n    // alpha = 1.0 (ONNX default): x if x > 1.0, else 0\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 1.0, 0.5]]);\n\n    let output = activation::thresholded_relu(tensor, 1.0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.0]]), false);\n}\n\n#[test]\nfn test_thresholded_relu_d2_alpha() {\n    // alpha = 0.5: x if x > 0.5, else 0\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 0.5, 0.6]]);\n\n    let output = activation::thresholded_relu(tensor, 0.5);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 0.0, 2.0], [3.0, 0.0, 0.6]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/grid/affine_grid.rs",
    "content": "use super::*;\nuse burn_tensor::grid::affine_grid_2d;\n\nfn create_identity_transform(batch_size: usize) -> TestTensor<3> {\n    // Identity affine transform (batch_size, 2, 3)\n    TestTensor::<3>::from([[[1f32, 0., 0.], [0., 1., 0.]]]).expand([batch_size, 2, 3])\n}\n\n#[test]\nfn test_affine_grid_identity() {\n    let batch_size = 1;\n    let channels = 1;\n    let height = 2;\n    let width = 2;\n\n    let transform = create_identity_transform(batch_size);\n\n    let output = affine_grid_2d(transform, [batch_size, channels, height, width]);\n\n    // Expected normalized coords:\n    // [-1, -1], [ 1,-1]\n    // [-1,  1], [ 1, 1]\n    let expected = TestTensor::<4>::from([[\n        [[-1f32, -1f32], [1f32, -1f32]],\n        [[-1f32, 1f32], [1f32, 1f32]],\n    ]]);\n\n    output.into_data().assert_eq(&expected.into_data(), false);\n}\n\n#[test]\nfn test_affine_grid_scaling() {\n    let batch_size = 1;\n    let channels = 1;\n    let height = 2;\n    let width = 2;\n\n    let scale = 2.0f32;\n    let transform = TestTensor::<3>::from([[[scale, 0., 0.], [0., scale, 0.]]]);\n\n    let output = affine_grid_2d(transform, [batch_size, channels, height, width]);\n\n    // Expect scaled coordinates from normalized grid, so coords * 2\n    let expected = TestTensor::<4>::from([[\n        [[-2f32, -2f32], [2f32, -2f32]],\n        [[-2f32, 2f32], [2f32, 2f32]],\n    ]]);\n\n    output.into_data().assert_eq(&expected.into_data(), false);\n}\n\n#[test]\nfn test_affine_grid_translation() {\n    let batch_size = 1;\n    let channels = 1;\n    let height = 2;\n    let width = 2;\n\n    // Translate by 0.5 in x and -0.5 in y (normalized coords)\n    let tx = 0.5f32;\n    let ty = -0.5f32;\n\n    let transform = TestTensor::<3>::from([[[1.0, 0.0, tx], [0.0, 1.0, ty]]]);\n\n    let output = affine_grid_2d(transform, [batch_size, channels, height, width]);\n\n    // Expected coordinates:\n    // Original normalized coords are [-1,1] in x and y\n    // After translation, each coordinate shifts by tx and ty\n    // So points become:\n    // [-1 + 0.5, -1 - 0.5] = [-0.5, -1.5]\n    // [ 1 + 0.5, -1 - 0.5] = [1.5, -1.5]\n    // [-1 + 0.5,  1 - 0.5] = [-0.5, 0.5]\n    // [ 1 + 0.5,  1 - 0.5] = [1.5, 0.5]\n\n    let expected = TestTensor::<4>::from([[\n        [[-0.5f32, -1.5f32], [1.5f32, -1.5f32]],\n        [[-0.5f32, 0.5f32], [1.5f32, 0.5f32]],\n    ]]);\n\n    output.into_data().assert_eq(&expected.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/grid/meshgrid.rs",
    "content": "use super::*;\nuse burn_tensor::BasicOps;\nuse burn_tensor::Tensor;\nuse burn_tensor::TensorData;\nuse burn_tensor::backend::Backend;\nuse burn_tensor::grid::{\n    GridIndexing, GridOptions, GridSparsity, IndexPos, meshgrid, meshgrid_stack,\n};\n\nfn assert_tensors_equal<const N: usize, B: Backend, K>(\n    actual: &[Tensor<B, N, K>; N],\n    expected: &[Tensor<B, N, K>; N],\n) where\n    K: BasicOps<B>,\n{\n    for (a, e) in actual.iter().zip(expected.iter()) {\n        a.clone()\n            .into_data()\n            .assert_eq(&e.clone().into_data(), true);\n    }\n}\n\n#[test]\nfn test_meshgrid() {\n    let x = TestTensor::<1>::from([1, 2, 3, 4]);\n    let y = TestTensor::<1>::from([5, 6]);\n    let z = TestTensor::<1>::from([7, 8]);\n\n    let grid_shape = [x.dims()[0], y.dims()[0], z.dims()[0]];\n\n    // 3D, Dense, Matrix\n    assert_tensors_equal(\n        &meshgrid(&[x.clone(), y.clone(), z.clone()], GridOptions::default()),\n        &[\n            x.clone().reshape([4, 1, 1]).expand(grid_shape),\n            y.clone().reshape([1, 2, 1]).expand(grid_shape),\n            z.clone().reshape([1, 1, 2]).expand(grid_shape),\n        ],\n    );\n    assert_tensors_equal(\n        &meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Dense),\n        &[\n            x.clone().reshape([4, 1, 1]).expand(grid_shape),\n            y.clone().reshape([1, 2, 1]).expand(grid_shape),\n            z.clone().reshape([1, 1, 2]).expand(grid_shape),\n        ],\n    );\n    assert_tensors_equal(\n        &meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Matrix),\n        &[\n            x.clone().reshape([4, 1, 1]).expand(grid_shape),\n            y.clone().reshape([1, 2, 1]).expand(grid_shape),\n            z.clone().reshape([1, 1, 2]).expand(grid_shape),\n        ],\n    );\n\n    // 3D, Sparse, Matrix\n    assert_tensors_equal(\n        &meshgrid(\n            &[x.clone(), y.clone(), z.clone()],\n            GridOptions {\n                indexing: GridIndexing::Matrix,\n                sparsity: GridSparsity::Sparse,\n            },\n        ),\n        &[\n            x.clone().reshape([4, 1, 1]),\n            y.clone().reshape([1, 2, 1]),\n            z.clone().reshape([1, 1, 2]),\n        ],\n    );\n    assert_tensors_equal(\n        &meshgrid(&[x.clone(), y.clone(), z.clone()], GridSparsity::Sparse),\n        &[\n            x.clone().reshape([4, 1, 1]),\n            y.clone().reshape([1, 2, 1]),\n            z.clone().reshape([1, 1, 2]),\n        ],\n    );\n\n    // 3D, Dense, Cartesian\n    assert_tensors_equal(\n        &meshgrid(&[x.clone(), y.clone(), z.clone()], GridIndexing::Cartesian),\n        &[\n            x.clone()\n                .reshape([4, 1, 1])\n                .expand(grid_shape)\n                .swap_dims(0, 1),\n            y.clone()\n                .reshape([1, 2, 1])\n                .expand(grid_shape)\n                .swap_dims(0, 1),\n            z.clone()\n                .reshape([1, 1, 2])\n                .expand(grid_shape)\n                .swap_dims(0, 1),\n        ],\n    );\n\n    // 3D, Sparse, Cartesian\n    assert_tensors_equal(\n        &meshgrid(\n            &[x.clone(), y.clone(), z.clone()],\n            GridOptions::new(GridIndexing::Cartesian, GridSparsity::Sparse),\n        ),\n        &[\n            x.clone().reshape([4, 1, 1]).swap_dims(0, 1),\n            y.clone().reshape([1, 2, 1]).swap_dims(0, 1),\n            z.clone().reshape([1, 1, 2]).swap_dims(0, 1),\n        ],\n    );\n    assert_tensors_equal(\n        &meshgrid(\n            &[x.clone(), y.clone(), z.clone()],\n            GridOptions {\n                indexing: GridIndexing::Cartesian,\n                sparsity: GridSparsity::Sparse,\n            },\n        ),\n        &[\n            x.clone().reshape([4, 1, 1]).swap_dims(0, 1),\n            y.clone().reshape([1, 2, 1]).swap_dims(0, 1),\n            z.clone().reshape([1, 1, 2]).swap_dims(0, 1),\n        ],\n    );\n}\n\n#[test]\nfn test_meshgrid_stack() {\n    let tensors = [\n        TestTensor::from([0.5, 1.0, 2.5]),\n        TestTensor::from([0.5, 1.0]),\n    ];\n\n    let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::First);\n    result.to_data().assert_eq(\n        &TensorData::from([\n            [[0.5, 0.5], [1.0, 1.0], [2.5, 2.5]],\n            [[0.5, 1.0], [0.5, 1.0], [0.5, 1.0]],\n        ]),\n        false,\n    );\n\n    let result: Tensor<_, 3> = meshgrid_stack(&tensors, IndexPos::Last);\n    result.to_data().assert_eq(\n        &TensorData::from([\n            [[0.5, 0.5], [0.5, 1.0]],\n            [[1.0, 0.5], [1.0, 1.0]],\n            [[2.5, 0.5], [2.5, 1.0]],\n        ]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/grid/mod.rs",
    "content": "use super::*;\n\npub(crate) mod affine_grid;\npub(crate) mod meshgrid;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/cosine_similarity.rs",
    "content": "use super::*;\nuse burn_tensor::{ElementConversion, Tolerance};\nuse burn_tensor::{TensorData, linalg};\n\n#[test]\nfn test_cosine_similarity_basic() {\n    // Create test tensors\n    let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [0.5, 1.5, 2.5]]);\n    let x2 = TestTensor::<2>::from([[1.5, 2.5, 3.5], [0.7, 1.7, 2.7]]);\n\n    // Test cosine similarity along dimension 1\n    let expected = TensorData::from([[0.99983203], [0.99987257]]);\n    linalg::cosine_similarity(x1.clone(), x2.clone(), 1, None)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    // Test with explicit epsilon\n    linalg::cosine_similarity(x1.clone(), x2.clone(), 1, Some(1e-8.elem::<FloatElem>()))\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_cosine_similarity_orthogonal() {\n    // Create orthogonal vectors\n    let x1 = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);\n    let x2 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);\n\n    // Orthogonal vectors should have cosine similarity of 0\n    let expected = TensorData::from([[0.0], [0.0]]);\n    linalg::cosine_similarity(x1, x2, 1, None)\n        .into_data()\n        .assert_eq(&expected, false);\n}\n\n#[test]\nfn test_cosine_similarity_parallel() {\n    // Create parallel vectors\n    let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n    let x2 = TestTensor::<2>::from([[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]]);\n\n    // Parallel vectors should have cosine similarity of 1\n    let expected = TensorData::from([[1.0], [1.0]]);\n    linalg::cosine_similarity(x1, x2, 1, None)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_cosine_similarity_opposite() {\n    // Create opposite direction vectors\n    let x1 = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n    let x2 = TestTensor::<2>::from([[-1.0, -2.0, -3.0], [-4.0, -5.0, -6.0]]);\n\n    // Opposite vectors should have cosine similarity of -1\n    let expected = TensorData::from([[-1.0], [-1.0]]);\n    linalg::cosine_similarity(x1, x2, 1, None)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_cosine_similarity_different_dimension() {\n    // Test with a 3D tensor\n    let x1 = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);\n    let x2 = TestTensor::<3>::from([[[2.0, 3.0], [4.0, 5.0]], [[6.0, 7.0], [8.0, 9.0]]]);\n\n    // Test along dimension 2\n    let expected = TensorData::from([[[0.9959688], [0.9958376]], [[0.9955946], [0.9955169]]]);\n\n    // sensitive to rounding in dot/norm; loosen f16 tolerance\n    let tolerance = Tolerance::default().set_half_precision_relative(7e-3);\n\n    linalg::cosine_similarity(x1.clone(), x2.clone(), 2, None)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    // Test with negative dimension (-1 is the last dimension, which is 2 in this case)\n    linalg::cosine_similarity(x1.clone(), x2.clone(), -1, None)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_cosine_similarity_near_zero() {\n    // Test with near-zero vectors\n    let x1 = TestTensor::<2>::from([[1e-10, 2e-10, 3e-10], [4e-10, 5e-10, 6e-10]]);\n    let x2 = TestTensor::<2>::from([[2e-10, 4e-10, 6e-10], [8e-10, 10e-10, 12e-10]]);\n\n    // Update the expected values based on the actual implementation behavior\n    let expected = TensorData::from([[0.0028], [0.0154]]);\n\n    // Smaller values result in NaN on metal f16\n    let epsilon = Some(FloatElem::from_elem(1e-2));\n    let tolerance = Tolerance::absolute(0.2);\n\n    linalg::cosine_similarity(x1, x2, 1, epsilon)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/diag.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, linalg::diag};\n\n#[test]\nfn test_diag_2d_square() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    let expected = TensorData::from([1.0, 4.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_2d_tall() {\n    let device = Default::default();\n    // 4x2 matrix (tall) - min(4,2) = 2 diagonal elements\n    let tensor =\n        TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Result should have shape [2] with values [1.0, 4.0]\n    let expected = TensorData::from([1.0, 4.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_2d_wide() {\n    let device = Default::default();\n    // 2x4 matrix (wide) - min(2,4) = 2 diagonal elements\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Result should have shape [2] with values [1.0, 6.0]\n    let expected = TensorData::from([1.0, 6.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_3d_batch_square() {\n    let device = Default::default();\n    // Batch of 2 matrices, each 2x2\n    let tensor = TestTensor::<3>::from_data(\n        [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],\n        &device,\n    );\n    let result = diag::<_, 3, 2, _>(tensor);\n    // Result should have shape [2, 2]\n    let expected = TensorData::from([[1.0, 4.0], [5.0, 8.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_3d_batch_tall() {\n    let device = Default::default();\n    // Batch of 2 matrices, each 3x2 (tall)\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],\n            [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],\n        ],\n        &device,\n    );\n    let result = diag::<_, 3, 2, _>(tensor);\n    // Result should have shape [2, 2] - min(3,2) = 2 diagonal elements each\n    let expected = TensorData::from([[1.0, 4.0], [7.0, 10.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_3d_batch_wide() {\n    let device = Default::default();\n    // Batch of 2 matrices, each 2x3 (wide)\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],\n            [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],\n        ],\n        &device,\n    );\n    let result = diag::<_, 3, 2, _>(tensor);\n    // Result should have shape [2, 2] - min(2,3) = 2 diagonal elements each\n    let expected = TensorData::from([[1.0, 5.0], [7.0, 11.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_4d_batch_channel_square() {\n    let device = Default::default();\n    // [batch=2, channel=2, rows=2, cols=2]\n    let tensor = TestTensor::<4>::from_data(\n        [\n            [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],\n            [[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]],\n        ],\n        &device,\n    );\n    let result = diag::<_, 4, 3, _>(tensor);\n    // Result should have shape [2, 2, 2]\n    let expected = TensorData::from([[[1.0, 4.0], [5.0, 8.0]], [[9.0, 12.0], [13.0, 16.0]]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_4d_batch_channel_tall() {\n    let device = Default::default();\n    // [batch=2, channel=1, rows=3, cols=2]\n    let tensor = TestTensor::<4>::from_data(\n        [\n            [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]],\n            [[[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]],\n        ],\n        &device,\n    );\n    let result = diag::<_, 4, 3, _>(tensor);\n    // Result should have shape [2, 1, 2] - min(3,2) = 2 diagonal elements each\n    let expected = TensorData::from([[[1.0, 4.0]], [[7.0, 10.0]]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_4d_batch_channel_wide() {\n    let device = Default::default();\n    // [batch=1, channel=2, rows=2, cols=4]\n    let tensor = TestTensor::<4>::from_data(\n        [[\n            [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],\n            [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],\n        ]],\n        &device,\n    );\n    let result = diag::<_, 4, 3, _>(tensor);\n    // Result should have shape [1, 2, 2] - min(2,4) = 2 diagonal elements each\n    let expected = TensorData::from([[[1.0, 6.0], [9.0, 14.0]]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_1x1() {\n    let device = Default::default();\n    // Single element matrix\n    let tensor = TestTensor::<2>::from_data([[5.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Should return [5.0] with shape [1]\n    let expected = TensorData::from([5.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_single_row() {\n    let device = Default::default();\n    // Single row matrix\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // min(1,3) = 1, should return [1.0] with shape [1]\n    let expected = TensorData::from([1.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_single_column() {\n    let device = Default::default();\n    // Single column matrix\n    let tensor = TestTensor::<2>::from_data([[1.0], [2.0], [3.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // min(3,1) = 1, should return [1.0] with shape [1]\n    let expected = TensorData::from([1.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_zeros() {\n    let device = Default::default();\n    // Matrix with zeros on diagonal\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 0.0]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Should extract diagonal: [0.0, 0.0]\n    let expected = TensorData::from([0.0, 0.0]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_batch_single_element() {\n    let device = Default::default();\n    // Batch with single element matrices\n    let tensor = TestTensor::<3>::from_data([[[5.0]], [[7.0]]], &device);\n    let result = diag::<_, 3, 2, _>(tensor);\n    // Should return [[5.0], [7.0]] with shape [2, 1]\n    let expected = TensorData::from([[5.0], [7.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_batch_mixed_zeros() {\n    let device = Default::default();\n    // Batch with mixed zero and non-zero diagonal elements\n    let tensor = TestTensor::<3>::from_data(\n        [[[1.0, 2.0], [3.0, 0.0]], [[0.0, 5.0], [6.0, 7.0]]],\n        &device,\n    );\n    let result = diag::<_, 3, 2, _>(tensor);\n    // Should return [[1.0, 0.0], [0.0, 7.0]] with shape [2, 2]\n    let expected = TensorData::from([[1.0, 0.0], [0.0, 7.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_int_tensor() {\n    let device = Default::default();\n    // Test with integer tensor\n    let tensor = TestTensorInt::<2>::from_data([[1, 2], [3, 4]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Result should have shape [2] with values [1, 4]\n    let expected = TensorData::from([1, 4]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_diag_int_3x3() {\n    let device = Default::default();\n    // Test with 3x3 integer matrix\n    let tensor = TestTensorInt::<2>::from_data([[1, 2, 3], [4, 5, 6], [7, 8, 9]], &device);\n    let result = diag::<_, 2, 1, _>(tensor);\n    // Result should have shape [3] with values [1, 5, 9]\n    let expected = TensorData::from([1, 5, 9]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn test_diag_1d_should_panic() {\n    let device = Default::default();\n    // 1D tensor should panic - diagonal requires at least 2 dimensions\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);\n    let _result = diag::<_, 1, 0, _>(tensor);\n}\n\n#[test]\n#[should_panic]\nfn test_diag_wrong_output_rank_should_panic() {\n    let device = Default::default();\n    // Providing wrong output rank should panic\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let _result = diag::<_, 2, 2, _>(tensor); // Should be 2,1 not 2,2\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/lu_decomposition.rs",
    "content": "use super::*;\nuse burn_tensor::{\n    Distribution, Shape, TensorData, Tolerance, cast::ToElement, linalg::lu_decomposition, s,\n};\n\n#[test]\nfn test_lu_2x2_decomposition() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[4.0, 3.0], [6.0, 3.0]], &device);\n    let (result, _permutations) = lu_decomposition(tensor);\n    let expected = TensorData::from([[6.0, 3.0], [2.0 / 3.0, 1.0]]);\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_lu_3x3_decomposition() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [[0.0, 5.0, 22.0 / 3.0], [4.0, 2.0, 1.0], [2.0, 7.0, 9.0]],\n        &device,\n    );\n    let (result, permutations) = lu_decomposition(tensor);\n    let expected = TestTensor::<2>::from_data(\n        [\n            [4.0, 2.0, 1.0],\n            [0.5, 6.0, 8.5],\n            [0.0, 0.8333333, 0.25000048],\n        ],\n        &device,\n    );\n    let expected_permutations = TensorData::from([1, 2, 0]);\n    permutations\n        .into_data()\n        .assert_eq(&expected_permutations, false);\n\n    let tolerance = Tolerance::default().set_half_precision_absolute(5e-3);\n    result\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);\n}\n\n#[test]\n#[should_panic]\nfn test_lu_singular_matrix() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [2.0, 4.0]], &device);\n    let _result = lu_decomposition(tensor);\n}\n\n#[test]\n#[should_panic]\nfn test_lu_non_square_matrix() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let _result = lu_decomposition(tensor);\n}\n\n#[test]\nfn test_lu_1x1_element_matrix() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[5.0]], &device);\n    let (result, _permutations) = lu_decomposition(tensor);\n    let expected = TensorData::from([[5.0]]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_lu_identity_matrix() {\n    let device = Default::default();\n\n    let tensor = TestTensor::<2>::eye(4, &device);\n    let (result, _permutations) = lu_decomposition(tensor);\n    let expected = TestTensor::<2>::eye(4, &device);\n    result.into_data().assert_eq(&expected.into_data(), true);\n}\n\n#[test]\nfn test_lu_50x50_random_matrix() {\n    let device = Default::default();\n    let size = 50;\n    let distribution = Distribution::Uniform(0.0, 1.0);\n    let tensor = TestTensor::<2>::random(Shape::new([size, size]), distribution, &device);\n    let (result, permutations) = lu_decomposition(tensor.clone());\n    // Reconstruct the original matrix from L and U\n    let mut l = TestTensor::<2>::eye(size, &device);\n    let mut u = TestTensor::<2>::zeros(Shape::new([size, size]), &device);\n\n    for i in 0..size {\n        for j in 0..size {\n            if i > j {\n                l = l.slice_assign(s![i, j], result.clone().slice(s![i, j]));\n            } else {\n                u = u.slice_assign(s![i, j], result.clone().slice(s![i, j]));\n            }\n        }\n    }\n    // Construct the permutation matrix P from the permutation vector\n    let mut p = TestTensor::<2>::zeros(Shape::new([size, size]), &device);\n    for i in 0..size {\n        let perm_index = permutations.clone().slice(s![i]).into_scalar().to_usize();\n        p = p.slice_assign(\n            s![perm_index, i],\n            TestTensor::<2>::from_data([[1.0]], &device),\n        );\n    }\n\n    // Verify that P * L * U reconstructs the original matrix\n    let reconstructed = p.matmul(l).matmul(u);\n    reconstructed\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&tensor.into_data(), Tolerance::permissive());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/matvec.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance, linalg};\n\n#[test]\nfn test_matvec_basic_float() {\n    let device = Default::default();\n    let matrix = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);\n    let vector = TestTensor::<1>::from_floats([5.0, 6.0], &device);\n\n    let result = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);\n    let expected = TensorData::from([17.0, 39.0]);\n\n    result\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_matvec_basic_int() {\n    let device = Default::default();\n    let matrix = TestTensorInt::<2>::from_ints([[2, 0, -1], [1, 3, 2]], &device);\n    let vector = TestTensorInt::<1>::from_ints([3, -2, 4], &device);\n\n    let result = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);\n    let expected = TensorData::from([2, 5]);\n\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_matvec_batched() {\n    let device = Default::default();\n    let matrix = TestTensor::<3>::from_floats(\n        [\n            [[1.0, 0.0, 2.0], [3.0, 1.0, -1.0]],\n            [[-2.0, 1.0, 0.0], [0.5, -1.5, 2.0]],\n        ],\n        &device,\n    );\n    let vector = TestTensor::<2>::from_floats([[1.0, -1.0, 0.5], [2.0, 0.0, -1.0]], &device);\n\n    let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);\n    let expected = TensorData::from([[2.0, 1.5], [-4.0, -1.0]]);\n\n    result\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_matvec_vector_broadcasts_over_batches() {\n    let device = Default::default();\n    let matrix = TestTensor::<3>::from_floats(\n        [\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],\n            [[-1.0, 0.0, 2.0], [3.0, 1.0, -2.0]],\n        ],\n        &device,\n    );\n    let vector = TestTensor::<2>::from_floats([[1.0, 0.0, -1.0]], &device);\n\n    let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);\n    let expected = TensorData::from([[-2.0, -2.0], [-3.0, 5.0]]);\n\n    result\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_matvec_matrix_broadcasts_over_vector_batches() {\n    let device = Default::default();\n    let matrix = TestTensor::<3>::from_floats([[[1.0, 0.0, 2.0], [3.0, -1.0, 1.0]]], &device);\n    let vector = TestTensor::<2>::from_floats([[2.0, 1.0, 0.0], [1.0, -1.0, 3.0]], &device);\n\n    let result = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);\n    let expected = TensorData::from([[2.0, 5.0], [7.0, 7.0]]);\n\n    result\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[should_panic]\nfn test_matvec_invalid_inner_dim_panics() {\n    let device = Default::default();\n    let matrix = TestTensor::<2>::zeros([2, 3], &device);\n    let vector = TestTensor::<1>::zeros([4], &device);\n\n    let _ = linalg::matvec::<TestBackend, 2, 1, _>(matrix, vector);\n}\n\n#[test]\n#[should_panic]\nfn test_matvec_mismatched_batches_panics() {\n    let device = Default::default();\n    let matrix = TestTensor::<3>::zeros([2, 3, 4], &device);\n    let vector = TestTensor::<2>::zeros([3, 4], &device);\n\n    let _ = linalg::matvec::<TestBackend, 3, 2, _>(matrix, vector);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/mod.rs",
    "content": "use super::*;\n\npub(crate) mod cosine_similarity;\npub(crate) mod diag;\npub(crate) mod lu_decomposition;\npub(crate) mod matvec;\npub(crate) mod outer;\npub(crate) mod trace;\npub(crate) mod vector_norm;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/outer.rs",
    "content": "use super::*;\nuse burn_tensor::{ElementConversion, Tolerance};\nuse burn_tensor::{TensorData, linalg};\n\n// ---------- Vector (D=1, R=2) tests ----------\n\n#[test]\nfn test_outer_basic() {\n    let u = TestTensor::<1>::from([1.0, 2.0, 3.0]);\n    let v = TestTensor::<1>::from([4.0, 5.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::from([[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_shapes_only() {\n    let device = Default::default();\n    let u = TestTensor::<1>::zeros([3], &device);\n    let v = TestTensor::<1>::zeros([5], &device);\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v);\n    assert_eq!(out.shape().dims(), [3, 5]);\n}\n\n#[test]\nfn test_outer_asymmetry_and_shapes() {\n    let u = TestTensor::<1>::from([1.0, 2.0]);\n    let v = TestTensor::<1>::from([3.0, 4.0, 5.0]);\n\n    let uv = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone());\n    let vu = linalg::outer::<TestBackend, 1, 2, _>(v, u);\n\n    assert_eq!(uv.shape().dims(), [2, 3]);\n    assert_eq!(vu.shape().dims(), [3, 2]);\n}\n\n#[test]\nfn test_outer_zero_left() {\n    let device = Default::default();\n    let u = TestTensor::<1>::zeros([3], &device);\n    let v = TestTensor::<1>::from([7.0, 8.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::zeros::<FloatElem, _>([3, 2]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_zero_right() {\n    let device = Default::default();\n    let u = TestTensor::<1>::from([1.0, -2.0, 3.0]);\n    let v = TestTensor::<1>::zeros([4], &device);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::zeros::<FloatElem, _>([3, 4]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_signs() {\n    let u = TestTensor::<1>::from([-1.0, 2.0]);\n    let v = TestTensor::<1>::from([3.0, -4.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::from([[-3.0, 4.0], [6.0, -8.0]]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_integer_inputs() {\n    let u = TestTensorInt::<1>::from([1, 2, 3]);\n    let v = TestTensorInt::<1>::from([4, 5]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::from([[4, 5], [8, 10], [12, 15]]);\n\n    out.assert_eq(&expected, false);\n}\n\n#[test]\nfn test_outer_equivalence_to_matmul() {\n    let u = TestTensor::<1>::from([1.0, 2.0, 3.0]);\n    let v = TestTensor::<1>::from([4.0, 5.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone()).into_data();\n\n    let u2 = u.reshape([3, 1]);\n    let v2 = v.reshape([1, 2]);\n    let out_matmul = u2.matmul(v2).into_data();\n\n    out.assert_approx_eq::<FloatElem>(&out_matmul, Tolerance::default());\n}\n\n#[test]\nfn test_outer_vector_identity_right_mult() {\n    let u = TestTensor::<1>::from([2.0, -1.0]);\n    let v = TestTensor::<1>::from([3.0, 4.0]);\n    let w = TestTensor::<1>::from([5.0, 6.0]);\n\n    let uv = linalg::outer::<TestBackend, 1, 2, _>(u.clone(), v.clone());\n    let left = uv.matmul(w.clone().reshape([2, 1])).reshape([2]);\n\n    let v_dot_w = v.dot(w);\n    let right = u * v_dot_w;\n\n    left.into_data()\n        .assert_approx_eq::<FloatElem>(&right.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_outer_length_one_vectors() {\n    let u = TestTensor::<1>::from([3.0]);\n    let v = TestTensor::<1>::from([4.0, 5.0, 6.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::from([[12.0, 15.0, 18.0]]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_large_values() {\n    let big = 1.0e10;\n    let u = TestTensor::<1>::from([big, -big]);\n    let v = TestTensor::<1>::from([big, big]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n    let expected = TensorData::from([[big * big, big * big], [-big * big, -big * big]]);\n\n    let tol = Tolerance::relative(1e-6).set_half_precision_relative(1e-3);\n    out.assert_approx_eq::<FloatElem>(&expected, tol);\n}\n\n#[test]\nfn test_outer_nan_propagation() {\n    let u = TestTensor::<1>::from([f32::NAN, 2.0]);\n    let v = TestTensor::<1>::from([3.0, 4.0]);\n\n    let out = linalg::outer::<TestBackend, 1, 2, _>(u, v).into_data();\n\n    let s: &[FloatElem] = out\n        .as_slice::<FloatElem>()\n        .expect(\"outer nan_propagation: as_slice failed\");\n\n    assert!(s[0].is_nan());\n    assert!(s[1].is_nan());\n    assert_eq!(s[2], 6.0f32.elem::<FloatElem>());\n    assert_eq!(s[3], 8.0f32.elem::<FloatElem>());\n}\n\n// ---------- Batched (D=2, R=3) tests ----------\n\n#[test]\nfn test_outer_batched_basic() {\n    let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n    let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]);\n    let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();\n\n    let expected = TensorData::from([\n        [[5.0, 6.0, 7.0], [10.0, 12.0, 14.0]],\n        [[24.0, 27.0, 30.0], [32.0, 36.0, 40.0]],\n    ]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_batched_shapes() {\n    let device = Default::default();\n    let x = TestTensor::<2>::zeros([3, 4], &device);\n    let y = TestTensor::<2>::zeros([3, 5], &device);\n    let out = linalg::outer::<TestBackend, 2, 3, _>(x, y);\n    assert_eq!(out.shape().dims(), [3, 4, 5]);\n}\n\n#[test]\nfn test_outer_batched_zero_left() {\n    let device = Default::default();\n    let x = TestTensor::<2>::zeros([2, 3], &device);\n    let y = TestTensor::<2>::from([[7.0, 8.0], [9.0, 10.0]]);\n    let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();\n\n    let expected = TensorData::zeros::<FloatElem, _>([2, 3, 2]);\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_batched_zero_right() {\n    let device = Default::default();\n    let x = TestTensor::<2>::from([[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]]);\n    let y = TestTensor::<2>::zeros([2, 4], &device);\n    let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();\n\n    let expected = TensorData::zeros::<FloatElem, _>([2, 3, 4]);\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_batched_signs() {\n    let x = TestTensor::<2>::from([[-1.0, 2.0], [3.0, -4.0]]);\n    let y = TestTensor::<2>::from([[3.0, -4.0], [-5.0, 6.0]]);\n    let out = linalg::outer::<TestBackend, 2, 3, _>(x, y).into_data();\n\n    let expected = TensorData::from([[[-3.0, 4.0], [6.0, -8.0]], [[-15.0, 18.0], [20.0, -24.0]]]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_outer_batched_equivalence_to_per_sample_outer() {\n    let x = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n    let y = TestTensor::<2>::from([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]);\n    let batched = linalg::outer::<TestBackend, 2, 3, _>(x.clone(), y.clone());\n\n    for b in 0..2 {\n        let idx = TestTensorInt::<1>::from([b]);\n\n        let xb2d = x.clone().select(0, idx.clone()); // (1, m)\n        let yb2d = y.clone().select(0, idx); // (1, n)\n\n        let dims_x: [usize; 2] = xb2d.shape().dims();\n        let dims_y: [usize; 2] = yb2d.shape().dims();\n        let (m, n) = (dims_x[1], dims_y[1]);\n\n        let per = linalg::outer::<TestBackend, 1, 2, _>(xb2d.reshape([m]), yb2d.reshape([n]));\n\n        let bat3d = batched.clone().select(0, TestTensorInt::<1>::from([b])); // (m, n)\n\n        let per_len = per.shape().num_elements();\n        let per_flat = per.reshape([per_len]).into_data();\n\n        let bat_len = bat3d.shape().num_elements();\n        let bat_flat = bat3d.reshape([bat_len]).into_data();\n\n        bat_flat.assert_approx_eq::<FloatElem>(&per_flat, Tolerance::default());\n    }\n}\n\n#[test]\n#[should_panic]\nfn test_outer_batched_mismatched_batches_panics() {\n    let device = Default::default();\n    let x = TestTensor::<2>::zeros([2, 3], &device);\n    let y = TestTensor::<2>::zeros([3, 4], &device);\n    let _ = linalg::outer::<TestBackend, 2, 3, _>(x, y);\n}\n\n#[test]\nfn test_outer_dim() {\n    let u = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n    let v = TestTensor::<2>::from([[4.0, 5.0], [5.0, 6.0]]);\n\n    let out = linalg::outer_dim::<TestBackend, 2, 3, _, _>(u, v, 0).into_data();\n    let expected = TensorData::from([[[4.0, 10.0], [5.0, 12.0]], [[12.0, 20.0], [15.0, 24.0]]]);\n\n    out.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/trace.rs",
    "content": "use super::*;\nuse burn_tensor::linalg::trace;\n\n#[test]\nfn test_trace_2d_square() {\n    let device = Default::default();\n    let tensor =\n        TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([15.0], &device); // 1 + 5 + 9 = 15\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_2d_rectangular_wide() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([7.0], &device); // 1 + 6 = 7\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_2d_rectangular_tall() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([5.0], &device); // 1 + 4 = 5\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_3d_batch() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(\n        [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],\n        &device,\n    );\n\n    let result = trace::<_, 3, 2>(tensor);\n    let expected = TestTensor::<2>::from_data([[5.0], [13.0]], &device); // [1+4=5, 5+8=13]\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_4d_batch() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_data(\n        [[\n            // Batch 0, Channel 0\n            [[1.0, 2.0], [3.0, 4.0]],\n            // Batch 0, Channel 1\n            [[5.0, 6.0], [7.0, 8.0]],\n        ]],\n        &device,\n    );\n\n    let result = trace::<_, 4, 3>(tensor);\n    let expected = TestTensor::<3>::from_data([[[5.0], [13.0]]], &device);\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_single_element() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[42.0]], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([42.0], &device);\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_zeros() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::zeros([3, 3], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([0.0], &device);\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\nfn test_trace_negative_values() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[-1.0, 2.0], [3.0, -4.0]], &device);\n    let result = trace::<_, 2, 1>(tensor);\n    let expected = TestTensor::<1>::from_data([-5.0], &device); // -1 + (-4) = -5\n\n    assert_eq!(result.to_data(), expected.to_data());\n}\n\n#[test]\n#[should_panic]\nfn test_trace_1d_should_panic() {\n    let device = Default::default();\n    // 1D tensor should panic - trace requires at least 2 dimensions\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);\n    let _result = trace::<_, 1, 0>(tensor);\n}\n\n#[test]\n#[should_panic]\nfn test_trace_wrong_output_rank_should_panic() {\n    let device = Default::default();\n    // Providing wrong output rank should panic\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let _result = trace::<_, 2, 2>(tensor); // Should be 2,1 not 2,2\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/linalg/vector_norm.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse burn_tensor::backend::Backend;\nuse burn_tensor::linalg;\n\n#[test]\nfn test_max_min_abs() {\n    let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);\n\n    let expected = TestTensor::<2>::from([[3., 4.]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::max_abs_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    let expected = TestTensor::<2>::from([[1., 2.]]).into_data();\n    linalg::vector_norm(x.clone(), -f64::INFINITY, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::min_abs_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    let expected = TestTensor::<2>::from([[2.], [4.]]).into_data();\n    linalg::vector_norm(x.clone(), f64::INFINITY, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::max_abs_norm(x.clone(), 1)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    let expected = TestTensor::<2>::from([[1.], [3.]]).into_data();\n    linalg::vector_norm(x.clone(), -f64::INFINITY, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::vector_norm(x.clone(), f64::NEG_INFINITY, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::min_abs_norm(x, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    // Test with integer tensor\n    let z = TestTensorInt::<2>::from([[1, 2], [3, 4]]);\n\n    linalg::max_abs_norm(z.clone(), 0)\n        .into_data()\n        .assert_eq(&TestTensorInt::<2>::from([[3, 4]]).into_data(), true);\n    linalg::max_abs_norm(z.clone(), 1)\n        .into_data()\n        .assert_eq(&TestTensorInt::<2>::from([[2], [4]]).into_data(), true);\n\n    linalg::min_abs_norm(z.clone(), 0)\n        .into_data()\n        .assert_eq(&TestTensorInt::<2>::from([[1, 2]]).into_data(), true);\n    linalg::min_abs_norm(z, 1)\n        .into_data()\n        .assert_eq(&TestTensorInt::<2>::from([[1], [3]]).into_data(), true);\n}\n\n#[test]\nfn test_l0_norm() {\n    let x = TestTensor::<2>::from([[1.0, -2.0, 0.], [0.0, 0., 4.]]);\n\n    let expected = TestTensor::<2>::from([[1., 1., 1.]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L0, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l0_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    let expected = TestTensor::<2>::from([[2.], [1.]]).into_data();\n    linalg::vector_norm(x.clone(), 0.0, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l0_norm(x.clone(), 1)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    // Test with integer tensor\n    let z = TestTensorInt::<2>::from([[1, -2, 0], [0, 0, 4]]);\n\n    linalg::l0_norm(z.clone(), 0)\n        .into_data()\n        .assert_eq(&TestTensor::<2>::from([[1, 1, 1]]).int().into_data(), true);\n    linalg::l0_norm(z.clone(), 1)\n        .into_data()\n        .assert_eq(&TestTensor::<2>::from([[2], [1]]).int().into_data(), true);\n}\n\n#[test]\nfn test_l1_norm() {\n    let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);\n\n    let expected = TestTensor::<2>::from([[4.0, 6.0]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L1, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l1_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    let expected = TestTensor::<2>::from([[3.0], [7.0]]).into_data();\n    linalg::vector_norm(x.clone(), 1.0, 1)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l1_norm(x.clone(), 1)\n        .into_data()\n        .assert_eq(&expected, true);\n}\n\n#[test]\nfn test_lp_norm() {\n    let x = TestTensor::<2>::from([[1., -2., 0.], [0., 3., 4.]]);\n    let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3);\n\n    fn lp_norm_naive<B: Backend, const D: usize>(\n        x: Tensor<B, D>,\n        p: f64,\n        dim: usize,\n    ) -> Tensor<B, D> {\n        x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)\n    }\n\n    // Arbitrary P\n    let expected = TestTensor::<2>::from([[1.0, 3.2710664, 4.0]]).into_data();\n    linalg::vector_norm(x.clone(), 3, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::lp_norm(x.clone(), 3., 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    // L0\n    let expected = TestTensor::<2>::from([[1., 2., 1.]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L0, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l0_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::lp_norm(x.clone(), 0.0, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    // L1\n    let expected = TestTensor::<2>::from([[1.0, 5.0, 4.0]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L1, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::l1_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    lp_norm_naive(x.clone(), 1.0, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::lp_norm(x.clone(), 1.0, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n\n    // L2\n    let expected = TestTensor::<2>::from([[1.0, 3.6055512, 4.0]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L2, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::l2_norm(x.clone(), 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    lp_norm_naive(x.clone(), 2.0, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::lp_norm(x.clone(), 2.0, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    // LInf\n    let expected = TestTensor::<2>::from([[1.0, 3.0, 4.0]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::LInf, 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::max_abs_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::lp_norm(x.clone(), f64::INFINITY, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    // LNegInf\n    let expected = TestTensor::<2>::from([[0.0, 2.0, 0.0]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::LNegInf, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::min_abs_norm(x.clone(), 0)\n        .into_data()\n        .assert_eq(&expected, true);\n    linalg::lp_norm(x.clone(), f64::NEG_INFINITY, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_l2_norm() {\n    let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);\n    let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3);\n\n    let expected = TestTensor::<2>::from([[3.16227766, 4.47213595]]).into_data();\n    linalg::vector_norm(x.clone(), linalg::Norm::L2, 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::l2_norm(x.clone(), 0)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n\n    let expected = TestTensor::<2>::from([[2.23606798], [5.0]]).into_data();\n    linalg::vector_norm(x.clone(), 2.0, 1)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n    linalg::l2_norm(x.clone(), 1)\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_normalize() {\n    let x = TestTensor::<2>::from([[1., 2.], [3., 4.]]);\n\n    let expected = TensorData::from([[1. / 4., 2. / 6.], [3. / 4., 4. / 6.]]);\n    let output = linalg::vector_normalize(x.clone(), 1.0, 0, 0.25).into_data();\n    output.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let expected = TensorData::from([[1. / 5., 2. / 6.], [3. / 5., 4. / 6.]]);\n    let output = linalg::vector_normalize(x.clone(), 1.0, 0, 5.0).into_data();\n    output.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/mod.rs",
    "content": "#[allow(unused_imports)]\npub use super::*; // re-export test types\n\nmod activation;\nmod grid;\nmod linalg;\nmod module;\nmod ops;\nmod primitive;\nmod stats;\n\n#[cfg(feature = \"quantization\")]\nmod quantization;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::adaptive_avg_pool1d;\n\n#[test]\nfn test_adaptive_avg_pool1d_simple() {\n    let test = AdaptiveAvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        length: 8,\n        length_out: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0.5, 2.5, 4.5, 6.5],\n        [8.5, 10.5, 12.5, 14.5],\n    ]]));\n}\n\n#[test]\nfn test_adaptive_avg_pool1d_dyn_filter_size() {\n    let test = AdaptiveAvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        length: 7,\n        length_out: 3,\n    };\n\n    test.assert_output(TestTensor::from([[[1.0, 3.0, 5.0], [8.0, 10.0, 12.0]]]));\n}\n\n#[test]\nfn test_adaptive_avg_pool1d_bigger_output() {\n    let test = AdaptiveAvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        length: 4,\n        length_out: 8,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0],\n        [4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0, 7.0],\n    ]]));\n}\n\nstruct AdaptiveAvgPool1dTestCase {\n    batch_size: usize,\n    channels: usize,\n    length: usize,\n    length_out: usize,\n}\n\nimpl AdaptiveAvgPool1dTestCase {\n    fn assert_output(self, y: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.length]);\n        let device = Default::default();\n        let x = TestTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        );\n        let output = adaptive_avg_pool1d(x, self.length_out);\n\n        y.into_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/adaptive_avgpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::adaptive_avg_pool2d;\n\n#[test]\nfn test_adaptive_avg_pool2d_simple() {\n    let test = AdaptiveAvgPool2dTestCase {\n        batch_size: 1,\n        channels: 2,\n        height: 8,\n        width: 6,\n        height_out: 4,\n        width_out: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [3.5000, 4.5000, 6.5000, 7.5000],\n            [15.5000, 16.5000, 18.5000, 19.5000],\n            [27.5000, 28.5000, 30.5000, 31.5000],\n            [39.5000, 40.5000, 42.5000, 43.5000],\n        ],\n        [\n            [51.5000, 52.5000, 54.5000, 55.5000],\n            [63.5000, 64.5000, 66.5000, 67.5000],\n            [75.5000, 76.5000, 78.5000, 79.5000],\n            [87.5000, 88.5000, 90.5000, 91.5000],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_adaptive_avg_pool2d_dyn_filter_size() {\n    let test = AdaptiveAvgPool2dTestCase {\n        batch_size: 1,\n        channels: 2,\n        height: 5,\n        width: 7,\n        height_out: 3,\n        width_out: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [[5.0000, 8.0000], [15.5000, 18.5000], [26.0000, 29.0000]],\n        [[40.0000, 43.0000], [50.5000, 53.5000], [61.0000, 64.0000]],\n    ]]));\n}\n\n#[test]\nfn test_adaptive_avg_pool2d_bigger_output() {\n    let test = AdaptiveAvgPool2dTestCase {\n        batch_size: 1,\n        channels: 2,\n        height: 4,\n        width: 3,\n        height_out: 5,\n        width_out: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [0.0000, 0.5000, 1.5000, 2.0000],\n            [1.5000, 2.0000, 3.0000, 3.5000],\n            [4.5000, 5.0000, 6.0000, 6.5000],\n            [7.5000, 8.0000, 9.0000, 9.5000],\n            [9.0000, 9.5000, 10.5000, 11.0000],\n        ],\n        [\n            [12.0000, 12.5000, 13.5000, 14.0000],\n            [13.5000, 14.0000, 15.0000, 15.5000],\n            [16.5000, 17.0000, 18.0000, 18.5000],\n            [19.5000, 20.0000, 21.0000, 21.5000],\n            [21.0000, 21.5000, 22.5000, 23.0000],\n        ],\n    ]]));\n}\n\nstruct AdaptiveAvgPool2dTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl AdaptiveAvgPool2dTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = adaptive_avg_pool2d(x, [self.height_out, self.width_out]);\n\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/attention.rs",
    "content": "use super::*;\nuse burn_tensor::Distribution;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::attention;\nuse burn_tensor::module::attention_fallback;\nuse burn_tensor::ops::AttentionModuleOptions;\n\n#[test]\nfn test_attention_no_mask() {\n    // Skip on metal with f16 - flash attention returns zeros\n    // Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4325\n    #[cfg(feature = \"metal\")]\n    if core::any::TypeId::of::<FloatElemType>() == core::any::TypeId::of::<burn_tensor::f16>() {\n        return;\n    }\n\n    let num_batches = 1;\n    let num_heads = 1;\n    let seq_q = 128;\n    let seq_kv = 128;\n    let head_dim = 64;\n    let val_dim = 64;\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_q, head_dim],\n        Distribution::Uniform(0., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_kv, head_dim],\n        Distribution::Uniform(0., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_kv, val_dim],\n        Distribution::Uniform(0., 1.),\n        &Default::default(),\n    );\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        None,\n        Default::default(),\n    );\n\n    let expected =\n        attention_fallback::<TestBackend>(query, key, value, None, None, Default::default());\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n#[test]\nfn test_attention_custom_scale() {\n    let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n\n    let options = AttentionModuleOptions {\n        scale: Some(0.1),\n        ..Default::default()\n    };\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        None,\n        options,\n    );\n\n    let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n#[test]\nfn test_attention_attn_bias() {\n    let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let bias = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, seq_len],\n        Distribution::Uniform(-0.5, 0.5),\n        &Default::default(),\n    );\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        Some(bias.clone()),\n        Default::default(),\n    );\n\n    let expected =\n        attention_fallback::<TestBackend>(query, key, value, None, Some(bias), Default::default());\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n#[test]\nfn test_attention_softcap() {\n    let [num_batches, num_heads, seq_len, head_dim] = [1, 2, 16, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n\n    let options = AttentionModuleOptions {\n        softcap: Some(50.0),\n        ..Default::default()\n    };\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        None,\n        options,\n    );\n\n    let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n#[test]\nfn test_attention_is_causal() {\n    let [num_batches, num_heads, seq_len, head_dim] = [2, 4, 16, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n\n    let options = AttentionModuleOptions {\n        is_causal: true,\n        ..Default::default()\n    };\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        None,\n        options,\n    );\n\n    let expected = attention_fallback::<TestBackend>(query, key, value, None, None, options);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n/// Cross-attention: seq_q != seq_k, with causal masking and additive bias.\n#[test]\nfn test_attention_cross_attention_with_bias() {\n    let [num_batches, num_heads, seq_q, seq_k, head_dim] = [2, 2, 8, 24, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_q, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_k, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_k, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let bias = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_q, seq_k],\n        Distribution::Uniform(-0.5, 0.5),\n        &Default::default(),\n    );\n\n    let options = AttentionModuleOptions {\n        is_causal: true,\n        ..Default::default()\n    };\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        None,\n        Some(bias.clone()),\n        options,\n    );\n\n    let expected = attention_fallback::<TestBackend>(query, key, value, None, Some(bias), options);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n\n/// Regression: softcap must be applied before -inf masking.\n/// With causal masking, position 0 can only attend to itself, so output[0] == value[0].\n/// If softcap were applied after masking, tanh(-inf/softcap) = -softcap (finite),\n/// and the masked position would leak into the output.\n#[test]\nfn test_attention_softcap_preserves_causal_mask() {\n    let [num_batches, num_heads, seq_len, head_dim] = [1, 1, 4, 8];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n\n    let options = AttentionModuleOptions {\n        softcap: Some(20.0),\n        is_causal: true,\n        ..Default::default()\n    };\n\n    let output = attention_fallback::<TestBackend>(query, key, value.clone(), None, None, options);\n\n    // With causal masking, position 0 can only attend to itself (softmax = [1, 0, 0, 0]).\n    // So output[..., 0, :] must equal value[..., 0, :].\n    let output_row0 = output.slice([0..1, 0..1, 0..1, 0..head_dim]);\n    let value_row0 = value.slice([0..1, 0..1, 0..1, 0..head_dim]);\n\n    output_row0\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&value_row0.into_data(), Tolerance::relative(1e-5));\n}\n\n/// Combined: mask + bias + custom scale + softcap together.\n#[test]\nfn test_attention_all_options() {\n    let [num_batches, num_heads, seq_len, head_dim] = [2, 2, 16, 32];\n\n    let query = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let key = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let value = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, head_dim],\n        Distribution::Uniform(-1., 1.),\n        &Default::default(),\n    );\n    let bias = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, seq_len],\n        Distribution::Uniform(-0.5, 0.5),\n        &Default::default(),\n    );\n    // Create a random bool mask by thresholding a uniform float tensor\n    let mask = TestTensor::<4>::random(\n        [num_batches, num_heads, seq_len, seq_len],\n        Distribution::Uniform(0., 1.),\n        &Default::default(),\n    )\n    .greater_elem(0.7);\n\n    let options = AttentionModuleOptions {\n        scale: Some(0.05),\n        softcap: Some(30.0),\n        is_causal: true,\n    };\n\n    let output = attention(\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        Some(mask.clone()),\n        Some(bias.clone()),\n        options,\n    );\n\n    let expected =\n        attention_fallback::<TestBackend>(query, key, value, Some(mask), Some(bias), options);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected.into_data(),\n        Tolerance::rel_abs(1e-2, 1e-3).set_half_precision_relative(1e-1),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/avgpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::avg_pool1d;\n\n#[test]\nfn test_avg_pool1d_simple() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size: 3,\n        padding: 0,\n        stride: 1,\n        length: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from([[[1., 2., 3., 4.]]]));\n}\n\n#[test]\nfn test_avg_pool1d_complex() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        length: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0.33333, 2.0000, 4.0000],\n        [4.33333, 8.0000, 10.0000],\n    ]]));\n}\n\n#[test]\nfn test_avg_pool1d_complex_dont_count_pad() {\n    let test = AvgPool1dTestCase {\n        batch_size: 1,\n        channels: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        length: 6,\n        count_include_pad: false,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0.5000, 2.0000, 4.0000],\n        [6.5000, 8.0000, 10.0000],\n    ]]));\n}\n\nstruct AvgPool1dTestCase {\n    batch_size: usize,\n    channels: usize,\n    kernel_size: usize,\n    padding: usize,\n    stride: usize,\n    length: usize,\n    count_include_pad: bool,\n}\n\nimpl AvgPool1dTestCase {\n    fn assert_output(self, y: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.length]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n        );\n        let output = avg_pool1d(\n            x,\n            self.kernel_size,\n            self.stride,\n            self.padding,\n            self.count_include_pad,\n            false,\n        );\n\n        y.to_data().assert_approx_eq::<FloatElem>(\n            &output.into_data(),\n            Tolerance::default().set_half_precision_relative(1e-3),\n        );\n    }\n}\n\n#[test]\nfn test_avg_pool1d_ceil_mode() {\n    // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride\n    // Input: 1x1x6 (values 0-5), kernel: 3, stride: 2, padding: 0\n    // Floor mode: output = (6-3)/2+1 = 2 elements\n    // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements\n    let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]);\n\n    // With ceil_mode=false (floor): output is 2 elements\n    // Window 0: avg(0,1,2) = 1\n    // Window 1: avg(2,3,4) = 3\n    let y_floor = TestTensor::<3>::from([[[1.0, 3.0]]]);\n\n    let output_floor = avg_pool1d(\n        x.clone(),\n        3,    // kernel_size\n        2,    // stride\n        0,    // padding\n        true, // count_include_pad\n        false,\n    );\n\n    y_floor.to_data().assert_approx_eq::<FloatElem>(\n        &output_floor.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n\n    // With ceil_mode=true: output is 3 elements\n    // Window 0: avg(0,1,2) = 1\n    // Window 1: avg(2,3,4) = 3\n    // Window 2: avg(4,5) = 4.5 (partial window, count_include_pad=false divides by 2)\n    let y_ceil = TestTensor::<3>::from([[[1.0, 3.0, 4.5]]]);\n\n    let output_ceil = avg_pool1d(\n        x, 3,     // kernel_size\n        2,     // stride\n        0,     // padding\n        false, // count_include_pad=false to get correct average for partial window\n        true,\n    );\n\n    y_ceil.to_data().assert_approx_eq::<FloatElem>(\n        &output_ceil.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n}\n\n#[test]\nfn test_avg_pool1d_ceil_mode_count_include_pad() {\n    // Test count_include_pad=true + ceil_mode=true interaction for 1D\n    // When ceil_mode creates windows that extend beyond the padded input:\n    // - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions)\n    //\n    // Input: 1x1x6, kernel 3, stride 2, padding 1, ceil_mode=true\n    // Output is 4 elements\n    let x = TestTensor::from([[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]);\n\n    // Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true:\n    // Window 0: positions -1,0,1 -> values 0,0,1 (0 is padding) / 3 = 0.333\n    // Window 1: positions 1,2,3 -> values 1,2,3 / 3 = 2.0\n    // Window 2: positions 3,4,5 -> values 3,4,5 / 3 = 4.0\n    // Window 3: positions 5,6,7 -> only 5 is valid, 6 is padding, 7 is ceil_mode extension\n    //           value 5 / 2 (only 2 positions within padded bounds) = 2.5\n    let expected = TestTensor::<3>::from([[[0.3333, 2.0, 4.0, 2.5]]]);\n\n    let output = avg_pool1d(\n        x, 3,    // kernel_size\n        2,    // stride\n        1,    // padding\n        true, // count_include_pad=true\n        true, // ceil_mode=true\n    );\n\n    expected.to_data().assert_approx_eq::<FloatElem>(\n        &output.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-2),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/avgpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::avg_pool2d;\n\n#[test]\nfn test_avg_pool2d_simple() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        height: 6,\n        width: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [7., 8., 9., 10.],\n        [13., 14., 15., 16.],\n        [19., 20., 21., 22.],\n        [25., 26., 27., 28.],\n    ]]]));\n}\n\n#[test]\nfn test_avg_pool2d_complex() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 2,\n        height: 4,\n        width: 6,\n        count_include_pad: true,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [1.1667, 3.0000, 4.3333, 2.5000],\n        [3.2500, 7.5000, 9.5000, 5.2500],\n        [6.2500, 13.5000, 15.5000, 8.2500],\n        [5.1667, 11.0000, 12.3333, 6.5000],\n    ]]]));\n}\n\n#[test]\nfn test_avg_pool2d_complex_dont_include_pad() {\n    let test = AvgPool2dTestCase {\n        batch_size: 1,\n        channels: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 1,\n        stride_2: 2,\n        height: 4,\n        width: 6,\n        count_include_pad: false,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [3.5000, 4.5000, 6.5000, 7.5000],\n        [6.5000, 7.5000, 9.5000, 10.5000],\n        [12.5000, 13.5000, 15.5000, 16.5000],\n        [15.5000, 16.5000, 18.5000, 19.5000],\n    ]]]));\n}\n\nstruct AvgPool2dTestCase {\n    batch_size: usize,\n    channels: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    height: usize,\n    width: usize,\n    count_include_pad: bool,\n}\n\nimpl AvgPool2dTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = avg_pool2d(\n            x,\n            [self.kernel_size_1, self.kernel_size_2],\n            [self.stride_1, self.stride_2],\n            [self.padding_1, self.padding_2],\n            self.count_include_pad,\n            false,\n        );\n\n        y.to_data().assert_approx_eq::<FloatElem>(\n            &output.into_data(),\n            Tolerance::default().set_half_precision_relative(1e-3),\n        );\n    }\n}\n\n#[test]\nfn test_avg_pool2d_ceil_mode() {\n    // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride\n    // Input: 1x1x6x6 (values 0-35), kernel: 3x3, stride: 2x2, padding: 0x0\n    // Floor mode: output = (6-3)/2+1 = 2 x 2\n    // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 x 3\n    let x = TestTensor::from([[[\n        [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0, 9.0, 10.0, 11.0],\n        [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],\n        [18.0, 19.0, 20.0, 21.0, 22.0, 23.0],\n        [24.0, 25.0, 26.0, 27.0, 28.0, 29.0],\n        [30.0, 31.0, 32.0, 33.0, 34.0, 35.0],\n    ]]]);\n\n    // With ceil_mode=false (floor): output is 2x2\n    // Window (0,0): avg(0,1,2,6,7,8,12,13,14) = avg(63) = 7\n    // Window (0,1): avg(2,3,4,8,9,10,14,15,16) = avg(81) = 9\n    // Window (1,0): avg(12,13,14,18,19,20,24,25,26) = avg(171) = 19\n    // Window (1,1): avg(14,15,16,20,21,22,26,27,28) = avg(189) = 21\n    let y_floor = TestTensor::<4>::from([[[[7.0, 9.0], [19.0, 21.0]]]]);\n\n    let output_floor = avg_pool2d(\n        x.clone(),\n        [3, 3],\n        [2, 2],\n        [0, 0],\n        true, // count_include_pad\n        false,\n    );\n\n    y_floor.to_data().assert_approx_eq::<FloatElem>(\n        &output_floor.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n\n    // With ceil_mode=true: output is 3x3\n    // The extra windows at the edge include partial/padded regions\n    // When count_include_pad=false, only actual values are averaged\n    // Window (0,2): positions (0:3, 4:6) -> values 4,5,10,11,16,17 -> avg = 10.5\n    // Window (1,2): positions (2:5, 4:6) -> values 16,17,22,23,28,29 -> avg = 22.5\n    // Window (2,0): positions (4:6, 0:3) -> values 24,25,26,30,31,32 -> avg = 28\n    // Window (2,1): positions (4:6, 2:5) -> values 26,27,28,32,33,34 -> avg = 30\n    // Window (2,2): positions (4:6, 4:6) -> values 28,29,34,35 -> avg = 31.5\n    let y_ceil =\n        TestTensor::<4>::from([[[[7.0, 9.0, 10.5], [19.0, 21.0, 22.5], [28.0, 30.0, 31.5]]]]);\n\n    let output_ceil = avg_pool2d(\n        x,\n        [3, 3],\n        [2, 2],\n        [0, 0],\n        false, // count_include_pad=false to avoid dividing by full kernel size\n        true,\n    );\n\n    y_ceil.to_data().assert_approx_eq::<FloatElem>(\n        &output_ceil.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n}\n\n#[test]\nfn test_avg_pool2d_ceil_mode_count_include_pad() {\n    // Test count_include_pad=true + ceil_mode=true interaction\n    // When ceil_mode creates windows that extend beyond the padded input:\n    // - count_include_pad=true should count positions within padded bounds (not ceil_mode extensions)\n    //\n    // For input 6x6, kernel 3, stride 2, padding 1, ceil_mode=true:\n    // - Output is 4x4\n    // - Corner (3,3) window covers positions beyond even the user padding\n    // - Expected: 35/4 = 8.75 (divides by count of positions within padded bounds)\n\n    let x = TestTensor::from([[[\n        [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0, 9.0, 10.0, 11.0],\n        [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],\n        [18.0, 19.0, 20.0, 21.0, 22.0, 23.0],\n        [24.0, 25.0, 26.0, 27.0, 28.0, 29.0],\n        [30.0, 31.0, 32.0, 33.0, 34.0, 35.0],\n    ]]]);\n\n    // Expected PyTorch output with padding=1, ceil_mode=true, count_include_pad=true\n    // Note: corner (3,3) = 8.75 = 35/4, not 35/9\n    let expected = TestTensor::<4>::from([[[\n        [1.5556, 3.3333, 4.6667, 2.6667],\n        [8.3333, 14.0000, 16.0000, 8.5000],\n        [16.3333, 26.0000, 28.0000, 14.5000],\n        [10.1667, 16.0000, 17.0000, 8.7500],\n    ]]]);\n\n    let output = avg_pool2d(\n        x,\n        [3, 3],\n        [2, 2],\n        [1, 1],\n        true, // count_include_pad=true\n        true, // ceil_mode=true\n    );\n\n    expected.to_data().assert_approx_eq::<FloatElem>(\n        &output.into_data(),\n        Tolerance::default().set_half_precision_relative(1e-2),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/bicubic_interpolate.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::interpolate;\nuse burn_tensor::ops::{InterpolateMode, InterpolateOptions};\n\n#[test]\nfn test_upsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 2,\n        channels: 1,\n        height: 7,\n        width: 5,\n        height_out: 8,\n        width_out: 7,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[\n            [0.0000, 0.5741, 1.3704, 2.0000, 2.6296, 3.4259, 4.0000],\n            [4.0015, 4.5755, 5.3718, 6.0015, 6.6311, 7.4274, 8.0015],\n            [8.3528, 8.9268, 9.7231, 10.3528, 10.9824, 11.7787, 12.3528],\n            [\n                12.7697, 13.3438, 14.1400, 14.7697, 15.3993, 16.1956, 16.7697,\n            ],\n            [\n                17.2303, 17.8044, 18.6007, 19.2303, 19.8600, 20.6562, 21.2303,\n            ],\n            [\n                21.6472, 22.2213, 23.0176, 23.6472, 24.2769, 25.0731, 25.6472,\n            ],\n            [\n                25.9986, 26.5726, 27.3689, 27.9986, 28.6282, 29.4245, 29.9986,\n            ],\n            [\n                30.0000, 30.5741, 31.3704, 32.0000, 32.6296, 33.4259, 34.0000,\n            ],\n        ]],\n        [[\n            [\n                35.0000, 35.5741, 36.3704, 37.0000, 37.6296, 38.4259, 39.0000,\n            ],\n            [\n                39.0015, 39.5755, 40.3718, 41.0015, 41.6311, 42.4274, 43.0015,\n            ],\n            [\n                43.3528, 43.9269, 44.7231, 45.3528, 45.9824, 46.7787, 47.3528,\n            ],\n            [\n                47.7697, 48.3438, 49.1400, 49.7697, 50.3993, 51.1956, 51.7697,\n            ],\n            [\n                52.2303, 52.8044, 53.6007, 54.2303, 54.8600, 55.6562, 56.2303,\n            ],\n            [\n                56.6472, 57.2213, 58.0176, 58.6472, 59.2769, 60.0731, 60.6472,\n            ],\n            [\n                60.9986, 61.5726, 62.3689, 62.9986, 63.6282, 64.4245, 64.9986,\n            ],\n            [\n                65.0000, 65.5741, 66.3704, 67.0000, 67.6296, 68.4259, 69.0000,\n            ],\n        ]],\n    ]));\n}\n\n#[test]\nfn test_downsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 45,\n        width: 14,\n        height_out: 4,\n        width_out: 6,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0.0000, 2.5760, 5.2480, 7.7520, 10.4240, 13.0000],\n        [204.8148, 207.3908, 210.0628, 212.5668, 215.2388, 217.8148],\n        [411.1852, 413.7612, 416.4331, 418.9371, 421.6091, 424.1852],\n        [616.0000, 618.576, 621.2479, 623.7519, 626.4239, 629.0000],\n    ]]]));\n}\n\n#[test]\nfn test_1d_bicubic() {\n    // Initialize the model without weights (because the exported file does not contain them)\n    let device = Default::default();\n\n    // Run the model\n    let input = TestTensor::<3>::from_floats(\n        [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],\n        &device,\n    );\n\n    let input = input.unsqueeze_dim(2);\n\n    let output = interpolate(\n        input,\n        [1, 9],\n        InterpolateOptions::new(InterpolateMode::Bicubic),\n    );\n\n    assert_eq!(output.dims(), [1, 1, 1, 9]);\n\n    // assert output data does not contain NaN\n    assert!(\n        !output\n            .clone()\n            .to_data()\n            .as_slice::<FloatElem>()\n            .unwrap()\n            .iter()\n            .any(|&x| x.is_nan()),\n        \"interpolate output contains NaN\"\n    );\n\n    TestTensor::<4>::from([[[[\n        1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794,\n        -1.3986,\n    ]]]])\n    .to_data()\n    .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\nstruct InterpolateTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl InterpolateTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        self.assert_output_with_align_corners(y, true);\n    }\n\n    fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = interpolate(\n            x,\n            [self.height_out, self.width_out],\n            InterpolateOptions::new(InterpolateMode::Bicubic).with_align_corners(align_corners),\n        );\n\n        let tolerance = Tolerance::permissive();\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n\n#[test]\nfn test_upsample_half_pixel() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 4,\n        width: 4,\n        height_out: 8,\n        width_out: 8,\n    };\n\n    test.assert_output_with_align_corners(\n        TestTensor::from([[[\n            [\n                -0.5273, -0.2305, 0.2461, 0.875, 1.2812, 1.9102, 2.3867, 2.6836,\n            ],\n            [\n                0.6602, 0.957, 1.4336, 2.0625, 2.4688, 3.0977, 3.5742, 3.8711,\n            ],\n            [\n                2.5664, 2.8633, 3.3398, 3.9688, 4.375, 5.0039, 5.4805, 5.7773,\n            ],\n            [5.082, 5.3789, 5.8555, 6.4844, 6.8906, 7.5195, 7.9961, 8.293],\n            [6.707, 7.0039, 7.4805, 8.1094, 8.5156, 9.1445, 9.6211, 9.918],\n            [\n                9.2227, 9.5195, 9.9961, 10.625, 11.0312, 11.6602, 12.1367, 12.4336,\n            ],\n            [\n                11.1289, 11.4258, 11.9023, 12.5312, 12.9375, 13.5664, 14.043, 14.3398,\n            ],\n            [\n                12.3164, 12.6133, 13.0898, 13.7188, 14.125, 14.7539, 15.2305, 15.5273,\n            ],\n        ]]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/bilinear_interpolate.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::interpolate;\nuse burn_tensor::ops::{InterpolateMode, InterpolateOptions};\nuse burn_tensor::{DType, Shape};\n\n#[test]\nfn test_upsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 2,\n        channels: 1,\n        height: 7,\n        width: 5,\n        height_out: 8,\n        width_out: 7,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[\n            [0.0000, 0.6667, 1.3333, 2.0000, 2.6667, 3.3333, 4.0000],\n            [4.2857, 4.9524, 5.6190, 6.2857, 6.9524, 7.6190, 8.2857],\n            [8.5714, 9.2381, 9.9048, 10.5714, 11.2381, 11.9048, 12.5714],\n            [\n                12.8571, 13.5238, 14.1905, 14.8571, 15.5238, 16.1905, 16.8571,\n            ],\n            [\n                17.1429, 17.8095, 18.4762, 19.1429, 19.8095, 20.4762, 21.1429,\n            ],\n            [\n                21.4286, 22.0952, 22.7619, 23.4286, 24.0952, 24.7619, 25.4286,\n            ],\n            [\n                25.7143, 26.3810, 27.0476, 27.7143, 28.3810, 29.0476, 29.7143,\n            ],\n            [\n                30.0000, 30.6667, 31.3333, 32.0000, 32.6667, 33.3333, 34.0000,\n            ],\n        ]],\n        [[\n            [\n                35.0000, 35.6667, 36.3333, 37.0000, 37.6667, 38.3333, 39.0000,\n            ],\n            [\n                39.2857, 39.9524, 40.6190, 41.2857, 41.9524, 42.6190, 43.2857,\n            ],\n            [\n                43.5714, 44.2381, 44.9048, 45.5714, 46.2381, 46.9048, 47.5714,\n            ],\n            [\n                47.8571, 48.5238, 49.1905, 49.8571, 50.5238, 51.1905, 51.8571,\n            ],\n            [\n                52.1429, 52.8095, 53.4762, 54.1429, 54.8095, 55.4762, 56.1429,\n            ],\n            [\n                56.4286, 57.0952, 57.7619, 58.4286, 59.0952, 59.7619, 60.4286,\n            ],\n            [\n                60.7143, 61.3810, 62.0476, 62.7143, 63.3810, 64.0476, 64.7143,\n            ],\n            [\n                65.0000, 65.6667, 66.3333, 67.0000, 67.6667, 68.3333, 69.0000,\n            ],\n        ]],\n    ]));\n}\n\n#[test]\nfn test_downsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 45,\n        width: 14,\n        height_out: 4,\n        width_out: 6,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0.0, 2.6, 5.2, 7.8, 10.4, 13.],\n        [205.3333, 207.9333, 210.5333, 213.1333, 215.7333, 218.3333],\n        [410.6667, 413.2667, 415.8667, 418.4667, 421.0667, 423.6667],\n        [616., 618.6, 621.2, 623.8, 626.4, 629.],\n    ]]]));\n}\n\n#[test]\nfn test_1d_bilinear() {\n    // Initialize the model without weights (because the exported file does not contain them)\n    let device = Default::default();\n\n    // Run the model\n    let input = TestTensor::<3>::from_floats(\n        [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],\n        &device,\n    );\n\n    let input = input.unsqueeze_dim(2);\n\n    let output = interpolate(\n        input,\n        [1, 9],\n        InterpolateOptions::new(InterpolateMode::Bilinear),\n    );\n\n    assert_eq!(output.dims(), [1, 1, 1, 9]);\n\n    // assert output data does not contain NaN\n    assert!(\n        !output\n            .clone()\n            .to_data()\n            .as_slice::<FloatElem>()\n            .unwrap()\n            .iter()\n            .any(|&x| x.is_nan()),\n        \"interpolate output contains NaN\"\n    );\n\n    TestTensor::<4>::from([[[[\n        1.541f32,\n        0.39450002,\n        -0.76475,\n        -1.943125,\n        -0.80520004,\n        0.36178753,\n        -0.671275,\n        -1.2022874,\n        -1.3986,\n    ]]]])\n    .to_data()\n    .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_interpolate_coord_float_precision_boundary() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 28,\n        width: 4,\n        height_out: 24,\n        width_out: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0.0, 3.0],\n        [4.6956, 7.6956],\n        [9.3913, 12.3913],\n        [14.0869, 17.0869],\n        [18.7826, 21.7826],\n        [23.4782, 26.4782],\n        [28.1739, 31.1739],\n        [32.8695, 35.8695],\n        [37.5652, 40.5652],\n        [42.2608, 45.2608],\n        [46.9565, 49.9565],\n        [51.6521, 54.6521],\n        [56.3478, 59.3478],\n        [61.0434, 64.0434],\n        [65.7391, 68.7391],\n        [70.4347, 73.4347],\n        [75.1304, 78.1304],\n        [79.8260, 82.8260],\n        [84.5217, 87.5217],\n        [89.2173, 92.2173],\n        [93.9130, 96.9130],\n        [98.6086, 101.6086],\n        [103.3043, 106.3043],\n        [108.0, 111.0],\n    ]]]));\n}\n\n#[test]\nfn should_interpolate_cast() {\n    let device = Default::default();\n    let shape_x = Shape::new([1, 1, 4, 4]);\n    let x = TestTensor::from(\n        TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n            .reshape::<4, _>(shape_x)\n            .into_data(),\n    )\n    .cast(DType::F32); // ok for f32 backends, casts dtype for f16 tests\n    let output = interpolate(\n        x,\n        [8, 8],\n        InterpolateOptions::new(InterpolateMode::Bilinear),\n    );\n\n    let expected = TestTensor::<4>::from([[[\n        [0.0, 0.42857, 0.8571, 1.2857, 1.7142, 2.1428, 2.5714, 3.0],\n        [1.7142, 2.1428, 2.5714, 3.0, 3.4285, 3.8571, 4.2857, 4.7142],\n        [3.4285, 3.8571, 4.2857, 4.7142, 5.1428, 5.5714, 6.0, 6.4285],\n        [5.1428, 5.5714, 6.0, 6.4285, 6.8571, 7.2857, 7.7142, 8.1428],\n        [6.8571, 7.2857, 7.7142, 8.1428, 8.5714, 9.0, 9.4285, 9.8571],\n        [\n            8.5714, 9.0, 9.4285, 9.8571, 10.2857, 10.7142, 11.1428, 11.5714,\n        ],\n        [\n            10.2857, 10.7142, 11.1428, 11.5714, 12.0, 12.4285, 12.8571, 13.2857,\n        ],\n        [\n            12.0, 12.4285, 12.8571, 13.2857, 13.7142, 14.1428, 14.5714, 15.0,\n        ],\n    ]]]);\n\n    let tolerance = Tolerance::permissive();\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);\n}\n\nstruct InterpolateTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl InterpolateTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        self.assert_output_with_align_corners(y, true);\n    }\n\n    fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = interpolate(\n            x,\n            [self.height_out, self.width_out],\n            InterpolateOptions::new(InterpolateMode::Bilinear).with_align_corners(align_corners),\n        );\n\n        let tolerance = Tolerance::permissive();\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n\n#[test]\nfn test_upsample_half_pixel() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 4,\n        width: 4,\n        height_out: 8,\n        width_out: 8,\n    };\n\n    test.assert_output_with_align_corners(\n        TestTensor::from([[[\n            [0.0, 0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.0],\n            [1.0, 1.25, 1.75, 2.25, 2.75, 3.25, 3.75, 4.0],\n            [3.0, 3.25, 3.75, 4.25, 4.75, 5.25, 5.75, 6.0],\n            [5.0, 5.25, 5.75, 6.25, 6.75, 7.25, 7.75, 8.0],\n            [7.0, 7.25, 7.75, 8.25, 8.75, 9.25, 9.75, 10.0],\n            [9.0, 9.25, 9.75, 10.25, 10.75, 11.25, 11.75, 12.0],\n            [11.0, 11.25, 11.75, 12.25, 12.75, 13.25, 13.75, 14.0],\n            [12.0, 12.25, 12.75, 13.25, 13.75, 14.25, 14.75, 15.0],\n        ]]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv1d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::conv1d;\nuse burn_tensor::ops::ConvOptions;\n\n#[test]\nfn test_conv1d_simple() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[43., 67., 82., 49.], [104., 176., 227., 158.]],\n        [[139., 187., 202., 113.], [392., 584., 635., 414.]],\n    ]));\n}\n\n#[test]\nfn test_conv1d_dilation() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 2,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[62., 38.], [159., 111.]],\n        [[158., 102.], [447., 367.]],\n    ]));\n}\n\n#[test]\nfn test_conv1d_groups() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        stride: 1,\n        dilation: 1,\n        groups: 2,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[2., 5., 8., 3.], [42., 63., 75., 47.]],\n        [[26., 29., 32., 11.], [114., 159., 171., 103.]],\n    ]));\n}\n\n#[test]\nfn test_conv1d_complex() {\n    let test = Conv1dTestCase {\n        batch_size: 2,\n        channels_in: 3,\n        channels_out: 4,\n        kernel_size: 3,\n        padding: 1,\n        stride: 2,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [\n            [[171., 294.], [415., 781.], [659., 1268.], [903., 1755.]],\n            [[495., 726.], [1387., 2185.], [2279., 3644.], [3171., 5103.]],\n        ],\n        &Default::default(),\n    ));\n}\n\nstruct Conv1dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size: usize,\n    padding: usize,\n    stride: usize,\n    dilation: usize,\n    groups: usize,\n    length: usize,\n}\n\nimpl Conv1dTestCase {\n    fn assert_output(self, y: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size,\n        ]);\n        let device = Default::default();\n        let weight = TestTensor::from_data(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_weight)\n                .into_data(),\n            &device,\n        );\n        let bias = TestTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        );\n        let x = TestTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        );\n        let output = conv1d(\n            x,\n            weight,\n            Some(bias),\n            ConvOptions::new([self.stride], [self.padding], [self.dilation], self.groups),\n        );\n\n        let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(1e-3);\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv2d.rs",
    "content": "use super::*;\nuse alloc::{vec, vec::Vec};\nuse burn_tensor::Shape;\nuse burn_tensor::activation::gelu;\nuse burn_tensor::module::conv2d;\nuse burn_tensor::ops::ConvOptions;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn test_conv2d_simple() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [1196., 1796., 1916., 1264.],\n            [1881., 2793., 2946., 1923.],\n            [2313., 3405., 3558., 2307.],\n            [1424., 2072., 2156., 1380.],\n        ],\n        [\n            [2709., 4173., 4509., 3065.],\n            [4582., 7006., 7483., 5056.],\n            [5878., 8914., 9391., 6304.],\n            [4089., 6177., 6477., 4333.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_simple_implicit() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 16,\n        kernel_size_1: 4,\n        kernel_size_2: 4,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 5,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [666., 916., 1030., 774.],\n            [1124., 1500., 1620., 1190.],\n            [1604., 2100., 2220., 1610.],\n            [990., 1264., 1330., 936.],\n        ],\n        [\n            [1531., 2165., 2471., 1927.],\n            [2757., 3805., 4181., 3207.],\n            [4197., 5685., 6061., 4587.],\n            [3295., 4433., 4691., 3529.],\n        ],\n        [\n            [2396., 3414., 3912., 3080.],\n            [4390., 6110., 6742., 5224.],\n            [6790., 9270., 9902., 7564.],\n            [5600., 7602., 8052., 6122.],\n        ],\n        [\n            [3261., 4663., 5353., 4233.],\n            [6023., 8415., 9303., 7241.],\n            [9383., 12855., 13743., 10541.],\n            [7905., 10771., 11413., 8715.],\n        ],\n        [\n            [4126., 5912., 6794., 5386.],\n            [7656., 10720., 11864., 9258.],\n            [11976., 16440., 17584., 13518.],\n            [10210., 13940., 14774., 11308.],\n        ],\n        [\n            [4991., 7161., 8235., 6539.],\n            [9289., 13025., 14425., 11275.],\n            [14569., 20025., 21425., 16495.],\n            [12515., 17109., 18135., 13901.],\n        ],\n        [\n            [5856., 8410., 9676., 7692.],\n            [10922., 15330., 16986., 13292.],\n            [17162., 23610., 25266., 19472.],\n            [14820., 20278., 21496., 16494.],\n        ],\n        [\n            [6721., 9659., 11117., 8845.],\n            [12555., 17635., 19547., 15309.],\n            [19755., 27195., 29107., 22449.],\n            [17125., 23447., 24857., 19087.],\n        ],\n        [\n            [7586., 10908., 12558., 9998.],\n            [14188., 19940., 22108., 17326.],\n            [22348., 30780., 32948., 25426.],\n            [19430., 26616., 28218., 21680.],\n        ],\n        [\n            [8451., 12157., 13999., 11151.],\n            [15821., 22245., 24669., 19343.],\n            [24941., 34365., 36789., 28403.],\n            [21735., 29785., 31579., 24273.],\n        ],\n        [\n            [9316., 13406., 15440., 12304.],\n            [17454., 24550., 27230., 21360.],\n            [27534., 37950., 40630., 31380.],\n            [24040., 32954., 34940., 26866.],\n        ],\n        [\n            [10181., 14655., 16881., 13457.],\n            [19087., 26855., 29791., 23377.],\n            [30127., 41535., 44471., 34357.],\n            [26345., 36123., 38301., 29459.],\n        ],\n        [\n            [11046., 15904., 18322., 14610.],\n            [20720., 29160., 32352., 25394.],\n            [32720., 45120., 48312., 37334.],\n            [28650., 39292., 41662., 32052.],\n        ],\n        [\n            [11911., 17153., 19763., 15763.],\n            [22353., 31465., 34913., 27411.],\n            [35313., 48705., 52153., 40311.],\n            [30955., 42461., 45023., 34645.],\n        ],\n        [\n            [12776., 18402., 21204., 16916.],\n            [23986., 33770., 37474., 29428.],\n            [37906., 52290., 55994., 43288.],\n            [33260., 45630., 48384., 37238.],\n        ],\n        [\n            [13641., 19651., 22645., 18069.],\n            [25619., 36075., 40035., 31445.],\n            [40499., 55875., 59835., 46265.],\n            [35565., 48799., 51745., 39831.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_implicit_padded_in_channels() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 16,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [4521., 6753., 7014., 4635.],\n            [6858., 10197., 10548., 6939.],\n            [7830., 11601., 11952., 7839.],\n            [5007., 7383., 7590., 4953.],\n        ],\n        [\n            [10516., 15988., 16735., 11278.],\n            [16822., 25507., 26587., 17875.],\n            [19738., 29827., 30907., 20719.],\n            [13594., 20506., 21199., 14188.],\n        ],\n        [\n            [16511., 25223., 26456., 17921.],\n            [26786., 40817., 42626., 28811.],\n            [31646., 48053., 49862., 33599.],\n            [22181., 33629., 34808., 23423.],\n        ],\n        [\n            [22506., 34458., 36177., 24564.],\n            [36750., 56127., 58665., 39747.],\n            [43554., 66279., 68817., 46479.],\n            [30768., 46752., 48417., 32658.],\n        ],\n        [\n            [28501., 43693., 45898., 31207.],\n            [46714., 71437., 74704., 50683.],\n            [55462., 84505., 87772., 59359.],\n            [39355., 59875., 62026., 41893.],\n        ],\n        [\n            [34496., 52928., 55619., 37850.],\n            [56678., 86747., 90743., 61619.],\n            [67370., 102731., 106727., 72239.],\n            [47942., 72998., 75635., 51128.],\n        ],\n        [\n            [40491., 62163., 65340., 44493.],\n            [66642., 102057., 106782., 72555.],\n            [79278., 120957., 125682., 85119.],\n            [56529., 86121., 89244., 60363.],\n        ],\n        [\n            [46486., 71398., 75061., 51136.],\n            [76606., 117367., 122821., 83491.],\n            [91186., 139183., 144637., 97999.],\n            [65116., 99244., 102853., 69598.],\n        ],\n        [\n            [52481., 80633., 84782., 57779.],\n            [86570., 132677., 138860., 94427.],\n            [103094., 157409., 163592., 110879.],\n            [73703., 112367., 116462., 78833.],\n        ],\n        [\n            [58476., 89868., 94503., 64422.],\n            [96534., 147987., 154899., 105363.],\n            [115002., 175635., 182547., 123759.],\n            [82290., 125490., 130071., 88068.],\n        ],\n        [\n            [64471., 99103., 104224., 71065.],\n            [106498., 163297., 170938., 116299.],\n            [126910., 193861., 201502., 136639.],\n            [90877., 138613., 143680., 97303.],\n        ],\n        [\n            [70466., 108338., 113945., 77708.],\n            [116462., 178607., 186977., 127235.],\n            [138818., 212087., 220457., 149519.],\n            [99464., 151736., 157289., 106538.],\n        ],\n        [\n            [76461., 117573., 123666., 84351.],\n            [126426., 193917., 203016., 138171.],\n            [150726., 230313., 239412., 162399.],\n            [108051., 164859., 170898., 115773.],\n        ],\n        [\n            [82456., 126808., 133387., 90994.],\n            [136390., 209227., 219055., 149107.],\n            [162634., 248539., 258367., 175279.],\n            [116638., 177982., 184507., 125008.],\n        ],\n        [\n            [88451., 136043., 143108., 97637.],\n            [146354., 224537., 235094., 160043.],\n            [174542., 266765., 277322., 188159.],\n            [125225., 191105., 198116., 134243.],\n        ],\n        [\n            [94446., 145278., 152829., 104280.],\n            [156318., 239847., 251133., 170979.],\n            [186450., 284991., 296277., 201039.],\n            [133812., 204228., 211725., 143478.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_groups_channels_out() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 16,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [73., 121., 154., 103.],\n            [171., 258., 294., 186.],\n            [279., 402., 438., 270.],\n            [139., 187., 202., 113.],\n        ],\n        [\n            [164., 284., 371., 266.],\n            [415., 664., 781., 538.],\n            [739., 1132., 1249., 838.],\n            [518., 782., 851., 564.],\n        ],\n        [\n            [255., 447., 588., 429.],\n            [659., 1070., 1268., 890.],\n            [1199., 1862., 2060., 1406.],\n            [897., 1377., 1500., 1015.],\n        ],\n        [\n            [346., 610., 805., 592.],\n            [903., 1476., 1755., 1242.],\n            [1659., 2592., 2871., 1974.],\n            [1276., 1972., 2149., 1466.],\n        ],\n        [\n            [437., 773., 1022., 755.],\n            [1147., 1882., 2242., 1594.],\n            [2119., 3322., 3682., 2542.],\n            [1655., 2567., 2798., 1917.],\n        ],\n        [\n            [528., 936., 1239., 918.],\n            [1391., 2288., 2729., 1946.],\n            [2579., 4052., 4493., 3110.],\n            [2034., 3162., 3447., 2368.],\n        ],\n        [\n            [619., 1099., 1456., 1081.],\n            [1635., 2694., 3216., 2298.],\n            [3039., 4782., 5304., 3678.],\n            [2413., 3757., 4096., 2819.],\n        ],\n        [\n            [710., 1262., 1673., 1244.],\n            [1879., 3100., 3703., 2650.],\n            [3499., 5512., 6115., 4246.],\n            [2792., 4352., 4745., 3270.],\n        ],\n        [\n            [5793., 8865., 9330., 6335.],\n            [9467., 14450., 15134., 10250.],\n            [11303., 17186., 17870., 12062.],\n            [7971., 12099., 12546., 8457.],\n        ],\n        [\n            [6460., 9892., 10411., 7074.],\n            [10575., 16152., 16917., 11466.],\n            [12627., 19212., 19977., 13494.],\n            [8926., 13558., 14059., 9484.],\n        ],\n        [\n            [7127., 10919., 11492., 7813.],\n            [11683., 17854., 18700., 12682.],\n            [13951., 21238., 22084., 14926.],\n            [9881., 15017., 15572., 10511.],\n        ],\n        [\n            [7794., 11946., 12573., 8552.],\n            [12791., 19556., 20483., 13898.],\n            [15275., 23264., 24191., 16358.],\n            [10836., 16476., 17085., 11538.],\n        ],\n        [\n            [8461., 12973., 13654., 9291.],\n            [13899., 21258., 22266., 15114.],\n            [16599., 25290., 26298., 17790.],\n            [11791., 17935., 18598., 12565.],\n        ],\n        [\n            [9128., 14000., 14735., 10030.],\n            [15007., 22960., 24049., 16330.],\n            [17923., 27316., 28405., 19222.],\n            [12746., 19394., 20111., 13592.],\n        ],\n        [\n            [9795., 15027., 15816., 10769.],\n            [16115., 24662., 25832., 17546.],\n            [19247., 29342., 30512., 20654.],\n            [13701., 20853., 21624., 14619.],\n        ],\n        [\n            [10462., 16054., 16897., 11508.],\n            [17223., 26364., 27615., 18762.],\n            [20571., 31368., 32619., 22086.],\n            [14656., 22312., 23137., 15646.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_groups() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 5,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]],\n        [\n            [3724., 3841., 3958.],\n            [4309., 4426., 4543.],\n            [4894., 5011., 5128.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_groups_multiple_channels() {\n    let test = Conv2dTestCase {\n        batch_size: 1,\n        channels_in: 4,\n        channels_out: 4,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 5,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [4035., 4188., 4341.],\n            [4800., 4953., 5106.],\n            [5565., 5718., 5871.],\n        ],\n        [\n            [10030., 10507., 10984.],\n            [12415., 12892., 13369.],\n            [14800., 15277., 15754.],\n        ],\n        [\n            [56075., 56876., 57677.],\n            [60080., 60881., 61682.],\n            [64085., 64886., 65687.],\n        ],\n        [\n            [78270., 79395., 80520.],\n            [83895., 85020., 86145.],\n            [89520., 90645., 91770.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv2d_complex() {\n    let test = Conv2dTestCase {\n        batch_size: 2,\n        channels_in: 3,\n        channels_out: 4,\n        kernel_size_1: 3,\n        kernel_size_2: 2,\n        padding_1: 1,\n        padding_2: 2,\n        stride_1: 2,\n        stride_2: 3,\n        dilation_1: 1,\n        dilation_2: 2,\n        groups: 1,\n        height: 4,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::from([\n        [\n            [[1845., 3789., 1926.], [3210., 6465., 3228.]],\n            [[4276., 9082., 4789.], [8071., 16834., 8737.]],\n            [[6707., 14375., 7652.], [12932., 27203., 14246.]],\n            [[9138., 19668., 10515.], [17793., 37572., 19755.]],\n        ],\n        [\n            [[5445., 10629., 5166.], [8070., 15645., 7548.]],\n            [[14356., 28882., 14509.], [22651., 45454., 22777.]],\n            [[23267., 47135., 23852.], [37232., 75263., 38006.]],\n            [[32178., 65388., 33195.], [51813., 105072., 53235.]],\n        ],\n    ]));\n}\n\nstruct Conv2dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    groups: usize,\n    height: usize,\n    width: usize,\n}\n\nimpl Conv2dTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n        ]);\n        let device = Default::default();\n        let weight = TestTensor::from(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weight)\n                .into_data(),\n        );\n        let bias = TestTensor::from(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n        );\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = conv2d(\n            x,\n            weight,\n            Some(bias),\n            ConvOptions::new(\n                [self.stride_1, self.stride_2],\n                [self.padding_1, self.padding_2],\n                [self.dilation_1, self.dilation_2],\n                self.groups,\n            ),\n        );\n\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n\n#[rustfmt::skip] // param values are too long\n    fn conv2d_weight() -> TensorData {\n        TensorData::new(\n            vec![0.048065186, -0.3059082, -0.10345459, -0.34643555, -0.20788574, -0.021072388, 0.13745117, -0.05102539, 0.024536133, -0.16479492, -0.19519043, 0.27270508, 0.17700195, -0.33764648, -0.08239746, -0.27929688, 0.17321777, -0.1315918, 0.04574585, -0.17980957, -0.33569336, 0.27612305, 0.30004883, -0.28979492, -0.17297363, -0.021759033, -0.27148438, 0.005657196, 0.29956055, -0.06958008, -0.29345703, -0.14440918, 0.10827637, -0.13305664, -0.20239258, 0.24890137, -0.1541748, -0.20019531, -0.2854004, 0.17016602, 0.07861328, -0.09075928, 0.30908203, -0.00013422966, 0.29589844, 0.15258789, -0.25708008, 0.20422363, -0.2529297, 0.07891846, -0.19506836, 0.23571777, 0.27124023, 0.17370605, -0.16992188, -0.23522949, 0.14648438, -0.09576416, -0.18310547, 0.21044922, -0.08911133, -0.2541504, -0.2775879, -0.2064209, -0.16271973, -0.048919678, -0.03555298, -0.11639404, 0.09661865, -0.10241699, 0.08929443, 0.2866211],\n            [8, 1, 3, 3],\n        )\n    }\n\n#[test]\nfn test_conv2d_binary_broadcasted() {\n    let device = Default::default();\n    let x = TestTensor::<4>::full([1, 1, 28, 28], -0.42421296, &device);\n\n    // conv2d -> batchnorm -> activation\n    let weight = TestTensor::from_data(conv2d_weight(), &device);\n    let bias = TestTensor::from([\n        0.082336426,\n        -0.049591064,\n        0.0031795502,\n        0.00095653534,\n        0.02357483,\n        0.005569458,\n        0.07525635,\n        0.056396484,\n    ]);\n\n    // channels: [1, 8], kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], groups: 1, padding: [0, 0]\n    let opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 1);\n    let x = conv2d(x, weight, Some(bias), opt);\n\n    // simulate batchnorm binary ops with broadcasted params\n    let gamma = TestTensor::<1>::from([\n        1.0048828, 0.9902344, 1.0185547, 0.97558594, 1.0097656, 0.97802734, 1.0009766, 1.0146484,\n    ]);\n    let beta = TestTensor::<1>::from([\n        0.026290894,\n        0.0007505417,\n        0.006134033,\n        0.02418518,\n        0.07373047,\n        0.020507813,\n        0.01902771,\n        0.02003479,\n    ]);\n    let mean = TestTensor::<1>::from([\n        0.029159546,\n        -0.08673096,\n        -0.03894043,\n        -0.01108551,\n        0.032440186,\n        0.03237915,\n        0.013839722,\n        0.04397583,\n    ])\n    .reshape([1, 8, 1, 1]);\n    let var = TestTensor::<1>::from([\n        0.67089844, 0.29956055, 0.5209961, 0.1862793, 0.30419922, 0.21313477, 0.7504883, 0.26342773,\n    ])\n    .reshape([1, 8, 1, 1]);\n\n    let std = var.add_scalar(1e-5).sqrt();\n    let x = x.sub(mean);\n    let x = x.div(std);\n    let x = x.mul(gamma.reshape([1, 8, 1, 1]));\n    let x = x.add(beta.reshape([1, 8, 1, 1]));\n\n    let x = gelu(x);\n\n    let expected: Vec<f32> = [\n        0.36432067f32,\n        0.34909567,\n        0.30684796,\n        0.13217466,\n        -0.018471397,\n        -0.1389876,\n        0.39402074,\n        0.12394252,\n    ]\n    .iter()\n    .flat_map(|&v| core::iter::repeat_n(v, 676))\n    .collect();\n    let expected = TensorData::new(expected, [1, 8, 26, 26]);\n\n    x.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_absolute(1e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv3d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::conv3d;\nuse burn_tensor::ops::ConvOptions;\n\n#[test]\nfn test_conv3d_simple() {\n    let test = Conv3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 4,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [29980.0, 44860.0, 45640.0, 30324.0],\n                [45072.0, 67380.0, 68496.0, 45468.0],\n                [48096.0, 71844.0, 72960.0, 48396.0],\n                [31780.0, 47428.0, 48136.0, 31900.0],\n            ],\n            [\n                [47292.0, 70548.0, 71556.0, 47400.0],\n                [70335.0, 104823.0, 106254.0, 70317.0],\n                [74223.0, 110547.0, 111978.0, 74061.0],\n                [48552.0, 72240.0, 73140.0, 48324.0],\n            ],\n            [\n                [58236.0, 86676.0, 87684.0, 57960.0],\n                [85887.0, 127719.0, 129150.0, 85293.0],\n                [89775.0, 133443.0, 134874.0, 89037.0],\n                [58344.0, 86640.0, 87540.0, 57732.0],\n            ],\n            [\n                [36148.0, 53620.0, 54184.0, 35692.0],\n                [52740.0, 78144.0, 78936.0, 51936.0],\n                [54900.0, 81312.0, 82104.0, 54000.0],\n                [35260.0, 52156.0, 52648.0, 34580.0],\n            ],\n        ],\n        [\n            [\n                [66701.0, 100589.0, 102665.0, 68773.0],\n                [102745.0, 154861.0, 157921.0, 105733.0],\n                [110953.0, 167101.0, 170161.0, 113845.0],\n                [75413.0, 113525.0, 115529.0, 77261.0],\n            ],\n            [\n                [112741.0, 169693.0, 172645.0, 115441.0],\n                [172396.0, 259372.0, 263719.0, 176266.0],\n                [184060.0, 276760.0, 281107.0, 187786.0],\n                [124369.0, 186937.0, 189781.0, 126733.0],\n            ],\n            [\n                [144421.0, 216925.0, 219877.0, 146737.0],\n                [219052.0, 328924.0, 333271.0, 222346.0],\n                [230716.0, 346312.0, 350659.0, 233866.0],\n                [154897.0, 232441.0, 235285.0, 156877.0],\n            ],\n            [\n                [100517.0, 150821.0, 152681.0, 101789.0],\n                [151885.0, 227833.0, 230569.0, 153673.0],\n                [159229.0, 238777.0, 241513.0, 160921.0],\n                [106541.0, 159725.0, 161513.0, 107589.0],\n            ],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv3d_groups() {\n    let test = Conv3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 0,\n        padding_2: 0,\n        padding_3: 0,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 2,\n        depth: 5,\n        height: 5,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [15219., 15570., 15921.],\n                [16974., 17325., 17676.],\n                [18729., 19080., 19431.],\n            ],\n            [\n                [23994., 24345., 24696.],\n                [25749., 26100., 26451.],\n                [27504., 27855., 28206.],\n            ],\n            [\n                [32769., 33120., 33471.],\n                [34524., 34875., 35226.],\n                [36279., 36630., 36981.],\n            ],\n        ],\n        [\n            [\n                [172819., 173899., 174979.],\n                [178219., 179299., 180379.],\n                [183619., 184699., 185779.],\n            ],\n            [\n                [199819., 200899., 201979.],\n                [205219., 206299., 207379.],\n                [210619., 211699., 212779.],\n            ],\n            [\n                [226819., 227899., 228979.],\n                [232219., 233299., 234379.],\n                [237619., 238699., 239779.],\n            ],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv3d_complex() {\n    let test = Conv3dTestCase {\n        batch_size: 2,\n        channels_in: 3,\n        channels_out: 4,\n        kernel_size_1: 4,\n        kernel_size_2: 3,\n        kernel_size_3: 2,\n        padding_1: 1,\n        padding_2: 2,\n        padding_3: 3,\n        stride_1: 2,\n        stride_2: 3,\n        stride_3: 4,\n        dilation_1: 1,\n        dilation_2: 2,\n        dilation_3: 3,\n        groups: 1,\n        depth: 4,\n        height: 5,\n        width: 6,\n    };\n\n    test.assert_output(TestTensor::from([\n        [\n            [\n                [[149148., 299070., 149850.], [147636., 295758., 148050.]],\n                [[150660., 301014., 150282.], [147420., 294246., 146754.]],\n            ],\n            [\n                [[351325., 709903., 358507.], [357589., 722143., 364483.]],\n                [[391717., 789607., 397819.], [396253., 798391., 402067.]],\n            ],\n            [\n                [[553502., 1120736., 567164.], [567542., 1148528., 580916.]],\n                [[632774., 1278200., 645356.], [645086., 1302536., 657380.]],\n            ],\n            [\n                [[755679., 1531569., 775821.], [777495., 1574913., 797349.]],\n                [[873831., 1766793., 892893.], [893919., 1806681., 912693.]],\n            ],\n        ],\n        [\n            [\n                [[408348., 810990., 402570.], [393876., 781758., 387810.]],\n                [[370980., 735174., 364122.], [354780., 702486., 347634.]],\n            ],\n            [\n                [\n                    [1077085., 2154943., 1077787.],\n                    [1070389., 2141263., 1070803.],\n                ],\n                [\n                    [1078597., 2156887., 1078219.],\n                    [1070173., 2139751., 1069507.],\n                ],\n            ],\n            [\n                [\n                    [1745822., 3498896., 1753004.],\n                    [1746902., 3500768., 1753796.],\n                ],\n                [\n                    [1786214., 3578600., 1792316.],\n                    [1785566., 3577016., 1791380.],\n                ],\n            ],\n            [\n                [\n                    [2414559., 4842849., 2428221.],\n                    [2423415., 4860273., 2436789.],\n                ],\n                [\n                    [2493831., 5000313., 2506413.],\n                    [2500959., 5014281., 2513253.],\n                ],\n            ],\n        ],\n    ]));\n}\n\nstruct Conv3dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    kernel_size_3: usize,\n    padding_1: usize,\n    padding_2: usize,\n    padding_3: usize,\n    stride_1: usize,\n    stride_2: usize,\n    stride_3: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    dilation_3: usize,\n    groups: usize,\n    depth: usize,\n    height: usize,\n    width: usize,\n}\n\nimpl Conv3dTestCase {\n    fn assert_output(self, y: TestTensor<5>) {\n        let shape_x = Shape::new([\n            self.batch_size,\n            self.channels_in,\n            self.depth,\n            self.height,\n            self.width,\n        ]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n            self.kernel_size_3,\n        ]);\n        let device = Default::default();\n        let weight = TestTensor::from(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_weight)\n                .into_data(),\n        );\n        let bias = TestTensor::from(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n        );\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_x)\n                .into_data(),\n        );\n        let output = conv3d(\n            x,\n            weight,\n            Some(bias),\n            ConvOptions::new(\n                [self.stride_1, self.stride_2, self.stride_3],\n                [self.padding_1, self.padding_2, self.padding_3],\n                [self.dilation_1, self.dilation_2, self.dilation_3],\n                self.groups,\n            ),\n        );\n\n        let tolerance = Tolerance::relative(1e-5).set_half_precision_relative(2e-3);\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv_transpose1d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::conv_transpose1d;\nuse burn_tensor::ops::ConvTransposeOptions;\n\n#[test]\nfn test_conv_transpose1d_diff_channels() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        padding_out: 0,\n        stride: 1,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [270., 453., 516., 387.],\n        [352., 589., 679., 505.],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose1d_stride() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        padding_out: 1,\n        stride: 2,\n        dilation: 1,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [28., 62., 36., 78., 44., 94., 52., 62.],\n        [41., 93., 55., 121., 69., 149., 83., 93.],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose1d_dilation() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        padding_out: 0,\n        stride: 1,\n        dilation: 2,\n        groups: 1,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [30., 64., 78., 76., 94., 52.],\n        [49., 101., 127., 113., 143., 77.],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose1d_groups() {\n    let test = ConvTranspose1dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size: 3,\n        padding: 1,\n        padding_out: 0,\n        stride: 1,\n        dilation: 1,\n        groups: 2,\n        length: 4,\n    };\n\n    test.assert_output(TestTensor::from_floats(\n        [[[0., 1., 4., 7.], [32., 59., 71., 59.]]],\n        &Default::default(),\n    ));\n}\n\nstruct ConvTranspose1dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size: usize,\n    padding: usize,\n    padding_out: usize,\n    stride: usize,\n    dilation: usize,\n    groups: usize,\n    length: usize,\n}\n\nimpl ConvTranspose1dTestCase {\n    fn assert_output(self, y: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.length]);\n        let shape_weights = Shape::new([\n            self.channels_in,\n            self.channels_out / self.groups,\n            self.kernel_size,\n        ]);\n        let device = Default::default();\n        let weights = TestTensor::from_data(\n            TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_weights)\n                .into_data(),\n            &device,\n        );\n        let bias = TestTensor::from_data(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n            &device,\n        );\n        let x = TestTensor::from_data(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<3, _>(shape_x)\n                .into_data(),\n            &device,\n        );\n        let output = conv_transpose1d(\n            x,\n            weights,\n            Some(bias),\n            ConvTransposeOptions::new(\n                [self.stride],\n                [self.padding],\n                [self.padding_out],\n                [self.dilation],\n                self.groups,\n            ),\n        );\n\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv_transpose2d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::conv_transpose2d;\nuse burn_tensor::ops::ConvTransposeOptions;\n\n#[test]\nfn test_conv_transpose2d_simple_1() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]]));\n}\n\n#[test]\nfn test_conv_transpose2d_simple_2() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [9855., 15207., 15738., 10797.],\n            [16290., 25119., 25956., 17793.],\n            [18486., 28467., 29304., 20061.],\n            [13593., 20913., 21498., 14703.],\n        ],\n        [\n            [11854., 18286., 18979., 13012.],\n            [19612., 30223., 31303., 21439.],\n            [22456., 34543., 35623., 24355.],\n            [16456., 25288., 26035., 17782.],\n        ],\n        [\n            [13853., 21365., 22220., 15227.],\n            [22934., 35327., 36650., 25085.],\n            [26426., 40619., 41942., 28649.],\n            [19319., 29663., 30572., 20861.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose2d_simple_3() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 1,\n        kernel_size_1: 2,\n        kernel_size_2: 2,\n        padding_1: 0,\n        padding_2: 0,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0.0, 0.0, 1.0],\n        [0.0, 4.0, 6.0],\n        [4.0, 12.0, 9.0],\n    ]]]));\n}\n\n#[test]\nfn test_conv_transpose2d_stride_2() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 1,\n        kernel_size_1: 2,\n        kernel_size_2: 2,\n        padding_1: 0,\n        padding_2: 0,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 2,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0.0, 0.0, 0.0, 1.0],\n        [0.0, 0.0, 2.0, 3.0],\n        [0.0, 2.0, 0.0, 3.0],\n        [4.0, 6.0, 6.0, 9.0],\n    ]]]));\n}\n\n#[test]\nfn test_conv_transpose2d_dilation_2() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_out_1: 1,\n        padding_out_2: 1,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 2,\n        dilation_2: 2,\n        groups: 1,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [126., 116., 136., 124., 146.],\n            [108., 88., 114., 92., 120.],\n            [156., 140., 166., 148., 176.],\n            [126., 100., 132., 104., 138.],\n            [186., 164., 196., 172., 206.],\n        ],\n        [\n            [217., 189., 227., 197., 237.],\n            [163., 125., 169., 129., 175.],\n            [247., 213., 257., 221., 267.],\n            [181., 137., 187., 141., 193.],\n            [277., 237., 287., 245., 297.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose2d_stride2_out_padding() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_out_1: 1,\n        padding_out_2: 1,\n        stride_1: 2,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [352., 728., 378., 780., 404., 832., 430., 452.],\n            [784., 1616., 836., 1720., 888., 1824., 940., 992.],\n            [456., 936., 482., 988., 508., 1040., 534., 564.],\n            [992., 2032., 1044., 2136., 1096., 2240., 1148., 1216.],\n            [560., 1144., 586., 1196., 612., 1248., 638., 676.],\n            [1200., 2448., 1252., 2552., 1304., 2656., 1356., 1440.],\n            [664., 1352., 690., 1404., 716., 1456., 742., 788.],\n            [784., 1598., 816., 1662., 848., 1726., 880., 926.],\n        ],\n        [\n            [497., 1035., 541., 1123., 585., 1211., 629., 651.],\n            [1145., 2373., 1233., 2549., 1321., 2725., 1409., 1461.],\n            [673., 1387., 717., 1475., 761., 1563., 805., 835.],\n            [1497., 3077., 1585., 3253., 1673., 3429., 1761., 1829.],\n            [849., 1739., 893., 1827., 937., 1915., 981., 1019.],\n            [1849., 3781., 1937., 3957., 2025., 4133., 2113., 2197.],\n            [1025., 2091., 1069., 2179., 1113., 2267., 1157., 1203.],\n            [1145., 2337., 1195., 2437., 1245., 2537., 1295., 1341.],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose2d_groups_2() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [[5., 11.], [23., 29.]],\n        [[236., 258.], [302., 324.]],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose2d_groups_different_channels() {\n    let test = ConvTranspose2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 6,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        groups: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00],\n            [0.0000e+00, 5.0000e+00, 1.1000e+01, 1.1000e+01],\n            [6.0000e+00, 2.3000e+01, 2.9000e+01, 2.3000e+01],\n            [1.2000e+01, 3.2000e+01, 3.7000e+01, 2.4000e+01],\n        ],\n        [\n            [1.0000e+00, 1.0000e+01, 1.1000e+01, 1.2000e+01],\n            [1.9000e+01, 6.0000e+01, 6.6000e+01, 4.8000e+01],\n            [2.5000e+01, 7.8000e+01, 8.4000e+01, 6.0000e+01],\n            [3.1000e+01, 7.8000e+01, 8.3000e+01, 5.2000e+01],\n        ],\n        [\n            [2.0000e+00, 2.0000e+01, 2.1000e+01, 2.2000e+01],\n            [3.8000e+01, 1.1500e+02, 1.2100e+02, 8.5000e+01],\n            [4.4000e+01, 1.3300e+02, 1.3900e+02, 9.7000e+01],\n            [5.0000e+01, 1.2400e+02, 1.2900e+02, 8.0000e+01],\n        ],\n        [\n            [1.1100e+02, 2.5000e+02, 2.5900e+02, 1.4800e+02],\n            [2.8500e+02, 6.3400e+02, 6.5600e+02, 3.6600e+02],\n            [3.1500e+02, 7.0000e+02, 7.2200e+02, 4.0200e+02],\n            [2.0100e+02, 4.3800e+02, 4.5100e+02, 2.4800e+02],\n        ],\n        [\n            [1.4800e+02, 3.3200e+02, 3.4100e+02, 1.9400e+02],\n            [3.7600e+02, 8.3300e+02, 8.5500e+02, 4.7500e+02],\n            [4.0600e+02, 8.9900e+02, 9.2100e+02, 5.1100e+02],\n            [2.5600e+02, 5.5600e+02, 5.6900e+02, 3.1200e+02],\n        ],\n        [\n            [1.8500e+02, 4.1400e+02, 4.2300e+02, 2.4000e+02],\n            [4.6700e+02, 1.0320e+03, 1.0540e+03, 5.8400e+02],\n            [4.9700e+02, 1.0980e+03, 1.1200e+03, 6.2000e+02],\n            [3.1100e+02, 6.7400e+02, 6.8700e+02, 3.7600e+02],\n        ],\n    ]]));\n}\n\nstruct ConvTranspose2dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    padding_out_1: usize,\n    padding_out_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    groups: usize,\n    height: usize,\n    width: usize,\n}\n\nimpl ConvTranspose2dTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let shape_weights = Shape::new([\n            self.channels_in,\n            self.channels_out / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n        ]);\n        let device = Default::default();\n        let weights = TestTensor::from(\n            TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weights)\n                .into_data(),\n        );\n        let bias = TestTensor::from(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n        );\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = conv_transpose2d(\n            x,\n            weights,\n            Some(bias),\n            ConvTransposeOptions::new(\n                [self.stride_1, self.stride_2],\n                [self.padding_1, self.padding_2],\n                [self.padding_out_1, self.padding_out_2],\n                [self.dilation_1, self.dilation_2],\n                self.groups,\n            ),\n        );\n\n        y.into_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::rel_abs(1e-1, 0.01));\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/conv_transpose3d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::conv_transpose3d;\nuse burn_tensor::ops::ConvTransposeOptions;\n\n#[test]\nfn test_conv_transpose3d_simple_1() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 1,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        padding_out_3: 0,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [[96., 124.], [180., 208.]],\n        [[348., 376.], [432., 460.]],\n    ]]]));\n}\n#[test]\nfn test_conv_transpose3d_simple_2() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        padding_out_3: 0,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 4,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [238452., 360588., 363756., 244488.],\n                [367929., 556353., 561186., 377163.],\n                [380745., 575685., 580518., 390123.],\n                [261192., 394896., 398172., 267564.],\n            ],\n            [\n                [394083., 595827., 600822., 403749.],\n                [607635., 918648., 926262., 622404.],\n                [627831., 949104., 956718., 642816.],\n                [430353., 650529., 655686., 440523.],\n            ],\n            [\n                [447075., 675747., 680742., 457317.],\n                [688419., 1040472., 1048086., 704052.],\n                [708615., 1070928., 1078542., 724464.],\n                [485073., 733041., 738198., 495819.],\n            ],\n            [\n                [328656., 496632., 500124., 335892.],\n                [505611., 763983., 769302., 516645.],\n                [519723., 785259., 790578., 530901.],\n                [355428., 536988., 540588., 363000.],\n            ],\n        ],\n        [\n            [\n                [286729., 433489., 437629., 294061.],\n                [442288., 668620., 674911., 453466.],\n                [458992., 693784., 700075., 470314.],\n                [314653., 475573., 479821., 322321.],\n            ],\n            [\n                [474274., 716842., 723295., 485884.],\n                [730837., 1104544., 1114345., 748522.],\n                [756865., 1143748., 1153549., 774766.],\n                [518320., 783208., 789823., 530434.],\n            ],\n            [\n                [542818., 820090., 826543., 555004.],\n                [834949., 1261360., 1271161., 853498.],\n                [860977., 1300564., 1310365., 879742.],\n                [588592., 889048., 895663., 601282.],\n            ],\n            [\n                [397669., 600637., 605101., 406201.],\n                [611074., 922906., 929683., 624052.],\n                [629074., 950014., 956791., 642196.],\n                [429625., 648769., 653341., 438493.],\n            ],\n        ],\n        [\n            [\n                [335006., 506390., 511502., 343634.],\n                [516647., 780887., 788636., 529769.],\n                [537239., 811883., 819632., 550505.],\n                [368114., 556250., 561470., 377078.],\n            ],\n            [\n                [554465., 837857., 845768., 568019.],\n                [854039., 1290440., 1302428., 874640.],\n                [885899., 1338392., 1350380., 906716.],\n                [606287., 915887., 923960., 620345.],\n            ],\n            [\n                [638561., 964433., 972344., 652691.],\n                [981479., 1482248., 1494236., 1002944.],\n                [1013339., 1530200., 1542188., 1035020.],\n                [692111., 1045055., 1053128., 706745.],\n            ],\n            [\n                [466682., 704642., 710078., 476510.],\n                [716537., 1081829., 1090064., 731459.],\n                [738425., 1114769., 1123004., 753491.],\n                [503822., 760550., 766094., 513986.],\n            ],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose3d_stride_2() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 1,\n        channels_out: 1,\n        kernel_size_1: 2,\n        kernel_size_2: 2,\n        kernel_size_3: 2,\n        padding_1: 0,\n        padding_2: 0,\n        padding_3: 0,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        padding_out_3: 0,\n        stride_1: 2,\n        stride_2: 2,\n        stride_3: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [\n            [0., 0., 0., 1.],\n            [0., 0., 2., 3.],\n            [0., 2., 0., 3.],\n            [4., 6., 6., 9.],\n        ],\n        [\n            [0., 0., 4., 5.],\n            [0., 0., 6., 7.],\n            [8., 10., 12., 15.],\n            [12., 14., 18., 21.],\n        ],\n        [\n            [0., 4., 0., 5.],\n            [8., 12., 10., 15.],\n            [0., 6., 0., 7.],\n            [12., 18., 14., 21.],\n        ],\n        [\n            [16., 20., 20., 25.],\n            [24., 28., 30., 35.],\n            [24., 30., 28., 35.],\n            [36., 42., 42., 49.],\n        ],\n    ]]]));\n}\n\n#[test]\nfn test_conv_transpose3d_dilation_2() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        padding_out_1: 1,\n        padding_out_2: 1,\n        padding_out_3: 1,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 2,\n        dilation_2: 2,\n        dilation_3: 2,\n        groups: 1,\n        depth: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [810., 776., 832., 796., 854.],\n                [756., 712., 774., 728., 792.],\n                [876., 836., 898., 856., 920.],\n                [810., 760., 828., 776., 846.],\n                [942., 896., 964., 916., 986.],\n            ],\n            [\n                [720., 660., 734., 672., 748.],\n                [606., 536., 616., 544., 626.],\n                [762., 696., 776., 708., 790.],\n                [636., 560., 646., 568., 656.],\n                [804., 732., 818., 744., 832.],\n            ],\n            [\n                [1008., 956., 1030., 976., 1052.],\n                [918., 856., 936., 872., 954.],\n                [1074., 1016., 1096., 1036., 1118.],\n                [972., 904., 990., 920., 1008.],\n                [1140., 1076., 1162., 1096., 1184.],\n            ],\n            [\n                [846., 768., 860., 780., 874.],\n                [696., 608., 706., 616., 716.],\n                [888., 804., 902., 816., 916.],\n                [726., 632., 736., 640., 746.],\n                [930., 840., 944., 852., 958.],\n            ],\n            [\n                [1206., 1136., 1228., 1156., 1250.],\n                [1080., 1000., 1098., 1016., 1116.],\n                [1272., 1196., 1294., 1216., 1316.],\n                [1134., 1048., 1152., 1064., 1170.],\n                [1338., 1256., 1360., 1276., 1382.],\n            ],\n        ],\n        [\n            [\n                [1405., 1317., 1427., 1337., 1449.],\n                [1243., 1145., 1261., 1161., 1279.],\n                [1471., 1377., 1493., 1397., 1515.],\n                [1297., 1193., 1315., 1209., 1333.],\n                [1537., 1437., 1559., 1457., 1581.],\n            ],\n            [\n                [1099., 985., 1113., 997., 1127.],\n                [877., 753., 887., 761., 897.],\n                [1141., 1021., 1155., 1033., 1169.],\n                [907., 777., 917., 785., 927.],\n                [1183., 1057., 1197., 1069., 1211.],\n            ],\n            [\n                [1603., 1497., 1625., 1517., 1647.],\n                [1405., 1289., 1423., 1305., 1441.],\n                [1669., 1557., 1691., 1577., 1713.],\n                [1459., 1337., 1477., 1353., 1495.],\n                [1735., 1617., 1757., 1637., 1779.],\n            ],\n            [\n                [1225., 1093., 1239., 1105., 1253.],\n                [967., 825., 977., 833., 987.],\n                [1267., 1129., 1281., 1141., 1295.],\n                [997., 849., 1007., 857., 1017.],\n                [1309., 1165., 1323., 1177., 1337.],\n            ],\n            [\n                [1801., 1677., 1823., 1697., 1845.],\n                [1567., 1433., 1585., 1449., 1603.],\n                [1867., 1737., 1889., 1757., 1911.],\n                [1621., 1481., 1639., 1497., 1657.],\n                [1933., 1797., 1955., 1817., 1977.],\n            ],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose3d_stride2_out_padding() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        padding_out_1: 1,\n        padding_out_2: 1,\n        padding_out_3: 1,\n        stride_1: 2,\n        stride_2: 2,\n        stride_3: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 1,\n        depth: 2,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [2144., 4366., 2224., 4526., 2304., 4686., 2384., 2422.],\n                [4584., 9324., 4744., 9644., 4904., 9964., 5064., 5148.],\n                [2464., 5006., 2544., 5166., 2624., 5326., 2704., 2750.],\n                [5224., 10604., 5384., 10924., 5544., 11244., 5704., 5804.],\n                [2784., 5646., 2864., 5806., 2944., 5966., 3024., 3078.],\n                [5864., 11884., 6024., 12204., 6184., 12524., 6344., 6460.],\n                [3104., 6286., 3184., 6446., 3264., 6606., 3344., 3406.],\n                [3272., 6628., 3358., 6800., 3444., 6972., 3530., 3592.],\n            ],\n            [\n                [5280., 10716., 5440., 11036., 5600., 11356., 5760., 5868.],\n                [\n                    11152., 22616., 11472., 23256., 11792., 23896., 12112., 12344.,\n                ],\n                [5920., 11996., 6080., 12316., 6240., 12636., 6400., 6524.],\n                [\n                    12432., 25176., 12752., 25816., 13072., 26456., 13392., 13656.,\n                ],\n                [6560., 13276., 6720., 13596., 6880., 13916., 7040., 7180.],\n                [\n                    13712., 27736., 14032., 28376., 14352., 29016., 14672., 14968.,\n                ],\n                [7200., 14556., 7360., 14876., 7520., 15196., 7680., 7836.],\n                [7632., 15432., 7804., 15776., 7976., 16120., 8148., 8304.],\n            ],\n            [\n                [3424., 6926., 3504., 7086., 3584., 7246., 3664., 3734.],\n                [7144., 14444., 7304., 14764., 7464., 15084., 7624., 7772.],\n                [3744., 7566., 3824., 7726., 3904., 7886., 3984., 4062.],\n                [7784., 15724., 7944., 16044., 8104., 16364., 8264., 8428.],\n                [4064., 8206., 4144., 8366., 4224., 8526., 4304., 4390.],\n                [8424., 17004., 8584., 17324., 8744., 17644., 8904., 9084.],\n                [4384., 8846., 4464., 9006., 4544., 9166., 4624., 4718.],\n                [4648., 9380., 4734., 9552., 4820., 9724., 4906., 5000.],\n            ],\n            [\n                [4000., 8096., 4098., 8292., 4196., 8488., 4294., 4364.],\n                [8368., 16928., 8564., 17320., 8760., 17712., 8956., 9104.],\n                [4392., 8880., 4490., 9076., 4588., 9272., 4686., 4764.],\n                [9152., 18496., 9348., 18888., 9544., 19280., 9740., 9904.],\n                [4784., 9664., 4882., 9860., 4980., 10056., 5078., 5164.],\n                [\n                    9936., 20064., 10132., 20456., 10328., 20848., 10524., 10704.,\n                ],\n                [5176., 10448., 5274., 10644., 5372., 10840., 5470., 5564.],\n                [5440., 10982., 5544., 11190., 5648., 11398., 5752., 5846.],\n            ],\n        ],\n        [\n            [\n                [3009., 6149., 3143., 6417., 3277., 6685., 3411., 3449.],\n                [6529., 13321., 6797., 13857., 7065., 14393., 7333., 7417.],\n                [3545., 7221., 3679., 7489., 3813., 7757., 3947., 3993.],\n                [7601., 15465., 7869., 16001., 8137., 16537., 8405., 8505.],\n                [4081., 8293., 4215., 8561., 4349., 8829., 4483., 4537.],\n                [8673., 17609., 8941., 18145., 9209., 18681., 9477., 9593.],\n                [4617., 9365., 4751., 9633., 4885., 9901., 5019., 5081.],\n                [4785., 9707., 4925., 9987., 5065., 10267., 5205., 5267.],\n            ],\n            [\n                [7873., 16009., 8141., 16545., 8409., 17081., 8677., 8785.],\n                [\n                    16769., 34065., 17305., 35137., 17841., 36209., 18377., 18609.,\n                ],\n                [8945., 18153., 9213., 18689., 9481., 19225., 9749., 9873.],\n                [\n                    18913., 38353., 19449., 39425., 19985., 40497., 20521., 20785.,\n                ],\n                [\n                    10017., 20297., 10285., 20833., 10553., 21369., 10821., 10961.,\n                ],\n                [\n                    21057., 42641., 21593., 43713., 22129., 44785., 22665., 22961.,\n                ],\n                [\n                    11089., 22441., 11357., 22977., 11625., 23513., 11893., 12049.,\n                ],\n                [\n                    11521., 23317., 11801., 23877., 12081., 24437., 12361., 12517.,\n                ],\n            ],\n            [\n                [5153., 10437., 5287., 10705., 5421., 10973., 5555., 5625.],\n                [\n                    10817., 21897., 11085., 22433., 11353., 22969., 11621., 11769.,\n                ],\n                [5689., 11509., 5823., 11777., 5957., 12045., 6091., 6169.],\n                [\n                    11889., 24041., 12157., 24577., 12425., 25113., 12693., 12857.,\n                ],\n                [6225., 12581., 6359., 12849., 6493., 13117., 6627., 6713.],\n                [\n                    12961., 26185., 13229., 26721., 13497., 27257., 13765., 13945.,\n                ],\n                [6761., 13653., 6895., 13921., 7029., 14189., 7163., 7257.],\n                [7025., 14187., 7165., 14467., 7305., 14747., 7445., 7539.],\n            ],\n            [\n                [5729., 11607., 5881., 11911., 6033., 12215., 6185., 6255.],\n                [\n                    12041., 24381., 12345., 24989., 12649., 25597., 12953., 13101.,\n                ],\n                [6337., 12823., 6489., 13127., 6641., 13431., 6793., 6871.],\n                [\n                    13257., 26813., 13561., 27421., 13865., 28029., 14169., 14333.,\n                ],\n                [6945., 14039., 7097., 14343., 7249., 14647., 7401., 7487.],\n                [\n                    14473., 29245., 14777., 29853., 15081., 30461., 15385., 15565.,\n                ],\n                [7553., 15255., 7705., 15559., 7857., 15863., 8009., 8103.],\n                [7817., 15789., 7975., 16105., 8133., 16421., 8291., 8385.],\n            ],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose3d_groups_2() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 2,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 1,\n        padding_2: 1,\n        padding_3: 1,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        padding_out_3: 0,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 2,\n        depth: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [[[96., 124.], [180., 208.]], [[348., 376.], [432., 460.]]],\n        [\n            [[2997., 3089.], [3273., 3365.]],\n            [[3825., 3917.], [4101., 4193.]],\n        ],\n    ]]));\n}\n\n#[test]\nfn test_conv_transpose3d_groups_different_channels() {\n    let test = ConvTranspose3dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 6,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        kernel_size_3: 3,\n        padding_1: 0,\n        padding_2: 0,\n        padding_3: 0,\n        padding_out_1: 0,\n        padding_out_2: 0,\n        padding_out_3: 0,\n        stride_1: 1,\n        stride_2: 1,\n        stride_3: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        dilation_3: 1,\n        groups: 2,\n        depth: 2,\n        height: 2,\n        width: 2,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [\n            [\n                [0., 0., 1., 2.],\n                [0., 5., 11., 11.],\n                [6., 23., 29., 23.],\n                [12., 32., 37., 24.],\n            ],\n            [\n                [0., 13., 23., 21.],\n                [30., 96., 124., 86.],\n                [66., 180., 208., 134.],\n                [66., 161., 179., 107.],\n            ],\n            [\n                [36., 103., 113., 75.],\n                [138., 348., 376., 230.],\n                [174., 432., 460., 278.],\n                [138., 323., 341., 197.],\n            ],\n            [\n                [72., 166., 175., 100.],\n                [192., 433., 455., 255.],\n                [222., 499., 521., 291.],\n                [144., 318., 331., 182.],\n            ],\n        ],\n        [\n            [\n                [1., 28., 29., 30.],\n                [55., 168., 174., 120.],\n                [61., 186., 192., 132.],\n                [67., 168., 173., 106.],\n            ],\n            [\n                [109., 284., 294., 184.],\n                [355., 853., 881., 519.],\n                [391., 937., 965., 567.],\n                [283., 648., 666., 378.],\n            ],\n            [\n                [145., 374., 384., 238.],\n                [463., 1105., 1133., 663.],\n                [499., 1189., 1217., 711.],\n                [355., 810., 828., 468.],\n            ],\n            [\n                [181., 410., 419., 236.],\n                [463., 1028., 1050., 580.],\n                [493., 1094., 1116., 616.],\n                [307., 670., 683., 372.],\n            ],\n        ],\n        [\n            [\n                [2., 56., 57., 58.],\n                [110., 331., 337., 229.],\n                [116., 349., 355., 241.],\n                [122., 304., 309., 188.],\n            ],\n            [\n                [218., 555., 565., 347.],\n                [680., 1610., 1638., 952.],\n                [716., 1694., 1722., 1000.],\n                [500., 1135., 1153., 649.],\n            ],\n            [\n                [254., 645., 655., 401.],\n                [788., 1862., 1890., 1096.],\n                [824., 1946., 1974., 1144.],\n                [572., 1297., 1315., 739.],\n            ],\n            [\n                [290., 654., 663., 372.],\n                [734., 1623., 1645., 905.],\n                [764., 1689., 1711., 941.],\n                [470., 1022., 1035., 562.],\n            ],\n        ],\n        [\n            [\n                [651., 1388., 1405., 750.],\n                [1485., 3150., 3188., 1690.],\n                [1539., 3264., 3302., 1750.],\n                [873., 1840., 1861., 982.],\n            ],\n            [\n                [1695., 3578., 3620., 1910.],\n                [3789., 7967., 8059., 4233.],\n                [3921., 8243., 8335., 4377.],\n                [2181., 4566., 4616., 2416.],\n            ],\n            [\n                [1875., 3956., 3998., 2108.],\n                [4185., 8795., 8887., 4665.],\n                [4317., 9071., 9163., 4809.],\n                [2397., 5016., 5066., 2650.],\n            ],\n            [\n                [1191., 2490., 2515., 1316.],\n                [2613., 5450., 5504., 2870.],\n                [2691., 5612., 5666., 2954.],\n                [1473., 3062., 3091., 1608.],\n            ],\n        ],\n        [\n            [\n                [868., 1848., 1865., 994.],\n                [1972., 4177., 4215., 2231.],\n                [2026., 4291., 4329., 2291.],\n                [1144., 2408., 2429., 1280.],\n            ],\n            [\n                [2236., 4713., 4755., 2505.],\n                [4978., 10452., 10544., 5530.],\n                [5110., 10728., 10820., 5674.],\n                [2830., 5917., 5967., 3119.],\n            ],\n            [\n                [2416., 5091., 5133., 2703.],\n                [5374., 11280., 11372., 5962.],\n                [5506., 11556., 11648., 6106.],\n                [3046., 6367., 6417., 3353.],\n            ],\n            [\n                [1516., 3166., 3191., 1668.],\n                [3316., 6909., 6963., 3627.],\n                [3394., 7071., 7125., 3711.],\n                [1852., 3846., 3875., 2014.],\n            ],\n        ],\n        [\n            [\n                [1085., 2308., 2325., 1238.],\n                [2459., 5204., 5242., 2772.],\n                [2513., 5318., 5356., 2832.],\n                [1415., 2976., 2997., 1578.],\n            ],\n            [\n                [2777., 5848., 5890., 3100.],\n                [6167., 12937., 13029., 6827.],\n                [6299., 13213., 13305., 6971.],\n                [3479., 7268., 7318., 3822.],\n            ],\n            [\n                [2957., 6226., 6268., 3298.],\n                [6563., 13765., 13857., 7259.],\n                [6695., 14041., 14133., 7403.],\n                [3695., 7718., 7768., 4056.],\n            ],\n            [\n                [1841., 3842., 3867., 2020.],\n                [4019., 8368., 8422., 4384.],\n                [4097., 8530., 8584., 4468.],\n                [2231., 4630., 4659., 2420.],\n            ],\n        ],\n    ]]));\n}\n\nstruct ConvTranspose3dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    kernel_size_3: usize,\n    padding_1: usize,\n    padding_2: usize,\n    padding_3: usize,\n    padding_out_1: usize,\n    padding_out_2: usize,\n    padding_out_3: usize,\n    stride_1: usize,\n    stride_2: usize,\n    stride_3: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    dilation_3: usize,\n    groups: usize,\n    depth: usize,\n    height: usize,\n    width: usize,\n}\n\nimpl ConvTranspose3dTestCase {\n    fn assert_output(self, y: TestTensor<5>) {\n        let shape_x = Shape::new([\n            self.batch_size,\n            self.channels_in,\n            self.depth,\n            self.height,\n            self.width,\n        ]);\n        let shape_weights = Shape::new([\n            self.channels_in,\n            self.channels_out / self.groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n            self.kernel_size_3,\n        ]);\n        let device = Default::default();\n        let weights = TestTensor::from(\n            TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_weights)\n                .into_data(),\n        );\n        let bias = TestTensor::from(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n        );\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<5, _>(shape_x)\n                .into_data(),\n        );\n        let output = conv_transpose3d(\n            x,\n            weights,\n            Some(bias),\n            ConvTransposeOptions::new(\n                [self.stride_1, self.stride_2, self.stride_3],\n                [self.padding_1, self.padding_2, self.padding_3],\n                [self.padding_out_1, self.padding_out_2, self.padding_out_3],\n                [self.dilation_1, self.dilation_2, self.dilation_3],\n                self.groups,\n            ),\n        );\n\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/deform_conv2d.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::deform_conv2d;\nuse burn_tensor::ops::DeformConvOptions;\nuse burn_tensor::{Shape, Tensor};\n\n#[test]\nfn test_deform_conv2d_simple() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 5,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[0.9074, 0.6387], [0.5160, 0.4196]],\n        [[2.4259, 1.8008], [1.5449, 1.3112]],\n        [[3.9444, 2.9629], [2.5738, 2.2027]],\n        [[5.4629, 4.1250], [3.6027, 3.0943]],\n        [[6.9814, 5.2871], [4.6316, 3.9859]],\n    ]]));\n}\n\n#[test]\nfn test_deform_conv2d_batched() {\n    let test = DeformConv2dTestCase {\n        batch_size: 2,\n        channels_in: 3,\n        channels_out: 5,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([\n        [\n            [[0.215466, 0.192846], [0.193407, 0.175496]],\n            [[0.725073, 0.675926], [0.687746, 0.648506]],\n            [[1.234679, 1.159006], [1.182085, 1.121516]],\n            [[1.744286, 1.642086], [1.676423, 1.594526]],\n            [[2.253892, 2.125167], [2.170762, 2.067536]],\n        ],\n        [\n            [[1.652976, 1.136937], [0.984030, 0.718403]],\n            [[4.836801, 3.472453], [3.177263, 2.418021]],\n            [[8.020626, 5.807969], [5.370497, 4.117639]],\n            [[11.204453, 8.143486], [7.563731, 5.817256]],\n            [[14.388277, 10.479003], [9.756965, 7.516875]],\n        ],\n    ]))\n}\n\n#[test]\nfn test_deform_conv2d_weight_groups() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 6,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 3,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[0.101823, 0.065756], [0.046691, 0.036233]],\n        [[0.412523, 0.336674], [0.306863, 0.282386]],\n        [[1.307585, 1.024152], [0.902454, 0.800008]],\n        [[1.840507, 1.458072], [1.299371, 1.158781]],\n        [[3.402235, 2.634555], [2.305198, 2.014265]],\n        [[4.157379, 3.231476], [2.838861, 2.485659]],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_offset_groups() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 3,\n        channels_out: 6,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 3,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[1.0794, 0.7676], [0.7209, 0.5337]],\n        [[2.7059, 2.0216], [1.9740, 1.5419]],\n        [[4.3325, 3.2755], [3.2271, 2.5501]],\n        [[5.9590, 4.5295], [4.4802, 3.5582]],\n        [[7.5855, 5.7835], [5.7333, 4.5664]],\n        [[9.2120, 7.0375], [6.9864, 5.5746]],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_different_kernel_size() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 4,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[1.0669], [0.6329]],\n        [[2.9741], [2.0383]],\n        [[4.8812], [3.4437]],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_different_padding_size() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 2,\n        padding_2: 3,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [\n            [\n                0.199779, 0.376176, 0.528501, 0.605256, 0.384365, 0.198675, 0.048145, 0.000000,\n            ],\n            [\n                0.287923, 0.551719, 0.777562, 0.890479, 0.580469, 0.304325, 0.079554, 0.000000,\n            ],\n            [\n                0.372947, 0.721405, 1.013668, 1.151988, 0.756444, 0.393098, 0.101582, 0.000000,\n            ],\n            [\n                0.132138, 0.324872, 0.495372, 0.584617, 0.453122, 0.250084, 0.075703, 0.000000,\n            ],\n            [\n                0.059332, 0.160658, 0.244789, 0.297057, 0.239464, 0.132701, 0.047114, 0.000000,\n            ],\n            [\n                0.014338, 0.051338, 0.078303, 0.094190, 0.081278, 0.041954, 0.014506, 0.000000,\n            ],\n        ],\n        [\n            [\n                0.766652, 1.164805, 1.521938, 1.711110, 1.230500, 0.807579, 0.450423, 0.333333,\n            ],\n            [\n                0.981162, 1.601005, 2.152534, 2.440920, 1.745547, 1.091843, 0.536749, 0.333333,\n            ],\n            [\n                1.196386, 2.044845, 2.785330, 3.152243, 2.242613, 1.351308, 0.604905, 0.333333,\n            ],\n            [\n                0.669465, 1.178133, 1.644096, 1.902188, 1.573183, 1.033924, 0.553577, 0.333333,\n            ],\n            [\n                0.495048, 0.786124, 1.039796, 1.204721, 1.052342, 0.743887, 0.483380, 0.333333,\n            ],\n            [\n                0.378767, 0.498209, 0.592867, 0.654230, 0.615487, 0.488202, 0.390890, 0.333333,\n            ],\n        ],\n        [\n            [\n                1.333524, 1.953435, 2.515375, 2.816964, 2.076636, 1.416483, 0.852701, 0.666667,\n            ],\n            [\n                1.674402, 2.650291, 3.527507, 3.991360, 2.910625, 1.879361, 0.993943, 0.666667,\n            ],\n            [\n                2.019825, 3.368286, 4.556992, 5.152499, 3.728782, 2.309520, 1.108229, 0.666667,\n            ],\n            [\n                1.206791, 2.031395, 2.792820, 3.219759, 2.693245, 1.817763, 1.031452, 0.666667,\n            ],\n            [\n                0.930765, 1.411590, 1.834802, 2.112385, 1.865221, 1.355072, 0.919646, 0.666667,\n            ],\n            [\n                0.743195, 0.945081, 1.107431, 1.214270, 1.149695, 0.934451, 0.767274, 0.666667,\n            ],\n        ],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_different_stride() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 2,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[1.0647], [0.5783]],\n        [[2.9289], [1.8829]],\n        [[4.7931], [3.1875]],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_different_dilation() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 2,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 5,\n        width: 5,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [[0.6162], [0.7611], [0.4666]],\n        [[1.8578], [2.2684], [1.6208]],\n        [[3.0994], [3.7757], [2.7749]],\n    ]]))\n}\n\n#[test]\nfn test_deform_conv2d_different_width() {\n    let test = DeformConv2dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        channels_out: 3,\n        kernel_size_1: 3,\n        kernel_size_2: 3,\n        padding_1: 0,\n        padding_2: 0,\n        stride_1: 1,\n        stride_2: 1,\n        dilation_1: 1,\n        dilation_2: 1,\n        weight_groups: 1,\n        offset_groups: 1,\n        height: 6,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::<4>::from([[\n        [\n            [0.8909, 0.6016],\n            [1.0697, 0.7186],\n            [1.2618, 0.8433],\n            [0.6424, 0.5032],\n        ],\n        [\n            [2.4670, 1.8168],\n            [2.9529, 2.1497],\n            [3.4805, 2.5090],\n            [2.0925, 1.7411],\n        ],\n        [\n            [4.0432, 3.0321],\n            [4.8362, 3.5809],\n            [5.6992, 4.1746],\n            [3.5425, 2.9790],\n        ],\n    ]]))\n}\n\nstruct DeformConv2dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    channels_out: usize,\n    kernel_size_1: usize,\n    kernel_size_2: usize,\n    padding_1: usize,\n    padding_2: usize,\n    stride_1: usize,\n    stride_2: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    weight_groups: usize,\n    offset_groups: usize,\n    height: usize,\n    width: usize,\n}\n\nimpl DeformConv2dTestCase {\n    fn assert_output(self, y: Tensor<TestBackend, 4>) {\n        let out_height =\n            (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)\n                / self.stride_1\n                + 1;\n        let out_width =\n            (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)\n                / self.stride_2\n                + 1;\n\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let shape_weight = Shape::new([\n            self.channels_out,\n            self.channels_in / self.weight_groups,\n            self.kernel_size_1,\n            self.kernel_size_2,\n        ]);\n        let shape_offset = Shape::new([\n            self.batch_size,\n            self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2,\n            out_height,\n            out_width,\n        ]);\n        let shape_mask = Shape::new([\n            self.batch_size,\n            self.kernel_size_1 * self.kernel_size_2 * self.offset_groups,\n            out_height,\n            out_width,\n        ]);\n        let device = Default::default();\n        let weight = TestTensor::<4>::from(\n            TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_weight.clone())\n                .into_data(),\n        )\n        .div_scalar(shape_weight.num_elements() as f32);\n        let bias = TestTensor::<1>::from(\n            TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),\n        )\n        .div_scalar(self.channels_out as f32);\n        let x = TestTensor::<4>::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_x.clone())\n                .into_data(),\n        )\n        .div_scalar(shape_x.num_elements() as f32);\n        let offset = TestTensor::<4>::from(\n            TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_offset.clone())\n                .into_data(),\n        )\n        .div_scalar(shape_offset.num_elements() as f32);\n        let mask = TestTensor::<4>::from(\n            TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)\n                .reshape::<4, _>(shape_mask.clone())\n                .into_data(),\n        )\n        .div_scalar(shape_mask.num_elements() as f32);\n\n        let output = deform_conv2d(\n            x,\n            offset,\n            weight,\n            Some(mask),\n            Some(bias),\n            DeformConvOptions::new(\n                [self.stride_1, self.stride_2],\n                [self.padding_1, self.padding_2],\n                [self.dilation_1, self.dilation_2],\n                self.weight_groups,\n                self.offset_groups,\n            ),\n        );\n\n        let tolerance = Tolerance::permissive();\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/forward.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, module::embedding};\n\n#[test]\nfn test_embedding_forward() {\n    let weights = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TensorData::from([[0, 1], [1, 1]]);\n    let weights = TestTensor::<2>::from(weights);\n    let indices = TestTensorInt::<2>::from(indices);\n\n    let output = embedding(weights, indices);\n    let expected = TensorData::from([\n        [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n        [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/lanczos3_interpolate.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::interpolate;\nuse burn_tensor::ops::{InterpolateMode, InterpolateOptions};\n\n#[test]\nfn test_upsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 2,\n        channels: 1,\n        height: 7,\n        width: 5,\n        height_out: 8,\n        width_out: 7,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[\n            [-0.0000, 0.5685, 1.3918, 2.0000, 2.6082, 3.4315, 4.0000],\n            [4.0822, 4.6507, 5.4740, 6.0822, 6.6904, 7.5137, 8.0822],\n            [8.7971, 9.3656, 10.1889, 10.7971, 11.4053, 12.2286, 12.7971],\n            [\n                12.8964, 13.4649, 14.2882, 14.8964, 15.5046, 16.3279, 16.8964,\n            ],\n            [\n                17.1036, 17.6721, 18.4954, 19.1036, 19.7118, 20.5351, 21.1036,\n            ],\n            [\n                21.2029, 21.7715, 22.5947, 23.2029, 23.8112, 24.6344, 25.2029,\n            ],\n            [\n                25.9178, 26.4863, 27.3096, 27.9178, 28.5260, 29.3493, 29.9178,\n            ],\n            [\n                30.0000, 30.5685, 31.3918, 32.0000, 32.6082, 33.4315, 34.0000,\n            ],\n        ]],\n        [[\n            [\n                35.0000, 35.5685, 36.3918, 37.0000, 37.6082, 38.4315, 39.0000,\n            ],\n            [\n                39.0822, 39.6507, 40.4740, 41.0822, 41.6904, 42.5137, 43.0822,\n            ],\n            [\n                43.7971, 44.3656, 45.1888, 45.7971, 46.4053, 47.2286, 47.7971,\n            ],\n            [\n                47.8964, 48.4649, 49.2882, 49.8964, 50.5046, 51.3279, 51.8964,\n            ],\n            [\n                52.1036, 52.6721, 53.4954, 54.1036, 54.7118, 55.5351, 56.1036,\n            ],\n            [\n                56.2029, 56.7715, 57.5947, 58.2029, 58.8112, 59.6344, 60.2029,\n            ],\n            [\n                60.9178, 61.4863, 62.3096, 62.9178, 63.5260, 64.3493, 64.9178,\n            ],\n            [\n                65.0000, 65.5685, 66.3918, 67.0000, 67.6082, 68.4315, 69.0000,\n            ],\n        ]],\n    ]));\n}\n\n#[test]\nfn test_downsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 45,\n        width: 14,\n        height_out: 4,\n        width_out: 6,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [-0.0000, 2.6107, 5.1803, 7.8197, 10.3893, 13.0000],\n        [205.5606, 208.1713, 210.7408, 213.3802, 215.9498, 218.5606],\n        [410.4395, 413.0502, 415.6198, 418.2592, 420.8287, 423.4395],\n        [616.0000, 618.6107, 621.1803, 623.8197, 626.3893, 629.0000],\n    ]]]));\n}\n\n#[test]\nfn test_upsample_2x() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 4,\n        width: 4,\n        height_out: 8,\n        width_out: 8,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [\n            -0.0000, 0.2972, 0.8164, 1.3131, 1.6869, 2.1836, 2.7028, 3.0000,\n        ],\n        [\n            1.1889, 1.4861, 2.0053, 2.5020, 2.8758, 3.3725, 3.8917, 4.1889,\n        ],\n        [\n            3.2658, 3.5630, 4.0822, 4.5789, 4.9527, 5.4493, 5.9685, 6.2658,\n        ],\n        [\n            5.2524, 5.5496, 6.0689, 6.5655, 6.9393, 7.4360, 7.9552, 8.2524,\n        ],\n        [\n            6.7476, 7.0448, 7.5640, 8.0607, 8.4345, 8.9311, 9.4504, 9.7476,\n        ],\n        [\n            8.7342, 9.0315, 9.5507, 10.0473, 10.4211, 10.9178, 11.4370, 11.7342,\n        ],\n        [\n            10.8111, 11.1083, 11.6275, 12.1242, 12.4980, 12.9947, 13.5139, 13.8111,\n        ],\n        [\n            12.0000, 12.2972, 12.8164, 13.3131, 13.6869, 14.1836, 14.7028, 15.0000,\n        ],\n    ]]]));\n}\n\n#[test]\nfn test_upsample_half_pixel() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 4,\n        width: 4,\n        height_out: 8,\n        width_out: 8,\n    };\n\n    test.assert_output_with_align_corners(\n        TestTensor::from([[[\n            [\n                -0.4626, -0.2276, 0.3055, 0.9087, 1.3512, 1.9543, 2.4875, 2.7225,\n            ],\n            [\n                0.4773, 0.7123, 1.2454, 1.8486, 2.2911, 2.8942, 3.4274, 3.6623,\n            ],\n            [\n                2.6099, 2.8449, 3.3780, 3.9812, 4.4237, 5.0268, 5.5600, 5.7949,\n            ],\n            [\n                5.0224, 5.2574, 5.7906, 6.3937, 6.8362, 7.4394, 7.9725, 8.2075,\n            ],\n            [\n                6.7925, 7.0275, 7.5606, 8.1638, 8.6063, 9.2094, 9.7426, 9.9776,\n            ],\n            [\n                9.2051, 9.4400, 9.9732, 10.5763, 11.0188, 11.6220, 12.1551, 12.3901,\n            ],\n            [\n                11.3377, 11.5726, 12.1058, 12.7089, 13.1514, 13.7546, 14.2877, 14.5227,\n            ],\n            [\n                12.2775, 12.5125, 13.0457, 13.6488, 14.0913, 14.6945, 15.2276, 15.4626,\n            ],\n        ]]]),\n        false,\n    );\n}\n\n#[test]\nfn test_1d_lanczos3() {\n    let device = Default::default();\n\n    let input = TestTensor::<3>::from_floats(\n        [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],\n        &device,\n    );\n\n    let input = input.unsqueeze_dim(2);\n\n    let output = interpolate(\n        input,\n        [1, 9],\n        InterpolateOptions::new(InterpolateMode::Lanczos3),\n    );\n    assert_eq!(output.dims(), [1, 1, 1, 9]);\n\n    assert!(\n        !output\n            .clone()\n            .to_data()\n            .as_slice::<FloatElem>()\n            .unwrap()\n            .iter()\n            .any(|&x| x.is_nan()),\n        \"interpolate output contains NaN\"\n    );\n\n    TestTensor::<4>::from([[[[\n        1.5410, 0.7266, -1.1387, -2.2672, -0.7894, 0.6408, -0.4967, -1.4650, -1.3986,\n    ]]]])\n    .to_data()\n    .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::permissive());\n}\n\nstruct InterpolateTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl InterpolateTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        self.assert_output_with_align_corners(y, true);\n    }\n\n    fn assert_output_with_align_corners(self, y: TestTensor<4>, align_corners: bool) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n        let output = interpolate(\n            x,\n            [self.height_out, self.width_out],\n            InterpolateOptions::new(InterpolateMode::Lanczos3).with_align_corners(align_corners),\n        );\n\n        let tolerance = Tolerance::permissive();\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/linear.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::linear;\n\n#[test]\nfn test_linear_1d() {\n    let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let x = TestTensor::<1>::from([1.0, 2.0]);\n    let output = linear(x, weight, None);\n\n    let expected = TensorData::from([7.0, 10.0]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));\n}\n\n#[test]\nfn test_linear_1d_one_element_output() {\n    let weight = TestTensor::<2>::from([[3.0], [4.0]]);\n\n    let x = TestTensor::<1>::from([1.0, 2.0]);\n    let output = linear(x, weight, None);\n\n    let expected = TensorData::from([11.0]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));\n}\n\n#[test]\nfn test_linear_forward_no_bias() {\n    let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);\n\n    let output = linear(x, weight, None);\n\n    let expected = TensorData::from([[[7.0, 10.0], [15.0, 22.0]], [[-7.0, -10.0], [-15.0, -22.0]]]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));\n}\n\n#[test]\nfn test_linear_forward_with_bias() {\n    let weight = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n    let bias = Some(TestTensor::<1>::from([1.0, -1.0]));\n\n    let x = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[-1.0, -2.0], [-3.0, -4.0]]]);\n\n    let output = linear(x, weight, bias);\n\n    let expected = TensorData::from([[[8.0, 9.0], [16.0, 21.0]], [[-6.0, -11.0], [-14.0, -23.0]]]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(1e-5));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/maxpool1d.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::{max_pool1d, max_pool1d_with_indices};\n\n#[test]\nfn test_max_pool1d_simple() {\n    let kernel_size = 3;\n    let padding = 0;\n    let stride = 1;\n    let dilation = 1;\n\n    let x = TestTensor::from([[\n        [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],\n        [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],\n    ]]);\n    let y = TestTensor::<3>::from([[\n        [0.9861, 0.5474, 0.4477, 0.8221],\n        [0.949, 0.949, 0.949, 0.789],\n    ]]);\n\n    let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_different_padding_stride_kernel() {\n    let kernel_size = 3;\n    let padding = 1;\n    let stride = 2;\n    let dilation = 1;\n\n    let x = TestTensor::from([[[0.6309, 0.6112, 0.6998, 0.4708]]]);\n    let y = TestTensor::<3>::from([[[0.6309, 0.6998]]]);\n\n    let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_with_neg() {\n    let kernel_size = 3;\n    let padding = 1;\n    let stride = 1;\n    let dilation = 1;\n\n    let x = TestTensor::from([[[-0.6309, -0.6112, -0.6998, -0.4708]]]);\n    let y = TestTensor::<3>::from([[[-0.6112, -0.6112, -0.4708, -0.4708]]]);\n\n    let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_with_dilation() {\n    let kernel_size = 2;\n    let padding = 1;\n    let stride = 1;\n    let dilation = 2;\n\n    let x = TestTensor::from([[\n        [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],\n        [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],\n    ]]);\n    let y = TestTensor::<3>::from([[\n        [0.5474, 0.9861, 0.5474, 0.4477, 0.8221, 0.3548],\n        [0.5474, 0.9490, 0.7890, 0.9490, 0.7890, 0.5537],\n    ]]);\n\n    let output = max_pool1d(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool1d_with_indices() {\n    let kernel_size = 2;\n    let padding = 0;\n    let stride = 1;\n    let dilation = 1;\n\n    let x = TestTensor::from([[[0.2479, 0.6386, 0.3166, 0.5742]]]);\n    let indices = TensorData::from([[[1, 1, 3]]]);\n    let y = TestTensor::<3>::from([[[0.6386, 0.6386, 0.5742]]]);\n\n    let (output, output_indices) =\n        max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices.into_data().assert_eq(&indices, false);\n}\n\n#[test]\nfn test_max_pool1d_complex() {\n    let kernel_size = 4;\n    let padding = 2;\n    let stride = 1;\n    let dilation = 1;\n\n    let x = TestTensor::from([[[0.5388, 0.0676, 0.7122, 0.8316, 0.0653]]]);\n    let indices = TensorData::from([[[0, 2, 3, 3, 3, 3]]]);\n    let y = TestTensor::<3>::from([[[0.5388, 0.7122, 0.8316, 0.8316, 0.8316, 0.8316]]]);\n\n    let (output, output_indices) =\n        max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, false);\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices.into_data().assert_eq(&indices, false);\n}\n\n#[test]\nfn test_max_pool1d_ceil_mode() {\n    // Test ceil_mode=true produces larger output when input doesn't divide evenly by stride\n    // Input: 1x1x6, kernel: 3, stride: 2, padding: 0\n    // Floor mode: output = (6-3)/2+1 = 2 elements\n    // Ceil mode: output = ceil((6-3)/2)+1 = ceil(1.5)+1 = 3 elements\n    let kernel_size = 3;\n    let padding = 0;\n    let stride = 2;\n    let dilation = 1;\n\n    let x = TestTensor::from([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]]);\n\n    // With ceil_mode=false (floor): output is 2 elements\n    // Window 0: positions [0:3] -> max(1,2,3) = 3\n    // Window 1: positions [2:5] -> max(3,4,5) = 5\n    let y_floor = TestTensor::<3>::from([[[3.0, 5.0]]]);\n\n    let output_floor = max_pool1d(x.clone(), kernel_size, stride, padding, dilation, false);\n\n    y_floor\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output_floor.into_data(), Tolerance::default());\n\n    // With ceil_mode=true: output is 3 elements\n    // Window 0: positions [0:3] -> max(1,2,3) = 3\n    // Window 1: positions [2:5] -> max(3,4,5) = 5\n    // Window 2: positions [4:7] -> max(5,6) = 6 (partial window)\n    let y_ceil = TestTensor::<3>::from([[[3.0, 5.0, 6.0]]]);\n\n    let output_ceil = max_pool1d(x, kernel_size, stride, padding, dilation, true);\n\n    y_ceil\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output_ceil.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/maxpool2d.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::{max_pool2d, max_pool2d_with_indices};\n\n#[test]\nfn test_max_pool2d_simple() {\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let x = TestTensor::from([\n        [\n            [\n                [0.9861, 0.5474, 0.4477, 0.0732, 0.3548, 0.8221],\n                [0.8148, 0.5474, 0.9490, 0.7890, 0.5537, 0.5689],\n                [0.5986, 0.2059, 0.4897, 0.6136, 0.2965, 0.6182],\n                [0.1485, 0.9540, 0.4023, 0.6176, 0.7111, 0.3392],\n                [0.3703, 0.0472, 0.2771, 0.1868, 0.8855, 0.5605],\n                [0.5063, 0.1638, 0.9432, 0.7836, 0.8696, 0.1068],\n            ],\n            [\n                [0.8872, 0.0137, 0.1652, 0.5505, 0.6127, 0.6473],\n                [0.1128, 0.0888, 0.1152, 0.5456, 0.6199, 0.7947],\n                [0.5911, 0.7781, 0.7256, 0.6578, 0.0989, 0.9149],\n                [0.5879, 0.5189, 0.6561, 0.0578, 0.7025, 0.6426],\n                [0.9590, 0.0325, 0.6455, 0.6248, 0.2009, 0.1544],\n                [0.7339, 0.1369, 0.6598, 0.5528, 0.6775, 0.1572],\n            ],\n        ],\n        [\n            [\n                [0.6853, 0.6439, 0.4639, 0.5573, 0.2723, 0.5910],\n                [0.5419, 0.7729, 0.6743, 0.8956, 0.2997, 0.9546],\n                [0.0334, 0.2178, 0.6917, 0.4958, 0.3357, 0.6584],\n                [0.7358, 0.9074, 0.2462, 0.5159, 0.6420, 0.2441],\n                [0.7602, 0.6297, 0.6073, 0.5937, 0.8037, 0.4881],\n                [0.8859, 0.0974, 0.3954, 0.6763, 0.1078, 0.7467],\n            ],\n            [\n                [0.2991, 0.5012, 0.8024, 0.7653, 0.9378, 0.7952],\n                [0.7393, 0.2336, 0.9521, 0.2719, 0.8445, 0.0454],\n                [0.6479, 0.9822, 0.7905, 0.0318, 0.2474, 0.0628],\n                [0.9955, 0.7591, 0.4140, 0.3215, 0.4349, 0.1527],\n                [0.8064, 0.0164, 0.4002, 0.2024, 0.6128, 0.5827],\n                [0.5368, 0.7895, 0.8727, 0.7793, 0.0910, 0.3421],\n            ],\n        ],\n    ]);\n    let y = TestTensor::<4>::from([\n        [\n            [\n                [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],\n                [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],\n                [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],\n                [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],\n                [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],\n                [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],\n            ],\n            [\n                [0.8872, 0.8872, 0.5505, 0.6199, 0.7947, 0.7947],\n                [0.8872, 0.8872, 0.7781, 0.7256, 0.9149, 0.9149],\n                [0.7781, 0.7781, 0.7781, 0.7256, 0.9149, 0.9149],\n                [0.9590, 0.9590, 0.7781, 0.7256, 0.9149, 0.9149],\n                [0.9590, 0.9590, 0.6598, 0.7025, 0.7025, 0.7025],\n                [0.9590, 0.9590, 0.6598, 0.6775, 0.6775, 0.6775],\n            ],\n        ],\n        [\n            [\n                [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],\n                [0.7729, 0.7729, 0.8956, 0.8956, 0.9546, 0.9546],\n                [0.9074, 0.9074, 0.9074, 0.8956, 0.9546, 0.9546],\n                [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],\n                [0.9074, 0.9074, 0.9074, 0.8037, 0.8037, 0.8037],\n                [0.8859, 0.8859, 0.6763, 0.8037, 0.8037, 0.8037],\n            ],\n            [\n                [0.7393, 0.9521, 0.9521, 0.9521, 0.9378, 0.9378],\n                [0.9822, 0.9822, 0.9822, 0.9521, 0.9378, 0.9378],\n                [0.9955, 0.9955, 0.9822, 0.9521, 0.8445, 0.8445],\n                [0.9955, 0.9955, 0.9822, 0.7905, 0.6128, 0.6128],\n                [0.9955, 0.9955, 0.8727, 0.8727, 0.7793, 0.6128],\n                [0.8064, 0.8727, 0.8727, 0.8727, 0.7793, 0.6128],\n            ],\n        ],\n    ]);\n\n    let output = max_pool2d(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_different_padding_stride_kernel() {\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 1;\n    let padding_1 = 1;\n    let padding_2 = 0;\n    let stride_1 = 1;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let x = TestTensor::from([[[\n        [0.6309, 0.6112, 0.6998],\n        [0.4708, 0.9161, 0.5402],\n        [0.4577, 0.7397, 0.9870],\n        [0.6380, 0.4352, 0.5884],\n        [0.6277, 0.5139, 0.4525],\n        [0.9333, 0.9846, 0.5006],\n    ]]]);\n    let y = TestTensor::<4>::from([[[\n        [0.6309, 0.6998],\n        [0.6309, 0.9870],\n        [0.6380, 0.9870],\n        [0.6380, 0.9870],\n        [0.9333, 0.5884],\n        [0.9333, 0.5006],\n    ]]]);\n\n    let output = max_pool2d(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_with_neg() {\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let x = TestTensor::from([[[\n        [0.6309, 0.6112, 0.6998],\n        [0.4708, 0.9161, 0.5402],\n        [0.4577, 0.7397, 0.9870],\n        [0.6380, 0.4352, 0.5884],\n        [0.6277, 0.5139, 0.4525],\n        [0.9333, 0.9846, 0.5006],\n    ]]])\n    .neg();\n    let y = TestTensor::<4>::from([[[\n        [-0.4708, -0.4708, -0.5402],\n        [-0.4577, -0.4577, -0.5402],\n        [-0.4352, -0.4352, -0.4352],\n        [-0.4352, -0.4352, -0.4352],\n        [-0.4352, -0.4352, -0.4352],\n        [-0.5139, -0.4525, -0.4525],\n    ]]]);\n\n    let output = max_pool2d(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_with_dilation() {\n    let kernel_size_1 = 2;\n    let kernel_size_2 = 2;\n    let padding_1 = 0;\n    let padding_2 = 0;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 2;\n    let dilation_2 = 2;\n\n    let x = TestTensor::from([[[\n        [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],\n        [0.9861, 0.9861, 0.9490, 0.9490, 0.8221, 0.8221],\n        [0.9540, 0.9540, 0.9540, 0.9490, 0.7890, 0.7111],\n        [0.9540, 0.9540, 0.9540, 0.8855, 0.8855, 0.8855],\n        [0.9540, 0.9540, 0.9540, 0.9432, 0.8855, 0.8855],\n        [0.5063, 0.9432, 0.9432, 0.9432, 0.8855, 0.8855],\n    ]]]);\n    let y = TestTensor::<4>::from([[[\n        [0.9861, 0.9861, 0.9540, 0.9490],\n        [0.9861, 0.9861, 0.9540, 0.9490],\n        [0.9540, 0.9540, 0.9540, 0.9490],\n        [0.9540, 0.9540, 0.9540, 0.9432],\n    ]]]);\n\n    let output = max_pool2d(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_with_indices() {\n    let kernel_size_1 = 2;\n    let kernel_size_2 = 2;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 1;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let x = TestTensor::from([[[\n        [0.2479, 0.6386, 0.3166, 0.5742],\n        [0.7065, 0.1940, 0.6305, 0.8959],\n        [0.5416, 0.8602, 0.8129, 0.1662],\n        [0.3358, 0.3059, 0.8293, 0.0990],\n    ]]]);\n    let indices = TensorData::from([[[\n        [0, 1, 1, 3, 3],\n        [4, 4, 1, 7, 7],\n        [4, 9, 9, 7, 7],\n        [8, 9, 9, 14, 11],\n        [12, 12, 14, 14, 15],\n    ]]]);\n    let y = TestTensor::<4>::from([[[\n        [0.2479, 0.6386, 0.6386, 0.5742, 0.5742],\n        [0.7065, 0.7065, 0.6386, 0.8959, 0.8959],\n        [0.7065, 0.8602, 0.8602, 0.8959, 0.8959],\n        [0.5416, 0.8602, 0.8602, 0.8293, 0.1662],\n        [0.3358, 0.3358, 0.8293, 0.8293, 0.0990],\n    ]]]);\n\n    let (output, output_indices) = max_pool2d_with_indices(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices.into_data().assert_eq(&indices, false);\n}\n\n#[test]\nfn test_max_pool2d_complex() {\n    let kernel_size_1 = 4;\n    let kernel_size_2 = 2;\n    let padding_1 = 2;\n    let padding_2 = 1;\n    let stride_1 = 1;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    let x = TestTensor::from([[[\n        [0.5388, 0.0676, 0.7122, 0.8316, 0.0653],\n        [0.9154, 0.1536, 0.9089, 0.8016, 0.7518],\n        [0.2073, 0.0501, 0.8811, 0.5604, 0.5075],\n        [0.4384, 0.9963, 0.9698, 0.4988, 0.2609],\n        [0.3391, 0.2230, 0.4610, 0.5365, 0.6880],\n    ]]]);\n    let indices = TensorData::from([[[\n        [5, 7, 3],\n        [5, 7, 3],\n        [5, 16, 3],\n        [5, 16, 8],\n        [15, 16, 24],\n        [15, 16, 24],\n    ]]]);\n    let y = TestTensor::<4>::from([[[\n        [0.9154, 0.9089, 0.8316],\n        [0.9154, 0.9089, 0.8316],\n        [0.9154, 0.9963, 0.8316],\n        [0.9154, 0.9963, 0.8016],\n        [0.4384, 0.9963, 0.688],\n        [0.4384, 0.9963, 0.688],\n    ]]]);\n    let (output, output_indices) = max_pool2d_with_indices(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y.to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices.into_data().assert_eq(&indices, false);\n}\n\n#[test]\nfn test_max_pool2d_ceil_mode() {\n    // Test ceil_mode=true which produces larger output when input doesn't divide evenly by stride\n    // Using 1x1x6x6 with kernel 3x3, stride 2x2, padding 0:\n    // Floor mode: output = (6+0-1*(3-1)-1)/2+1 = 3/2+1 = 2 x 2\n    // Ceil mode: output = ceil(3/2)+1 = 2+1 = 3 x 3\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 0;\n    let padding_2 = 0;\n    let stride_1 = 2;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    // Input (values 1-36 arranged row by row):\n    // col:    0  1  2  3  4  5\n    // row 0:  1  2  3  4  5  6\n    // row 1:  7  8  9  10 11 12\n    // row 2: 13 14 15 16 17 18\n    // row 3: 19 20 21 22 23 24\n    // row 4: 25 26 27 28 29 30\n    // row 5: 31 32 33 34 35 36\n    let x = TestTensor::from([[[\n        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],\n        [7.0, 8.0, 9.0, 10.0, 11.0, 12.0],\n        [13.0, 14.0, 15.0, 16.0, 17.0, 18.0],\n        [19.0, 20.0, 21.0, 22.0, 23.0, 24.0],\n        [25.0, 26.0, 27.0, 28.0, 29.0, 30.0],\n        [31.0, 32.0, 33.0, 34.0, 35.0, 36.0],\n    ]]]);\n\n    // With ceil_mode=false (floor): output is 2x2\n    // (0,0): rows 0-2, cols 0-2 -> max(1,2,3,7,8,9,13,14,15) = 15\n    // (0,1): rows 0-2, cols 2-4 -> max(3,4,5,9,10,11,15,16,17) = 17\n    // (1,0): rows 2-4, cols 0-2 -> max(13,14,15,19,20,21,25,26,27) = 27\n    // (1,1): rows 2-4, cols 2-4 -> max(15,16,17,21,22,23,27,28,29) = 29\n    let y_floor = TestTensor::<4>::from([[[[15.0, 17.0], [27.0, 29.0]]]]);\n\n    let output_floor = max_pool2d(\n        x.clone(),\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        false,\n    );\n\n    y_floor\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output_floor.into_data(), Tolerance::default());\n\n    // With ceil_mode=true: output is 3x3\n    // Extra windows at edges use only available input values (padded with -inf for max pooling)\n    // (0,0): rows 0-2, cols 0-2 -> max = 15\n    // (0,1): rows 0-2, cols 2-4 -> max = 17\n    // (0,2): rows 0-2, cols 4-5 -> max(5,6,11,12,17,18) = 18\n    // (1,0): rows 2-4, cols 0-2 -> max = 27\n    // (1,1): rows 2-4, cols 2-4 -> max = 29\n    // (1,2): rows 2-4, cols 4-5 -> max(17,18,23,24,29,30) = 30\n    // (2,0): rows 4-5, cols 0-2 -> max(25,26,27,31,32,33) = 33\n    // (2,1): rows 4-5, cols 2-4 -> max(27,28,29,33,34,35) = 35\n    // (2,2): rows 4-5, cols 4-5 -> max(29,30,35,36) = 36\n    let y_ceil =\n        TestTensor::<4>::from([[[[15.0, 17.0, 18.0], [27.0, 29.0, 30.0], [33.0, 35.0, 36.0]]]]);\n\n    let output_ceil = max_pool2d(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        true,\n    );\n\n    y_ceil\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output_ceil.into_data(), Tolerance::default());\n}\n\n#[test]\nfn test_max_pool2d_ceil_mode_with_indices() {\n    // Test ceil_mode=true with indices to verify correct index calculation\n    // when pooling windows extend beyond original input bounds\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 0;\n    let padding_2 = 0;\n    let stride_1 = 2;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    // Input 6x6 (indices 0-35 in row-major order):\n    // row 0:  0  1  2  3  4  5\n    // row 1:  6  7  8  9 10 11\n    // row 2: 12 13 14 15 16 17\n    // row 3: 18 19 20 21 22 23\n    // row 4: 24 25 26 27 28 29\n    // row 5: 30 31 32 33 34 35\n    let x = TestTensor::from([[[\n        [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0, 9.0, 10.0, 11.0],\n        [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],\n        [18.0, 19.0, 20.0, 21.0, 22.0, 23.0],\n        [24.0, 25.0, 26.0, 27.0, 28.0, 29.0],\n        [30.0, 31.0, 32.0, 33.0, 34.0, 35.0],\n    ]]]);\n\n    // With ceil_mode=true: output is 3x3\n    // (0,0): rows 0-2, cols 0-2 -> max at index 14\n    // (0,1): rows 0-2, cols 2-4 -> max at index 16\n    // (0,2): rows 0-2, cols 4-5 -> max at index 17\n    // (1,0): rows 2-4, cols 0-2 -> max at index 26\n    // (1,1): rows 2-4, cols 2-4 -> max at index 28\n    // (1,2): rows 2-4, cols 4-5 -> max at index 29\n    // (2,0): rows 4-5, cols 0-2 -> max at index 32\n    // (2,1): rows 4-5, cols 2-4 -> max at index 34\n    // (2,2): rows 4-5, cols 4-5 -> max at index 35\n    let expected_values =\n        TestTensor::<4>::from([[[[14.0, 16.0, 17.0], [26.0, 28.0, 29.0], [32.0, 34.0, 35.0]]]]);\n    let expected_indices = TensorData::from([[[[14i64, 16, 17], [26, 28, 29], [32, 34, 35]]]]);\n\n    let (output, output_indices) = max_pool2d_with_indices(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        true,\n    );\n\n    expected_values\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices\n        .into_data()\n        .assert_eq(&expected_indices, false);\n}\n\n#[test]\nfn test_max_pool2d_ceil_mode_with_indices_and_padding() {\n    // Test ceil_mode=true with padding and indices to verify correct index calculation\n    // This exercises the case where both user padding and ceil_mode extra padding apply\n    let kernel_size_1 = 3;\n    let kernel_size_2 = 3;\n    let padding_1 = 1;\n    let padding_2 = 1;\n    let stride_1 = 2;\n    let stride_2 = 2;\n    let dilation_1 = 1;\n    let dilation_2 = 1;\n\n    // Input 5x5 (indices 0-24 in row-major order):\n    // row 0:  0  1  2  3  4\n    // row 1:  5  6  7  8  9\n    // row 2: 10 11 12 13 14\n    // row 3: 15 16 17 18 19\n    // row 4: 20 21 22 23 24\n    let x = TestTensor::from([[[\n        [0.0, 1.0, 2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0, 8.0, 9.0],\n        [10.0, 11.0, 12.0, 13.0, 14.0],\n        [15.0, 16.0, 17.0, 18.0, 19.0],\n        [20.0, 21.0, 22.0, 23.0, 24.0],\n    ]]]);\n\n    // With padding=1, ceil_mode=true:\n    // Effective input is 7x7 (5 + 2*1)\n    // Output size: ceil((5 + 2*1 - 3) / 2) + 1 = ceil(4/2) + 1 = 3\n    //\n    // Windows (with -inf padding at boundaries):\n    // (0,0): rows -1 to 1, cols -1 to 1 -> valid: (0,0) to (1,1), max at (1,1)=6\n    // (0,1): rows -1 to 1, cols 1 to 3 -> max at (1,3)=8\n    // (0,2): rows -1 to 1, cols 3 to 5 -> max at (1,4)=9\n    // (1,0): rows 1 to 3, cols -1 to 1 -> max at (3,1)=16\n    // (1,1): rows 1 to 3, cols 1 to 3 -> max at (3,3)=18\n    // (1,2): rows 1 to 3, cols 3 to 5 -> max at (3,4)=19\n    // (2,0): rows 3 to 5, cols -1 to 1 -> max at (4,1)=21\n    // (2,1): rows 3 to 5, cols 1 to 3 -> max at (4,3)=23\n    // (2,2): rows 3 to 5, cols 3 to 5 -> max at (4,4)=24\n    let expected_values =\n        TestTensor::<4>::from([[[[6.0, 8.0, 9.0], [16.0, 18.0, 19.0], [21.0, 23.0, 24.0]]]]);\n    let expected_indices = TensorData::from([[[[6i64, 8, 9], [16, 18, 19], [21, 23, 24]]]]);\n\n    let (output, output_indices) = max_pool2d_with_indices(\n        x,\n        [kernel_size_1, kernel_size_2],\n        [stride_1, stride_2],\n        [padding_1, padding_2],\n        [dilation_1, dilation_2],\n        true,\n    );\n\n    expected_values\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    output_indices\n        .into_data()\n        .assert_eq(&expected_indices, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/mod.rs",
    "content": "use super::*;\n\nmod adaptive_avgpool1d;\nmod adaptive_avgpool2d;\nmod attention;\nmod avgpool1d;\nmod avgpool2d;\nmod bicubic_interpolate;\nmod bilinear_interpolate;\nmod conv1d;\nmod conv2d;\nmod conv3d;\nmod conv_transpose1d;\nmod conv_transpose2d;\nmod conv_transpose3d;\nmod deform_conv2d;\nmod forward;\nmod lanczos3_interpolate;\nmod linear;\nmod maxpool1d;\nmod maxpool2d;\nmod nearest_interpolate;\nmod unfold4d;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/nearest_interpolate.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::interpolate;\nuse burn_tensor::ops::{InterpolateMode, InterpolateOptions};\n\n#[test]\nfn test_upsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 2,\n        channels: 1,\n        height: 7,\n        width: 5,\n        height_out: 8,\n        width_out: 7,\n    };\n\n    test.assert_output(TestTensor::from([\n        [[\n            [0., 0., 1., 2., 2., 3., 4.],\n            [0., 0., 1., 2., 2., 3., 4.],\n            [5., 5., 6., 7., 7., 8., 9.],\n            [10., 10., 11., 12., 12., 13., 14.],\n            [15., 15., 16., 17., 17., 18., 19.],\n            [20., 20., 21., 22., 22., 23., 24.],\n            [25., 25., 26., 27., 27., 28., 29.],\n            [30., 30., 31., 32., 32., 33., 34.],\n        ]],\n        [[\n            [35., 35., 36., 37., 37., 38., 39.],\n            [35., 35., 36., 37., 37., 38., 39.],\n            [40., 40., 41., 42., 42., 43., 44.],\n            [45., 45., 46., 47., 47., 48., 49.],\n            [50., 50., 51., 52., 52., 53., 54.],\n            [55., 55., 56., 57., 57., 58., 59.],\n            [60., 60., 61., 62., 62., 63., 64.],\n            [65., 65., 66., 67., 67., 68., 69.],\n        ]],\n    ]));\n}\n\n#[test]\nfn test_downsample_interpolation() {\n    let test = InterpolateTestCase {\n        batch_size: 1,\n        channels: 1,\n        height: 45,\n        width: 14,\n        height_out: 4,\n        width_out: 6,\n    };\n\n    test.assert_output(TestTensor::from([[[\n        [0., 2., 4., 7., 9., 11.],\n        [154., 156., 158., 161., 163., 165.],\n        [308., 310., 312., 315., 317., 319.],\n        [462., 464., 466., 469., 471., 473.],\n    ]]]));\n}\n\n#[test]\nfn test_1d_nearest() {\n    // Initialize the model without weights (because the exported file does not contain them)\n    let device = Default::default();\n\n    // Run the model\n    let input = TestTensor::<3>::from_floats(\n        [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],\n        &device,\n    );\n\n    let input = input.unsqueeze_dim(2);\n\n    let output = interpolate(\n        input,\n        [1, 9],\n        InterpolateOptions::new(InterpolateMode::Nearest),\n    );\n    assert_eq!(output.dims(), [1, 1, 1, 9]);\n\n    // assert output data does not contain NaN\n    assert!(\n        !output\n            .clone()\n            .to_data()\n            .as_slice::<FloatElem>()\n            .unwrap()\n            .iter()\n            .any(|&x| x.is_nan()),\n        \"interpolate output contains NaN\"\n    );\n\n    TestTensor::<4>::from([[[[\n        1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986,\n    ]]]])\n    .to_data()\n    .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n}\n\nstruct InterpolateTestCase {\n    batch_size: usize,\n    channels: usize,\n    height: usize,\n    width: usize,\n    height_out: usize,\n    width_out: usize,\n}\n\nimpl InterpolateTestCase {\n    fn assert_output(self, y: TestTensor<4>) {\n        let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())\n                .reshape::<4, _>(shape_x)\n                .into_data()\n                .convert::<f32>(),\n        );\n        let output = interpolate(\n            x,\n            [self.height_out, self.width_out],\n            InterpolateOptions::new(InterpolateMode::Nearest),\n        );\n\n        y.to_data()\n            .assert_approx_eq::<FloatElem>(&output.into_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/module/unfold4d.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\nuse burn_tensor::Tolerance;\nuse burn_tensor::module::unfold4d;\nuse burn_tensor::ops::UnfoldOptions;\n\n#[test]\nfn test_unfold4d_shape() {\n    let test = Unfold4dTestCase {\n        batch_size: 2,\n        channels_in: 5,\n        kernel_size: [2, 3],\n        padding: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        height: 3,\n        width: 4,\n    };\n\n    test.assert_shape([2, 30, 4]);\n}\n\n#[test]\nfn test_unfold4d_simple() {\n    let test = Unfold4dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        kernel_size: [2, 2],\n        padding: [0, 0],\n        stride: [1, 1],\n        dilation: [1, 1],\n        height: 4,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0., 1., 2., 4., 5., 6., 8., 9., 10.],\n        [1., 2., 3., 5., 6., 7., 9., 10., 11.],\n        [4., 5., 6., 8., 9., 10., 12., 13., 14.],\n        [5., 6., 7., 9., 10., 11., 13., 14., 15.],\n        [16., 17., 18., 20., 21., 22., 24., 25., 26.],\n        [17., 18., 19., 21., 22., 23., 25., 26., 27.],\n        [20., 21., 22., 24., 25., 26., 28., 29., 30.],\n        [21., 22., 23., 25., 26., 27., 29., 30., 31.],\n    ]]));\n}\n\n#[test]\nfn test_unfold4d_complex() {\n    let test = Unfold4dTestCase {\n        batch_size: 1,\n        channels_in: 2,\n        kernel_size: [2, 3],\n        padding: [0, 1],\n        stride: [1, 2],\n        dilation: [1, 2],\n        height: 3,\n        width: 4,\n    };\n\n    test.assert_output(TestTensor::from([[\n        [0., 0.],\n        [1., 5.],\n        [3., 7.],\n        [0., 0.],\n        [5., 9.],\n        [7., 11.],\n        [0., 0.],\n        [13., 17.],\n        [15., 19.],\n        [0., 0.],\n        [17., 21.],\n        [19., 23.],\n    ]]));\n}\n\nstruct Unfold4dTestCase {\n    batch_size: usize,\n    channels_in: usize,\n    kernel_size: [usize; 2],\n    padding: [usize; 2],\n    stride: [usize; 2],\n    dilation: [usize; 2],\n    height: usize,\n    width: usize,\n}\n\nimpl Unfold4dTestCase {\n    fn assert_shape(self, expected_shape: [usize; 3]) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())\n                .reshape::<4, _>(shape_x)\n                .into_data()\n                .convert::<f32>(),\n        );\n\n        let output = unfold4d(\n            x,\n            self.kernel_size,\n            UnfoldOptions::new(self.stride, self.padding, self.dilation),\n        );\n\n        assert_eq!(\n            output.shape().as_slice(),\n            expected_shape,\n            \"Expected shape doesn't match the actual shape\"\n        );\n    }\n\n    fn assert_output(self, expected: TestTensor<3>) {\n        let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);\n        let x = TestTensor::from(\n            TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())\n                .reshape::<4, _>(shape_x)\n                .into_data(),\n        );\n\n        let output = unfold4d(\n            x,\n            self.kernel_size,\n            UnfoldOptions::new(self.stride, self.padding, self.dilation),\n        );\n\n        let tolerance = Tolerance::default()\n            .set_half_precision_relative(2e-3)\n            .set_half_precision_absolute(2e-3);\n        output\n            .into_data()\n            .assert_approx_eq::<FloatElem>(&expected.into_data(), tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/abs.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_abs_ops_float() {\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);\n\n    let output = tensor.abs();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/add.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, backend::Backend};\n\n#[test]\nfn test_add_d2() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output.into_data().assert_eq(\n        &TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_add_broadcast() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0]]);\n    let tensor_2 = TestTensor::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output.into_data().assert_eq(\n        &TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_add_different_strides_rhs() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1;\n\n    let output = tensor_1 + tensor_2.transpose();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false);\n}\n\n#[test]\nfn test_add_different_strides_lhs() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = TestTensor::from([[4.0, 5.0], [6.0, 7.0]]) * 1;\n\n    let output = tensor_1.transpose() + tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[4.0, 7.0], [7.0, 10.0]]), false);\n}\n\n#[test]\nfn test_add_different_strides_broadcast() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = TestTensor::from([[4.0, 5.0]]) * 1;\n\n    let output = tensor_1.transpose() + tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[4.0, 7.0], [5.0, 8.0]]), false);\n}\n\n#[test]\nfn should_support_add_scalar_ops() {\n    let scalar = 2.0;\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor + scalar;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]), false);\n}\n\n#[test]\nfn add_maybe_fused_not_contiguous() {\n    let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor2 = TestTensorInt::arange(8..16, &Default::default()).float();\n    let tensor1 = tensor1.reshape([2, 4]);\n    let tensor2 = tensor2.reshape([4, 2]);\n    let tensor2 = tensor2.swap_dims(0, 1);\n\n    TestBackend::sync(&tensor2.device()).unwrap();\n\n    let output = tensor1 + tensor2;\n\n    output.into_data().assert_eq(\n        &TensorData::from([[8.0, 11.0, 14.0, 17.0], [13.0, 16.0, 19.0, 22.0]]),\n        false,\n    );\n}\n\n#[test]\nfn add_maybe_fused_not_contiguous_broadcasted() {\n    let tensor1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor2 = TestTensorInt::arange(8..10, &Default::default()).float();\n    let tensor1 = tensor1.reshape([2, 4]);\n    let tensor2 = tensor2.reshape([1, 2]);\n    let tensor2 = tensor2.swap_dims(0, 1);\n\n    TestBackend::sync(&tensor2.device()).unwrap();\n\n    let output = tensor2 + tensor1;\n\n    output.into_data().assert_eq(\n        &TensorData::from([[8.0, 9.0, 10.0, 11.0], [13.0, 14.0, 15.0, 16.0]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/aggregation.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse burn_tensor::backend::Backend;\n\n#[test]\nfn test_should_mean() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.mean();\n    let expected = TensorData::from([15.0 / 6.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_should_sum() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sum();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([15.0]), false);\n}\n\n#[test]\nfn test_should_sum_dim_maybe_fused() {\n    let tensor = TestTensor::<2>::from([[5.0], [-12.0]]);\n    let tensor1 = TestTensor::<2>::from([[2.0, 3.0], [-1.0, -5.0]]);\n    let ones = TestTensor::<2>::ones([2, 2], &Default::default());\n    let _x = ones.clone() * tensor;\n    let y = ones * tensor1;\n\n    let output = y.clone().sum_dim(1);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[5.0], [-6.0]]), false);\n\n    // Negative Indexing.\n    let output = y.clone().sum_dim(-1);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[5.0], [-6.0]]), false);\n}\n\n#[test]\nfn test_should_mean_last_dim() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.clone().mean_dim(1);\n    let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    // Negative Indexing.\n    let output = tensor.clone().mean_dim(-1);\n    let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_should_sum_last_dim() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sum_dim(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0], [12.0]]), false);\n}\n\n#[test]\nfn test_should_sum_first_dim() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);\n\n    let output = tensor.sum_dim(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[7.0, 3.0, 5.0]]), false);\n}\n\n#[test]\nfn test_should_mean_first_dim() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);\n\n    let output = tensor.mean_dim(0);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_should_sum_mid_dim_3d_non_contiguous_1() {\n    let tensor = TestTensor::<3>::from([\n        [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],\n        [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],\n    ]);\n\n    let output = tensor.swap_dims(0, 2).sum_dim(1);\n\n    output.into_data().assert_eq(\n        &TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]),\n        false,\n    );\n}\n\n#[test]\nfn test_should_sum_mid_dim_3d_non_contiguous_2() {\n    let tensor = TestTensor::<3>::from([\n        [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],\n        [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],\n    ]);\n\n    let output = tensor.swap_dims(0, 1).sum_dim(1);\n\n    output.into_data().assert_eq(\n        &TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]),\n        false,\n    );\n}\n\n#[test]\nfn test_prod_float() {\n    let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor.prod();\n\n    // 2 * 1 * 2 * 3 * 4 * 5 = 240 but we need to check the precision because of the float\n    let expected = TensorData::from([240.0]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor_with_zero.prod();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([0.0]), false);\n}\n\n#[test]\nfn test_prod_dim_float() {\n    let tensor = TestTensor::<2>::from([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor.prod_dim(1);\n    let expected = TensorData::from([[4.0], [60.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let tensor_with_zero = TestTensor::<2>::from([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor_with_zero.prod_dim(1);\n    let expected = TensorData::from([[0.0], [60.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_sum_dim_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let output = tensor.clone().sum_dim(1);\n    let expected = TensorData::from([[3.], [12.]]);\n\n    output.into_data().assert_eq(&expected, false);\n\n    let output = tensor.sum_dim(0);\n    let expected = TensorData::from([[3., 5., 7.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dims_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    tensor\n        .clone()\n        .sum_dims(&[1])\n        .to_data()\n        .assert_eq(&TensorData::from([[3.], [12.]]), false);\n\n    tensor\n        .clone()\n        .sum_dims(&[-1])\n        .to_data()\n        .assert_eq(&TensorData::from([[3.], [12.]]), false);\n\n    tensor\n        .clone()\n        .sum_dims(&[0, 1])\n        .to_data()\n        .assert_eq(&TensorData::from([[15.]]), false);\n}\n\n#[test]\nfn test_sum_and_squeeze_dims() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],\n            [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],\n        ],\n        &Default::default(),\n    );\n\n    tensor\n        .sum_dims_squeeze::<1, _>(&[0, 1])\n        .to_data()\n        .assert_eq(&TensorData::from([20., 16., 21.]), false);\n}\n\n#[test]\nfn test_sum_dim_1_reshape_maybe_fused() {\n    let tensor = TestTensorInt::arange(0..9, &Default::default()).float();\n    TestBackend::sync(&tensor.device()).unwrap();\n\n    let output = tensor.reshape([3, 3]) + 2;\n    let output = output.sum_dim(1);\n    let expected = TensorData::from([[9.0], [18.0], [27.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_1_swap_dims_maybe_fused() {\n    let tensor = TestTensorInt::arange(0..9, &Default::default()).float();\n    let tensor = tensor.reshape([3, 3]);\n    TestBackend::sync(&tensor.device()).unwrap();\n\n    let output = tensor.swap_dims(0, 1) + 2;\n    let output = output.sum_dim(1);\n    let expected = TensorData::from([[15.0], [18.0], [21.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_2_reshape_maybe_fused_broadcast() {\n    let tensor = TestTensorInt::arange(0..9, &Default::default()).float();\n    TestBackend::sync(&tensor.device()).unwrap();\n\n    let output = tensor.reshape([1, 3, 3]) + 2;\n    let output = output.sum_dim(2);\n    let expected = TensorData::from([[[9.0], [18.0], [27.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_2_maybe_fused_on_write() {\n    let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor_2 = TestTensorInt::arange(10..12, &Default::default()).float();\n    let tensor_1 = tensor_1.reshape([1, 2, 4]);\n    let tensor_2 = tensor_2.reshape([1, 2, 1]);\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let output = (tensor_1 + tensor_2.clone()).sum_dim(2) + tensor_2;\n    TestBackend::sync(&output.device()).unwrap();\n    let expected = TensorData::from([[[56.0], [77.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_3_maybe_fused_on_read_not_contiguous() {\n    let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();\n\n    let tensor_1 = tensor_1.reshape([4, 2, 1]);\n    let tensor_1 = tensor_1.swap_dims(0, 2);\n\n    let tensor_2 = tensor_2.reshape([1, 4, 2]);\n    let tensor_2 = tensor_2.swap_dims(1, 2);\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let output = (tensor_1 + tensor_2).sum_dim(2);\n    let expected = TensorData::from([[[88.0], [96.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_4_maybe_fused_on_read_not_contiguous_mixed() {\n    let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();\n    let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float();\n\n    let tensor_1 = tensor_1.reshape([4, 2, 1]);\n    let tensor_3 = tensor_3.reshape([1, 2, 4]);\n    let tensor_1 = tensor_1.swap_dims(0, 2);\n\n    let tensor_2 = tensor_2.reshape([1, 4, 2]);\n    let tensor_2 = tensor_2.swap_dims(1, 2);\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(2);\n    let expected = TensorData::from([[[222.0], [246.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_5_maybe_fused_on_read_not_contiguous_mixed() {\n    let tensor_1 = TestTensorInt::arange(0..8, &Default::default()).float();\n    let tensor_2 = TestTensorInt::arange(16..24, &Default::default()).float();\n    let tensor_3 = TestTensorInt::arange(32..40, &Default::default()).float();\n\n    let tensor_1 = tensor_1.reshape([4, 2, 1]);\n    let tensor_3 = tensor_3.reshape([1, 2, 4]);\n    let tensor_1 = tensor_1.swap_dims(0, 2);\n\n    let tensor_2 = tensor_2.reshape([1, 4, 2]);\n    let tensor_2 = tensor_2.swap_dims(1, 2);\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let output = (tensor_3 + tensor_1 + tensor_2).sum_dim(1);\n    let expected = TensorData::from([[[102.0, 112.0, 122.0, 132.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_6_maybe_fused_on_read_not_contiguous_broadcasted() {\n    let tensor_1 = TestTensorInt::arange(0..32, &Default::default()).float();\n    let tensor_2 = TestTensorInt::arange(0..8, &Default::default()).float();\n\n    let tensor_1 = tensor_1.reshape([4, 2, 2, 2]);\n    let tensor_1 = tensor_1.swap_dims(3, 2);\n    let tensor_1 = tensor_1.swap_dims(1, 2);\n\n    let tensor_2 = tensor_2.reshape([1, 2, 2, 2]);\n\n    TestBackend::sync(&tensor_1.device()).unwrap();\n    let sum = tensor_2.clone().sum_dim(0);\n    let sum = sum.sum_dim(1);\n    let sum = sum.sum_dim(2);\n\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let _tmp = sum.clone() + 2;\n    let output = (tensor_1 + tensor_2 + sum).sum_dim(1);\n    let expected = TensorData::from([\n        [[[29.0, 43.0], [41.0, 55.0]]],\n        [[[45.0, 59.0], [57.0, 71.0]]],\n        [[[61.0, 75.0], [73.0, 87.0]]],\n        [[[77.0, 91.0], [89.0, 103.0]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sum_dim_7_maybe_fused_on_read_reshaped() {\n    let tensor_1 = TestTensorInt::arange(0..16, &Default::default()).float();\n\n    let tensor_1 = tensor_1.reshape([4, 4]);\n\n    TestBackend::sync(&tensor_1.device()).unwrap();\n\n    let reshaped = tensor_1.reshape([1, 4, 4]);\n    let tmp = reshaped + 5.0;\n    let output = tmp.sum_dim(2);\n    let expected = TensorData::from([[[26.0], [42.0], [58.0], [74.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_mean_dim_fused_on_read_on_write() {\n    // https://github.com/tracel-ai/burn/issues/3987\n    let device = Default::default();\n    let x = TestTensor::ones([128, 32, 1], &device);\n\n    let weight = TestTensor::ones([1, 32, 1], &device);\n    let options = burn_tensor::ops::ConvOptions::new([1], [0], [1], 1);\n    let x = burn_tensor::module::conv1d(x, weight, None, options);\n    let global = x.clone().powi_scalar(2).sum_dim(2).add_scalar(1e-5).sqrt();\n    let norm = global.clone().div(global.mean_dim(1));\n    let x = x.clone().mul(norm).add(x);\n\n    let out = x.sum();\n\n    out.into_data()\n        .assert_eq(&TensorData::from([8192.0]), false);\n}\n\n#[test]\nfn test_mean_dim_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let output = tensor.clone().mean_dim(1);\n    let expected = TensorData::from([[1.], [4.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    let output = tensor.mean_dim(0);\n    let expected = TensorData::from([[1.5, 2.5, 3.5]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_mean_dims_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    tensor\n        .clone()\n        .mean_dims(&[1])\n        .to_data()\n        .assert_eq(&TensorData::from([[1.], [4.]]), false);\n\n    tensor\n        .clone()\n        .mean_dims(&[-1])\n        .to_data()\n        .assert_eq(&TensorData::from([[1.], [4.]]), false);\n\n    tensor\n        .clone()\n        .mean_dims(&[0, 1])\n        .to_data()\n        .assert_eq(&TensorData::from([[2.5]]), false);\n}\n\n#[test]\nfn test_multiple_reduce_dims_permuted() {\n    // Regression test for https://github.com/tracel-ai/burn/issues/4461\n    let tensor = TestTensorInt::arange(0..2 * 2 * 256, &Default::default())\n        .float()\n        .reshape([2, 2, 256]);\n\n    let output = tensor\n        .permute([1, 2, 0])\n        .mean_dim(0)\n        .mean_dim(1)\n        .squeeze_dims::<1>(&[0, 1]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([255.5, 767.5]), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/all.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_all() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_all_dim() {\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let data_actual = tensor.all_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/any.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_any() {\n    // test float tensor\n    let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n\n    // test int tensor\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [0, 0, 0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n\n    // test bool tensor\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_any_dim() {\n    let tensor = TestTensor::<2>::from([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n\n    // test int tensor\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n\n    // test bool tensor\n    let tensor = TestTensorBool::<2>::from([[false, false, false], [true, true, false]]);\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/arg.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_argmax_2d_dim0() {\n    let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.argmax(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 0, 1]]), false);\n}\n\n#[test]\nfn test_argmin_2d_dim0() {\n    let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);\n\n    let output = tensor.argmin(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 0]]), false);\n}\n\n#[test]\nfn test_argmax_2d_dim1() {\n    let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.argmax(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1], [2]]), false);\n}\n\n#[test]\nfn test_argmin_2d_dim1() {\n    let tensor = TestTensor::<2>::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);\n\n    let output = tensor.argmin(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2], [1]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/cast.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{DType, TensorData};\n\n#[test]\nfn cast_float_to_bool() {\n    let tensor1 = TestTensor::<2>::from([[0.0, 43.0, 0.0], [2.0, -4.2, 31.33]]);\n    let data_actual = tensor1.bool().into_data();\n    let data_expected = TensorData::from([[false, true, false], [true, true, true]]);\n    data_actual.assert_eq(&data_expected, false);\n}\n\n#[test]\nfn cast_float_to_int() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]).int();\n    let expected = TensorData::from([[1, 2, 3], [4, 5, 6]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn cast_int_to_float_tensor() {\n    let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]).float();\n\n    let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn cast_bool_to_float_tensor() {\n    let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).float();\n\n    let expected = TensorData::from([[1., 0., 1.], [0., 0., 1.]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn cast_float_precision() {\n    let data = TensorData::from([[1.0, 2.0, 3.0], [4.4, 5.5, 6.6]]);\n    let tensor = TestTensor::<2>::from(data.clone());\n\n    let output = tensor.cast(DType::F32);\n\n    assert_eq!(output.dtype(), DType::F32);\n    // Use precision 2 for parameterized tests in f16 and bf16\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/cat.rs",
    "content": "use super::*;\nuse alloc::vec::Vec;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{DType, TensorData};\n\n#[test]\nfn should_support_cat_ops_2d_dim0() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_ops_2d_dim1() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);\n    let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_ops_3d() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);\n    let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_dimensions_are_not_the_same() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device);\n\n    TestTensor::cat(vec![tensor_1, tensor_2], 0).into_data();\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_list_of_vectors_is_empty() {\n    let tensor: Vec<TestTensor<2>> = vec![];\n    TestTensor::cat(tensor, 0).into_data();\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_cat_exceeds_dimension() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);\n    let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);\n\n    TestTensor::cat(vec![tensor_1, tensor_2], 3).into_data();\n}\n\n#[test]\nfn should_support_cat_ops_cast_dtype() {\n    let device = Default::default();\n    // ok for f32 backends, casts dtype for f16 tests\n    let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device)\n        .cast(DType::F32);\n    let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device).cast(DType::F32);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_with_empty_tensor() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let tensor_2: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor with size 0 on dim 1\n\n    // Concatenating with an empty tensor should just return the non-empty tensor\n    let output = TestTensor::cat(vec![tensor_1.clone(), tensor_2], 1);\n    let expected = TensorData::from([[1.0, 2.0, 3.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_with_empty_tensor_first() {\n    let device = Default::default();\n    let tensor_1: TestTensor<2> = TestTensor::empty([1, 0], &device); // Empty tensor\n    let tensor_2 = TestTensor::<2>::from_data([[4.0, 5.0, 6.0]], &device);\n\n    // Empty tensor first, then non-empty\n    let output = TestTensor::cat(vec![tensor_1, tensor_2.clone()], 1);\n    let expected = TensorData::from([[4.0, 5.0, 6.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_with_multiple_empty_tensors() {\n    let device = Default::default();\n    let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device);\n    let tensor_2 = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let tensor_3: TestTensor<2> = TestTensor::empty([2, 0], &device);\n    let tensor_4 = TestTensor::<2>::from_data([[5.0], [6.0]], &device);\n\n    // Mix of empty and non-empty tensors\n    let output = TestTensor::cat(vec![tensor_1, tensor_2, tensor_3, tensor_4], 1);\n    let expected = TensorData::from([[1.0, 2.0, 5.0], [3.0, 4.0, 6.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_cat_all_empty_tensors() {\n    let device = Default::default();\n    let tensor_1: TestTensor<2> = TestTensor::empty([2, 0], &device);\n    let tensor_2: TestTensor<2> = TestTensor::empty([2, 0], &device);\n\n    // All empty tensors should produce an empty tensor\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);\n\n    assert_eq!(output.shape().as_slice(), [2, 0]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/ceil.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_ceil_ops() {\n    let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.ceil();\n    let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/chunk.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_chunk_evenly_divisible() {\n    let tensors = TestTensorInt::arange(0..12, &Default::default())\n        .float()\n        .chunk(6, 0);\n    assert_eq!(tensors.len(), 6);\n\n    let expected = [\n        TensorData::from([0, 1]),\n        TensorData::from([2, 3]),\n        TensorData::from([4, 5]),\n        TensorData::from([6, 7]),\n        TensorData::from([8, 9]),\n        TensorData::from([10, 11]),\n    ];\n\n    for (index, tensor) in tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_chunk_not_evenly_divisible() {\n    let tensors = TestTensorInt::arange(0..11, &Default::default())\n        .float()\n        .chunk(6, 0);\n    assert_eq!(tensors.len(), 6);\n\n    let expected = [\n        TensorData::from([0, 1]),\n        TensorData::from([2, 3]),\n        TensorData::from([4, 5]),\n        TensorData::from([6, 7]),\n        TensorData::from([8, 9]),\n        TensorData::from([10]),\n    ];\n\n    for (index, tensor) in tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_chunk_not_evenly_divisible_remains_several() {\n    let tensors = TestTensorInt::arange(0..100, &Default::default())\n        .float()\n        .chunk(8, 0);\n    assert_eq!(tensors.len(), 8);\n\n    let expected = [13, 13, 13, 13, 13, 13, 13, 9];\n\n    for (index, tensor) in tensors.iter().enumerate() {\n        assert_eq!(tensor.shape()[0], expected[index]);\n    }\n}\n\n#[test]\nfn test_chunk_not_divisible() {\n    let tensors = TestTensorInt::arange(0..6, &Default::default())\n        .float()\n        .chunk(7, 0);\n    assert_eq!(tensors.len(), 6);\n\n    let expected = [\n        TensorData::from([0]),\n        TensorData::from([1]),\n        TensorData::from([2]),\n        TensorData::from([3]),\n        TensorData::from([4]),\n        TensorData::from([5]),\n    ];\n\n    for (index, tensor) in tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\n#[should_panic]\nfn test_invalid_dim() {\n    let _tensors = TestTensorInt::arange(0..12, &Default::default()).chunk(6, 1);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/clamp.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn clamp_min() {\n    let device = Default::default();\n    // test float tensor\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &device);\n\n    let output = tensor.clamp_min(2.0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]), false);\n\n    // test int tensor\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor = TestTensorInt::<2>::from_data(data, &device);\n    let output = tensor.clamp_min(2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2, 2, 2], [3, 4, 5]]), false);\n}\n\n#[test]\nfn clamp_max() {\n    let device = Default::default();\n    // test float tensor\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &device);\n\n    let output = tensor.clamp_max(2.0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]), false);\n\n    // test int tensor\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor = TestTensorInt::<2>::from_data(data, &device);\n    let output = tensor.clamp_max(4);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 2], [3, 4, 4]]), false);\n}\n\n#[test]\nfn clamp_min_max() {\n    let device = Default::default();\n    // test float tensor\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &device);\n    let output = tensor.clamp(1.0, 4.0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]), false);\n\n    // test int tensor\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor = TestTensorInt::<2>::from_data(data, &device);\n    let output = tensor.clamp(1, 4);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 1, 2], [3, 4, 4]]), false);\n}\n\n#[test]\nfn clamp_min_max_vec_should_compile() {\n    let input = TestTensor::<2>::ones([2, 4], &Default::default());\n    let output = input.clamp(0., 0.5);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/close.rs",
    "content": "use super::*;\nuse burn_tensor::{DEFAULT_ATOL, DEFAULT_RTOL, TensorData};\n\n#[test]\nfn test_is_close() {\n    let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;\n\n    let data_actual = tensor1\n        .clone()\n        .is_close(tensor2.clone(), None, None)\n        .into_data();\n    let defaults_expected = TensorData::from([[true, true, true], [true, true, false]]);\n    defaults_expected.assert_eq(&data_actual, false);\n\n    // Using the defaults.\n    let data_actual = tensor1\n        .is_close(tensor2, Some(DEFAULT_RTOL), Some(DEFAULT_ATOL))\n        .into_data();\n    defaults_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_all_close() {\n    let tensor1 = TestTensor::<2>::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 3.0]]) + 1e-9;\n    assert!(!tensor1.clone().all_close(tensor2.clone(), None, None));\n\n    let tensor2 = TestTensor::from([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]) + 1e-9;\n    assert!(tensor1.all_close(tensor2, None, None));\n\n    // non finite values\n    let inf_plus = TestTensor::<2>::from([[f32::INFINITY]]);\n    let one = TestTensor::<2>::from([[1.]]);\n    let inf_minus = TestTensor::<2>::from([[-f32::INFINITY]]);\n    assert!(!inf_plus.clone().all_close(inf_minus.clone(), None, None));\n    assert!(!one.clone().all_close(inf_minus.clone(), None, None));\n    assert!(!one.all_close(inf_plus.clone(), None, None));\n    assert!(inf_plus.clone().all_close(inf_plus, None, None));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/comparison.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_equal_inf() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0], [f32::INFINITY, 4.0, f32::NEG_INFINITY]]);\n    let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_not_equal_inf() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, f32::INFINITY, 5.0]]);\n    let data_2 = TensorData::from([[1.0, 1.0, 1.0], [f32::INFINITY, 3.0, f32::NEG_INFINITY]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.not_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, true], [true, true, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_equal() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_not_equal() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.not_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, true], [true, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_equal_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().equal_elem(2);\n    let data_actual_inplace = tensor_1.equal_elem(2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_not_equal_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal_elem(2);\n    let data_actual_inplace = tensor_1.not_equal_elem(2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn greater_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_elem(4);\n    let data_actual_inplace = tensor_1.greater_elem(4);\n\n    let data_expected = TensorData::from([[false, false, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_equal_elem(4.0);\n    let data_actual_inplace = tensor_1.greater_equal_elem(4.0);\n\n    let data_expected = TensorData::from([[false, false, false], [false, true, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater(tensor_2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater_equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_elem(4.0);\n    let data_actual_inplace = tensor_1.lower_elem(4.0);\n\n    let data_expected = TensorData::from([[true, true, true], [true, false, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_equal_elem() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_equal_elem(4.0);\n    let data_actual_inplace = tensor_1.lower_equal_elem(4.0);\n\n    let data_expected = TensorData::from([[true, true, true], [true, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_equal() {\n    let tensor_1 = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = TestTensor::<2>::from([[1.0, 1.0, 1.0], [4.0, 3.0, 50.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_broadcast() {\n    // Test broadcasting with shape [1, 4] vs [4, 4]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]);\n    let data_2 = TensorData::from([\n        [0.5, 1.5, 2.5, 3.5],\n        [1.5, 2.5, 3.5, 4.5],\n        [2.5, 3.5, 4.5, 5.5],\n        [3.5, 4.5, 5.5, 6.5],\n    ]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.greater(tensor_2);\n\n    let expected = TensorData::from([\n        [true, true, true, true],\n        [false, false, false, false],\n        [false, false, false, false],\n        [false, false, false, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal_broadcast() {\n    // Test broadcasting with shape [4, 1] vs [1, 4]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1.0], [2.0], [3.0], [4.0]]);\n    let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0]]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.greater_equal(tensor_2);\n\n    let expected = TensorData::from([\n        [true, false, false, false],\n        [true, true, false, false],\n        [true, true, true, false],\n        [true, true, true, true],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_lower_broadcast() {\n    // Test broadcasting mimicking CLIP pattern: [1, 5] vs [5, 1]\n    let device = Default::default();\n    let data_1 = TensorData::from([[0.0, 1.0, -1.0, 2.0, -2.0]]);\n    let data_2 = TensorData::from([[0.5], [1.5], [-0.5], [-1.5], [2.5]]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.lower(tensor_2);\n\n    let expected = TensorData::from([\n        [true, false, true, false, true],\n        [true, true, true, false, true],\n        [false, false, true, false, true],\n        [false, false, false, false, true],\n        [true, true, true, true, true],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_lower_equal_broadcast() {\n    // Test broadcasting with shape [1, 1] vs [2, 4]\n    let device = Default::default();\n    let data_1 = TensorData::from([[2.5]]);\n    let data_2 = TensorData::from([[1.0, 2.0, 3.0, 4.0], [2.0, 2.5, 3.0, 3.5]]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.lower_equal(tensor_2);\n\n    let expected = TensorData::from([[false, false, true, true], [false, true, true, true]]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_equal_broadcast() {\n    // Test broadcasting with different ranks\n    let device = Default::default();\n    let data_1 = TensorData::from([[2.0], [3.0], [4.0]]);\n    let data_2 = TensorData::from([[2.0, 3.0, 4.0, 2.0]]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.equal(tensor_2);\n\n    let expected = TensorData::from([\n        [true, false, false, true],\n        [false, true, false, false],\n        [false, false, true, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_not_equal_broadcast() {\n    // Test broadcasting with shape [3, 1] vs [1, 3]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1.0], [2.0], [3.0]]);\n    let data_2 = TensorData::from([[1.0, 2.0, 3.0]]);\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.not_equal(tensor_2);\n\n    let expected = TensorData::from([\n        [false, true, true],\n        [true, false, true],\n        [true, true, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/create_like.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Distribution, TensorData};\n\n#[test]\nfn should_support_zeros_like() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let tensor = tensor.zeros_like();\n    let expected = TensorData::from([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.]]]);\n\n    tensor\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_ones_like() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let tensor = tensor.ones_like();\n    let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);\n\n    tensor\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_randoms_like() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let tensor = tensor.random_like(Distribution::Uniform(0.99999, 1.));\n    let expected = TensorData::from([[[1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.]]]);\n\n    tensor\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/cross.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[cfg(feature = \"std\")]\nuse burn_backend_tests::might_panic;\n\n#[test]\nfn test_cross_3d_last_dim() {\n    let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]);\n    let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]);\n\n    let output = tensor_1.cross(tensor_2, -1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cross_3d_non_contiguous_last_dim() {\n    let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0], [2.0, -1.0, 4.0]]);\n    let tensor_2 = TestTensor::from([[4.0, 3.0], [-2.0, 5.0], [1.0, -2.0]]);\n\n    let output = tensor_1.cross(tensor_2.permute([1, 0]), -1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[-7.0, -21.0, -14.0], [-18.0, 16.0, 13.0]]),\n        false,\n    );\n}\n\n#[cfg(feature = \"std\")]\n#[might_panic(reason = \"not implemented: Cross product on non-last dimension\")]\n#[test]\nfn test_cross_3d_dim0() {\n    let tensor_1 = TestTensor::<2>::from([[1.0, 0.0], [0.0, 1.0], [0.0, 0.0]]);\n    let tensor_2 = TestTensor::from([[0.0, 1.0], [0.0, 0.0], [1.0, 0.0]]);\n\n    let output = tensor_1.cross(tensor_2, 0);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[0.0, 0.0], [-1.0, 0.0], [0.0, -1.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cross_3d_broadcast() {\n    let tensor_1 = TestTensor::<2>::from([[1.0, 3.0, -5.0]]);\n    let tensor_2 = TestTensor::from([[4.0, -2.0, 1.0], [3.0, 5.0, -2.0]]);\n\n    let output = tensor_1.cross(tensor_2, -1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[-7.0, -21.0, -14.0], [19.0, -13.0, -4.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cross_4d_last_dim() {\n    let tensor_1 = TestTensor::<3>::from([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]);\n    let tensor_2 = TestTensor::from([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]);\n\n    let output = tensor_1.cross(tensor_2, -1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]]),\n        false,\n    );\n}\n\n// Helper to compute expected cross product for 2-D (N × 3) tensors.\nfn manual_cross(a: &[[f32; 3]], b: &[[f32; 3]]) -> Vec<[f32; 3]> {\n    a.iter()\n        .zip(b.iter())\n        .map(|(x, y)| {\n            [\n                x[1] * y[2] - x[2] * y[1],\n                x[2] * y[0] - x[0] * y[2],\n                x[0] * y[1] - x[1] * y[0],\n            ]\n        })\n        .collect()\n}\n\n#[test]\nfn forward_matches_manual_cross() {\n    let a_raw = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];\n    let b_raw = [[7.0, 8.0, 9.0], [1.0, 0.0, -1.0]];\n    let a = TestTensor::<2>::from(a_raw);\n    let b = TestTensor::<2>::from(b_raw);\n\n    let out = a.cross(b.clone(), 1);\n    let expected_vec = manual_cross(&a_raw, &b_raw);\n    let expected: [[f32; 3]; 2] = [expected_vec[0], expected_vec[1]];\n\n    out.into_data()\n        .assert_eq(&TensorData::from(expected), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/cumulative.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_cumsum_float_dim_0() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let output = tensor.cumsum(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]), false);\n}\n\n#[test]\nfn test_cumsum_float_dim_1() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let output = tensor.cumsum(1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cumsum_non_contiguous() {\n    let tensor = TestTensor::<2>::from([[1., 2.], [3., 4.]]).swap_dims(0, 1);\n\n    let output = tensor.cumsum(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1., 4.], [2., 6.]]), false);\n}\n\n#[test]\nfn test_cumsum_float_3d() {\n    let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);\n\n    let output = tensor.cumsum(2);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[1.0, 3.0], [3.0, 7.0]], [[5.0, 11.0], [7.0, 15.0]]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cumprod_float_dim_0() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let output = tensor.cumprod(0);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cumprod_float_dim_1() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let output = tensor.cumprod(1);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cumprod_float_3d() {\n    let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]);\n\n    let output = tensor.cumprod(2);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[1.0, 2.0], [3.0, 12.0]], [[5.0, 30.0], [7.0, 56.0]]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cummin_float_dim_0() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]);\n\n    let output = tensor.cummin(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [2.0, 1.0, 1.0]]), false);\n}\n\n#[test]\nfn test_cummin_float_dim_1() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]);\n\n    let output = tensor.cummin(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0, 1.0, 1.0], [2.0, 2.0, 1.0]]), false);\n}\n\n#[test]\nfn test_cummin_float_3d() {\n    let tensor = TestTensor::<3>::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 6.0], [7.0, 8.0]]]);\n\n    let output = tensor.cummin(2);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 5.0], [7.0, 7.0]]]),\n        false,\n    );\n}\n\n#[test]\nfn test_cummax_float_dim_0() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]);\n\n    let output = tensor.cummax(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [3.0, 5.0, 4.0]]), false);\n}\n\n#[test]\nfn test_cummax_float_dim_1() {\n    let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [1.0, 5.0, 2.0]]);\n\n    let output = tensor.cummax(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0, 3.0, 4.0], [1.0, 5.0, 5.0]]), false);\n}\n\n#[test]\nfn test_cummax_float_3d() {\n    let tensor = TestTensor::<3>::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 2.0], [6.0, 1.0]]]);\n\n    let output = tensor.cummax(2);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[1.0, 3.0], [2.0, 4.0]], [[5.0, 5.0], [6.0, 6.0]]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/div.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_div_ops() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 / tensor_2;\n    let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_div_broadcast() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);\n    let data_2 = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 / tensor_2;\n\n    output.into_data().assert_eq(\n        &TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]),\n        false,\n    );\n}\n\n#[test]\nfn should_support_div_scalar_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let scalar = 2.0;\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(data, &device);\n\n    let output = tensor / scalar;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/dot.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_float() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<1>::from_data([1.0, 2.0, 3.0], &device);\n    let tensor_2 = TestTensor::<1>::from_data([0.0, -1.0, 4.0], &device);\n\n    let output = tensor_1.dot(tensor_2);\n    let expected = TensorData::from([10.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<1>::from_data([1, 2, 3], &device);\n    let tensor_2 = TestTensor::<1>::from_data([0, -1, 4], &device);\n\n    let output = tensor_1.dot(tensor_2);\n    let expected = TensorData::from([10]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn test_panics_for_different_sizes() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<1>::from_data([1, 2], &device);\n    let tensor_2 = TestTensor::<1>::from_data([1, 2, 3], &device);\n    let _output = tensor_1.dot(tensor_2);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/erf.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_erf_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.erf();\n    let expected = TensorData::from([[0.0000, 0.8427, 0.99532], [0.99998, 1.0000, 1.0000]]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_absolute(2e-3),\n    );\n}\n\n#[test]\nfn should_support_erf_ops_with_negative_number() {\n    let data = TensorData::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.erf();\n    let expected = TensorData::from([\n        [-0.06312324, -0.048490416, -0.10016122],\n        [0.99998, 1.0000, 1.0000],\n    ]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_absolute(3e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/exp.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_exp_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.exp();\n    let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/expand.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn expand_2d() {\n    let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default());\n    let output = tensor.expand([3, 3]);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]),\n        false,\n    );\n\n    let tensor = TestTensor::<1>::from_floats([4.0, 7.0, 2.0, 3.0], &Default::default());\n    let output = tensor.expand([2, 4]);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]),\n        false,\n    );\n}\n\n#[test]\nfn expand_3d() {\n    let tensor = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default());\n    let output = tensor.expand([3, 2, 2]);\n    let expected = TensorData::from([\n        [[1.0, 2.0], [3.0, 4.0]],\n        [[1.0, 2.0], [3.0, 4.0]],\n        [[1.0, 2.0], [3.0, 4.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn expand_higher_dimensions() {\n    let tensor = TestTensor::<2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &Default::default());\n    let output = tensor.expand([2, 3, 4]);\n    let expected = TensorData::from([\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n        ],\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn expand_sum_3d() {\n    let tensor = TestTensor::<2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &Default::default());\n    let output = tensor.expand([3, 2, 2]).sum_dim(0);\n    let expected = TensorData::from([[[3.0, 6.0], [9.0, 12.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn broadcast_single() {\n    let tensor = TestTensor::<1>::from_floats([1.0], &Default::default());\n    let output = tensor.expand([2, 3]);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), false);\n}\n\n#[test]\n#[should_panic]\nfn should_fail_expand_incompatible_shapes() {\n    let tensor = TestTensor::<1>::from_floats([1.0, 2.0, 3.0], &Default::default());\n    let _expanded_tensor = tensor.expand([2, 2]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/finite.rs",
    "content": "use super::*;\n\n#[test]\nfn is_finite() {\n    let all_finite = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let all_finite_expected = TestTensorBool::<2>::from([[true, true, true], [true, true, true]]);\n\n    let with_inf_nan = TestTensor::<2>::from([\n        [0.0, f32::INFINITY, f32::NAN],\n        [f32::NEG_INFINITY, f32::NAN, 5.0],\n    ]);\n    let with_inf_nan_expected =\n        TestTensorBool::<2>::from([[true, false, false], [false, false, true]]);\n\n    all_finite_expected\n        .into_data()\n        .assert_eq(&all_finite.is_finite().into_data(), false);\n\n    with_inf_nan\n        .is_finite()\n        .into_data()\n        .assert_eq(&with_inf_nan_expected.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/flatten.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\n\n/// Test if the function can successfully flatten a 4D tensor to a 1D tensor.\n#[test]\nfn should_flatten_to_1d() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let flattened_tensor: TestTensor<1> = tensor.flatten(0, 3);\n    let expected_shape = Shape::new([120]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully flatten the middle dimensions of a 4D tensor.\n#[test]\nfn should_flatten_middle() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let flattened_tensor: TestTensor<3> = tensor.flatten(1, 2);\n    let expected_shape = Shape::new([2, 12, 5]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully flatten the first dimensions of a 4D tensor.\n#[test]\nfn should_flatten_begin() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let flattened_tensor: TestTensor<2> = tensor.flatten(0, 2);\n    let expected_shape = Shape::new([24, 5]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully flatten the last dimensions of a 4D tensor.\n#[test]\nfn should_flatten_end() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let flattened_tensor: TestTensor<2> = tensor.flatten(1, 3);\n    let expected_shape = Shape::new([2, 60]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can flatten negative indices.\n#[test]\nfn should_flatten_end_negative_indices() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let flattened_tensor: TestTensor<2> = tensor.flatten(-3, -1);\n    let expected_shape = Shape::new([2, 60]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n\n/// Test if the function panics when the start dimension is greater than the end dimension.\n#[test]\n#[should_panic]\nfn should_flatten_panic() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let _flattened_tensor: TestTensor<2> = tensor.flatten(2, 0);\n}\n\n#[test]\n#[should_panic]\nfn not_enough_destination_dimension() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 5, 15]), &Default::default());\n    let flattened_tensor: TestTensor<1> = tensor.flatten(1, 2);\n    let expected_shape = Shape::new([75]);\n    assert_eq!(flattened_tensor.shape(), expected_shape);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/flip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn flip_float() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .float();\n\n    let flipped = tensor.clone().flip([0, 2]);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).float()\n    let expected = TensorData::from([\n        [\n            [15., 14., 13., 12.],\n            [19., 18., 17., 16.],\n            [23., 22., 21., 20.],\n        ],\n        [[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]],\n    ]);\n\n    flipped.into_data().assert_eq(&expected, false);\n\n    // Test with no flip\n    let flipped = tensor.clone().flip([]);\n    tensor.into_data().assert_eq(&flipped.into_data(), false);\n}\n\n#[test]\n#[should_panic]\nfn flip_duplicated_axes() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with a duplicated axis\n    let _ = tensor.clone().flip([0, 0, 1]);\n}\n\n#[test]\n#[should_panic]\nfn flip_out_of_bound_axis() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with an out of bound axis\n    let _ = tensor.clone().flip([3, 0, 1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/floor.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_floor_ops() {\n    let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.floor();\n    let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/fmod.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{ElementConversion, TensorData};\n\n#[allow(unused_imports)] // f16\nuse num_traits::Float;\n\n#[test]\nfn should_support_fmod_ops() {\n    let dividend = TensorData::from([[5.3, -5.3], [7.5, -7.5]]);\n    let divisor = TensorData::from([[2.0, 2.0], [3.0, 3.0]]);\n\n    let dividend_tensor = TestTensor::<2>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<2>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([[1.3, -1.3], [1.5, -1.5]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_fmod_scalar() {\n    let data = TensorData::from([5.3, -5.3, 7.5, -7.5]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.fmod_scalar(2.0);\n    let expected = TensorData::from([1.3, -1.3, 1.5, -1.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_positive_dividend_positive_divisor() {\n    let dividend = TensorData::from([10.0, 7.5, 3.8, 1.2]);\n    let divisor = TensorData::from([3.0, 2.0, 1.5, 0.7]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([1.0, 1.5, 0.8, 0.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_negative_dividend() {\n    let dividend = TensorData::from([-10.0, -7.5, -3.8, -1.2]);\n    let divisor = TensorData::from([3.0, 2.0, 1.5, 0.7]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([-1.0, -1.5, -0.8, -0.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_mixed_signs() {\n    let dividend = TensorData::from([5.3, -5.3, 5.3, -5.3]);\n    let divisor = TensorData::from([2.0, 2.0, -2.0, -2.0]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    // fmod result has same sign as dividend\n    let expected = TensorData::from([1.3, -1.3, 1.3, -1.3]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_infinity_dividend() {\n    // If x is ±∞ and y is not NaN, NaN is returned\n    let dividend = TensorData::from([\n        f32::INFINITY,\n        f32::NEG_INFINITY,\n        f32::INFINITY,\n        f32::NEG_INFINITY,\n    ]);\n    let divisor = TensorData::from([2.0, 3.0, -2.0, -3.0]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let data = output.into_data();\n    let values = data.as_slice::<FloatElem>().unwrap();\n\n    // All results should be NaN\n    assert!(values[0].is_nan(), \"fmod(inf, 2.0) should be NaN\");\n    assert!(values[1].is_nan(), \"fmod(-inf, 3.0) should be NaN\");\n    assert!(values[2].is_nan(), \"fmod(inf, -2.0) should be NaN\");\n    assert!(values[3].is_nan(), \"fmod(-inf, -3.0) should be NaN\");\n}\n\n#[test]\nfn should_handle_zero_divisor() {\n    // If y is ±0 and x is not NaN, NaN should be returned\n    let dividend = TensorData::from([5.3, -5.3, 0.0, 1.0]);\n    let divisor = TensorData::from([0.0, -0.0, 0.0, -0.0]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let data = output.into_data();\n    let values = data.as_slice::<FloatElem>().unwrap();\n\n    // All results should be NaN\n    assert!(values[0].is_nan(), \"fmod(5.3, 0.0) should be NaN\");\n    assert!(values[1].is_nan(), \"fmod(-5.3, -0.0) should be NaN\");\n    assert!(values[2].is_nan(), \"fmod(0.0, 0.0) should be NaN\");\n    assert!(values[3].is_nan(), \"fmod(1.0, -0.0) should be NaN\");\n}\n\n#[test]\nfn should_handle_infinity_divisor() {\n    // If y is ±∞ and x is finite, x is returned\n    let dividend = TensorData::from([5.3, -5.3, 0.0, -0.0]);\n    let divisor = TensorData::from([\n        f32::INFINITY,\n        f32::NEG_INFINITY,\n        f32::INFINITY,\n        f32::NEG_INFINITY,\n    ]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([5.3, -5.3, 0.0, -0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_nan_arguments() {\n    // If either argument is NaN, NaN is returned\n    let dividend = TensorData::from([f32::NAN, 5.3, f32::NAN, 0.0]);\n    let divisor = TensorData::from([2.0, f32::NAN, f32::NAN, 3.0]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let data = output.into_data();\n    let values = data.as_slice::<FloatElem>().unwrap();\n\n    assert!(values[0].is_nan(), \"fmod(NaN, 2.0) should be NaN\");\n    assert!(values[1].is_nan(), \"fmod(5.3, NaN) should be NaN\");\n    assert!(values[2].is_nan(), \"fmod(NaN, NaN) should be NaN\");\n    assert!(!values[3].is_nan(), \"fmod(0.0, 3.0) should be 0.0\");\n}\n\n#[test]\nfn should_handle_negative_zero() {\n    // If x is -0 and y is greater than zero, either +0 or -0 may be returned\n    let dividend = TensorData::from([-0.0_f32]);\n    let divisor = TensorData::from([2.0_f32]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let data = output.into_data();\n    let values = data.as_slice::<FloatElem>().unwrap();\n\n    // Result should be zero (either +0 or -0 is acceptable)\n    assert_eq!(\n        values[0],\n        0.0f32.elem::<FloatElem>(),\n        \"fmod(-0, 2.0) should be zero\"\n    );\n}\n\n#[test]\nfn should_support_fmod_broadcasting_2d() {\n    // Test broadcasting: 1x2 with 3x2\n    let dividend = TensorData::from([[5.3, -5.3]]); // Shape: 1x2\n    let divisor = TensorData::from([[2.0, 2.0], [3.0, 3.0], [1.5, 1.5]]); // Shape: 3x2\n\n    let dividend_tensor = TestTensor::<2>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<2>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([\n        [1.3, -1.3], // 5.3 % 2.0, -5.3 % 2.0\n        [2.3, -2.3], // 5.3 % 3.0, -5.3 % 3.0\n        [0.8, -0.8], // 5.3 % 1.5, -5.3 % 1.5\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_fmod_broadcasting_3d() {\n    // Test broadcasting: 1x1x3 with 2x1x3\n    let dividend = TensorData::from([[[5.0, -7.0, 8.0]]]); // Shape: 1x1x3\n    let divisor = TensorData::from([[[3.0, 3.0, 3.0]], [[4.0, 4.0, 4.0]]]); // Shape: 2x1x3\n\n    let dividend_tensor = TestTensor::<3>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<3>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([\n        [[2.0, -1.0, 2.0]], // 5.0 % 3.0, -7.0 % 3.0, 8.0 % 3.0\n        [[1.0, -3.0, 0.0]], // 5.0 % 4.0, -7.0 % 4.0, 8.0 % 4.0\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_fmod_scalar_broadcasting() {\n    // Test scalar operation with different shapes\n    let data = TensorData::from([[5.3, -5.3, 7.5], [-7.5, 10.0, -10.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.fmod_scalar(3.0);\n    let expected = TensorData::from([[2.3, -2.3, 1.5], [-1.5, 1.0, -1.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_edge_case_values() {\n    // Test various edge cases\n    let dividend = TensorData::from([0.0, -0.0, 1e-10, -1e-10, 10.0, -10.0]);\n    let divisor = TensorData::from([1.0, 1.0, 1.0, 1.0, 3.0, 3.0]);\n\n    let dividend_tensor = TestTensor::<1>::from_data(dividend, &Default::default());\n    let divisor_tensor = TestTensor::<1>::from_data(divisor, &Default::default());\n\n    let output = dividend_tensor.fmod(divisor_tensor);\n    let expected = TensorData::from([0.0, -0.0, 1e-10, -1e-10, 1.0, -1.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_special_scalar_cases() {\n    // Test scalar operations with special values\n    let data = TensorData::from([5.3, -5.3, 0.0, -0.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    // Test with infinity divisor\n    let output_inf = tensor.clone().fmod_scalar(f32::INFINITY);\n    let expected_inf = TensorData::from([5.3, -5.3, 0.0, -0.0]);\n    output_inf\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected_inf, Tolerance::default());\n\n    // Test with very small divisor\n    // Doesn't work if the test divisor is subnormal\n    if FloatElem::MIN_POSITIVE > 1e-5f32.elem::<FloatElem>() {\n        return;\n    }\n\n    let output_small = tensor.clone().fmod_scalar(1e-5);\n    let data = output_small.into_data();\n    let values = data.as_slice::<FloatElem>().unwrap();\n\n    // let expected = TensorData::from([0.0, 0.0, 0.0, 0.0]);\n\n    // Results should be very small remainders\n    assert!(values[0].abs() < 1e-5f32.elem::<FloatElem>());\n    assert!(values[1].abs() < 1e-5f32.elem::<FloatElem>());\n    assert_eq!(values[2], 0.0f32.elem::<FloatElem>());\n    assert_eq!(values[3], 0.0f32.elem::<FloatElem>());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/full.rs",
    "content": "use super::*;\nuse burn_tensor::{DType, TensorData};\n\n#[test]\nfn test_data_full() {\n    let tensor = TensorData::full([2, 3], 2.0);\n\n    tensor.assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), false);\n}\n\n#[test]\nfn test_tensor_full() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::full([2, 3], 2.1, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false);\n}\n\n#[test]\nfn test_tensor_full_options() {\n    let tensor = TestTensor::<2>::full([2, 3], 2.1, (&Default::default(), DType::F32));\n    assert_eq!(tensor.dtype(), DType::F32);\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/gather_scatter.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_gather_1d_dim0() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats([0.0, 1.0, 2.0], &device);\n    let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]), false);\n}\n\n#[test]\nfn should_gather_2d_dim0() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]], &device);\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]), false);\n}\n\n#[test]\nfn should_gather_2d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]], &device);\n\n    let output = tensor.gather(1, indices);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]),\n        false,\n    );\n}\n\n#[test]\nfn should_gather_3d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &device,\n    );\n    let indices =\n        TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device);\n\n    let output = tensor.gather(1, indices);\n    let expected = TensorData::from([\n        [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]],\n        [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_gather_2d_only_1dim() {\n    let device = Default::default();\n    let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::<2>::from_ints([[1, 2]], &device).reshape([2, 1]);\n\n    let output = tensor.gather(1, indices);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0], [5.0]]), false);\n}\n\n#[test]\nfn should_scatter_add_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats([0.0, 0.0, 0.0], &device);\n    let values = TestTensor::from_floats([5.0, 4.0, 3.0], &device);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &device);\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([4.0, 5.0, 3.0]), false);\n}\n\n#[test]\nfn should_scatter_add_2d_dim0() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device);\n    let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]], &device);\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]), false);\n}\n\n#[test]\nfn should_scatter_add_2d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device);\n    let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]], &device);\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]), false);\n}\n\n#[test]\nfn should_scatter_add_3d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &device,\n    );\n    let values = TestTensor::from_floats(\n        [\n            [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],\n            [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],\n        ],\n        &device,\n    );\n    let indices =\n        TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]], &device);\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([\n        [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]],\n        [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_scatter_add_2d_dim1_diff_shape() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device);\n    let values = TestTensor::from_floats([[1.0], [4.0]], &device);\n    let indices = TestTensorInt::from_ints([[1], [2]], &device);\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]), false);\n}\n\n#[test]\n#[should_panic]\nfn scatter_should_panic_on_mismatch_of_shapes() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats([0.0, 0.0, 0.0], &device);\n    let values = TestTensor::from_floats([5.0, 4.0], &device);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &device);\n\n    tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/grid_sample.rs",
    "content": "use super::*;\nuse burn_tensor::{\n    TensorData, Tolerance,\n    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},\n};\n\n/// Tests grid_sample_2d with default options (align_corners=false, zeros padding).\n///\n/// For a 3x3 input with grid coordinates:\n/// - (0.0, 0.0) maps to pixel (1.0, 1.0) -> center pixel = 4.0\n/// - (-1.0, 0.25) maps to pixel (-0.5, 1.375) -> partially out of bounds\n/// - (1.0, 1.0) maps to pixel (2.5, 2.5) -> corner, partially out of bounds\n/// - (0.2, -0.8) maps to pixel (1.3, 0.3) -> interpolates around center-top\n#[test]\nfn should_grid_sample_2d_default() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_floats(\n        [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]],\n        &device,\n    );\n    let grid = TestTensor::<4>::from_floats(\n        [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]],\n        &device,\n    );\n\n    let output = tensor.grid_sample_2d(grid, InterpolateMode::Bilinear);\n\n    // Expected values computed with PyTorch grid_sample(align_corners=False, padding_mode='zeros')\n    let expected = TensorData::from([[[[4.0, 2.0625], [2.0, 1.04]]]]);\n    output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n/// Tests grid_sample_2d with align_corners=true and border padding.\n///\n/// This is the original Burn semantics before the API change.\n#[test]\nfn should_grid_sample_2d_align_corners_border() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_floats(\n        [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]],\n        &device,\n    );\n    let grid = TestTensor::<4>::from_floats(\n        [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]],\n        &device,\n    );\n\n    let options = GridSampleOptions::new(InterpolateMode::Bilinear)\n        .with_padding_mode(GridSamplePaddingMode::Border)\n        .with_align_corners(true);\n    let output = tensor.grid_sample_2d(grid, options);\n\n    // Expected values computed with PyTorch grid_sample(align_corners=True, padding_mode='border')\n    let expected = TensorData::from([[[[4.0, 3.75], [8.0, 1.8]]]]);\n    output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n/// Tests out-of-bounds grid coordinates with zeros padding.\n/// Grid coordinate (0.0, -2.0) maps to pixel (1.0, -2.5) which is completely out of bounds.\n#[test]\nfn should_pad_zeros_grid_sample_2d() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_floats(\n        [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]],\n        &device,\n    );\n    let grid = TestTensor::<4>::from_floats([[[[0.0, -2.0]]]], &device);\n\n    let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());\n\n    // With zeros padding, out-of-bounds samples return 0\n    let expected = TensorData::from([[[[0.0]]]]);\n    output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n/// Tests out-of-bounds grid coordinates with border padding.\n#[test]\nfn should_pad_border_grid_sample_2d() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_floats(\n        [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]],\n        &device,\n    );\n    let grid = TestTensor::<4>::from_floats([[[[0.0, -2.0]]]], &device);\n\n    let options = GridSampleOptions::new(InterpolateMode::Bilinear)\n        .with_padding_mode(GridSamplePaddingMode::Border);\n    let output = tensor.grid_sample_2d(grid, options);\n\n    // With border padding, out-of-bounds coordinates are clamped to border\n    // Grid (0.0, -2.0) with align_corners=false: pixel (1.0, -2.5) -> clamped to (1.0, 0.0) = 1.0\n    let expected = TensorData::from([[[[1.0]]]]);\n    output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n/// Tests bilinear interpolation with reflection padding.\n#[test]\nfn should_pad_reflection_grid_sample_2d() {\n    let device = Default::default();\n    let tensor = TestTensor::<4>::from_floats(\n        [[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]],\n        &device,\n    );\n    let grid = TestTensor::<4>::from_floats(\n        [[[[0.0, 0.0], [-1.0, 0.25]], [[1.0, 1.0], [0.2, -0.8]]]],\n        &device,\n    );\n\n    let options = GridSampleOptions::new(InterpolateMode::Bilinear)\n        .with_padding_mode(GridSamplePaddingMode::Reflection);\n    let output = tensor.grid_sample_2d(grid, options);\n\n    // Expected values computed with PyTorch F.grid_sample(mode='bilinear', padding_mode='reflection', align_corners=False)\n    let expected = TensorData::from([[[[4.0, 4.125], [8.0, 1.3]]]]);\n    output\n        .to_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/inf.rs",
    "content": "use super::*;\n\n#[test]\nfn is_inf() {\n    let no_inf = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let no_inf_expected = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);\n\n    let with_inf =\n        TestTensor::<2>::from([[0.0, f32::INFINITY, 2.0], [f32::NEG_INFINITY, 4.0, 5.0]]);\n    let with_inf_expected = TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);\n\n    no_inf\n        .is_inf()\n        .into_data()\n        .assert_eq(&no_inf_expected.into_data(), false);\n\n    with_inf\n        .is_inf()\n        .into_data()\n        .assert_eq(&with_inf_expected.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/init.rs",
    "content": "use super::*;\nuse burn_tensor::{DType, TensorData};\n\n#[test]\nfn should_support_float_empty() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::empty(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into())\n}\n\n#[test]\nfn should_support_float_empty_options() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::empty(shape, (&Default::default(), DType::F32));\n    assert_eq!(tensor.shape(), shape.into())\n}\n\n#[test]\nfn should_support_float_zeros() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::zeros(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[0., 0.], [0., 0.]]), false);\n}\n\n#[test]\nfn should_support_float_zeros_options() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::zeros(shape, (&Default::default(), DType::F32));\n    assert_eq!(tensor.shape(), shape.into());\n    assert_eq!(tensor.dtype(), DType::F32);\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[0., 0.], [0., 0.]]), false);\n}\n\n#[test]\nfn should_support_float_ones() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::ones(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);\n}\n\n#[test]\nfn should_support_float_ones_options() {\n    let shape = [2, 2];\n    let tensor = TestTensor::<2>::ones(shape, (&Default::default(), DType::F32));\n    assert_eq!(tensor.shape(), shape.into());\n    assert_eq!(tensor.dtype(), DType::F32);\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[1., 1.], [1., 1.]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/iter_dim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_1d_iter_last_item() {\n    let data = [1, 2, 3, 4];\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_ints(data, &device);\n    tensor\n        .iter_dim(0)\n        .last()\n        .unwrap()\n        .into_data()\n        .assert_eq(&TensorData::from([4]), false);\n}\n\n#[test]\n#[should_panic]\nfn test_too_high_dimension() {\n    TestTensor::<1>::zeros([10], &Default::default()).iter_dim(1);\n}\n\n#[test]\nfn test_transposed() {\n    let data = [\n        [1., 2., 3., 1., 2.],\n        [4., 5., 6., 1., 2.],\n        [7., 8., 9., 1., 2.],\n    ];\n    let tensor = TestTensor::<2>::from_floats(data, &Default::default());\n    let lhs = tensor.clone().slice([1..2, 0..5]);\n    let rhs = tensor.transpose().iter_dim(1).nth(1).unwrap();\n    assert_eq!(\n        lhs.into_data().as_slice::<FloatElem>().unwrap(),\n        rhs.into_data().as_slice::<FloatElem>().unwrap()\n    );\n}\n\n#[test]\nfn test_2d_iter_dim() {\n    let tensor =\n        TestTensor::<2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &Default::default());\n\n    let mut iter = tensor.iter_dim(0);\n\n    let iter1 = iter.next().unwrap();\n    iter1\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0, 4.9, 2.0]]), false);\n\n    let iter2 = iter.next().unwrap();\n    iter2\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0, 1.9, 3.0]]), false);\n\n    assert!(iter.next().is_none());\n}\n\n#[test]\nfn test_2d_iter_dim1() {\n    let tensor =\n        TestTensor::<2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &Default::default());\n\n    let mut iter = tensor.iter_dim(1);\n\n    let iter1 = iter.next().unwrap();\n    iter1\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0], [2.0]]), false);\n\n    let iter2 = iter.next().unwrap();\n    iter2\n        .into_data()\n        .assert_eq(&TensorData::from([[4.9], [1.9]]), false);\n\n    let iter3 = iter.next().unwrap();\n    iter3\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0], [3.0]]), false);\n\n    assert!(iter.next().is_none());\n}\n\n#[test]\nfn test_3d_iter_dim() {\n    let tensor = TestTensor::<3>::from([[\n        [1., 2., 3., 1., 2.],\n        [4., 5., 6., 1., 2.],\n        [7., 8., 9., 1., 2.],\n    ]]);\n\n    let mut iter = tensor.clone().iter_dim(0);\n\n    let iter1 = iter.next().unwrap();\n    iter1.into_data().assert_eq(&tensor.into_data(), true);\n\n    assert!(iter.next().is_none());\n}\n\n#[test]\nfn test_3d_iter_dim1() {\n    let tensor = TestTensor::<3>::from([[\n        [1., 2., 3., 1., 2.],\n        [4., 5., 6., 1., 2.],\n        [7., 8., 9., 1., 2.],\n    ]]);\n\n    let mut iter = tensor.iter_dim(1);\n\n    let iter1 = iter.next().unwrap();\n    iter1\n        .into_data()\n        .assert_eq(&TensorData::from([[[1., 2., 3., 1., 2.]]]), false);\n\n    let iter2 = iter.next().unwrap();\n    iter2\n        .into_data()\n        .assert_eq(&TensorData::from([[[4., 5., 6., 1., 2.]]]), false);\n\n    let iter3 = iter.next().unwrap();\n    iter3\n        .into_data()\n        .assert_eq(&TensorData::from([[[7., 8., 9., 1., 2.]]]), false);\n\n    assert!(iter.next().is_none());\n}\n\n#[test]\nfn test_3d_iter_dim2() {\n    let tensor = TestTensor::<3>::from([[\n        [1., 2., 3., 1., 2.],\n        [4., 5., 6., 1., 2.],\n        [7., 8., 9., 1., 2.],\n    ]]);\n\n    let mut iter = tensor.iter_dim(2);\n\n    let iter1 = iter.next().unwrap();\n    iter1\n        .into_data()\n        .assert_eq(&TensorData::from([[[1.], [4.], [7.]]]), false);\n\n    let iter2 = iter.next().unwrap();\n    iter2\n        .into_data()\n        .assert_eq(&TensorData::from([[[2.], [5.], [8.]]]), false);\n\n    let iter3 = iter.next().unwrap();\n    iter3\n        .into_data()\n        .assert_eq(&TensorData::from([[[3.], [6.], [9.]]]), false);\n\n    let iter4 = iter.next().unwrap();\n    iter4\n        .into_data()\n        .assert_eq(&TensorData::from([[[1.], [1.], [1.]]]), false);\n\n    let iter5 = iter.next().unwrap();\n    iter5\n        .into_data()\n        .assert_eq(&TensorData::from([[[2.], [2.], [2.]]]), false);\n\n    assert!(iter.next().is_none());\n}\n\n#[test]\nfn test_iteration_over_low_dim() {\n    let data = [[\n        [1., 2., 3., 1., 2.],\n        [4., 5., 6., 1., 2.],\n        [7., 8., 9., 1., 2.],\n    ]];\n\n    let tensor = TestTensor::<3>::from_floats(data, &Default::default());\n\n    let lhs = tensor.iter_dim(2).nth(1).unwrap();\n    let rhs = TestTensor::<1>::from([2., 5., 8.]);\n    assert_eq!(\n        lhs.into_data().as_slice::<FloatElem>().unwrap(),\n        rhs.into_data().as_slice::<FloatElem>().unwrap()\n    );\n}\n\n#[test]\nfn test_iter_dim_double_end() {\n    let input = TestTensorInt::<1>::arange(0..(4 * 6 * 3), &Default::default()).reshape([4, 6, 3]);\n    let mut iter = input.iter_dim(1);\n\n    let ele0 = TensorData::from([[[0, 1, 2]], [[18, 19, 20]], [[36, 37, 38]], [[54, 55, 56]]]);\n    let ele1 = TensorData::from([[[3, 4, 5]], [[21, 22, 23]], [[39, 40, 41]], [[57, 58, 59]]]);\n    let ele2 = TensorData::from([[[6, 7, 8]], [[24, 25, 26]], [[42, 43, 44]], [[60, 61, 62]]]);\n    let ele3 = TensorData::from([\n        [[9, 10, 11]],\n        [[27, 28, 29]],\n        [[45, 46, 47]],\n        [[63, 64, 65]],\n    ]);\n    let ele4 = TensorData::from([\n        [[12, 13, 14]],\n        [[30, 31, 32]],\n        [[48, 49, 50]],\n        [[66, 67, 68]],\n    ]);\n    let ele5 = TensorData::from([\n        [[15, 16, 17]],\n        [[33, 34, 35]],\n        [[51, 52, 53]],\n        [[69, 70, 71]],\n    ]);\n\n    iter.next().unwrap().into_data().assert_eq(&ele0, false);\n    iter.next_back()\n        .unwrap()\n        .into_data()\n        .assert_eq(&ele5, false);\n    iter.next_back()\n        .unwrap()\n        .into_data()\n        .assert_eq(&ele4, false);\n    iter.next().unwrap().into_data().assert_eq(&ele1, false);\n    iter.next().unwrap().into_data().assert_eq(&ele2, false);\n    iter.next().unwrap().into_data().assert_eq(&ele3, false);\n    assert!(iter.next().is_none());\n    assert!(iter.next_back().is_none());\n}\n\n#[test]\nfn test_iter_dim_single_element() {\n    let input = TestTensorInt::<1>::arange(0..(4 * 3), &Default::default()).reshape([4, 1, 3]);\n\n    let mut iter = input.clone().iter_dim(1);\n    iter.next()\n        .unwrap()\n        .into_data()\n        .assert_eq(&input.clone().into_data(), false);\n    assert!(iter.next_back().is_none());\n    assert!(iter.next().is_none());\n\n    let mut iter = input.clone().iter_dim(1);\n    iter.next_back()\n        .unwrap()\n        .into_data()\n        .assert_eq(&input.clone().into_data(), false);\n    assert!(iter.next().is_none());\n    assert!(iter.next_back().is_none());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/log.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_log_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.log();\n    let expected = TensorData::from([\n        [-f32::INFINITY, 0.0, core::f32::consts::LN_2],\n        [1.09861, 1.38629, 1.60944],\n    ]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/log1p.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_exp_log1p() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.log1p();\n    let expected = TensorData::from([\n        [0.0, core::f32::consts::LN_2, 1.09861],\n        [1.38629, 1.60944, 1.79176],\n    ]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::default().set_half_precision_relative(1e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/mask.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_mask_fill_swap_dims() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::arange(0..16, &device).float();\n    let tensor_1 = tensor_1.reshape([2, 2, 4]);\n    let tensor_1 = tensor_1.swap_dims(0, 2);\n\n    let mask = tensor_1.clone().lower_equal_elem(5.0);\n    let output = tensor_1.clone().mask_fill(mask, -5.0);\n\n    let expected = TensorData::from([\n        [[-5.0, 8.0], [-5.0, 12.0]],\n        [[-5.0, 9.0], [-5.0, 13.0]],\n        [[-5.0, 10.0], [6.0, 14.0]],\n        [[-5.0, 11.0], [7.0, 15.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mask_where_ops() {\n    let device = Default::default();\n    let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n    let value = TestTensor::<2>::from_data(TensorData::from([[1.8, 2.8], [3.8, 4.8]]), &device);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mask_where_broadcast() {\n    let device = Default::default();\n    // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times\n    let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]);\n    let mask = TestTensorBool::<3>::from_bool(\n        TensorData::from([\n            [[true, false], [false, true]],\n            [[false, true], [true, false]],\n            [[false, false], [false, false]],\n            [[true, true], [true, true]],\n        ]),\n        &device,\n    );\n    let value = TestTensor::<3>::ones([4, 2, 2], &device);\n\n    let output = tensor.float().mask_where(mask, value);\n    let expected = TensorData::from([\n        [[1., 3.], [4., 1.]],\n        [[2., 1.], [1., 5.]],\n        [[2., 3.], [4., 5.]],\n        [[1., 1.], [1., 1.]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mask_where_broadcast_value_small() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(2..4, &device).float();\n    let mask = TestTensorBool::<1>::from_bool(TensorData::from([true, false]), &device);\n    let value = TestTensor::<1>::ones([1], &device);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([1., 3.]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_handle_mask_where_nans() {\n    let device = Default::default();\n    let tensor = TestTensor::from_data(\n        [\n            [f32::NAN, f32::NAN, f32::NAN],\n            [f32::NAN, f32::NAN, f32::NAN],\n            [f32::NAN, f32::NAN, f32::NAN],\n        ],\n        &device,\n    );\n    let mask = TestTensorBool::<2>::from_bool(\n        TensorData::from([\n            [true, true, true],\n            [true, true, false],\n            [false, false, false],\n        ]),\n        &device,\n    );\n    let value = TestTensor::<2>::from_data(\n        TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]),\n        &device,\n    );\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([\n        [0.9, 0.8, 0.7],\n        [0.6, 0.5, f32::NAN],\n        [f32::NAN, f32::NAN, f32::NAN],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_mask_fill_ops() {\n    let device = Default::default();\n    let tensor = TestTensor::from_data([[1.0, 7.0], [2.0, 3.0]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n\n    let output = tensor.mask_fill(mask, 2.0);\n    let expected = TensorData::from([[2.0, 7.0], [2.0, 2.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mask_fill_broadcasted() {\n    let device = Default::default();\n    let tensor = TestTensor::zeros([1, 4, 2, 2], &device);\n    let mask = TestTensorBool::<4>::from_bool(\n        TensorData::from([[[[true, false], [false, true]]]]),\n        &device,\n    );\n\n    let output = tensor.mask_fill(mask, 2.0);\n    let expected = TensorData::from([[\n        [[2., 0.], [0., 2.]],\n        [[2., 0.], [0., 2.]],\n        [[2., 0.], [0., 2.]],\n        [[2., 0.], [0., 2.]],\n    ]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn float_mask_fill_infinite() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [f32::NEG_INFINITY, f32::NEG_INFINITY],\n            [f32::NEG_INFINITY, f32::NEG_INFINITY],\n        ],\n        &device,\n    );\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n\n    let output = tensor.mask_fill(mask, 10.0f32);\n    let expected = TensorData::from([[10f32, f32::NEG_INFINITY], [f32::NEG_INFINITY, 10f32]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/matmul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::{ElementConversion, Tolerance, backend::Backend};\n\n#[test]\nfn test_float_matmul_d2() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats([[1.0, 7.0], [2.0, 3.0], [1.0, 5.0]], &device);\n    let tensor_2 = TestTensor::from_floats([[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_d3() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_floats([[[1.0, 7.0], [2.0, 3.0]]], &device);\n    let tensor_2 = TestTensor::from_floats([[[4.0, 7.0], [2.0, 3.0]]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_broadcast_1() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_floats([[[1.0, 7.0], [2.0, 3.0]]], &device);\n    let tensor_2 = TestTensor::from_floats(\n        [[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]],\n        &device,\n    );\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_broadcast_4d() {\n    let device = Default::default();\n    // [2, 1, 2, 2]\n    let tensor_1 = TestTensor::<4>::from_floats(\n        [[[[1.0, 7.0], [2.0, 3.0]]], [[[2.0, 5.0], [6.0, 3.0]]]],\n        &device,\n    );\n    // [1, 2, 2, 2]\n    let tensor_2 = TestTensor::from_floats(\n        [[[[9.0, 8.0], [1.0, 4.0]], [[2.0, 7.0], [3.0, 5.0]]]],\n        &device,\n    );\n\n    // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2]\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [[[16.0, 36.0], [21.0, 28.0]], [[23.0, 42.0], [13.0, 29.0]]],\n        [[[23.0, 36.0], [57.0, 60.0]], [[19.0, 39.0], [21.0, 57.0]]],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_simple_1() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats([[5.0, 14.0], [14.0, 50.0]], &device);\n    let tensor_2 = TestTensor::from_floats([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[15.0, 34.0, 53.0], [42.0, 106.0, 170.0]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_4_3() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats(\n        [[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],\n        &device,\n    );\n    let tensor_2 = TestTensor::from_floats(\n        [[0., 1., 2.], [4., 5., 6.], [8., 9., 10.], [12., 13., 14.]],\n        &device,\n    );\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[56., 62., 68.], [152., 174., 196.], [248., 286., 324.]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_batch_vec_mat() {\n    let device = Default::default();\n\n    // [..., B, 1, K] = [3, 1, 2]\n    let tensor_1 =\n        TestTensor::<3>::from_floats([[[1.0, 7.0]], [[2.0, 3.0]], [[1.0, 5.0]]], &device);\n\n    // [..., 1, K, N] = [1, 2, 3]\n    let tensor_2 = TestTensor::<3>::from_floats([[[4.0, 7.0, 5.0], [2.0, 3.0, 5.0]]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    // [..., B, 1, N] = [3, 1, 3]\n    let expected = TensorData::from([\n        [[18.0, 28.0, 40.0]],\n        [[14.0, 23.0, 25.0]],\n        [[14.0, 22.0, 30.0]],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_trivial() {\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..16, &device)\n        .reshape([4, 4])\n        .float();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1);\n\n    tensor_3.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([\n            [56., 62., 68., 74.],\n            [152., 174., 196., 218.],\n            [248., 286., 324., 362.],\n            [344., 398., 452., 506.],\n        ]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn test_float_matmul_trivial_transposed() {\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..16, &device)\n        .reshape([4, 4])\n        .float();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());\n\n    tensor_3.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([\n            [14., 38., 62., 86.],\n            [38., 126., 214., 302.],\n            [62., 214., 366., 518.],\n            [86., 302., 518., 734.],\n        ]),\n        Tolerance::default(),\n    );\n}\n\n/// Regression test for batch bug in fused matmul\n#[test]\nfn test_float_matmul_vecmat_transposed_fused() {\n    let device = Default::default();\n\n    let batch1 = 1;\n    let batch2 = 2;\n    let batch = batch1 * batch2;\n    let seq_length = 3;\n    let d_model = 32;\n\n    // Guard int arange limits\n    #[allow(clippy::unnecessary_cast)]\n    if (IntElem::MAX as i64) < seq_length * d_model * batch {\n        return;\n    }\n    if FloatElem::MAX.elem::<f64>() < 269493.0 {\n        return;\n    }\n\n    let weight: TestTensor<4> = TestTensorInt::arange(0..d_model * batch, &device)\n        .reshape([batch1, batch2, 1, d_model])\n        .float();\n    let signal: TestTensor<4> = TestTensorInt::arange(0..seq_length * d_model * batch, &device)\n        .reshape([batch1, batch2, seq_length, d_model])\n        .float();\n\n    TestBackend::sync(&device).unwrap();\n    let weight = weight.transpose();\n    let out = signal.matmul(weight) + 5;\n    let expected = TensorData::from([[\n        [[10421.0], [26293.0], [42165.0]],\n        [[172213.0], [220853.0], [269493.0]],\n    ]]);\n    expected.assert_approx_eq(&out.into_data(), Tolerance::<f32>::strict());\n}\n\n#[test]\nfn test_float_matmul_4_8() {\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..32, &device)\n        .reshape([4, 8])\n        .float();\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());\n\n    tensor_3.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([\n            [140., 364., 588., 812.],\n            [364., 1100., 1836., 2572.],\n            [588., 1836., 3084., 4332.],\n            [812., 2572., 4332., 6092.],\n        ]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn test_float_matmul_simple_2() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats([[1.0, 2.0, 3.0, 4.0]], &device);\n    let tensor_2 = TestTensor::from_floats([[3.0], [4.0], [5.0], [6.0]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[50.0]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_float_matmul_simple_3() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats(\n        [[3., 3., 3.], [4., 4., 4.], [5., 5., 5.], [6., 6., 6.]],\n        &device,\n    );\n    let tensor_2 = TestTensor::from_floats(\n        [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]],\n        &device,\n    );\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [9., 18., 27., 36.],\n        [12., 24., 36., 48.],\n        [15., 30., 45., 60.],\n        [18., 36., 54., 72.],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn float_should_panic_when_inner_dimensions_are_not_equal() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_floats([[3., 3.], [4., 4.], [5., 5.], [6., 6.]], &device);\n    let tensor_2 = TestTensor::from_floats(\n        [[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]],\n        &device,\n    );\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [9., 18., 27., 36.],\n        [12., 24., 36., 48.],\n        [15., 30., 45., 60.],\n        [18., 36., 54., 72.],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/maxmin.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_max_dim_2d() {\n    let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    f.clone()\n        .max_dim(0)\n        .into_data()\n        .assert_eq(&TensorData::from([[3., 4., 5.]]), false);\n\n    f.clone()\n        .max_dim(1)\n        .into_data()\n        .assert_eq(&TensorData::from([[2.], [5.]]), false);\n\n    // Negative Index\n    f.clone()\n        .max_dim(-1)\n        .into_data()\n        .assert_eq(&TensorData::from([[2.], [5.]]), false);\n\n    // Regression Test: https://github.com/tracel-ai/burn/issues/3139\n    let z = f.clone().int();\n    z.clone()\n        .max_dim(0)\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 4, 5]]), false);\n    z.clone()\n        .max_dim(1)\n        .into_data()\n        .assert_eq(&TensorData::from([[2], [5]]), false);\n}\n\n#[test]\nfn test_max_dims_2d() {\n    let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    f.clone()\n        .max_dims(&[0])\n        .into_data()\n        .assert_eq(&TensorData::from([[3., 4., 5.]]), false);\n\n    f.clone()\n        .max_dims(&[-2])\n        .into_data()\n        .assert_eq(&TensorData::from([[3., 4., 5.]]), false);\n\n    f.clone()\n        .max_dims(&[0, 1])\n        .into_data()\n        .assert_eq(&TensorData::from([[5.]]), false);\n}\n\n#[test]\nfn test_max_dim_with_indices_2d_with_dim_0th() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    // Positive, Negative Index\n    for idx in [0, -2] {\n        let (output, index) = tensor.clone().max_dim_with_indices(idx);\n\n        let output_expected = TensorData::from([[3., 4., 5.]]);\n        let index_expected = TensorData::from([[1, 1, 1]]);\n\n        output.into_data().assert_eq(&output_expected, false);\n        index.into_data().assert_eq(&index_expected, false);\n    }\n}\n\n#[test]\nfn test_max_dim_with_indices_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let (output, index) = tensor.max_dim_with_indices(1);\n\n    let output_expected = TensorData::from([[2.], [5.]]);\n    let index_expected = TensorData::from([[2], [2]]);\n\n    output.into_data().assert_eq(&output_expected, false);\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_max_dim_2d_with_0th_dim() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let output = tensor.max_dim(0);\n    let expected = TensorData::from([[3., 4., 5.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_max_pair() {\n    let a = TestTensor::<1>::from_floats([1.0, 2.0, 3.0, 4.0], &Default::default());\n    let b = TestTensor::from_floats([2.0, 1.0, 4.0, 5.0], &Default::default());\n\n    let output = a.max_pair(b);\n    let expected = TensorData::from([2.0, 2.0, 4.0, 5.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_min_dim_2d() {\n    let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    f.clone()\n        .min_dim(0)\n        .into_data()\n        .assert_eq(&TensorData::from([[0., 1., 2.]]), false);\n\n    f.clone()\n        .min_dim(1)\n        .into_data()\n        .assert_eq(&TensorData::from([[0.], [3.]]), false);\n\n    // Negative Index\n    f.clone()\n        .min_dim(-1)\n        .into_data()\n        .assert_eq(&TensorData::from([[0.], [3.]]), false);\n\n    // Regression Test: https://github.com/tracel-ai/burn/issues/3139\n    let z = f.int();\n    z.clone()\n        .min_dim(0)\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 2]]), false);\n    z.clone()\n        .min_dim(1)\n        .into_data()\n        .assert_eq(&TensorData::from([[0], [3]]), false);\n}\n\n#[test]\nfn test_min_dims_2d() {\n    let f = TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    f.clone()\n        .min_dims(&[0])\n        .into_data()\n        .assert_eq(&TensorData::from([[0., 1., 2.]]), false);\n\n    f.clone()\n        .min_dims(&[-2])\n        .into_data()\n        .assert_eq(&TensorData::from([[0., 1., 2.]]), false);\n\n    f.clone()\n        .min_dims(&[0, 1])\n        .into_data()\n        .assert_eq(&TensorData::from([[0.]]), false);\n}\n\n#[test]\nfn test_min_dim_with_indices_2d() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let (output, index) = tensor.min_dim_with_indices(1);\n\n    let output_expected = TensorData::from([[0.], [3.]]);\n    let index_expected = TensorData::from([[0], [0]]);\n\n    output.into_data().assert_eq(&output_expected, false);\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_min_dim_2d_with_0th_dim() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    let output = tensor.min_dim(0);\n    let expected = TensorData::from([[0., 1., 2.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_min_dim_with_indices_2d_with_0th_dim() {\n    let tensor =\n        TestTensor::<2>::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &Default::default());\n\n    // Positive, Negative Index\n    for idx in [0, -2] {\n        let (output, index) = tensor.clone().min_dim_with_indices(idx);\n\n        let output_expected = TensorData::from([[0., 1., 2.]]);\n        let index_expected = TensorData::from([[0, 0, 0]]);\n\n        output.into_data().assert_eq(&output_expected, false);\n        index.into_data().assert_eq(&index_expected, false);\n    }\n}\n\n#[test]\nfn test_min_pair() {\n    let a = TestTensor::<1>::from_floats([1.0, 2.0, 3.0, 4.0], &Default::default());\n    let b = TestTensor::from_floats([2.0, 1.0, 4.0, 5.0], &Default::default());\n\n    let output = a.min_pair(b);\n    let expected = TensorData::from([1.0, 1.0, 3.0, 4.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_max_abs() {\n    let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default());\n\n    let output = tensor.max_abs();\n    let expected = TensorData::from([6.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_max_abs_dim_2d_dim_0() {\n    let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default());\n\n    let output = tensor.clone().max_abs_dim(0);\n    let expected = TensorData::from([[5., 6., 2.]]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Negative Index\n    let output = tensor.clone().max_abs_dim(-2);\n    let expected = TensorData::from([[5., 6., 2.]]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_max_abs_dims_2d() {\n    let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default());\n\n    tensor\n        .clone()\n        .max_abs_dims(&[0])\n        .into_data()\n        .assert_eq(&TensorData::from([[5., 6., 2.]]), false);\n\n    tensor\n        .clone()\n        .max_abs_dims(&[-2])\n        .into_data()\n        .assert_eq(&TensorData::from([[5., 6., 2.]]), false);\n\n    tensor\n        .clone()\n        .max_abs_dims(&[0, 1])\n        .into_data()\n        .assert_eq(&TensorData::from([[6.]]), false);\n}\n\n#[test]\nfn test_max_abs_dim_2d_dim_1() {\n    let tensor = TestTensor::<2>::from_floats([[0., 1., -2.], [-5., 6., 1.]], &Default::default());\n\n    let output = tensor.max_abs_dim(1);\n    let expected = TensorData::from([[2.], [6.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/mod.rs",
    "content": "use super::*;\n\nmod abs;\nmod add;\nmod aggregation;\nmod all;\nmod any;\nmod arg;\nmod cast;\nmod cat;\nmod ceil;\nmod chunk;\nmod clamp;\nmod close;\nmod comparison;\nmod create_like;\nmod cross;\nmod cumulative;\nmod div;\nmod dot;\nmod erf;\nmod exp;\nmod expand;\nmod finite;\nmod flatten;\nmod flip;\nmod floor;\nmod fmod;\nmod full;\nmod gather_scatter;\nmod grid_sample;\nmod inf;\nmod init;\nmod iter_dim;\nmod log;\nmod log1p;\nmod mask;\nmod matmul;\nmod maxmin;\nmod movedim;\nmod mul;\nmod nan;\nmod narrow;\nmod neg;\nmod one_hot;\nmod padding;\nmod permute;\nmod powf;\nmod powf_scalar;\nmod prod;\nmod random;\nmod recip;\nmod remainder;\nmod repeat;\nmod repeat_dim;\nmod reshape;\nmod round;\nmod select;\nmod sign;\nmod slice;\nmod slice_assign;\nmod sort_argsort;\nmod split;\nmod sqrt;\nmod square;\nmod squeeze;\nmod stack;\nmod sub;\nmod take;\nmod topk;\nmod transaction;\nmod transpose;\nmod tri;\nmod trig;\nmod trunc;\nmod unfold;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/movedim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn movedim_float() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .float();\n\n    let permuted = tensor.clone().movedim(0, 2);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()\n    let expected = TensorData::from([\n        [[0., 12.], [1., 13.], [2., 14.], [3., 15.]],\n        [[4., 16.], [5., 17.], [6., 18.], [7., 19.]],\n        [[8., 20.], [9., 21.], [10., 22.], [11., 23.]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().movedim(0, -1);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().movedim(0, 0);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\nfn vec_input_float() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .float();\n\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);\n    // from pytorch\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0]).float()\n    let expected = TensorData::from([\n        [[0., 1., 2., 3.], [12., 13., 14., 15.]],\n        [[4., 5., 6., 7.], [16., 17., 18., 19.]],\n        [[8., 9., 10., 11.], [20., 21., 22., 23.]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axes\n    let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axes\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\nfn different_input_types() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .float();\n\n    let permuted = tensor.clone().movedim(0_usize, 2_i32);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2).float()\n    let expected = TensorData::from([\n        [[0., 12.], [1., 13.], [2., 14.], [3., 15.]],\n        [[4., 16.], [5., 17.], [6., 18.], [7., 19.]],\n        [[8., 20.], [9., 21.], [10., 22.], [11., 23.]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().movedim(0_usize, -1);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().movedim(0_i32, 0_usize);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\n#[should_panic]\nfn edge_different_sizes() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with a repeated axis\n    let _ = tensor.clone().movedim(vec![0, 1], vec![0]);\n}\n\n#[test]\n#[should_panic]\nfn edge_out_of_bound_axis() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with an out of bound axis\n    let _ = tensor.clone().movedim(0, 100);\n}\n\n#[test]\n#[should_panic]\nfn edge_vec_is_not_a_set() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with a repeated axis\n    let _ = tensor.clone().movedim(vec![0, 1, 1, 1, 1], vec![0, 0, 1]);\n}\n\n#[test]\n#[should_panic]\nfn edge_out_of_bound_axis_vec() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with an out of bound axis\n    let _ = tensor.clone().movedim(vec![0, 100], vec![0, 1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/mul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_mul_ops() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data_2 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_mul_broadcast() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);\n    let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_mul_broadcast_2_dims() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device).reshape([3, 1]);\n    let tensor_2 = TestTensor::<1>::from_data([3.0, 4.0, 5.0], &device).reshape([1, 3]);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mul_scalar_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let scalar = 2.0;\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor * scalar;\n    let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/nan.rs",
    "content": "use super::*;\nuse burn_tensor::cast::ToElement;\n\n#[test]\nfn is_nan() {\n    let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let no_nan_expected = TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);\n\n    let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]);\n    let with_nan_expected = TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);\n\n    assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data());\n\n    assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data());\n}\n\n#[test]\nfn contains_nan() {\n    let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    assert!(!no_nan.contains_nan().into_scalar().to_bool());\n\n    let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]);\n    assert!(with_nan.contains_nan().into_scalar().to_bool());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/narrow.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Shape, TensorData};\n\n#[test]\nfn test_narrow_1() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let output = tensor.clone().narrow(0, 0, 2);\n    let expected = TensorData::from([[1., 2., 3.], [4., 5., 6.]]);\n\n    assert_eq!(output.shape(), Shape::from([2, 3]));\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_narrow_2() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let output = tensor.clone().narrow(1, 1, 2);\n    let expected = TensorData::from([[2., 3.], [5., 6.], [8., 9.]]);\n    assert_eq!(output.shape(), Shape::from([3, 2]));\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_narrow_3() {\n    let device = &Default::default();\n    let shape = Shape::new([8, 8]);\n    let tensor = TestTensorInt::arange(0..shape.num_elements() as i64, device)\n        .reshape::<2, _>(shape)\n        .float();\n\n    let output = tensor.clone().narrow(0, 3, 4);\n    let expected = TensorData::from([\n        [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],\n        [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],\n        [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],\n        [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0],\n    ]);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_dim() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let _output = tensor.narrow(2, 0, 2);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_start() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let _output = tensor.narrow(0, 3, 2);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_zero_length() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let _output = tensor.narrow(0, 1, 0);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_length() {\n    let tensor = TestTensor::<2>::from_data(\n        TensorData::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]),\n        &Default::default(),\n    );\n\n    let _output = tensor.narrow(0, 0, 4);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/neg.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_neg_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.neg();\n    let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::<f32>();\n\n    // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32\n    assert_eq!(\n        output\n            .into_data()\n            .convert::<f32>()\n            .as_slice::<f32>()\n            .unwrap(),\n        expected.as_slice::<f32>().unwrap()\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/one_hot.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn float_should_support_one_hot() {\n    let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]);\n    let one_hot_tensor: TestTensor<2> = tensor.one_hot(5);\n    let expected = TensorData::from([\n        [1.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n    ]);\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn float_should_support_one_hot_index() {\n    let tensor = TestTensor::<1>::from([2.0]);\n    let one_hot_tensor: TestTensor<2> = tensor.one_hot::<2>(10);\n    let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]);\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn float_one_hot_should_panic_when_index_exceeds_number_of_classes() {\n    let tensor = TestTensor::<1>::from([5.0]);\n    let _result: TestTensor<2> = tensor.one_hot(5);\n}\n\n#[test]\n#[should_panic]\nfn float_one_hot_should_panic_when_number_of_classes_is_zero() {\n    let tensor = TestTensor::<1>::from([0.0]);\n    let _result: TestTensor<2> = tensor.one_hot(0);\n}\n\n#[test]\nfn one_hot_fill_with_negative_axis_and_indices() {\n    let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]);\n    let expected = TensorData::from([\n        [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]],\n        [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]],\n    ]);\n\n    let one_hot_tensor: TestTensor<3> = tensor.one_hot_fill(3, 5.0, 0.0, -1);\n\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn one_hot_fill_with_negative_indices() {\n    let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]);\n    let expected = TensorData::from([\n        [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n    ]);\n\n    let one_hot_tensor: TestTensor<2> = tensor.one_hot_fill(10, 3.0, 1.0, 1);\n\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[should_panic]\n#[test]\nfn one_hot_fill_should_panic_when_axis_out_range_of_rank() {\n    let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]);\n\n    let _one_hot_tensor: TestTensor<3> = tensor.one_hot_fill(2, 5.0, 0.0, 3);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/padding.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, ops::PadMode};\n\n#[test]\nfn padding_constant_2d_test() {\n    let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]];\n    let tensor = TestTensor::<2>::from(unpadded_floats);\n\n    let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1);\n\n    let expected = TensorData::from([\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1],\n        [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_4d_test() {\n    let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]];\n    let tensor = TestTensor::<4>::from(unpadded_floats);\n\n    let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1);\n\n    let expected = TensorData::from([[[\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 0.0, 1.0, 1.1, 1.1],\n        [1.1, 1.1, 2.0, 3.0, 1.1, 1.1],\n        [1.1, 1.1, 4.0, 5.0, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n    ]]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_asymmetric_test() {\n    let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]];\n    let tensor = TestTensor::<4>::from(unpadded_floats);\n\n    let padded_tensor = tensor.pad((2, 1, 4, 3), 1.1);\n\n    let expected = TensorData::from([[[\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 0.0, 1.0, 1.1],\n        [1.1, 1.1, 2.0, 3.0, 1.1],\n        [1.1, 1.1, 4.0, 5.0, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1],\n    ]]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_2d_test() {\n    // Test reflect padding on a 2D tensor\n    // Input: [[1, 2, 3], [4, 5, 6]]\n    // With padding (1, 1, 1, 1):\n    // - Top: reflect row 1 -> [4, 5, 6]\n    // - Bottom: reflect row 0 -> [1, 2, 3]\n    // - Left: reflect col 1\n    // - Right: reflect col 1\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect);\n\n    // Expected: reflect excludes the edge value\n    // Before padding height: [[1,2,3], [4,5,6]]\n    // After top pad (reflect row at index 1): [[4,5,6], [1,2,3], [4,5,6]]\n    // After bottom pad (reflect row at index 1 from end): [[4,5,6], [1,2,3], [4,5,6], [1,2,3]]\n    // Then pad width similarly\n    let expected = TensorData::from([\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [2.0, 1.0, 2.0, 3.0, 2.0],\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [2.0, 1.0, 2.0, 3.0, 2.0],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_width_only_test() {\n    // Test reflect padding on width dimension only\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]);\n\n    let padded_tensor = tensor.pad((2, 2, 0, 0), PadMode::Reflect);\n\n    // Input: [1, 2, 3, 4]\n    // Reflect left 2: take indices [1, 2] = [2, 3], flip = [3, 2]\n    // Reflect right 2: take indices [1, 2] from end = [2, 3], flip = [3, 2]\n    // Result: [3, 2, 1, 2, 3, 4, 3, 2]\n    let expected = TensorData::from([[3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_4d_test() {\n    // Test reflect padding on 4D tensor (common for images: NCHW)\n    let tensor = TestTensor::<4>::from([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]);\n\n    let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Reflect);\n\n    let expected = TensorData::from([[[\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [2.0, 1.0, 2.0, 3.0, 2.0],\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [8.0, 7.0, 8.0, 9.0, 8.0],\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n    ]]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_edge_2d_test() {\n    // Test edge padding on a 2D tensor\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge);\n\n    // Edge padding replicates the boundary values\n    let expected = TensorData::from([\n        [1.0, 1.0, 2.0, 3.0, 3.0],\n        [1.0, 1.0, 2.0, 3.0, 3.0],\n        [4.0, 4.0, 5.0, 6.0, 6.0],\n        [4.0, 4.0, 5.0, 6.0, 6.0],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_edge_width_only_test() {\n    // Test edge padding on width dimension only\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]);\n\n    let padded_tensor = tensor.pad((2, 3, 0, 0), PadMode::Edge);\n\n    // Input: [1, 2, 3, 4]\n    // Left 2: [1, 1]\n    // Right 3: [4, 4, 4]\n    // Result: [1, 1, 1, 2, 3, 4, 4, 4, 4]\n    let expected = TensorData::from([[1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 4.0]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_edge_4d_test() {\n    // Test edge padding on 4D tensor\n    let tensor = TestTensor::<4>::from([[[[1.0, 2.0], [3.0, 4.0]]]]);\n\n    let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::Edge);\n\n    let expected = TensorData::from([[[\n        [1.0, 1.0, 2.0, 2.0],\n        [1.0, 1.0, 2.0, 2.0],\n        [3.0, 3.0, 4.0, 4.0],\n        [3.0, 3.0, 4.0, 4.0],\n    ]]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_default_test() {\n    // Test default PadMode (Constant with 0.0)\n    let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let padded_tensor = tensor.pad((1, 1, 1, 1), PadMode::default());\n\n    let expected = TensorData::from([\n        [0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 2.0, 0.0],\n        [0.0, 3.0, 4.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_max_valid_test() {\n    // Test reflect padding at maximum valid size (dim_size - 1)\n    // For a 4-element dimension, max valid padding is 3\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0, 4.0]]);\n\n    // Padding of 3 on left is valid for width=4 (3 < 4)\n    let padded_tensor = tensor.pad((3, 3, 0, 0), PadMode::Reflect);\n\n    // Input: [1, 2, 3, 4]\n    // Reflect left 3: take indices [1, 2, 3] = [2, 3, 4], flip = [4, 3, 2]\n    // Reflect right 3: take indices [0, 1, 2] = [1, 2, 3], flip = [3, 2, 1]\n    // Result: [4, 3, 2, 1, 2, 3, 4, 3, 2, 1]\n    let expected = TensorData::from([[4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 1.0]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_asymmetric_test() {\n    // Test asymmetric reflect padding\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);\n\n    // Asymmetric padding: left=2, right=1, top=1, bottom=2\n    let padded_tensor = tensor.pad((2, 1, 1, 2), PadMode::Reflect);\n\n    let expected = TensorData::from([\n        [6.0, 5.0, 4.0, 5.0, 6.0, 5.0],\n        [3.0, 2.0, 1.0, 2.0, 3.0, 2.0],\n        [6.0, 5.0, 4.0, 5.0, 6.0, 5.0],\n        [9.0, 8.0, 7.0, 8.0, 9.0, 8.0],\n        [6.0, 5.0, 4.0, 5.0, 6.0, 5.0],\n        [3.0, 2.0, 1.0, 2.0, 3.0, 2.0],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic(expected = \"Reflect padding\")]\nfn padding_reflect_exceeds_dimension_test() {\n    // Test that reflect padding panics when padding >= dim_size\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0]]);\n\n    // Padding of 3 on width=3 should panic (3 >= 3, need padding < dim_size)\n    let _ = tensor.pad((3, 0, 0, 0), PadMode::Reflect);\n}\n\n#[test]\nfn padding_edge_asymmetric_test() {\n    // Test asymmetric edge padding\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    // Asymmetric padding: left=2, right=1, top=3, bottom=1\n    let padded_tensor = tensor.pad((2, 1, 3, 1), PadMode::Edge);\n\n    let expected = TensorData::from([\n        [1.0, 1.0, 1.0, 2.0, 3.0, 3.0],\n        [1.0, 1.0, 1.0, 2.0, 3.0, 3.0],\n        [1.0, 1.0, 1.0, 2.0, 3.0, 3.0],\n        [1.0, 1.0, 1.0, 2.0, 3.0, 3.0],\n        [4.0, 4.0, 4.0, 5.0, 6.0, 6.0],\n        [4.0, 4.0, 4.0, 5.0, 6.0, 6.0],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_zero_padding_test() {\n    // Test that zero padding returns the original tensor unchanged\n    let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let padded_constant = tensor.clone().pad((0, 0, 0, 0), PadMode::Constant(5.0));\n    let padded_reflect = tensor.clone().pad((0, 0, 0, 0), PadMode::Reflect);\n    let padded_edge = tensor.clone().pad((0, 0, 0, 0), PadMode::Edge);\n\n    let expected = TensorData::from([[1.0, 2.0], [3.0, 4.0]]);\n    padded_constant.into_data().assert_eq(&expected, false);\n    padded_reflect.into_data().assert_eq(&expected, false);\n    padded_edge.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_empty_tensor_constant_test() {\n    // Test constant padding on an empty tensor (zero-sized dimension)\n    // This should work - creates a tensor filled with the constant value\n    let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default());\n\n    // Padding an empty height dimension with constant should create a tensor of just padding\n    let padded = tensor.pad((0, 0, 2, 2), 1.0);\n\n    // Result should be 4x3 (0 + 2 + 2 = 4 rows)\n    assert_eq!(padded.dims(), [4, 3]);\n\n    let expected = TensorData::from([\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic(expected = \"edge padding\")]\nfn padding_empty_tensor_edge_panics_test() {\n    // Test that edge padding panics on empty tensor\n    let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default());\n\n    // Edge padding on zero-sized dimension should panic\n    let _ = tensor.pad((0, 0, 1, 1), PadMode::Edge);\n}\n\n#[test]\n#[should_panic(expected = \"Reflect padding\")]\nfn padding_empty_tensor_reflect_panics_test() {\n    // Test that reflect padding panics on empty tensor\n    let tensor: TestTensor<2> = TestTensor::empty([0, 3], &Default::default());\n\n    // Reflect padding on zero-sized dimension should panic\n    let _ = tensor.pad((0, 0, 1, 1), PadMode::Reflect);\n}\n\n// --- Tests for N-dimensional padding using (before, after) pairs ---\n\n#[test]\nfn padding_constant_pairs_2d_test() {\n    // Same as padding_constant_2d_test but using the new pairs API\n    let tensor = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    // [(row_before, row_after), (col_before, col_after)]\n    let padded_tensor = tensor.pad([(2, 2), (2, 2)], 1.1);\n\n    let expected = TensorData::from([\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1],\n        [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n        [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],\n    ]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_single_dim_test() {\n    // Pad only the last dimension\n    let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let padded_tensor = tensor.pad([(1, 1)], 0.0);\n\n    let expected = TensorData::from([[0.0, 1.0, 2.0, 0.0], [0.0, 3.0, 4.0, 0.0]]);\n    padded_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_all_dims_4d_test() {\n    // Pad all 4 dimensions of a 4D tensor (batch, channel, height, width)\n    // Input: shape [1, 1, 2, 2]\n    let tensor = TestTensor::<4>::from([[[[1.0, 2.0], [3.0, 4.0]]]]);\n\n    // Pad: batch(1,1), channel(1,1), height(0,0), width(0,0)\n    let padded = tensor.pad([(1, 1), (1, 1), (0, 0), (0, 0)], 0.0);\n\n    // Shape should be [3, 3, 2, 2]\n    assert_eq!(padded.dims(), [3, 3, 2, 2]);\n\n    let expected = TensorData::from([\n        [\n            [[0.0, 0.0], [0.0, 0.0]],\n            [[0.0, 0.0], [0.0, 0.0]],\n            [[0.0, 0.0], [0.0, 0.0]],\n        ],\n        [\n            [[0.0, 0.0], [0.0, 0.0]],\n            [[1.0, 2.0], [3.0, 4.0]],\n            [[0.0, 0.0], [0.0, 0.0]],\n        ],\n        [\n            [[0.0, 0.0], [0.0, 0.0]],\n            [[0.0, 0.0], [0.0, 0.0]],\n            [[0.0, 0.0], [0.0, 0.0]],\n        ],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_constant_batch_dim_only_test() {\n    // Pad only the batch dimension of a 3D tensor [N, H, W]\n    let tensor = TestTensor::<3>::from([[[1.0, 2.0], [3.0, 4.0]]]);\n\n    // 3 pairs for 3 dims: batch(1,1), height(0,0), width(0,0)\n    let padded = tensor.pad([(1, 1), (0, 0), (0, 0)], -1.0);\n\n    assert_eq!(padded.dims(), [3, 2, 2]);\n\n    let expected = TensorData::from([\n        [[-1.0, -1.0], [-1.0, -1.0]],\n        [[1.0, 2.0], [3.0, 4.0]],\n        [[-1.0, -1.0], [-1.0, -1.0]],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_pairs_test() {\n    // Reflect padding using pairs API\n    let tensor = TestTensor::<2>::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);\n\n    let padded = tensor.pad([(1, 1), (1, 1)], PadMode::Reflect);\n\n    let expected = TensorData::from([\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [2.0, 1.0, 2.0, 3.0, 2.0],\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n        [8.0, 7.0, 8.0, 9.0, 8.0],\n        [5.0, 4.0, 5.0, 6.0, 5.0],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_edge_pairs_test() {\n    // Edge padding using pairs API\n    let tensor = TestTensor::<2>::from([[1.0, 2.0], [3.0, 4.0]]);\n\n    let padded = tensor.pad([(1, 1), (1, 1)], PadMode::Edge);\n\n    let expected = TensorData::from([\n        [1.0, 1.0, 2.0, 2.0],\n        [1.0, 1.0, 2.0, 2.0],\n        [3.0, 3.0, 4.0, 4.0],\n        [3.0, 3.0, 4.0, 4.0],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_reflect_batch_dim_3d_test() {\n    // Reflect pad the batch dimension of a 3D tensor [N, H, W]\n    // Input shape: [3, 1, 2] - 3 batches, 1 row, 2 cols\n    let tensor = TestTensor::<3>::from([[[1.0, 2.0]], [[3.0, 4.0]], [[5.0, 6.0]]]);\n\n    // Pad batch dim with reflect(1, 1), no spatial padding\n    let padded = tensor.pad([(1, 1), (0, 0), (0, 0)], PadMode::Reflect);\n\n    assert_eq!(padded.dims(), [5, 1, 2]);\n\n    // Reflect on batch: [3,4] [1,2] [3,4] [5,6] [3,4]\n    let expected = TensorData::from([\n        [[3.0, 4.0]],\n        [[1.0, 2.0]],\n        [[3.0, 4.0]],\n        [[5.0, 6.0]],\n        [[3.0, 4.0]],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn padding_edge_batch_dim_3d_test() {\n    // Edge pad the batch dimension of a 3D tensor\n    let tensor = TestTensor::<3>::from([[[1.0, 2.0]], [[3.0, 4.0]]]);\n\n    let padded = tensor.pad([(2, 1), (0, 0), (0, 0)], PadMode::Edge);\n\n    assert_eq!(padded.dims(), [5, 1, 2]);\n\n    let expected = TensorData::from([\n        [[1.0, 2.0]],\n        [[1.0, 2.0]],\n        [[1.0, 2.0]],\n        [[3.0, 4.0]],\n        [[3.0, 4.0]],\n    ]);\n    padded.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic(expected = \"Padding has\")]\nfn padding_too_many_pairs_panics_test() {\n    let tensor = TestTensor::<2>::from([[1.0, 2.0]]);\n\n    // 3 pairs for a 2D tensor should panic\n    let _ = tensor.pad([(1, 1), (1, 1), (1, 1)], 0.0);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/permute.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn permute_float_a() {\n    let tensor = TestTensor::<1>::from([\n        0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,\n    ])\n    .reshape([2, 2, 4]);\n\n    let permuted = tensor.clone().permute([2, 1, 0]);\n\n    let expected = TensorData::from([\n        [[0., 8.], [4., 12.]],\n        [[1., 9.], [5., 13.]],\n        [[2., 10.], [6., 14.]],\n        [[3., 11.], [7., 15.]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().permute([-1, 1, 0]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().permute([0, 1, 2]);\n    permuted.into_data().assert_eq(&tensor.into_data(), false);\n}\n\n#[test]\nfn permute_float() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device)\n        .reshape([2, 3, 4])\n        .float();\n\n    let permuted = tensor.clone().permute([2, 1, 0]);\n\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0).float()\n    let expected = TensorData::from([\n        [[0., 12.], [4., 16.], [8., 20.]],\n        [[1., 13.], [5., 17.], [9., 21.]],\n        [[2., 14.], [6., 18.], [10., 22.]],\n        [[3., 15.], [7., 19.], [11., 23.]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().permute([-1, 1, 0]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().permute([0, 1, 2]);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/powf.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_powf_ops() {\n    let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n    let pow = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);\n    let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default());\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_power() {\n    let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n    let pow = TensorData::from([[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]]);\n    let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default());\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_values_with_even_power() {\n    let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n    let pow = TensorData::from([[2.0, 2.0, 4.0], [4.0, 4.0, 2.0]]);\n    let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default());\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1.0, 1.0, 16.0], [81.0, 256.0, 25.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_values_with_odd_power() {\n    let data = TensorData::from([[1.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n    let pow = TensorData::from([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]);\n    let tensor_pow = TestTensor::<2>::from_data(pow, &Default::default());\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_powf_broadcasted() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<1>::from_floats([2.0, 3.0, 4.0], &device);\n    let tensor_2 = TestTensor::from_floats([1.0], &device);\n\n    // Broadcast rhs\n    let output = tensor_1.clone().powf(tensor_2.clone());\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&tensor_1.to_data(), Tolerance::default());\n\n    // Broadcast lhs\n    let output = tensor_2.powf(tensor_1);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([1.0, 1.0, 1.0]), Tolerance::default());\n}\n\nfn outer(a: TestTensor<1>, b: TestTensor<1>) -> TestTensor<2> {\n    a.unsqueeze_dim::<2>(1) * b.unsqueeze_dim::<2>(0)\n}\n\n#[test]\nfn should_support_powf_scalar_tensor() {\n    let device = Default::default();\n    let head_dim = 64;\n    let seq_len = 1024;\n    let base = 10000;\n\n    let channel_range = TestTensorInt::arange_step(0..head_dim as i64, 2, &device).float();\n    let base = TestTensor::<1>::from_data([base as f32], &device);\n    let inv_freq = base.powf(-channel_range / head_dim as f32);\n\n    let t = TestTensorInt::arange(0..seq_len as i64, &device).float();\n\n    let freqs = outer(t, inv_freq);\n\n    let _cos = freqs.clone().cos();\n    let _sin = freqs.sin();\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/powf_scalar.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_powf_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.powf_scalar(0.71);\n    let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.1815, 2.67586, 3.13522]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_power() {\n    let data = TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.powf_scalar(-0.33);\n    let expected = TensorData::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_values_with_even_power() {\n    let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.powf_scalar(4.0);\n    let expected = TensorData::from([[0.0, 1.0, 16.0], [81.0, 256.0, 625.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_neg_values_with_odd_power() {\n    let data = TensorData::from([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.powf_scalar(3.0);\n    let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -125.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/prod.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_prod_float() {\n    let tensor_1 = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor_1.prod();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([-600.0]), false);\n}\n\n#[test]\nfn test_prod_dim_2d() {\n    let f = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    f.clone()\n        .prod_dim(1)\n        .into_data()\n        .assert_eq(&TensorData::from([[-10.0], [60.0]]), false);\n\n    f.clone()\n        .prod_dim(-1)\n        .into_data()\n        .assert_eq(&TensorData::from([[-10.0], [60.0]]), false);\n}\n\n#[test]\nfn test_prod_dims_2d() {\n    let f = TestTensor::<2>::from([[-5.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    f.clone()\n        .prod_dims(&[1])\n        .into_data()\n        .assert_eq(&TensorData::from([[-10.0], [60.0]]), false);\n\n    f.clone()\n        .prod_dims(&[-1])\n        .into_data()\n        .assert_eq(&TensorData::from([[-10.0], [60.0]]), false);\n\n    f.clone()\n        .prod_dims(&[0, 1])\n        .into_data()\n        .assert_eq(&TensorData::from([[-600.0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/random.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, ElementConversion, TensorData, Tolerance, backend::Backend};\n\n#[test]\nfn rand_default() {\n    let tensor = TestTensor::<1>::random([20], Distribution::Default, &Default::default());\n\n    // check that the tensor is within the range of [0..1) (1 is exclusive)\n    // the conversion can ceil the value if `FloatElem` is less precise than f32\n    let low = 0.elem::<FloatElem>();\n    let high = 1.elem::<FloatElem>();\n    if FloatElem::EPSILON.elem::<f32>() > f32::EPSILON {\n        tensor.into_data().assert_within_range_inclusive(low..=high);\n    } else {\n        tensor.into_data().assert_within_range(low..high);\n    }\n}\n\n#[test]\nfn rand_uniform() {\n    let tensor = TestTensor::<1>::random([20], Distribution::Uniform(4., 5.), &Default::default());\n    let low = 4.elem::<FloatElem>();\n    let high = 5.elem::<FloatElem>();\n\n    if FloatElem::EPSILON.elem::<f32>() > f32::EPSILON {\n        tensor.into_data().assert_within_range_inclusive(low..=high);\n    } else {\n        tensor.into_data().assert_within_range(low..high);\n    }\n}\n\n#[test]\nfn rand_bernoulli() {\n    let tensor = TestTensor::<1>::random([20], Distribution::Bernoulli(1.), &Default::default());\n\n    tensor.into_data().assert_eq(\n        &TensorData::new::<FloatElem, _>(vec![1.elem(); 20], [20]),\n        true,\n    );\n}\n\n#[test]\n#[ignore] // TODO: mark serial for backends that handle the same devices (e.g. fusion)?\nfn test_seed_reproducibility() {\n    let device = Default::default();\n    TestBackend::seed(&device, 42);\n    let t1 = TestTensor::<1>::random([5], Distribution::Default, &device);\n    TestBackend::seed(&device, 42);\n    let t2 = TestTensor::<1>::random([5], Distribution::Default, &device);\n\n    t1.into_data()\n        .assert_approx_eq::<FloatElem>(&t2.into_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/recip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_recip_ops() {\n    let data = TensorData::from([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.recip();\n    let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/remainder.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n/// From https://pytorch.org/docs/stable/generated/torch.remainder.html\n#[test]\nfn should_support_remainder_basic() {\n    let device = Default::default();\n    let lhs =\n        TestTensor::<1>::from_data(TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]), &device);\n    let rhs = TestTensor::<1>::from_data(TensorData::from([2.0, 3.0, 1.0, 2.0, 1.0, 3.0]), &device);\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([1.0, 1.0, -0.0, 1.0, 0.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_remainder_basic_scalar() {\n    let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.remainder_scalar(2.0);\n    let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_remainder_float() {\n    let device = Default::default();\n    let lhs = TestTensor::<1>::from_data(TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]), &device);\n    let rhs = TestTensor::<1>::from_data(\n        TensorData::from([1.4233, 2.7313, 0.2641, 1.9651, 0.5897]),\n        &device,\n    );\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([1.0, 2.0, 0.0949, 0.0698, 0.2824]);\n\n    // Metal has less precise remainder function\n    let tolerance = Tolerance::default()\n        .set_half_precision_relative(1e-2)\n        .set_half_precision_absolute(2e-3);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n/// Also from https://pytorch.org/docs/stable/generated/torch.remainder.html\n#[test]\nfn should_support_remainder_float_scalar() {\n    let data = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.clone().remainder_scalar(-1.5);\n    let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_be_zero() {\n    let device = Default::default();\n    let lhs = TestTensor::<1>::from_data(TensorData::from([0.0, 0.0, 0.0]), &device);\n    let rhs = TestTensor::<1>::from_data(TensorData::from([3.5, -2.1, 1e-4]), &device);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([0.0, 0.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_be_zero_scalar() {\n    let data = TensorData::from([0.0, 0.0, 0.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.clone().remainder_scalar(3.5);\n    let expected = TensorData::from([0.0, 0.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_have_no_remainder() {\n    let device = Default::default();\n    let lhs = TestTensor::<1>::from_data(\n        // Previous values failed on some vulkan backends (driver bug?)\n        // TensorData::from([-1.4843, 1.1350, -2.1563, 1.0862, 0.5, 3.6587]),\n        TensorData::from([-1.0, 1.5, -2.0, 2.5, 0.5, 4.0]),\n        &device,\n    );\n    let rhs = TestTensor::<1>::from_data(\n        // TensorData::from([1.4843, 1.1350, 2.1563, 1.0862, 0.5, 3.6587]),\n        TensorData::from([1.0, 1.5, 2.0, 2.5, 0.5, 4.0]),\n        &device,\n    );\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([-0., 0., -0., 0., 0., 0.]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_have_no_remainder_scalar() {\n    let data = TensorData::from([-4.0, 4.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.remainder_scalar(4.0);\n    let expected = TensorData::from([-0.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_be_negative() {\n    let device = Default::default();\n\n    let lhs = TestTensor::<1>::from_data(TensorData::from([-7.0, -3.0, 2.0, 6.0]), &device);\n    let rhs = TestTensor::<1>::from_data(TensorData::from([-2.5, -2.1, -1.5, -3.25]), &device);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([-2.0, -0.9, -1.0, -0.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_be_negative_scalar() {\n    let data = TensorData::from([-7.0, -3.0, 2.0, 6.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.clone().remainder_scalar(-2.5);\n    let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_fp_dividends() {\n    let data = TensorData::from([-7.5, -2.5, 2.5, 7.5]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.remainder_scalar(3.0);\n    let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    // for tensor.remainder case, tests above have already covered float point dividend cases\n}\n\n#[test]\nfn should_support_large_divisor() {\n    let device = Default::default();\n\n    let lhs = TestTensor::<1>::from_data(\n        TensorData::from([-1.0, 1.0, -1.5, 1.5, -1.0, 1.0, -1.5, 1.5]),\n        &device,\n    );\n    let rhs = TestTensor::<1>::from_data(\n        TensorData::from([10.0, 10.0, 10.0, 10.0, -10.0, -10.0, -10.0, -10.0]),\n        &device,\n    );\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([9.0, 1.0, 8.5, 1.5, -1.0, -9.0, -1.5, -8.5]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_large_divisor_scalar() {\n    let data = TensorData::from([-1.0, 1.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.remainder_scalar(10.0);\n    let expected = TensorData::from([9.0, 1.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_remainder_op() {\n    let device = Default::default();\n    let lhs =\n        TestTensor::<1>::from_data(TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]), &device);\n    let rhs = TestTensor::<1>::from_data(TensorData::from([2.0, 3.0, 1.0, 2.0, 1.0, 3.0]), &device);\n\n    let output = lhs % rhs;\n    let expected = TensorData::from([1.0, 1.0, -0.0, 1.0, 0.0, 0.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_remainder_scalar_op() {\n    let data = TensorData::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor % 2.0;\n    let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/repeat.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_repeat_ops_one_dimension() {\n    let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[4, 1, 1]);\n    let expected = TensorData::from([\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_float_repeat_repeating_on_many_dimensions() {\n    let data = TensorData::from([\n        [[1.0f32, 2.0f32], [3.0f32, 4.0f32]],\n        [[5.0f32, 6.0f32], [7.0f32, 8.0f32]],\n        [[9.0f32, 10.0f32], [11.0f32, 12.0f32]],\n        [[13.0f32, 14.0f32], [15.0f32, 16.0f32]],\n    ]);\n    let tensor = TestTensor::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[2, 3, 2]);\n    let expected = TensorData::from([\n        [\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n        ],\n        [\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n        ],\n        [\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n        ],\n        [\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n        ],\n        [\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n        ],\n        [\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n        ],\n        [\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n        ],\n        [\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_repeat_0_times_empty() {\n    let tensor = TestTensor::<3>::ones([2, 3, 4], &Default::default());\n\n    let output = tensor.repeat(&[1, 0, 2]);\n\n    assert_eq!(output.shape(), [2, 0, 8].into());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/repeat_dim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_repeat_ops() {\n    let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    let output = tensor.repeat_dim(0, 4);\n    let expected = TensorData::from([\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n        [0.0f32, 1.0f32, 2.0f32],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_float_repeat_on_dims_larger_than_1() {\n    let data = TensorData::from([\n        [[1.0f32, 2.0f32], [3.0f32, 4.0f32]],\n        [[5.0f32, 6.0f32], [7.0f32, 8.0f32]],\n        [[9.0f32, 10.0f32], [11.0f32, 12.0f32]],\n        [[13.0f32, 14.0f32], [15.0f32, 16.0f32]],\n    ]);\n    let tensor = TestTensor::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat_dim(2, 2);\n    let expected = TensorData::from([\n        [\n            [1.0f32, 2.0f32, 1.0f32, 2.0f32],\n            [3.0f32, 4.0f32, 3.0f32, 4.0f32],\n        ],\n        [\n            [5.0f32, 6.0f32, 5.0f32, 6.0f32],\n            [7.0f32, 8.0f32, 7.0f32, 8.0f32],\n        ],\n        [\n            [9.0f32, 10.0f32, 9.0f32, 10.0f32],\n            [11.0f32, 12.0f32, 11.0f32, 12.0f32],\n        ],\n        [\n            [13.0f32, 14.0f32, 13.0f32, 14.0f32],\n            [15.0f32, 16.0f32, 15.0f32, 16.0f32],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn repeat_dim_swap_dims_1() {\n    let tensor = TestTensorInt::arange(0..16, &Default::default()).float();\n\n    let tensor = tensor.reshape([4, 1, 4]);\n    let tensor = tensor.swap_dims(0, 2);\n    let output = tensor.repeat_dim(1, 4);\n\n    let expected = TensorData::from([\n        [\n            [0.0, 4.0, 8.0, 12.0],\n            [0.0, 4.0, 8.0, 12.0],\n            [0.0, 4.0, 8.0, 12.0],\n            [0.0, 4.0, 8.0, 12.0],\n        ],\n        [\n            [1.0, 5.0, 9.0, 13.0],\n            [1.0, 5.0, 9.0, 13.0],\n            [1.0, 5.0, 9.0, 13.0],\n            [1.0, 5.0, 9.0, 13.0],\n        ],\n        [\n            [2.0, 6.0, 10.0, 14.0],\n            [2.0, 6.0, 10.0, 14.0],\n            [2.0, 6.0, 10.0, 14.0],\n            [2.0, 6.0, 10.0, 14.0],\n        ],\n        [\n            [3.0, 7.0, 11.0, 15.0],\n            [3.0, 7.0, 11.0, 15.0],\n            [3.0, 7.0, 11.0, 15.0],\n            [3.0, 7.0, 11.0, 15.0],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn repeat_dim_swap_dims_2() {\n    let tensor = TestTensorInt::arange(0..16, &Default::default()).float();\n\n    let tensor = tensor.reshape([2, 2, 1, 4]);\n    let tensor = tensor.swap_dims(0, 1);\n    let output = tensor.repeat_dim(2, 4);\n\n    let expected = TensorData::from([\n        [\n            [\n                [0.0, 1.0, 2.0, 3.0],\n                [0.0, 1.0, 2.0, 3.0],\n                [0.0, 1.0, 2.0, 3.0],\n                [0.0, 1.0, 2.0, 3.0],\n            ],\n            [\n                [8.0, 9.0, 10.0, 11.0],\n                [8.0, 9.0, 10.0, 11.0],\n                [8.0, 9.0, 10.0, 11.0],\n                [8.0, 9.0, 10.0, 11.0],\n            ],\n        ],\n        [\n            [\n                [4.0, 5.0, 6.0, 7.0],\n                [4.0, 5.0, 6.0, 7.0],\n                [4.0, 5.0, 6.0, 7.0],\n                [4.0, 5.0, 6.0, 7.0],\n            ],\n            [\n                [12.0, 13.0, 14.0, 15.0],\n                [12.0, 13.0, 14.0, 15.0],\n                [12.0, 13.0, 14.0, 15.0],\n                [12.0, 13.0, 14.0, 15.0],\n            ],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn repeat_dim_swap_dims_3() {\n    let tensor = TestTensorInt::arange(0..16, &Default::default()).float();\n\n    let tensor = tensor.reshape([1, 2, 2, 4]);\n    let tensor = tensor.swap_dims(0, 2);\n    let tensor = tensor.swap_dims(1, 3);\n    let output = tensor.repeat_dim(2, 4);\n\n    let expected = TensorData::from([\n        [\n            [[0.0, 8.0], [0.0, 8.0], [0.0, 8.0], [0.0, 8.0]],\n            [[1.0, 9.0], [1.0, 9.0], [1.0, 9.0], [1.0, 9.0]],\n            [[2.0, 10.0], [2.0, 10.0], [2.0, 10.0], [2.0, 10.0]],\n            [[3.0, 11.0], [3.0, 11.0], [3.0, 11.0], [3.0, 11.0]],\n        ],\n        [\n            [[4.0, 12.0], [4.0, 12.0], [4.0, 12.0], [4.0, 12.0]],\n            [[5.0, 13.0], [5.0, 13.0], [5.0, 13.0], [5.0, 13.0]],\n            [[6.0, 14.0], [6.0, 14.0], [6.0, 14.0], [6.0, 14.0]],\n            [[7.0, 15.0], [7.0, 15.0], [7.0, 15.0], [7.0, 15.0]],\n        ],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_repeat_dim_0_times_empty() {\n    let tensor = TestTensor::<3>::ones([2, 3, 4], &Default::default());\n\n    let output = tensor.repeat_dim(2, 0);\n\n    assert_eq!(output.shape(), [2, 3, 0].into());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/reshape.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_rank() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n    assert_eq!(tensor.rank(), 1);\n\n    let data = TensorData::from([[0.0, 1.0, 2.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n    assert_eq!(tensor.rank(), 2);\n}\n\n#[test]\nfn should_support_reshape_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.clone().reshape([1, 3]);\n    let expected = TensorData::from([[0.0, 1.0, 2.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_reshape_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.clone().reshape([6]);\n    let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_dim_infererence() {\n    let data = TensorData::from([\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0],\n        [9.0, 10.0, 11.0],\n    ]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    // Infer the dimension via -1\n    let reshaped = tensor.clone().reshape([2, -1]);\n    assert_eq!(reshaped.shape(), [2, 6].into());\n\n    // Infer the dimension via 0 (keep from the source) and -1 (infer)\n    let reshaped = reshaped.reshape([0, 2, -1]);\n    assert_eq!(reshaped.shape(), [2, 2, 3].into());\n\n    // This is effectively as if we did a flatten\n    let reshaped = tensor.clone().reshape([-1]);\n    assert_eq!(reshaped.shape(), [12].into());\n\n    // Keeping the first dimension the same (using 0)\n    let reshaped = tensor.clone().reshape([0, 3]);\n    assert_eq!(reshaped.shape(), [4, 3].into());\n}\n\n#[test]\nfn should_not_corrupt_after_slice() {\n    let zeros = TestTensor::<1>::zeros([2], &Default::default());\n    zeros.clone().slice([1..2]).reshape([1]).exp();\n\n    // May lead to zeroes being equal to [0.0, 1.0]\n    zeros.into_data().assert_eq(\n        &TestTensor::<1>::zeros([2], &Default::default()).to_data(),\n        true,\n    );\n}\n\n#[test]\n#[should_panic]\nfn multiple_neg_ones() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n    let _data_actual = tensor.reshape([-1, -1]).into_data();\n}\n\n#[test]\n#[should_panic]\nfn neg_value() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n    let _data_actual = tensor.reshape([-2, -1]).into_data();\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/round.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_round_ops() {\n    let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.round();\n    let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_round_ties_even() {\n    let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.round();\n    let expected = TensorData::from([2., 2., 4., 4., 6., 6.]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/select.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_select_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_2d_dim0_same_num_dim() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_data([1, 0], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_2d_dim0_more_num_dim() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_data([1, 0, 1, 1], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([\n        [3.0, 4.0, 5.0],\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [3.0, 4.0, 5.0],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_2d_dim0_vec() {\n    let device = Default::default();\n    let tensor =\n        TestTensor::<2>::from_data([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]], &device);\n    let indices = TestTensorInt::from_data([1, 0, 3, 2], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([[2.0, 3.0], [0.0, 1.0], [6.0, 7.0], [4.0, 5.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_2d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.select(1, indices);\n    let expected = TensorData::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device);\n    let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([3.0, 12.0, 3.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_1d_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([7, 8, 9], &device);\n    let values = TestTensorInt::from_data([5, 4, 3, 2, 1], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([10, 19, 10]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_2d_dim0() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 0]), &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_2d_dim1() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 0, 2]), &device);\n\n    let output = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([[2.0, 2.0, 5.0], [8.0, 8.0, 11.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_3d_dim1_vec() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],\n            [[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0], [-7.0, -8.0]],\n        ],\n        &device,\n    );\n    let indices = TestTensorInt::from_data([1, 0, 3, 2], &device);\n\n    let output = tensor.select(1, indices);\n    let expected = TensorData::from([\n        [[3.0, 4.0], [1.0, 2.0], [7.0, 8.0], [5.0, 6.0]],\n        [[-3.0, -4.0], [-1.0, -2.0], [-7.0, -8.0], [-5.0, -6.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn should_select_panic_invalid_dimension() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device);\n\n    tensor.select(10, indices);\n}\n\n#[test]\nfn should_match_default_implementation_behavior() {\n    // Verify optimized implementation matches original default logic\n    let device = Default::default();\n    let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);\n    let indices = TestTensorInt::from_data([0, 1, 0], &device);\n    let values = TestTensorBool::<1>::from_data([false, true, true], &device);\n\n    let optimized_result =\n        tensor\n            .clone()\n            .select_assign(0, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n\n    // Manual default implementation logic\n    let int_tensor = tensor.int();\n    let int_values = values.int();\n    let assigned = int_tensor.select_assign(0, indices, int_values, IndexingUpdateOp::Add);\n    let default_result = assigned.greater_elem(0);\n\n    optimized_result\n        .into_data()\n        .assert_eq(&default_result.into_data(), false);\n}\n\n#[test]\nfn should_select_with_negative_dim_2d() {\n    // Test using negative dimension indexing on 2D tensor\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::from_data([1, 0, 2], &device);\n\n    // Using -1 should refer to the last dimension (dim 1)\n    let output_neg = tensor.clone().select(-1, indices.clone());\n    let output_pos = tensor.select(1, indices);\n\n    // Both should produce the same result\n    output_neg\n        .into_data()\n        .assert_eq(&output_pos.into_data(), false);\n}\n\n#[test]\nfn should_select_add_with_negative_dim_2d() {\n    // Test select_add with negative dimension on 2D tensor\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let values = TestTensor::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let indices = TestTensorInt::from_data([0, 2], &device);\n\n    // Using -1 should refer to the last dimension (dim 1)\n    let output_neg =\n        tensor\n            .clone()\n            .select_assign(-1, indices.clone(), values.clone(), IndexingUpdateOp::Add);\n    let output_pos = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add);\n\n    output_neg\n        .into_data()\n        .assert_eq(&output_pos.into_data(), false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_select_negative_dim_out_of_bounds() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let indices = TestTensorInt::from_data([0, 1], &device);\n\n    // This should panic because -3 is out of bounds for a 2D tensor\n    tensor.select(-3, indices);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_select_add_negative_dim_out_of_bounds() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let values = TestTensor::from_data([[5.0], [6.0]], &device);\n    let indices = TestTensorInt::from_data([0], &device);\n\n    // This should panic because -3 is out of bounds for a 2D tensor\n    tensor.select_assign(-3, indices, values, IndexingUpdateOp::Add);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/sign.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_sign_ops_float() {\n    let tensor = TestTensor::<2>::from([[-0.2, -1.0, 2.0], [3.0, 0.0, -5.0]]);\n\n    let output = tensor.sign();\n    let expected = TensorData::from([[-1.0, -1.0, 1.0], [1.0, 0.0, -1.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/slice.rs",
    "content": "use super::*;\nuse burn_tensor::{ElementConversion, Slice, TensorData, s};\n\n#[test]\nfn should_support_slice_dim_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default());\n\n    // Test with range (negative index)\n    let output = tensor.clone().slice_dim(0, -2..);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([1.0, 2.0]), false);\n\n    // Test with Slice directly\n    let slice = Slice::new(1, None, 1); // equivalent to 1..\n    let output = tensor.slice_dim(0, slice);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([1.0, 2.0]), false);\n}\n\n#[test]\n#[should_panic(expected = \"The provided dimension exceeds the tensor dimensions\")]\nfn should_panic_when_slice_dim_1d_bad_dim() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default());\n\n    let _output = tensor.slice_dim(1, 1..);\n}\n\n#[test]\nfn should_support_slice_dim_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    let output = tensor.slice_dim(1, 1..);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0, 2.0], [4.0, 5.0]]), false);\n}\n\n#[test]\nfn should_support_slice_dim_with_step() {\n    let data = TensorData::from([[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    // Test 1: Slice dimension 1 with step=2 using s! macro\n    let output = tensor.clone().slice_dim(1, s![0..4;2]);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 2.0], [4.0, 6.0]]), false);\n\n    // Test 2: Slice dimension 1 with step=2 using Slice directly\n    let slice = Slice::new(0, Some(4), 2);\n    let output = tensor.slice_dim(1, slice);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0.0, 2.0], [4.0, 6.0]]), false);\n}\n\n#[test]\nfn should_support_slice_dim_with_negative_step() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    // Slice dimension 1 with negative step (reverse columns)\n    let output = tensor.slice_dim(1, s![..;-1]);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2.0, 1.0, 0.0], [5.0, 4.0, 3.0]]), false);\n}\n\n#[test]\nfn should_support_full_sliceing_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default());\n\n    let output = tensor.slice([0..3]);\n\n    output.into_data().assert_eq(&data, false);\n}\n\n#[test]\nfn should_support_full_sliceing_vec() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    let slices: Vec<Slice> = vec![(0..2).into()];\n\n    let output = tensor.clone().slice(&slices);\n    output.into_data().assert_eq(&data, false);\n\n    let output = tensor.slice([0..2, 0..3]);\n    output.into_data().assert_eq(&data, false);\n}\n\n#[test]\nfn should_support_partial_sliceing_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.slice([1..3]);\n    let expected = TensorData::from([1.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_full_sliceing_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data.clone(), &Default::default());\n\n    let output = tensor.clone().slice([0..2]);\n    output.into_data().assert_eq(&data, false);\n\n    let output = tensor.slice([0..2, 0..3]);\n    output.into_data().assert_eq(&data, false);\n}\n\n#[test]\nfn should_support_partial_sliceing_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.slice([0..2, 0..2]);\n    let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_range_first_dim() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.slice(0..1);\n    let expected = TensorData::from([[0.0, 1.0, 2.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_partial_sliceing_3d() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.slice([1..2, 1..2, 0..2]);\n    let expected = TensorData::from([[[9.0, 10.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_partial_sliceing_3d_non_contiguous() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.transpose().slice([1..2, 1..2, 0..2]);\n    let expected = TensorData::from([[[7.0, 10.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.slice_fill([0..2], -1.0);\n    let expected = TensorData::from([-1.0, -1.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_vec() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let slices: Vec<Slice> = vec![(0..2).into()];\n\n    let output = tensor.slice_fill(&slices, -1.0);\n    let expected = TensorData::from([-1.0, -1.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_cast_f32() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device).cast(burn_tensor::DType::F32);\n\n    tensor\n        .slice_fill(s![0..2], 1.0)\n        .into_data()\n        .assert_eq(&TensorData::from([1.0, 1.0, 2.0]), false);\n}\n\n// Skip on metal - F64 not supported\n#[cfg(not(feature = \"metal\"))]\n#[test]\nfn should_support_slice_fill_cast_f64() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device).cast(burn_tensor::DType::F64);\n\n    tensor\n        .slice_fill(s![0..2], 1.0)\n        .into_data()\n        .assert_eq(&TensorData::from([1.0, 1.0, 2.0]), false);\n}\n\n#[test]\nfn should_support_slice_fill_1d_neg() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n\n    let output = tensor.slice_fill([-1..], -1.0);\n    let expected = TensorData::from([0.0, 1.0, -1.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(data, &device);\n\n    let output = tensor.slice_fill([1..2, 0..2], -1.0);\n    let expected = TensorData::from([[0.0, 1.0, 2.0], [-1.0, -1.0, 5.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_with_positive_step() {\n    let device = Default::default();\n\n    // Test 1D tensor with step\n    let tensor = TestTensor::<1>::zeros([10], &device);\n    let output = tensor.slice_fill(s![0..10;2], 5.0);\n    let expected = TensorData::from([5.0, 0.0, 5.0, 0.0, 5.0, 0.0, 5.0, 0.0, 5.0, 0.0]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test 2D tensor with step on first dimension\n    let tensor = TestTensor::<2>::zeros([4, 4], &device);\n    let output = tensor.slice_fill(s![0..4;2, ..], 3.0);\n    let expected = TensorData::from([\n        [3.0, 3.0, 3.0, 3.0],\n        [0.0, 0.0, 0.0, 0.0],\n        [3.0, 3.0, 3.0, 3.0],\n        [0.0, 0.0, 0.0, 0.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test 2D tensor with step on second dimension\n    let tensor = TestTensor::<2>::zeros([3, 6], &device);\n    let output = tensor.slice_fill(s![.., 0..6;3], 2.0);\n    let expected = TensorData::from([\n        [2.0, 0.0, 0.0, 2.0, 0.0, 0.0],\n        [2.0, 0.0, 0.0, 2.0, 0.0, 0.0],\n        [2.0, 0.0, 0.0, 2.0, 0.0, 0.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_with_negative_step() {\n    let device = Default::default();\n\n    // Test 1D tensor with negative step (reverse fill)\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0], &device);\n    let output = tensor.slice_fill(s![0..5;-1], 10.0);\n    // Should reverse the indices [4,3,2,1,0] and fill them with 10.0\n    let expected = TensorData::from([10.0, 10.0, 10.0, 10.0, 10.0]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test 2D tensor with negative step\n    let tensor =\n        TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], &device);\n    let output = tensor.slice_fill(s![.., 0..3;-2], -1.0);\n    // Should fill columns in reverse order with step 2: indices 2, 0\n    let expected = TensorData::from([[-1.0, 2.0, -1.0], [-1.0, 5.0, -1.0], [-1.0, 8.0, -1.0]]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_fill_with_mixed_steps() {\n    let device = Default::default();\n\n    // Test 2D tensor with mixed positive and negative steps\n    let tensor = TestTensor::<2>::zeros([4, 6], &device);\n    let output = tensor.slice_fill(s![0..4;2, 0..6;-3], 7.0);\n    // Step 2 on dim 0 selects rows 0, 2\n    // Step -3 on dim 1 with range 0..6 reverses and takes every 3rd: indices [5, 2]\n    let expected = TensorData::from([\n        [0.0, 0.0, 7.0, 0.0, 0.0, 7.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 7.0, 0.0, 0.0, 7.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test 3D tensor with steps\n    let tensor = TestTensor::<3>::zeros([2, 4, 4], &device);\n    let output = tensor.slice_fill(s![.., 0..4;2, 0..4;-2], 1.0);\n    // Step 2 on dim 1 selects rows 0, 2\n    // Step -2 on dim 2 with range 0..4 reverses and takes every 2nd: indices [3, 1]\n    let expected_slice = [\n        [0.0, 1.0, 0.0, 1.0],\n        [0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 1.0],\n        [0.0, 0.0, 0.0, 0.0],\n    ];\n    let expected = TensorData::from([expected_slice, expected_slice]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn clamp_when_slice_exceeds_dimension() {\n    let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]);\n    let data = tensor.to_data();\n\n    let output = tensor.slice([0..4]);\n    output.into_data().assert_eq(&data, true);\n}\n\n#[test]\nfn negative_dimensions() {\n    let tensor = TestTensor::<2>::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data = tensor.to_data();\n\n    // Clamping to the tensor dimensions\n    let output = tensor.clone().slice([0..4, 0..4]);\n    output.into_data().assert_eq(&data, true);\n\n    // Negative dimensions\n    let output = tensor.clone().slice([0..1, 0..1]);\n    let data = TensorData::from([[0.elem::<FloatElem>()]]);\n    output.into_data().assert_eq(&data, true);\n\n    let output = tensor.slice(s![0..-1, 0..-2]);\n    output.into_data().assert_eq(&data, true);\n}\n\n#[test]\nfn missing_dimensions() {\n    let tensor = TestTensor::<2>::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data = tensor.to_data();\n\n    // Clamping to the tensor dimensions\n    let output = tensor.clone().slice([0..4, 0..4]);\n    output.into_data().assert_eq(&data, true);\n\n    // Negative dimensions\n    let data = TensorData::from([[0.elem::<FloatElem>()]]);\n    let output = tensor.clone().slice(s![0..-1, 0..-2]);\n    output.into_data().assert_eq(&data, true);\n\n    // Missing dimensions\n    let output = tensor.clone().slice(s![0..1, ..]);\n    let data = TensorData::from([[0.0f32, 1.0, 2.0]]);\n    output.into_data().assert_eq(&data, false);\n\n    let output = tensor.clone().slice(s![.., 0..2]);\n    let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]);\n    output.into_data().assert_eq(&data, false);\n\n    let output = tensor.clone().slice([.., ..]);\n    let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    output.into_data().assert_eq(&data, false);\n}\n\n#[test]\nfn should_slice_aggregation_result() {\n    let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]).mean();\n\n    let output = tensor.clone().slice([(0..1)]);\n    output.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_slice_with_too_many_dimensions() {\n    let tensor = TestTensor::<1>::from([0.0, 1.0, 2.0]);\n\n    let _output = tensor.slice([0..1, 0..1]);\n}\n\n#[test]\nfn should_support_descending_slice_as_empty() {\n    // Like PyTorch, x[3:1] should return an empty tensor, not panic\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.slice(s![2..1]);\n\n    // Should produce an empty tensor with shape [0]\n    assert_eq!(output.dims(), [0]);\n}\n\n#[test]\nfn should_support_empty_slice() {\n    // ONNX models can have empty slices where start == end\n    // This should produce a tensor with size 0 in that dimension\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.slice([1..1]);\n\n    // Should produce an empty tensor with shape [0]\n    assert_eq!(output.dims(), [0]);\n}\n\n#[test]\nfn should_support_empty_slice_2d() {\n    // Test empty slice on 2D tensor\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    // Empty slice on first dimension\n    let output = tensor.clone().slice([1..1, 0..3]);\n    assert_eq!(output.dims(), [0, 3]);\n\n    // Empty slice on second dimension\n    let output = tensor.slice([0..2, 2..2]);\n    assert_eq!(output.dims(), [2, 0]);\n}\n\n#[test]\nfn test_slice_with_positive_step() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n\n    // Test step=2 along first dimension\n    let sliced = tensor.clone().slice([s![0..3;2]]);\n    let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0], [9.0, 10.0, 11.0, 12.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=2 along second dimension\n    let sliced = tensor.clone().slice(s![.., 0..4;2]);\n    let expected = TensorData::from([[1.0, 3.0], [5.0, 7.0], [9.0, 11.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=2 along both dimensions\n    let sliced = tensor.clone().slice(s![0..3;2, 0..4;2]);\n    let expected = TensorData::from([[1.0, 3.0], [9.0, 11.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_with_negative_step() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n\n    // Test step=-1 along first dimension (reverse rows)\n    let sliced = tensor.clone().slice([s![0..3;-1]]);\n    let expected = TensorData::from([\n        [9.0, 10.0, 11.0, 12.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [1.0, 2.0, 3.0, 4.0],\n    ]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=-1 along second dimension (reverse columns)\n    let sliced = tensor.clone().slice(s![.., 0..4;-1]);\n    let expected = TensorData::from([\n        [4.0, 3.0, 2.0, 1.0],\n        [8.0, 7.0, 6.0, 5.0],\n        [12.0, 11.0, 10.0, 9.0],\n    ]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=-2 along first dimension\n    let sliced = tensor.clone().slice([s![0..3;-2]]);\n    let expected = TensorData::from([[9.0, 10.0, 11.0, 12.0], [1.0, 2.0, 3.0, 4.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=-2 along second dimension\n    let sliced = tensor.clone().slice(s![.., 0..4;-2]);\n    let expected = TensorData::from([[4.0, 2.0], [8.0, 6.0], [12.0, 10.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_with_mixed_steps() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n\n    // Test positive step along first dimension, negative along second\n    let sliced = tensor.clone().slice(s![0..3;2, 0..4;-1]);\n    let expected = TensorData::from([[4.0, 3.0, 2.0, 1.0], [12.0, 11.0, 10.0, 9.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test negative step along first dimension, positive along second\n    let sliced = tensor.clone().slice(s![0..3;-1, 0..4;2]);\n    let expected = TensorData::from([[9.0, 11.0], [5.0, 7.0], [1.0, 3.0]]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_with_steps_1d() {\n    let device = Default::default();\n    let tensor =\n        TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], &device);\n\n    // Test positive step\n    let sliced = tensor.clone().slice([s![0..10;2]]);\n    let expected = TensorData::from([1.0, 3.0, 5.0, 7.0, 9.0]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test negative step\n    let sliced = tensor.clone().slice([s![0..10;-1]]);\n    let expected = TensorData::from([10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test negative step with partial range\n    let sliced = tensor.clone().slice([s![2..8;-2]]);\n    let expected = TensorData::from([8.0, 6.0, 4.0]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_with_steps_3d() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0], [3.0, 4.0]],\n            [[5.0, 6.0], [7.0, 8.0]],\n            [[9.0, 10.0], [11.0, 12.0]],\n            [[13.0, 14.0], [15.0, 16.0]],\n        ],\n        &device,\n    );\n\n    // Test step=2 along first dimension\n    let sliced = tensor.clone().slice(s![0..4;2, .., ..]);\n    let expected = TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[9.0, 10.0], [11.0, 12.0]]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=-1 along all dimensions\n    let sliced = tensor.clone().slice(s![0..4;-1, 0..2;-1, 0..2;-1]);\n    let expected = TensorData::from([\n        [[16.0, 15.0], [14.0, 13.0]],\n        [[12.0, 11.0], [10.0, 9.0]],\n        [[8.0, 7.0], [6.0, 5.0]],\n        [[4.0, 3.0], [2.0, 1.0]],\n    ]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/slice_assign.rs",
    "content": "use super::*;\nuse burn_tensor::{Slice, TensorData, s};\n\n#[test]\nfn should_support_slice_assign_1d() {\n    let data = TensorData::from([0.0, 1.0, 2.0]);\n    let data_assigned = TensorData::from([10.0, 5.0]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data(data, &device);\n    let tensor_assigned = TestTensor::<1>::from_data(data_assigned, &device);\n\n    let output = tensor.slice_assign([0..2], tensor_assigned);\n    let expected = TensorData::from([10.0, 5.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_2d() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data_assigned = TensorData::from([[10.0, 5.0]]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(data, &device);\n    let tensor_assigned = TestTensor::<2>::from_data(data_assigned, &device);\n\n    let output = tensor.slice_assign([1..2, 0..2], tensor_assigned);\n    let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_vec() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data_assigned = TensorData::from([[10.0, 5.0]]);\n\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(data, &device);\n    let tensor_assigned = TestTensor::<2>::from_data(data_assigned, &device);\n\n    let slices: Vec<Slice> = vec![1..2, 0..2].into_iter().map(Slice::from).collect();\n\n    let output = tensor.slice_assign(&slices, tensor_assigned);\n    let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn slice_assign_now_supports_non_unit_step() {\n    let device = Default::default();\n    // Create tensors where the shapes match for stepped slicing\n    let tensor = TestTensor::<2>::ones([4, 4], &device);\n    // With step=2 on first dim, we select indices 0 and 2, so we need a [2, 4] values tensor\n    let values = TestTensor::<2>::zeros([2, 4], &device);\n\n    // This now works because slice_assign supports steps != 1\n    // We use s! macro to create a slice with step=2\n    let result = tensor.slice_assign(s![0..3;2, ..], values);\n\n    // Verify the result: rows 0 and 2 should be zeros, rows 1 and 3 should be ones\n    let expected = TensorData::from([\n        [0.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0, 1.0],\n    ]);\n    result.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_with_positive_step_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device);\n    let values = TestTensor::<1>::from_data([10.0, 20.0, 30.0], &device);\n\n    // Assign to indices 0, 2, 4 (step=2)\n    let output = tensor.slice_assign([s![0..6;2]], values);\n    let expected = TensorData::from([10.0, 2.0, 20.0, 4.0, 30.0, 6.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_with_positive_step_2d() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n            [13.0, 14.0, 15.0, 16.0],\n        ],\n        &device,\n    );\n\n    // Assign to rows 0, 2 (step=2)\n    let values = TestTensor::<2>::from_data(\n        [[100.0, 101.0, 102.0, 103.0], [200.0, 201.0, 202.0, 203.0]],\n        &device,\n    );\n    let output = tensor.clone().slice_assign([s![0..4;2]], values);\n    let expected = TensorData::from([\n        [100.0, 101.0, 102.0, 103.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [200.0, 201.0, 202.0, 203.0],\n        [13.0, 14.0, 15.0, 16.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Assign to columns 0, 2 (step=2)\n    let values = TestTensor::<2>::from_data(\n        [\n            [100.0, 200.0],\n            [101.0, 201.0],\n            [102.0, 202.0],\n            [103.0, 203.0],\n        ],\n        &device,\n    );\n    let output = tensor.clone().slice_assign(s![.., 0..4;2], values);\n    let expected = TensorData::from([\n        [100.0, 2.0, 200.0, 4.0],\n        [101.0, 6.0, 201.0, 8.0],\n        [102.0, 10.0, 202.0, 12.0],\n        [103.0, 14.0, 203.0, 16.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Assign with step=2 on both dimensions\n    let values = TestTensor::<2>::from_data([[100.0, 200.0], [300.0, 400.0]], &device);\n    let output = tensor.slice_assign(s![0..4;2, 0..4;2], values);\n    let expected = TensorData::from([\n        [100.0, 2.0, 200.0, 4.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [300.0, 10.0, 400.0, 12.0],\n        [13.0, 14.0, 15.0, 16.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_with_negative_step_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device);\n    let values = TestTensor::<1>::from_data([60.0, 50.0, 40.0, 30.0, 20.0, 10.0], &device);\n\n    // Assign in reverse order (step=-1)\n    let output = tensor.slice_assign([s![0..6;-1]], values);\n    let expected = TensorData::from([10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_with_negative_step_2d() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n\n    // Assign to rows in reverse order (step=-1)\n    let values = TestTensor::<2>::from_data(\n        [\n            [30.0, 31.0, 32.0, 33.0],\n            [20.0, 21.0, 22.0, 23.0],\n            [10.0, 11.0, 12.0, 13.0],\n        ],\n        &device,\n    );\n    let output = tensor.clone().slice_assign([s![0..3;-1]], values);\n    let expected = TensorData::from([\n        [10.0, 11.0, 12.0, 13.0],\n        [20.0, 21.0, 22.0, 23.0],\n        [30.0, 31.0, 32.0, 33.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Assign to columns in reverse order (step=-1)\n    let values = TestTensor::<2>::from_data(\n        [\n            [40.0, 30.0, 20.0, 10.0],\n            [80.0, 70.0, 60.0, 50.0],\n            [120.0, 110.0, 100.0, 90.0],\n        ],\n        &device,\n    );\n    let output = tensor.clone().slice_assign(s![.., 0..4;-1], values);\n    let expected = TensorData::from([\n        [10.0, 20.0, 30.0, 40.0],\n        [50.0, 60.0, 70.0, 80.0],\n        [90.0, 100.0, 110.0, 120.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_with_mixed_steps() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n            [13.0, 14.0, 15.0, 16.0],\n        ],\n        &device,\n    );\n\n    // Positive step along rows, negative along columns\n    let values = TestTensor::<2>::from_data(\n        [[100.0, 101.0, 102.0, 103.0], [200.0, 201.0, 202.0, 203.0]],\n        &device,\n    );\n    let output = tensor.clone().slice_assign(s![0..4;2, 0..4;-1], values);\n    let expected = TensorData::from([\n        [103.0, 102.0, 101.0, 100.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [203.0, 202.0, 201.0, 200.0],\n        [13.0, 14.0, 15.0, 16.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Negative step along rows, positive along columns\n    let values = TestTensor::<2>::from_data(\n        [\n            [100.0, 200.0],\n            [101.0, 201.0],\n            [102.0, 202.0],\n            [103.0, 203.0],\n        ],\n        &device,\n    );\n    let output = tensor.slice_assign(s![0..4;-1, 0..4;2], values);\n    let expected = TensorData::from([\n        [103.0, 2.0, 203.0, 4.0],\n        [102.0, 6.0, 202.0, 8.0],\n        [101.0, 10.0, 201.0, 12.0],\n        [100.0, 14.0, 200.0, 16.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_3d_with_steps() {\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0], [3.0, 4.0]],\n            [[5.0, 6.0], [7.0, 8.0]],\n            [[9.0, 10.0], [11.0, 12.0]],\n            [[13.0, 14.0], [15.0, 16.0]],\n        ],\n        &device,\n    );\n\n    // Test step=2 along first dimension\n    let values = TestTensor::<3>::from_data(\n        [\n            [[100.0, 101.0], [102.0, 103.0]],\n            [[200.0, 201.0], [202.0, 203.0]],\n        ],\n        &device,\n    );\n    let output = tensor.clone().slice_assign(s![0..4;2, .., ..], values);\n    let expected = TensorData::from([\n        [[100.0, 101.0], [102.0, 103.0]],\n        [[5.0, 6.0], [7.0, 8.0]],\n        [[200.0, 201.0], [202.0, 203.0]],\n        [[13.0, 14.0], [15.0, 16.0]],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test step=-1 along all dimensions\n    let values = TestTensor::<3>::from_data(\n        [\n            [[400.0, 399.0], [398.0, 397.0]],\n            [[396.0, 395.0], [394.0, 393.0]],\n            [[392.0, 391.0], [390.0, 389.0]],\n            [[388.0, 387.0], [386.0, 385.0]],\n        ],\n        &device,\n    );\n    let output = tensor.slice_assign(s![0..4;-1, 0..2;-1, 0..2;-1], values);\n    let expected = TensorData::from([\n        [[385.0, 386.0], [387.0, 388.0]],\n        [[389.0, 390.0], [391.0, 392.0]],\n        [[393.0, 394.0], [395.0, 396.0]],\n        [[397.0, 398.0], [399.0, 400.0]],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_partial_with_steps() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0, 5.0],\n            [6.0, 7.0, 8.0, 9.0, 10.0],\n            [11.0, 12.0, 13.0, 14.0, 15.0],\n            [16.0, 17.0, 18.0, 19.0, 20.0],\n            [21.0, 22.0, 23.0, 24.0, 25.0],\n        ],\n        &device,\n    );\n\n    // Assign to a subset with step=2\n    let values = TestTensor::<2>::from_data([[100.0, 200.0], [300.0, 400.0]], &device);\n    let output = tensor.slice_assign(s![1..4;2, 1..4;2], values);\n    let expected = TensorData::from([\n        [1.0, 2.0, 3.0, 4.0, 5.0],\n        [6.0, 100.0, 8.0, 200.0, 10.0],\n        [11.0, 12.0, 13.0, 14.0, 15.0],\n        [16.0, 300.0, 18.0, 400.0, 20.0],\n        [21.0, 22.0, 23.0, 24.0, 25.0],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_empty_range() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let values: TestTensor<2> = TestTensor::empty([2, 0], &device);\n\n    // Empty slice assignment (start == end) should be a no-op\n    let output = tensor.clone().slice_assign([0..2, 1..1], values);\n    let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_empty_range_1d() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([1.0, 2.0, 3.0, 4.0, 5.0], &device);\n    let values: TestTensor<1> = TestTensor::empty([0], &device);\n\n    // Empty slice assignment should return tensor unchanged\n    let output = tensor.clone().slice_assign([2..2], values);\n    let expected = TensorData::from([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_single_dim_slice() {\n    let device = Default::default();\n    let x = TestTensor::<3>::ones([2, 3, 1], &device);\n    let values = TestTensor::<3>::zeros([1, 3, 1], &device);\n\n    let output = x.slice_assign(s![1], values);\n\n    output.into_data().assert_eq(\n        &TensorData::from([[[1.0], [1.0], [1.0]], [[0.0], [0.0], [0.0]]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/sort_argsort.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_sort_1d_float() {\n    let tensor = TestTensor::<1>::from([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let values = tensor.sort(0);\n\n    let values_expected = TensorData::from([\n        -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 199.412,\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n}\n\n#[test]\nfn test_argsort_1d_float() {\n    let tensor = TestTensor::<1>::from([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let indices = tensor.argsort(0);\n\n    let indices_expected = TensorData::from([12, 6, 2, 3, 0, 5, 10, 1, 4, 7, 11, 9, 8]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_with_indices_descending_float() {\n    // 1D\n    let tensor = TestTensor::<1>::from([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 199.412, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let (values, indices) = tensor.sort_descending_with_indices(0);\n\n    let values_expected = TensorData::from([\n        199.412, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1,\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // 2D\n    let tensor = TestTensor::<3>::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, 4.], [0.99, 3., -8.1]],\n    ]);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.sort_descending_with_indices(1);\n\n    let values_expected = TensorData::from([\n        [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]],\n        [[0.99, 3., 4.], [-0.3, 2.3, -8.1]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_float() {\n    let tensor = TestTensor::<3>::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, 4.], [0.99, 3., -8.1]],\n    ]);\n\n    // Sort along dim=0\n    let values = tensor.clone().sort(0);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]],\n        [[-0.3, 2.3, 4.], [0.99, 3., 0.94]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    // Sort along dim=1\n    let values = tensor.clone().sort(1);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, -8.1], [0.99, 3., 4.]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    // Sort along dim=2\n    let values = tensor.sort(2);\n\n    let values_expected = TensorData::from([\n        [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]],\n        [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n}\n\n#[test]\nfn test_sort_with_indices_float() {\n    let tensor = TestTensor::<3>::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, 4.], [0.99, 3., -8.1]],\n    ]);\n\n    // Sort along dim=0\n    let (values, indices) = tensor.clone().sort_with_indices(0);\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]],\n        [[-0.3, 2.3, 4.], [0.99, 3., 0.94]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.clone().sort_with_indices(1);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, -8.1], [0.99, 3., 4.]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let (values, indices) = tensor.sort_with_indices(2);\n\n    let values_expected = TensorData::from([\n        [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]],\n        [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]],\n    ]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_argsort_float() {\n    let tensor = TestTensor::<3>::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, 4.], [0.99, 3., -8.1]],\n    ]);\n\n    // Sort along dim=0\n    let indices = tensor.clone().argsort(0);\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let indices = tensor.clone().argsort(1);\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let indices = tensor.argsort(2);\n\n    let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_float_nan() {\n    let tensor = TestTensor::<2>::from([[-0.5, f32::NAN], [0., 0.94], [-0.3, f32::NAN]]);\n\n    // Sort along dim=0\n    let values = tensor.sort(0);\n\n    let values_expected = TensorData::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n}\n\n#[test]\nfn test_sort_descending_1d() {\n    let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]);\n\n    // Sort along dim=0\n    let values = tensor.sort_descending(0);\n\n    let values_expected = TensorData::from([5., 4., 3., 2., 1.]);\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/split.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_split_evenly_divisible() {\n    let device = Default::default();\n    let tensors =\n        TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]], &device);\n\n    let split_tensors = tensors.split(2, 0);\n    assert_eq!(split_tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([[0, 1], [2, 3]]),\n        TensorData::from([[4, 5], [6, 7]]),\n        TensorData::from([[8, 9], [10, 11]]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_not_evenly_divisible() {\n    let device = Default::default();\n    let tensors = TestTensor::<2>::from_data([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], &device);\n\n    let split_tensors = tensors.split(2, 0);\n    assert_eq!(split_tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([[0, 1], [2, 3]]),\n        TensorData::from([[4, 5], [6, 7]]),\n        TensorData::from([[8, 9]]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_along_dim1() {\n    let device = Default::default();\n    let tensors = TestTensor::<2>::from_data([[0, 1, 2], [3, 4, 5]], &device);\n\n    let split_tensors = tensors.split(2, 1);\n    assert_eq!(split_tensors.len(), 2);\n\n    let expected = [\n        TensorData::from([[0, 1], [3, 4]]),\n        TensorData::from([[2], [5]]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_split_size_larger_than_tensor_size() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);\n\n    let split_tensors = tensors.split(10, 0);\n    assert_eq!(split_tensors.len(), 1);\n\n    let expected = [TensorData::from([0, 1, 2, 3, 4])];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_with_zero_split_size_zero_tensor_size() {\n    let device = Default::default();\n    let empty_array: [i32; 0] = [];\n    let tensors = TestTensor::<1>::from_data(empty_array, &device);\n\n    let split_tensors = tensors.split(0, 0);\n    assert_eq!(split_tensors.len(), 0);\n}\n\n#[test]\nfn test_split_zero_sized_tensor() {\n    let device = Default::default();\n    let empty_array: [i32; 0] = [];\n    let tensors = TestTensor::<1>::from_data(empty_array, &device);\n\n    let split_tensors = tensors.split(1, 0);\n    assert_eq!(split_tensors.len(), 0);\n}\n\n#[test]\n#[should_panic(\n    expected = \"split_size must be greater than 0 unless the tensor size along the dimension is 0.\"\n)]\nfn test_split_with_zero_split_size_non_zero_tensor() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4], &device);\n\n    let _split_tensors = tensors.split(0, 0);\n}\n\n#[test]\n#[should_panic(expected = \"Given dimension is greater than or equal to the tensor rank.\")]\nfn test_split_invalid_dim() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2], &device);\n\n    let _split_tensors = tensors.split(1, 2);\n}\n\n#[test]\nfn test_split_3d_tensor_along_dim0() {\n    let device = Default::default();\n    let tensors = TestTensor::<3>::from_data(\n        [\n            [[0, 1], [2, 3]],\n            [[4, 5], [6, 7]],\n            [[8, 9], [10, 11]],\n            [[12, 13], [14, 15]],\n        ],\n        &device,\n    );\n\n    let split_tensors = tensors.split(2, 0);\n    assert_eq!(split_tensors.len(), 2);\n\n    let expected = [\n        TensorData::from([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]),\n        TensorData::from([[[8, 9], [10, 11]], [[12, 13], [14, 15]]]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_3d_tensor_along_dim1() {\n    let device = Default::default();\n    let tensors = TestTensor::<3>::from_data(\n        [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]],\n        &device,\n    );\n\n    let split_tensors = tensors.split(2, 1);\n    assert_eq!(split_tensors.len(), 2);\n\n    let expected = [\n        TensorData::from([[[0, 1], [2, 3]], [[6, 7], [8, 9]]]),\n        TensorData::from([[[4, 5]], [[10, 11]]]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\nfn test_split_with_sizes() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device);\n\n    let split_tensors = tensors.split_with_sizes(vec![2, 3, 1], 0);\n    assert_eq!(split_tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([0, 1]),\n        TensorData::from([2, 3, 4]),\n        TensorData::from([5]),\n    ];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n\n#[test]\n#[should_panic(\n    expected = \"The sum of split_sizes must equal the tensor size along the specified dimension.\"\n)]\nfn test_split_with_sizes_invalid_sum() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2, 3, 4, 5], &device);\n\n    let _split_tensors = tensors.split_with_sizes(vec![2, 2, 1], 0);\n}\n\n#[test]\nfn test_split_with_sizes_zero_length() {\n    let device = Default::default();\n    let tensors = TestTensor::<1>::from_data([0, 1, 2], &device);\n\n    let split_tensors = tensors.split_with_sizes(vec![0, 1, 2], 0);\n    assert_eq!(split_tensors.len(), 2);\n\n    let expected = [TensorData::from([0]), TensorData::from([1, 2])];\n\n    for (index, tensor) in split_tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/sqrt.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse core::f32::consts::SQRT_2;\n\n#[test]\nfn should_support_sqrt_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.sqrt();\n    let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::relative(1e-4).set_half_precision_relative(1e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/square.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_sqrt_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.square();\n    let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]);\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &expected,\n        Tolerance::relative(1e-4).set_half_precision_relative(1e-3),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/squeeze.rs",
    "content": "use super::*;\nuse burn_tensor::Shape;\n\n/// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor.\n#[test]\nfn should_squeeze_dim() {\n    let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default());\n    let squeezed_tensor: TestTensor<2> = tensor.squeeze_dim(1);\n    let expected_shape = Shape::new([2, 4]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n#[test]\nfn should_squeeze() {\n    let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default());\n    let squeezed_tensor: TestTensor<2> = tensor.squeeze();\n    let expected_shape = Shape::new([2, 4]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor.\n#[test]\nfn should_squeeze_first() {\n    let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 4, 5]), &Default::default());\n    let squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(0);\n    let expected_shape = Shape::new([3, 4, 5]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n/// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor.\n#[test]\nfn should_squeeze_last() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 1]), &Default::default());\n    let squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(3);\n    let expected_shape = Shape::new([2, 3, 4]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n/// Test if the function panics when the squeezed dimension is not of size 1.\n#[test]\n#[should_panic]\nfn should_squeeze_panic() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let _squeezed_tensor: TestTensor<3> = tensor.squeeze_dim(2);\n}\n\n/// Test if the function works with an empty slice\n#[test]\nfn should_squeeze_dims_with_empty_slice() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 3]), &Default::default());\n    let squeezed_tensor: TestTensor<1> = tensor.squeeze_dims(&[]);\n    let expected_shape = Shape::new([3]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n#[test]\nfn should_squeeze_all_dims() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 3, 1]), &Default::default());\n    let squeezed_tensor: TestTensor<1> = tensor.squeeze();\n    let expected_shape = Shape::new([3]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function works with positive indices\n#[test]\nfn should_squeeze_dims_with_positive_indices() {\n    let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());\n    let squeezed_tensor: TestTensor<2> = tensor.squeeze_dims(&[0, 2]);\n    let expected_shape = Shape::new([3, 5]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function works with negative indices\n#[test]\nfn should_squeeze_dims_with_negative_indices() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 1, 3, 1]), &Default::default());\n    let squeezed_tensor: TestTensor<2> = tensor.squeeze_dims(&[-3, -1]);\n    let expected_shape = Shape::new([2, 3]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n/// Test to make sure the function panics if a non-singleton dimension is squeezed\n#[test]\n#[should_panic]\nfn should_squeeze_dims_work_if_non_singleton() {\n    let tensor = TestTensor::<3>::ones(Shape::new([2, 3, 4]), &Default::default());\n    let squeezed_tensor: TestTensor<3> = tensor.squeeze_dims(&[1]);\n    let expected_shape = Shape::new([2, 3, 4]);\n    assert_eq!(squeezed_tensor.shape(), expected_shape);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_squeeze_consumes_all_singleton() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 3, 1]), &Default::default());\n    let _squeezed_tensor: TestTensor<2> = tensor.squeeze(); // output rank should be 1\n}\n\n/// Test to make sure the function panics if too many dimensions are requested to be squeezed\n#[test]\n#[should_panic]\nfn should_squeeze_dims_panic_on_too_many_dimensions() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default());\n    let _: TestTensor<1> = tensor.squeeze_dims(&[0, 1, 2]);\n}\n\n/// Test to make sure function panics if dimensions are mismatched\n#[test]\n#[should_panic]\nfn should_squeeze_dims_dimension_mismatch_panic() {\n    let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());\n    let _: TestTensor<3> = tensor.squeeze_dims(&[0, 2]);\n}\n\n/// Test if the function can successfully unsqueeze the size 1 dimension at the specified position of a 3D tensor.\n#[test]\nfn should_unsqueeze_dim() {\n    let tensor = TestTensor::<3>::ones(Shape::new([2, 4, 1]), &Default::default());\n    let unsqueezed_tensor: TestTensor<4> = tensor.unsqueeze_dim(1);\n    let expected_shape = Shape::new([2, 1, 4, 1]);\n    assert_eq!(unsqueezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully unsqueeze the first size 1 dimension of a 4D tensor.\n#[test]\nfn should_unsqueeze_dim_first() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(0);\n    let expected_shape = Shape::new([1, 2, 3, 4, 5]);\n    assert_eq!(unsqueezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function can successfully unsqueeze the last size 1 dimension of a 4D tensor.\n#[test]\nfn should_unsqueeze_dim_last() {\n    let tensor = TestTensor::<4>::ones(Shape::new([5, 4, 3, 2]), &Default::default());\n    let unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(4);\n    let expected_shape = Shape::new([5, 4, 3, 2, 1]);\n    assert_eq!(unsqueezed_tensor.shape(), expected_shape);\n}\n\n/// Test if the function panics when the unsqueezed dimension is out of bounds.\n#[test]\n#[should_panic]\nfn should_unsqueeze_dim_panic() {\n    let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());\n    let _unsqueezed_tensor: TestTensor<5> = tensor.unsqueeze_dim(5);\n}\n\n#[test]\nfn should_unsqueeze_dims_support_dim_inference() {\n    let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());\n    let output_tensor = input_tensor.unsqueeze_dims::<5>(&[1, -2]);\n    let expected_shape = Shape::new([3, 1, 4, 1, 5]);\n    assert_eq!(output_tensor.shape(), expected_shape);\n}\n\n#[test]\nfn should_unsqueeze_dims_handle_first_last() {\n    let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());\n    let output_tensor = input_tensor.unsqueeze_dims::<5>(&[0, 4]);\n    let expected_shape = Shape::new([1, 3, 4, 5, 1]);\n    assert_eq!(output_tensor.shape(), expected_shape);\n}\n\n#[test]\nfn should_unsqueeze_dims_work_with_single_dim() {\n    //bruh, just call unsqueeze_dim\n    let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());\n    let output_tensor: TestTensor<4> = input_tensor.unsqueeze_dims(&[1]);\n    let expected_shape = Shape::new([3, 1, 4, 5]);\n    assert_eq!(output_tensor.shape(), expected_shape);\n}\n\n#[test]\nfn should_unsqueeze_dims_multiple_trailing_negatives() {\n    let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());\n    let output_tensor: TestTensor<6> = input_tensor.unsqueeze_dims(&[0, -1, -1]);\n    let expected_shape = Shape::new([1, 3, 4, 5, 1, 1]);\n    assert_eq!(output_tensor.shape(), expected_shape);\n}\n\n#[test]\n#[should_panic]\nfn should_unsqueeze_dims_panic() {\n    let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());\n    let _output_tensor: TestTensor<5> = input_tensor.unsqueeze_dims(&[0, -6]);\n}\n\n#[test]\n#[should_panic]\nfn squeeze_all_singleton_not_supported() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default());\n    let _ = tensor.squeeze::<0>();\n}\n\n#[test]\n#[should_panic]\nfn squeeze_dim_singleton_not_supported() {\n    let tensor = TestTensor::<1>::ones(Shape::new([1]), &Default::default());\n    let _ = tensor.squeeze_dim::<0>(0);\n}\n\n#[test]\n#[should_panic]\nfn squeeze_dims_all_singleton_not_supported() {\n    let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default());\n    let _ = tensor.squeeze_dims::<0>(&[0, 1, 2]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/stack.rs",
    "content": "use super::*;\nuse alloc::{vec, vec::Vec};\nuse burn_tensor::{Tensor, TensorData};\n\n#[test]\nfn should_support_stack_ops_2d_dim0() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_stack_ops_2d_dim1() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0, 6.0]], &device);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 1);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_stack_ops_3d() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);\n    let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]], [[4.1, 5.1, 6.1]]], &device);\n\n    let output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([\n        [[[1.0000, 2.0000, 3.0000]], [[1.1000, 2.1000, 3.1000]]],\n        [[[4.0000, 5.0000, 6.0000]], [[4.1000, 5.1000, 6.1000]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_dimensions_are_not_the_same() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], &device);\n    let tensor_2 = TestTensor::from_data([[4.0, 5.0]], &device);\n\n    let _output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_list_of_vectors_is_empty() {\n    let tensors: Vec<TestTensor<2>> = vec![];\n    let _output = Tensor::stack::<3>(tensors, 0);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_stack_exceeds_dimension() {\n    let device = Default::default();\n    let tensor_1 = TestTensor::<3>::from_data([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]], &device);\n    let tensor_2 = TestTensor::from_data([[[4.0, 5.0, 6.0]]], &device);\n\n    let _output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 3);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/sub.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_sub_ops() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data_2 = TensorData::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sub_broadcast() {\n    let data_1 = TensorData::from([[0.0, 1.0, 2.0]]);\n    let data_2 = TensorData::from([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n    let device = Default::default();\n    let tensor_1 = TestTensor::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensor::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_sub_scalar_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let scalar = 2.0;\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor - scalar;\n    let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/take.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_take_1d() {\n    // Test that take works with 1D indices\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([0.0, 1.0, 2.0], &device);\n    let indices = TestTensorInt::<1>::from_data([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.take::<1, 1>(0, indices);\n    let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_2d_dim0() {\n    // Test take on 2D tensor along dimension 0\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([1, 0, 1, 1], &device);\n\n    let output = tensor.take::<1, 2>(0, indices);\n    let expected = TensorData::from([\n        [3.0, 4.0, 5.0],\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [3.0, 4.0, 5.0],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_2d_dim1() {\n    // Test take on 2D tensor along dimension 1\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([2, 0, 1], &device);\n\n    let output = tensor.take::<1, 2>(1, indices);\n    let expected = TensorData::from([[2.0, 0.0, 1.0], [5.0, 3.0, 4.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn take_and_select_should_be_equivalent() {\n    // Verify that take and select produce identical results\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n    let indices = TestTensorInt::<1>::from_data([2, 0, 1, 1], &device);\n\n    let result_take = tensor.clone().take::<1, 2>(0, indices.clone());\n    let result_select = tensor.select(0, indices);\n\n    let take_data = result_take.into_data();\n    let select_data = result_select.into_data();\n\n    take_data.assert_eq(&select_data, false);\n}\n\n#[test]\nfn should_take_with_2d_indices() {\n    // Test take with 2D indices - output will be 3D with shape [2, 2, 4]\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [5.0, 6.0, 7.0, 8.0],\n            [9.0, 10.0, 11.0, 12.0],\n        ],\n        &device,\n    );\n\n    // 2D indices to select along dimension 0 - shape [2, 2]\n    let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device);\n    let output = tensor.take::<2, 3>(0, indices);\n\n    // Expected: shape [2, 2, 4] - indices shape replaces dim 0\n    let expected = TensorData::from([\n        [[1.0, 2.0, 3.0, 4.0], [9.0, 10.0, 11.0, 12.0]],\n        [[5.0, 6.0, 7.0, 8.0], [1.0, 2.0, 3.0, 4.0]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_with_2d_indices_dim1() {\n    // Test take with 2D indices along dimension 1 - output will be 3D with shape [2, 2, 2]\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], &device);\n\n    // 2D indices to select along dimension 1 - shape [2, 2]\n    let indices = TestTensorInt::<2>::from_data([[0, 3], [2, 1]], &device);\n    let output = tensor.take::<2, 3>(1, indices);\n\n    // Expected: shape [2, 2, 2] - indices shape replaces dim 1\n    let expected = TensorData::from([[[1.0, 4.0], [3.0, 2.0]], [[5.0, 8.0], [7.0, 6.0]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_3d_tensor() {\n    // Test take with 3D tensor - output will be 4D with shape [2, 2, 2, 2]\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(\n        [\n            [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],\n            [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],\n        ],\n        &device,\n    );\n\n    // 2D indices to select along dimension 1 - shape [2, 2]\n    let indices = TestTensorInt::<2>::from_data([[0, 2], [1, 0]], &device);\n    let output = tensor.take::<2, 4>(1, indices);\n\n    // Expected: shape [2, 2, 2, 2] - indices shape replaces dim 1\n    let expected = TensorData::from([\n        [[[1.0, 2.0], [5.0, 6.0]], [[3.0, 4.0], [1.0, 2.0]]],\n        [[[7.0, 8.0], [11.0, 12.0]], [[9.0, 10.0], [7.0, 8.0]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_with_3d_indices() {\n    // Test take with 3D indices - output will be 4D\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n\n    // 3D indices to select along dimension 1 - shape [2, 2, 2]\n    let indices = TestTensorInt::<3>::from_data([[[0, 2], [1, 0]], [[2, 1], [0, 2]]], &device);\n    let output = tensor.take::<3, 4>(1, indices);\n\n    // Expected: shape [2, 2, 2, 2] - indices shape replaces dim 1\n    let expected = TensorData::from([\n        [[[1.0, 3.0], [2.0, 1.0]], [[3.0, 2.0], [1.0, 3.0]]],\n        [[[4.0, 6.0], [5.0, 4.0]], [[6.0, 5.0], [4.0, 6.0]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_take_invalid_dimension() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([1, 0], &device);\n\n    // This should panic because dimension 10 is out of bounds\n    tensor.take::<1, 2>(10, indices);\n}\n\n#[test]\nfn should_take_with_single_index() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([1], &device);\n\n    let output = tensor.take::<1, 2>(0, indices);\n    let expected = TensorData::from([[4.0, 5.0, 6.0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_with_negative_dim_2d() {\n    // Test using negative dimension indexing on 2D tensor\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([2, 0, 1], &device);\n\n    // Using -1 should refer to the last dimension (dim 1)\n    let output_neg = tensor.clone().take::<1, 2>(-1, indices.clone());\n    let output_pos = tensor.take::<1, 2>(1, indices);\n\n    // Both should produce the same result\n    let neg_data = output_neg.into_data();\n    let pos_data = output_pos.into_data();\n    neg_data.assert_eq(&pos_data, false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_take_negative_dim_out_of_bounds() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n    let indices = TestTensorInt::<1>::from_data([0, 1], &device);\n\n    // This should panic because -3 is out of bounds for a 2D tensor\n    tensor.take::<1, 2>(-3, indices);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/topk.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_topk_with_indices_3d() {\n    let tensor =\n        TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);\n\n    let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2);\n\n    let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);\n\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::default());\n\n    let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]);\n\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/transaction.rs",
    "content": "use super::*;\nuse burn_tensor::Transaction;\n\n// https://github.com/tracel-ai/burn/issues/4021\n#[test]\nfn should_support_transaction() {\n    let rows = 261120;\n    let cols = 408;\n\n    let device = Default::default();\n\n    let j = TestTensor::<2>::zeros([rows, cols], &device);\n    let jt = j.clone().transpose();\n\n    let g = jt.matmul(j);\n\n    let g = g.transpose();\n    let expected = g.to_data();\n\n    assert_eq!(g.shape().dims(), [cols, cols]);\n\n    // Fails\n    let [data] = Transaction::default()\n        .register(g)\n        .execute()\n        .try_into()\n        .unwrap();\n\n    // check byte equality\n    assert_eq!(data, expected);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/transpose.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_transpose_ops() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    // Check the .t() alias.\n    let output = tensor.t();\n\n    let expected = TensorData::from([\n        [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]],\n        [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_transpose_maybe_fused_with_one() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n    let ones = TestTensor::<3>::ones([1, 1, 1], &Default::default());\n\n    let output = tensor.transpose();\n    let expected = TensorData::from([\n        [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]],\n        [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]],\n    ]);\n    let expected_ones = TensorData::from([[[1.0]]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n    ones.into_data()\n        .assert_approx_eq::<FloatElem>(&expected_ones, Tolerance::default());\n}\n\n#[test]\nfn should_support_swap_dims_no_op() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.swap_dims(0, 0);\n    let expected = TensorData::from([\n        [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n        [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_swap_dims() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.swap_dims(0, 2);\n    let expected = TensorData::from([\n        [[0.0, 6.0], [3.0, 9.0]],\n        [[1.0, 7.0], [4.0, 10.0]],\n        [[2.0, 8.0], [5.0, 11.0]],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_swap_dims_neg_index() {\n    let tensor = TestTensor::<3>::from_floats(\n        [\n            [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n            [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n        ],\n        &Default::default(),\n    );\n\n    let output = tensor.swap_dims(-3, -1);\n    let expected = TensorData::from([\n        [[0.0, 6.0], [3.0, 9.0]],\n        [[1.0, 7.0], [4.0, 10.0]],\n        [[2.0, 8.0], [5.0, 11.0]],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/tri.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_triu() {\n    let tensor = TestTensor::<2>::from([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]);\n    let output = tensor.triu(0);\n    let expected = TensorData::from([[1., 1., 1.], [0., 1., 1.], [0., 0., 1.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_triu_positive_diagonal() {\n    let tensor = TestTensor::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]);\n\n    let output = tensor.triu(1);\n    let expected = TensorData::from([[0, 1, 1], [0, 0, 1], [0, 0, 0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/trig.rs",
    "content": "#![allow(clippy::approx_constant)]\n\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse core::f32::consts::{FRAC_PI_2, FRAC_PI_3, FRAC_PI_4, FRAC_PI_6, FRAC_PI_8, PI};\n\n#[test]\nfn should_support_cos_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.cos();\n    let expected = TensorData::from([[1.0, 0.54030, -0.41615], [-0.98999, -0.65364, 0.28366]]);\n\n    // Metal has less precise trigonometric functions\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-2);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn should_support_cosh_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.cosh();\n    let expected = TensorData::from([[1.0000, 1.5431, 3.7622], [10.0677, 27.3082, 74.2099]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_sin_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.sin();\n    let expected = TensorData::from([[0.0, 0.841471, 0.909297], [0.141120, -0.756802, -0.958924]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_sinh_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.sinh();\n    let expected = TensorData::from([[0.0000, 1.1752, 3.6269], [10.0179, 27.2899, 74.2032]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_tan_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.tan();\n    let expected = TensorData::from([[0.0, 1.557408, -2.185040], [-0.142547, 1.157821, -3.380515]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_tanh_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.tanh();\n    let expected = TensorData::from([[0.0, 0.761594, 0.964028], [0.995055, 0.999329, 0.999909]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_asin_ops() {\n    let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -1.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.asin();\n    let expected = TensorData::from([[0.0, 0.523599, 0.785398], [-0.523599, -0.785398, -1.570796]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_acos_ops() {\n    let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -1.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.acos();\n    let expected = TensorData::from([\n        [1.570796, 1.047198, 0.785398],\n        [2.094395, 2.356194, 3.141593],\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_atan_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.atan();\n    let expected = TensorData::from([[0.0, 0.785398, 1.107149], [1.249046, 1.325818, 1.373401]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_asinh_ops() {\n    let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.asinh();\n    let expected = TensorData::from([[0.0, 0.881374, 1.443635], [1.818446, 2.094713, 2.312438]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_acosh_ops() {\n    let data = TensorData::from([[1.0, 1.5, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.acosh();\n    let expected = TensorData::from([[0.0, 0.962424, 1.316958], [1.762747, 2.063437, 2.292432]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_atanh_ops() {\n    let data = TensorData::from([[0.0, 0.5, 0.707107], [-0.5, -0.707107, -0.9]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.atanh();\n    let expected = TensorData::from([[0.0, 0.549306, 0.881374], [-0.549306, -0.881374, -1.472219]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_atan2_ops() {\n    let y = TensorData::from([[0.0, 1.0, 1.0], [-1.0, -1.0, 0.0]]);\n    let x = TensorData::from([[1.0, 1.0, 0.0], [1.0, 0.0, -1.0]]);\n\n    let y_tensor = TestTensor::<2>::from_data(y, &Default::default());\n    let x_tensor = TestTensor::<2>::from_data(x, &Default::default());\n\n    let output = y_tensor.atan2(x_tensor);\n    let expected = TensorData::from([[0.0, 0.785398, 1.570796], [-0.785398, -1.570796, 3.141593]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_deg2rad_ops() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats(\n        [\n            0.0, 22.5, 30.0, 45.0, 60.0, 90.0, 135.0, 180.0, 270.0, 360.0,\n        ],\n        &device,\n    );\n\n    let output = tensor.deg2rad();\n    let expected = TensorData::from([\n        0.0f32,\n        FRAC_PI_8,\n        FRAC_PI_6,\n        FRAC_PI_4,\n        FRAC_PI_3,\n        FRAC_PI_2,\n        0.75 * PI,\n        PI,\n        1.5 * PI,\n        2.0 * PI,\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_support_rad2deg_ops() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats(\n        [\n            0.0,\n            FRAC_PI_8,\n            FRAC_PI_6,\n            FRAC_PI_4,\n            FRAC_PI_3,\n            FRAC_PI_2,\n            PI,\n            1.5 * PI,\n            2.0 * PI,\n            -FRAC_PI_3,\n        ],\n        &device,\n    );\n\n    let output = tensor.rad2deg();\n    let expected = TensorData::from([\n        0.0f32, 22.5, 30.0, 45.0, 60.0, 90.0, 180.0, 270.0, 360.0, -60.0,\n    ]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/trunc.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{ElementConversion, TensorData};\n\n#[test]\nfn should_support_trunc_ops() {\n    let data = TensorData::from([[2.3, -1.7, 0.5], [-0.5, 3.9, -4.2]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.trunc();\n    let expected = TensorData::from([[2.0, -1.0, 0.0], [0.0, 3.0, -4.0]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_truncate_positive_values_like_floor() {\n    let data = TensorData::from([1.7, 2.9, 3.1, 4.5]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.trunc();\n    let expected = TensorData::from([1.0, 2.0, 3.0, 4.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_truncate_negative_values_like_ceil() {\n    let data = TensorData::from([-1.7, -2.9, -3.1, -4.5]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.trunc();\n    let expected = TensorData::from([-1.0, -2.0, -3.0, -4.0]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn should_handle_special_cases() {\n    // Test special IEEE 754 cases\n    let data = TensorData::from([0.0, -0.0, f32::INFINITY, f32::NEG_INFINITY, f32::NAN]);\n    let tensor = TestTensor::<1>::from_data(data, &Default::default());\n\n    let output = tensor.trunc();\n    let values = output.into_data().as_slice::<FloatElem>().unwrap().to_vec();\n\n    // Check positive zero\n    assert_eq!(values[0], 0.0f32.elem::<FloatElem>());\n    assert!(values[0].is_sign_positive());\n\n    // Check negative zero is preserved\n    assert_eq!(values[1], 0.0f32.elem::<FloatElem>());\n    assert!(values[1].is_sign_negative());\n\n    // Check infinity is preserved\n    assert!(values[2].is_infinite() && values[2].is_sign_positive());\n    assert!(values[3].is_infinite() && values[3].is_sign_negative());\n\n    // Check NaN is preserved\n    assert!(values[4].is_nan());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/ops/unfold.rs",
    "content": "use super::*;\nuse burn_tensor::Distribution;\nuse burn_tensor::s;\n\n#[test]\nfn test_unfold_float() {\n    let device = Default::default();\n\n    let input = TestTensor::<3>::random([2, 6, 6], Distribution::Default, &device);\n\n    let dim = 1;\n    let size = 3;\n    let step = 2;\n    let actual: TestTensor<4> = input.clone().unfold(dim, size, step);\n\n    let expected = TestTensor::<4>::empty([2, 2, 6, 3], &device)\n        .slice_assign(\n            s![.., 0, .., ..],\n            input\n                .clone()\n                .slice(s![.., 0..3, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        )\n        .slice_assign(\n            s![.., 1, .., ..],\n            input\n                .clone()\n                .slice(s![.., 2..5, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        );\n\n    actual.to_data().assert_eq(&expected.to_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/primitive.rs",
    "content": "use super::*;\nuse burn_tensor::{Element, Shape};\n\n#[test]\nfn should_support_float_dtype() {\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]).into_primitive();\n\n    assert_eq!(\n        burn_tensor::TensorMetadata::shape(&tensor),\n        Shape::new([2, 3])\n    );\n    assert_eq!(\n        burn_tensor::TensorMetadata::dtype(&tensor),\n        FloatElem::dtype() // default float elem type\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/calibration.rs",
    "content": "use super::*;\nuse burn_tensor::{\n    TensorData,\n    ops::QuantizedTensor,\n    quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantValue, compute_range},\n};\n\n// NOTE: The scheme variant fields are not important for calibration, only the \"main\" variant (e.g., per-tensor)\n#[test]\nfn min_max_calibration_range_per_tensor() {\n    let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default());\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let range = compute_range(&scheme, &tensor, &Calibration::MinMax);\n\n    range\n        .min\n        .into_data()\n        .assert_eq(&TensorData::from([-1.8]), false);\n    range\n        .max\n        .into_data()\n        .assert_eq(&TensorData::from([0.5]), false);\n}\n\n#[test]\nfn min_max_calibration_range_per_block() {\n    let tensor = TestTensor::<2>::from_floats(\n        [\n            [-1.8, -1.0, 0.0, 0.5],\n            [1.8, 1.0, 0.0, -0.5],\n            [0.01, 0.02, 0.03, 0.04],\n            [-0.01, -0.02, -0.03, -0.04],\n        ],\n        &Default::default(),\n    );\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([4]));\n\n    let range = compute_range(&scheme, &tensor, &Calibration::MinMax);\n\n    range\n        .min\n        .into_data()\n        .assert_eq(&TensorData::from([[-1.8], [-0.5], [0.01], [-0.04]]), false);\n    range\n        .max\n        .into_data()\n        .assert_eq(&TensorData::from([[0.5], [1.8], [0.04], [-0.01]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/data.rs",
    "content": "use super::*;\nuse alloc::vec;\nuse burn_tensor::quantization::{QTensorPrimitive, QuantLevel, QuantValue};\nuse burn_tensor::{TensorData, ops::QuantizedTensor};\n\n#[test]\nfn should_support_per_tensor_symmetric_int8() {\n    let data = TensorData::quantized(\n        vec![-127i8, -71, 0, 35],\n        [4],\n        QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S),\n        &[0.014_173_228],\n    );\n    let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default());\n\n    let q_data = tensor.into_data();\n    q_data.assert_eq(&data, true);\n\n    let tensor = TestTensor::<1>::from_data(q_data.clone(), &Default::default());\n\n    tensor.into_data().assert_eq(&q_data, true);\n}\n\n#[test]\nfn should_support_per_block_symmetric_int8() {\n    let data = TensorData::quantized(\n        vec![\n            -127i8, -71, 0, 35, -127i8, -71, 0, 35, -32, -63, -95, -127, -32, -63, -95, -127,\n        ],\n        [16],\n        QuantizedTensor::<TestBackend>::default_scheme()\n            .with_value(QuantValue::Q8S)\n            .with_level(QuantLevel::block([8])),\n        &[0.014_173_228, 0.000_314_96],\n    );\n    let tensor = TestTensor::<1>::from_data(data.clone(), &Default::default());\n\n    tensor.into_data().assert_eq(&data, true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod calibration;\nmod data;\nmod ops;\nmod scheme;\n\n/// Quantized tensor utilities\npub mod qtensor {\n    use core::marker::PhantomData;\n\n    use burn_tensor::quantization::QuantLevel;\n\n    use burn_tensor::{\n        Tensor, TensorData,\n        backend::Backend,\n        quantization::{QTensorPrimitive, QuantValue},\n    };\n\n    pub struct QTensor<B: Backend, const D: usize> {\n        b: PhantomData<B>,\n    }\n\n    impl<B: Backend, const D: usize> QTensor<B, D> {\n        /// Creates a quantized int8 tensor from the floating point data using the default quantization scheme\n        /// (i.e., per-tensor symmetric quantization).\n        pub fn int8<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {\n            Self::int8_symmetric(floats)\n        }\n\n        /// Creates a quantized int8 tensor from the floating point data using blocks of size 16\n        pub fn int8_block<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {\n            Tensor::from_floats(floats, &Default::default()).quantize_dynamic(\n                &<B::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()\n                    .with_value(QuantValue::Q8S)\n                    .with_level(QuantLevel::block([16])),\n            )\n        }\n\n        /// Creates a quantized int8 tensor from the floating point data using per-tensor symmetric quantization.\n        pub fn int8_symmetric<F: Into<TensorData>>(floats: F) -> Tensor<B, D> {\n            Tensor::from_floats(floats, &Default::default()).quantize_dynamic(\n                &<B::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()\n                    .with_value(QuantValue::Q8S),\n            )\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/abs.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_abs_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);\n\n    let output = tensor.abs();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/add.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_add_d2() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_add_broadcast() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[3.0, 5.0, 7.0], [6.0, 8.0, 10.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_add_different_strides_rhs() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0], [6.0, 7.0]]) * 1;\n\n    let output = tensor_1 + tensor_2.transpose();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[4.0, 7.0], [7.0, 10.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_add_different_strides_lhs() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0], [6.0, 7.0]]) * 1;\n\n    let output = tensor_1.transpose() + tensor_2;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[4.0, 7.0], [7.0, 10.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_add_different_strides_broadcast() {\n    // We need to execute an operation after `from data` to trigger inplace in some backends.\n    // Which is the operation that might be problematic in this case.\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0], [2.0, 3.0]]) * 1;\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0]]) * 1;\n\n    let output = tensor_1.transpose() + tensor_2;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[4.0, 7.0], [5.0, 8.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_support_add_scalar_ops() {\n    let scalar = 2.0;\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor + scalar;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/aggregation.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_should_mean() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.mean();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([15.0 / 6.0]), Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_should_sum() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sum();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([15.0]), Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_should_mean_last_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.mean_dim(1);\n    let expected = TensorData::from([[3.0 / 3.0], [12.0 / 3.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_should_sum_last_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sum_dim(1);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[3.0], [12.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_should_sum_first_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);\n\n    let output = tensor.sum_dim(0);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[7.0, 3.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_should_mean_first_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]]);\n\n    let output = tensor.mean_dim(0);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[7.0 / 2.0, 3.0 / 2.0, 5.0 / 2.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_should_sum_mid_dim_3d_non_contiguous_1() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],\n        [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],\n    ]);\n\n    let output = tensor.swap_dims(0, 2).sum_dim(1);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::new(vec![9.0, 7.0, -1.0, 3.0, 4.0, 5.0], [3, 1, 2]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_should_sum_mid_dim_3d_non_contiguous_2() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[2.0, 4.0, 1.0], [7.0, -5.0, 3.0]],\n        [[3.0, 1.0, 2.0], [4.0, 2.0, 3.0]],\n    ]);\n\n    let output = tensor.swap_dims(0, 1).sum_dim(1);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::new(vec![5.0, 5.0, 3.0, 11.0, -3.0, 6.0], [2, 1, 3]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn test_prod_float() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.prod();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([240.0]), Tolerance::rel_abs(1e-1, 1e-1));\n\n    let tensor_with_zero = QTensor::<TestBackend, 2>::int8([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor_with_zero.prod();\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([0.0]), Tolerance::rel_abs(1e-1, 1e-1));\n}\n\n#[test]\nfn test_prod_dim_float() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[2.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.prod_dim(1);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[4.0], [60.0]]),\n            Tolerance::absolute(1e-1),\n        );\n\n    let tensor_with_zero = QTensor::<TestBackend, 2>::int8([[2.0, 0.0, 2.0], [3.0, 4.0, 5.0]]);\n    let output = tensor_with_zero.prod_dim(1);\n    let expected = TensorData::from([[0.0], [60.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/all.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_all() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([false]);\n    assert_eq!(data_expected, data_actual);\n\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([true]);\n    assert_eq!(data_expected, data_actual);\n}\n\n#[test]\nfn test_all_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 0.0], [1.0, -1.0, 1.0]]);\n    let data_actual = tensor.all_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    assert_eq!(data_expected, data_actual);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/any.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_any() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    assert_eq!(data_expected, data_actual);\n\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    assert_eq!(data_expected, data_actual);\n}\n\n#[test]\nfn test_any_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [1.0, -1.0, 0.0]]);\n\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    assert_eq!(data_expected, data_actual);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/arg.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_argmax_2d_dim0() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.argmax(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 0, 1]]), false);\n}\n\n#[test]\nfn test_argmin_2d_dim0() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);\n\n    let output = tensor.argmin(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 0]]), false);\n}\n\n#[test]\nfn test_argmax_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.argmax(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1], [2]]), false);\n}\n\n#[test]\nfn test_argmin_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);\n\n    let output = tensor.argmin(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2], [1]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cat.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse alloc::vec;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_cat_ops_2d_dim0() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0, 6.0]]);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_cat_ops_2d_dim1() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0, 6.0]]);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 1);\n    let expected = TensorData::from([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_cat_ops_3d() {\n    let tensor_1 = QTensor::<TestBackend, 3>::int8([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]]]);\n    let tensor_2 = QTensor::<TestBackend, 3>::int8([[[4.0, 5.0, 6.0]]]);\n\n    let output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[1.1, 2.1, 3.1]], [[4.0, 5.0, 6.0]]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_dimensions_are_not_the_same() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0]]);\n\n    let _output = TestTensor::cat(vec![tensor_1, tensor_2], 0);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_cat_exceeds_dimension() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0, 6.0]]);\n\n    let _output = TestTensor::cat(vec![tensor_1, tensor_2], 3);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/ceil.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_ceil_ops() {\n    let tensor =\n        QTensor::<TestBackend, 2>::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n\n    let output = tensor.ceil();\n    let expected = TensorData::from([[25., 88., 77.], [60., 44., 96.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(1e-1, 1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/chunk.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse alloc::vec::Vec;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_chunk_evenly_divisible() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let tensors: Vec<TestTensor<1>> = tensor.chunk(3, 0);\n    assert_eq!(tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([0., 1.]),\n        TensorData::from([2., 3.]),\n        TensorData::from([4., 5.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_chunk_not_evenly_divisible() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);\n\n    let tensors: Vec<TestTensor<1>> = tensor.chunk(4, 0);\n    assert_eq!(tensors.len(), 4);\n\n    let expected = [\n        TensorData::from([0., 1.]),\n        TensorData::from([2., 3.]),\n        TensorData::from([4., 5.]),\n        TensorData::from([6.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_chunk_not_divisible() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let tensors: Vec<TestTensor<1>> = tensor.chunk(7, 0);\n    assert_eq!(tensors.len(), 6);\n\n    let expected = [\n        TensorData::from([0.]),\n        TensorData::from([1.]),\n        TensorData::from([2.]),\n        TensorData::from([3.]),\n        TensorData::from([4.]),\n        TensorData::from([5.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_chunk_multi_dimension() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]);\n\n    let tensors: Vec<TestTensor<2>> = tensor.chunk(2, 1);\n    assert_eq!(tensors.len(), 2);\n\n    let expected = [\n        TensorData::from([[0., 1., 2.]]),\n        TensorData::from([[3., 4., 5.]]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\n#[should_panic]\nfn test_invalid_dim() {\n    let _tensors = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).chunk(6, 1);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/clamp.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn clamp_min() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.clamp_min(2.0);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[2.0, 2.0, 2.0], [3.0, 4.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn clamp_max() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.clamp_max(2.0);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 1.0, 2.0], [2.0, 2.0, 2.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn clamp_min_max() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.clamp(1.0, 4.0);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[1.0, 1.0, 2.0], [3.0, 4.0, 4.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cos.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_cos_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.cos();\n    let expected = TensorData::from([[1.0, 0.5403, -0.4161], [-0.9899, -0.6536, 0.2836]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/cosh.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_cosh_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.cosh();\n    let expected = TensorData::from([[1.0000, 1.5431, 3.7622], [10.0677, 27.3082, 74.2100]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/div.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_div_ops() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor_1 / tensor_2;\n    let expected = TensorData::from([[0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_div_broadcast() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor_1 / tensor_2;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 1.0, 1.0], [0.0, 0.25, 0.4]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_support_div_scalar_ops() {\n    let scalar = 2.0;\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor / scalar;\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/erf.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_erf_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.erf();\n    let expected = TensorData::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_erf_ops_with_negative_number() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.erf();\n    let expected = TensorData::from([\n        [-0.06312324, -0.048490416, -0.10016122],\n        [1.0000, 1.0000, 1.0000],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/exp.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_exp_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.exp();\n    let expected = TensorData::from([[1.0, 2.71830, 7.3891], [20.0855, 54.5981, 148.4132]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/expand.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn expand_2d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0]);\n    let output = tensor.expand([3, 3]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]),\n            Tolerance::absolute(1e-1),\n        );\n\n    // Quantized [4.0, 7.0, 2.0, 3.0]\n    let tensor = QTensor::<TestBackend, 1>::int8([4.0, 7.0, 2.0, 3.0]);\n    let output = tensor.expand([2, 4]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[4.0, 7.0, 2.0, 3.0], [4.0, 7.0, 2.0, 3.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn expand_3d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 2.0], [3.0, 4.0]]);\n\n    let output = tensor.expand([3, 2, 2]);\n    let expected = TensorData::from([\n        [[1.0, 2.0], [3.0, 4.0]],\n        [[1.0, 2.0], [3.0, 4.0]],\n        [[1.0, 2.0], [3.0, 4.0]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn expand_higher_dimensions() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0, 4.0]]);\n\n    let output = tensor.expand([2, 3, 4]);\n    let expected = TensorData::from([\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n        ],\n        [\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n            [1.0, 2.0, 3.0, 4.0],\n        ],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn broadcast_single() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0]);\n\n    let output = tensor.expand([2, 3]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_eq(&TensorData::from([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), false);\n}\n\n#[test]\n#[should_panic]\nfn should_fail_expand_incompatible_shapes() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0]);\n    let _expanded_tensor = tensor.expand([2, 2]);\n}\n\n#[test]\nfn should_all_negative_one() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0]);\n\n    let output = tensor.expand([2, -1]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[1., 2., 3.], [1., 2., 3.]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\n#[should_panic]\nfn should_panic_negative_one_on_non_existing_dim() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0]);\n    let _expanded_tensor = tensor.expand([-1, 3]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/flip.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn flip_float() {\n    let tensor = QTensor::<TestBackend, 3>::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]);\n\n    let flipped = tensor.clone().flip([0, 2]);\n    let expected = TensorData::from([[[5., 4., 3.]], [[2., 1., 0.]]]);\n\n    flipped\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n\n    // Test with no flip\n    let flipped = tensor.clone().flip([]);\n    tensor.into_data().assert_eq(&flipped.into_data(), true);\n}\n\n#[test]\n#[should_panic]\nfn flip_duplicated_axes() {\n    let tensor = QTensor::<TestBackend, 3>::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]);\n\n    // Test with a duplicated axis\n    let _ = tensor.flip([0, 0, 1]);\n}\n\n#[test]\n#[should_panic]\nfn flip_out_of_bound_axis() {\n    let tensor = QTensor::<TestBackend, 3>::int8([[[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]]);\n\n    // Test with an out of bound axis\n    let _ = tensor.clone().flip([3, 0, 1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/floor.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_floor_ops() {\n    let tensor =\n        QTensor::<TestBackend, 2>::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n\n    let output = tensor.floor();\n    let expected = TensorData::from([[24., 87., 76.], [59., 43., 95.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(1e-1, 1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/gather_scatter.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::IndexingUpdateOp;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_gather_1d_dim0() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &Default::default());\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_gather_2d_dim0() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]], &Default::default());\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_gather_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]], &Default::default());\n\n    let output = tensor.gather(1, indices);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_gather_3d_dim1() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n        [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n    ]);\n    let indices = TestTensorInt::from_ints(\n        [[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]],\n        &Default::default(),\n    );\n\n    let output = tensor.gather(1, indices);\n    let expected = TensorData::from([\n        [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]],\n        [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_gather_2d_only_1dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::<2>::from_ints([[1, 2]], &Default::default()).reshape([2, 1]);\n\n    let output = tensor.gather(1, indices);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[1.0], [5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_scatter_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 0.0, 0.0]);\n    let values = QTensor::<TestBackend, 1>::int8([5.0, 4.0, 3.0]);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &Default::default());\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([4.0, 5.0, 3.0]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_scatter_2d_dim0() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);\n    let values = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n    let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]], &Default::default());\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_scatter_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);\n    let values = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);\n    let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]], &Default::default());\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\nfn should_scatter_3d_dim1() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],\n        [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],\n    ]);\n    let values = QTensor::<TestBackend, 3>::int8([\n        [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],\n        [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],\n    ]);\n    let indices = TestTensorInt::from_ints(\n        [[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]],\n        &Default::default(),\n    );\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([\n        [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]],\n        [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]],\n    ]);\n\n    // Set higher tolerance (0.2) due to larger de/quantization errors\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(2e-1));\n}\n\n#[test]\nfn should_scatter_2d_dim1_diff_shape() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);\n    let values = QTensor::<TestBackend, 2>::int8([[1.0], [4.0]]);\n    let indices = TestTensorInt::from_ints([[1], [2]], &Default::default());\n\n    let output = tensor.scatter(1, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &TensorData::from([[0.0, 1.0, 0.0], [0.0, 0.0, 4.0]]),\n            Tolerance::absolute(1e-1),\n        );\n}\n\n#[test]\n#[should_panic]\nfn scatter_should_panic_on_mismatch_of_shapes() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 0.0, 0.0]);\n    let values = QTensor::<TestBackend, 1>::int8([1.0, 4.0]);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &Default::default());\n\n    tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_log_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.log();\n    let expected = TensorData::from([\n        [-f32::INFINITY, 0.0, core::f32::consts::LN_2],\n        [1.0986, 1.3862, 1.6094],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/log1p.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_exp_log1p() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.log1p();\n    let expected = TensorData::from([\n        [0.0, core::f32::consts::LN_2, 1.0986],\n        [1.3862, 1.6094, 1.7917],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/map_comparison.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_equal() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_not_equal() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.not_equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\n#[ignore = \"quantization equality with float element is undefined\"]\nfn test_equal_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().equal_elem(2);\n    let data_actual_inplace = tensor.equal_elem(2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\n#[ignore = \"quantization equality with float element is undefined\"]\nfn test_not_equal_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 2.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().not_equal_elem(2);\n    let data_actual_inplace = tensor.not_equal_elem(2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\n#[ignore = \"quantization equality with float element is undefined\"]\nfn test_greater_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().greater_elem(4);\n    let data_actual_inplace = tensor.greater_elem(4);\n\n    let data_expected = TensorData::from([[false, false, false], [false, false, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_greater_equal_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().greater_equal_elem(4.0);\n    let data_actual_inplace = tensor.greater_equal_elem(4.0);\n\n    let data_expected = TensorData::from([[false, false, false], [false, true, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_greater() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater(tensor_2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, false, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_greater_equal() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_lower_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().lower_elem(4.0);\n    let data_actual_inplace = tensor.lower_elem(4.0);\n\n    let data_expected = TensorData::from([[true, true, true], [true, false, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\n#[ignore = \"quantization equality with float element is undefined\"]\nfn test_lower_equal_elem() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let data_actual_cloned = tensor.clone().lower_equal_elem(4.0);\n    let data_actual_inplace = tensor.lower_equal_elem(4.0);\n\n    let data_expected = TensorData::from([[true, true, true], [true, true, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_lower() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower(tensor_2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n\n#[test]\nfn test_lower_equal() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 1.0], [3.0, 5.0, 4.0]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, true, false]]);\n    assert_eq!(data_expected, data_actual_cloned.into_data());\n    assert_eq!(data_expected, data_actual_inplace.into_data());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mask.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_mask_where_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 7.0], [2.0, 3.0]]);\n    let mask = TestTensorBool::<2>::from_bool(\n        TensorData::from([[true, false], [false, true]]),\n        &Default::default(),\n    );\n    let value = QTensor::<TestBackend, 2>::int8([[1.8, 2.8], [3.8, 4.8]]);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([[1.8, 7.0], [2.0, 4.8]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_mask_fill_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 7.0], [2.0, 3.0]]);\n    let mask = TestTensorBool::<2>::from_bool(\n        TensorData::from([[true, false], [false, true]]),\n        &Default::default(),\n    );\n\n    let output = tensor.mask_fill(mask, 2.0);\n    let expected = TensorData::from([[2.0, 7.0], [2.0, 2.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/maxmin.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_max_dim_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.max_dim(1);\n    let expected = TensorData::from([[2.], [5.]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_max_dim_with_indices_2d_with_dim_0th() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let (output, index) = tensor.max_dim_with_indices(0);\n\n    let output_expected = TensorData::from([[3., 4., 5.]]);\n    let index_expected = TensorData::from([[1, 1, 1]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_expected, Tolerance::rel_abs(2e-2, 1e-2));\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_max_dim_with_indices_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let (output, index) = tensor.max_dim_with_indices(1);\n\n    let output_expected = TensorData::from([[2.], [5.]]);\n    let index_expected = TensorData::from([[2], [2]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_expected, Tolerance::rel_abs(2e-2, 1e-2));\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_min_dim_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.min_dim(1);\n\n    let expected = TensorData::from([[0.], [3.]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_min_dim_with_indices_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let (output, index) = tensor.min_dim_with_indices(1);\n\n    let output_expected = TensorData::from([[0.], [3.]]);\n    let index_expected = TensorData::from([[0], [0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_expected, Tolerance::rel_abs(2e-2, 1e-2));\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_min_dim_2d_with_0th_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.min_dim(0);\n    let expected = TensorData::from([[0., 1., 2.]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_max_dim_2d_with_0th_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.max_dim(0);\n    let expected = TensorData::from([[3., 4., 5.]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_min_dim_with_indices_2d_with_0th_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let (output, index) = tensor.min_dim_with_indices(0);\n\n    let output_expected = TensorData::from([[0., 1., 2.]]);\n    let index_expected = TensorData::from([[0, 0, 0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&output_expected, Tolerance::rel_abs(2e-2, 1e-2));\n    index.into_data().assert_eq(&index_expected, false);\n}\n\n#[test]\nfn test_maximum_pair() {\n    let a = QTensor::<TestBackend, 1>::int8([1.0, 5.0, 3.0, 4.0]);\n    let b = QTensor::<TestBackend, 1>::int8([2.0, 1.0, 4.0, 5.0]);\n\n    let output = a.max_pair(b);\n    let expected = TensorData::from([2.0, 5.0, 4.0, 5.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_minimum_pair() {\n    let a = QTensor::<TestBackend, 1>::int8([1.0, 5.0, 3.0, 4.0]);\n    let b = QTensor::<TestBackend, 1>::int8([2.0, 1.0, 4.0, 5.0]);\n\n    let output = a.min_pair(b);\n    let expected = TensorData::from([1.0, 1.0, 3.0, 4.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mod.rs",
    "content": "pub use super::*;\n\nmod abs;\nmod add;\nmod aggregation;\nmod all;\nmod any;\nmod arg;\nmod cat;\nmod ceil;\nmod chunk;\nmod clamp;\nmod cos;\nmod cosh;\nmod div;\nmod erf;\nmod exp;\nmod expand;\nmod flip;\nmod floor;\nmod gather_scatter;\nmod log;\nmod log1p;\nmod map_comparison;\nmod mask;\nmod maxmin;\nmod mul;\nmod narrow;\nmod neg;\nmod permute;\nmod powf;\nmod powf_scalar;\nmod recip;\nmod remainder;\nmod repeat_dim;\nmod reshape;\nmod round;\nmod select;\nmod sin;\nmod sinh;\nmod slice;\nmod sort_argsort;\nmod split;\nmod sqrt;\nmod stack;\nmod sub;\nmod tan;\nmod tanh;\nmod topk;\nmod transpose;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/mul.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_mul_ops() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = tensor_1.clone();\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(5e-2, 1e-2));\n}\n\n#[test]\nfn test_mul_broadcast() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 4.0, 10.0], [0.0, 7.0, 16.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn test_mul_broadcast_2_dims() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0], [1.0], [2.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[3.0, 4.0, 5.0]]);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0.0, 0.0, 0.0], [3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_mul_scalar_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let scalar = 2.0;\n\n    let output = tensor * scalar;\n    let expected = TensorData::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/narrow.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{Shape, TensorData};\n\n#[test]\nfn test_narrow() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]);\n\n    let output = tensor.clone().narrow(0, 0, 2);\n    let expected = TensorData::from([[1., 2., 3.], [7., 8., 9.]]);\n\n    assert_eq!(output.shape(), Shape::from([2, 3]));\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n\n    let output = tensor.narrow(1, 1, 2);\n    let expected = TensorData::from([[2., 3.], [8., 9.], [14., 15.]]);\n    assert_eq!(output.shape(), Shape::from([3, 2]));\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]);\n\n    let _output = tensor.narrow(2, 0, 2);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_start() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]);\n\n    let _output = tensor.narrow(0, 3, 2);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_zero_length() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]);\n\n    let _output = tensor.narrow(0, 1, 0);\n}\n\n#[test]\n#[should_panic]\nfn test_narrow_invalid_length() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1., 2., 3.], [7., 8., 9.], [13., 14., 15.]]);\n\n    let _output = tensor.narrow(0, 0, 4);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/neg.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_neg_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.neg();\n    let expected = TensorData::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]).convert::<f32>();\n\n    // -0.0 is represented differently than 0.0 so we make sure the values are the same in f32\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/permute.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn permute_float() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,\n    ])\n    .reshape([2, 2, 4]);\n\n    let permuted = tensor.clone().permute([2, 1, 0]);\n\n    let expected = TensorData::from([\n        [[0., 8.], [4., 12.]],\n        [[1., 9.], [5., 13.]],\n        [[2., 10.], [6., 14.]],\n        [[3., 11.], [7., 15.]],\n    ]);\n\n    permuted\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(1e-1, 1e-1));\n\n    // Test with negative axis\n    let permuted = tensor.clone().permute([-1, 1, 0]);\n    permuted\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(1e-1, 1e-1));\n\n    // Test with the same axis\n    let permuted = tensor.clone().permute([0, 1, 2]);\n    permuted\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(\n            &tensor.dequantize().into_data(),\n            Tolerance::rel_abs(1e-4, 1e-4), // dequant error should be the same\n        );\n}\n\n#[test]\n#[should_panic]\nfn edge_repeated_axes() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,\n    ])\n    .reshape([2, 2, 4]);\n\n    // Test with a repeated axis\n    let _ = tensor.permute([0, 0, 1]);\n}\n\n#[test]\n#[should_panic]\nfn edge_out_of_bound_axis() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,\n    ])\n    .reshape([2, 2, 4]);\n\n    // Test with an invalid axis\n    let _ = tensor.permute([3, 0, 1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_powf_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_pow = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_pow = QTensor::<TestBackend, 2>::int8([[-0.95, -0.67, -0.45], [-0.24, -0.5, -0.6]]);\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[1., 1., 0.73204285], [0.76822936, 0.5, 0.38073079]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_values_with_even_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n    let tensor_pow = QTensor::<TestBackend, 2>::int8([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]);\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_values_with_odd_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]]);\n    let tensor_pow = QTensor::<TestBackend, 2>::int8([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]]);\n\n    let output = tensor.powf(tensor_pow);\n    let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/powf_scalar.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_powf_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.powf_scalar(0.71);\n    let expected = TensorData::from([[0.0, 1.0, 1.6358], [2.182, 2.6759, 3.1352]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.powf_scalar(-0.33);\n    let expected = TensorData::from([[1.0, 1.0, 0.79553646], [0.695905, 0.6328783, 0.58794934]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_values_with_even_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]);\n\n    let output = tensor.powf_scalar(2.0);\n    let expected = TensorData::from([[0., 1., 4.], [9., 16., 25.]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n\n#[test]\nfn should_support_neg_values_with_odd_power() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, -1.0, -2.0], [-3.0, -4.0, -4.0]]);\n\n    let output = tensor.powf_scalar(3.0);\n    let expected = TensorData::from([[0.0, -1.0, -8.0], [-27.0, -64.0, -64.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(4e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/recip.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_recip_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.5, 1.0, 2.0], [3.0, -4.0, -5.0]]);\n\n    let output = tensor.recip();\n    let expected = TensorData::from([[2.0, 1.0, 0.5], [0.33333, -0.25, -0.2]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/remainder.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_remainder_basic() {\n    let lhs = QTensor::<TestBackend, 1>::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 2.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([2.0, 3.0, 1.0, 2.0, 1.0, 2.0]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([1., 1., 0., 1., 0., 0.]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[ignore = \"quantization remainder with float element is undefined\"]\nfn should_support_remainder_basic_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);\n\n    let output = tensor.remainder_scalar(2.0);\n    let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_remainder_float() {\n    let lhs = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([1.4233, 2.7313, 0.2641, 1.9651, 0.5897]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([1., 2., 0.0949, 0.0698, 0.2824]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_remainder_float_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let output = tensor.remainder_scalar(-1.5);\n    let expected = TensorData::from([-0.5, -1.0, 0.0, -0.5, -1.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_be_zero() {\n    let lhs = QTensor::<TestBackend, 1>::int8([0.0, 0.0, 0.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([3.5, -2.1, 1.5]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([0.0, 0.0, 0.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_be_zero_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 0.0, 0.0]);\n\n    let output = tensor.remainder_scalar(3.5);\n    let expected = TensorData::from([0.0, 0.0, 0.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_have_no_remainder() {\n    let lhs = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([0.0, 0.0, 0.0, 0.0, 0.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_have_no_remainder_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([4.0, 4.0]);\n\n    let output = tensor.remainder_scalar(4.0);\n    let expected = TensorData::from([0.0, 0.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_be_negative() {\n    let lhs = QTensor::<TestBackend, 1>::int8([-7.0, -3.0, 2.0, 6.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([-2.5, -2.1, -1.5, -3.25]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([-2., -0.9, -1., -0.5]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_be_negative_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([-7.0, -3.0, 2.0, 6.0]);\n\n    let output = tensor.remainder_scalar(-2.5);\n    let expected = TensorData::from([-2.0, -0.50, -0.50, -1.5]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_fp_dividends() {\n    let tensor = QTensor::<TestBackend, 1>::int8([-7.5, -2.5, 2.5, 7.5]);\n\n    let output = tensor.remainder_scalar(3.0);\n    let expected = TensorData::from([1.5, 0.5, 2.5, 1.5]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_large_divisor() {\n    let lhs = QTensor::<TestBackend, 1>::int8([-1.0, 1.0, -1.5, 1.5, -1.0, 1.0, -1.5, 1.5]);\n    let rhs = QTensor::<TestBackend, 1>::int8([10.0, 10.0, 10.0, 10.0, -10.0, -10.0, -10.0, -10.0]);\n\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([9., 1., 8.5, 1.5, -1., -9., -1.5, -8.5]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_large_divisor_scalar() {\n    let tensor = QTensor::<TestBackend, 1>::int8([-1.0, 1.0]);\n\n    let output = tensor.remainder_scalar(10.0);\n    let expected = TensorData::from([9.0, 1.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_remainder_op() {\n    let lhs = QTensor::<TestBackend, 1>::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 2.0]);\n    let rhs = QTensor::<TestBackend, 1>::int8([2.0, 3.0, 1.0, 2.0, 1.0, 2.0]);\n\n    let output = lhs % rhs;\n    let expected = TensorData::from([1., 1., 0., 1., 0., 0.]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[ignore = \"quantization remainder with float element is undefined\"]\nfn should_support_remainder_scalar_op() {\n    let tensor = QTensor::<TestBackend, 1>::int8([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);\n\n    let output = tensor % 2.0;\n    let expected = TensorData::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/repeat_dim.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_support_repeat_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0, 3.0]]);\n\n    let output = tensor.repeat_dim(0, 4);\n    let expected = TensorData::from([\n        [0.0, 1.0, 2.0, 3.0],\n        [0.0, 1.0, 2.0, 3.0],\n        [0.0, 1.0, 2.0, 3.0],\n        [0.0, 1.0, 2.0, 3.0],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::permissive());\n}\n\n#[test]\nfn should_support_repeat_on_dims_larger_than_1() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,\n    ])\n    .reshape([4, 2, 2]);\n\n    let output = tensor.repeat_dim(2, 2);\n    let expected = TensorData::from([\n        [[0., 1., 0., 1.], [2., 3., 2., 3.]],\n        [[4., 5., 4., 5.], [6., 7., 6., 7.]],\n        [[8., 9., 8., 9.], [10., 11., 10., 11.]],\n        [[12., 13., 12., 13.], [14., 15., 14., 15.]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(1e-1, 1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/reshape.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn should_support_reshape_1d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0, 3.0]]);\n\n    let output = tensor.clone().reshape([1, 4]);\n    let expected = TensorData::from([[0.0, 1.0, 2.0, 3.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_reshape_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.clone().reshape([6]);\n    let expected = TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_dim_infererence() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,\n    ])\n    .reshape([4, 3]);\n\n    // Infer the dimension via -1\n    let reshaped = tensor.clone().reshape([2, -1]);\n    assert_eq!(reshaped.shape(), [2, 6].into());\n\n    // Infer the dimension via 0 (keep from the source) and -1 (infer)\n    let reshaped = reshaped.reshape([0, 2, -1]);\n    assert_eq!(reshaped.shape(), [2, 2, 3].into());\n\n    // This is effectively as if we did a flatten\n    let reshaped = tensor.clone().reshape([-1]);\n    assert_eq!(reshaped.shape(), [12].into());\n\n    // Keeping the first dimension the same (using 0)\n    let reshaped = tensor.clone().reshape([0, 3]);\n    assert_eq!(reshaped.shape(), [4, 3].into());\n}\n\n#[test]\nfn should_not_corrupt_after_slice() {\n    let zeros = QTensor::<TestBackend, 1>::int8([0.0, 0.0]);\n    zeros.clone().slice([1..2]).reshape([1]).exp();\n\n    // May lead to zeroes being equal to [0.0, 1.0]\n    zeros.dequantize().into_data().assert_eq(\n        &TestTensor::<1>::zeros([2], &Default::default()).to_data(),\n        true,\n    );\n}\n\n#[test]\n#[should_panic]\nfn multiple_neg_ones() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let _ = tensor.reshape([-1, -1]);\n}\n\n#[test]\n#[should_panic]\nfn neg_value() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let _ = tensor.reshape([-2, -1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/round.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_round_ops() {\n    let tensor =\n        QTensor::<TestBackend, 2>::int8([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);\n\n    let output = tensor.round();\n    let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_round_ties_even() {\n    // NOTE: round ties to even only affects values that are exact halfway from ceil/floor, so quantization\n    // errors can impact this. This basically only guarantees the values for the max value in the range since\n    // it is always represented correctly.\n    let tensor = QTensor::<TestBackend, 1>::int8([5.5]);\n\n    let output = tensor.round();\n    let expected = TensorData::from([6.]);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/select.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::IndexingUpdateOp;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_select_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0]);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default());\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([1.0, 1.0, 0.0, 1.0, 2.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_select_2d_dim0_same_num_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_data([1, 0], &Default::default());\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_select_2d_dim0_more_num_dim() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_data([1, 0, 1, 1], &Default::default());\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([\n        [3.0, 4.0, 5.0],\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [3.0, 4.0, 5.0],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_select_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default());\n\n    let output = tensor.select(1, indices);\n    let expected = TensorData::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_select_assign_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let values = QTensor::<TestBackend, 1>::int8([5.0, 4.0, 3.0, 2.0, 1.0]);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &Default::default());\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([3.0, 12.0, 3.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_select_assign_2d_dim0() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let values = tensor.clone();\n    let indices = TestTensorInt::from_data(TensorData::from([1, 0]), &Default::default());\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([[3.0, 5.0, 7.0], [3.0, 5.0, 7.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_select_assign_2d_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let values = tensor.clone();\n    let indices = TestTensorInt::from_data(TensorData::from([1, 0, 2]), &Default::default());\n\n    let output = tensor.select_assign(1, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([[1.0, 1.0, 4.0], [7.0, 7.0, 10.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[should_panic]\nfn should_select_panic_invalid_dimension() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &Default::default());\n\n    tensor.select(10, indices);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sin.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_sin_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sin();\n    let expected = TensorData::from([[0.0, 0.8414, 0.9092], [0.1411, -0.7568, -0.9589]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sinh.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_sinh_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sinh();\n    let expected = TensorData::from([[0.0000, 1.1752, 3.6269], [10.0179, 27.2899, 74.2032]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(3e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/slice.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::{Tolerance, s};\n\n#[test]\nfn should_support_full_sliceing_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0]);\n    let data = tensor.to_data();\n\n    let output = tensor.slice([0..4]);\n\n    output.into_data().assert_eq(&data, false);\n}\n\n#[test]\nfn should_support_partial_sliceing_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0]);\n\n    let output = tensor.slice([1..3]);\n    let expected = TensorData::from([1.0, 2.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_full_sliceing_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data = tensor.to_data();\n\n    let output = tensor.clone().slice([0..2]);\n    output.into_data().assert_eq(&data, true);\n\n    let output = tensor.slice([0..2, 0..3]);\n    output.into_data().assert_eq(&data, true);\n}\n\n#[test]\nfn should_support_partial_sliceing_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.slice([0..2, 0..2]);\n    let expected = TensorData::from([[0.0, 1.0], [3.0, 4.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_partial_sliceing_3d() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[0., 1., 2., 3.], [4., 5., 6., 7.]],\n        [[8., 9., 10., 11.], [12., 13., 14., 15.]],\n    ]);\n\n    let output = tensor.slice([1..2, 1..2, 0..2]);\n    let expected = TensorData::from([[[12.0, 13.0]]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_partial_sliceing_3d_non_contiguous() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[0., 1., 2., 3.], [4., 5., 6., 7.]],\n        [[8., 9., 10., 11.], [12., 13., 14., 15.]],\n    ]);\n\n    let output = tensor.transpose().slice([1..2, 1..2, 0..2]);\n    let expected = TensorData::from([[[9.0, 13.0]]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn should_support_slice_assign_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let tensor_assigned = QTensor::<TestBackend, 1>::int8([10.0, 5.0]);\n\n    let output = tensor.slice_assign([0..2], tensor_assigned);\n    let expected = TensorData::from([10.0, 5.0, 2.0]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_slice_assign_2d() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_assigned = QTensor::<TestBackend, 2>::int8([[10.0, 5.0]]);\n\n    let output = tensor.slice_assign([1..2, 0..2], tensor_assigned);\n    let expected = TensorData::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn slice_should_not_corrupt_potentially_inplace_operations() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n    let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]);\n\n    let expected = TensorData::from([4., 6., 8.]);\n\n    tensor\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn slice_assign_should_not_corrupt_potentially_inplace_operations() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n    let values = QTensor::<TestBackend, 1>::int8([10., 20., 30.]);\n\n    let tensor_1 = tensor.clone().slice_assign([0..3], values);\n    let tensor_2 = tensor + 2;\n\n    let expected = TensorData::from([10., 20., 30., 4., 5.]);\n\n    tensor_1\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n\n    let expected = TensorData::from([3., 4., 5., 6., 7.]);\n\n    tensor_2\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn clamp_when_slice_exceeds_dimension() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n    let data = tensor.to_data();\n\n    let output = tensor.slice([0..4]);\n    output.into_data().assert_eq(&data, true);\n}\n\n#[test]\nfn negative_dimensions() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data = tensor.to_data();\n\n    // Clamping to the tensor dimensions\n    let output = tensor.clone().slice([0..4, 0..4]);\n    output.into_data().assert_eq(&data, true);\n\n    // Negative dimensions\n    let output = tensor.clone().slice([0..1, 0..1]);\n    let data = TensorData::from([[0.0f32]]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n\n    let output = tensor.slice(s![0..-1, 0..-2]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\nfn missing_dimensions() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let data = tensor.to_data();\n\n    // Clamping to the tensor dimensions\n    let output = tensor.clone().slice([0..4, 0..4]);\n    output.into_data().assert_eq(&data, true);\n\n    // Negative dimensions\n    let data = TensorData::from([[0.0f32]]);\n    let output = tensor.clone().slice(s![0..-1, 0..-2]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n\n    // Missing dimensions\n    let output = tensor.clone().slice(s![0..1, ..]);\n    let data = TensorData::from([[0.0f32, 1.0, 2.0]]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n\n    let output = tensor.clone().slice(s![.., 0..2]);\n    let data = TensorData::from([[0.0f32, 1.0], [3.0, 4.0]]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n\n    let output = tensor.clone().slice([.., ..]);\n    let data = TensorData::from([[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&data, Tolerance::rel_abs(2e-2, 1e-2));\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_slice_with_too_many_dimensions() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0]);\n\n    let _output = tensor.slice([0..1, 0..1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sort_argsort.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_sort_1d_float() {\n    // Quantized [0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1]\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let values = tensor.sort(0);\n\n    let values_expected = TensorData::from([\n        -8.1, -0.3, -0.21, 0., 0.5, 0.94, 0.99, 1.2, 2.1, 2.3, 3., 4., 5.2,\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_argsort_1d_float() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let indices = tensor.argsort(0);\n\n    let indices_expected = TensorData::from([12, 6, 2, 3, 0, 5, 10, 1, 4, 7, 11, 9, 8]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_with_indices_descending_float() {\n    // 1D\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 5.2, 4., 0.99, 3., -8.1,\n    ]);\n\n    // Sort along dim=0\n    let (values, indices) = tensor.sort_descending_with_indices(0);\n\n    let values_expected = TensorData::from([\n        5.2, 4., 3., 2.3, 2.1, 1.2, 0.99, 0.94, 0.5, 0., -0.21, -0.3, -8.1,\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([8, 9, 11, 7, 4, 1, 10, 5, 0, 3, 2, 6, 12]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // 3D\n    // Quantized [-0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1]\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1,\n    ])\n    .reshape([2, 2, 3]);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.sort_descending_with_indices(1);\n\n    let values_expected = TensorData::from([\n        [[0., 2.1, 0.94], [-0.5, 1.2, -0.21]],\n        [[0.99, 3., 4.], [-0.3, 2.3, -8.1]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([[[1, 1, 1], [0, 0, 0]], [[1, 1, 0], [0, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_float() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1,\n    ])\n    .reshape([2, 2, 3]);\n\n    // Sort along dim=0\n    let values = tensor.clone().sort(0);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]],\n        [[-0.3, 2.3, 4.], [0.99, 3., 0.94]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    // Sort along dim=1\n    let values = tensor.clone().sort(1);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, -8.1], [0.99, 3., 4.]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    // Sort along dim=2\n    let values = tensor.sort(2);\n\n    let values_expected = TensorData::from([\n        [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]],\n        [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_sort_with_indices_float() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1,\n    ])\n    .reshape([2, 2, 3]);\n\n    // Sort along dim=0\n    let (values, indices) = tensor.clone().sort_with_indices(0);\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, -8.1]],\n        [[-0.3, 2.3, 4.], [0.99, 3., 0.94]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.clone().sort_with_indices(1);\n\n    let values_expected = TensorData::from([\n        [[-0.5, 1.2, -0.21], [0., 2.1, 0.94]],\n        [[-0.3, 2.3, -8.1], [0.99, 3., 4.]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let (values, indices) = tensor.sort_with_indices(2);\n\n    let values_expected = TensorData::from([\n        [[-0.5, -0.21, 1.2], [0., 0.94, 2.1]],\n        [[-0.3, 2.3, 4.], [-8.1, 0.99, 3.]],\n    ]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_argsort_float() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        -0.5, 1.2, -0.21, 0., 2.1, 0.94, -0.3, 2.3, 4., 0.99, 3., -8.1,\n    ])\n    .reshape([2, 2, 3]);\n\n    // Sort along dim=0\n    let indices = tensor.clone().argsort(0);\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [0, 0, 1]], [[1, 1, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let indices = tensor.clone().argsort(1);\n\n    let indices_expected = TensorData::from([[[0, 0, 0], [1, 1, 1]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let indices = tensor.argsort(2);\n\n    let indices_expected = TensorData::from([[[0, 2, 1], [0, 2, 1]], [[0, 1, 2], [2, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_descending_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    // Sort along dim=0\n    let values = tensor.sort_descending(0);\n\n    let values_expected = TensorData::from([5., 4., 3., 2., 1.]);\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/split.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse alloc::vec;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_split_evenly_divisible() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let tensors = tensor.split(2, 0);\n    assert_eq!(tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([0., 1.]),\n        TensorData::from([2., 3.]),\n        TensorData::from([4., 5.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_split_not_evenly_divisible() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);\n\n    let tensors = tensor.split(2, 0);\n    assert_eq!(tensors.len(), 4);\n\n    let expected = [\n        TensorData::from([0., 1.]),\n        TensorData::from([2., 3.]),\n        TensorData::from([4., 5.]),\n        TensorData::from([6.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_split_along_dim1() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let tensors = tensor.split(2, 1);\n    assert_eq!(tensors.len(), 2);\n\n    let expected = [\n        TensorData::from([[0., 1.], [3., 4.]]),\n        TensorData::from([[2.], [5.]]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\nfn test_split_split_size_larger_than_tensor_size() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let tensors = tensor.split(10, 0);\n    assert_eq!(tensors.len(), 1);\n\n    let expected = [TensorData::from([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\n#[should_panic(\n    expected = \"split_size must be greater than 0 unless the tensor size along the dimension is 0.\"\n)]\nfn test_split_with_zero_split_size_non_zero_tensor() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let _ = tensor.split(0, 0);\n}\n\n#[test]\n#[should_panic(expected = \"Given dimension is greater than or equal to the tensor rank.\")]\nfn test_split_invalid_dim() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let _ = tensor.split(1, 2);\n}\n\n#[test]\nfn test_split_with_sizes() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let tensors = tensor.split_with_sizes(vec![2, 3, 1], 0);\n    assert_eq!(tensors.len(), 3);\n\n    let expected = [\n        TensorData::from([0., 1.]),\n        TensorData::from([2., 3., 4.]),\n        TensorData::from([5.]),\n    ];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n\n#[test]\n#[should_panic(\n    expected = \"The sum of split_sizes must equal the tensor size along the specified dimension.\"\n)]\nfn test_split_with_sizes_invalid_sum() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let _ = tensor.split_with_sizes(vec![2, 2, 1], 0);\n}\n\n#[test]\nfn test_split_with_sizes_zero_length() {\n    let tensor = QTensor::<TestBackend, 1>::int8([0.0, 2.0, 5.0]);\n\n    let tensors = tensor.split_with_sizes(vec![0, 1, 2], 0);\n    assert_eq!(tensors.len(), 2);\n\n    let expected = [TensorData::from([0.]), TensorData::from([2., 5.])];\n\n    for (index, tensor) in tensors.into_iter().enumerate() {\n        tensor\n            .dequantize()\n            .to_data()\n            .assert_approx_eq::<FloatElem>(&expected[index], Tolerance::absolute(1e-1));\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sqrt.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\nuse core::f32::consts::SQRT_2;\n\n#[test]\nfn should_support_sqrt_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.sqrt();\n    let expected = TensorData::from([[0.0, 1.0, SQRT_2], [1.73205, 2.0, 2.2360]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/stack.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse alloc::vec;\nuse burn_tensor::Tensor;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_stack_ops_2d_dim0() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0, 6.0]]);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_stack_ops_2d_dim1() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0, 6.0]]);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 1);\n    let expected = TensorData::from([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_stack_ops_3d() {\n    let tensor_1 = QTensor::<TestBackend, 3>::int8([[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]]);\n    let tensor_2 = QTensor::<TestBackend, 3>::int8([[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]]);\n\n    let output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([\n        [[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]],\n        [[[4.0, 5.0, 6.0]], [[6.0, 5.0, 4.0]]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_dimensions_are_not_the_same() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 5.0]]);\n\n    let _output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_stack_exceeds_dimension() {\n    let tensor_1 = QTensor::<TestBackend, 3>::int8([[[1.0, 2.0, 3.0]], [[3.0, 2.0, 1.0]]]);\n    let tensor_2 = QTensor::<TestBackend, 3>::int8([[[4.0, 5.0, 6.0]]]);\n\n    let _output = Tensor::stack::<4>(vec![tensor_1, tensor_2], 3);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/sub.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_sub_ops() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-6.0, -6.0, -6.0], [-6.0, -6.0, -6.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_sub_broadcast() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-3.0, -3.0, -3.0], [-6.0, -6.0, -6.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_sub_scalar_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n    let scalar = 2.0;\n\n    let output = tensor - scalar;\n    let expected = TensorData::from([[-2.0, -1.0, 0.0], [1.0, 2.0, 3.0]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::rel_abs(2e-2, 1e-2));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tan.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_tan_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.tan();\n    let expected = TensorData::from([[0.0, 1.5574, -2.1850], [-0.1425, 1.1578, -3.3805]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/tanh.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_tanh_ops() {\n    let tensor = QTensor::<TestBackend, 2>::int8([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);\n\n    let output = tensor.tanh();\n    let expected = TensorData::from([[0.0, 0.7615, 0.9640], [0.9950, 0.9993, 0.9999]]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/topk.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_topk_1d() {\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let values = tensor.topk(3, /*dim*/ 0);\n    let expected = TensorData::from([5., 4., 3.]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_topk() {\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[1., 4., 7.], [2., 5., 6.]],\n        [[3., 0., 9.], [8., 2., 7.]],\n    ]);\n\n    let values = tensor.topk(2, /*dim*/ 2);\n    let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn test_topk_with_indices() {\n    // 1D\n    let tensor = QTensor::<TestBackend, 1>::int8([1.0, 2.0, 3.0, 4.0, 5.0]);\n\n    let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0);\n\n    let values_expected = TensorData::from([5., 4., 3.]);\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::permissive());\n\n    let indices_expected = TensorData::from([4, 3, 2]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_topk_with_indices_3d() {\n    // 3D\n    let tensor = QTensor::<TestBackend, 3>::int8([\n        [[1., 4., 7.], [2., 5., 6.]],\n        [[3., 0., 9.], [8., 2., 7.]],\n    ]);\n\n    let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2);\n\n    let values_expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);\n\n    values\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&values_expected, Tolerance::absolute(1e-1));\n\n    let indices_expected = TensorData::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]);\n\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/extended/transpose.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn should_support_transpose_ops() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,\n    ])\n    .reshape([2, 2, 3]);\n\n    let output = tensor.transpose();\n    let expected = TensorData::from([\n        [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]],\n        [[6.0, 9.0], [7.0, 10.0], [8.0, 11.0]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n\n#[test]\nfn should_support_swap_dims() {\n    let tensor = QTensor::<TestBackend, 1>::int8([\n        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,\n    ])\n    .reshape([2, 2, 3]);\n\n    let output = tensor.swap_dims(0, 2);\n    let expected = TensorData::from([\n        [[0.0, 6.0], [3.0, 9.0]],\n        [[1.0, 7.0], [4.0, 10.0]],\n        [[2.0, 8.0], [5.0, 11.0]],\n    ]);\n\n    output\n        .dequantize()\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::absolute(1e-1));\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/matmul.rs",
    "content": "use super::qtensor::*;\nuse super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\n#[ignore]\nfn test_matmul_vectors() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 2.0, 3.0, 6.35]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[12.7], [4.0], [5.0], [1.0]]);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([[42.05]]);\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\n#[ignore]\nfn test_matmul_2d() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]]);\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([[16.7, 27.05, 50.8], [14., 25., 43.4], [10., 17., 30.7]]);\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\nfn test_matmul_2d_aligned() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([\n        [1.0, 2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [9.0, 10.0, 11.0, 12.0],\n    ]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([\n        [2.0, 0.0, 1.0, 0.0],\n        [1.0, 2.0, 0.0, 0.0],\n        [0.0, 1.0, 2.0, 0.0],\n        [1.0, 0.0, 0.0, 1.0],\n    ]);\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([\n        [8.0, 7.0, 7.0, 4.0],\n        [24.0, 19.0, 19.0, 8.0],\n        [40.0, 31.0, 31.0, 12.0],\n    ]);\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\nfn test_matmul_2d_aligned_fused() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([\n        [1.0, 2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0, 8.0],\n        [9.0, 10.0, 11.0, 12.0],\n    ]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([\n        [2.0, 0.0, 1.0, 0.0],\n        [1.0, 2.0, 0.0, 0.0],\n        [0.0, 1.0, 2.0, 0.0],\n        [1.0, 0.0, 0.0, 1.0],\n    ]);\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let tensor_4 = tensor_3 / 2.0;\n\n    let expected = TensorData::from([\n        [4.0, 3.5, 3.5, 2.0],\n        [12.0, 9.5, 9.5, 4.0],\n        [20.0, 15.5, 15.5, 6.0],\n    ]);\n    tensor_4\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\n#[ignore]\nfn test_matmul_3d() {\n    let tensor_1 = QTensor::<TestBackend, 3>::int8([[[1.0, 6.35], [2.0, 3.0]]]);\n    let tensor_2 = QTensor::<TestBackend, 3>::int8([[[12.7, 4.0], [2.0, 3.0]]]);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([[[25.4, 23.05], [31.4, 17.0]]]);\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\n#[ignore]\nfn test_matmul_broadcast_4d() {\n    let tensor_1 =\n        QTensor::<TestBackend, 4>::int8([[[[1.0, 7.0], [2.0, 3.0]]], [[[2.0, 5.0], [6.0, 3.0]]]]);\n    let tensor_2 =\n        QTensor::<TestBackend, 4>::int8([[[[9.0, 8.0], [1.0, 4.0]], [[2.0, 7.0], [3.0, 5.0]]]]);\n\n    // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2]\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [[[16.0, 36.0], [21.0, 28.0]], [[23.0, 42.0], [13.0, 29.0]]],\n        [[[23.0, 36.0], [57.0, 60.0]], [[19.0, 39.0], [21.0, 57.0]]],\n    ]);\n\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\n#[ignore]\nfn test_matmul_broadcast() {\n    let tensor_1 = QTensor::<TestBackend, 3>::int8([[[1.0, 7.0], [2.0, 3.0]]]);\n    let tensor_2 =\n        QTensor::<TestBackend, 3>::int8([[[4.0, 7.0], [2.0, 3.0]], [[2.0, 5.0], [6.0, 3.0]]]);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[[18.0, 28.0], [14.0, 23.0]], [[44.0, 26.0], [22.0, 19.0]]]);\n\n    tensor_3\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::relative(2e-2));\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_inner_dimensions_are_not_equal() {\n    let tensor_1 = QTensor::<TestBackend, 2>::int8([[3., 3.], [4., 4.], [5., 5.], [6., 6.]]);\n    let tensor_2 =\n        QTensor::<TestBackend, 2>::int8([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]]);\n\n    let _ = tensor_1.matmul(tensor_2);\n}\n\n#[test]\nfn test_matmul_lhs_float_rhs_quantized() {\n    // Simulates a typical workflow with linear layers (e.g., transformers), where the rhs\n    // represents the weights. The lhs might be a float if a previous operation did not propagate\n    // the quantization. We still want to perform an efficient matmul with quantized weights.\n    let tensor_1 = TestTensor::<2>::from([\n        [1.0, 6.35, 2.0, 3.0],\n        [2.0, 3.0, 4.0, 5.0],\n        [1.0, 3.0, 5.0, 7.0],\n    ]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8([\n        [4.0, 8.0, 12.7, 1.6],\n        [2.0, 3.0, 6.0, 4.0],\n        [1.0, 5.0, 9.0, 2.5],\n        [3.0, 7.0, 11.0, 0.5],\n    ]);\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([\n        [27.7, 58.05, 101.8, 33.5],\n        [33., 80., 134.4, 27.7],\n        [36., 91., 152.7, 29.6],\n    ]);\n    let output = tensor_3.into_data();\n    output.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n\n    // Default quantization scheme does not propagate quantization with matmul\n    assert!(output.dtype.is_float());\n}\n\n#[test]\nfn test_matmul_mixed_block_scale() {\n    let tensor_1 = TestTensor::<2>::from([\n        [1.0, 6.35, 2.0, 3.0],\n        [2.0, 3.0, 4.0, 5.0],\n        [1.0, 3.0, 5.0, 7.0],\n    ]);\n    let tensor_2 = QTensor::<TestBackend, 2>::int8_block([\n        [\n            6.110, 4.0, 9.360, 7.850, 0.630, 1.770, 0.430, 7.550, 9.690, 3.560, 2.920, 9.130,\n            3.390, 0.510, 1.620, 1.460,\n        ],\n        [\n            6.140, 8.260, 5.660, 5.610, 7.070, 3.050, 9.890, 5.520, 1.350, 3.810, 5.630, 0.250,\n            0.350, 8.860, 3.610, 6.240,\n        ],\n        [\n            8.810, 4.620, 7.420, 8.110, 2.560, 4.710, 5.730, 8.980, 1.170, 6.090, 4.140, 3.610,\n            4.960, 9.720, 5.710, 1.470,\n        ],\n        [\n            2.260, 9.640, 6.320, 6.980, 9.860, 1.030, 8.340, 1.570, 4.140, 4.760, 4.590, 6.400,\n            5.350, 1.430, 4.960, 1.180,\n        ],\n    ]);\n    let tensor_3 = tensor_1.matmul(tensor_2);\n\n    let expected = TensorData::from([\n        [\n            69.499, 94.611, 79.101, 80.633, 80.225, 33.647, 99.711, 65.272, 33.022, 54.213, 60.721,\n            37.138, 31.582, 80.501, 50.843, 47.564,\n        ],\n        [\n            77.180, 99.460, 96.980, 99.870, 82.010, 36.680, 95.150, 75.430, 48.810, 66.710, 62.240,\n            65.450, 54.420, 73.630, 61.710, 33.420,\n        ],\n        [\n            84.400, 119.360, 107.680, 114.090, 103.660, 41.680, 117.130, 80.0, 48.570, 78.760,\n            72.640, 72.730, 66.690, 85.700, 75.720, 35.790,\n        ],\n    ]);\n    let output = tensor_3.into_data();\n    output.assert_approx_eq::<FloatElem>(&expected, Tolerance::permissive());\n\n    // Default quantization scheme does not propagate quantization with matmul\n    assert!(output.dtype.is_float());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/mod.rs",
    "content": "pub use super::*;\n\nmod matmul;\nmod quantize;\n\n// TODO: re-enable for cubecl backends when inputs are valid for packed U32 storage\n#[cfg(feature = \"ndarray\")]\nmod extended;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/ops/quantize.rs",
    "content": "use super::*;\nuse alloc::{vec, vec::Vec};\nuse burn_tensor::quantization::{\n    QParams, QTensorPrimitive, QuantLevel, QuantScheme, QuantStore, QuantValue,\n    QuantizationParameters, QuantizedBytes,\n};\nuse burn_tensor::{DType, Element, TensorData};\nuse burn_tensor::{Tolerance, ops::QuantizedTensor};\n\nfn get_q_params(data: TensorData) -> QParams<Vec<f32>> {\n    let num_elements = data.num_elements();\n    let scheme = if let DType::QFloat(scheme) = data.dtype {\n        scheme\n    } else {\n        unreachable!()\n    };\n    let q_bytes = QuantizedBytes {\n        bytes: data.into_bytes(),\n        scheme,\n        num_elements,\n    };\n    q_bytes.into_vec_i8().1\n}\n\n#[test]\nfn should_support_quantize_symmetric_int8() {\n    // Strict equality was based on full precision\n    if !matches!(FloatElem::dtype(), DType::F32) {\n        return;\n    }\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device);\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n    let qparams = QuantizationParameters {\n        scales: TestTensor::from_floats([0.014_173_228], &device),\n    };\n\n    let x_q = tensor.clone().quantize(&scheme, qparams);\n\n    let x_q_data = x_q.to_data();\n    let expected = TensorData::quantized(\n        vec![-127i8, -71, 0, 35],\n        [4],\n        scheme.with_store(QuantStore::Native),\n        &[0.014_173_228], // scale\n    );\n\n    // Values equality\n    x_q_data.assert_eq(&expected, false);\n\n    // Quantization parameters check\n    let qparams = get_q_params(x_q_data);\n    let expected = get_q_params(expected);\n    assert_eq!(qparams.scales.len(), 1);\n    // TODO: check scales\n    assert_eq!(qparams.scales, expected.scales);\n\n    // Dequantize\n    let x = x_q.dequantize();\n\n    x.into_data()\n        .assert_approx_eq::<FloatElem>(&tensor.into_data(), Tolerance::rel_abs(1e-1, 1e-2));\n}\n\n#[test]\nfn should_support_quantize_dynamic_int8() {\n    let device = Default::default();\n    // NOTE: we use fully representable values since different backend implementations could differ slightly\n    // due to rounding discrepancies\n    let tensor = TestTensor::<1>::from_floats([5., 0., 4., -12.7], &device);\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let x_q = tensor.quantize_dynamic(&scheme);\n\n    let expected = TensorData::quantized(\n        vec![50i8, 0, 40, -127],\n        [4],\n        scheme.with_store(QuantStore::Native),\n        &[0.1], // scale\n    );\n\n    x_q.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_single_with_transform() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n    let input = TestTensorInt::<1>::arange(0..32, &Default::default()).float();\n\n    let quant = input.quantize_dynamic(&scheme);\n    let result = quant * 10;\n\n    let data = result.into_data();\n    let expected = [\n        0.0, 9.76378, 19.52756, 29.29134, 39.05512, 48.818897, 61.02362, 70.7874, 80.551186,\n        90.31496, 100.07874, 109.84252, 119.60631, 129.37009, 139.13387, 148.89764, 161.10237,\n        170.86615, 180.62991, 190.39369, 200.15749, 209.92126, 219.68504, 229.44882, 239.21262,\n        248.97638, 261.1811, 270.9449, 280.70865, 290.47244, 300.23624, 310.0,\n    ];\n    data.assert_approx_eq::<FloatElem>(&TensorData::from(expected), Tolerance::permissive());\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_arange_16x16() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let input: TestTensor<2> = TestTensorInt::arange(0..256, &Default::default())\n        .float()\n        .div_scalar(256.)\n        .reshape([16, 16]);\n\n    let output = input.clone().quantize_dynamic(&scheme);\n    let output = output.dequantize();\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &input.into_data(),\n        Tolerance::absolute(1e-1).set_relative(1e-2),\n    );\n}\n\n#[test]\nfn should_quantize_dequantize_symmetric_per_block_arange_16x16() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([2, 16]));\n\n    let input: TestTensor<2> = TestTensorInt::arange(0..256, &Default::default())\n        .float()\n        .div_scalar(256.)\n        .reshape([16, 16]);\n\n    let output = input.clone().quantize_dynamic(&scheme);\n    let output = output.dequantize();\n\n    output.into_data().assert_approx_eq::<FloatElem>(\n        &input.into_data(),\n        Tolerance::absolute(1e-1).set_relative(1e-2),\n    );\n}\n\nfn should_quantize_transposed<const D: usize>(tensor: Tensor<TestBackend, D>, scheme: QuantScheme) {\n    let tensor_t = tensor.clone().transpose();\n\n    let output = tensor_t.quantize_dynamic(&scheme).dequantize().transpose();\n\n    tensor.into_data().assert_approx_eq::<FloatElem>(\n        &output.into_data(),\n        Tolerance::absolute(1e-1).set_relative(1e-2),\n    );\n}\n\n#[test]\nfn should_quantize_symmetric_int8_transposed_8x32() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let tensor = TestTensorInt::arange(0..256, &Default::default())\n        .float()\n        .div_scalar(256.)\n        .reshape([8, 32]);\n    should_quantize_transposed(tensor, scheme);\n}\n\n#[test]\nfn should_quantize_symmetric_int8_transposed_48x64() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let tensor = TestTensorInt::arange(0..3072, &Default::default())\n        .float()\n        .div_scalar(3072.)\n        .reshape([48, 64]);\n    should_quantize_transposed(tensor, scheme);\n}\n\n#[test]\nfn should_quantize_symmetric_per_block_int8_transposed_32x64() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([32]));\n\n    let tensor = TestTensorInt::arange(0..2048, &Default::default())\n        .float()\n        .div_scalar(2048.)\n        .reshape([32, 64]);\n    should_quantize_transposed(tensor, scheme);\n}\n\n#[test]\nfn should_quantize_symmetric_int8_permuted_batch_dims() {\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let tensor = TestTensorInt::arange(0..2048, &Default::default())\n        .float()\n        .div_scalar(2048.)\n        .reshape([2, 4, 8, 32]);\n\n    // Permute [0,1,2,3] -> [1,2,0,3]\n    // This rearranges batch dims but keeps packed dim in place\n    let tensor_permuted = tensor.clone().permute([1, 2, 0, 3]);\n\n    let output = tensor_permuted\n        .quantize_dynamic(&scheme)\n        .dequantize()\n        .permute([2, 0, 1, 3]); // reverse permutation\n\n    tensor.into_data().assert_approx_eq::<FloatElem>(\n        &output.into_data(),\n        Tolerance::absolute(1e-1).set_relative(1e-2),\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/quantization/scheme.rs",
    "content": "use super::*;\nuse burn_tensor::Tolerance;\nuse burn_tensor::{\n    Element, TensorData,\n    ops::QuantizedTensor,\n    quantization::{CalibrationRange, QTensorPrimitive, QuantLevel, QuantValue, compute_q_params},\n};\n\n#[test]\nfn per_tensor_symmetric_int8() {\n    let device = Default::default();\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n    let range = CalibrationRange {\n        min: TestTensor::<1>::from_floats([0.5], &device),\n        max: TestTensor::<1>::from_floats([1.8], &device),\n    };\n\n    let qparams = compute_q_params(&scheme, range);\n\n    qparams\n        .scales\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&TensorData::from([0.014_173_23]), Tolerance::default());\n}\n\n#[test]\nfn per_block_symmetric_int8() {\n    let device = Default::default();\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([4]));\n    let range = CalibrationRange {\n        min: TestTensor::<1>::from_floats([-1.8, -0.5, 0.01, -0.04], &device),\n        max: TestTensor::<1>::from_floats([0.5, 1.8, 0.04, -0.01], &device),\n    };\n\n    let qparams = compute_q_params(&scheme, range);\n\n    qparams.scales.into_data().assert_approx_eq::<FloatElem>(\n        &TensorData::from([0.014_173_23, 0.014_173_23, 0.000_314_96, 0.000_314_96]),\n        Tolerance::default(),\n    );\n}\n\n#[test]\nfn quant_scheme_should_inhibit_by_default() {\n    let device = Default::default();\n    let scheme = QuantizedTensor::<TestBackend>::default_scheme().with_value(QuantValue::Q8S);\n\n    let tensor_1 = TestTensor::<2>::from_floats(\n        [[1.0, 6.35, 0., 0.], [2.0, 3.0, 0., 0.], [1.0, 3.0, 0., 0.]],\n        &device,\n    )\n    .quantize_dynamic(&scheme);\n    let _tensor_2 = TestTensor::<2>::from_floats(\n        [\n            [4.0, 8.0, 12.7, 0.],\n            [2.0, 3.0, 6.0, 0.],\n            [0., 0., 0., 0.],\n            [0., 0., 0., 0.],\n        ],\n        &device,\n    )\n    .quantize_dynamic(&scheme);\n\n    // let tensor_3 = tensor_1.clone().matmul(tensor_2);\n    // assert_eq!(tensor_3.to_data().dtype, FloatElem::dtype());\n\n    let tensor_4 = tensor_1.add_scalar(1.);\n    assert_eq!(tensor_4.to_data().dtype, FloatElem::dtype());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/cov.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, Tolerance};\n\n#[test]\nfn test_cov_1() {\n    let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.cov(1, 1);\n    let expected =\n        TensorData::from([[2.48917, -1.73333], [-1.73333, 15.33333]]).convert::<FloatElem>();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_cov_4() {\n    let data = TensorData::from([[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.cov(1, 0);\n    let expected = TensorData::from([[1.86687, -1.30000], [-1.30000, 11.5]]).convert::<FloatElem>();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_cov_2() {\n    let data = TensorData::from([[0.5, 1.8], [0.2, -2.0], [3.0, -4.0], [5.0, 0.0]]);\n    let tensor = TestTensor::<2>::from_data(data, &Default::default());\n\n    let output = tensor.cov(1, 1);\n    let expected = TensorData::from([\n        [0.845, -1.43, -4.55, -3.25],\n        [-1.43, 2.42, 7.7, 5.5],\n        [-4.55, 7.7, 24.5, 17.5],\n        [-3.25, 5.5, 17.5, 12.5],\n    ])\n    .convert::<FloatElem>();\n\n    let tolerance = Tolerance::default().set_half_precision_relative(1e-3);\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, tolerance);\n}\n\n#[test]\nfn test_cov_3() {\n    let data = TensorData::from([\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n    ]);\n    let device = Default::default();\n    let tensor = TestTensor::<3>::from_data(data, &device);\n    let data_actual = tensor.cov(0, 1).into_data();\n    let data_expected = TestTensor::<3>::zeros([4, 4, 4], &device).to_data();\n    data_expected.assert_approx_eq::<FloatElem>(&data_actual, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/display.rs",
    "content": "use super::*;\nuse burn_tensor::backend::Backend;\nuse burn_tensor::{Element, Shape, TensorData};\n\ntype FloatElem = <TestBackend as Backend>::FloatElem;\ntype IntElem = <TestBackend as Backend>::IntElem;\n\n// Floating point values might not match for other precisions\nfn skip_precision_not_f32() -> bool {\n    core::any::TypeId::of::<FloatElem>() != core::any::TypeId::of::<f32>()\n}\n\n#[test]\nfn test_display_2d_int_tensor() {\n    let int_data = TensorData::from([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);\n    let tensor_int = TestTensorInt::<2>::from_data(int_data, &Default::default());\n\n    let output = format!(\"{}\", tensor_int);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[1, 2, 3],\n [4, 5, 6],\n [7, 8, 9]],\n  shape:  [3, 3],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Int\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor_int.device(),\n        TestBackend::name(&tensor_int.device()),\n        dtype = core::any::type_name::<IntElem>(),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_2d_float_tensor() {\n    if skip_precision_not_f32() {\n        return;\n    }\n\n    let float_data = TensorData::from([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6], [7.7, 8.8, 9.9]]);\n    let tensor_float = TestTensor::<2>::from_data(float_data, &Default::default());\n\n    let output = format!(\"{}\", tensor_float);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[1.1, 2.2, 3.3],\n [4.4, 5.5, 6.6],\n [7.7, 8.8, 9.9]],\n  shape:  [3, 3],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}}\"#,\n        tensor_float.device(),\n        TestBackend::name(&tensor_float.device()),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_2d_bool_tensor() {\n    let bool_data = TensorData::from([\n        [true, false, true],\n        [false, true, false],\n        [false, true, true],\n    ]);\n    let tensor_bool = TestTensorBool::<2>::from_data(bool_data, &Default::default());\n\n    let output = format!(\"{}\", tensor_bool);\n    // TODO: remove once backends no longer rely on generics for default elem types\n    let expected_name = match <TestBackend as Backend>::BoolElem::dtype() {\n        burn_tensor::DType::U8 => burn_tensor::DType::Bool(burn_tensor::BoolStore::U8).name(),\n        burn_tensor::DType::U32 => burn_tensor::DType::Bool(burn_tensor::BoolStore::U32).name(),\n        dtype => dtype.name(),\n    };\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[true, false, true],\n [false, true, false],\n [false, true, true]],\n  shape:  [3, 3],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Bool\",\n  dtype:  {:?},\n}}\"#,\n        tensor_bool.device(),\n        TestBackend::name(&tensor_bool.device()),\n        expected_name,\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_3d_tensor() {\n    let data = TensorData::from([\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],\n        [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],\n    ]);\n    let tensor = TestTensorInt::<3>::from_data(data, &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[[1, 2, 3, 4],\n  [5, 6, 7, 8],\n  [9, 10, 11, 12]],\n [[13, 14, 15, 16],\n  [17, 18, 19, 20],\n  [21, 22, 23, 24]]],\n  shape:  [2, 3, 4],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Int\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n        dtype = core::any::type_name::<IntElem>(),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_4d_tensor() {\n    let data = TensorData::from([\n        [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],\n        [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]],\n    ]);\n\n    let tensor = TestTensorInt::<4>::from_data(data, &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[[[1, 2, 3],\n   [4, 5, 6]],\n  [[7, 8, 9],\n   [10, 11, 12]]],\n [[[13, 14, 15],\n   [16, 17, 18]],\n  [[19, 20, 21],\n   [22, 23, 24]]]],\n  shape:  [2, 2, 2, 3],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Int\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n        dtype = core::any::type_name::<IntElem>(),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_tensor_summarize_1() {\n    let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 2, 1000]), &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]],\n [[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]]],\n  shape:  [2, 2, 2, 1000],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n        dtype = FloatElem::dtype().name(),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_tensor_summarize_2() {\n    let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 20, 100]), &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]],\n [[[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0]]]],\n  shape:  [2, 2, 20, 100],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n        dtype = FloatElem::dtype().name(),\n    );\n    assert_eq!(output, expected);\n}\n\n#[test]\nfn test_display_tensor_summarize_3() {\n    let tensor = TestTensor::<4>::zeros(Shape::new([2, 2, 200, 6]), &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]],\n [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],\n  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   ...\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n   [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]]],\n  shape:  [2, 2, 200, 6],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"{dtype}\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n        dtype = FloatElem::dtype().name(),\n    );\n    assert_eq!(output, expected);\n}\n#[test]\nfn test_display_precision() {\n    if skip_precision_not_f32() {\n        return;\n    }\n\n    let tensor = TestTensor::<2>::full([1, 1], 0.123456789, &Default::default());\n\n    let output = format!(\"{}\", tensor);\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[0.12345679]],\n  shape:  [1, 1],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n    );\n    assert_eq!(output, expected);\n\n    // CAN'T DO THIS BECAUSE OF GLOBAL STATE\n    // let print_options = PrintOptions {\n    //     precision: Some(3),\n    //     ..Default::default()\n    // };\n    // set_print_options(print_options);\n\n    let tensor = TestTensor::<2>::full([3, 2], 0.123456789, &Default::default());\n\n    // Set precision to 3\n    let output = format!(\"{:.3}\", tensor);\n\n    let expected = format!(\n        r#\"Tensor {{\n  data:\n[[0.123, 0.123],\n [0.123, 0.123],\n [0.123, 0.123]],\n  shape:  [3, 2],\n  device:  {:?},\n  backend:  {:?},\n  kind:  \"Float\",\n  dtype:  \"f32\",\n}}\"#,\n        tensor.device(),\n        TestBackend::name(&tensor.device()),\n    );\n    assert_eq!(output, expected);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/eye.rs",
    "content": "use super::*;\n\n#[test]\nfn test_eye_float() {\n    let device = Default::default();\n    let tensor = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);\n    let rhs = TestTensor::<2>::eye(3, &device);\n    assert_eq!(tensor.to_data(), rhs.to_data());\n}\n\n#[test]\nfn test_eye_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<2>::from([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);\n    let rhs = TestTensorInt::<2>::eye(3, &device);\n    assert_eq!(tensor.to_data(), rhs.to_data());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/median.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_median_even() {\n    let tensor = TestTensor::<2>::from_data(\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        &Default::default(),\n    );\n\n    let median_actual_1 = tensor.clone().median(1);\n    let median_expected_1 = TensorData::from([[0.2], [0.0]]).convert::<FloatElem>();\n    median_actual_1\n        .into_data()\n        .assert_eq(&median_expected_1, false);\n\n    let median_actual_0 = tensor.median(0);\n    let median_expected_0 = TensorData::from([[0.5, -4.0, 0.2, -2.0]]).convert::<FloatElem>();\n    median_actual_0\n        .into_data()\n        .assert_eq(&median_expected_0, false);\n}\n\n#[test]\nfn test_median_odd() {\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [0.5, 1.8, 0.2, -2.0, 1.0],\n            [3.0, -4.0, 5.0, 0.0, -1.0],\n            [5.0, -5.0, 1.0, 3.0, -2.0],\n        ],\n        &Default::default(),\n    );\n\n    let median_actual_1 = tensor.clone().median(1);\n    let median_expected_1 = TensorData::from([[0.5], [0.0], [1.0]]).convert::<FloatElem>();\n    median_actual_1\n        .into_data()\n        .assert_eq(&median_expected_1, false);\n\n    let median_actual_0 = tensor.median(0);\n    let median_expected_0 = TensorData::from([[3.0, -4.0, 1.0, 0.0, -1.0]]).convert::<FloatElem>();\n    median_actual_0\n        .into_data()\n        .assert_eq(&median_expected_0, false);\n}\n\n#[test]\nfn test_median_with_indices() {\n    let device = Default::default();\n    let tensor = TestTensor::<1>::from_data([3.0, 1.0, 2.0], &device);\n    // median = 2, original index = 2\n    let (values, indices) = tensor.median_with_indices(0);\n    values\n        .into_data()\n        .assert_eq(&TensorData::from([2.0]), false);\n    indices\n        .into_data()\n        .assert_eq(&TensorData::from([2i64]), false);\n\n    let tensor = TestTensor::<2>::from_data([[5.0, 1.0, 3.0], [2.0, 8.0, 4.0]], &device);\n    // Along dim 1:\n    // Row 0: median = 3, original index = 2\n    // Row 1: median = 4, original index = 2\n    let (values, indices) = tensor.median_with_indices(1);\n    values\n        .into_data()\n        .assert_eq(&TensorData::from([[3.0], [4.0]]), false);\n    indices\n        .into_data()\n        .assert_eq(&TensorData::from([[2i64], [2i64]]), false);\n}\n\n#[test]\nfn test_median_all_elements() {\n    let tensor = TestTensor::<2>::from_data(\n        [\n            [0.5, 1.8, 0.2, -2.0, 1.0],\n            [3.0, -4.0, 5.0, 0.0, -1.0],\n            [5.0, -5.0, 1.0, 3.0, -2.0],\n        ],\n        &Default::default(),\n    );\n\n    // Sorted: [-5, -4, -2, -2, -1, 0, 0.2, 0.5, 1, 1, 1.8, 3, 3, 5, 5]\n    let dims = tensor.dims().len();\n    let flattened_tensor: Tensor<_, 1> = tensor.flatten(0, dims - 1);\n    let result = flattened_tensor.median(0);\n    result\n        .into_data()\n        .assert_eq(&TensorData::from([0.5]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod cov;\nmod display;\nmod eye;\nmod median;\nmod var;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/float/stats/var.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_var() {\n    let tensor = TestTensor::<2>::from_data(\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        &Default::default(),\n    );\n\n    let output = tensor.var(1);\n    let expected = TensorData::from([[2.4892], [15.3333]]).convert::<FloatElem>();\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_var_mean() {\n    let tensor = TestTensor::<2>::from_data(\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        &Default::default(),\n    );\n\n    let (var, mean) = tensor.var_mean(1);\n\n    let var_expected = TensorData::from([[2.4892], [15.3333]]).convert::<FloatElem>();\n    let mean_expected = TensorData::from([[0.125], [1.]]).convert::<FloatElem>();\n\n    var.into_data()\n        .assert_approx_eq::<FloatElem>(&var_expected, Tolerance::default());\n    mean.into_data()\n        .assert_approx_eq::<FloatElem>(&mean_expected, Tolerance::default());\n}\n\n#[test]\nfn test_var_bias() {\n    let tensor = TestTensor::<2>::from_data(\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        &Default::default(),\n    );\n\n    let output = tensor.var_bias(1);\n    let expected = TensorData::from([[1.86688], [11.5]]).convert::<FloatElem>();\n\n    output\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_var_mean_bias() {\n    let tensor = TestTensor::<2>::from_data(\n        [[0.5, 1.8, 0.2, -2.0], [3.0, -4.0, 5.0, 0.0]],\n        &Default::default(),\n    );\n\n    let (var, mean) = tensor.var_mean_bias(1);\n\n    let var_expected = TensorData::from([[1.86688], [11.5]]).convert::<FloatElem>();\n    let mean_expected = TensorData::from([[0.125], [1.]]).convert::<FloatElem>();\n\n    var.into_data()\n        .assert_approx_eq::<FloatElem>(&var_expected, Tolerance::default());\n    mean.into_data()\n        .assert_approx_eq::<FloatElem>(&mean_expected, Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod ops;\nmod primitive;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/abs.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_abs_ops_int() {\n    let tensor = TestTensorInt::<2>::from([[0, -1, 2], [3, 4, -5]]);\n\n    let output = tensor.abs();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 2], [3, 4, 5]]), false);\n}\n\n#[test]\nfn should_support_abs_ops_int_signed_min() {\n    let tensor = TestTensorInt::<2>::from([[IntElem::MIN]]);\n\n    let output = tensor.abs();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[IntElem::MIN]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/add.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_add_d2_int() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 11]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[6, 8, 10], [12, 14, 16]]), false);\n}\n\n#[test]\nfn test_add_broadcast_int() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2]]);\n    let tensor_2 = TestTensorInt::from([[3, 4, 5], [6, 7, 8]]);\n\n    let output = tensor_1 + tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 5, 7], [6, 8, 10]]), false);\n}\n\n#[test]\nfn should_support_add_scalar_ops_int() {\n    let scalar = 2;\n    let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let output = tensor + scalar;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2, 3, 4], [5, 6, 7]]), false);\n}\n\n#[test]\nfn scalar_add_not_contiguous() {\n    let tensor = TestTensorInt::<1>::arange(0..32, &Default::default()).float();\n    let tensor = tensor.reshape([1, 4, 4, 2]).permute([0, 3, 1, 2]);\n\n    let tensor = tensor.slice([0..1, 0..2, 0..4, 0..4]);\n    let before = tensor.clone();\n\n    let after = tensor.add_scalar(0.0);\n\n    before\n        .into_data()\n        .assert_approx_eq::<f32>(&after.into_data(), Default::default());\n}\n\n#[test]\nfn scalar_add_not_contiguous_int() {\n    let tensor = TestTensorInt::<1>::arange(0..32, &Default::default());\n    let tensor = tensor.reshape([1, 4, 4, 2]).permute([0, 3, 1, 2]);\n\n    let tensor = tensor.slice([0..1, 0..2, 0..4, 0..4]);\n    let before = tensor.clone();\n\n    let after = tensor.add_scalar(0);\n\n    before.into_data().assert_eq(&after.into_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/aggregation.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_should_mean_int() {\n    let tensor = TestTensorInt::<2>::from([[2, 2, 2], [3, 4, 5]]);\n\n    let output = tensor.mean();\n\n    output.into_data().assert_eq(&TensorData::from([3]), false);\n}\n\n#[test]\nfn test_should_mean_last_dim_int() {\n    let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let output = tensor.mean_dim(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1], [4]]), false);\n}\n\n#[test]\nfn test_should_sum_last_dim_int() {\n    let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let output = tensor.sum_dim(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3], [12]]), false);\n}\n\n#[test]\nfn test_should_sum_int() {\n    let tensor = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let output = tensor.sum();\n\n    output.into_data().assert_eq(&TensorData::from([15]), false);\n}\n\n#[test]\n#[ignore = \"Not implemented for all backends yet\"]\nfn test_prod_int() {\n    let tensor = TestTensorInt::<2>::from([[2, 1, 2], [3, 4, 5]]);\n    let output = tensor.prod();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([240]), false);\n\n    let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]);\n    let output = tensor_with_zero.prod();\n\n    output.into_data().assert_eq(&TensorData::from([0]), false);\n}\n\n#[test]\n#[ignore = \"Not implemented for all backends yet\"]\nfn test_prod_dim_int() {\n    let tensor = TestTensorInt::<2>::from([[2, 1, 2], [3, 4, 5]]);\n    let output = tensor.prod_dim(1);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[4], [60]]), false);\n\n    let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]);\n    let output = tensor_with_zero.prod_dim(1);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0], [60]]), false);\n\n    // Negative Indexing.\n    let tensor_with_zero = TestTensorInt::<2>::from([[2, 0, 2], [3, 4, 5]]);\n    let output = tensor_with_zero.prod_dim(-1);\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0], [60]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/all.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_all() {\n    let tensor = TestTensorInt::<2>::from([[0, 1, 0], [1, -1, 1]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1]]);\n    let data_actual = tensor.all().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_all_dim() {\n    let tensor = TestTensorInt::<2>::from([[0, 1, 0], [1, -1, 1]]);\n    let data_actual = tensor.all_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/any.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_any() {\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([true]);\n    data_expected.assert_eq(&data_actual, false);\n\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [0, 0, 0]]);\n    let data_actual = tensor.any().into_data();\n    let data_expected = TensorData::from([false]);\n    data_expected.assert_eq(&data_actual, false);\n}\n\n#[test]\nfn test_any_dim() {\n    let tensor = TestTensorInt::<2>::from([[0, 0, 0], [1, -1, 0]]);\n    let data_actual = tensor.any_dim(1).into_data();\n    let data_expected = TensorData::from([[false], [true]]);\n    data_expected.assert_eq(&data_actual, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/arange.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::backend::Backend;\n\n#[test]\nfn test_arange() {\n    let device = <TestBackend as Backend>::Device::default();\n\n    let tensor = TestTensorInt::<1>::arange(2..5, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([2, 3, 4]), false);\n\n    // Test arange with negative numbers\n    let tensor = TestTensorInt::<1>::arange(-10..-5, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([-10, -9, -8, -7, -6]), false);\n\n    let tensor = TestTensorInt::<1>::arange(-3..0, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([-3, -2, -1]), false);\n\n    // Test arange with a mix of positive and negative numbers\n    let tensor = TestTensorInt::<1>::arange(-2..3, &device);\n    tensor\n        .clone()\n        .into_data()\n        .assert_eq(&TensorData::from([-2, -1, 0, 1, 2]), false);\n    assert_eq!(tensor.device(), device);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/arange_step.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::backend::Backend;\n\n#[test]\nfn test_arange_step() {\n    let device = <TestBackend as Backend>::Device::default();\n\n    // Test correct sequence of numbers when the range is 0..9 and the step is 1\n    let tensor = TestTensorInt::<1>::arange_step(0..9, 1, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([0, 1, 2, 3, 4, 5, 6, 7, 8]), false);\n\n    // Test correct sequence of numbers when the range is 0..3 and the step is 2\n    let tensor = TestTensorInt::<1>::arange_step(0..3, 2, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([0, 2]), false);\n\n    // Test correct sequence of numbers when the range is 0..2 and the step is 5\n    let tensor = TestTensorInt::<1>::arange_step(0..2, 5, &device);\n    tensor.into_data().assert_eq(&TensorData::from([0]), false);\n\n    // Test correct sequence of numbers when the range includes negative numbers\n    let tensor = TestTensorInt::<1>::arange_step(-3..3, 2, &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([-3, -1, 1]), false);\n\n    let tensor = TestTensorInt::<1>::arange_step(-5..1, 5, &device);\n    tensor\n        .clone()\n        .into_data()\n        .assert_eq(&TensorData::from([-5, 0]), false);\n    assert_eq!(tensor.device(), device);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_when_step_is_zero() {\n    let device = <TestBackend as Backend>::Device::default();\n    // Test that arange_step panics when the step is 0\n    let _tensor = TestTensorInt::<1>::arange_step(0..3, 0, &device);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/arg.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_argmax_2d_dim0_int() {\n    let tensor = TestTensorInt::<2>::from([[10, 11, 2], [3, 4, 5]]);\n\n    let output = tensor.argmax(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 0, 1]]), false);\n}\n\n#[test]\nfn test_argmin_2d_dim0_int() {\n    let tensor = TestTensorInt::<2>::from([[10, 11, 2], [30, 4, 5]]);\n\n    let output = tensor.argmin(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 0]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/bitwise.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_apply_bitwise_and_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]);\n\n    let output = tensor_1.bitwise_and(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_and_1d() {\n    let tensor_1 = TestTensorInt::<1>::from([13, 7]);\n    let tensor_2 = TestTensorInt::from([11, 3]);\n\n    let output = tensor_1.bitwise_and(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([9, 3]), false);\n}\n\n#[test]\nfn should_apply_bitwise_and_scalar_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let scalar = 5;\n\n    let output = tensor_1.bitwise_and_scalar(scalar);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_not_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n\n    let output = tensor_1.bitwise_not();\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_or_scalar_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let scalar = 5;\n\n    let output = tensor_1.bitwise_or_scalar(scalar);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_or_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]);\n\n    let output = tensor_1.bitwise_or(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_or_1d() {\n    let tensor_1 = TestTensorInt::<1>::from([13, 7]);\n    let tensor_2 = TestTensorInt::from([11, 3]);\n\n    let output = tensor_1.bitwise_or(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([15, 7]), false);\n}\n\n#[test]\nfn should_apply_bitwise_xor_scalar_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let scalar = 5;\n\n    let output = tensor_1.bitwise_xor_scalar(scalar);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_xor_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]);\n\n    let output = tensor_1.bitwise_xor(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_xor_1d() {\n    let tensor_1 = TestTensorInt::<1>::from([13, 7]);\n    let tensor_2 = TestTensorInt::from([11, 3]);\n\n    let output = tensor_1.bitwise_xor(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([6, 4]), false);\n}\n\n#[test]\nfn should_apply_bitwise_left_shift_2d() {\n    if (IntElem::MAX as u32) < 512 {\n        return;\n    }\n\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor_1.bitwise_left_shift(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_left_shift_scalar_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let scalar = 2;\n\n    let output = tensor_1.bitwise_left_shift_scalar(scalar);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_right_shift_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor_1.bitwise_right_shift(tensor_2);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false);\n}\n\n#[test]\nfn should_apply_bitwise_right_shift_scalar_2d() {\n    let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);\n    let scalar = 2;\n\n    let output = tensor_1.bitwise_right_shift_scalar(scalar);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/cartesian_grid.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::backend::Backend;\n\n#[test]\nfn test_cartesian_grid() {\n    let device = <TestBackend as Backend>::Device::default();\n\n    // Test a single element tensor\n    let tensor: TestTensorInt<2> = TestTensorInt::<1>::cartesian_grid([1], &device);\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[0]]), false);\n\n    // Test for a 2x2 tensor\n    let tensor: TestTensorInt<3> = TestTensorInt::<2>::cartesian_grid([2, 2], &device);\n    tensor.into_data().assert_eq(\n        &TensorData::from([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]),\n        false,\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/cast.rs",
    "content": "use super::*;\nuse burn_tensor::{DType, TensorData};\n\n#[test]\nfn cast_int_to_bool() {\n    let tensor1 = TestTensorInt::<2>::from([[0, 43, 0], [2, -4, 31]]);\n    let data_actual = tensor1.bool().into_data();\n    let data_expected = TensorData::from([[false, true, false], [true, true, true]]);\n    data_actual.assert_eq(&data_expected, false);\n}\n\n#[test]\nfn cast_bool_to_int_tensor() {\n    let tensor = TestTensorBool::<2>::from([[true, false, true], [false, false, true]]).int();\n\n    let expected = TensorData::from([[1, 0, 1], [0, 0, 1]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn cast_int_precision() {\n    let data = TensorData::from([[1, 2, 3], [4, 5, 6]]);\n    let tensor = TestTensorInt::<2>::from(data.clone());\n\n    let output = tensor.cast(DType::I32);\n\n    assert_eq!(output.dtype(), DType::I32);\n    output.into_data().assert_eq(&data, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/cat.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_cat_ops_int() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);\n    let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device);\n\n    let output = Tensor::cat(vec![tensor_1, tensor_2], 0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3], [4, 5, 6]]), false);\n}\n\n#[test]\nfn should_support_cat_with_empty_tensor_int() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);\n    let tensor_2: TestTensorInt<2> = TestTensorInt::empty([1, 0], &device);\n\n    let output = Tensor::cat(vec![tensor_1, tensor_2], 1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/chunk.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_chunk_multi_dimension() {\n    let tensors =\n        TestTensorInt::<2>::from_data(TensorData::from([[0, 1, 2, 3]]), &Default::default())\n            .chunk(2, 1);\n    assert_eq!(tensors.len(), 2);\n\n    let expected = [TensorData::from([[0, 1]]), TensorData::from([[2, 3]])];\n\n    for (index, tensor) in tensors.iter().enumerate() {\n        tensor.to_data().assert_eq(&expected[index], false);\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/comparison.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_equal() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_not_equal() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.not_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, true], [true, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_equal_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 2, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().equal_elem(2);\n    let data_actual_inplace = tensor_1.equal_elem(2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_not_equal_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 2, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().not_equal_elem(2);\n    let data_actual_inplace = tensor_1.not_equal_elem(2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn greater_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_elem(4);\n    let data_actual_inplace = tensor_1.greater_elem(4);\n\n    let data_expected = TensorData::from([[false, false, false], [false, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_equal_elem(4);\n    let data_actual_inplace = tensor_1.greater_equal_elem(4);\n\n    let data_expected = TensorData::from([[false, false, false], [false, true, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]);\n\n    let data_actual_cloned = tensor_1.clone().greater(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater(tensor_2);\n\n    let data_expected = TensorData::from([[false, false, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]);\n\n    let data_actual_cloned = tensor_1.clone().greater_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.greater_equal(tensor_2);\n\n    let data_expected = TensorData::from([[false, true, true], [false, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_elem(4);\n    let data_actual_inplace = tensor_1.lower_elem(4);\n\n    let data_expected = TensorData::from([[true, true, true], [true, false, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_equal_elem() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_equal_elem(4);\n    let data_actual_inplace = tensor_1.lower_equal_elem(4);\n\n    let data_expected = TensorData::from([[true, true, true], [true, true, false]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]);\n\n    let data_actual_cloned = tensor_1.clone().lower(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower(tensor_2);\n\n    let data_expected = TensorData::from([[true, false, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_lower_equal() {\n    let tensor_1 = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n    let tensor_2 = TestTensorInt::<2>::from([[1, 1, 1], [4, 3, 50]]);\n\n    let data_actual_cloned = tensor_1.clone().lower_equal(tensor_2.clone());\n    let data_actual_inplace = tensor_1.lower_equal(tensor_2);\n\n    let data_expected = TensorData::from([[true, true, false], [true, false, true]]);\n    data_expected.assert_eq(&data_actual_cloned.into_data(), false);\n    data_expected.assert_eq(&data_actual_inplace.into_data(), false);\n}\n\n#[test]\nfn test_greater_broadcast() {\n    // Test broadcasting with shape [1, 4] vs [4, 4]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1, 2, 3, 4]]);\n    let data_2 = TensorData::from([\n        [0.5, 1.5, 2.5, 3.5],\n        [1.5, 2.5, 3.5, 4.5],\n        [2.5, 3.5, 4.5, 5.5],\n        [3.5, 4.5, 5.5, 6.5],\n    ]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.greater(tensor_2);\n\n    let expected = TensorData::from([\n        [true, true, true, true],\n        [false, false, false, false],\n        [false, false, false, false],\n        [false, false, false, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_greater_equal_broadcast() {\n    // Test broadcasting with shape [4, 1] vs [1, 4]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1], [2], [3], [4]]);\n    let data_2 = TensorData::from([[1, 2, 3, 4]]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.greater_equal(tensor_2);\n\n    let expected = TensorData::from([\n        [true, false, false, false],\n        [true, true, false, false],\n        [true, true, true, false],\n        [true, true, true, true],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_equal_broadcast() {\n    // Test broadcasting with different ranks\n    let device = Default::default();\n    let data_1 = TensorData::from([[2], [3], [4]]);\n    let data_2 = TensorData::from([[2, 3, 4, 2]]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.equal(tensor_2);\n\n    let expected = TensorData::from([\n        [true, false, false, true],\n        [false, true, false, false],\n        [false, false, true, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_not_equal_broadcast() {\n    // Test broadcasting with shape [3, 1] vs [1, 3]\n    let device = Default::default();\n    let data_1 = TensorData::from([[1], [2], [3]]);\n    let data_2 = TensorData::from([[1, 2, 3]]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.not_equal(tensor_2);\n\n    let expected = TensorData::from([\n        [false, true, true],\n        [true, false, true],\n        [true, true, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_int_greater_broadcast() {\n    let device = Default::default();\n    let data_1 = TensorData::from([[1i32, 2, 3]]);\n    let data_2 = TensorData::from([[0i32], [2], [4]]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.greater(tensor_2);\n\n    let expected = TensorData::from([\n        [true, true, true],\n        [false, false, true],\n        [false, false, false],\n    ]);\n    expected.assert_eq(&result.into_data(), false);\n}\n\n#[test]\nfn test_int_lower_equal_broadcast() {\n    let device = Default::default();\n    let data_1 = TensorData::from([[2i32], [4]]);\n    let data_2 = TensorData::from([[1i32, 2, 3]]);\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let result = tensor_1.lower_equal(tensor_2);\n\n    let expected = TensorData::from([[false, true, true], [false, false, false]]);\n    expected.assert_eq(&result.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/create_like.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_zeros_like() {\n    let tensor = TestTensorInt::<3>::from([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]);\n\n    let tensor = tensor.zeros_like();\n    let expected = TensorData::from([[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_ones_like() {\n    let tensor = TestTensorInt::<3>::from([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]);\n\n    let tensor = tensor.ones_like();\n    let expected = TensorData::from([[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1]]]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/cumulative.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_cumsum_int_dim_0() {\n    let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor.cumsum(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3], [5, 7, 9]]), false);\n}\n\n#[test]\nfn test_cumsum_int_dim_1() {\n    let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor.cumsum(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 3, 6], [4, 9, 15]]), false);\n}\n\n#[test]\nfn test_cumprod_int_dim_0() {\n    let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor.cumprod(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3], [4, 10, 18]]), false);\n}\n\n#[test]\nfn test_cumprod_int_dim_1() {\n    let tensor = TestTensorInt::<2>::from([[1, 2, 3], [4, 5, 6]]);\n\n    let output = tensor.cumprod(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 6], [4, 20, 120]]), false);\n}\n\n#[test]\nfn test_cummin_int_dim_0() {\n    let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]);\n\n    let output = tensor.cummin(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 1, 4], [2, 1, 1]]), false);\n}\n\n#[test]\nfn test_cummin_int_dim_1() {\n    let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]);\n\n    let output = tensor.cummin(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 1, 1], [2, 2, 1]]), false);\n}\n\n#[test]\nfn test_cummax_int_dim_0() {\n    let tensor = TestTensorInt::<2>::from([[3, 1, 4], [1, 5, 2]]);\n\n    let output = tensor.cummax(0);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 1, 4], [3, 5, 4]]), false);\n}\n\n#[test]\nfn test_cummax_int_dim_1() {\n    let tensor = TestTensorInt::<2>::from([[3, 1, 4], [1, 5, 2]]);\n\n    let output = tensor.cummax(1);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[3, 3, 4], [1, 5, 5]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/div.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_div_ops_int() {\n    let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let data_2 = TensorData::from([[1, 1, 2], [1, 1, 2]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 / tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 1], [3, 4, 2]]), false);\n}\n\n#[test]\nfn test_div_broadcast_int() {\n    let data_1 = TensorData::from([[0, 1, 2]]);\n    let data_2 = TensorData::from([[1, 1, 2], [3, 4, 5]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 / tensor_2;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 1, 1], [0, 0, 0]]), false);\n}\n\n#[test]\nfn should_support_div_scalar_ops_int() {\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let scalar = 2;\n    let tensor = TestTensorInt::<2>::from_data(data, &Default::default());\n\n    let output = tensor / scalar;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 0, 1], [1, 2, 2]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/expand.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn expand_2d_int() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let output = tensor.expand([3, 3]);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3], [1, 2, 3], [1, 2, 3]]), false);\n}\n\n#[test]\nfn should_all_negative_one() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let output = tensor.expand([2, -1]);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 2, 3], [1, 2, 3]]), false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_negative_one_on_non_existing_dim() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let _expanded_tensor = tensor.expand([-1, 3]);\n}\n\n/// Regression test for https://github.com/tracel-ai/burn/issues/2091\n#[test]\nfn inplace_op_after_expand() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let mut output = tensor.expand([2, 3]);\n    output = output + 1;\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([[2, 3, 4], [2, 3, 4]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/flip.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn flip_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    let flipped = tensor.clone().flip([0, 2]);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2))\n    let expected = TensorData::from([\n        [[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]],\n        [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]],\n    ]);\n\n    flipped.into_data().assert_eq(&expected, false);\n\n    // Test with no flip\n    let flipped = tensor.clone().flip([]);\n    assert_eq!(tensor.into_data(), flipped.into_data());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/full.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_tensor_full() {\n    let device = Default::default();\n    let int_tensor = TestTensorInt::<2>::full([2, 2], 2, &device);\n    int_tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[2, 2], [2, 2]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/gather_scatter.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_gather_1d_dim0_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_ints([5, 6, 7], &device);\n    let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.gather(0, indices);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([6, 6, 5, 6, 7]), false);\n}\n\n#[test]\nfn should_gather_indices_broadcasted() {\n    let device = Default::default();\n\n    let batch_size = 3;\n    let fft_size = 4;\n    let shape = [batch_size, fft_size, 2];\n    let x = TestTensorInt::arange(\n        0..shape.iter().product::<usize>() as i64,\n        &Default::default(),\n    )\n    .reshape(shape);\n    let idx = TestTensorInt::<1>::from_ints([0, 2, 1, 3], &device);\n\n    let expected = TestTensorInt::<3>::from([\n        [[0, 1], [4, 5], [2, 3], [6, 7]],\n        [[8, 9], [12, 13], [10, 11], [14, 15]],\n        [[16, 17], [20, 21], [18, 19], [22, 23]],\n    ])\n    .into_data();\n\n    // Case 1: gather dim 2\n    let perm = idx\n        .clone()\n        .reshape([1, 1, fft_size])\n        .repeat_dim(0, batch_size)\n        .repeat_dim(1, 2);\n\n    let input = x.clone().permute([0, 2, 1]);\n    let out = input.gather(2, perm).permute([0, 2, 1]);\n\n    out.into_data().assert_eq(&expected, true);\n\n    // Case 2: gather directly on dim 1\n    let perm = idx.reshape([1, fft_size, 1]).repeat_dim(0, batch_size);\n    let out2 = x.gather(1, perm.repeat_dim(2, 2));\n\n    out2.into_data().assert_eq(&expected, true);\n}\n\n#[test]\nfn should_scatter_add_1d_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_ints([0, 0, 0], &device);\n    let values = TestTensorInt::from_ints([5, 4, 3], &device);\n    let indices = TestTensorInt::from_ints([1, 0, 2], &device);\n\n    let output = tensor.scatter(0, indices, values, IndexingUpdateOp::Add);\n\n    output\n        .into_data()\n        .assert_eq(&TensorData::from([4, 5, 3]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/init.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_int_empty() {\n    let shape = [2, 2];\n    let tensor = TestTensorInt::<2>::empty(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into())\n}\n\n#[test]\nfn should_support_int_zeros() {\n    let shape = [2, 2];\n    let tensor = TestTensorInt::<2>::zeros(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[0, 0], [0, 0]]), false);\n}\n\n#[test]\nfn should_support_int_ones() {\n    let shape = [2, 2];\n    let tensor = TestTensorInt::<2>::ones(shape, &Default::default());\n    assert_eq!(tensor.shape(), shape.into());\n\n    tensor\n        .into_data()\n        .assert_eq(&TensorData::from([[1, 1], [1, 1]]), false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/mask.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_mask_where_broadcast_int() {\n    let device = Default::default();\n    // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times\n    let tensor = TestTensorInt::<1>::arange(2..6, &device).reshape([1, 2, 2]);\n    let mask = TestTensorBool::<3>::from_bool(\n        TensorData::from([\n            [[true, false], [false, true]],\n            [[false, true], [true, false]],\n            [[false, false], [false, false]],\n            [[true, true], [true, true]],\n        ]),\n        &device,\n    );\n    let value = TestTensorInt::<3>::ones([4, 2, 2], &device);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([\n        [[1, 3], [4, 1]],\n        [[2, 1], [1, 5]],\n        [[2, 3], [4, 5]],\n        [[1, 1], [1, 1]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_int_mask_where_ops() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n    let value = TestTensorInt::<2>::from_data(TensorData::from([[8, 9], [10, 11]]), &device);\n\n    let output = tensor.mask_where(mask, value);\n    let expected = TensorData::from([[8, 7], [2, 11]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_int_mask_fill_ops() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<2>::from_data([[1, 7], [2, 3]], &device);\n    let mask =\n        TestTensorBool::<2>::from_bool(TensorData::from([[true, false], [false, true]]), &device);\n\n    let output = tensor.mask_fill(mask, 9);\n    let expected = TensorData::from([[9, 7], [2, 9]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/matmul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_int_matmul_d2() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_ints([[1, 7], [2, 3], [1, 5]], &device);\n    let tensor_2 = TestTensorInt::<2>::from_ints([[4, 7, 5], [2, 3, 5]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[18, 28, 40], [14, 23, 25], [14, 22, 30]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_d3() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device);\n    let tensor_2 = TestTensorInt::<3>::from_ints([[[4, 7], [2, 3]]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[[18, 28], [14, 23]]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_broadcast_1() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<3>::from_ints([[[1, 7], [2, 3]]], &device);\n    let tensor_2 = TestTensorInt::from_ints([[[4, 7], [2, 3]], [[2, 5], [6, 3]]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[[18, 28], [14, 23]], [[44, 26], [22, 19]]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_broadcast_4d() {\n    let device = Default::default();\n    // [2, 1, 2, 2]\n    let tensor_1 = TestTensorInt::<4>::from_ints([[[[1, 7], [2, 3]]], [[[2, 5], [6, 3]]]], &device);\n    // [1, 2, 2, 2]\n    let tensor_2 = TestTensorInt::from_ints([[[[9, 8], [1, 4]], [[2, 7], [3, 5]]]], &device);\n\n    // [2, 1, 2, 2] @ [1, 2, 2, 2] -> [2, 2, 2, 2]\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [[[16, 36], [21, 28]], [[23, 42], [13, 29]]],\n        [[[23, 36], [57, 60]], [[19, 39], [21, 57]]],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_simple_1() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_ints([[5, 14], [14, 25]], &device);\n    let tensor_2 = TestTensorInt::from_ints([[3, 4, 5], [0, 1, 2]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[15, 34, 53], [42, 81, 120]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_4_3() {\n    if (IntElem::MAX as u32) < 324 {\n        return;\n    }\n\n    let device = Default::default();\n    let tensor_1 =\n        TestTensorInt::<2>::from_ints([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], &device);\n    let tensor_2 =\n        TestTensorInt::from_ints([[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[56, 62, 68], [152, 174, 196], [248, 286, 324]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_trivial() {\n    if (IntElem::MAX as u32) < 506 {\n        return;\n    }\n\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]);\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1);\n\n    tensor_3.into_data().assert_eq(\n        &TensorData::from([\n            [56, 62, 68, 74],\n            [152, 174, 196, 218],\n            [248, 286, 324, 362],\n            [344, 398, 452, 506],\n        ]),\n        false,\n    );\n}\n\n#[test]\nfn test_int_matmul_trivial_transposed() {\n    if (IntElem::MAX as u32) < 734 {\n        return;\n    }\n\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..16, &device).reshape([4, 4]);\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());\n\n    tensor_3.into_data().assert_eq(\n        &TensorData::from([\n            [14, 38, 62, 86],\n            [38, 126, 214, 302],\n            [62, 214, 366, 518],\n            [86, 302, 518, 734],\n        ]),\n        false,\n    );\n}\n\n#[test]\nfn test_int_matmul_4_8() {\n    if (IntElem::MAX as u32) < 6092 {\n        return;\n    }\n\n    let device = Default::default();\n\n    let tensor_1 = TestTensorInt::<1>::arange(0..32, &device).reshape([4, 8]);\n\n    let tensor_3 = tensor_1.clone().matmul(tensor_1.transpose());\n\n    tensor_3.into_data().assert_eq(\n        &TensorData::from([\n            [140, 364, 588, 812],\n            [364, 1100, 1836, 2572],\n            [588, 1836, 3084, 4332],\n            [812, 2572, 4332, 6092],\n        ]),\n        false,\n    );\n}\n\n#[test]\nfn test_int_matmul_simple_2() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_ints([[1, 2, 3, 4]], &device);\n    let tensor_2 = TestTensorInt::from_ints([[3], [4], [5], [6]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([[50]]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_int_matmul_simple_3() {\n    let device = Default::default();\n    let tensor_1 =\n        TestTensorInt::<2>::from_ints([[3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]], &device);\n    let tensor_2 = TestTensorInt::from_ints([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [9, 18, 27, 36],\n        [12, 24, 36, 48],\n        [15, 30, 45, 60],\n        [18, 36, 54, 72],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn int_should_panic_when_inner_dimensions_are_not_equal() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_ints([[3, 3], [4, 4], [5, 5], [6, 6]], &device);\n    let tensor_2 = TestTensorInt::from_ints([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]], &device);\n\n    let tensor_3 = tensor_1.matmul(tensor_2);\n    let expected = TensorData::from([\n        [9, 18, 27, 36],\n        [12, 24, 36, 48],\n        [15, 30, 45, 60],\n        [18, 36, 54, 72],\n    ]);\n\n    tensor_3.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod abs;\nmod add;\nmod aggregation;\nmod all;\nmod any;\nmod arange;\nmod arange_step;\nmod bitwise;\nmod cartesian_grid;\nmod cast;\nmod cat;\nmod chunk;\nmod comparison;\nmod create_like;\nmod cumulative;\nmod div;\nmod expand;\nmod flip;\nmod full;\nmod gather_scatter;\nmod init;\nmod mask;\nmod matmul;\nmod movedim;\nmod mul;\nmod one_hot;\nmod permute;\nmod random;\nmod remainder;\nmod repeat;\nmod repeat_dim;\nmod reshape;\nmod roll;\nmod select;\nmod sign;\nmod slice;\nmod slice_assign;\nmod sort_argsort;\nmod stack;\nmod sub;\nmod take;\nmod topk;\nmod transpose;\nmod tri;\nmod unfold;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/movedim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn movedim_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    let permuted = tensor.clone().movedim(0, 2);\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim(0, 2)\n    let expected = TensorData::from([\n        [[0, 12], [1, 13], [2, 14], [3, 15]],\n        [[4, 16], [5, 17], [6, 18], [7, 19]],\n        [[8, 20], [9, 21], [10, 22], [11, 23]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().movedim(0, -1);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().movedim(0, 0);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\nfn vec_input_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![1, 0]);\n    // from pytorch\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).movedim([0, 1], [1, 0])\n    let expected = TensorData::from([\n        [[0, 1, 2, 3], [12, 13, 14, 15]],\n        [[4, 5, 6, 7], [16, 17, 18, 19]],\n        [[8, 9, 10, 11], [20, 21, 22, 23]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axes\n    let permuted = tensor.clone().movedim(vec![-3, -2], vec![-2, -3]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axes\n    let permuted = tensor.clone().movedim(vec![0, 1], vec![0, 1]);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/mul.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_mul_ops_int() {\n    let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let data_2 = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0, 1, 4], [9, 16, 25]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_mul_broadcast_int() {\n    let data_1 = TensorData::from([[0, 1, 2]]);\n    let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 * tensor_2;\n    let expected = TensorData::from([[0, 4, 10], [0, 7, 16]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_mul_scalar_ops_int() {\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let scalar = 2;\n    let tensor = TestTensorInt::<2>::from_data(data, &Default::default());\n\n    let output = tensor * scalar;\n    let expected = TensorData::from([[0, 2, 4], [6, 8, 10]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/one_hot.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn int_should_support_one_hot() {\n    let tensor = TestTensorInt::<1>::from([0, 1, 4]);\n    let one_hot_tensor: TestTensorInt<2> = tensor.one_hot(5);\n    let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]);\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn int_one_hot_should_panic_when_index_exceeds_number_of_classes() {\n    let tensor = TestTensorInt::<1>::from([5]);\n    let _result: TestTensorInt<2> = tensor.one_hot(5);\n}\n\n#[test]\n#[should_panic]\nfn int_one_hot_should_panic_when_number_of_classes_is_zero() {\n    let tensor = TestTensorInt::<1>::from([2]);\n    let _result: TestTensorInt<2> = tensor.one_hot(0);\n}\n\n#[test]\nfn one_hot_fill_with_positive_axis_and_indices() {\n    let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]);\n    let expected = TensorData::from([\n        [\n            [1, 1],\n            [3, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 3],\n        ],\n        [\n            [1, 1],\n            [1, 1],\n            [3, 1],\n            [1, 1],\n            [1, 3],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n            [1, 1],\n        ],\n    ]);\n\n    let one_hot_tensor: TestTensorInt<3> = tensor.one_hot_fill(10, 3.0, 1.0, 1);\n\n    one_hot_tensor.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/permute.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn permute_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    let permuted = tensor.clone().permute([2, 1, 0]);\n\n    // from pytorch:\n    // import torch; torch.arange(0, 24).reshape(2, 3, 4).permute(2, 1, 0)\n    let expected = TensorData::from([\n        [[0, 12], [4, 16], [8, 20]],\n        [[1, 13], [5, 17], [9, 21]],\n        [[2, 14], [6, 18], [10, 22]],\n        [[3, 15], [7, 19], [11, 23]],\n    ]);\n\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with negative axis\n    let permuted = tensor.clone().permute([-1, 1, 0]);\n    permuted.into_data().assert_eq(&expected, false);\n\n    // Test with the same axis\n    let permuted = tensor.clone().permute([0, 1, 2]);\n    permuted.into_data().assert_eq(&tensor.into_data(), true);\n}\n\n#[test]\n#[should_panic]\nfn edge_repeated_axes() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with a repeated axis\n    let _ = tensor.clone().permute([0, 0, 1]);\n}\n\n#[test]\n#[should_panic]\nfn edge_out_of_bound_axis() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);\n\n    // Test with a repeated axis\n    let _ = tensor.clone().permute([3, 0, 1]);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/random.rs",
    "content": "use super::*;\nuse burn_tensor::{Distribution, ElementConversion};\n\n#[test]\nfn rand_uniform_int() {\n    let low = 0.;\n    let high = 5.;\n\n    let tensor = TestTensorInt::<1>::random(\n        [100_000],\n        Distribution::Uniform(low, high),\n        &Default::default(),\n    );\n\n    tensor\n        .into_data()\n        .assert_within_range::<IntElem>(low.elem()..high.elem());\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/remainder.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_int_remainder_basic() {\n    let data = TensorData::from([-3, -2, -1, 1, 2, 3]);\n    let device = Default::default();\n    let lhs = TestTensorInt::<1>::from_data(data, &device);\n\n    let rhs = TestTensorInt::from_data(TensorData::from([2, 3, 1, 2, 1, 3]), &device);\n    let output = lhs.remainder(rhs);\n    let expected = TensorData::from([1, 1, -0, 1, 0, 0]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_int_remainder_basic_scalar() {\n    let data = TensorData::from([-3, -2, -1, 1, 2, 3]);\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data(data, &device);\n\n    let output = tensor.remainder_scalar(2);\n    let expected = TensorData::from([1, 0, 1, 1, 0, 1]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/repeat.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_int_repeat_ops_one_dimension() {\n    let data = TensorData::from([[0i32, 1i32, 2i32]]);\n    let tensor = TestTensorInt::<2>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[4, 1, 1]);\n    let expected = TensorData::from([\n        [0i32, 1i32, 2i32],\n        [0i32, 1i32, 2i32],\n        [0i32, 1i32, 2i32],\n        [0i32, 1i32, 2i32],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_int_repeat_on_many_dims() {\n    let data = TensorData::from([\n        [[1i32, 2i32], [3i32, 4i32]],\n        [[5i32, 6i32], [7i32, 8i32]],\n        [[9i32, 10i32], [11i32, 12i32]],\n        [[13i32, 14i32], [15i32, 16i32]],\n    ]);\n    let tensor = TestTensorInt::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat(&[2, 3, 2]);\n\n    let expected = TensorData::from([\n        [\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n        ],\n        [\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n        ],\n        [\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n        ],\n        [\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n        ],\n        [\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n            [1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32],\n        ],\n        [\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n            [5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32],\n        ],\n        [\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n            [9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32],\n        ],\n        [\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n            [13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/repeat_dim.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_int_repeat_ops() {\n    let data = TensorData::from([[0, 1, 2]]);\n    let tensor = TestTensorInt::<2>::from_data(data, &Default::default());\n\n    let output = tensor.repeat_dim(0, 4);\n    let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_int_repeat_on_dims_larger_than_1() {\n    let data = TensorData::from([\n        [[1i32, 2i32], [3i32, 4i32]],\n        [[5i32, 6i32], [7i32, 8i32]],\n        [[9i32, 10i32], [11i32, 12i32]],\n        [[13i32, 14i32], [15i32, 16i32]],\n    ]);\n    let tensor = TestTensorInt::<3>::from_data(data, &Default::default());\n\n    let output = tensor.repeat_dim(2, 3);\n    let expected = TensorData::from([\n        [\n            [1i32, 2i32, 1i32, 2i32, 1i32, 2i32],\n            [3i32, 4i32, 3i32, 4i32, 3i32, 4i32],\n        ],\n        [\n            [5i32, 6i32, 5i32, 6i32, 5i32, 6i32],\n            [7i32, 8i32, 7i32, 8i32, 7i32, 8i32],\n        ],\n        [\n            [9i32, 10i32, 9i32, 10i32, 9i32, 10i32],\n            [11i32, 12i32, 11i32, 12i32, 11i32, 12i32],\n        ],\n        [\n            [13i32, 14i32, 13i32, 14i32, 13i32, 14i32],\n            [15i32, 16i32, 15i32, 16i32, 15i32, 16i32],\n        ],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/reshape.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_reshape_maybe_fused_1() {\n    let tensor = TestTensorInt::arange(0..32, &Default::default());\n    let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default());\n    let tensor1 = tensor.clone().reshape([1, 4, 8]);\n    let output = tensor0 + tensor1;\n\n    let expected = TensorData::from([\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n        [\n            [0, 1, 2, 3, 4, 5, 6, 7],\n            [8, 9, 10, 11, 12, 13, 14, 15],\n            [16, 17, 18, 19, 20, 21, 22, 23],\n            [24, 25, 26, 27, 28, 29, 30, 31],\n        ],\n    ]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_reshape_maybe_fused_2() {\n    let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default());\n    let tensor1 = tensor.reshape([2, 2, 1]);\n    let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default());\n    let output = tensor2 + tensor1;\n\n    let expected_tensor1 =\n        TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]);\n    output.into_data().assert_eq(&expected_tensor1, false);\n}\n\n#[test]\nfn should_support_reshape_maybe_fused_3() {\n    let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default());\n    let tensor1 = tensor.reshape([2, 2, 1]);\n    let _tensor2 = TestTensorInt::<3>::full([2, 2, 3], 5, &Default::default());\n\n    let expected_tensor1 = TensorData::from([[[0], [2]], [[1], [2]]]);\n    tensor1.into_data().assert_eq(&expected_tensor1, false);\n}\n\n#[test]\nfn should_support_reshape_maybe_fused_4() {\n    let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default());\n    let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default());\n    let tensor2 = tensor2.swap_dims(0, 1);\n    let tensor1 = tensor.reshape([2, 2, 1]);\n    let output = tensor2 + tensor1;\n\n    let expected_tensor1 =\n        TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]);\n    output.into_data().assert_eq(&expected_tensor1, false);\n}\n\n#[test]\nfn should_support_reshape_maybe_fused_5() {\n    let tensor = TestTensorInt::<3>::from_data([[[0], [1], [2], [3]]], &Default::default());\n    let tensor1 = tensor.clone().reshape([2, 1, 2]);\n    let tensor2 = TestTensorInt::<3>::full([2, 4, 2], 0, &Default::default());\n    let output = tensor2.clone() + tensor1 + tensor.clone();\n\n    let expected_tensor1 = TensorData::from([\n        [[0, 1], [1, 2], [2, 3], [3, 4]],\n        [[2, 3], [3, 4], [4, 5], [5, 6]],\n    ]);\n    output.into_data().assert_eq(&expected_tensor1, false);\n}\n\n#[test]\nfn should_support_reshape_maybe_fused_6() {\n    let device = Default::default();\n\n    let tensor1 = TestTensorInt::arange(0..32, &device);\n    let tensor1 = tensor1.reshape([2, 4, 4]);\n\n    let tensor2 = TestTensorInt::arange(0..16, &device);\n    let tensor2 = tensor2.reshape([1, 4, 4]);\n\n    let tensor3 = TestTensorInt::arange(0..8, &device);\n    let tensor3 = tensor3.reshape([4, 1, 2]);\n    let tensor3 = tensor3.swap_dims(0, 2);\n\n    let out = tensor1 + tensor2 + tensor3;\n\n    let expected = TensorData::from([\n        [\n            [0, 4, 8, 12],\n            [8, 12, 16, 20],\n            [16, 20, 24, 28],\n            [24, 28, 32, 36],\n        ],\n        [\n            [17, 21, 25, 29],\n            [25, 29, 33, 37],\n            [33, 37, 41, 45],\n            [41, 45, 49, 53],\n        ],\n    ]);\n    out.to_data().assert_eq(&expected, false);\n}\n\n// Skip on metal - cubecl autotune error\n// Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4327\n#[cfg(not(feature = \"metal\"))]\n#[test]\nfn should_support_multiple_reshapes_cloned_tensor() {\n    let device = Default::default();\n\n    let lhs = TestTensorInt::<1>::arange(0..4, &device).reshape([2, 2]);\n    // fusion should preserve correct strides when operating on the same tensor\n    let rhs = lhs.clone();\n\n    let lhs = lhs.reshape([2, 2, 1]);\n    let rhs = rhs.reshape([1, 2, 2]);\n\n    let p = lhs.mul(rhs);\n\n    let s = p.sum_dim(1);\n\n    let out = s.reshape([2, 2]);\n\n    out.into_data()\n        .assert_eq(&TensorData::from([[2, 3], [6, 11]]), false);\n}\n\n#[test]\nfn should_support_reshape_int() {\n    let data = TensorData::from([0, 1, 2]);\n    let tensor = TestTensorInt::<1>::from_data(data, &Default::default());\n\n    let output = tensor.clone().reshape([1, 3]);\n    let expected = TensorData::from([[0, 1, 2]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/roll.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[ignore = \"0 size resources are not yet supported\"]\n#[test]\nfn test_roll_empty() {\n    let device = Default::default();\n    let input = TestTensorInt::<2>::zeros([12, 0], &device);\n\n    let result = input.clone().roll(&[1, 2], &[0, 1]);\n\n    assert_eq!(&*result.shape(), &[12, 0]);\n\n    // TODO: Rolling an empty tensor should return the same empty tensor;\n    // but we have no way to compare tensor references yet.\n}\n\n#[test]\nfn test_roll() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // No-op shift:\n    input\n        .clone()\n        .roll(&[0, 0], &[0, 1])\n        .to_data()\n        .assert_eq(&input.clone().to_data(), false);\n\n    input\n        .clone()\n        .roll(&[1, -1], &[0, 1])\n        .to_data()\n        .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);\n\n    input\n        .clone()\n        .roll(&[-1, 1], &[1, 0])\n        .to_data()\n        .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);\n\n    input\n        .clone()\n        .roll(&[2 * 32 + 1, 3 * (-400) - 1], &[0, 1])\n        .to_data()\n        .assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);\n}\n\n#[should_panic]\n#[test]\nfn test_roll_dim_too_big() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // Attempting to roll on a dimension that doesn't exist should panic\n    let _d = input.roll(&[1], &[2]);\n}\n\n#[should_panic]\n#[test]\nfn test_roll_dim_too_small() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // Attempting to roll on a dimension that doesn't exist should panic\n    let _d = input.roll(&[1], &[-3]);\n}\n\n#[should_panic]\n#[test]\nfn test_roll_shift_size_mismatch() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // Attempting to roll with a shift size that doesn't match the number of dimensions should panic\n    let _d = input.roll(&[1, 2], &[0]);\n}\n\n#[test]\nfn test_roll_dim() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    input\n        .clone()\n        .roll_dim(1, 0)\n        .to_data()\n        .assert_eq(&TensorData::from([[3, 4, 5], [0, 1, 2]]), false);\n\n    input\n        .clone()\n        .roll_dim(-1, 1)\n        .to_data()\n        .assert_eq(&TensorData::from([[2, 0, 1], [5, 3, 4]]), false);\n}\n\n#[should_panic]\n#[test]\nfn test_roll_dim_dim_too_big() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // Attempting to roll on a dimension that doesn't exist should panic\n    let _d = input.roll_dim(1, 2);\n}\n\n#[should_panic]\n#[test]\nfn test_roll_dim_dim_too_small() {\n    let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);\n\n    // Attempting to roll on a dimension that doesn't exist should panic\n    let _d = input.roll_dim(1, -3);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/select.rs",
    "content": "use super::*;\nuse burn_tensor::{IndexingUpdateOp, TensorData};\n\n#[test]\nfn should_select_1d_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([5, 6, 7], &device);\n    let indices = TestTensorInt::from_data([1, 1, 0, 1, 2], &device);\n\n    let output = tensor.select(0, indices);\n    let expected = TensorData::from([6, 6, 5, 6, 7]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_select_add_1d_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([7, 8, 9], &device);\n    let values = TestTensorInt::from_data([5, 4, 3, 2, 1], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1, 1, 0, 1, 2]), &device);\n\n    let output = tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n    let expected = TensorData::from([10, 19, 10]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn should_panic_select_add_invalid_num_indices() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([0; 12], &device);\n    let values = TestTensorInt::from_data([1; 12], &device);\n    let indices = TestTensorInt::from_data(TensorData::from([1]), &device);\n\n    tensor.select_assign(0, indices, values, IndexingUpdateOp::Add);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/sign.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_sign_ops_int() {\n    let tensor = TestTensorInt::<2>::from([[-2, -1, 2], [3, 0, -5]]);\n\n    let output = tensor.sign();\n    let expected = TensorData::from([[-1, -1, 1], [1, 0, -1]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/slice.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, s};\n\n#[test]\nfn slice_should_not_corrupt_potentially_inplace_operations() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);\n    let tensor = tensor.clone().slice([0..3]) + tensor.clone().slice([2..5]);\n\n    let expected = TensorData::from([4, 6, 8]);\n\n    tensor.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_int_tensor_with_steps() {\n    let device = Default::default();\n    let tensor =\n        TestTensorInt::<2>::from_data([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], &device);\n\n    // Test step=2 along first dimension\n    let sliced = tensor.clone().slice([s![0..3;2]]);\n    let expected = TensorData::from([[1i32, 2, 3, 4], [9, 10, 11, 12]]);\n    sliced.into_data().assert_eq(&expected, false);\n\n    // Test step=-1 along second dimension\n    let sliced = tensor.clone().slice(s![.., 0..4;-1]);\n    let expected = TensorData::from([[4i32, 3, 2, 1], [8, 7, 6, 5], [12, 11, 10, 9]]);\n    sliced.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/slice_assign.rs",
    "content": "use super::*;\nuse burn_tensor::{TensorData, s};\n\n#[test]\nfn slice_assign_should_not_corrupt_potentially_inplace_operations() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device);\n    let values = TestTensorInt::<1>::from_data([10, 20, 30], &device);\n    let tensor_1 = tensor.clone().slice_assign([0..3], values);\n    let tensor_2 = tensor + 2;\n\n    let expected = TensorData::from([10, 20, 30, 4, 5]);\n\n    tensor_1.into_data().assert_eq(&expected, false);\n\n    let expected = TensorData::from([3, 4, 5, 6, 7]);\n\n    tensor_2.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_slice_assign_int_tensor_with_steps() {\n    let device = Default::default();\n    let tensor =\n        TestTensorInt::<2>::from_data([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], &device);\n\n    // Test step=2 along first dimension\n    let values =\n        TestTensorInt::<2>::from_data([[100, 101, 102, 103], [200, 201, 202, 203]], &device);\n    let output = tensor.clone().slice_assign([s![0..3;2]], values);\n    let expected = TensorData::from([[100i32, 101, 102, 103], [5, 6, 7, 8], [200, 201, 202, 203]]);\n    output.into_data().assert_eq(&expected, false);\n\n    // Test step=-1 along second dimension\n    let values = TestTensorInt::<2>::from_data(\n        [[40, 30, 20, 10], [80, 70, 60, 50], [120, 110, 100, 90]],\n        &device,\n    );\n    let output = tensor.slice_assign(s![.., 0..4;-1], values);\n    let expected = TensorData::from([[10i32, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]);\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_slice_assign_empty_range_int() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::from_data([1, 2, 3, 4, 5], &device);\n    let values: TestTensorInt<1> = TestTensorInt::empty([0], &device);\n\n    // Empty slice assignment for int tensor\n    let output = tensor.clone().slice_assign([3..3], values);\n    let expected = TensorData::from([1i32, 2, 3, 4, 5]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/sort_argsort.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_sort_1d_int() {\n    // Skip with u8\n    if (IntElem::MAX as u32) < 1000u32 {\n        return;\n    }\n\n    let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, 2, 8, -10, 42, 1000]);\n\n    // Sort along dim=0\n    let values = tensor.sort(0);\n    let values_expected = TensorData::from([-10, 0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 42, 1000]);\n\n    values.into_data().assert_eq(&values_expected, false);\n}\n\n#[test]\nfn test_argsort_1d_int() {\n    // Skip with u8\n    if (IntElem::MAX as u32) < 1000u32 {\n        return;\n    }\n\n    let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]);\n\n    // Sort along dim=0\n    let indices = tensor.argsort(0);\n    let indices_expected = TensorData::from([10, 7, 0, 3, 6, 1, 4, 5, 2, 9, 8, 11, 12]);\n\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_with_indices_descending_int() {\n    // Skip with u8\n    if (IntElem::MAX as u32) >= 1000u32 {\n        // 1D\n        let tensor = TestTensorInt::<1>::from([1, 4, 7, 2, 5, 6, 3, 0, 9, 8, -10, 42, 1000]);\n\n        // Sort along dim=0\n        let (values, indices) = tensor.sort_descending_with_indices(0);\n\n        let values_expected = TensorData::from([1000, 42, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -10]);\n        values.into_data().assert_eq(&values_expected, false);\n\n        let indices_expected = TensorData::from([12, 11, 8, 9, 2, 5, 4, 1, 6, 3, 0, 7, 10]);\n        indices.into_data().assert_eq(&indices_expected, false);\n    }\n\n    // 2D\n    let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.sort_descending_with_indices(1);\n\n    let values_expected = TensorData::from([[[2, 5, 7], [1, 4, 6]], [[8, 2, 9], [3, 0, 8]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    let indices_expected = TensorData::from([[[1, 1, 0], [0, 0, 1]], [[1, 1, 0], [0, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_int() {\n    let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);\n\n    // Sort along dim=0\n    let values = tensor.clone().sort(0);\n\n    let values_expected = TensorData::from([[[1, 0, 7], [2, 2, 6]], [[3, 4, 9], [8, 5, 8]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    // Sort along dim=1\n    let values = tensor.clone().sort(1);\n\n    let values_expected = TensorData::from([[[1, 4, 6], [2, 5, 7]], [[3, 0, 8], [8, 2, 9]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    // Sort along dim=2\n    let values = tensor.sort(2);\n\n    let values_expected = TensorData::from([[[1, 4, 7], [2, 5, 6]], [[0, 3, 9], [2, 8, 8]]]);\n    values.into_data().assert_eq(&values_expected, false);\n}\n\n#[test]\nfn test_sort_with_indices_int() {\n    let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [7, 2, 8]]]);\n\n    // Sort along dim=0\n    let (values, indices) = tensor.clone().sort_with_indices(0);\n\n    let values_expected = TensorData::from([[[1, 0, 7], [2, 2, 6]], [[3, 4, 9], [7, 5, 8]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    let indices_expected = TensorData::from([[[0, 1, 0], [0, 1, 0]], [[1, 0, 1], [1, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let (values, indices) = tensor.clone().sort_with_indices(1);\n\n    let values_expected = TensorData::from([[[1, 4, 6], [2, 5, 7]], [[3, 0, 8], [7, 2, 9]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    let indices_expected = TensorData::from([[[0, 0, 1], [1, 1, 0]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let (values, indices) = tensor.sort_with_indices(2);\n\n    let values_expected = TensorData::from([[[1, 4, 7], [2, 5, 6]], [[0, 3, 9], [2, 7, 8]]]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    let indices_expected = TensorData::from([[[0, 1, 2], [0, 1, 2]], [[1, 0, 2], [1, 0, 2]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_argsort_int() {\n    let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [7, 2, 8]]]);\n\n    // Sort along dim=0\n    let indices = tensor.clone().argsort(0);\n\n    let indices_expected = TensorData::from([[[0, 1, 0], [0, 1, 0]], [[1, 0, 1], [1, 0, 1]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=1\n    let indices = tensor.clone().argsort(1);\n\n    let indices_expected = TensorData::from([[[0, 0, 1], [1, 1, 0]], [[0, 0, 1], [1, 1, 0]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n\n    // Sort along dim=2\n    let indices = tensor.argsort(2);\n\n    let indices_expected = TensorData::from([[[0, 1, 2], [0, 1, 2]], [[1, 0, 2], [1, 0, 2]]]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n\n#[test]\nfn test_sort_descending_1d() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);\n\n    // Sort along dim=0\n    let values = tensor.sort_descending(0);\n\n    let values_expected = TensorData::from([5, 4, 3, 2, 1]);\n    values.into_data().assert_eq(&values_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/stack.rs",
    "content": "use super::*;\nuse alloc::vec;\nuse burn_tensor::{Tensor, TensorData};\n\n#[test]\nfn should_support_stack_ops_int() {\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data([[1, 2, 3]], &device);\n    let tensor_2 = TestTensorInt::<2>::from_data([[4, 5, 6]], &device);\n\n    let output = Tensor::stack::<3>(vec![tensor_1, tensor_2], 0);\n    let expected = TensorData::from([[[1, 2, 3]], [[4, 5, 6]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_generate_row_major_layout() {\n    let device = Default::default();\n    let tensor = TestTensorInt::<1>::arange(1..25, &device).reshape([4, 6]);\n    let zeros = TestTensorInt::zeros([4, 6], &device);\n    let intersperse =\n        Tensor::stack::<3>([tensor.clone(), zeros.clone()].to_vec(), 2).reshape([4, 12]);\n\n    let expected = TensorData::from([\n        [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0],\n        [7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0],\n        [13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0],\n        [19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0],\n    ]);\n\n    intersperse.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/sub.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_sub_ops_int() {\n    let data_1 = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let data_2 = TensorData::from([[6, 7, 8], [9, 10, 11]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-6, -6, -6], [-6, -6, -6]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_sub_broadcast_int() {\n    let data_1 = TensorData::from([[0, 1, 2]]);\n    let data_2 = TensorData::from([[3, 4, 5], [6, 7, 8]]);\n    let device = Default::default();\n    let tensor_1 = TestTensorInt::<2>::from_data(data_1, &device);\n    let tensor_2 = TestTensorInt::<2>::from_data(data_2, &device);\n\n    let output = tensor_1 - tensor_2;\n    let expected = TensorData::from([[-3, -3, -3], [-6, -6, -6]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_sub_scalar_ops_int() {\n    let data = TensorData::from([[0, 1, 2], [3, 4, 5]]);\n    let scalar = 2;\n    let tensor = TestTensorInt::<2>::from_data(data, &Default::default());\n\n    let output = tensor - scalar;\n    let expected = TensorData::from([[-2, -1, 0], [1, 2, 3]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/take.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_take_int_tensor() {\n    // Test take with integer tensors\n    let device = Default::default();\n    let tensor = TestTensorInt::<2>::from_data([[10, 20, 30], [40, 50, 60]], &device);\n    let indices = TestTensorInt::<1>::from_data([1, 0], &device);\n\n    let output = tensor.take::<1, 2>(0, indices);\n    let expected = TensorData::from([[40, 50, 60], [10, 20, 30]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_take_int_tensor_with_2d_indices() {\n    // Test take with integer tensors - output will be 3D\n    let device = Default::default();\n    let tensor = TestTensorInt::<2>::from_data([[10, 20, 30], [40, 50, 60], [70, 80, 90]], &device);\n\n    // 2D indices - shape [2, 2]\n    let indices = TestTensorInt::<2>::from_data([[0, 2], [2, 1]], &device);\n    let output = tensor.take::<2, 3>(0, indices);\n\n    // Expected: shape [2, 2, 3]\n    let expected = TensorData::from([[[10, 20, 30], [70, 80, 90]], [[70, 80, 90], [40, 50, 60]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/topk.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\nuse burn_tensor::Tolerance;\n\n#[test]\nfn test_topk_1d() {\n    // Int\n    let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);\n\n    let values = tensor.topk(3, /*dim*/ 0);\n    let expected = TensorData::from([5, 4, 3]);\n\n    values.into_data().assert_eq(&expected, false);\n\n    // Float\n    let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]);\n\n    let values = tensor.topk(3, /*dim*/ 0);\n    let expected = TensorData::from([5., 4., 3.]);\n\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_topk() {\n    // 3D Int\n    let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);\n\n    let values = tensor.topk(2, /*dim*/ 2);\n    let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]);\n\n    values.into_data().assert_eq(&expected, false);\n\n    // 3D Float\n    let tensor =\n        TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);\n\n    let values = tensor.topk(2, /*dim*/ 2);\n    let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]);\n\n    values\n        .into_data()\n        .assert_approx_eq::<FloatElem>(&expected, Tolerance::default());\n}\n\n#[test]\nfn test_topk_with_indices_1d() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);\n\n    let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0);\n\n    let values_expected = TensorData::from([5, 4, 3]);\n    values.into_data().assert_eq(&values_expected, false);\n\n    let indices_expected = TensorData::from([4, 3, 2]);\n    indices.into_data().assert_eq(&indices_expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/transpose.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn should_support_transpose_ops_int() {\n    let tensor = TestTensorInt::<3>::from_data(\n        [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],\n        &Default::default(),\n    );\n\n    let output = tensor.transpose();\n    let expected = TensorData::from([[[0, 3], [1, 4], [2, 5]], [[6, 9], [7, 10], [8, 11]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn should_support_swap_dims_int() {\n    let tensor = TestTensorInt::<3>::from_data(\n        [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],\n        &Default::default(),\n    );\n\n    let output = tensor.swap_dims(0, 2);\n    let expected = TensorData::from([[[0, 6], [3, 9]], [[1, 7], [4, 10]], [[2, 8], [5, 11]]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/tri.rs",
    "content": "use super::*;\nuse burn_tensor::TensorData;\n\n#[test]\nfn test_triu_negative_diagonal() {\n    let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]);\n\n    let output = tensor.triu(-1);\n    let expected = TensorData::from([[1, 1, 1], [1, 1, 1], [0, 1, 1]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_triu_batch_tensors() {\n    let tensor = TestTensorInt::<4>::from([\n        [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]],\n        [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]],\n    ]);\n    let output = tensor.triu(1);\n    let expected = TensorData::from([\n        [[[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]],\n        [[[0, 1, 1, 1], [0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn test_triu_too_few_dims() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let _output = tensor.triu(0);\n}\n\n#[test]\nfn test_tril() {\n    let tensor = TestTensor::<2>::from([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]);\n    let output = tensor.tril(0);\n    let expected = TensorData::from([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_tril_positive_diagonal() {\n    let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]);\n\n    let output = tensor.tril(1);\n    let expected = TensorData::from([[1, 1, 0], [1, 1, 1], [1, 1, 1]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_tril_negative_diagonal() {\n    let tensor = TestTensorInt::<2>::from([[1, 1, 1], [1, 1, 1], [1, 1, 1]]);\n\n    let output = tensor.tril(-1);\n    let expected = TensorData::from([[0, 0, 0], [1, 0, 0], [1, 1, 0]]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\nfn test_tril_batch_tensors() {\n    let tensor = TestTensorInt::<4>::from([\n        [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]],\n        [[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]],\n    ]);\n    let output = tensor.tril(1);\n    let expected = TensorData::from([\n        [[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 1]]],\n        [[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1], [1, 1, 1, 1]]],\n    ]);\n\n    output.into_data().assert_eq(&expected, false);\n}\n\n#[test]\n#[should_panic]\nfn test_tril_too_few_dims() {\n    let tensor = TestTensorInt::<1>::from([1, 2, 3]);\n    let _output = tensor.tril(0);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/ops/unfold.rs",
    "content": "use super::*;\nuse burn_tensor::Distribution;\nuse burn_tensor::s;\n\n#[test]\nfn test_unfold_int() {\n    // Distribution::Default samples from [0, 255)\n    if (IntElem::MAX as u32) < 255 - 1 {\n        return;\n    }\n    let device = Default::default();\n\n    let input = TestTensorInt::<3>::random([2, 6, 6], Distribution::Default, &device);\n\n    let dim = 1;\n    let size = 3;\n    let step = 2;\n    let actual: TestTensorInt<4> = input.clone().unfold(dim, size, step);\n\n    let expected = TestTensorInt::<4>::empty([2, 2, 6, 3], &device)\n        .slice_assign(\n            s![.., 0, .., ..],\n            input\n                .clone()\n                .slice(s![.., 0..3, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        )\n        .slice_assign(\n            s![.., 1, .., ..],\n            input\n                .clone()\n                .slice(s![.., 2..5, ..])\n                .swap_dims(1, 2)\n                .unsqueeze_dim::<4>(1),\n        );\n\n    actual.to_data().assert_eq(&expected.to_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/int/primitive.rs",
    "content": "use super::*;\nuse burn_tensor::{Element, Shape};\n\n#[test]\nfn should_support_int_dtype() {\n    let tensor = TestTensorInt::<2>::from([[0, -1, 2], [3, 4, -5]]).into_primitive();\n\n    assert_eq!(\n        burn_tensor::TensorMetadata::shape(&tensor),\n        Shape::new([2, 3])\n    );\n    assert_eq!(\n        burn_tensor::TensorMetadata::dtype(&tensor),\n        IntElem::dtype() // default int elem type\n    );\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/mod.rs",
    "content": "pub use super::*; // re-export test types\n\nmod clone_invariance;\n#[cfg(feature = \"std\")]\nmod multi_threads;\n\n// Data types\nmod bool;\nmod float;\nmod int;\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor/multi_threads.rs",
    "content": "use super::*;\nuse core::time::Duration;\nuse std::sync::{\n    Arc,\n    atomic::{AtomicU32, Ordering},\n};\n\nstruct MultiThreadTestSettings {\n    num_threads: usize,\n    // The number of operations that are applied while the tensor is still alive and has a\n    // reference count > 1 on the new thread.\n    num_ops_alive: usize,\n    // The number of operations that are applied after the tensor is consumed for the last time.\n    num_ops_consumed: usize,\n    // Number of operations that needs to execute before continuing execution on the main thread.\n    sleep_before: Duration,\n    sleep_alive: Duration,\n    sleep_consumed: Duration,\n    // If the output is dropped, otherwise it will be consumed by an operation.\n    dropped: bool,\n}\n\n#[test]\nfn should_handle_multi_threads_dropped() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 5,\n        num_ops_consumed: 5,\n        sleep_before: Duration::from_millis(100),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: true,\n    })\n}\n\n#[test]\nfn should_handle_multi_threads_consumed() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 5,\n        num_ops_consumed: 5,\n        sleep_before: Duration::from_millis(100),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: false,\n    })\n}\n\n#[test]\nfn should_handle_multi_threads_drop_no_wait() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 5,\n        num_ops_consumed: 5,\n        sleep_before: Duration::from_millis(100),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: true,\n    })\n}\n\n#[test]\nfn should_handle_multi_threads_consumed_no_wait() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 5,\n        num_ops_consumed: 5,\n        sleep_before: Duration::from_millis(100),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: false,\n    })\n}\n\n#[test]\nfn should_handle_multi_threads_no_async_op() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 0,\n        num_ops_consumed: 0,\n        sleep_before: Duration::from_millis(100),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: false,\n    })\n}\n\n// Skip on metal - flaky (works when ran alone)\n// Enable once this issue is fixed: https://github.com/tracel-ai/burn/issues/4328\n#[cfg(not(feature = \"metal\"))]\n#[test]\nfn should_handle_multi_threads_no_async_op_no_wait() {\n    run_multi_thread_test(MultiThreadTestSettings {\n        num_threads: 3,\n        num_ops_alive: 0,\n        num_ops_consumed: 0,\n        sleep_before: Duration::from_millis(0),\n        sleep_alive: Duration::from_millis(100),\n        sleep_consumed: Duration::from_millis(100),\n        dropped: false,\n    })\n}\n\nfn run_multi_thread_test(settings: MultiThreadTestSettings) {\n    let tensor = TestTensor::<2>::from([[0.0, -1.0, 2.0], [3.0, 4.0, -5.0]]);\n\n    let mut joined = Vec::with_capacity(settings.num_threads);\n\n    let counter_alive = Arc::new(AtomicU32::new(0));\n    let counter_consumed = Arc::new(AtomicU32::new(0));\n\n    for i in 0..settings.num_threads {\n        let tensor_moved = tensor.clone();\n        let ca_moved = counter_alive.clone();\n        let cc_moved = counter_consumed.clone();\n\n        let handle = std::thread::spawn(move || {\n            let mut base = tensor_moved.clone();\n            std::thread::sleep(settings.sleep_before);\n\n            if settings.num_ops_alive == 0 && settings.num_ops_consumed == 0 {\n                core::mem::drop(tensor_moved);\n                core::mem::drop(base);\n            } else {\n                if settings.num_ops_alive > 1 {\n                    for j in 0..(settings.num_ops_alive - 1) {\n                        base = tensor_moved.clone() + j as u32;\n                        ca_moved.fetch_add(1, Ordering::Relaxed);\n                        std::thread::sleep(settings.sleep_alive);\n                    }\n                }\n\n                base = base * tensor_moved + i as u32;\n                ca_moved.fetch_add(1, Ordering::Relaxed);\n\n                for n in 0..settings.num_ops_consumed {\n                    base = base + n as i32;\n                    cc_moved.fetch_add(1, Ordering::Relaxed);\n                    std::thread::sleep(settings.sleep_consumed);\n                }\n                let _data = base.into_data();\n            }\n        });\n        joined.push(handle);\n    }\n\n    fn wait(counter: Arc<AtomicU32>, limit: usize) {\n        loop {\n            let counter_curr = counter.load(Ordering::Relaxed);\n            if counter_curr as usize >= limit {\n                break;\n            } else {\n                std::thread::sleep(Duration::from_millis(10));\n            }\n        }\n    }\n\n    wait(counter_alive, settings.num_ops_alive);\n    wait(counter_consumed, settings.num_ops_consumed);\n\n    if settings.dropped {\n        core::mem::drop(tensor);\n    } else {\n        let t = tensor * 2.0;\n        let _t = t.into_data();\n    }\n\n    for j in joined {\n        j.join().unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-backend-tests/tests/tensor.rs",
    "content": "//! Burn backend tensor tests.\n\n#![allow(clippy::single_range_in_vec_init, reason = \"false positive\")]\nextern crate alloc;\n\npub type FloatElemType = f32;\n#[allow(unused)]\npub type IntElemType = i32;\n\n#[path = \"common/backend.rs\"]\nmod backend;\npub use backend::*;\n\n#[path = \"common/tensor.rs\"]\nmod tensor;\n"
  },
  {
    "path": "crates/burn-candle/Cargo.toml",
    "content": "[package]\nauthors = [\"louisfd <louisfd94@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"[Deprecated] Candle backend for the Burn framework - use burn-cubecl, burn-ndarray, or burn-tch instead\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-candle\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-candle\"\ndocumentation = \"https://docs.rs/burn-candle\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\nstd = []\ndoc = [\"default\"]\ntracing = [\n    \"burn-backend/tracing\",\n    \"burn-std/tracing\",\n]\n\ncuda = [\"candle-core/cuda\"]\nmetal = [\"candle-core/metal\"]\naccelerate = [\"candle-core/accelerate\"]\n\n[dependencies]\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n# For rand utils and stub mutex\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\n\ncandle-core = { workspace = true }\nderive-new = { workspace = true }\n\n[dev-dependencies]\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-candle/README.md",
    "content": "# Burn Candle Backend\n\n> **Deprecated:** This crate is deprecated as of `0.21.0-pre.2` and will be removed in a future release.\n> Please migrate to one of the actively maintained backends:\n> - **CubeCL backends** (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration\n> - **NdArray** for portable CPU execution\n> - **LibTorch** (`burn-tch`) for a mature CPU/GPU backend\n\nThis crate provides a backend for [Burn](https://github.com/tracel-ai/burn) based on the [Candle](https://github.com/huggingface/candle) framework.\n\n## Feature Flags\n\n- `cuda` - Cuda GPU device (NVIDIA only)\n- `accelerate` - Accelerate framework (macOS only)\n"
  },
  {
    "path": "crates/burn-candle/src/backend.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn_backend::{\n    BackTrace, Backend, DType, DTypeUsage, DeviceId, DeviceOps, ExecutionError, QTensorPrimitive,\n    tensor::Device,\n};\nuse burn_std::{\n    rand::{SeedableRng, StdRng},\n    stub::Mutex,\n};\nuse candle_core::{DeviceLocation, backend::BackendDevice};\n\nuse crate::{\n    CandleTensor, IntoDType,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n};\n\n/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.\n///\n/// It is compatible with a wide range of hardware configurations, including CPUs and GPUs\n/// that support CUDA or Metal. Additionally, the backend can be compiled to `wasm` when using the CPU.\n#[derive(Clone, Default, Debug)]\npub struct Candle<F = f32, I = i64>\nwhere\n    F: FloatCandleElement,\n    I: IntCandleElement,\n{\n    _float: PhantomData<F>,\n    _int: PhantomData<I>,\n}\n\n// Seed for CPU device\npub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);\n\npub(crate) fn get_seeded_rng() -> StdRng {\n    let mut seed = SEED.lock().unwrap();\n    seed.take().unwrap_or_else(burn_std::rand::get_seeded_rng)\n}\n\npub(crate) fn set_seeded_rng(rng_seeded: StdRng) {\n    let mut seed = SEED.lock().unwrap();\n    *seed = Some(rng_seeded);\n}\n\n/// The device type for the candle backend.\n#[derive(Clone, Debug, PartialEq, Eq)]\n/// The device struct when using the `candle` backend.\n///\n/// To create a Cuda or Metal device from the index, use the associated methods to create the variant:\n/// ```no_run\n/// use burn_candle::CandleDevice;\n///\n/// // Create a Cuda device from its index\n/// let device = CandleDevice::cuda(0);\n/// // Create a Metal device from its index\n/// let device = CandleDevice::metal(0);\n/// ```\n#[derive(Default)]\npub enum CandleDevice {\n    /// CPU device.\n    #[default]\n    Cpu,\n\n    /// Cuda device with the given index. The index is the index of the Cuda device in the list of\n    /// all Cuda devices found on the system.\n    Cuda(CudaDevice),\n\n    /// Metal device with the given index. The index is the index of the Metal device in the list of\n    /// all Metal devices found on the system.\n    Metal(MetalDevice),\n}\n\nimpl CandleDevice {\n    /// Create a Cuda device with the given index.\n    /// The index is the index of the Cuda device in the list of all Cuda devices found on the system.\n    pub fn cuda(index: usize) -> Self {\n        CandleDevice::Cuda(CudaDevice {\n            device: candle_core::CudaDevice::new(index).unwrap(),\n            index,\n        })\n    }\n\n    /// Create a Metal device with the given index.\n    /// The index is the index of the Metal device in the list of all Metal devices found on the system.\n    pub fn metal(index: usize) -> Self {\n        CandleDevice::Metal(MetalDevice {\n            device: candle_core::MetalDevice::new(index).unwrap(),\n            index,\n        })\n    }\n\n    pub(crate) fn set_seed(&self, seed: u64) {\n        match self {\n            CandleDevice::Cpu => {\n                // candle_core::cpu_backend::CpuDevice.set_seed(seed).unwrap();\n                // Candle does not support seeding the CPU rng so we use a global seed\n                let rng = StdRng::seed_from_u64(seed);\n                set_seeded_rng(rng);\n            }\n            CandleDevice::Cuda(cuda_device) => cuda_device.device.set_seed(seed).unwrap(),\n            CandleDevice::Metal(metal_device) => metal_device.device.set_seed(seed).unwrap(),\n        }\n    }\n}\n\n#[derive(Clone, Debug)]\n/// A Cuda device for the `candle` backend.\npub struct CudaDevice {\n    pub(crate) device: candle_core::CudaDevice,\n    /// The index of the Cuda device in the list of all devices on the system.\n    pub index: usize,\n}\n\nimpl PartialEq for CudaDevice {\n    fn eq(&self, other: &Self) -> bool {\n        self.device.same_device(&other.device) && self.index == other.index\n    }\n}\n\nimpl Eq for CudaDevice {}\n\n#[derive(Clone, Debug)]\n/// A Metal device for the `candle` backend.\npub struct MetalDevice {\n    pub(crate) device: candle_core::MetalDevice,\n    /// The index of the Metal device in the list of all devices on the system.\n    pub index: usize,\n}\n\nimpl PartialEq for MetalDevice {\n    fn eq(&self, other: &Self) -> bool {\n        self.device.same_device(&other.device) && self.index == other.index\n    }\n}\n\nimpl Eq for MetalDevice {}\n\nimpl From<CandleDevice> for candle_core::Device {\n    fn from(device: CandleDevice) -> Self {\n        match device {\n            CandleDevice::Cpu => candle_core::Device::Cpu,\n            CandleDevice::Cuda(device) => candle_core::Device::Cuda(device.device),\n            CandleDevice::Metal(device) => candle_core::Device::Metal(device.device),\n        }\n    }\n}\n\nimpl From<candle_core::Device> for CandleDevice {\n    fn from(device: candle_core::Device) -> Self {\n        match device.location() {\n            DeviceLocation::Cpu => CandleDevice::Cpu,\n            DeviceLocation::Cuda { gpu_id } => {\n                if let candle_core::Device::Cuda(device) = device {\n                    CandleDevice::Cuda(CudaDevice {\n                        device,\n                        index: gpu_id,\n                    })\n                } else {\n                    panic!(\"Expected CUDA device.\");\n                }\n            }\n            DeviceLocation::Metal { gpu_id } => {\n                if let candle_core::Device::Metal(device) = device {\n                    CandleDevice::Metal(MetalDevice {\n                        device,\n                        index: gpu_id,\n                    })\n                } else {\n                    panic!(\"Expected Metal device.\");\n                }\n            }\n        }\n    }\n}\n\nimpl burn_backend::Device for CandleDevice {\n    fn to_id(&self) -> burn_backend::DeviceId {\n        match self {\n            CandleDevice::Cuda(device) => DeviceId::new(0, device.index as u32),\n            CandleDevice::Metal(device) => DeviceId::new(1, device.index as u32),\n            CandleDevice::Cpu => DeviceId::new(2, 0),\n        }\n    }\n\n    fn from_id(device_id: DeviceId) -> Self {\n        match device_id.type_id {\n            0 => CandleDevice::cuda(device_id.index_id as usize),\n            1 => CandleDevice::metal(device_id.index_id as usize),\n            _ => CandleDevice::Cpu,\n        }\n    }\n\n    fn device_count(type_id: u16) -> usize {\n        // TODO: Fix that\n        1\n    }\n}\nimpl DeviceOps for CandleDevice {}\n\nimpl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {\n    type Device = CandleDevice;\n\n    type FloatTensorPrimitive = CandleTensor;\n    type FloatElem = F;\n\n    type IntTensorPrimitive = CandleTensor;\n    type IntElem = I;\n\n    type BoolTensorPrimitive = CandleTensor;\n    type BoolElem = u8;\n\n    type QuantizedTensorPrimitive = CandleTensor;\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    fn name(device: &Self::Device) -> String {\n        match device {\n            CandleDevice::Cpu => \"candle<cpu>\",\n            CandleDevice::Cuda(..) => \"candle<cuda>\",\n            CandleDevice::Metal(..) => \"candle<metal>\",\n        }\n        .to_string()\n    }\n\n    fn seed(device: &CandleDevice, seed: u64) {\n        device.set_seed(seed);\n    }\n\n    fn sync(device: &Device<Self>) -> Result<(), ExecutionError> {\n        let device: candle_core::Device = (device.clone()).into();\n\n        match device {\n            candle_core::Device::Cpu => (),\n            candle_core::Device::Cuda(device) => {\n                #[cfg(feature = \"cuda\")]\n                device\n                    .synchronize()\n                    .map_err(|err| ExecutionError::Generic {\n                        reason: format!(\"Can't sync the cuda device: {err}\"),\n                        backtrace: BackTrace::capture(),\n                    })?;\n            }\n            candle_core::Device::Metal(device) => {\n                // For some reason, device.wait_until_completed() does not seem to work,\n                // and neither does writing and reading a value with into_data\n                return Err(ExecutionError::Generic {\n                    reason:\n                        \"Device synchronization unavailable with Metal device on Candle backend\"\n                            .into(),\n                    backtrace: BackTrace::capture(),\n                });\n            }\n        }\n\n        Ok(())\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {\n        if dtype.try_into_dtype().is_ok() {\n            burn_backend::DTypeUsage::general()\n        } else {\n            burn_backend::DTypeUsageSet::empty()\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_std::{BoolStore, QuantScheme};\n\n    use super::*;\n\n    #[test]\n    fn should_support_dtypes() {\n        type B = Candle<f32>;\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F64));\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::Flex32));\n        assert!(B::supports_dtype(&device, DType::F16));\n        assert!(B::supports_dtype(&device, DType::BF16));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::U32));\n        assert!(B::supports_dtype(&device, DType::U8));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::I16));\n        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::U8)));\n\n        assert!(!B::supports_dtype(&device, DType::U64));\n        assert!(!B::supports_dtype(&device, DType::U16));\n        assert!(!B::supports_dtype(&device, DType::I8));\n        assert!(!B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n        assert!(!B::supports_dtype(\n            &device,\n            DType::QFloat(QuantScheme::default())\n        ));\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/element.rs",
    "content": "use std::borrow::Borrow;\n\nuse burn_backend::{Element, bf16, f16};\nuse candle_core::{FloatDType, Tensor, WithDType};\n\n/// Candle element\npub trait CandleElement: Element + WithDType {}\n/// Candle float element\npub trait FloatCandleElement: CandleElement + FloatDType {}\n/// Candle int element\npub trait IntCandleElement: CandleElement {}\n\nimpl CandleElement for f64 {}\nimpl FloatCandleElement for f64 {}\n\nimpl CandleElement for f32 {}\nimpl FloatCandleElement for f32 {}\n\nimpl CandleElement for f16 {}\nimpl FloatCandleElement for f16 {}\n\nimpl CandleElement for bf16 {}\nimpl FloatCandleElement for bf16 {}\n\nimpl CandleElement for u8 {}\nimpl IntCandleElement for u8 {}\n\nimpl CandleElement for u32 {}\nimpl IntCandleElement for u32 {}\n\nimpl CandleElement for i64 {}\nimpl IntCandleElement for i64 {}\n"
  },
  {
    "path": "crates/burn-candle/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![allow(unused)] // TODO remove when backend filled\n#![deprecated(\n    since = \"0.21.0\",\n    note = \"burn-candle is deprecated and will be removed in a future release. Use burn-cubecl (CUDA/ROCm/Vulkan/Metal/WebGPU), burn-ndarray, or burn-tch instead.\"\n)]\n\n//! Burn Candle Backend\n//!\n//! **Deprecated:** This backend is deprecated and will be removed in a future release.\n//! Please migrate to one of the actively maintained backends:\n//! - CubeCL backends (CUDA, ROCm, Vulkan, Metal, WebGPU) for GPU acceleration\n//! - NdArray for portable CPU execution\n//! - LibTorch (`burn-tch`) for a mature CPU/GPU backend\n\n#[macro_use]\nextern crate derive_new;\n\nmod backend;\nmod element;\nmod ops;\nmod tensor;\n\npub use backend::*;\npub use element::*;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-candle/src/ops/activation.rs",
    "content": "use burn_backend::{ops::ActivationOps, tensor::FloatTensor};\n\nuse crate::{\n    Candle, CandleTensor,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n    tensor,\n};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> ActivationOps<Self> for Candle<F, I> {\n    fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.gelu().unwrap())\n    }\n\n    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.relu().unwrap())\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/base.rs",
    "content": "use std::cmp::max;\nuse std::marker::PhantomData;\n\nuse crate::{\n    Candle, CandleDevice, CandleTensor,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n};\nuse burn_backend::{\n    BackTrace, Backend, Distribution, ExecutionError, Slice, bf16, f16,\n    ops::unfold::{calculate_unfold_shape, calculate_unfold_windows},\n};\nuse burn_backend::{Element, Shape, TensorData, TensorMetadata};\nuse candle_core::{Layout, WithDType};\n\nuse super::tensor;\n\npub fn cpu_random<E: CandleElement>(shape: Shape, distribution: Distribution) -> TensorData {\n    let mut rng = crate::get_seeded_rng();\n    let data = TensorData::random::<E, _, _>(shape, distribution, &mut rng);\n    crate::set_seeded_rng(rng);\n    data\n}\n\npub fn cat(tensors: Vec<CandleTensor>, dim: usize) -> CandleTensor {\n    let tensors: Vec<candle_core::Tensor> = tensors.into_iter().map(|t| t.tensor).collect();\n    CandleTensor::new(candle_core::Tensor::cat(&tensors, dim).unwrap())\n}\n\npub fn from_data<E: CandleElement>(data: TensorData, device: &CandleDevice) -> CandleTensor {\n    CandleTensor::from_data::<E>(data, device.clone())\n}\npub fn into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {\n    fn tensor_data_from_dtype<T: WithDType + Element>(\n        tensor: &CandleTensor,\n    ) -> Result<TensorData, ExecutionError> {\n        let data = tensor\n            .tensor\n            .flatten_all()\n            .map_err(|err| ExecutionError::Generic {\n                reason: format!(\"{err}\"),\n                backtrace: BackTrace::capture(),\n            })?\n            .to_vec1::<T>()\n            .map_err(|err| ExecutionError::Generic {\n                reason: format!(\"{err}\"),\n                backtrace: BackTrace::capture(),\n            })?;\n        Ok(TensorData::new(data, tensor.shape()))\n    }\n\n    match tensor.tensor.dtype() {\n        candle_core::DType::BF16 => tensor_data_from_dtype::<bf16>(&tensor),\n        candle_core::DType::F16 => tensor_data_from_dtype::<f16>(&tensor),\n        candle_core::DType::F32 => tensor_data_from_dtype::<f32>(&tensor),\n        candle_core::DType::F64 => tensor_data_from_dtype::<f64>(&tensor),\n        candle_core::DType::U8 => tensor_data_from_dtype::<u8>(&tensor),\n        candle_core::DType::U32 => tensor_data_from_dtype::<u32>(&tensor),\n        candle_core::DType::I16 => tensor_data_from_dtype::<i16>(&tensor),\n        candle_core::DType::I32 => tensor_data_from_dtype::<i32>(&tensor),\n        candle_core::DType::I64 => tensor_data_from_dtype::<i64>(&tensor),\n        other => todo!(\"{other:?} not yet supported\"),\n    }\n}\n\npub fn to_device(tensor: CandleTensor, device: &CandleDevice) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.to_device(&(device.clone()).into()).unwrap())\n}\n\npub fn empty(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {\n    zeros(shape, device, dtype)\n}\n\npub fn zeros(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {\n    CandleTensor::new(\n        candle_core::Tensor::zeros(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),\n    )\n}\n\npub fn ones(shape: Shape, device: &CandleDevice, dtype: candle_core::DType) -> CandleTensor {\n    CandleTensor::new(\n        candle_core::Tensor::ones(shape.to_vec(), dtype, &(device.clone()).into()).unwrap(),\n    )\n}\n\npub fn swap_dims(mut tensor: CandleTensor, dim1: usize, dim2: usize) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap())\n}\n\npub fn permute(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.permute(axes).unwrap())\n}\n\npub fn flip(tensor: CandleTensor, axes: &[usize]) -> CandleTensor {\n    // FIXME: Replace with an appropriate method when Candle provides one.\n    let mut tensor = tensor.tensor;\n    for &axis in axes {\n        // Ensure tensor is contiguous before index_select (required by Candle)\n        tensor = tensor.contiguous().unwrap();\n\n        let indexes = candle_core::Tensor::arange_step(\n            tensor.dim(axis).unwrap() as i64 - 1,\n            -1,\n            -1,\n            tensor.device(),\n        )\n        .unwrap();\n        tensor = tensor.index_select(&indexes, axis).unwrap();\n    }\n\n    CandleTensor::new(tensor)\n}\n\npub fn reshape(tensor: CandleTensor, shape: Shape) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.reshape(shape.to_vec()).unwrap())\n}\n\npub fn device(tensor: &CandleTensor) -> CandleDevice {\n    tensor.tensor.device().clone().into()\n}\n\npub fn shape(tensor: &CandleTensor) -> Shape {\n    tensor.shape()\n}\n\npub fn slice(tensor: CandleTensor, ranges: &[std::ops::Range<usize>]) -> CandleTensor {\n    let mut narrow_tensor = tensor.tensor;\n    for (i, range) in ranges.iter().enumerate().take(ranges.len()) {\n        narrow_tensor = narrow_tensor\n            .narrow(i, range.start, range.end - range.start)\n            .unwrap()\n    }\n    CandleTensor::new(narrow_tensor)\n}\n\npub fn slice_with_steps(tensor: CandleTensor, slices: &[Slice]) -> CandleTensor {\n    let mut result_tensor = tensor.tensor;\n\n    for (dim, slice) in slices.iter().enumerate() {\n        if slice.step == 1 {\n            // Use narrow for step=1 (more efficient)\n            // Convert slice to range using tensor shape\n            let dim_size = result_tensor.dim(dim).unwrap();\n            let range = slice.to_range(dim_size);\n            let start = range.start;\n            let length = range.end - range.start;\n            result_tensor = result_tensor.narrow(dim, start, length).unwrap();\n        } else {\n            // Use index_select for step != 1\n            let dim_size = result_tensor.dim(dim).unwrap();\n            let range = slice.to_range(dim_size);\n            let start = range.start;\n            let end = range.end;\n            let step = slice.step;\n\n            // Generate indices based on step direction\n            let indices_vec = if step > 0 {\n                // Forward stepping\n                let step_usize = step as usize;\n                (start..end).step_by(step_usize).collect::<Vec<_>>()\n            } else {\n                // Backward stepping (negative step)\n                let step_usize = step.unsigned_abs();\n                // Start from end-1 and go backwards\n                let mut indices = Vec::new();\n                let mut idx = end - 1;\n                while idx >= start && idx < end {\n                    // Check for underflow\n                    indices.push(idx);\n                    if idx >= step_usize {\n                        idx -= step_usize;\n                    } else {\n                        break;\n                    }\n                }\n                indices\n            };\n\n            // Convert indices to tensor and use index_select\n            let indices_len = indices_vec.len();\n            let device = result_tensor.device();\n            let indices = candle_core::Tensor::from_vec(\n                indices_vec.iter().map(|&x| x as u32).collect::<Vec<_>>(),\n                indices_len,\n                device,\n            )\n            .unwrap();\n\n            result_tensor = result_tensor.index_select(&indices, dim).unwrap();\n        }\n    }\n\n    CandleTensor::new(result_tensor)\n}\n\npub fn slice_assign(tensor: CandleTensor, slices: &[Slice], value: CandleTensor) -> CandleTensor {\n    // Check if all slices have step=1 (candle's native slice_assign requirement)\n    let all_unit_steps = slices.iter().all(|s| s.step == 1);\n\n    if all_unit_steps {\n        // Convert Slice to Range for candle's native slice_assign\n        let ranges: Vec<std::ops::Range<usize>> = slices\n            .iter()\n            .enumerate()\n            .map(|(dim, slice)| {\n                let dim_size = tensor.tensor.dim(dim).unwrap_or(usize::MAX);\n                slice.to_range(dim_size)\n            })\n            .collect();\n\n        CandleTensor::new(tensor.tensor.slice_assign(&ranges, &value.tensor).unwrap())\n    } else {\n        // Implement slice_assign with steps using scatter operations\n        slice_assign_with_steps_workaround(tensor, slices, value)\n    }\n}\n\n/// Implements slice_assign for non-unit steps using index operations\nfn slice_assign_with_steps_workaround(\n    tensor: CandleTensor,\n    slices: &[Slice],\n    value: CandleTensor,\n) -> CandleTensor {\n    let shape = tensor.shape();\n    let ndims = shape.num_dims();\n    let device = tensor.tensor.device();\n\n    // Generate indices for each dimension based on slice specifications\n    let indices_per_dim = generate_slice_indices(slices, &shape);\n\n    // Early return if no elements to assign\n    let total_elements: usize = indices_per_dim.iter().map(|v| v.len()).product();\n    if total_elements == 0 {\n        return tensor;\n    }\n\n    // Flatten tensors and get metadata\n    let value_flat = value.tensor.flatten_all().unwrap();\n    let strides = tensor.tensor.stride();\n    let tensor_shape = tensor.tensor.dims();\n\n    // Use a macro to handle different dtypes without code duplication\n    macro_rules! apply_slice_assign {\n        ($dtype:ty, $to_vec_fn:ident) => {{\n            let mut tensor_vec: Vec<$dtype> =\n                tensor.tensor.flatten_all().unwrap().$to_vec_fn().unwrap();\n            let value_vec: Vec<$dtype> = value_flat.$to_vec_fn().unwrap();\n\n            // Apply assignments using cartesian product of indices\n            for (value_idx, &value) in value_vec.iter().enumerate() {\n                let flat_idx = compute_flat_index(value_idx, &indices_per_dim, &strides);\n                if flat_idx < tensor_vec.len() {\n                    tensor_vec[flat_idx] = value;\n                }\n            }\n\n            candle_core::Tensor::from_vec(tensor_vec, tensor_shape, device).unwrap()\n        }};\n    }\n\n    use candle_core::DType;\n    let result = match tensor.tensor.dtype() {\n        DType::F32 => apply_slice_assign!(f32, to_vec1),\n        DType::F64 => apply_slice_assign!(f64, to_vec1),\n        DType::I64 => apply_slice_assign!(i64, to_vec1),\n        DType::U32 => apply_slice_assign!(u32, to_vec1),\n        DType::U8 => apply_slice_assign!(u8, to_vec1),\n        _ => panic!(\n            \"Unsupported dtype {:?} for slice_assign with steps\",\n            tensor.tensor.dtype()\n        ),\n    };\n\n    CandleTensor::new(result)\n}\n\n/// Generate indices for each dimension based on slice specifications\nfn generate_slice_indices(slices: &[Slice], tensor_dims: &[usize]) -> Vec<Vec<usize>> {\n    let ndims = tensor_dims.len();\n    let mut indices_per_dim = Vec::with_capacity(ndims);\n\n    // Process provided slices\n    for (dim_idx, slice) in slices.iter().enumerate() {\n        let dim_size = tensor_dims[dim_idx];\n        let range = slice.to_range(dim_size);\n        let indices = generate_stepped_indices(range.start, range.end, slice.step);\n        indices_per_dim.push(indices);\n    }\n\n    // Fill remaining dimensions with full ranges\n    for &dim_size in tensor_dims.iter().skip(slices.len()) {\n        indices_per_dim.push((0..dim_size).collect());\n    }\n\n    indices_per_dim\n}\n\n/// Generate indices for a single dimension with stepping\nfn generate_stepped_indices(start: usize, end: usize, step: isize) -> Vec<usize> {\n    if step > 0 {\n        // Forward stepping\n        (start..end).step_by(step as usize).collect()\n    } else if step < 0 {\n        // Backward stepping: start from end-1 and go backwards\n        let step_size = step.unsigned_abs();\n        let mut indices = Vec::new();\n        let mut idx = end.saturating_sub(1);\n\n        while idx >= start && idx < end {\n            indices.push(idx);\n            if idx >= step_size {\n                idx -= step_size;\n            } else {\n                break;\n            }\n        }\n        indices\n    } else {\n        // This branch should never be reached since step is validated to be non-zero\n        panic!(\"Step cannot be zero\")\n    }\n}\n\n/// Compute flat index from multi-dimensional indices using cartesian product logic\nfn compute_flat_index(\n    value_idx: usize,\n    indices_per_dim: &[Vec<usize>],\n    strides: &[usize],\n) -> usize {\n    let mut flat_idx = 0;\n    let mut remainder = value_idx;\n\n    // Convert value_idx to multi-dimensional indices and compute flat tensor index\n    for dim in (0..indices_per_dim.len()).rev() {\n        let dim_size = indices_per_dim[dim].len();\n        let idx_in_dim = remainder % dim_size;\n        remainder /= dim_size;\n\n        let actual_idx = indices_per_dim[dim][idx_in_dim];\n        flat_idx += actual_idx * strides[dim];\n    }\n\n    flat_idx\n}\n\npub fn narrow(tensor: CandleTensor, dim: usize, start: usize, length: usize) -> CandleTensor {\n    let tensor = tensor.tensor.narrow(dim, start, length);\n    match tensor {\n        Ok(tensor) => CandleTensor::new(tensor),\n        Err(e) => panic!(\"error narrow from Candle\"),\n    }\n}\n\npub fn chunk(tensor: CandleTensor, chunks: usize, dim: usize) -> Vec<CandleTensor> {\n    let tensors = tensor.tensor.chunk(chunks, dim);\n    match tensors {\n        Ok(tensors) => tensors.into_iter().map(CandleTensor::new).collect(),\n        Err(e) => panic!(\"error chunk from Candle\"),\n    }\n}\n\npub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.broadcast_as(shape.to_vec()).unwrap())\n}\n\npub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {\n    let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);\n    let windows = result_shape[dim];\n\n    let mut select_ranges = tensor.shape().into_ranges();\n    let new_axis = select_ranges.len();\n\n    let mut stack = Vec::with_capacity(windows);\n    for widx in 0..windows {\n        let start = widx * step;\n        let end = start + size;\n        select_ranges[dim] = start..end;\n\n        let mut window_slice = slice(tensor.clone(), &select_ranges);\n\n        window_slice = swap_dims(window_slice, dim, new_axis);\n        let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap());\n\n        stack.push(window_slice);\n    }\n    cat(stack, dim)\n}\n\npub fn sign(tensor: CandleTensor) -> CandleTensor {\n    CandleTensor::new(tensor.tensor.sign().unwrap())\n}\n\npub fn mask_where_broadcasted(\n    tensor: CandleTensor,\n    mask: CandleTensor,\n    value: CandleTensor,\n) -> CandleTensor {\n    let shape = tensor\n        .tensor\n        .shape()\n        .broadcast_shape_binary_op(mask.tensor.shape(), \"where_cond\")\n        .unwrap();\n\n    let mut tensor = tensor.tensor;\n    let mut mask = mask.tensor;\n    let mut value = value.tensor;\n\n    if shape != *tensor.shape() {\n        tensor = tensor.broadcast_as(shape.clone()).unwrap();\n    }\n    if shape != *mask.shape() {\n        mask = mask.broadcast_as(shape.clone()).unwrap();\n    }\n    if shape != *value.shape() {\n        value = value.broadcast_as(shape).unwrap();\n    }\n\n    CandleTensor::new(mask.where_cond(&value, &tensor).unwrap())\n}\n\npub fn cross(lhs: CandleTensor, rhs: CandleTensor, dim: usize) -> CandleTensor {\n    let shape_lhs = lhs.shape();\n    let shape_rhs = rhs.shape();\n    let ndims = shape_lhs.num_dims();\n\n    // Broadcast the shapes except along dim\n    let mut broadcast_shape = vec![0; ndims];\n    for (i, item) in broadcast_shape.iter_mut().enumerate().take(ndims) {\n        if i == dim {\n            *item = shape_lhs[i];\n        } else {\n            let l = shape_lhs[i];\n            let r = shape_rhs[i];\n            if l == r {\n                *item = l;\n            } else if l == 1 {\n                *item = r;\n            } else if r == 1 {\n                *item = l;\n            } else {\n                panic!(\"Tensors are not broadcastable along dimension {}\", i);\n            }\n        }\n    }\n\n    // Broadcast lhs and rhs\n    let lhs_broadcast = if shape_lhs == Shape::from(broadcast_shape.clone()) {\n        lhs\n    } else {\n        expand(lhs, Shape::from(broadcast_shape.clone()))\n    };\n    let rhs_broadcast = if shape_rhs == Shape::from(broadcast_shape.clone()) {\n        rhs\n    } else {\n        expand(rhs, Shape::from(broadcast_shape.clone()))\n    };\n\n    // Now, move dim to the last dimension\n    let mut perm = (0..ndims).collect::<Vec<_>>();\n    perm.remove(dim);\n    perm.push(dim);\n\n    let lhs_permuted = permute(lhs_broadcast, &perm);\n    let rhs_permuted = permute(rhs_broadcast, &perm);\n\n    // Reshape to (*, 3)\n    let total_elements = lhs_permuted.shape().num_elements();\n    let batch_size = total_elements / 3;\n    let lhs_reshaped = reshape(lhs_permuted, Shape::new([batch_size, 3]));\n    let rhs_reshaped = reshape(rhs_permuted, Shape::new([batch_size, 3]));\n\n    // Extract components using narrow and squeeze\n    let lhs_0 = CandleTensor::new(\n        lhs_reshaped\n            .tensor\n            .narrow(1, 0, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n    let lhs_1 = CandleTensor::new(\n        lhs_reshaped\n            .tensor\n            .narrow(1, 1, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n    let lhs_2 = CandleTensor::new(\n        lhs_reshaped\n            .tensor\n            .narrow(1, 2, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n    let rhs_0 = CandleTensor::new(\n        rhs_reshaped\n            .tensor\n            .narrow(1, 0, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n    let rhs_1 = CandleTensor::new(\n        rhs_reshaped\n            .tensor\n            .narrow(1, 1, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n    let rhs_2 = CandleTensor::new(\n        rhs_reshaped\n            .tensor\n            .narrow(1, 2, 1)\n            .unwrap()\n            .squeeze(1)\n            .unwrap(),\n    );\n\n    // Compute cross product components\n    let result_0 = CandleTensor::new(\n        lhs_1\n            .tensor\n            .mul(&rhs_2.tensor)\n            .unwrap()\n            .sub(&lhs_2.tensor.mul(&rhs_1.tensor).unwrap())\n            .unwrap(),\n    );\n    let result_1 = CandleTensor::new(\n        lhs_2\n            .tensor\n            .mul(&rhs_0.tensor)\n            .unwrap()\n            .sub(&lhs_0.tensor.mul(&rhs_2.tensor).unwrap())\n            .unwrap(),\n    );\n    let result_2 = CandleTensor::new(\n        lhs_0\n            .tensor\n            .mul(&rhs_1.tensor)\n            .unwrap()\n            .sub(&lhs_1.tensor.mul(&rhs_0.tensor).unwrap())\n            .unwrap(),\n    );\n\n    // Stack the components\n    let result_0_unsqueezed = CandleTensor::new(result_0.tensor.unsqueeze(1).unwrap());\n    let result_1_unsqueezed = CandleTensor::new(result_1.tensor.unsqueeze(1).unwrap());\n    let result_2_unsqueezed = CandleTensor::new(result_2.tensor.unsqueeze(1).unwrap());\n    let result = cat(\n        vec![\n            result_0_unsqueezed,\n            result_1_unsqueezed,\n            result_2_unsqueezed,\n        ],\n        1,\n    );\n\n    // Reshape back to the broadcast shape with dim at the end\n    let mut result_shape = broadcast_shape;\n    result_shape.remove(dim);\n    result_shape.push(3);\n    let result_reshaped = reshape(result, Shape::from(result_shape));\n\n    // Permute back\n    let mut inv_perm = vec![0; ndims];\n    for (i, &p) in perm.iter().enumerate() {\n        inv_perm[p] = i;\n    }\n    permute(result_reshaped, &inv_perm)\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/bool_tensor.rs",
    "content": "use burn_backend::{\n    BackTrace, DType, ExecutionError, Scalar, Shape, Slice, TensorData, TensorMetadata,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, Device, FloatTensor, IntTensor},\n};\n\nuse crate::{\n    Candle, CandleTensor,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n};\n\nuse super::base::{expand, permute, unfold};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {\n    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        super::base::empty(shape, device, candle_core::DType::U8)\n    }\n\n    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        super::base::zeros(shape, device, candle_core::DType::U8)\n    }\n\n    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        super::base::ones(shape, device, candle_core::DType::U8)\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {\n        let x: Vec<u8> = tensor\n            .tensor\n            .flatten_all()\n            .map_err(|err| ExecutionError::Generic {\n                reason: format!(\"{err}\"),\n                backtrace: BackTrace::capture(),\n            })?\n            .to_vec1()\n            .map_err(|err| ExecutionError::Generic {\n                reason: format!(\"{err}\"),\n                backtrace: BackTrace::capture(),\n            })?;\n\n        let y = x.iter().map(|b| !matches!(b, 0)).collect();\n\n        Ok(TensorData::new(y, tensor.shape()))\n    }\n\n    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {\n        match data.dtype {\n            DType::U8 => super::base::from_data::<u8>(data, device),\n            _ => unimplemented!(\"Unsupported dtype for `bool_from_data`\"),\n        }\n    }\n\n    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())\n    }\n\n    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())\n    }\n\n    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {\n        super::base::device(tensor)\n    }\n\n    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {\n        super::base::to_device(tensor, device)\n    }\n\n    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        super::base::reshape(tensor, shape)\n    }\n\n    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {\n        super::base::slice_with_steps(tensor, slices)\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        slices: &[Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        super::base::slice_assign(tensor, slices, value)\n    }\n\n    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {\n        super::base::cat(tensors, dim)\n    }\n\n    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())\n    }\n\n    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        let x = (candle_core::Tensor::zeros_like(&tensor.tensor).unwrap());\n        CandleTensor::new(tensor.tensor.eq(&x).unwrap())\n    }\n\n    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        let x = candle_core::Tensor::ones_like(&lhs.tensor).unwrap();\n        CandleTensor::new(lhs.tensor.add(&rhs.tensor).unwrap().gt(&x).unwrap())\n    }\n\n    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .add(&rhs.tensor)\n                .unwrap()\n                .clamp(0u32, 1u32)\n                .unwrap(),\n        )\n    }\n\n    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {\n        super::base::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        super::base::permute(tensor, axes)\n    }\n\n    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        super::base::flip(tensor, axes)\n    }\n\n    fn bool_select(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())\n    }\n\n    fn bool_select_or(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .index_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn bool_unfold(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> BoolTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        super::base::mask_where_broadcasted(tensor, mask, value)\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        CandleTensor::new(\n            mask.tensor\n                .where_cond(\n                    &super::candle_utils::fill_like::<u8>(value.elem(), &tensor.tensor),\n                    &tensor.tensor,\n                )\n                .unwrap(),\n        )\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        let tensor = tensor.tensor.contiguous().unwrap();\n        let indices = indices.tensor.contiguous().unwrap();\n        CandleTensor::new(tensor.gather(&indices, dim).unwrap())\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .scatter_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(lhs.tensor.eq(rhs.elem::<u8>()).unwrap())\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/candle_utils.rs",
    "content": "use candle_core::{DType, Device, Shape, Tensor};\n\nuse crate::element::CandleElement;\n\npub(crate) fn fill<E: CandleElement, S: Into<Shape>>(\n    value: E,\n    shape: S,\n    dtype: DType,\n    device: &Device,\n) -> Tensor {\n    let values = (Tensor::ones((1), dtype, device).unwrap() * value.elem::<f64>()).unwrap();\n    values.expand(shape).unwrap()\n}\n\npub(crate) fn fill_like<E: CandleElement>(value: E, reference_tensor: &Tensor) -> Tensor {\n    fill(\n        value,\n        reference_tensor.shape(),\n        reference_tensor.dtype(),\n        reference_tensor.device(),\n    )\n}\n\n/// Broadcasts two tensors to a common shape for comparison operations\npub(crate) fn broadcast_for_comparison(\n    lhs: &Tensor,\n    rhs: &Tensor,\n) -> Result<(Tensor, Tensor), candle_core::Error> {\n    let broadcast_shape = lhs\n        .shape()\n        .broadcast_shape_binary_op(rhs.shape(), \"comparison\")?;\n\n    let lhs = if broadcast_shape != *lhs.shape() {\n        lhs.broadcast_as(&broadcast_shape)?\n    } else {\n        lhs.clone()\n    };\n\n    let rhs = if broadcast_shape != *rhs.shape() {\n        rhs.broadcast_as(&broadcast_shape)?\n    } else {\n        rhs.clone()\n    };\n\n    Ok((lhs, rhs))\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/int_tensor.rs",
    "content": "use burn_backend::{\n    DType, Distribution, ElementConversion, ExecutionError, IntDType, Scalar, Shape, Slice,\n    TensorData,\n    ops::{FloatTensorOps, IntTensorOps},\n    tensor::{Bool, BoolTensor, Device, FloatTensor, IntElem, IntTensor},\n};\n\nuse crate::{\n    Candle, CandleDevice, CandleTensor, IntoDType,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n};\n\nuse super::base::{cpu_random, expand, permute, sign, unfold};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {\n    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        super::base::empty(shape, device, dtype.into_dtype())\n    }\n\n    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {\n        super::base::into_data(tensor)\n    }\n\n    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {\n        match data.dtype {\n            DType::I64 => super::base::from_data::<i64>(data, device),\n            DType::U32 => super::base::from_data::<u32>(data, device),\n            DType::U8 => super::base::from_data::<u8>(data, device),\n            _ => unimplemented!(\"Unsupported dtype for `int_from_data`\"),\n        }\n    }\n\n    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {\n        super::base::device(tensor)\n    }\n\n    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {\n        super::base::to_device(tensor, device)\n    }\n\n    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        super::base::reshape(tensor, shape)\n    }\n\n    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {\n        super::base::slice_with_steps(tensor, slices)\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<Self>,\n        slices: &[Slice],\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        super::base::slice_assign(tensor, slices, value)\n    }\n\n    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap())\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        source: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        super::base::mask_where_broadcasted(tensor, mask, source)\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> IntTensor<Self> {\n        CandleTensor::new(\n            mask.tensor\n                .where_cond(\n                    &super::candle_utils::fill_like::<I>(value.elem(), &tensor.tensor),\n                    &tensor.tensor,\n                )\n                .unwrap(),\n        )\n    }\n\n    fn int_gather(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let tensor = tensor.tensor.contiguous().unwrap();\n        let indices = indices.tensor.contiguous().unwrap();\n        CandleTensor::new(tensor.gather(&indices, dim).unwrap())\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .scatter_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn int_select(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .index_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {\n        super::base::cat(tensors, dim)\n    }\n\n    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())\n    }\n\n    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(lhs.tensor.eq(rhs.elem::<I>()).unwrap())\n    }\n\n    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())\n    }\n\n    fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .gt(&super::candle_utils::fill_like::<I>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .ge(&super::candle_utils::fill_like::<I>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())\n    }\n\n    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .lt(&super::candle_utils::fill_like::<I>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .le(&super::candle_utils::fill_like::<I>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())\n    }\n\n    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())\n    }\n\n    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())\n    }\n\n    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())\n    }\n\n    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())\n    }\n\n    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.\n        panic!(\"Not supported by Candle\")\n    }\n\n    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        CandleTensor::new(\n            (lhs.tensor.clone()\n                - lhs\n                    .tensor\n                    .broadcast_div(&rhs.tensor)\n                    .unwrap()\n                    .broadcast_mul(&rhs.tensor)\n                    .unwrap())\n            .unwrap(),\n        )\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        // Same problem as int_div_scalar.\n        panic!(\"Not supported by Candle\")\n    }\n\n    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        CandleTensor::new(\n            candle_core::Tensor::zeros(\n                shape.to_vec(),\n                dtype.into_dtype(),\n                &(device.clone()).into(),\n            )\n            .unwrap(),\n        )\n    }\n\n    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        CandleTensor::new(\n            candle_core::Tensor::ones(shape.to_vec(), dtype.into_dtype(), &(device.clone()).into())\n                .unwrap(),\n        )\n    }\n\n    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let sum = tensor.tensor.sum_all().unwrap().to_scalar::<I>().unwrap();\n        CandleTensor::from_data::<I>(\n            TensorData::new([sum].into(), [1]),\n            Self::int_device(&tensor),\n        )\n    }\n\n    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        todo!(\n            \"prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)\"\n        )\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        todo!(\n            \"prod_int is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)\"\n        )\n    }\n\n    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        // Candle implements scalar a/b as a * (1/b). With ints 1/b is rounded to 0 so we always obtain 0.\n        panic!(\"Not supported by Candle\")\n    }\n\n    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        // Candle's cumsum doesn't support integer types, so we convert to float,\n        // compute cumsum, and convert back to int\n        let dtype = tensor.tensor.dtype();\n        let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();\n        let result_float = tensor_float.cumsum(dim).unwrap();\n        CandleTensor::new(result_float.to_dtype(dtype).unwrap())\n    }\n\n    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        // Convert to float for computation, then convert back\n        let dtype = tensor.tensor.dtype();\n        let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();\n\n        let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {\n            prev.broadcast_mul(curr)\n        });\n        CandleTensor::new(result_float.to_dtype(dtype).unwrap())\n    }\n\n    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        // Convert to float for computation, then convert back\n        let dtype = tensor.tensor.dtype();\n        let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();\n\n        let result_float = super::utils::cumulative_with_op(&tensor_float, dim, |prev, curr| {\n            prev.broadcast_minimum(curr)\n        });\n        CandleTensor::new(result_float.to_dtype(dtype).unwrap())\n    }\n\n    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {\n            prev.broadcast_maximum(curr)\n        });\n        CandleTensor::new(result)\n    }\n\n    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .argmax_keepdim(dim)\n                .unwrap()\n                .to_dtype(I::DTYPE)\n                .unwrap(),\n        )\n    }\n\n    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .argmin_keepdim(dim)\n                .unwrap()\n                .to_dtype(I::DTYPE)\n                .unwrap(),\n        )\n    }\n\n    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        // Ugly type conversion here as Candle does not support unary ops on ints\n        match tensor.tensor.dtype() {\n            candle_core::DType::U8 | candle_core::DType::U32 => tensor,\n            candle_core::DType::I64 => CandleTensor::new(\n                tensor\n                    .tensor\n                    .to_dtype(F::DTYPE)\n                    .unwrap()\n                    .abs()\n                    .unwrap()\n                    .to_dtype(candle_core::DType::I64)\n                    .unwrap(),\n            ),\n            _ => unreachable!(),\n        }\n    }\n\n    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        super::base::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> IntTensor<Self> {\n        if let CandleDevice::Cpu = device {\n            let distribution = if distribution == Distribution::Default {\n                Distribution::Uniform(0.0, 255.0)\n            } else {\n                distribution\n            };\n            // Use our own seed since candle doesn't support it on CPU\n            return Self::int_from_data(cpu_random::<I>(shape, distribution), device);\n        }\n\n        let shape = shape.to_vec();\n        let device = &(device.clone()).into();\n        match distribution {\n            Distribution::Default => CandleTensor::new(\n                candle_core::Tensor::rand(0.elem::<F>(), 255.elem::<F>(), shape, device)\n                    .unwrap()\n                    .to_dtype(I::DTYPE)\n                    .unwrap(),\n            ),\n            Distribution::Bernoulli(prob) => CandleTensor::new(\n                candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)\n                    .unwrap()\n                    .to_dtype(I::DTYPE)\n                    .unwrap()\n                    .lt(&super::candle_utils::fill(prob, shape, I::DTYPE, device))\n                    .unwrap()\n                    .to_dtype(I::DTYPE)\n                    .unwrap(),\n            ),\n            Distribution::Uniform(from, to) => CandleTensor::new(\n                candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),\n            ),\n            Distribution::Normal(mean, std) => CandleTensor::new(\n                candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)\n                    .unwrap(),\n            ),\n        }\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        super::base::permute(tensor, axes)\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        super::base::flip(tensor, axes)\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n\n    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        sign(tensor)\n    }\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_and is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_and_scalar is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_or is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_or_scalar is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_xor is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_xor_scalar is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_not is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_left_shift is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_right_shift is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_left_shift_scalar is not implemented for Candle IntTensor\");\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unimplemented!(\"bitwise_right_shift_scalar is not implemented for Candle IntTensor\");\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let lhs = Self::int_into_float(lhs);\n        let rhs = Self::int_into_float(rhs);\n\n        let out = Self::float_matmul(lhs, rhs);\n        Self::float_into_int(out)\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let dtype = dtype.into_dtype();\n\n        if tensor.tensor.dtype() == dtype {\n            tensor\n        } else {\n            CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/mod.rs",
    "content": "mod activation;\nmod base;\nmod bool_tensor;\nmod candle_utils;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\nmod utils;\n"
  },
  {
    "path": "crates/burn-candle/src/ops/module.rs",
    "content": "use burn_backend::{\n    Shape,\n    ops::{\n        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,\n        InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,\n        UnfoldOptions, attention::attention_fallback,\n    },\n    tensor::{FloatTensor, IntTensor},\n};\nuse candle_core::ToUsize2;\n\nuse crate::{\n    Candle, CandleTensor,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n    ops::base::reshape,\n};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I> {\n    fn conv1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        let conv = x\n            .tensor\n            .conv1d(\n                &weight.tensor,\n                options.padding[0],\n                options.stride[0],\n                options.dilation[0],\n                options.groups,\n            )\n            .unwrap();\n        CandleTensor::new(match bias {\n            Some(bias) => conv\n                .broadcast_add(&bias.tensor.unsqueeze(1).unwrap())\n                .unwrap(),\n            None => conv,\n        })\n    }\n\n    fn conv2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        assert!(\n            options.dilation[0] == options.dilation[1]\n                && options.padding[0] == options.padding[1]\n                && options.stride[0] == options.stride[1],\n            \"Candle does not support per dimension options in convolutions\"\n        );\n        let conv = x\n            .tensor\n            .conv2d(\n                &weight.tensor,\n                options.padding[0],\n                options.stride[0],\n                options.dilation[0],\n                options.groups,\n            )\n            .unwrap();\n        CandleTensor::new(match bias {\n            Some(bias) => conv\n                .broadcast_add(\n                    &bias\n                        .tensor\n                        .unsqueeze(0)\n                        .unwrap()\n                        .unsqueeze(2)\n                        .unwrap()\n                        .unsqueeze(3)\n                        .unwrap(),\n                )\n                .unwrap(),\n            None => conv,\n        })\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        unimplemented!(\"Candle does not support deformable convolutions\")\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        unimplemented!(\"Candle does not support deformable convolutions\")\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        panic!(\"Candle does not support 3D convolutions\");\n    }\n\n    fn conv_transpose1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        let conv_transpose = x\n            .tensor\n            .conv_transpose1d(\n                &weight.tensor,\n                options.padding[0],\n                options.padding_out[0],\n                options.stride[0],\n                options.dilation[0],\n                options.groups,\n            )\n            .unwrap();\n        CandleTensor::new(match bias {\n            Some(bias) => conv_transpose\n                .broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())\n                .unwrap(),\n            None => conv_transpose,\n        })\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        assert!(\n            options.dilation[0] == options.dilation[1]\n                && options.padding[0] == options.padding[1]\n                && options.padding_out[0] == options.padding_out[1]\n                && options.stride[0] == options.stride[1],\n            \"Candle does not support per dimension options in transposed convolutions\"\n        );\n        assert!(\n            options.groups == 1,\n            \"Candle does not support groups in transposed convolutions\"\n        );\n        let conv_transpose = x\n            .tensor\n            .conv_transpose2d(\n                &weight.tensor,\n                options.padding[0],\n                options.padding_out[0],\n                options.stride[0],\n                options.dilation[0],\n            )\n            .unwrap();\n        CandleTensor::new(match bias {\n            Some(bias) => conv_transpose\n                .broadcast_add(\n                    &bias\n                        .tensor\n                        .unsqueeze(0)\n                        .unwrap()\n                        .unsqueeze(2)\n                        .unwrap()\n                        .unsqueeze(3)\n                        .unwrap(),\n                )\n                .unwrap(),\n            None => conv_transpose,\n        })\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        panic!(\"Candle does not support 3D transposed convolutions\");\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        assert!(\n            padding[0] == 0 && padding[1] == 0,\n            \"Candle does not support padding in pooling\"\n        );\n        assert!(\n            count_include_pad,\n            \"Candle does not support excluding pad count in pooling\"\n        );\n        assert!(!ceil_mode, \"Candle does not support ceil_mode in pooling\");\n        CandleTensor::new(\n            x.tensor\n                .avg_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))\n                .unwrap(),\n        )\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        _ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        panic!(\"avg_pool2d_backward is not supported by Candle\")\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        assert!(\n            padding[0] == 0 && padding[1] == 0,\n            \"Candle does not support padding in pooling\"\n        );\n        assert!(\n            dilation[0] == 1 && dilation[1] == 1,\n            \"Candle does not support dilation in pooling\"\n        );\n        assert!(!ceil_mode, \"Candle does not support ceil_mode in pooling\");\n        CandleTensor::new(\n            x.tensor\n                .max_pool2d_with_stride((kernel_size[0], kernel_size[1]), (stride[0], stride[1]))\n                .unwrap(),\n        )\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        _ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Candle<F, I>> {\n        panic!(\"max_pool2d_with_indices is not supported by Candle\")\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        _ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool2dBackward<Candle<F, I>> {\n        panic!(\"max_pool2d_with_indices_backward is not supported by Candle\")\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        panic!(\"adaptive_avg_pool2 is not supported by Candle\")\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        panic!(\"adaptive_avg_pool2d_backward is not supported by Candle\")\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        let tensor = match options.mode {\n            InterpolateMode::Nearest => x\n                .tensor\n                .upsample_nearest2d(output_size[0], output_size[1])\n                .unwrap(),\n            InterpolateMode::Bilinear => {\n                panic!(\"bilinear interpolation is not supported by Candle\")\n            }\n            InterpolateMode::Bicubic => {\n                panic!(\"bicubic interpolation is not supported by Candle\")\n            }\n            InterpolateMode::Lanczos3 => {\n                panic!(\"lanczos3 interpolation is not supported by Candle\")\n            }\n        };\n\n        CandleTensor::new(tensor)\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        panic!(\"interpolate_backward is not supported by Candle\")\n    }\n\n    fn attention(\n        query: FloatTensor<Self>,\n        key: FloatTensor<Self>,\n        value: FloatTensor<Self>,\n        mask: Option<burn_backend::tensor::BoolTensor<Self>>,\n        attn_bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::AttentionModuleOptions,\n    ) -> FloatTensor<Self> {\n        attention_fallback::<Self>(query, key, value, mask, attn_bias, options)\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    Backend, DType, ExecutionError, Shape, Slice, TensorData,\n    ops::QTensorOps,\n    quantization::{QuantScheme, QuantizationParametersPrimitive},\n    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},\n};\n\nuse crate::{\n    Candle,\n    element::{FloatCandleElement, IntCandleElement},\n};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F, I> {\n    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn quantize(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n        _qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {\n        unimplemented!()\n    }\n\n    fn q_to_device(\n        _tensor: QuantizedTensor<Self>,\n        _device: &Device<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unimplemented!()\n    }\n\n    fn q_swap_dims(\n        _tensor: QuantizedTensor<Self>,\n        _dim1: usize,\n        _dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_gather(\n        _dim: usize,\n        _tensor: QuantizedTensor<Self>,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_select(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/tensor.rs",
    "content": "use std::borrow::Borrow;\n\nuse burn_backend::{\n    DType, Distribution, ElementConversion, ExecutionError, FloatDType, Scalar, Shape, Slice,\n    TensorData, bf16, f16,\n    ops::FloatTensorOps,\n    tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor},\n};\nuse candle_core::{Tensor, backend::BackendStorage, shape};\n\nuse crate::{\n    Candle, CandleDevice, CandleTensor, IntoDType,\n    element::{CandleElement, FloatCandleElement, IntCandleElement},\n};\n\nuse super::base::{cpu_random, expand, permute, sign, unfold};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {\n    fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {\n        match data.dtype {\n            DType::F64 => super::base::from_data::<f64>(data, device),\n            DType::F32 => super::base::from_data::<f32>(data, device),\n            DType::F16 => super::base::from_data::<f16>(data, device),\n            DType::BF16 => super::base::from_data::<bf16>(data, device),\n            _ => unimplemented!(\"Unsupported dtype for `float_from_data`\"),\n        }\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> FloatTensor<Self> {\n        if let CandleDevice::Cpu = device {\n            // Use our own seed since candle doesn't support it on CPU\n            return Self::float_from_data(cpu_random::<F>(shape, distribution), device);\n        }\n\n        let shape = shape.to_vec();\n        let device = &(device.clone()).into();\n        match distribution {\n            Distribution::Default => CandleTensor::new(\n                candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape, device)\n                    .unwrap()\n                    .to_dtype(F::DTYPE)\n                    .unwrap(),\n            ),\n            Distribution::Bernoulli(prob) => CandleTensor::new(\n                candle_core::Tensor::rand(0.elem::<F>(), 1.elem::<F>(), shape.clone(), device)\n                    .unwrap()\n                    .to_dtype(F::DTYPE)\n                    .unwrap()\n                    .lt(&super::candle_utils::fill(prob, shape, F::DTYPE, device))\n                    .unwrap()\n                    .to_dtype(F::DTYPE)\n                    .unwrap(),\n            ),\n            Distribution::Uniform(from, to) => CandleTensor::new(\n                candle_core::Tensor::rand(from.elem::<F>(), to.elem::<F>(), shape, device).unwrap(),\n            ),\n            Distribution::Normal(mean, std) => CandleTensor::new(\n                candle_core::Tensor::randn(mean.elem::<F>(), std.elem::<F>(), shape, device)\n                    .unwrap(),\n            ),\n        }\n    }\n\n    async fn float_into_data(tensor: CandleTensor) -> Result<TensorData, ExecutionError> {\n        super::base::into_data(tensor)\n    }\n\n    fn float_device(tensor: &CandleTensor) -> Device<Self> {\n        super::base::device(tensor)\n    }\n\n    fn float_to_device(tensor: CandleTensor, device: &Device<Self>) -> CandleTensor {\n        super::base::to_device(tensor, device)\n    }\n\n    fn float_into_int(tensor: CandleTensor) -> IntTensor<Self> {\n        CandleTensor::new(tensor.tensor.to_dtype(I::DTYPE).unwrap())\n    }\n\n    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        super::base::empty(shape, device, dtype.into_dtype())\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_add(&rhs.tensor).unwrap())\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new((lhs.tensor + rhs.elem::<f64>()).unwrap())\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_sub(&rhs.tensor).unwrap())\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new((lhs.tensor - rhs.elem::<f64>()).unwrap())\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_mul(&rhs.tensor).unwrap())\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new((lhs.tensor * rhs.elem::<f64>()).unwrap())\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(lhs.tensor.broadcast_div(&rhs.tensor).unwrap())\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new((lhs.tensor / rhs.elem::<f64>()).unwrap())\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(\n            (lhs.tensor.clone()\n                - lhs\n                    .tensor\n                    .broadcast_div(&rhs.tensor)\n                    .unwrap()\n                    .floor()\n                    .unwrap()\n                    .broadcast_mul(&rhs.tensor)\n                    .unwrap())\n            .unwrap(),\n        )\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        // In PyTorch, remainder can also be defined as torch.remainder(a, b) == a - a.div(b, rounding_mode=\"floor\") * b\n        let rhs_val = rhs.elem::<f64>();\n        let division_result = (lhs.tensor.clone() / rhs_val).unwrap().floor().unwrap();\n        let product = division_result * rhs_val;\n\n        CandleTensor::new((lhs.tensor - product).unwrap())\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let lhs_contiguous = if !lhs.tensor.is_contiguous() {\n            lhs.tensor.contiguous().unwrap()\n        } else {\n            lhs.tensor\n        };\n        let rhs_contiguous = if !rhs.tensor.is_contiguous() {\n            rhs.tensor.contiguous().unwrap()\n        } else {\n            rhs.tensor\n        };\n        CandleTensor::new(lhs_contiguous.broadcast_matmul(&rhs_contiguous).unwrap())\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        super::base::cross(lhs, rhs, dim)\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        super::base::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        super::base::reshape(tensor, shape)\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let tensor = tensor.tensor.contiguous().unwrap();\n        let indices = indices.tensor.contiguous().unwrap();\n        CandleTensor::new(tensor.gather(&indices, dim).unwrap())\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .scatter_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .index_add(&indices.tensor, &value.tensor, dim)\n                .unwrap(),\n        )\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        super::base::slice_with_steps(tensor, slices)\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        super::base::slice_assign(tensor, slices, value)\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        super::base::mask_where_broadcasted(tensor, mask, value)\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        let value = super::candle_utils::fill_like::<F>(value.elem(), &tensor.tensor);\n        super::base::mask_where_broadcasted(tensor, mask, CandleTensor::new(value))\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.eq(&rhs_broadcast).unwrap())\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(lhs.tensor.eq(rhs.elem::<F>()).unwrap())\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.gt(&rhs_broadcast).unwrap())\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .gt(&super::candle_utils::fill_like::<F>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.ge(&rhs_broadcast).unwrap())\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .ge(&super::candle_utils::fill_like::<F>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.lt(&rhs_broadcast).unwrap())\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .lt(&super::candle_utils::fill_like::<F>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let (lhs_broadcast, rhs_broadcast) =\n            super::candle_utils::broadcast_for_comparison(&lhs.tensor, &rhs.tensor).unwrap();\n        CandleTensor::new(lhs_broadcast.le(&rhs_broadcast).unwrap())\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        CandleTensor::new(\n            lhs.tensor\n                .le(&super::candle_utils::fill_like::<F>(\n                    rhs.elem(),\n                    &lhs.tensor,\n                ))\n                .unwrap(),\n        )\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let sum = tensor.tensor.sum_all().unwrap().to_scalar::<F>().unwrap();\n        CandleTensor::from_data::<F>(\n            TensorData::new([sum].into(), [1]),\n            Self::float_device(&tensor),\n        )\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap())\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {\n            prev.broadcast_mul(curr)\n        });\n        CandleTensor::new(result)\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {\n            prev.broadcast_minimum(curr)\n        });\n        CandleTensor::new(result)\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let result = super::utils::cumulative_with_op(&tensor.tensor, dim, |prev, curr| {\n            prev.broadcast_maximum(curr)\n        });\n        CandleTensor::new(result)\n    }\n\n    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.exp().unwrap())\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.log().unwrap())\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new((tensor.tensor + 1.).unwrap().log().unwrap())\n    }\n\n    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.powf(value.elem::<f64>()).unwrap())\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.sqrt().unwrap())\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.abs().unwrap())\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.cos().unwrap())\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // cosh(x) = (e^x + e^(-x)) / 2\n        let exp_x = tensor.tensor.exp().unwrap();\n        CandleTensor::new(((exp_x.clone() + exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.sin().unwrap())\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // sinh(x) = (e^x - e^(-x)) / 2\n        let exp_x = tensor.tensor.exp().unwrap();\n        CandleTensor::new(((exp_x.clone() - exp_x.recip().unwrap()).unwrap() / 2.0).unwrap())\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new((tensor.tensor.sin().unwrap() / tensor.tensor.cos().unwrap()).unwrap())\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.tanh().unwrap())\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // acos(x) = PI/2 - asin(x)\n        let neg_asin_x = Self::float_neg(Self::float_asin(tensor));\n        Self::float_add_scalar(neg_asin_x, core::f64::consts::FRAC_PI_2.into())\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // acosh(x) = ln(x + sqrt(x^2 - 1))\n        let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());\n        let x_sq_minus_one = Self::float_sub_scalar(x_squared, 1f64.into());\n        let sqrt_term = Self::float_sqrt(x_sq_minus_one);\n        Self::float_log(Self::float_add(tensor, sqrt_term))\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // asin(x) = atan(x / sqrt(1 - x^2))\n        let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());\n        let one_minus_x_sq = Self::float_add_scalar(Self::float_neg(x_squared), 1f64.into());\n        let sqrt_term = Self::float_sqrt(one_minus_x_sq);\n        Self::float_atan(Self::float_div(tensor, sqrt_term))\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // asinh(x) = ln(x + sqrt(x^2 + 1))\n        let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());\n        let x_sq_plus_one = Self::float_add_scalar(x_squared, 1f64.into());\n        let sqrt_term = Self::float_sqrt(x_sq_plus_one);\n        Self::float_log(Self::float_add(tensor, sqrt_term))\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // atan(x) = asin(x / sqrt(1 + x^2))\n        let x_squared = Self::float_powi_scalar(tensor.clone(), 2.into());\n        let one_plus_x_sq = Self::float_add_scalar(x_squared, 1f64.into());\n        let sqrt_term = Self::float_sqrt(one_plus_x_sq);\n        Self::float_asin(Self::float_div(tensor, sqrt_term))\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // atanh(x) = ln((1 + x) / (1 - x)) / 2\n        let num = (1.0 + tensor.tensor.clone()).unwrap();\n        let denom = (1.0 - tensor.tensor).unwrap();\n        CandleTensor::new(((num / denom).unwrap().log().unwrap() / 2.0).unwrap())\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        // atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))\n        let x_squared = Self::float_powi_scalar(rhs.clone(), 2.into());\n        let y_squared = Self::float_powi_scalar(lhs.clone(), 2.into());\n        let r = Self::float_sqrt(Self::float_add(x_squared, y_squared));\n        let ratio = Self::float_div(lhs, Self::float_add(r, rhs));\n        Self::float_mul_scalar(Self::float_atan(ratio), 2f64.into())\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {\n            // implements round_to_even for consistent behavior vs libtorch\n            // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67\n\n            let floor_a = tensor.tensor.floor()?;\n            let frac_part = tensor.tensor.sub(&floor_a)?;\n\n            let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;\n            let mask_half = frac_part.eq(&half)?;\n            let half_tensor = tensor.tensor.mul(&half)?;\n            let rounded_half = half_tensor.round()?;\n            let doubled =\n                rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;\n            let standard_round = tensor.tensor.round()?;\n            Ok(CandleTensor::new(\n                mask_half.where_cond(&doubled, &standard_round)?,\n            ))\n        };\n        inner(tensor).unwrap()\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.floor().unwrap())\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.ceil().unwrap())\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // truncate(x) = ⌊x⌋ if x ≥ 0, and ⌈x⌉ if x < 0\n        // This preserves the sign of zero and handles all special cases correctly\n        let is_negative = tensor.tensor.lt(0.0).unwrap();\n        let floored = tensor.tensor.floor().unwrap();\n        let ceiled = tensor.tensor.ceil().unwrap();\n        CandleTensor::new(is_negative.where_cond(&ceiled, &floored).unwrap())\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.erf().unwrap())\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        super::base::cat(tensors, dim)\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .argmax_keepdim(dim)\n                .unwrap()\n                .to_dtype(I::DTYPE)\n                .unwrap(),\n        )\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .argmin_keepdim(dim)\n                .unwrap()\n                .to_dtype(I::DTYPE)\n                .unwrap(),\n        )\n    }\n\n    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.minimum(max.elem::<F>()).unwrap())\n    }\n\n    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.maximum(min.elem::<F>()).unwrap())\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        CandleTensor::new(\n            tensor\n                .tensor\n                .clamp(min.elem::<F>(), max.elem::<F>())\n                .unwrap(),\n        )\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        CandleTensor::new(tensor.tensor.recip().unwrap())\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        //broadcast_pow is in main but not yet published\n        //note: probably replace once pow once 0.3.3 is out\n        //see: https://github.com/huggingface/candle/pull/1583/files#diff-6319fa1e16dadc4c7b4e25698139703d93b70f30a1f8e2ac0999978e39efaa81R2594\n\n        CandleTensor::new(\n            rhs.tensor\n                .broadcast_mul(&lhs.tensor.log().unwrap())\n                .unwrap()\n                .exp()\n                .unwrap(),\n        )\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        super::base::permute(tensor, axes)\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        super::base::flip(tensor, axes)\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n\n    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        sign(tensor)\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let dtype = dtype.into_dtype();\n\n        if tensor.tensor.dtype() == dtype {\n            tensor\n        } else {\n            CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/transaction.rs",
    "content": "use burn_backend::{\n    Backend,\n    ops::{TransactionOps, TransactionPrimitive},\n};\n\nuse crate::{\n    Candle,\n    element::{FloatCandleElement, IntCandleElement},\n};\n\nimpl<F: FloatCandleElement, I: IntCandleElement> TransactionOps<Self> for Candle<F, I> {}\n"
  },
  {
    "path": "crates/burn-candle/src/ops/utils.rs",
    "content": "/// Helper function for cumulative operations in Candle backend\n///\n/// This function reduces code duplication for cumulative operations (cumprod, cummin, cummax)\n/// which all follow the same pattern of slicing, applying an operation, and concatenating.\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor\n/// * `dim` - The dimension along which to apply the cumulative operation\n/// * `op` - A closure that takes two tensor references and produces a result tensor\npub fn cumulative_with_op<F>(tensor: &candle_core::Tensor, dim: usize, op: F) -> candle_core::Tensor\nwhere\n    F: Fn(&candle_core::Tensor, &candle_core::Tensor) -> candle_core::Result<candle_core::Tensor>,\n{\n    let dim_size = tensor.dims()[dim];\n    let mut slices = Vec::with_capacity(dim_size);\n\n    // First slice is the initial value\n    slices.push(tensor.narrow(dim, 0, 1).unwrap());\n\n    // Apply cumulative operation\n    for i in 1..dim_size {\n        let curr = tensor.narrow(dim, i, 1).unwrap();\n        let result = op(&slices[i - 1], &curr).unwrap();\n        slices.push(result);\n    }\n\n    candle_core::Tensor::cat(&slices, dim).unwrap()\n}\n"
  },
  {
    "path": "crates/burn-candle/src/tensor.rs",
    "content": "use burn_backend::{DType, FloatDType, IntDType, Shape, quantization::QuantScheme};\nuse burn_backend::{Element, QTensorPrimitive, TensorData, TensorMetadata};\nuse burn_std::BoolStore;\n\nuse crate::{CandleDevice, element::CandleElement};\n\n/// A tensor that uses the candle backend.\n#[derive(Debug, Clone)]\npub struct CandleTensor {\n    pub(crate) tensor: candle_core::Tensor,\n}\n\nimpl TensorMetadata for CandleTensor {\n    fn dtype(&self) -> DType {\n        match self.tensor.dtype() {\n            candle_core::DType::U8 => DType::U8,\n            candle_core::DType::U32 => DType::U32,\n            candle_core::DType::I64 => DType::I64,\n            candle_core::DType::BF16 => DType::BF16,\n            candle_core::DType::F16 => DType::F16,\n            candle_core::DType::F32 => DType::F32,\n            candle_core::DType::F64 => DType::F64,\n            candle_core::DType::I16 => DType::I16,\n            candle_core::DType::I32 => DType::I32,\n            other => todo!(\"{other:?} not yet supported\"),\n        }\n    }\n\n    fn shape(&self) -> Shape {\n        Shape::from(self.tensor.dims().to_vec())\n    }\n\n    fn rank(&self) -> usize {\n        self.tensor.dims().len()\n    }\n}\n\nimpl QTensorPrimitive for CandleTensor {\n    fn scheme(&self) -> &QuantScheme {\n        unimplemented!(\"Quantization is not supported\")\n    }\n}\n\nimpl CandleTensor {\n    /// Create a new tensor.\n    pub fn new(tensor: candle_core::Tensor) -> Self {\n        Self { tensor }\n    }\n\n    /// Creates a new tensor from data and a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The tensor's data.\n    /// * `device` - The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor.\n    pub fn from_data<E: CandleElement>(data: TensorData, device: CandleDevice) -> Self {\n        let candle_shape: candle_core::Shape = data.shape.to_vec().into();\n        let tensor = candle_core::Tensor::from_slice(\n            data.as_slice::<E>().unwrap(),\n            candle_shape,\n            &device.into(),\n        );\n        Self::new(tensor.unwrap())\n    }\n}\n\npub(crate) trait IntoDType {\n    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error>;\n\n    fn into_dtype(self) -> candle_core::DType\n    where\n        Self: Sized,\n    {\n        self.try_into_dtype().unwrap()\n    }\n}\n\nimpl IntoDType for IntDType {\n    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {\n        let dtype: DType = self.into();\n        dtype.try_into_dtype()\n    }\n}\n\nimpl IntoDType for FloatDType {\n    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {\n        let dtype: DType = self.into();\n        dtype.try_into_dtype()\n    }\n}\n\nimpl IntoDType for DType {\n    fn try_into_dtype(self) -> Result<candle_core::DType, candle_core::Error> {\n        match self {\n            DType::F64 => Ok(candle_core::DType::F64),\n            DType::F32 => Ok(candle_core::DType::F32),\n            DType::Flex32 => Ok(candle_core::DType::F32),\n            DType::F16 => Ok(candle_core::DType::F16),\n            DType::BF16 => Ok(candle_core::DType::BF16),\n            DType::I64 => Ok(candle_core::DType::I64),\n            DType::U32 => Ok(candle_core::DType::U32),\n            DType::U8 => Ok(candle_core::DType::U8),\n            DType::I16 => Ok(candle_core::DType::I16),\n            DType::I32 => Ok(candle_core::DType::I32),\n            DType::Bool(BoolStore::U8) => Ok(candle_core::DType::U8),\n            _ => Err(candle_core::Error::Msg(format!(\n                \"Unsupported dtype {self:?}\"\n            ))),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Backend extension for collective calculations.\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"collective\"]\nlicense.workspace = true\nname = \"burn-collective\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-collective\"\ndocumentation = \"https://docs.rs/burn-collective\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = []\ndoc = []\ntracing = [\n    \"dep:tracing\",\n    \"burn-std/tracing\",\n    \"burn-tensor/tracing\",\n    \"burn-communication/tracing\",\n    \"burn-ndarray?/tracing\",\n    \"burn-wgpu?/tracing\",\n    \"burn-cuda?/tracing\",\n]\norchestrator = [\"burn-communication/websocket\"]\n\n# Backends for testing\ntest-ndarray = [\"burn-ndarray\"]\ntest-wgpu = [\"burn-wgpu\", \"burn-wgpu/webgpu\"]\ntest-metal = [\"burn-wgpu\", \"burn-wgpu/metal\"]\ntest-vulkan = [\"burn-wgpu\", \"burn-wgpu/vulkan\"]\ntest-cuda = [\"burn-cuda\"]\n\n[dependencies]\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = true }\n\nlog = { workspace = true }\n\nburn-communication = { path = \"../burn-communication\", version = \"=0.21.0-pre.2\", features = [\n    \"data-service\",\n    \"websocket\",\n] }\ntokio = { workspace = true, features = [\n    \"rt-multi-thread\",\n    \"sync\",\n    \"signal\",\n    \"time\",\n    \"tracing\",\n] }\nserde = { workspace = true, features = [\"derive\"] }\nrmp-serde = { workspace = true }\nbytes = { workspace = true }\nfutures = { workspace = true }\ntokio-util = { workspace = true }\ntracing = { workspace = true, optional = true }\n\n# Tests\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true }\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true }\n\n[dev-dependencies]\nserial_test = { workspace = true }\n\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-collective/README.md",
    "content": "# burn-collective\n\nCollective operations on tensors\n\nThe following collective operation are implemented:\n\n- `all-reduce`\n    Aggregates a tensor between all peers, and distributes the result to all peers.\n    Different strategies can be used on the local and global levels. The result can only be\n    returned when all peers have called the all-reduce.\n- `reduce`\n    Aggregates a tensor from all peers onto one peer, called the \"root\"\n- `broadcast`\n    Copies a tensor from one peer to all other peers in the collective.\n\nPeers must call `register` before calling any other operation.\nThe total number of devices on the node, or nodes in the collective, must be known ahead of time.\n\nIn many libraries like NCCL and PyTorch, participating units are called \"ranks\".\nThis name is confusing in the context of tensors, so in burn-collective the participating units\nare called \"peers\".\n\n*`reduce` and `broadcast` are not yet implemented for multi-node contexts*\n\n## Local and Global\n\nInternally, there are two levels to the collective operations: local and global. Operations are done on the local level, then optionally on the global level.\n\n| Local                                      | Global                                        |\n|-----------------------------------------------|-----------------------------------------------|\n| Intra-node (typically within one machine)     | Inter-node (typically across machies)         |\n| Participants are threads (one per peer/GPU) | Participants are processes (one per node)     |\n| Communication depends on backend              | Network peer-to-peer communication            |\n| Local server is launched automatically      | Global coordinator must be launched manually  |\n| Local server does the aggregation          | Nodes do the operations themselves            |\n\nFor global operations (ie. with multiple nodes), there must be a global orchestrator available.\nStart one easily with `burn_collective::start_global_orchestrator()`.\n\nOn the global level, nodes use the `burn_communication::data_service::TensorDataService` to\nexpose and download tensors in a peer-to-peer manner, in order to be independent.\n\n## Components\n\nThe following are the important pieces of the collective operations system.\n\n| Term                           | One per...    | Meaning\n|--------------------------------|---------------|----------------------------------------------------------\n| Local Collective Client        | Peer/thread | Requests operations to the Local Collective Server\n| Local Collective Server        | Node/process  | Does local-level ops for threads in this process. In the case of global operations, passes operations on to the Global Collective Client.\n| Global Collective Client       | Node/process  | Does global-level ops for this node. Registers and requests strategies from the Global Collective Orchestrator.\n| Global Collective Orchestrator | Collective    | Responds to the Global Collective Client from each node. Responsible for aggregation strategies.\n\n## Strategies\n\nDifferent strategies can be used on the local and global level.\n\n### Centralized\n\nAn arbitrary peer is designated as the \"root\", and all others are transferred to the root's device.\nThe operation is done on that device.\nThe resulting tensor then sent to each peer.\n\n### Tree\n\nTensors in groups of N are aggregated together. This is done recursively until only one tensor\nremains. The strategy tries to put devices of the same type closer in the tree.\nWhen N=2, this is like a binary tree reduce.\nThe resulting tensor then sent to each peer\n\n### Ring\n\nSee this good explanation: <https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model>\n\nThe tensors are sliced into N parts, where N is the number of tensors to aggregate.\nThen, the slices are sent around in a series of cycles and aggregated until every tensor's slices\nis a sum of the other corresponding slices.\n\nIn the case where the tensors are too small to split into N slices, a fallback algorithm is used.\nFor now, the fallback is a binary tree.\n\n(p=3, n=3)\n\no->o  o  \no  o->o  \no  o  o->\n\no  1->o  \no  o  1->\n1->o  o  \n\no  1  2->\n2->o  1  \n1  2->o  \n\n3  1  2\n2  3  1\n1  2  3\n\n(This is essentially a reduce-scatter)\n\n3->x  x  \nx  3->x  \nx  x  3->\n\n3  3->x  \nx  3  3->\n3->x  3  \n\n3  3  3->\n3->3  3  \n3  3->3  \n\n3  3  3\n3  3  3\n3  3  3\n\n(This is essentially an all-gather)\n\nThis is done so that every peer is both sending and receiving data at any moment.\nThis is an important part of this strategy's advantages.\n\nThe ring strategy takes full advantage of the bandwidth available. The latency scales with the\nnumber of peers.\n\nSo when the tensors are very small, or when the number of peers is very large, the latency is more\nimportant in the ring strategy, and a tree algorithm is better. Otherwise, the ring algorithm is\nthe better.\n\nIn multi-node contexts, use of the Ring strategy in the local level may be less\nadvantageous. With multiple nodes, the global all-reduce step is enabled, and its result\nis redistributed to all devices.\nThe Ring strategy inherently distributes the result, which in this context would not be necessary.\n\nIt is recommended to use the Ring strategy at the global level\n\n### Double binary tree\n\n<https://developer.nvidia.com/blog/massively-scale-deep-learning-training-nccl-2-4/>\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/Cargo.toml",
    "content": "[package]\nname = \"burn-collective-multinode-tests\"\nversion.workspace = true\nedition.workspace = true\nlicense.workspace = true\n\n[features]\ndefault = [\"ndarray\"]\nndarray = [\"burn/ndarray\"]\n\n[dependencies]\nburn = { path = \"../../burn\", default-features = false, features = [\"std\"] }\nburn-std = { path = \"../../burn-std\", default-features = false }\nburn-collective = { path = \"..\", features = [\"orchestrator\"] }\nburn-communication = { path = \"../../burn-communication\" }\n\ntokio = { workspace = true, features = [\"rt-multi-thread\", \"process\"] }\n\nserde = { workspace = true, features = [\"derive\"] }\nserde_json = { workspace = true }\ninterprocess = \"2.3.1\"\nrmp-serde = { workspace = true }\ntokio-util = { workspace = true, features = [\"codec\"] }\ntokio-serde = { version = \"0.9.0\", features = [\"messagepack\"] }\nfutures = { workspace = true }\n\n\n[[bin]]\nname = \"global\"\npath = \"src/bin/global.rs\"\n\n[[bin]]\nname = \"node\"\npath = \"src/bin/node.rs\"\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/README.md",
    "content": "# Integration test for burn collective operations with multiple nodes and devices.\n\nRun `cargo run --bin test_launcher`\n\nThere are 3 binaries:\n\n## node.rs\n\nLaunches `n` threads each simulating a different device. Currently the backend is NdArray,\nso everything is CPU. The program takes a file with configurations and input data.\n\n## global.rs\n\nRuns the global orchestrator, who is responsible for responding to global collective operation\nrequests. In the case of an all-reduce, the orchestrator responds with a strategy for reducing,\nand the node can do the reduction independently.\n\n## test_launcher.rs\n\nGenerates input data, calculates the expected results, and launches the nodes each with their\nown inputs in a separate file.\n\nThe topology is [4, 4, 4, 4]. This means 4 nodes are launched,\neach with 4 threads (for each device).\n\nThe global orchestrator (`global.rs`) is also launched.\n\n## Output\n\nThe outputs and inputs for each node and the orchestrator are written to the `target/test_files` folder\n\nIf the nodes or orchestrator stall, there is a timeout.\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/src/bin/global.rs",
    "content": "//! Global orchestrator\n//!\n//! Launches the orchestrator that responds to global collective operations for nodes for the\n//! integration test\n//!\n//! This is necessary for any node who needs global collective operations\n\nuse std::env;\n\n#[tokio::main]\n/// Start the global orchestrator on the port given as first arg\npub async fn main() {\n    let args: Vec<String> = env::args().collect();\n    let port = args[1].parse::<u16>().expect(\"invalid port\");\n\n    // Launch the global orchestrator, which will listen and respond to global collective op\n    // requests from nodes\n    burn_collective::start_global_orchestrator(port).await;\n}\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/src/bin/node.rs",
    "content": "use burn::{\n    backend::NdArray,\n    prelude::Backend,\n    tensor::{Tensor, TensorPrimitive, Tolerance},\n};\nuse burn_collective::{\n    CollectiveConfig, PeerId, ReduceOperation, all_reduce, finish_collective, register,\n    reset_collective,\n};\nuse burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};\nuse std::{\n    env,\n    sync::mpsc::SyncSender,\n    time::{Duration, Instant},\n};\nuse tokio::net::TcpStream;\n\nuse futures::{SinkExt, StreamExt};\nuse std::thread::JoinHandle;\nuse tokio_serde::formats::MessagePack;\nuse tokio_util::codec::LengthDelimitedCodec;\n\ntype TestBackend = NdArray;\n\n/// Framed TCP connection channel\ntype TestChannel = tokio_serde::Framed<\n    tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,\n    NodeTest,\n    NodeTestResult,\n    MessagePack<NodeTest, NodeTestResult>,\n>;\n\n/// Start a node that will test all-reduce\n/// Args are the following:\n/// - launcher endpoint\n#[tokio::main]\npub async fn main() {\n    let args: Vec<String> = env::args().collect();\n\n    let launcher_addr = args[1].clone();\n\n    let socket = TcpStream::connect(launcher_addr).await.unwrap();\n    let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());\n    let mut socket: TestChannel = tokio_serde::Framed::new(\n        length_delimited,\n        MessagePack::<NodeTest, NodeTestResult>::default(),\n    );\n\n    // Loop: receive, do test, send result\n    while let Some(Ok(test)) = socket.next().await {\n        println!(\"Received test: {test:?}\");\n\n        let result = run_test::<NdArray>(&test);\n\n        // send the result back\n        socket.send(result).await.expect(\"failed to send Result\");\n    }\n\n    println!(\"Server closed connection; exiting.\");\n}\n\n/// Runs a test for one node\nfn run_test<B: Backend>(test_input: &NodeTest) -> NodeTestResult {\n    reset_collective::<TestBackend>();\n\n    // Channel for results\n    let (result_send, result_recv) = std::sync::mpsc::sync_channel(32);\n\n    // Launch a thread for each \"device\"\n    let handles = launch_threads::<B>(test_input.clone(), result_send);\n\n    // Receive results\n    let mut durations = vec![];\n    let tol: Tolerance<f32> = Tolerance::balanced();\n    for _ in 0..test_input.device_count {\n        // Assert all results are equal to each other as well as expected result\n        let (tensor, duration) = result_recv.recv().unwrap();\n        test_input.expected.assert_approx_eq(&tensor.to_data(), tol);\n\n        durations.push(duration);\n    }\n\n    // Threads finish\n    for handle in handles {\n        let _ = handle.join();\n    }\n\n    NodeTestResult {\n        success: true,\n        durations,\n    }\n}\n\n/// Launch a thread for each device, and run the all-reduce\nfn launch_threads<B: Backend>(\n    test_input: NodeTest,\n    result_send: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,\n) -> Vec<JoinHandle<()>> {\n    let mut handles = vec![];\n    for id in 0..test_input.device_count {\n        // Launch a thread to test\n\n        // Put all the parameters in the config\n        let config = CollectiveConfig::default()\n            .with_num_devices(test_input.device_count)\n            .with_global_address(test_input.global_address.clone())\n            .with_node_address(test_input.node_address.clone())\n            .with_data_service_port(test_input.data_service_port)\n            .with_num_nodes(test_input.node_count)\n            .with_global_all_reduce_strategy(test_input.global_strategy)\n            .with_local_all_reduce_strategy(test_input.local_strategy);\n\n        // Inputs and outputs for the test\n        let tensor_data = test_input.inputs[id].clone();\n        let tensor = Tensor::<B, TENSOR_RANK>::from_data(tensor_data, &B::Device::default());\n        let result_send = result_send.clone();\n\n        let handle = std::thread::spawn(move || {\n            run_peer::<B>(\n                id.into(),\n                config,\n                tensor,\n                result_send,\n                test_input.all_reduce_op,\n            )\n        });\n        handles.push(handle);\n    }\n\n    handles\n}\n\n/// Runs a thread in the all-reduce test.\npub fn run_peer<B: Backend>(\n    id: PeerId,\n    config: CollectiveConfig,\n    input: Tensor<B, TENSOR_RANK>,\n    output: SyncSender<(Tensor<B, TENSOR_RANK>, Duration)>,\n    all_reduce_op: ReduceOperation,\n) {\n    // Register the device\n    register::<B>(id, input.device(), config).unwrap();\n\n    let start = Instant::now();\n\n    // All-reduce\n    let input = input.into_primitive().tensor();\n    let tensor = all_reduce::<B>(id, input, all_reduce_op).unwrap();\n    let tensor = Tensor::<B, TENSOR_RANK>::from_primitive(TensorPrimitive::Float(tensor));\n\n    let duration = start.elapsed();\n\n    // Send result\n    output.send((tensor, duration)).unwrap();\n\n    finish_collective::<B>(id).unwrap();\n}\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/src/bin/test_launcher.rs",
    "content": "use burn::tensor::TensorData;\nuse burn_communication::Address;\nuse futures::{SinkExt, StreamExt};\nuse std::{\n    fmt::Display,\n    fs::{self, File},\n    str::FromStr,\n    time::{Duration, Instant},\n    vec,\n};\nuse tokio::net::TcpListener;\nuse tokio_serde::formats::MessagePack;\nuse tokio_util::codec::LengthDelimitedCodec;\n\nuse burn::{backend::NdArray, prelude::Backend, tensor::Tensor};\nuse burn_collective::{AllReduceStrategy, ReduceOperation};\nuse burn_collective_multinode_tests::shared::{NodeTest, NodeTestResult, TENSOR_RANK};\nuse burn_std::rand::{SeedableRng, StdRng};\nuse tokio::process::{Child, Command};\n\n#[derive(Clone)]\nstruct AllReduceTest {\n    shape: [usize; TENSOR_RANK],\n    op: ReduceOperation,\n    local_strategy: AllReduceStrategy,\n    global_strategy: AllReduceStrategy,\n}\n\nimpl Display for AllReduceTest {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        let op_str = match self.op {\n            ReduceOperation::Sum => \"sum\",\n            ReduceOperation::Mean => \"mean\",\n        };\n        let local_strategy_str = match self.local_strategy {\n            AllReduceStrategy::Centralized => \"local_centralized\",\n            AllReduceStrategy::Tree(n) => &format!(\"local_tree_{n}\"),\n            AllReduceStrategy::Ring => \"local_ring\",\n        };\n        let global_strategy_str = match self.global_strategy {\n            AllReduceStrategy::Centralized => \"global_centralized\",\n            AllReduceStrategy::Tree(n) => &format!(\"global_tree_{n}\"),\n            AllReduceStrategy::Ring => \"global_ring\",\n        };\n\n        write!(f, \"{op_str}_{local_strategy_str}_{global_strategy_str}\")\n    }\n}\n\n/// Framed TCP connection for sending tests and receiving results\ntype TestChannel = tokio_serde::Framed<\n    tokio_util::codec::Framed<tokio::net::TcpStream, LengthDelimitedCodec>,\n    NodeTestResult,\n    NodeTest,\n    MessagePack<NodeTestResult, NodeTest>,\n>;\n\n/// Handle for each node process\nstruct NodeProcessHandle {\n    process: Child,\n    channel: TestChannel,\n}\n\n/// Main function to run the multi-node all-reduce test.\n/// Launches a orchestrator and multiple nodes based on the provided topology.\n#[tokio::main(flavor = \"multi_thread\", worker_threads = 10)]\nasync fn main() {\n    let all_reduce_tests = vec![\n        AllReduceTest {\n            shape: [4, 64, 512],\n            op: ReduceOperation::Mean,\n            local_strategy: AllReduceStrategy::Tree(2),\n            global_strategy: AllReduceStrategy::Tree(2),\n        },\n        AllReduceTest {\n            shape: [4, 64, 512],\n            op: ReduceOperation::Mean,\n            local_strategy: AllReduceStrategy::Tree(2),\n            global_strategy: AllReduceStrategy::Ring,\n        },\n        AllReduceTest {\n            shape: [4, 64, 512],\n            op: ReduceOperation::Mean,\n            local_strategy: AllReduceStrategy::Centralized,\n            global_strategy: AllReduceStrategy::Centralized,\n        },\n    ];\n\n    let test_files_dir = \"target/test_files\";\n    fs::create_dir_all(test_files_dir).expect(\"Couldn't create test_files directory\");\n\n    let topology: Vec<usize> = vec![4; 4];\n\n    let mut orchestrator = launch_orchestrator(test_files_dir);\n\n    let launcher_endpoint = \"127.0.0.1:4000\";\n\n    // Build and run node processes\n    let mut all_tests_durations = vec![];\n    if let Ok(mut nodes) = launch_nodes(&topology, launcher_endpoint).await {\n        // Run one test\n        for test in all_reduce_tests.clone() {\n            let test_name = test.to_string();\n\n            let time =\n                test_all_reduce_centralized_no_collective::<NdArray>(&topology, test.clone());\n            println!(\n                \"{test_name}: Benchmark (no collective, centralized, single-threaded): {} secs\",\n                time.as_secs_f32()\n            );\n\n            match test_all_reduce(&topology, test, &mut nodes).await {\n                Err(node_idx) => {\n                    println!(\"{test_name}: Node with index {node_idx} failed!\");\n                    // Kill other node processes\n                    for mut node in nodes.drain(..) {\n                        node.process.kill().await.unwrap();\n                        node.process.wait().await.unwrap();\n                    }\n                    break;\n                }\n                Ok(durations) => {\n                    all_tests_durations.append(&mut durations.clone());\n                    let avg = durations.iter().map(|dur| dur.as_secs_f32()).sum::<f32>()\n                        / durations.len() as f32;\n                    println!(\"{test_name}: Success in {avg} secs\");\n                }\n            }\n        }\n    }\n\n    if !all_tests_durations.is_empty() {\n        let avg = all_tests_durations\n            .iter()\n            .map(|dur| dur.as_secs_f32())\n            .sum::<f32>()\n            / all_tests_durations.len() as f32;\n        println!(\"Average for all tests: {avg} secs\");\n    }\n\n    // Shutdown orchestrator\n    orchestrator.kill().await.unwrap();\n    orchestrator.wait().await.unwrap();\n}\n\n/// Launch a global orchestrator with an output file in the given directory.\n/// Necessary for global collective operations\n///\n/// Server listens on localhost port 3000\nfn launch_orchestrator(test_files_dir: &str) -> Child {\n    let out_path = format!(\"{test_files_dir}/orchestrator_out.txt\");\n    let out = File::create(out_path).expect(\"Could't create orchestrator output file\");\n\n    Command::new(\"cargo\")\n        .args([\"run\", \"--bin\", \"global\", \"--\", \"3000\"])\n        .stdout(out.try_clone().unwrap())\n        .stderr(out)\n        .spawn()\n        .expect(\"failed to launch orchestrator\")\n}\n\n/// Launch nodes for a all_reduce test\n/// Each node will connect to the global orchestrator and run an all-reduce operation.\n/// The topology is a vector where each element represents the number of devices in that node.\nasync fn launch_nodes(\n    topology: &[usize],\n    launcher_endpoint: &str,\n) -> Result<Vec<NodeProcessHandle>, ()> {\n    println!(\n        \"Launching {} nodes with topology: {:?}\",\n        topology.len(),\n        topology\n    );\n\n    // Listen for node connections\n    let listener = TcpListener::bind(launcher_endpoint).await.unwrap();\n    println!(\"Server listening on {launcher_endpoint}\");\n\n    let mut nodes = vec![];\n\n    for node_idx in 0..topology.len() {\n        // Create log file\n        let output_filename = format!(\"target/test_files/node_{}_log.txt\", node_idx + 1);\n        let out = File::create(output_filename).expect(\"Could't open node log file\");\n\n        // Start a process for each node. Pass on our feature flags\n        let node_process: Child = Command::new(\"cargo\")\n            .args([\n                \"run\",\n                \"--release\",\n                \"--features\",\n                #[cfg(feature = \"ndarray\")]\n                \"ndarray\",\n                \"--bin\",\n                \"node\",\n                \"--\",\n                launcher_endpoint,\n                &node_idx.to_string(),\n            ])\n            .stdout(out.try_clone().unwrap())\n            .stderr(out)\n            .spawn()\n            .expect(\"node failed\");\n\n        // Wait for child to connect for io\n        let (socket, _peer_addr) = listener.accept().await.unwrap();\n        let length_delimited = tokio_util::codec::Framed::new(socket, LengthDelimitedCodec::new());\n        let channel: TestChannel = tokio_serde::Framed::new(\n            length_delimited,\n            MessagePack::<NodeTestResult, NodeTest>::default(),\n        );\n\n        nodes.push(NodeProcessHandle {\n            process: node_process,\n            channel,\n        });\n    }\n\n    Ok(nodes)\n}\n\nasync fn test_all_reduce(\n    topology: &[usize],\n    test: AllReduceTest,\n    nodes: &mut [NodeProcessHandle],\n) -> Result<Vec<Duration>, usize> {\n    dispatch_all_reduce_test(topology, test, nodes).await;\n\n    let mut all_durations = vec![];\n    for (idx, handle) in nodes.iter_mut().enumerate() {\n        match handle.channel.next().await {\n            Some(Ok(mut result)) => {\n                if !result.success {\n                    return Err(idx);\n                }\n                all_durations.append(&mut result.durations);\n            }\n            _ => {\n                return Err(idx);\n            }\n        }\n    }\n\n    Ok(all_durations)\n}\n\nasync fn dispatch_all_reduce_test(\n    topology: &[usize],\n    test: AllReduceTest,\n    nodes: &mut [NodeProcessHandle],\n) {\n    let total_device_count: usize = topology.iter().sum();\n    let (mut all_inputs, expected) =\n        generate_random_input(test.shape, test.op, total_device_count, 42);\n\n    // URL for the global orchestrator on port 3000\n    let global_url = \"ws://localhost:3000\";\n    let global_address = Address::from_str(global_url).unwrap();\n\n    for (node_idx, &device_count) in topology.iter().enumerate() {\n        // Construct URL for node\n        // Ports 3001... are for each node\n        let data_service_port = node_idx as u16 + 3001;\n        let node_url = format!(\"ws://localhost:{data_service_port}\");\n        let node_address = Address::from_str(&node_url).unwrap();\n\n        // take input tensors for each device\n        let inputs = all_inputs[0..device_count].to_vec();\n        all_inputs = all_inputs[device_count..].to_vec();\n\n        let test = NodeTest {\n            device_count,\n            node_id: node_idx.into(),\n            node_count: topology.len() as u32,\n            global_address: global_address.clone(),\n            node_address,\n            data_service_port,\n            all_reduce_op: test.op,\n            global_strategy: test.global_strategy,\n            local_strategy: test.local_strategy,\n            inputs,\n            expected: expected.clone(),\n        };\n        let handle = &mut nodes[node_idx];\n\n        handle.channel.send(test).await.unwrap();\n    }\n\n    assert!(\n        all_inputs.is_empty(),\n        \"Not all inputs have been sent to tests\"\n    );\n}\n\n/// Run the test sequentially with no collective operations to get the optimal single-threaded speed\nfn test_all_reduce_centralized_no_collective<B: Backend>(\n    topology: &[usize],\n    test: AllReduceTest,\n) -> Duration {\n    let total_device_count: usize = topology.iter().sum();\n    let (all_inputs, _expected) =\n        generate_random_input(test.shape, test.op, total_device_count, 42);\n\n    let mut all_inputs = all_inputs\n        .into_iter()\n        .map(|data| Tensor::<B, 3>::from_data(data, &B::Device::default()));\n\n    let start = Instant::now();\n\n    // Sequential test\n    let mut result = all_inputs.next().unwrap();\n    for other in all_inputs {\n        result = result.add(other);\n    }\n    if test.op == ReduceOperation::Mean {\n        result.div_scalar(total_device_count as u32);\n    }\n\n    start.elapsed()\n}\n\n/// Generates random input tensors and expected output based on the provided shape and reduce kind.\nfn generate_random_input(\n    shape: [usize; 3],\n    reduce_kind: ReduceOperation,\n    input_count: usize,\n    seed: u64,\n) -> (Vec<TensorData>, TensorData) {\n    let mut rng = StdRng::seed_from_u64(seed);\n\n    // A random tensor for each device\n    let input: Vec<TensorData> = (0..input_count)\n        .map(|_| {\n            TensorData::random::<f32, _, _>(shape, burn::tensor::Distribution::Default, &mut rng)\n        })\n        .collect();\n\n    // Sum up the inputs\n    let device = <NdArray as Backend>::Device::default();\n    let mut expected_tensor = Tensor::<NdArray, TENSOR_RANK>::zeros(shape, &device);\n    for item in input.iter().take(input_count) {\n        let input_tensor = Tensor::<NdArray, TENSOR_RANK>::from_data(item.clone(), &device);\n        expected_tensor = expected_tensor.add(input_tensor);\n    }\n\n    if reduce_kind == ReduceOperation::Mean {\n        expected_tensor = expected_tensor.div_scalar(input_count as u32);\n    }\n\n    // All-Reduce results should have this value\n    let expected = expected_tensor.to_data();\n\n    (input, expected)\n}\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/src/lib.rs",
    "content": "pub mod shared;\n"
  },
  {
    "path": "crates/burn-collective/multinode-tests/src/shared.rs",
    "content": "use std::time::Duration;\n\nuse burn::tensor::TensorData;\nuse burn_collective::{AllReduceStrategy, NodeId, ReduceOperation};\nuse burn_communication::Address;\nuse serde::{Deserialize, Serialize};\n\n/// Ranks of inputs and outputs for all testing\npub const TENSOR_RANK: usize = 3;\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct NodeTest {\n    /// How many threads to start on this node\n    pub device_count: usize,\n    /// ID for this node\n    pub node_id: NodeId,\n    /// How many nodes in the cluster\n    pub node_count: u32,\n    /// Global server address\n    pub global_address: Address,\n    /// Node address\n    pub node_address: Address,\n    /// Node's data service port, for initializing the p2p tensor data service\n    pub data_service_port: u16,\n    /// What kind of all-reduce\n    pub all_reduce_op: ReduceOperation,\n    /// Node's data service port, for initializing the p2p tensor data service\n    pub global_strategy: AllReduceStrategy,\n    /// What kind of aggregation\n    pub local_strategy: AllReduceStrategy,\n\n    /// Input data for test: all tensors are D=3\n    pub inputs: Vec<TensorData>,\n    /// Expected output for test\n    pub expected: TensorData,\n}\n\n/// Result sent back from each node for each test\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct NodeTestResult {\n    pub success: bool,\n    pub durations: Vec<Duration>,\n}\n"
  },
  {
    "path": "crates/burn-collective/src/api.rs",
    "content": "use burn_tensor::backend::Backend;\n\nuse crate::{\n    CollectiveConfig, PeerId, ReduceOperation, global::shared::GlobalCollectiveError,\n    local::server::get_collective_client,\n};\n\n/// Errors from collective operations\n#[allow(unused)]\n#[derive(Debug, Clone)]\npub enum CollectiveError {\n    /// The [config](CollectiveConfig) was invalid.\n    /// Usually happens if only some global parameters have been defined\n    InvalidConfig,\n    /// Cannot un-register a node twice\n    MultipleUnregister,\n    /// Cannot register a node twice\n    MultipleRegister,\n    /// Trying to register a different way than is currently being done\n    RegisterParamsMismatch,\n    /// Trying to all-reduce tensors of different shapes: shape must match\n    AllReduceShapeMismatch,\n    /// Trying to all-reduce a different way than is currently being done: op must match\n    AllReduceOperationMismatch,\n    /// Trying to reduce tensors of different shapes: shape must match\n    ReduceShapeMismatch,\n    /// Trying to reduce a different way than is currently being done: op must match\n    ReduceOperationMismatch,\n    /// Trying to reduce with different roots\n    ReduceRootMismatch,\n    /// Trying to broadcast with different roots\n    BroadcastRootMismatch,\n    /// Trying to broadcast but no peer sent a tensor\n    BroadcastNoTensor,\n    /// Trying to broadcast but multiple peers sent a tensor\n    BroadcastMultipleTensors,\n    /// Local collective server couldn't respond\n    LocalServerMissing,\n    /// Another operation was called before Register\n    RegisterNotFirstOperation,\n    /// The global orchestrator had an error\n    Global(GlobalCollectiveError),\n\n    #[allow(unused)]\n    Other(String),\n}\n\n/// Registers a device. `num_devices` must be the same for every register,\n/// and `device_id` must be unique.\n///\n/// * `id` - The peer id of the caller\n///\n/// With auto-diff backends, make sure to use the inner backend.\npub fn register<B: Backend>(\n    id: PeerId,\n    device: B::Device,\n    config: CollectiveConfig,\n) -> Result<(), CollectiveError> {\n    log::info!(\"Registering peer {id} with config: {config}\");\n    let mut client = get_collective_client::<B>();\n    client.register(id, device, config)\n}\n\n/// Calls for an all-reduce operation with the given parameters, and returns the result.\n/// The `params` must be the same as the parameters passed by the other nodes.\n///\n/// * `id` - The peer id of the caller\n/// * `tensor` - The input tensor to reduce with the peers' tensors\n/// * `config` - Config of the collective operation, must be coherent with the other calls\npub fn all_reduce<B: Backend>(\n    id: PeerId,\n    tensor: B::FloatTensorPrimitive,\n    op: ReduceOperation,\n) -> Result<B::FloatTensorPrimitive, CollectiveError> {\n    let client = get_collective_client::<B>();\n    client.all_reduce(id, tensor, op)\n}\n\n/// Broadcasts, or receives a broadcasted tensor.\n///\n/// * `id` - The peer id of the caller\n/// * `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive\n///   the broadcasted tensor.\n///\n/// Returns the broadcasted tensor.\npub fn broadcast<B: Backend>(\n    id: PeerId,\n    tensor: Option<B::FloatTensorPrimitive>,\n) -> Result<B::FloatTensorPrimitive, CollectiveError> {\n    let client = get_collective_client::<B>();\n    client.broadcast(id, tensor)\n}\n\n/// Reduces a tensor onto one device.\n///\n/// * `id` - The peer id of the caller\n/// * `tensor` - The tensor to send as input\n/// * `root` - The ID of the peer that will receive the result.\n///\n/// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.\npub fn reduce<B: Backend>(\n    id: PeerId,\n    tensor: B::FloatTensorPrimitive,\n    op: ReduceOperation,\n    root: PeerId,\n) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {\n    let client = get_collective_client::<B>();\n    client.reduce(id, tensor, op, root)\n}\n\n/// Closes the collective session, unregistering the device\npub fn finish_collective<B: Backend>(id: PeerId) -> Result<(), CollectiveError> {\n    let client = get_collective_client::<B>();\n    client.finish(id)\n}\n\n/// Resets the local collective server. All registered callers and ongoing operations are forgotten\npub fn reset_collective<B: Backend>() {\n    let client = get_collective_client::<B>();\n    client.reset();\n}\n"
  },
  {
    "path": "crates/burn-collective/src/config.rs",
    "content": "use std::fmt::Display;\n\nuse burn_communication::Address;\nuse serde::{Deserialize, Serialize};\n\n/// Parameter struct for setting up and getting parameters for collective operations.\n/// Used in most collective api calls.\n/// This config is per-node. It is passed to [reduce](crate::register).\n#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]\npub struct CollectiveConfig {\n    pub(crate) num_devices: usize,\n    pub(crate) local_all_reduce_strategy: AllReduceStrategy,\n    pub(crate) local_reduce_strategy: ReduceStrategy,\n    pub(crate) local_broadcast_strategy: BroadcastStrategy,\n\n    // Global parameters (all are optional, but if one is defined they should all be)\n    pub(crate) num_nodes: Option<u32>,\n    pub(crate) global_address: Option<Address>,\n    pub(crate) node_address: Option<Address>,\n    pub(crate) data_service_port: Option<u16>,\n\n    // These strategies may be defined when no other global params are defined\n    pub(crate) global_all_reduce_strategy: Option<AllReduceStrategy>,\n    pub(crate) global_reduce_strategy: Option<ReduceStrategy>,\n    pub(crate) global_broadcast_strategy: Option<BroadcastStrategy>,\n}\n\nimpl Default for CollectiveConfig {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Display for CollectiveConfig {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        let num_devices = self.num_devices;\n        let local_all_reduce_strategy = self.local_all_reduce_strategy;\n        let local_reduce_strategy = self.local_reduce_strategy;\n        let local_broadcast_strategy = self.local_broadcast_strategy;\n        let num_nodes = self.num_nodes;\n        let global_address = &self.global_address;\n        let node_address = &self.node_address;\n        let data_service_port = self.data_service_port;\n        let global_all_reduce_strategy = self.global_all_reduce_strategy;\n        let global_reduce_strategy = self.global_reduce_strategy;\n        let global_broadcast_strategy = self.global_broadcast_strategy;\n\n        write!(\n            f,\n            r#\"\nCollectiveConfig {{\n    num_devices: {num_devices:?},\n    local_all_reduce_strategy: {local_all_reduce_strategy:?},\n    local_reduce_strategy: {local_reduce_strategy:?},\n    local_broadcast_strategy: {local_broadcast_strategy:?},\n    num_nodes: {num_nodes:?},\n    global_address: {global_address:?},\n    node_address: {node_address:?},\n    data_service_port: {data_service_port:?},\n    global_all_reduce_strategy: {global_all_reduce_strategy:?},\n    global_reduce_strategy: {global_reduce_strategy:?},\n    global_broadcast_strategy: {global_broadcast_strategy:?},\n}}\n\"#\n        )\n    }\n}\n\nimpl CollectiveConfig {\n    fn new() -> Self {\n        Self {\n            num_devices: 1,\n            local_all_reduce_strategy: AllReduceStrategy::Tree(2),\n            local_reduce_strategy: ReduceStrategy::Tree(2),\n            local_broadcast_strategy: BroadcastStrategy::Tree(2),\n\n            num_nodes: None,\n            global_address: None,\n            node_address: None,\n            data_service_port: None,\n            global_all_reduce_strategy: Some(AllReduceStrategy::Ring),\n            global_reduce_strategy: Some(ReduceStrategy::Tree(2)),\n            global_broadcast_strategy: Some(BroadcastStrategy::Tree(2)),\n        }\n    }\n\n    /// Selects the number of devices (local peers) on the current node\n    pub fn with_num_devices(mut self, num: usize) -> Self {\n        self.num_devices = num;\n        self\n    }\n\n    /// Selects an all-reduce strategy to use on the local level.\n    ///\n    /// In multi-node contexts, use of the Ring strategy in the local level may be less\n    /// advantageous. With multiple nodes, the global all-reduce step is enabled, and its result\n    /// is redistributed to all devices.\n    /// The Ring strategy inherently distributes the result, which in this context would not be\n    /// necessary.\n    ///\n    /// It is recommended to use a tree strategy locally, and a ring strategy globally.\n    pub fn with_local_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {\n        self.local_all_reduce_strategy = strategy;\n        self\n    }\n\n    /// Selects a reduce strategy to use on the local level.\n    pub fn with_local_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {\n        self.local_reduce_strategy = strategy;\n        self\n    }\n\n    /// Selects a broadcast strategy to use on the local level.\n    pub fn with_local_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {\n        self.local_broadcast_strategy = strategy;\n        self\n    }\n\n    /// Set the number of nodes in the collective\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts\n    pub fn with_num_nodes(mut self, n: u32) -> Self {\n        self.num_nodes = Some(n);\n        self\n    }\n\n    /// Set the network address of the Global Collective Orchestrator\n    ///  \n    /// This parameter is a global parameter and should only be set in multi-node contexts\n    pub fn with_global_address(mut self, addr: Address) -> Self {\n        self.global_address = Some(addr);\n        self\n    }\n\n    /// Define the address for this node\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts\n    pub fn with_node_address(mut self, addr: Address) -> Self {\n        self.node_address = Some(addr);\n        self\n    }\n\n    /// Selects the network port on which to expose the tensor data service\n    /// used for peer-to-peer tensor downloading.\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts\n    pub fn with_data_service_port(mut self, port: u16) -> Self {\n        self.data_service_port = Some(port);\n        self\n    }\n\n    /// Selects an all-reduce strategy to use on the global level.\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts.\n    /// See [the local strategy](Self::with_local_all_reduce_strategy)\n    pub fn with_global_all_reduce_strategy(mut self, strategy: AllReduceStrategy) -> Self {\n        self.global_all_reduce_strategy = Some(strategy);\n        self\n    }\n\n    /// Selects an reduce strategy to use on the global level.\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts.\n    /// See [the local strategy](Self::with_local_reduce_strategy)\n    pub fn with_global_reduce_strategy(mut self, strategy: ReduceStrategy) -> Self {\n        self.global_reduce_strategy = Some(strategy);\n        self\n    }\n\n    /// Selects an broadcst strategy to use on the global level.\n    ///\n    /// This parameter is a global parameter and should only be set in multi-node contexts.\n    /// See [the local strategy](Self::with_local_broadcast_strategy)\n    pub fn with_global_broadcast_strategy(mut self, strategy: BroadcastStrategy) -> Self {\n        self.global_broadcast_strategy = Some(strategy);\n        self\n    }\n\n    /// Returns whether the config is valid. If only some required global-level parameters are\n    /// defined and others are not, the config is invalid.  \n    pub fn is_valid(&self) -> bool {\n        match (\n            self.num_nodes,\n            &self.global_address,\n            &self.node_address,\n            self.data_service_port,\n        ) {\n            (None, None, None, None) => true,\n            (Some(_), Some(_), Some(_), Some(_)) => true,\n            // Global parameters have only been partially defined!\n            _ => false,\n        }\n    }\n\n    /// Return the global parameters for registering in a multi-node context.\n    ///\n    /// If only some global parameters are defined, returns None. Use [is_valid](Self::is_valid) to check for\n    /// validity in this case.\n    pub(crate) fn global_register_params(&self) -> Option<GlobalRegisterParams> {\n        match (\n            self.num_nodes,\n            &self.global_address,\n            &self.node_address,\n            self.data_service_port,\n        ) {\n            // Only local collective\n            (None, None, None, None) => None,\n            // Local + global collective\n            (Some(num_nodes), Some(global_addr), Some(node_addr), Some(data_service_port)) => {\n                Some(GlobalRegisterParams {\n                    num_nodes,\n                    global_address: global_addr.clone(),\n                    node_address: node_addr.clone(),\n                    data_service_port,\n                })\n            }\n            // Config is invalid!\n            _ => None,\n        }\n    }\n}\n\n/// Helper struct for parameters in a multi-node register operation. Either they are all defined,\n/// or all not defined. Passed to the global client for registering on the global level and\n/// opening the p2p tensor service.\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct GlobalRegisterParams {\n    /// The address for the connection to the global orchestrator.\n    pub global_address: Address,\n    /// The address for the connection to this node.\n    pub node_address: Address,\n    /// The port on which to open the tensor data service for peer-to-peer tensor transfers with\n    /// other nodes. Should match the port given in the node url.\n    pub data_service_port: u16,\n\n    /// The number of nodes globally. Should be the same between different nodes\n    pub num_nodes: u32,\n}\n\n/// Parameters for an all-reduce that should be the same between all devices\n#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]\npub struct SharedAllReduceParams {\n    pub op: ReduceOperation,\n    pub local_strategy: AllReduceStrategy,\n    pub global_strategy: Option<AllReduceStrategy>,\n}\n\n/// Parameters for a reduce that should be the same between all devices\n#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]\npub struct SharedReduceParams {}\n\n/// Parameters for a broadcast that should be the same between all devices\n#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]\npub struct SharedBroadcastParams {\n    pub op: ReduceOperation,\n    pub local_strategy: BroadcastStrategy,\n    pub global_strategy: Option<BroadcastStrategy>,\n}\n\n/// Reduce can be done different ways\n#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]\npub enum ReduceOperation {\n    Sum,\n    Mean,\n}\n\n/// All reduce can be implemented with different algorithms, which all have the same result.\n#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]\npub enum AllReduceStrategy {\n    /// One device is the \"central\". The other devices, \"peripherals\", send their tensors to the\n    /// central. The central does the reduction, and sends the result back to each peripheral.  \n    Centralized,\n\n    /// Devices are organized in a tree structure (with a given arity). Each node reduces its\n    /// children's tensors with its own, and sends the result to its parent. Leaf nodes will\n    /// simply send their tensors to their parents.\n    /// When the root node calculates the result, it is propagated down the tree.\n    Tree(u32),\n\n    /// Devices are organized in a ring. The tensors are split into N slices, where N is the\n    /// number of devices participating. The slices are progressively sent around the ring until\n    /// every device has one fully reduced slice of the tensor. Then, the resulting slices are sent\n    /// around until every device has the full result.\n    /// See `ring.rs` for details.\n    Ring,\n}\n\n/// Reduce can be implemented with different algorithms, which all have the same result.\n#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]\npub enum ReduceStrategy {\n    /// See [all-reduce](AllReduceStrategy::Centralized)\n    Centralized,\n\n    /// See [all-reduce](AllReduceStrategy::Tree)\n    Tree(u32),\n}\n\n/// Broadcast can be implemented with different algorithms, which all have the same result.\n#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]\npub enum BroadcastStrategy {\n    /// See [all-reduce](AllReduceStrategy::Centralized)\n    Centralized,\n\n    /// See [all-reduce](AllReduceStrategy::Tree)\n    Tree(u32),\n}\n\n/// A unique identifier for a peer in the context of collective operations.\n/// They must be unique, even in multi-node contexts.\n///\n/// This is like the rank in NCCL\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\npub struct PeerId(u32);\n\nimpl Display for PeerId {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        write!(f, \"PeerId({})\", self.0)\n    }\n}\n\nimpl From<u32> for PeerId {\n    fn from(value: u32) -> Self {\n        Self(value)\n    }\n}\n\nimpl From<i32> for PeerId {\n    fn from(value: i32) -> Self {\n        Self(value as u32)\n    }\n}\n\nimpl From<usize> for PeerId {\n    fn from(value: usize) -> Self {\n        Self(value as u32)\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/base.rs",
    "content": "use serde::{Deserialize, Serialize};\n\n/// Unique identifier for any node in the global collective.\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]\npub struct NodeId(u32);\n\nimpl From<u32> for NodeId {\n    fn from(value: u32) -> Self {\n        Self(value)\n    }\n}\n\nimpl From<usize> for NodeId {\n    fn from(value: usize) -> Self {\n        Self(value as u32)\n    }\n}\n\nimpl From<i32> for NodeId {\n    fn from(value: i32) -> Self {\n        Self(value as u32)\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/mod.rs",
    "content": "pub(crate) mod node;\npub(crate) mod shared;\n\n#[cfg(feature = \"orchestrator\")]\npub mod orchestrator;\n#[cfg(feature = \"orchestrator\")]\npub use orchestrator::*;\n\nmod base;\npub use base::*;\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/base.rs",
    "content": "use burn_communication::Protocol;\nuse burn_communication::data_service::TensorDataServer;\nuse burn_communication::{Address, ProtocolServer, data_service::TensorDataService};\nuse burn_tensor::backend::Backend;\nuse std::collections::HashMap;\nuse std::{marker::PhantomData, sync::Arc};\nuse tokio::sync::RwLock;\nuse tokio_util::sync::CancellationToken;\n\nuse crate::node::sync::SyncService;\nuse crate::{\n    AllReduceStrategy, BroadcastStrategy, GlobalRegisterParams, NodeId, PeerId, ReduceStrategy,\n};\nuse crate::{\n    ReduceOperation,\n    global::{\n        node::{\n            centralized::centralized_all_reduce_sum, ring::ring_all_reduce_sum,\n            tree::tree_all_reduce_sum, worker::GlobalClientWorker,\n        },\n        shared::{GlobalCollectiveError, RemoteRequest, RemoteResponse},\n    },\n    local::server::get_collective_server_runtime,\n};\n\n/// Must be synchronized between all nodes for collective operations to work\npub(crate) struct NodeState {\n    pub node_id: NodeId,\n    pub nodes: HashMap<NodeId, Address>,\n    pub num_global_devices: u32,\n}\n\n/// A node talks to the global orchestrator as well as other nodes with a peer-to-peer service\npub(crate) struct Node<B, P>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    // State is written during `register` and read during other operations,\n    // sometimes by multiple threads (ex. syncing during an all-reduce)\n    state: Arc<RwLock<Option<NodeState>>>,\n    data_service: Arc<TensorDataService<B, P>>,\n    sync_service: Arc<SyncService<P>>,\n    worker: GlobalClientWorker<P::Client>,\n    _n: PhantomData<P>,\n}\n\nimpl<B, P> Node<B, P>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    pub fn new(global_address: &Address, comms_server: P::Server) -> Self {\n        let state = Arc::new(tokio::sync::RwLock::new(None));\n        let cancel_token = CancellationToken::new();\n        let data_service = Arc::new(TensorDataService::new(cancel_token.clone()));\n        let sync_service = Arc::new(SyncService::new(state.clone()));\n\n        let runtime = get_collective_server_runtime();\n        let server = comms_server\n            .route_tensor_data_service(data_service.clone())\n            .route(\"/sync\", {\n                let sync_service = sync_service.clone();\n                async move |channel: <P::Server as ProtocolServer>::Channel| {\n                    sync_service.handle_sync_connection(channel).await;\n                }\n            })\n            .serve({\n                let cancel_token = cancel_token.clone();\n                async move { cancel_token.cancelled().await }\n            });\n\n        runtime.spawn(server);\n\n        let worker = GlobalClientWorker::new(&runtime, cancel_token.clone(), global_address);\n\n        Self {\n            state,\n            data_service,\n            sync_service,\n            worker,\n            _n: PhantomData,\n        }\n    }\n\n    pub async fn register(\n        &mut self,\n        peers: Vec<PeerId>,\n        global_params: GlobalRegisterParams,\n    ) -> Result<(), GlobalCollectiveError> {\n        let req = RemoteRequest::Register {\n            node_addr: global_params.node_address,\n            num_nodes: global_params.num_nodes,\n            peers,\n        };\n        match self.worker.request(req).await {\n            RemoteResponse::Register {\n                node_id,\n                nodes,\n                num_global_devices,\n            } => {\n                let mut state = self.state.write().await;\n                *state = Some(NodeState {\n                    node_id,\n                    nodes,\n                    num_global_devices,\n                });\n            }\n            RemoteResponse::Error(err) => {\n                return Err(err);\n            }\n            resp => {\n                log::error!(\"Response to a register request should be an ack, not {resp:?}\");\n                return Err(GlobalCollectiveError::WrongOrchestratorResponse);\n            }\n        }\n\n        Ok(())\n    }\n\n    /// Performs an all-reduce\n    ///\n    /// Reads the NodeState\n    pub async fn all_reduce(\n        &self,\n        tensor: B::FloatTensorPrimitive,\n        strategy: AllReduceStrategy,\n        op: ReduceOperation,\n    ) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {\n        let state = self.state.read().await;\n        let Some(ref state) = *state else {\n            return Err(GlobalCollectiveError::AllReduceBeforeRegister);\n        };\n        let node = state.node_id;\n        let nodes = &state.nodes;\n\n        let mut result = match strategy {\n            AllReduceStrategy::Centralized => {\n                centralized_all_reduce_sum(\n                    node,\n                    nodes,\n                    &self.data_service,\n                    self.sync_service.clone(),\n                    tensor,\n                )\n                .await?\n            }\n            AllReduceStrategy::Tree(arity) => {\n                tree_all_reduce_sum(\n                    node,\n                    nodes,\n                    self.data_service.clone(),\n                    self.sync_service.clone(),\n                    tensor,\n                    arity,\n                )\n                .await?\n            }\n            AllReduceStrategy::Ring => {\n                ring_all_reduce_sum(\n                    node,\n                    nodes,\n                    self.data_service.clone(),\n                    self.sync_service.clone(),\n                    tensor,\n                )\n                .await?\n            }\n        };\n\n        if op == ReduceOperation::Mean {\n            result = B::float_div_scalar(result, (state.num_global_devices as f32).into());\n        }\n\n        Ok(result)\n    }\n\n    pub async fn reduce(\n        &self,\n        _tensor: B::FloatTensorPrimitive,\n        _strategy: ReduceStrategy,\n        _root: PeerId,\n        _op: ReduceOperation,\n    ) -> Result<Option<B::FloatTensorPrimitive>, GlobalCollectiveError> {\n        unimplemented!(\"Global reduce unimplemented\");\n    }\n\n    pub async fn broadcast(\n        &self,\n        _tensor: Option<B::FloatTensorPrimitive>,\n        _strategy: BroadcastStrategy,\n    ) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError> {\n        unimplemented!(\"Global broadcast unimplemented\");\n    }\n\n    pub async fn finish(&mut self) {\n        let res = self.worker.close_connection().await;\n        if let Err(err) = res {\n            log::error!(\"Global collective client error: {err:?}\");\n        }\n        self.data_service.close().await;\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/centralized.rs",
    "content": "use std::{collections::HashMap, sync::Arc};\n\nuse crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};\nuse burn_communication::data_service::TensorDataService;\nuse burn_communication::{Address, Protocol};\nuse burn_tensor::TensorMetadata;\nuse burn_tensor::backend::Backend;\nuse futures::StreamExt;\nuse futures::stream::FuturesUnordered;\n\n/// Global all-reduce, using a centralized strategy.\n///\n/// Returns the resulting tensor on the same device as the input tensor\npub(crate) async fn centralized_all_reduce_sum<B, P>(\n    node: NodeId,\n    nodes: &HashMap<NodeId, Address>,\n    data_service: &Arc<TensorDataService<B, P>>,\n    sync_service: Arc<SyncService<P>>,\n    tensor: B::FloatTensorPrimitive,\n) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    let ids = nodes.keys().cloned().collect::<Vec<_>>();\n    let central = get_central_node(ids.clone());\n\n    let shape = tensor.shape();\n    let device = &B::float_device(&tensor);\n\n    let res = if central == node {\n        // Transfer 1: download tensors from other nodes\n        let mut futures = ids\n            .iter()\n            .filter(|id| **id != central) // Only non-central nodes\n            .map(|id| {\n                let address = nodes.get(id).unwrap();\n                let device = device.clone();\n                let data_service = data_service.clone();\n                async move {\n                    let data = data_service\n                        .download_tensor((*address).clone(), 0.into())\n                        .await\n                        .expect(\"Couldn't find the tensor for transfer id 0\");\n                    B::float_from_data(data, &device)\n                }\n            })\n            .collect::<FuturesUnordered<_>>();\n\n        // Sum all downloads async\n        let mut sum = tensor;\n        while let Some(res) = futures.next().await {\n            if shape != res.shape() {\n                return Err(GlobalCollectiveError::PeerSentIncoherentTensor);\n            }\n            sum = B::float_add(sum, res);\n        }\n\n        // Transfer 2: Expose result\n        let other_nodes_count = ids.len() as u32 - 1;\n        data_service\n            .expose(sum.clone(), other_nodes_count, 1.into())\n            .await;\n\n        sum\n    } else {\n        // Transfer 1: Expose input\n        data_service.expose(tensor, 1, 0.into()).await;\n\n        // Transfer 2: Download result\n        let central_addr = nodes.get(&central).unwrap().clone();\n        let data = data_service\n            .download_tensor(central_addr, 1.into())\n            .await\n            .expect(\"Couldn't find the tensor for transfer id 1\");\n\n        let res = B::float_from_data(data, device);\n        if shape != res.shape() {\n            return Err(GlobalCollectiveError::PeerSentIncoherentTensor);\n        }\n\n        res\n    };\n\n    // Wait for all nodes to finish\n    sync_service.sync().await;\n\n    Ok(res)\n}\n\n/// Get the central node for a centralized all-reduce\npub(crate) fn get_central_node(mut nodes: Vec<NodeId>) -> NodeId {\n    nodes.sort();\n\n    *nodes.first().unwrap()\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/mod.rs",
    "content": "pub mod base;\npub mod centralized;\npub mod ring;\npub mod sync;\npub mod tree;\npub mod worker;\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/ring.rs",
    "content": "//! Implements the collective ring all-reduce algorithm on the global level\n\nuse core::ops::Range;\nuse std::{collections::HashMap, sync::Arc};\n\nuse crate::{\n    NodeId,\n    global::shared::GlobalCollectiveError,\n    local::{get_ring_reduce_slice_ranges, get_slice_dim},\n    node::sync::SyncService,\n};\nuse burn_communication::{Address, Protocol, data_service::TensorDataService};\nuse burn_tensor::{Slice, TensorMetadata, backend::Backend};\n\n// https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model\n\n// Example: tensors=3, slices=3\n\n// phase 1\n// o->o  o\n// o  o->o\n//>o  o  o->\n\n// o  1->o\n//>o  o  1->\n// 1->o  o\n\n// o  1  2\n// 2  o  1\n// 1  2  o\n\n// phase 2\n//>o  1  2->\n// 2->o  1\n// 1  2->o\n\n// 2->1  2\n// 2  2->1\n//>1  2  2->\n\n// 2  2  2\n// 2  2  2\n// 2  2  2\n\n/// Ring all-reduce algorithm with summation\n///\n/// * `node` - The id of the current node\n/// * `nodes` - Map of all nodes in the operation\n/// * `data_service` - The data service handles peer-to-peer tensor transfers\n/// * `sync_service` - The sync service handles syncing with peers\n/// * `tensor` - The tensor to reduce. At least one dimension size must be greater than the number\n///   of nodes\npub(crate) async fn ring_all_reduce_sum<B, P>(\n    node: NodeId,\n    nodes: &HashMap<NodeId, Address>,\n    data_service: Arc<TensorDataService<B, P>>,\n    sync_service: Arc<SyncService<P>>,\n    tensor: B::FloatTensorPrimitive,\n) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    let shape = tensor.shape();\n\n    let device = &B::float_device(&tensor);\n    // Slice tensors in N parts, N is node count\n    let slice_dim = get_slice_dim(&shape);\n    if shape[slice_dim] < nodes.len() {\n        return Err(GlobalCollectiveError::RingReduceImpossible);\n    }\n\n    let ring = get_ring_topology(nodes.keys().cloned().collect::<Vec<_>>());\n    let slice_ranges = get_ring_reduce_slice_ranges(shape[slice_dim], ring.len());\n    let mut slices = slice_tensor::<B>(tensor, slice_dim, slice_ranges);\n\n    let mut send_slice_idx = ring\n        .iter()\n        .position(|id| *id == node)\n        .expect(\"Node is in ring\");\n    let prev_node_idx = (send_slice_idx + ring.len() - 1) % ring.len(); // +ring.len for overflow\n    let prev_node = nodes.get(&ring[prev_node_idx]).unwrap();\n    let mut transfer_counter: u64 = 0;\n\n    // Phase 1: add\n    do_cycles::<B, P>(\n        &mut slices,\n        &mut transfer_counter,\n        &mut send_slice_idx,\n        true,\n        prev_node.clone(),\n        &data_service,\n        device,\n    )\n    .await?;\n\n    // Phase 2: replace\n    do_cycles::<B, P>(\n        &mut slices,\n        &mut transfer_counter,\n        &mut send_slice_idx,\n        false,\n        prev_node.clone(),\n        &data_service,\n        device,\n    )\n    .await?;\n\n    // Wait for all nodes to finish\n    sync_service.sync().await;\n\n    // merge slices\n    Ok(B::float_cat(slices, slice_dim))\n}\n\n/// Do N-1 cycles of ring-reduce\n///\n/// * `slices` - Slices of the original tensor, len equal to node count\n/// * `transfer_counter` - counter for each step (one send one receive)\n/// * `send_slice_idx` - counter for the index of each slice to send\n/// * `is_phase_one` - In phase 1, the tensors are aggregated. Otherwise, they are overridden\n/// * `data_service` - TensorDataService for peer-to-peer tensor transfers\n/// * `device` - The device on which all local tensors are stored. Should match `slices`\nasync fn do_cycles<B, P>(\n    slices: &mut [B::FloatTensorPrimitive],\n    transfer_counter: &mut u64,\n    send_slice_idx: &mut usize,\n    is_phase_one: bool,\n    prev_node: Address,\n    data_service: &Arc<TensorDataService<B, P>>,\n    device: &B::Device,\n) -> Result<(), GlobalCollectiveError>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    let slice_count = slices.len();\n    for _ in 0..(slice_count - 1) {\n        let transfer_id = (*transfer_counter).into();\n        // +slice_count to avoid overflow\n        let recv_slice_idx = (*send_slice_idx + slice_count - 1) % slice_count;\n        let slice_send = slices[*send_slice_idx].clone();\n\n        let upload = {\n            let data_service = data_service.clone();\n            tokio::spawn(async move {\n                data_service\n                    .expose(slice_send.clone(), 1, transfer_id)\n                    .await\n            })\n        };\n        let download = {\n            let data_client = data_service.clone();\n            let next_node = prev_node.clone();\n            tokio::spawn(async move { data_client.download_tensor(next_node, transfer_id).await })\n        };\n\n        upload.await.unwrap();\n        let download = download.await.unwrap();\n        if is_phase_one {\n            let download = download.expect(\"Peer closed download connection\");\n            let tensor = B::float_from_data(download, device);\n            slices[recv_slice_idx] = B::float_add(slices[recv_slice_idx].clone(), tensor);\n        } else {\n            let tensor = B::float_from_data(download.unwrap(), device);\n            let old_shape = slices[recv_slice_idx].shape();\n            if old_shape != tensor.shape() {\n                return Err(GlobalCollectiveError::PeerSentIncoherentTensor);\n            }\n            slices[recv_slice_idx] = tensor;\n        }\n\n        // Move slice index\n        *send_slice_idx = recv_slice_idx;\n        *transfer_counter += 1;\n    }\n\n    Ok(())\n}\n\n/// But a tensor into even slices across a dimension\n///\n/// * `tensor` - the tensor to slice\n/// * `slice_dim` - the dimension to slice across\n/// * `slice_ranges` - The ranges of indices on `slice_dim` to use when slicing the tensor\nfn slice_tensor<B: Backend>(\n    tensor: B::FloatTensorPrimitive,\n    slice_dim: usize,\n    slice_ranges: Vec<Range<usize>>,\n) -> Vec<B::FloatTensorPrimitive> {\n    let shape = tensor.shape();\n    // full range across all dims as Slice\n    let full_range = shape\n        .iter()\n        .map(|dim| Slice::from(0..*dim))\n        .collect::<Vec<Slice>>();\n\n    // Slice tensors\n    let mut slices = vec![];\n    for range in &slice_ranges {\n        let mut all_ranges = full_range.clone();\n        all_ranges[slice_dim] = Slice::from(range.clone());\n        let slice = B::float_slice(tensor.clone(), &all_ranges);\n        slices.push(slice);\n    }\n\n    slices\n}\n\n/// Get the ring topology\nfn get_ring_topology(mut nodes: Vec<NodeId>) -> Vec<NodeId> {\n    // This ordering could be more sophisticated, using node proximities etc\n    nodes.sort();\n\n    nodes\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/sync.rs",
    "content": "use std::{\n    marker::PhantomData,\n    sync::{Arc, Mutex},\n    vec,\n};\n\nuse burn_communication::{CommunicationChannel, Message, Protocol, ProtocolClient};\nuse serde::{Deserialize, Serialize};\nuse tokio::sync::{Notify, RwLock};\n\nuse crate::{NodeId, node::base::NodeState};\n\n/// Handles the status of sync requests from other nodes\npub(crate) struct SyncService<P: Protocol> {\n    /// Current node's state, shared with the thread that does aggregations\n    node_state: Arc<RwLock<Option<NodeState>>>,\n    /// The number of peers that have requested to sync with us since the last successful sync.\n    syncing_peers: Mutex<Vec<NodeId>>,\n    /// Notification on each incoming sync request\n    sync_notif: Notify,\n\n    _p: PhantomData<P>,\n}\n\n#[derive(Debug, Serialize, Deserialize)]\nstruct SyncRequest(NodeId);\n\nimpl<P: Protocol> SyncService<P> {\n    pub fn new(node_state: Arc<RwLock<Option<NodeState>>>) -> Self {\n        Self {\n            node_state,\n            syncing_peers: Mutex::new(vec![]),\n            sync_notif: Notify::new(),\n            _p: PhantomData,\n        }\n    }\n\n    fn add_syncing_peer(&self, peer: NodeId) {\n        let mut syncing_peers = self.syncing_peers.lock().unwrap();\n        syncing_peers.push(peer);\n    }\n\n    /// Sync with all peers.\n    pub async fn sync(&self) {\n        // we can't sync while we register\n        let node_state = self.node_state.read().await;\n        let node_state = node_state\n            .as_ref()\n            .expect(\"Trying to sync a node before having registered to the orchestrator\");\n\n        // this peer is syncing\n        self.add_syncing_peer(node_state.node_id);\n        for (id, addr) in &node_state.nodes {\n            if *id == node_state.node_id {\n                continue;\n            }\n\n            let mut connection = P::Client::connect(addr.clone(), \"sync\")\n                .await\n                .expect(\"Couldn't connect to peer for sync\");\n            let msg = SyncRequest(node_state.node_id);\n            let sync_bytes = rmp_serde::to_vec(&msg).unwrap();\n            connection\n                .send(Message::new(sync_bytes.into()))\n                .await\n                .expect(\"Peer closed connection unexpectedly\");\n        }\n        loop {\n            {\n                // compare currently synced peers with list of all nodes\n                let mut syncing_peers = self.syncing_peers.lock().unwrap().to_vec();\n                syncing_peers.sort();\n\n                let mut all_node_ids = node_state.nodes.keys().cloned().collect::<Vec<_>>();\n                all_node_ids.sort();\n\n                if syncing_peers == all_node_ids {\n                    // all nodes have synced\n                    syncing_peers.clear();\n                    return;\n                }\n            }\n            // Wait for the next sync to come in\n            self.sync_notif.notified().await\n        }\n    }\n\n    pub async fn handle_sync_connection<C: CommunicationChannel>(&self, mut channel: C) {\n        let msg = channel.recv().await.unwrap();\n        let Some(msg) = msg else {\n            return;\n        };\n\n        let msg = rmp_serde::from_slice::<SyncRequest>(&msg.data).unwrap();\n\n        self.add_syncing_peer(msg.0);\n\n        self.sync_notif.notify_waiters();\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/tree.rs",
    "content": "use std::{collections::HashMap, sync::Arc};\n\nuse crate::{NodeId, global::shared::GlobalCollectiveError, node::sync::SyncService};\nuse burn_communication::{Address, Protocol, data_service::TensorDataService};\nuse burn_tensor::{TensorMetadata, backend::Backend};\nuse futures::{StreamExt, stream::FuturesUnordered};\n\nstruct TreeTopology {\n    parents: HashMap<NodeId, NodeId>,\n    children: HashMap<NodeId, Vec<NodeId>>,\n}\n\n/// Global all-reduce, using a b-tree strategy.\n///\n/// Returns the resulting tensor on the same device as the input tensor\npub(crate) async fn tree_all_reduce_sum<B, P>(\n    node: NodeId,\n    nodes: &HashMap<NodeId, Address>,\n    data_service: Arc<TensorDataService<B, P>>,\n    sync_service: Arc<SyncService<P>>,\n    tensor: B::FloatTensorPrimitive,\n    arity: u32,\n) -> Result<B::FloatTensorPrimitive, GlobalCollectiveError>\nwhere\n    B: Backend,\n    P: Protocol,\n{\n    let shape = tensor.shape();\n    let device = &B::float_device(&tensor);\n\n    // Topology could be cached based on (nodes.keys().cloned(), arity)\n    let strategy = get_tree_topology(nodes.keys().cloned().collect::<Vec<_>>(), arity);\n\n    // Transfer 1: Download and sum tensors from children\n    let mut result = tensor;\n\n    if let Some(children) = strategy.children.get(&node) {\n        let mut downloads = children\n            .iter()\n            .map(|child| {\n                let child_addr = nodes.get(child).unwrap().clone();\n                let data_service = data_service.clone();\n                async move {\n                    let data = data_service\n                        .download_tensor(child_addr.clone(), 0.into())\n                        .await\n                        .ok_or(GlobalCollectiveError::PeerLost(*child))?;\n                    Ok::<B::FloatTensorPrimitive, GlobalCollectiveError>(B::float_from_data(\n                        data, device,\n                    ))\n                }\n            })\n            .collect::<FuturesUnordered<_>>();\n\n        for _ in children {\n            let res = downloads.next().await.unwrap().unwrap();\n            if res.shape() != shape {\n                return Err(GlobalCollectiveError::PeerSentIncoherentTensor);\n            }\n            result = B::float_add(result, res);\n        }\n    }\n\n    // Transfer 2: Expose result to parent and download final result if not root\n    if let Some(parent) = strategy.parents.get(&node) {\n        data_service.expose(result.clone(), 1, 0.into()).await;\n\n        let parent_addr = nodes.get(parent).unwrap().clone();\n\n        let data = data_service\n            .download_tensor(parent_addr.clone(), 1.into())\n            .await\n            .ok_or(GlobalCollectiveError::PeerLost(*parent))?;\n\n        let parent_tensor = B::float_from_data(data, device);\n        if parent_tensor.shape() != shape {\n            return Err(GlobalCollectiveError::PeerSentIncoherentTensor);\n        }\n        result = parent_tensor;\n    }\n\n    // Transfer 3: Expose final result to children (if any)\n    if let Some(children) = strategy.children.get(&node)\n        && !children.is_empty()\n    {\n        data_service\n            .expose(result.clone(), children.len() as u32, 1.into())\n            .await;\n    }\n\n    // Final barrier\n    sync_service.sync().await;\n\n    Ok(result)\n}\n\n/// Get the tree topology.\n///\n/// * `nodes` - List of node ids. Order doesn't matter. Nodes must be unique.\nfn get_tree_topology(mut nodes: Vec<NodeId>, arity: u32) -> TreeTopology {\n    assert!(arity >= 1, \"Arity must be ≥ 1\");\n\n    nodes.sort(); // Sort \n\n    let n = nodes.len();\n    let k = arity as usize;\n\n    let mut parents: HashMap<_, _> = HashMap::with_capacity(n);\n    let mut children: HashMap<_, _> = HashMap::with_capacity(n);\n\n    for (i, &parent_id) in nodes.iter().enumerate() {\n        // compute the window [first_child, last_child)\n        let first = i * k + 1;\n        if first < n {\n            let last = usize::min(first + k, n);\n            let mut ch = Vec::with_capacity(last - first);\n            for &child_id in &nodes[first..last] {\n                parents.insert(child_id, parent_id);\n                ch.push(child_id);\n            }\n            children.insert(parent_id, ch);\n        } else {\n            // leaf‐node: no children\n            children.insert(parent_id, Vec::new());\n        }\n    }\n\n    TreeTopology { parents, children }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    /// Test the tree topology algorithm with arity 2 and 7 nodes\n    #[test]\n    fn test_get_tree_topology_arity2_size7() {\n        let mut nodes = vec![];\n        for i in 0..7 {\n            nodes.push(i.into());\n        }\n\n        let topology = get_tree_topology(nodes, 2);\n\n        // Root is 0, so it should have no parent\n        assert!(!topology.parents.contains_key(&0.into()));\n\n        // Parents:\n        //   Node 1 and 2 → parent 0\n        //   Node 3 and 4 → parent 1\n        //   Node 5 and 6 → parent 2\n        let expected_parents = [\n            (1.into(), 0.into()),\n            (2.into(), 0.into()),\n            (3.into(), 1.into()),\n            (4.into(), 1.into()),\n            (5.into(), 2.into()),\n            (6.into(), 2.into()),\n        ];\n        for (child, parent) in &expected_parents {\n            assert_eq!(\n                topology.parents.get(child),\n                Some(parent),\n                \"wrong parent for {child:?}\"\n            );\n        }\n        // There should be exactly 6 entries in parents\n        assert_eq!(topology.parents.len(), expected_parents.len());\n\n        // Children:\n        //   0 → [1, 2]\n        //   1 → [3, 4]\n        //   2 → [5, 6]\n        //   3,4,5,6 → []\n        assert_eq!(\n            topology.children.get(&0.into()),\n            Some(&vec![1.into(), 2.into()])\n        );\n        assert_eq!(\n            topology.children.get(&1.into()),\n            Some(&vec![3.into(), 4.into()])\n        );\n        assert_eq!(\n            topology.children.get(&2.into()),\n            Some(&vec![5.into(), 6.into()])\n        );\n        // Leaves\n        for leaf in 3..7 {\n            assert_eq!(\n                topology.children.get(&leaf.into()),\n                Some(&Vec::new()),\n                \"leaf {leaf:?} should have no children\"\n            );\n        }\n        // Ensure we have exactly 7 entries in children\n        assert_eq!(topology.children.len(), 7);\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/node/worker.rs",
    "content": "use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};\n\nuse burn_communication::{Address, CommunicationChannel, Message, ProtocolClient};\nuse tokio::{\n    runtime::Runtime,\n    sync::{\n        Mutex,\n        mpsc::{Receiver, Sender},\n    },\n    task::JoinHandle,\n};\nuse tokio_util::sync::CancellationToken;\n\nuse crate::global::shared::{\n    CollectiveMessage, CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest,\n    RemoteResponse, RequestId, SessionId,\n};\n\n/// Worker that handles communication with the orchestrator for global collective operations.\npub(crate) struct GlobalClientWorker<P: ProtocolClient> {\n    handle: Option<JoinHandle<Result<(), GlobalCollectiveError>>>,\n    cancel_token: CancellationToken,\n    request_sender: Sender<ClientRequest>,\n    _phantom_data: PhantomData<P>,\n}\n\n// Rename\nstruct GlobalClientWorkerState {\n    requests: HashMap<RequestId, Sender<RemoteResponse>>,\n}\n\nimpl GlobalClientWorkerState {\n    fn new() -> Self {\n        Self {\n            requests: HashMap::new(),\n        }\n    }\n}\n\n#[derive(Debug)]\npub(crate) struct ClientRequest {\n    pub request: RemoteRequest,\n    pub callback: Sender<RemoteResponse>,\n}\n\nimpl ClientRequest {\n    pub(crate) fn new(request: RemoteRequest, callback: Sender<RemoteResponse>) -> Self {\n        Self { request, callback }\n    }\n}\n\nimpl<C: ProtocolClient> GlobalClientWorker<C> {\n    /// Create a new global client worker and start the tasks.\n    pub(crate) fn new(\n        runtime: &Runtime,\n        cancel_token: CancellationToken,\n        global_address: &Address,\n    ) -> Self {\n        let (request_sender, request_recv) = tokio::sync::mpsc::channel::<ClientRequest>(10);\n\n        let state = Arc::new(Mutex::new(GlobalClientWorkerState::new()));\n\n        let handle = runtime.spawn(Self::start(\n            state,\n            cancel_token.clone(),\n            global_address.clone(),\n            request_recv,\n        ));\n\n        Self {\n            handle: Some(handle),\n            cancel_token,\n            request_sender,\n            _phantom_data: PhantomData,\n        }\n    }\n\n    /// Start the global client tasks\n    async fn start(\n        state: Arc<Mutex<GlobalClientWorkerState>>,\n        cancel_token: CancellationToken,\n        global_address: Address,\n        request_recv: Receiver<ClientRequest>,\n    ) -> Result<(), GlobalCollectiveError> {\n        // Init the connection.\n        let (request, response) = Self::init_connection(&global_address).await?;\n\n        // Websocket async worker loading responses from the server.\n        let response_handle = tokio::spawn(Self::response_loader(\n            state.clone(),\n            response,\n            cancel_token.clone(),\n        ));\n\n        // Channel async worker sending operations to the server.\n        let request_handle = tokio::spawn(Self::request_sender(\n            request_recv,\n            state,\n            request,\n            cancel_token.clone(),\n        ));\n\n        if let Err(e) = response_handle.await {\n            log::error!(\"Response handler failed: {e:?}\");\n        }\n        if let Err(e) = request_handle.await {\n            log::error!(\"Request handler failed: {e:?}\");\n        }\n\n        Ok(())\n    }\n\n    async fn init_connection(\n        address: &Address,\n    ) -> Result<(C::Channel, C::Channel), GlobalCollectiveError> {\n        let session_id = SessionId::new();\n\n        let stream_request = tokio::spawn(Self::connect_with_retry(\n            address.clone(),\n            \"request\",\n            std::time::Duration::from_secs(1),\n            None,\n            session_id,\n        ));\n        let stream_response = tokio::spawn(Self::connect_with_retry(\n            address.clone(),\n            \"response\",\n            std::time::Duration::from_secs(1),\n            None,\n            session_id,\n        ));\n\n        let Ok(Some(request)) = stream_request.await else {\n            return Err(GlobalCollectiveError::OrchestratorUnreachable);\n        };\n        let Ok(Some(response)) = stream_response.await else {\n            return Err(GlobalCollectiveError::OrchestratorUnreachable);\n        };\n\n        Ok((request, response))\n    }\n\n    /// Connect with websocket with retries.\n    async fn connect_with_retry(\n        address: Address,\n        route: &str,\n        retry_pause: Duration,\n        retry_max: Option<u32>,\n        session_id: SessionId,\n    ) -> Option<C::Channel> {\n        let mut retries = 0;\n        loop {\n            if let Some(max) = retry_max\n                && retries >= max\n            {\n                log::warn!(\"Failed to connect to {address} after {max} retries.\");\n                return None;\n            }\n\n            // Try to connect to the request address.\n            println!(\"Connecting to {address} ...\");\n            let result = C::connect(address.clone(), route).await;\n\n            if let Some(mut stream) = result {\n                let init_msg = CollectiveMessage::Init(session_id);\n                let bytes: bytes::Bytes = rmp_serde::to_vec(&init_msg).unwrap().into();\n                stream\n                    .send(Message::new(bytes))\n                    .await\n                    .expect(\"Can send the init message on the websocket.\");\n                return Some(stream);\n            }\n\n            println!(\"Failed to connect to {address}, retrying... Attempt #{retries}\");\n            tokio::time::sleep(retry_pause).await;\n            retries += 1;\n        }\n    }\n\n    /// Unregister the worker and close the connection.\n    pub(crate) async fn close_connection(&mut self) -> Result<(), GlobalCollectiveError> {\n        if let Some(handle) = self.handle.take() {\n            // Un-register from server\n            let req = RemoteRequest::Finish;\n            let resp = self.request(req).await;\n            if resp != RemoteResponse::FinishAck {\n                log::error!(\"Requested to finish, did not get FinishAck; got {resp:?}\");\n                return Err(GlobalCollectiveError::WrongOrchestratorResponse);\n            }\n\n            self.cancel_token.cancel();\n\n            if let Err(e) = handle.await.unwrap() {\n                log::error!(\"Connection error {e:?}\");\n            }\n        }\n\n        Ok(())\n    }\n\n    async fn response_loader(\n        state: Arc<Mutex<GlobalClientWorkerState>>,\n        mut stream_response: C::Channel,\n        cancel_token: CancellationToken,\n    ) {\n        loop {\n            tokio::select! {\n                // Check if the cancel token is cancelled\n                _ = cancel_token.cancelled() => {\n                    break;\n                }\n                // .. Or get a message from the websocket\n                response = stream_response.recv() => {\n                    match response {\n                        Err(err) => {\n                            log::error!(\"Error receiving message from websocket: {err:?}\");\n                            break;\n                        }\n                        Ok(response) => {\n                            let Some(response) = response else {\n                                log::warn!(\"Closed connection\");\n                                break;\n                            };\n\n                            let response: CollectiveMessageResponse = rmp_serde::from_slice(&response.data)\n                                .expect(\"Can deserialize messages from the websocket.\");\n                            let state_resp = state.lock().await;\n                            let response_callback = state_resp\n                                .requests\n                                .get(&response.request_id)\n                                .expect(\"Got a response to an unknown request\");\n                            response_callback.send(response.content).await.unwrap();\n                        }\n                    }\n                }\n            }\n        }\n\n        log::info!(\"Worker closing connection\");\n        stream_response\n            .close()\n            .await\n            .expect(\"Can close the websocket stream.\");\n    }\n\n    async fn request_sender(\n        mut request_recv: Receiver<ClientRequest>,\n        worker: Arc<Mutex<GlobalClientWorkerState>>,\n        mut stream_request: C::Channel,\n        cancel_token: CancellationToken,\n    ) {\n        loop {\n            tokio::select! {\n                _ = cancel_token.cancelled() => {\n                    break;\n                },\n                request = request_recv.recv() => {\n                    let Some(request) = request else {\n                        continue;\n                    };\n\n                    let id = RequestId::new();\n\n                    // Register the callback if there is one\n                    {\n                        let mut state = worker.lock().await;\n                        state.requests.insert(id, request.callback);\n                    }\n\n                    let request = CollectiveMessage::Request(id, request.request);\n\n                    let bytes = rmp_serde::to_vec::<CollectiveMessage>(&request)\n                        .expect(\"Can serialize tasks to bytes.\")\n                        .into();\n                    stream_request\n                        .send(Message::new(bytes))\n                        .await\n                        .expect(\"Can send the message on the websocket.\");\n                }\n            }\n        }\n\n        log::info!(\"Worker closing connection\");\n        stream_request\n            .close()\n            .await\n            .expect(\"Can send the close message on the websocket.\");\n    }\n\n    pub(crate) async fn request(&self, req: RemoteRequest) -> RemoteResponse {\n        let (callback, mut response_recv) = tokio::sync::mpsc::channel::<RemoteResponse>(10);\n        let client_req = ClientRequest::new(req, callback);\n        self.request_sender.send(client_req).await.unwrap();\n\n        response_recv.recv().await.unwrap()\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/orchestrator/base.rs",
    "content": "use std::fmt::Debug;\nuse std::sync::Arc;\nuse tokio::sync::Mutex;\n\nuse crate::global::{\n    orchestrator::state::GlobalCollectiveState,\n    shared::{CollectiveMessage, GlobalCollectiveError},\n};\nuse burn_communication::{\n    CommunicationChannel, Message, ProtocolServer, util::os_shutdown_signal, websocket::WsServer,\n};\n\n/// The global collective state manages collective operations on the global level\n#[derive(Clone)]\npub(crate) struct GlobalOrchestrator {\n    state: Arc<Mutex<GlobalCollectiveState>>,\n}\n\nimpl GlobalOrchestrator {\n    /// Starts the comms server with two routes: \"/request\" and \"/response\"\n    pub(crate) async fn start<F, S: ProtocolServer + Debug>(\n        shutdown_signal: F,\n        comms_server: S,\n    ) -> Result<(), GlobalCollectiveError>\n    where\n        F: Future<Output = ()> + Send + 'static,\n    {\n        let state = GlobalCollectiveState::new();\n        let server = Self {\n            state: Arc::new(tokio::sync::Mutex::new(state)),\n        };\n\n        comms_server\n            .route(\"/response\", {\n                let server = server.clone();\n                async move |socket| {\n                    if let Err(err) = server.handle_socket_response::<S>(socket).await {\n                        log::error!(\"[Response Handler] Error: {err:?}\")\n                    }\n                }\n            })\n            .route(\"/request\", {\n                let server = server.clone();\n                async move |socket| {\n                    if let Err(err) = server.handle_socket_request::<S>(socket).await {\n                        log::error!(\"[Request Handler] Error: {err:?}\")\n                    }\n                }\n            })\n            .serve(shutdown_signal)\n            .await\n            .map_err(|err| GlobalCollectiveError::Server(format!(\"{err:?}\")))?;\n\n        Ok(())\n    }\n\n    async fn handle_socket_response<S: ProtocolServer>(\n        self,\n        mut stream: S::Channel,\n    ) -> Result<(), GlobalCollectiveError> {\n        log::info!(\"[Response Handler] On new connection.\");\n\n        let msg = stream\n            .recv()\n            .await\n            .map_err(|err| GlobalCollectiveError::Server(format!(\"{err:?}\")))?;\n        let Some(msg) = msg else {\n            log::warn!(\"Response socket closed early!\");\n            return Ok(());\n        };\n\n        let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)\n            .map_err(|_| GlobalCollectiveError::InvalidMessage)?;\n\n        let CollectiveMessage::Init(id) = msg else {\n            return Err(GlobalCollectiveError::FirstMsgNotInit);\n        };\n\n        let mut receiver = {\n            let mut state = self.state.lock().await;\n            state.get_session_responder(id)\n        };\n\n        while let Some(response) = receiver.recv().await {\n            let bytes = rmp_serde::to_vec(&response).unwrap();\n\n            stream.send(Message::new(bytes.into())).await?;\n        }\n\n        log::info!(\"[Response Handler] Closing connection.\");\n        Ok(())\n    }\n\n    async fn handle_socket_request<S: ProtocolServer>(\n        self,\n        mut stream: S::Channel,\n    ) -> Result<(), GlobalCollectiveError> {\n        log::info!(\"[Request Handler] On new connection.\");\n\n        let mut session_id = None;\n\n        loop {\n            let packet = stream.recv().await?;\n            let Some(msg) = packet else {\n                log::info!(\"Peer closed the connection\");\n                break;\n            };\n\n            let mut state = self.state.lock().await;\n\n            let msg = rmp_serde::from_slice::<CollectiveMessage>(&msg.data)\n                .map_err(|_| GlobalCollectiveError::InvalidMessage)?;\n            match msg {\n                CollectiveMessage::Init(id) => {\n                    state.init_session(id);\n                    session_id = Some(id);\n                }\n                CollectiveMessage::Request(request_id, remote_request) => {\n                    let session_id = session_id.ok_or(GlobalCollectiveError::FirstMsgNotInit)?;\n                    state\n                        .process_request(session_id, request_id, remote_request)\n                        .await;\n                }\n            }\n        }\n\n        Ok(())\n    }\n}\n\n/// Start a global orchestrator with WebSocket on the given port\npub async fn start_global_orchestrator(port: u16) {\n    let server = WsServer::new(port);\n    let res = GlobalOrchestrator::start(os_shutdown_signal(), server).await;\n    if let Err(err) = res {\n        log::error!(\"Global Collective Orchestrator error: {err:?}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/orchestrator/mod.rs",
    "content": "pub(crate) mod base;\npub(crate) mod state;\n\npub use base::start_global_orchestrator;\n"
  },
  {
    "path": "crates/burn-collective/src/global/orchestrator/state.rs",
    "content": "use crate::{\n    PeerId,\n    global::{\n        NodeId,\n        shared::{\n            CollectiveMessageResponse, GlobalCollectiveError, RemoteRequest, RemoteResponse,\n            RequestId, SessionId,\n        },\n    },\n};\nuse burn_communication::Address;\nuse std::collections::HashMap;\nuse tokio::sync::mpsc::{Receiver, Sender};\n\npub(crate) struct Session {\n    response_sender: Sender<CollectiveMessageResponse>,\n    response_receiver: Option<Receiver<CollectiveMessageResponse>>,\n}\n\nimpl Session {\n    fn new() -> Self {\n        let (response_sender, recv) = tokio::sync::mpsc::channel::<CollectiveMessageResponse>(1);\n        Self {\n            response_sender,\n            response_receiver: Some(recv),\n        }\n    }\n\n    async fn respond(&mut self, response: CollectiveMessageResponse) {\n        self.response_sender.send(response).await.unwrap();\n    }\n}\n\npub(crate) struct GlobalCollectiveState {\n    /// The ids passed to each register so far, and their addresses\n    registered_nodes: HashMap<SessionId, NodeId>,\n    /// Address for each node\n    node_addresses: HashMap<NodeId, Address>,\n    /// Peer on each node\n    node_peers: HashMap<NodeId, Vec<PeerId>>,\n\n    /// How many total nodes for the current register operation, as defined by the first caller\n    cur_num_nodes: Option<u32>,\n    /// How many peers have registered total\n    num_global_peers: u32,\n\n    register_requests: Vec<(SessionId, RequestId, NodeId)>,\n\n    sessions: HashMap<SessionId, Session>,\n}\n\nimpl GlobalCollectiveState {\n    pub fn new() -> Self {\n        Self {\n            registered_nodes: HashMap::new(),\n            node_addresses: HashMap::new(),\n            node_peers: HashMap::new(),\n            cur_num_nodes: None,\n            num_global_peers: 0,\n            register_requests: Vec::new(),\n            sessions: HashMap::new(),\n        }\n    }\n\n    pub(crate) fn init_session(&mut self, id: SessionId) {\n        if self.sessions.contains_key(&id) {\n            return;\n        }\n        self.sessions.insert(id, Session::new());\n    }\n\n    /// Create the session with given id if necessary, and get the response receiver\n    pub(crate) fn get_session_responder(\n        &mut self,\n        id: SessionId,\n    ) -> Receiver<CollectiveMessageResponse> {\n        self.init_session(id);\n        let session = self.sessions.get_mut(&id).unwrap();\n        let response_recv = session.response_receiver.take();\n\n        response_recv.unwrap()\n    }\n\n    pub(crate) async fn respond(\n        &mut self,\n        session_id: SessionId,\n        response: CollectiveMessageResponse,\n    ) {\n        let session = self.sessions.get_mut(&session_id).unwrap();\n        session.respond(response).await;\n    }\n\n    /// Process an incoming node's request\n    pub(crate) async fn process_request(\n        &mut self,\n        session_id: SessionId,\n        request_id: RequestId,\n        request: RemoteRequest,\n    ) {\n        if let Err(err) = match request {\n            RemoteRequest::Register {\n                node_addr,\n                num_nodes,\n                peers,\n            } => {\n                self.register(session_id, request_id, node_addr, num_nodes, peers)\n                    .await\n            }\n            RemoteRequest::Finish => self.finish(session_id, request_id).await,\n        } {\n            // Error occurred, send it as response\n            let content = RemoteResponse::Error(err);\n            self.respond(\n                session_id,\n                CollectiveMessageResponse {\n                    request_id,\n                    content,\n                },\n            )\n            .await;\n        }\n    }\n\n    /// Un-register a node. Any pending requests will be cancelled, returning error responses.\n    async fn finish(\n        &mut self,\n        session_id: SessionId,\n        request_id: RequestId,\n    ) -> Result<(), GlobalCollectiveError> {\n        let node_id = self\n            .registered_nodes\n            .remove(&session_id)\n            .ok_or(GlobalCollectiveError::NotRegisteredOnFinish)?;\n        self.node_addresses.remove(&node_id);\n        self.node_peers.remove(&node_id);\n        self.num_global_peers = 0;\n\n        let mut register_requests = vec![];\n        core::mem::swap(&mut register_requests, &mut self.register_requests);\n        for (session, req, node_id) in register_requests {\n            if session == session_id {\n                // Send a response if we are finishing a session with a pending register request\n                let content = RemoteResponse::Error(GlobalCollectiveError::PendingRegisterOnFinish);\n                let response = CollectiveMessageResponse {\n                    request_id: req,\n                    content,\n                };\n                self.respond(session_id, response).await;\n            } else {\n                // keep the register request\n                self.register_requests.push((session, req, node_id));\n            }\n        }\n\n        self.respond(\n            session_id,\n            CollectiveMessageResponse {\n                request_id,\n                content: RemoteResponse::FinishAck,\n            },\n        )\n        .await;\n\n        Ok(())\n    }\n\n    async fn register(\n        &mut self,\n        session_id: SessionId,\n        request_id: RequestId,\n        node_addr: Address,\n        num_nodes: u32,\n        peers: Vec<PeerId>,\n    ) -> Result<(), GlobalCollectiveError> {\n        match &self.cur_num_nodes {\n            Some(cur_num_nodes) => {\n                if *cur_num_nodes != num_nodes {\n                    return Err(GlobalCollectiveError::RegisterParamsMismatch);\n                }\n            }\n            None => {\n                self.cur_num_nodes = Some(num_nodes);\n            }\n        }\n\n        self.num_global_peers += peers.len() as u32;\n\n        let node_id: NodeId = self.registered_nodes.len().into();\n        self.registered_nodes.insert(session_id, node_id);\n        if self.node_addresses.values().any(|addr| node_addr == *addr) {\n            return Err(GlobalCollectiveError::DoubleRegister);\n        }\n        self.node_addresses.insert(node_id, node_addr);\n        self.node_peers.insert(node_id, peers);\n\n        self.register_requests\n            .push((session_id, request_id, node_id));\n\n        if self.registered_nodes.len() == num_nodes as usize {\n            let mut callbacks = vec![];\n            core::mem::swap(&mut callbacks, &mut self.register_requests);\n\n            for (session, request, node_id) in callbacks {\n                let content = RemoteResponse::Register {\n                    node_id,\n                    nodes: self.node_addresses.clone(),\n                    num_global_devices: self.num_global_peers,\n                };\n                let resp = CollectiveMessageResponse {\n                    request_id: request,\n                    content,\n                };\n                self.respond(session, resp).await;\n            }\n        }\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/global/shared.rs",
    "content": "use std::{collections::HashMap, sync::atomic::AtomicU32};\n\nuse crate::{NodeId, PeerId};\nuse burn_communication::{Address, CommunicationError};\nuse burn_std::id::IdGenerator;\nuse serde::{Deserialize, Serialize};\n\n/// A unique identifier for each request made to a global orchestrator\n#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]\npub(crate) struct RequestId(u32);\n\nstatic REQ_ID_COUNTER: AtomicU32 = AtomicU32::new(0);\nimpl RequestId {\n    pub(crate) fn new() -> Self {\n        let id = REQ_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);\n        Self(id)\n    }\n}\n\nimpl Default for RequestId {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\n/// Unique identifier that can represent a session between a node and the orchestrator.\n#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]\npub(crate) struct SessionId {\n    id: u64,\n}\n\nimpl SessionId {\n    /// Create a new [session id](SessionId).\n    pub(crate) fn new() -> Self {\n        Self {\n            id: IdGenerator::generate(),\n        }\n    }\n}\n\n/// Requests sent from the client\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub(crate) enum CollectiveMessage {\n    Init(SessionId),\n    Request(RequestId, RemoteRequest),\n}\n\n/// Responses sent to the client\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub(crate) struct CollectiveMessageResponse {\n    pub request_id: RequestId,\n    pub content: RemoteResponse,\n}\n\n/// Requests made from a client to a server.\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub(crate) enum RemoteRequest {\n    // Register a node\n    Register {\n        /// Endpoint for this node\n        node_addr: Address,\n        /// Number of total nodes\n        num_nodes: u32,\n        /// List of peers on this node\n        peers: Vec<PeerId>,\n    },\n\n    /// Unregister node\n    Finish,\n}\n\n/// Responses for each server request\n#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]\npub(crate) enum RemoteResponse {\n    /// Response to a register request\n    Register {\n        /// The orchestrator gives the node its id\n        node_id: NodeId,\n        /// All the nodes in the collective: including self\n        nodes: HashMap<NodeId, Address>,\n        /// How many devices exist globally? For averaging values\n        num_global_devices: u32,\n    },\n\n    // Finish\n    FinishAck,\n\n    // There was a server-side error\n    Error(GlobalCollectiveError),\n}\n\n/// Errors that occur during collective operations on the global level\n#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]\npub enum GlobalCollectiveError {\n    /// Operations that can't be done before registering\n    AllReduceBeforeRegister,\n    /// Ring all-reduce can't be done if all tensor dimensions are smaller than the number of nodes.\n    RingReduceImpossible,\n\n    /// Either a node has unregistered twice, or a Finish has been called before a Register\n    NotRegisteredOnFinish,\n    /// Finish has been called before a Register operation was finished\n    PendingRegisterOnFinish,\n    /// Trying to register a different way than is currently being done\n    RegisterParamsMismatch,\n    /// Trying to register while already registered\n    DoubleRegister,\n    /// Trying to aggregate a different way than is currently being done\n    AllReduceParamsMismatch,\n\n    /// First message on socket should be Message::Init\n    FirstMsgNotInit,\n    /// Messages should be rmp_serde serialized `Message` types\n    InvalidMessage,\n    /// A peer behaved unexpectedly\n    PeerSentIncoherentTensor,\n    /// Tried to download from a peer, but the peer closed or lost the connection\n    PeerLost(NodeId),\n    /// Error from the coordinator\n    Server(String),\n\n    /// The node received an invalid response\n    WrongOrchestratorResponse,\n    /// Node couldn't connect to coordinator\n    OrchestratorUnreachable,\n}\n\nimpl<E: CommunicationError> From<E> for GlobalCollectiveError {\n    fn from(err: E) -> Self {\n        Self::Server(format!(\"{err:?}\"))\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/lib.rs",
    "content": "mod global;\npub use global::*;\n\nmod config;\npub use config::*;\n\nmod api;\npub use api::*;\n\nmod local;\n\n#[cfg(all(\n    test,\n    any(\n        feature = \"test-ndarray\",\n        feature = \"test-wgpu\",\n        feature = \"test-cuda\",\n        feature = \"test-metal\"\n    )\n))]\nmod tests;\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/base.rs",
    "content": "use crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};\nuse crate::{\n    AllReduceStrategy, CollectiveConfig, CollectiveError, ReduceOperation,\n    local::{\n        all_reduce_sum_centralized, all_reduce_sum_ring, all_reduce_sum_tree,\n        broadcast_centralized, broadcast_tree, reduce_sum_centralized, reduce_sum_tree,\n    },\n    node::base::Node,\n};\nuse burn_communication::Protocol;\nuse burn_tensor::backend::Backend;\n\n#[cfg(feature = \"tracing\")]\nuse tracing::Instrument;\n\n/// Perform an all-reduce with no multi-node operations (global ops)\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors, config))\n)]\npub(crate) async fn all_reduce_local_only<B: Backend>(\n    tensors: CollectiveTensorMap<B>,\n    op: ReduceOperation,\n    config: &CollectiveConfig,\n) -> Result<CollectiveTensorMap<B>, CollectiveError> {\n    let local_strategy = &config.local_all_reduce_strategy;\n\n    let mut reduced_tensors = match local_strategy {\n        AllReduceStrategy::Centralized => all_reduce_sum_centralized::<B>(tensors),\n        AllReduceStrategy::Tree(arity) => all_reduce_sum_tree::<B>(tensors, *arity),\n        AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors),\n    };\n\n    if op == ReduceOperation::Mean {\n        #[cfg(feature = \"tracing\")]\n        let _span = tracing::info_span!(\"mean_reduction\").entered();\n\n        // Apply mean division\n        let div = (reduced_tensors.len() as f32).into();\n\n        reduced_tensors = reduced_tensors\n            .into_iter()\n            .map(|(id, t)| (id, B::float_div_scalar(t, div)))\n            .collect();\n    }\n    Ok(reduced_tensors)\n}\n\n/// Do an all-reduce in a multi-node context\n///\n/// With Tree and Centralized strategies, the all-reduce is split between a\n/// reduce (all tensors are reduced to one device), and a broadcast (the result is sent to all\n/// other devices). The all-reduce on the global level is done between both steps.\n/// Due to the nature of the Ring strategy, this separation can't be done.\n///\n/// For the Ring strategy, this isn't possible, because it is more like a\n/// reduce-scatter plus an all-gather, so using a Ring strategy locally in a multi-node\n/// setup may be unadvantageous.\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors, config, global_client))\n)]\npub(crate) async fn all_reduce_with_global<B: Backend, P: Protocol>(\n    tensors: CollectiveTensorMap<B>,\n    op: ReduceOperation,\n    config: &CollectiveConfig,\n    global_client: &mut Node<B, P>,\n) -> Result<CollectiveTensorMap<B>, CollectiveError> {\n    let peer_devices = get_peer_devices::<B>(&tensors);\n\n    // For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later\n    let main_device = *tensors.keys().next().unwrap();\n\n    let mut main_tensor = match config.local_all_reduce_strategy {\n        AllReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &main_device),\n        AllReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &main_device, arity),\n        AllReduceStrategy::Ring => all_reduce_sum_ring::<B>(tensors)\n            .remove(&main_device)\n            .unwrap(),\n    };\n\n    // Do aggregation on global level with the main tensor\n    main_tensor = {\n        let fut = async {\n            let global_strategy = config\n                .global_all_reduce_strategy\n                .expect(\"global_all_reduce_strategy must be set\");\n\n            global_client\n                .all_reduce(main_tensor, global_strategy, op)\n                .await\n        };\n        #[cfg(feature = \"tracing\")]\n        {\n            fut.instrument(tracing::info_span!(\"global_all_reduce\"))\n        }\n        #[cfg(not(feature = \"tracing\"))]\n        {\n            fut\n        }\n    }\n    .await\n    .map_err(CollectiveError::Global)?;\n\n    // Broadcast result to all devices\n    let tensors = match config.local_all_reduce_strategy {\n        AllReduceStrategy::Tree(arity) => {\n            broadcast_tree::<B>(peer_devices, main_device, main_tensor, arity)\n        }\n        // If we chose the ring strategy and we must still broadcast the global result,\n        // we use the centralized strategy for broadcasting, but the tree may be better.\n        AllReduceStrategy::Centralized | AllReduceStrategy::Ring => {\n            broadcast_centralized::<B>(peer_devices, main_device, main_tensor)\n        }\n    };\n\n    Ok(tensors)\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/centralized.rs",
    "content": "use burn_tensor::backend::Backend;\n\nuse crate::local::tensor_map::{CollectiveTensorMap, get_peer_devices};\nuse crate::local::{broadcast_centralized, reduce_sum_centralized};\n\n/// Perform an all-reduce operation by reducing all tensors on one device, and broadcasting the\n/// result to all other devices\n///\n/// Internally, this is just a call to `reduce` followed by a `broadcast`\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\npub(crate) fn all_reduce_sum_centralized<B: Backend>(\n    tensors: CollectiveTensorMap<B>,\n) -> CollectiveTensorMap<B> {\n    // Get corresponding devices for each peer\n    let peer_devices = get_peer_devices::<B>(&tensors);\n    let central_device = *tensors.keys().next().unwrap();\n\n    // Reduce to central device\n    let central_tensor = reduce_sum_centralized::<B>(tensors, &central_device);\n\n    // Broadcast result to all\n    broadcast_centralized::<B>(peer_devices, central_device, central_tensor)\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/mod.rs",
    "content": "mod base;\nmod centralized;\nmod op;\nmod ring;\nmod tree;\n\npub(crate) use base::*;\npub(crate) use centralized::*;\npub(crate) use op::*;\npub(crate) use ring::*;\npub(crate) use tree::*;\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/op.rs",
    "content": "use crate::global::node::base::Node;\nuse crate::local::tensor_map::CollectiveTensorMap;\nuse crate::{CollectiveConfig, CollectiveError, PeerId, ReduceOperation, local};\nuse burn_communication::Protocol;\nuse burn_std::Shape;\nuse burn_tensor::TensorMetadata;\nuse burn_tensor::backend::Backend;\nuse std::sync::mpsc::SyncSender;\n\n/// An on-going all-reduce operation\n#[derive(Debug)]\npub struct AllReduceOp<B: Backend> {\n    /// all-reduce calls, one for each calling device\n    calls: Vec<AllReduceOpCall<B>>,\n    /// The reduce operation of the current all-reduce, as defined by the first caller\n    op: ReduceOperation,\n    /// The shape of the current all-reduce, as defined by the first caller\n    shape: Shape,\n}\n\n/// Struct for each device that calls an all-reduce operation\n#[derive(Debug)]\npub struct AllReduceOpCall<B: Backend> {\n    /// Id of the caller for this operation\n    caller: PeerId,\n    /// The tensor primitive passed as input\n    input: B::FloatTensorPrimitive,\n    /// Callback for the result of the all-reduce\n    result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,\n}\n\n/// Type sent to the collective client upon completion of a all-reduce aggregation\npub(crate) type AllReduceResult<T> = Result<T, CollectiveError>;\n\nimpl<B: Backend> AllReduceOp<B> {\n    pub fn new(shape: Shape, reduce_op: ReduceOperation) -> Self {\n        Self {\n            calls: vec![],\n            op: reduce_op,\n            shape,\n        }\n    }\n\n    /// Get a list of the peers.\n    fn peers(&self) -> Vec<PeerId> {\n        self.calls.iter().map(|c| c.caller).collect()\n    }\n\n    /// Register a call to all-reduce in this operation.\n    ///\n    /// # Returns\n    ///\n    /// `true` if enough peers have registered, and the all-reduce is ready\n    pub fn register_call(\n        &mut self,\n        caller: PeerId,\n        input: B::FloatTensorPrimitive,\n        result_sender: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,\n        op: ReduceOperation,\n        peer_count: usize,\n    ) -> Result<bool, CollectiveError> {\n        if self.shape != input.shape() {\n            return Err(CollectiveError::AllReduceShapeMismatch);\n        }\n        if self.op != op {\n            return Err(CollectiveError::AllReduceOperationMismatch);\n        }\n\n        self.calls.push(AllReduceOpCall {\n            caller,\n            input,\n            result_sender,\n        });\n\n        Ok(self.calls.len() == peer_count)\n    }\n\n    /// Runs the all-reduce if the operation is ready. Otherwise, do nothing\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level = \"trace\",\n        skip(self, config, global_client),\n        fields(\n            ?self.op,\n            ?self.shape,\n            self.peers = ?self.peers(),\n        )\n    ))]\n    pub async fn execute<P: Protocol>(\n        mut self,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) {\n        // all registered callers have sent a tensor to aggregate\n        match self.all_reduce(config, global_client).await {\n            Ok(mut tensors) => {\n                // Return resulting tensors\n                self.calls.iter().for_each(|call| {\n                    let result = tensors\n                        .remove(&call.caller)\n                        .expect(\"tensor/peer internal mismatch.\");\n                    call.result_sender.send(Ok(result)).unwrap();\n                });\n                assert_eq!(tensors.len(), 0, \"tensor/peer internal mismatch.\");\n            }\n            Err(err) => {\n                // Send error to all subscribers\n                self.fail(err);\n            }\n        }\n    }\n\n    /// Perform an all-reduce operation.\n    #[cfg_attr(\n        feature = \"tracing\",\n        tracing::instrument(level = \"trace\", skip(self, config, global_client))\n    )]\n    async fn all_reduce<P: Protocol>(\n        &mut self,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) -> Result<CollectiveTensorMap<B>, CollectiveError> {\n        let tensors = self\n            .calls\n            .iter()\n            .map(|call| (call.caller, call.input.clone()))\n            .collect();\n\n        if let Some(global_client) = global_client.as_mut() {\n            local::all_reduce_with_global(tensors, self.op, config, global_client).await\n        } else {\n            local::all_reduce_local_only::<B>(tensors, self.op, config).await\n        }\n    }\n\n    /// Send a collective error as result to operation caller\n    pub fn fail(self, err: CollectiveError) {\n        self.calls.iter().for_each(|op| {\n            op.result_sender.send(Err(err.clone())).unwrap();\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/ring.rs",
    "content": "use super::tree::all_reduce_sum_tree;\nuse crate::PeerId;\nuse crate::local::tensor_map;\nuse crate::local::tensor_map::CollectiveTensorMap;\nuse burn_tensor::{Shape, Slice, TensorMetadata, backend::Backend};\nuse std::{collections::HashMap, ops::Range};\n\n/// Ring implementation of All-Reduce (Ring-Reduce)\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\npub(crate) fn all_reduce_sum_ring<B: Backend>(\n    tensors: CollectiveTensorMap<B>,\n) -> CollectiveTensorMap<B> {\n    // https://blog.dailydoseofds.com/p/all-reduce-and-ring-reduce-for-model\n\n    // Example: tensors=3, slices=3\n\n    // phase 1\n    // o->o  o\n    // o  o->oå\n    // o  o  o->\n\n    // o  1->o\n    // o  o  1->\n    // 1->o  o\n\n    // o  1  2\n    // 2  o  1\n    // 1  2  o\n\n    // phase 2\n    // o  1  2->\n    // 2->o  1\n    // 1  2->o\n\n    // 2->1  2\n    // 2  2->1\n    // 1  2  2->\n\n    // 2  2  2\n    // 2  2  2\n    // 2  2  2\n\n    // Verify all shapes are the same\n    let shape = tensor_map::get_common_shape::<B>(&tensors)\n        .expect(\"Cannot aggregate tensors with different sizes\");\n\n    // Chose an axis\n    let slice_dim = get_slice_dim(&shape);\n\n    let slice_dim_size = shape[slice_dim];\n    let tensor_count = tensors.len();\n    if slice_dim_size < tensor_count {\n        // Tensor cannot be split into N slices! Use a fallback algorithm: binary tree\n        return all_reduce_sum_tree::<B>(tensors, 2);\n    }\n\n    // Split tensors into slices\n    let mut sliced_tensors = slice_tensors::<B>(tensors, shape, slice_dim);\n\n    // phase 1: aggregate in ring N-1 times (Reduce-Scatter)\n    ring_cycles::<B>(&mut sliced_tensors, true);\n\n    // phase 2: share (overwrite) in a ring N-1 times (All-Gather)\n    ring_cycles::<B>(&mut sliced_tensors, false);\n\n    // merge slices and put back in result\n    sliced_tensors\n        .into_iter()\n        .map(|(id, slices)| (id, B::float_cat(slices, slice_dim)))\n        .collect()\n}\n\n/// Get the dimension to slice across: the largest dimension of the shape\npub(crate) fn get_slice_dim(shape: &Shape) -> usize {\n    // get dimension with the greatest size.\n    shape\n        .iter()\n        .enumerate()\n        .max_by(|(_, a), (_, b)| a.cmp(b))\n        .map(|(index, _)| index)\n        .unwrap()\n}\n\n/// With a ring of N tensors, send the tensors N-1 times, either for the first of second phase.\n/// During the first phase, the tensor slices are summed.\n/// During the second, the slices are replaced.\nfn ring_cycles<B: Backend>(\n    sliced_tensors: &mut [(PeerId, Vec<B::FloatTensorPrimitive>)],\n    is_phase_one: bool,\n) {\n    let tensor_count = sliced_tensors.len();\n    for cycle in 0..(tensor_count - 1) {\n        for i in 0..tensor_count {\n            let src_tensor_idx = i;\n            let dest_tensor_idx = (i + 1) % tensor_count;\n\n            let slice_idx = if is_phase_one {\n                (i + (tensor_count - 1) * cycle) % tensor_count\n            } else {\n                // in phase 2, the starting slice is different (see diagrams)\n                (i + 1 + (tensor_count - 1) * cycle) % tensor_count\n            };\n\n            let src_slice = sliced_tensors[src_tensor_idx].1.remove(slice_idx);\n            let mut dest_slice = sliced_tensors[dest_tensor_idx].1.remove(slice_idx);\n\n            let dest_device = B::float_device(&dest_slice);\n            let src_slice_on_dest = B::float_to_device(src_slice.clone(), &dest_device);\n            if is_phase_one {\n                dest_slice = B::float_add(dest_slice, src_slice_on_dest);\n            } else {\n                let slices: Vec<Slice> = dest_slice\n                    .shape()\n                    .iter()\n                    .map(|&d| Slice::new(0, Some(d as isize), 1))\n                    .collect();\n\n                // in phase 2, we don't sum the two slices, we replace with the new one.\n                dest_slice =\n                    B::float_slice_assign(dest_slice, slices.as_slice(), src_slice_on_dest);\n            }\n\n            sliced_tensors[src_tensor_idx]\n                .1\n                .insert(slice_idx, src_slice);\n            sliced_tensors[dest_tensor_idx]\n                .1\n                .insert(slice_idx, dest_slice);\n        }\n    }\n}\n\n/// Slice a list of tensors the same way, evenly across a given dimension.\n/// The given `shape` should be the same for every tensor.\nfn slice_tensors<B: Backend>(\n    mut tensors: HashMap<PeerId, B::FloatTensorPrimitive>,\n    shape: Shape,\n    slice_dim: usize,\n) -> Vec<(PeerId, Vec<<B as Backend>::FloatTensorPrimitive>)> {\n    // Get slice index ranges\n    let ranges = get_ring_reduce_slice_ranges(shape[slice_dim], tensors.len());\n\n    // Slice tensors\n    let mut sliced_tensors = vec![];\n    for (id, tensor) in tensors.drain() {\n        let mut slices = vec![];\n        for range in &ranges {\n            let full_range = shape\n                .iter()\n                .enumerate()\n                .map(|(dim_idx, dim)| {\n                    if dim_idx == slice_dim {\n                        Slice::from(range.clone())\n                    } else {\n                        Slice::from(0..*dim)\n                    }\n                })\n                .collect::<Vec<_>>();\n            let slice = B::float_slice(tensor.clone(), &full_range);\n            slices.push(slice);\n        }\n        sliced_tensors.push((id, slices));\n    }\n\n    sliced_tensors\n}\n\n/// Get the index ranges for the slices to split a tensor evently across a given axis.\n///\n/// * `slice_dim_size` - The size of the dim to slice on\n/// * `slice_count` - The number of slices\n///\n/// Returns a vector of index ranges for each slice.\npub(crate) fn get_ring_reduce_slice_ranges(\n    slice_dim_size: usize,\n    slice_count: usize,\n) -> Vec<Range<usize>> {\n    let mut ranges: Vec<Range<usize>> = vec![];\n\n    let slice_size = slice_dim_size.div_ceil(slice_count);\n\n    for i in 0..slice_count {\n        let start = i * slice_size;\n        let end = start + slice_size;\n\n        ranges.push(Range { start, end });\n    }\n    ranges.last_mut().unwrap().end = slice_dim_size;\n\n    ranges\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/all_reduce/tree.rs",
    "content": "use crate::PeerId;\nuse crate::local::tensor_map::CollectiveTensorMap;\nuse burn_tensor::backend::{Backend, DeviceOps};\nuse std::collections::HashMap;\n\n/// Performs an all-reduce on the provided tensors in a b-tree structure with `arity`.\n/// Similar to [reduce_sum_tree](reduce_sum_tree), but this function broadcasts the result with\n/// the same tree algorithm.\n/// The returned tensors are on the same devices as the corresponding inputs\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\npub(crate) fn all_reduce_sum_tree<B: Backend>(\n    tensors: CollectiveTensorMap<B>,\n    arity: u32,\n) -> CollectiveTensorMap<B> {\n    let mut input = tensors.into_iter().collect::<Vec<_>>();\n\n    // Sort to put devices of the same type together\n    input.sort_by(|a, b| {\n        let dev_a = B::float_device(&a.1);\n        let dev_b = B::float_device(&b.1);\n        dev_a.id().cmp(&dev_b.id())\n    });\n    // Recursive all-reduce\n    let out = all_reduce_sum_tree_inner::<B>(input, arity);\n\n    let mut tensors = HashMap::new();\n    for (id, tensor) in out {\n        tensors.insert(id, tensor);\n    }\n    tensors\n}\n\n/// Recursive function that sums `tensors` and redistributes the result to the host devices\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\nfn all_reduce_sum_tree_inner<B: Backend>(\n    mut tensors: Vec<(PeerId, B::FloatTensorPrimitive)>,\n    arity: u32,\n) -> Vec<(PeerId, B::FloatTensorPrimitive)> {\n    let mut parent_tensors = vec![];\n    let mut children_groups = vec![];\n\n    // Phase 1: Sum tensors in groups of `arity` + 1\n    while !tensors.is_empty() {\n        // Maps ids to devices for each child of this parent\n        let mut children = vec![];\n        let (parent, mut parent_tensor) = tensors.remove(0);\n        let parent_device = B::float_device(&parent_tensor);\n\n        for _ in 0..arity {\n            if tensors.is_empty() {\n                break;\n            }\n            let (child, mut child_tensor) = tensors.remove(0);\n            let child_device = B::float_device(&child_tensor);\n            children.push((child, child_device));\n            child_tensor = B::float_to_device(child_tensor, &parent_device);\n            parent_tensor = B::float_add(parent_tensor, child_tensor);\n        }\n\n        parent_tensors.push((parent, parent_tensor));\n        children_groups.push(children);\n    }\n\n    if parent_tensors.len() > 1 {\n        // Parents are not yet at the root, do the upper part of the tree\n        parent_tensors = all_reduce_sum_tree_inner::<B>(parent_tensors, arity);\n    }\n\n    // Phase 2: Redistribute result from each parent to the respective devices\n    for (parent, parent_tensor) in parent_tensors {\n        let children = children_groups.remove(0);\n        for (child, child_device) in children {\n            // replace child tensors with result\n            tensors.push((\n                child,\n                B::float_to_device(parent_tensor.clone(), &child_device),\n            ));\n        }\n        tensors.push((parent, parent_tensor));\n    }\n\n    tensors\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/broadcast/centralized.rs",
    "content": "use std::collections::HashMap;\n\nuse crate::PeerId;\nuse crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};\nuse burn_tensor::backend::Backend;\n\n/// Broadcasts the tensor from one device in a map to all the others\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(devices, tensor))\n)]\npub(crate) fn broadcast_centralized<B: Backend>(\n    mut devices: PeerDeviceMap<B>,\n    central: PeerId,\n    tensor: B::FloatTensorPrimitive,\n) -> CollectiveTensorMap<B> {\n    let mut output = HashMap::new();\n\n    devices\n        .remove(&central)\n        .expect(\"Central device id is in `devices`\");\n    for (dest, dest_device) in devices {\n        let tensor = B::float_to_device(tensor.clone(), &dest_device);\n        output.insert(dest, tensor);\n    }\n    output.insert(central, tensor);\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/broadcast/mod.rs",
    "content": "mod centralized;\nmod op;\nmod tree;\n\npub(crate) use centralized::*;\npub(crate) use op::*;\npub(crate) use tree::*;\n"
  },
  {
    "path": "crates/burn-collective/src/local/broadcast/op.rs",
    "content": "use crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};\nuse crate::{\n    BroadcastStrategy, CollectiveConfig, CollectiveError, PeerId,\n    local::{broadcast_centralized, broadcast_tree},\n    node::base::Node,\n};\nuse burn_communication::Protocol;\n#[allow(unused_imports)] // TensorMetadata is used by tracing::instrument.\nuse burn_tensor::TensorMetadata;\nuse burn_tensor::backend::Backend;\nuse std::sync::mpsc::SyncSender;\n\n/// An on-going broadcast operation\npub struct BroadcastOp<B: Backend> {\n    /// broadcast calls, one for each calling device\n    calls: Vec<BroadcastOpCall<B>>,\n    /// The tensor to broadcast, as defined by the root. Should be defined before all\n    /// peers call the operation.\n    tensor: Option<B::FloatTensorPrimitive>,\n\n    /// ID of the root (or use the first call's peer).\n    root: Option<PeerId>,\n}\n\n/// Struct for each device that calls an broadcast operation\npub struct BroadcastOpCall<B: Backend> {\n    /// Id of the caller of the operation\n    caller: PeerId,\n    /// Device of the calling peer\n    device: B::Device,\n    /// Callback for the result of the broadcast\n    result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,\n}\n\n/// Type sent to the collective client upon completion of a broadcast op\npub(crate) type BroadcastResult<T> = Result<T, CollectiveError>;\n\nimpl<B: Backend> BroadcastOp<B> {\n    pub fn new() -> Self {\n        Self {\n            calls: vec![],\n            tensor: None,\n            root: None,\n        }\n    }\n\n    /// Get the effective root of the broadcast operation.\n    /// If the root is set, return it. Otherwise, return the first caller's peer.\n    pub fn effective_root(&self) -> PeerId {\n        self.root.unwrap_or(self.calls.first().unwrap().caller)\n    }\n\n    pub fn peers(&self) -> Vec<PeerId> {\n        self.calls.iter().map(|c| c.caller).collect()\n    }\n\n    fn peer_devices(&self) -> PeerDeviceMap<B> {\n        self.calls\n            .iter()\n            .map(|call| (call.caller, call.device.clone()))\n            .collect()\n    }\n\n    /// Register a call to reduce in this operation.\n    /// When the last caller registers a reduce, the operation is executed.\n    pub fn register_call(\n        &mut self,\n        caller: PeerId,\n        input: Option<B::FloatTensorPrimitive>,\n        result_sender: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,\n        device: B::Device,\n        peer_count: usize,\n    ) -> Result<bool, CollectiveError> {\n        if input.is_some() {\n            if self.tensor.is_some() {\n                return Err(CollectiveError::BroadcastMultipleTensors);\n            }\n            self.tensor = input;\n        }\n\n        self.calls.push(BroadcastOpCall {\n            caller,\n            device,\n            result_sender,\n        });\n\n        Ok(self.calls.len() == peer_count)\n    }\n\n    /// Runs the broadcast if the operation is ready. Otherwise, do nothing\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(self, config, global_client),\n        fields(\n            self.peers = ?self.peers(),\n            self.shape = ?self.tensor.as_ref().map(|t| t.shape()),\n            self.dtype = ?self.tensor.as_ref().map(|t| t.dtype()),\n        )\n    ))]\n    pub async fn execute<P: Protocol>(\n        mut self,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) {\n        // all registered callers have sent a tensor to aggregate\n        match self.broadcast(config, global_client).await {\n            Ok(mut tensors) => {\n                // Return resulting tensors\n                self.calls.iter().for_each(|call| {\n                    let result = tensors\n                        .remove(&call.caller)\n                        .expect(\"tensor/peer internal mismatch.\");\n                    call.result_sender.send(Ok(result)).unwrap();\n                });\n                assert_eq!(tensors.len(), 0, \"tensor/peer internal mismatch.\");\n            }\n            Err(err) => {\n                // Send error to all subscribers\n                self.fail(err);\n            }\n        }\n    }\n\n    #[cfg_attr(\n        feature = \"tracing\",\n        tracing::instrument(level = \"trace\", skip(self, config, global_client))\n    )]\n    async fn broadcast<P: Protocol>(\n        &mut self,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) -> Result<CollectiveTensorMap<B>, CollectiveError> {\n        // Do broadcast on global level with the main tensor\n        if let Some(global_client) = &global_client {\n            let strategy = config\n                .global_broadcast_strategy\n                .expect(\"global_broadcast_strategy not defined\");\n\n            self.tensor = Some(\n                global_client\n                    .broadcast(self.tensor.clone(), strategy)\n                    .await\n                    .map_err(CollectiveError::Global)?,\n            )\n        }\n\n        // At this point tensor must be defined\n        let Some(tensor) = self.tensor.take() else {\n            return Err(CollectiveError::BroadcastNoTensor);\n        };\n\n        let root = self.effective_root();\n        let peer_devices = self.peer_devices();\n\n        // Broadcast locally\n        Ok(match config.local_broadcast_strategy {\n            BroadcastStrategy::Tree(arity) => {\n                broadcast_tree::<B>(peer_devices, root, tensor, arity)\n            }\n            BroadcastStrategy::Centralized => {\n                broadcast_centralized::<B>(peer_devices, root, tensor)\n            }\n        })\n    }\n\n    /// Send a collective error as result to operation caller\n    pub fn fail(self, err: CollectiveError) {\n        self.calls.iter().for_each(|call| {\n            call.result_sender.send(Err(err.clone())).unwrap();\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/broadcast/tree.rs",
    "content": "use burn_tensor::backend::{Backend, DeviceOps};\nuse std::collections::HashMap;\n\nuse crate::PeerId;\nuse crate::local::tensor_map::{CollectiveTensorMap, PeerDeviceMap};\n\n/// Performs a broadcast on the provided tensors in a b-tree structure with `arity`.\n///\n/// Tensor must be on the device in the `devices` map corresponding to the `root` key.\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(devices, tensor))\n)]\npub(crate) fn broadcast_tree<B: Backend>(\n    mut devices: PeerDeviceMap<B>,\n    root: PeerId,\n    tensor: B::FloatTensorPrimitive,\n    arity: u32,\n) -> CollectiveTensorMap<B> {\n    // Convert hash map to vector of key-value pairs because order matters\n    let mut devices_vec = vec![];\n    let root_device = devices.remove(&root).unwrap();\n    for (id, tensor) in devices.drain() {\n        devices_vec.push((id, tensor));\n    }\n\n    // Sort to put devices of the same type together\n    devices_vec.sort_by(|a, b| {\n        let dev_a = &a.1;\n        let dev_b = &b.1;\n        dev_a.id().cmp(&dev_b.id())\n    });\n\n    // put the root first\n    devices_vec.insert(0, (root, root_device));\n\n    // Recursive broadcast\n    let out = broadcast_tree_inner::<B>(tensor, devices_vec, arity);\n\n    // put results in a hash map\n    let mut tensors = HashMap::new();\n    for (id, tensor) in out {\n        tensors.insert(id, tensor);\n    }\n\n    tensors\n}\n\n/// Recursive function that broadcasts tensor across the other devices. Tensor should be on the\n/// first device of the list\n///\n/// Broadcasts the tensor across the devices in the tree in a pre-order traversal.\nfn broadcast_tree_inner<B: Backend>(\n    tensor: B::FloatTensorPrimitive,\n    mut all_devices: Vec<(PeerId, B::Device)>,\n    arity: u32,\n) -> Vec<(PeerId, B::FloatTensorPrimitive)> {\n    let mut parents = vec![];\n    let mut children_groups = vec![];\n\n    // Put devices in groups of `arity` + the parent\n    while !all_devices.is_empty() {\n        let mut children = vec![];\n        let parent = all_devices.remove(0);\n\n        for _ in 0..arity {\n            if all_devices.is_empty() {\n                break;\n            }\n            children.push(all_devices.remove(0));\n        }\n\n        parents.push(parent);\n        children_groups.push(children);\n    }\n\n    let mut parents = if parents.len() > 1 {\n        broadcast_tree_inner::<B>(tensor, parents, arity)\n    } else {\n        let root = parents.first().unwrap();\n        // `tensor` should already be on the root's device, no need to call B::float_to_device\n        vec![(root.0, tensor)]\n    };\n\n    // Redistribute result from each parent to the respective devices\n    let mut tensors = vec![];\n    for children in children_groups {\n        let parent = parents.remove(0);\n        for (child_id, child_device) in children {\n            // replace child's tensor with parent's\n            let child_tensor = B::float_to_device(parent.1.clone(), &child_device);\n            tensors.push((child_id, child_tensor));\n        }\n        tensors.push(parent);\n    }\n\n    tensors\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/client.rs",
    "content": "use crate::local::all_reduce::AllReduceResult;\nuse crate::{\n    CollectiveConfig, CollectiveError, PeerId, ReduceOperation,\n    local::{\n        BroadcastResult, ReduceResult,\n        server::{FinishResult, Message, RegisterResult},\n    },\n};\nuse burn_tensor::backend::Backend;\nuse std::sync::mpsc::{Receiver, SyncSender};\n\n/// Local client to communicate with the local server. Each thread has a client.\n#[derive(Clone)]\npub(crate) struct LocalCollectiveClient<B: Backend> {\n    pub channel: SyncSender<Message<B>>,\n}\n\n/// A pending operation that can be waited on.\npub(crate) struct PendingCollectiveOperation<T> {\n    rx: Receiver<Result<T, CollectiveError>>,\n}\n\nimpl<T> From<PendingCollectiveOperation<T>> for Receiver<Result<T, CollectiveError>> {\n    fn from(value: PendingCollectiveOperation<T>) -> Self {\n        value.rx\n    }\n}\n\nimpl<T> PendingCollectiveOperation<T> {\n    /// Wait on the operation.\n    ///\n    /// Given a `Receiver<Result<T, CollectiveError>>`, this function will wait:\n    /// - Unwraps `Ok(Result<T, CollectiveError>)` into `Result<T, CollectiveError>`;\n    /// - maps `Err(RecvError)` to `Err(CollectiveError::LocalServerMissing)`.\n    pub(crate) fn wait(self) -> Result<T, CollectiveError> {\n        let tensor = self\n            .rx\n            .recv()\n            .unwrap_or(Err(CollectiveError::LocalServerMissing))?;\n\n        Ok(tensor)\n    }\n}\n\nimpl<B: Backend> LocalCollectiveClient<B> {\n    /// Common logic for starting a collective operation.\n    ///\n    /// - Allocates `(callback, recv)` channels,\n    /// - Passes the `callback` to the `Message<B>` builder,\n    /// - Sends the message through the collective channel,\n    /// - Returns the `recv`.\n    pub(crate) fn start_operation<T, F>(&self, builder: F) -> PendingCollectiveOperation<T>\n    where\n        F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,\n    {\n        let (tx, rx) = std::sync::mpsc::sync_channel(1);\n        self.channel.send((builder)(tx)).unwrap();\n        PendingCollectiveOperation { rx }\n    }\n\n    /// Common logic for starting a collective operation, with validation.\n    ///\n    /// When `valid` is `Err`, this function returns a `Receiver<Result<T, CollectiveError>>` that\n    /// immediately returns `Err(valid)`;\n    /// otherwise, it behaves like [`LocalCollectiveClient::start_operation`].\n    pub(crate) fn start_valid_operation<T, F>(\n        &self,\n        valid: Result<(), CollectiveError>,\n        builder: F,\n    ) -> PendingCollectiveOperation<T>\n    where\n        F: FnOnce(SyncSender<Result<T, CollectiveError>>) -> Message<B>,\n    {\n        match valid {\n            Err(e) => {\n                let (tx, rx) = std::sync::mpsc::sync_channel(1);\n                tx.send(Err(e)).unwrap();\n                PendingCollectiveOperation { rx }\n            }\n            _ => self.start_operation(builder),\n        }\n    }\n\n    pub(crate) fn reset(&self) {\n        self.channel.send(Message::Reset).unwrap();\n    }\n\n    pub(crate) fn register(\n        &mut self,\n        id: PeerId,\n        device: B::Device,\n        config: CollectiveConfig,\n    ) -> RegisterResult {\n        self.register_start(id, device, config).wait()\n    }\n\n    pub(crate) fn register_start(\n        &mut self,\n        id: PeerId,\n        device: B::Device,\n        config: CollectiveConfig,\n    ) -> PendingCollectiveOperation<()> {\n        self.start_valid_operation(\n            match config.is_valid() {\n                true => Ok(()),\n                false => Err(CollectiveError::InvalidConfig),\n            },\n            |callback| Message::Register {\n                device_id: id,\n                device,\n                config,\n                callback,\n            },\n        )\n    }\n\n    /// Calls for an all-reduce operation with the given parameters and returns the result.\n    /// The `params` must be the same as the parameters passed by the other nodes.\n    ///\n    /// # Arguments\n    /// * `id` - The peer id of the caller\n    /// * `tensor` - The input tensor to reduce with the peers' tensors\n    /// * `config` - Config of the collective operation. Must be coherent with the other calls.\n    ///\n    /// # Result\n    /// - `Ok(tensor)` if the operation was successful\n    /// - `Err(CollectiveError)` on error.\n    #[cfg_attr(\n        feature = \"tracing\",\n        tracing::instrument(level = \"trace\", skip(self, tensor))\n    )]\n    pub fn all_reduce(\n        &self,\n        id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n    ) -> AllReduceResult<B::FloatTensorPrimitive> {\n        self.all_reduce_start(id, tensor, op).wait()\n    }\n\n    /// Starts an all-reduce operation with the given parameters.\n    ///\n    /// The `params` must be the same as the parameters passed by the other nodes.\n    ///\n    /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].\n    ///\n    /// # Arguments\n    /// * `id` - The peer id of the caller\n    /// * `tensor` - The input tensor to reduce with the peers' tensors\n    /// * `config` - Config of the collective operation. Must be coherent with the other calls.\n    ///\n    /// # Result\n    ///\n    /// A `Receiver<>` that will yield:\n    /// - `Ok(AllReduceResult<B::FloatTensorPrimitive>)` if the operation was successful\n    /// - `Err(SendError)` if the channel was dropped.\n    pub(crate) fn all_reduce_start(\n        &self,\n        id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n    ) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {\n        self.start_operation(|callback| Message::AllReduce {\n            device_id: id,\n            tensor,\n            op,\n            callback,\n        })\n    }\n\n    /// Reduces a tensor onto one device.\n    ///\n    /// # Arguments\n    /// - `id` - The peer id of the caller.\n    /// - `tensor` - The tensor to send as input.\n    /// - `op` - The reduce operation to apply.\n    /// - `root` - The ID of the peer that will receive the result.\n    ///\n    /// Returns Ok(None) if the root tensor is not the caller. Otherwise, returns the reduced tensor.\n    pub fn reduce(\n        &self,\n        id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n        root: PeerId,\n    ) -> ReduceResult<B::FloatTensorPrimitive> {\n        self.reduce_start(id, tensor, op, root).wait()\n    }\n\n    /// Starts a reduce operation on a tensor onto one device.\n    ///\n    /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].\n    ///\n    /// # Arguments\n    /// - `id` - The peer id of the caller.\n    /// - `tensor` - The tensor to send as input.\n    /// - `op` - The reduce operation to apply.\n    /// - `root` - The ID of the peer that will receive the result.\n    ///\n    /// # Result\n    ///\n    /// A `Receiver<>` that will yield:\n    /// - `Ok(ReduceResult<B::FloatTensorPrimitive>)` if the operation was successful\n    /// - `Err(SendError)` if the channel was dropped.\n    pub(crate) fn reduce_start(\n        &self,\n        id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n        root: PeerId,\n    ) -> PendingCollectiveOperation<Option<B::FloatTensorPrimitive>> {\n        self.start_operation(|callback| Message::Reduce {\n            device_id: id,\n            tensor,\n            op,\n            root,\n            callback,\n        })\n    }\n\n    /// Broadcasts, or receives a broadcasted tensor.\n    ///\n    /// # Arguments\n    /// - `id` - The peer id of the caller\n    /// - `tensor` - If defined, this tensor will be broadcasted.\n    ///   Otherwise, this call will receive the broadcasted tensor.\n    ///\n    /// # Result\n    /// Synchronously waits on the broadcasted tensor.\n    pub fn broadcast(\n        &self,\n        id: PeerId,\n        tensor: Option<B::FloatTensorPrimitive>,\n    ) -> BroadcastResult<B::FloatTensorPrimitive> {\n        self.broadcast_start(id, tensor).wait()\n    }\n\n    /// Starts a Broadcast, or receives a broadcasted tensor.\n    ///\n    /// This receiver can be waited on using [`LocalCollectiveClient::operation_wait`].\n    ///\n    /// # Arguments\n    /// - `id` - The peer id of the caller\n    /// - `tensor` - If defined, this tensor will be broadcasted. Otherwise, this call will receive\n    ///   the broadcasted tensor.\n    ///\n    /// # Result\n    ///\n    /// A `Receiver<>` that will yield:\n    /// - `Ok(BroadcastResult<B::FloatTensorPrimitive>)` if the operation was successful\n    /// - `Err(SendError)` if the channel was dropped.\n    pub(crate) fn broadcast_start(\n        &self,\n        id: PeerId,\n        tensor: Option<B::FloatTensorPrimitive>,\n    ) -> PendingCollectiveOperation<B::FloatTensorPrimitive> {\n        self.start_operation(|callback| Message::Broadcast {\n            device_id: id,\n            tensor,\n            callback,\n        })\n    }\n\n    pub(crate) fn finish(&self, id: PeerId) -> FinishResult {\n        self.finish_start(id).wait()\n    }\n\n    pub(crate) fn finish_start(&self, id: PeerId) -> PendingCollectiveOperation<()> {\n        self.start_operation(|callback| Message::Finish { id, callback })\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/mod.rs",
    "content": "mod all_reduce;\nmod broadcast;\nmod reduce;\n\npub(crate) mod tensor_map;\n\npub(crate) use all_reduce::*;\npub(crate) use broadcast::*;\npub(crate) use reduce::*;\n\npub(crate) mod client;\npub(crate) mod server;\n"
  },
  {
    "path": "crates/burn-collective/src/local/reduce/centralized.rs",
    "content": "use burn_tensor::backend::Backend;\n\nuse crate::PeerId;\nuse crate::local::tensor_map::CollectiveTensorMap;\n\n#[cfg(feature = \"tracing\")]\nuse crate::local::tensor_map::get_common_shape;\n\n/// Sums the tensors on one device and returns the result\n#[cfg_attr(feature = \"tracing\", tracing::instrument(\n    level=\"trace\",\n    skip(tensors),\n    fields(shape = ?get_common_shape::<B>(&tensors).unwrap())\n))]\npub(crate) fn reduce_sum_centralized<B: Backend>(\n    mut tensors: CollectiveTensorMap<B>,\n    central: &PeerId,\n) -> B::FloatTensorPrimitive {\n    let mut central_tensor = tensors\n        .remove(central)\n        .expect(\"Source device id is in the map\");\n    let central_device = B::float_device(&central_tensor);\n\n    for (_, tensor) in tensors {\n        let rhs = B::float_to_device(tensor.clone(), &central_device);\n        central_tensor = B::float_add(central_tensor, rhs);\n    }\n\n    central_tensor\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/reduce/mod.rs",
    "content": "mod centralized;\nmod op;\nmod tree;\n\npub(crate) use centralized::*;\npub(crate) use op::*;\npub(crate) use tree::*;\n"
  },
  {
    "path": "crates/burn-collective/src/local/reduce/op.rs",
    "content": "use burn_communication::Protocol;\nuse burn_tensor::{Shape, TensorMetadata, backend::Backend};\nuse std::sync::mpsc::SyncSender;\n\nuse crate::{\n    CollectiveConfig, CollectiveError, PeerId, ReduceOperation, ReduceStrategy,\n    local::{reduce_sum_centralized, reduce_sum_tree},\n    node::base::Node,\n};\n\n/// An on-going reduce operation\npub struct ReduceOp<B: Backend> {\n    /// reduce calls, one for each calling device\n    calls: Vec<ReduceOpCall<B>>,\n    /// The reduce operation, as defined by the first caller\n    op: ReduceOperation,\n    /// The peer that receives the reduce result, as defined by the first caller\n    root: PeerId,\n    /// The shape of the tensor to reduce, as defined by the first caller\n    shape: Shape,\n}\n\n/// Struct for each device that calls an reduce operation\npub struct ReduceOpCall<B: Backend> {\n    /// Id of the caller of the operation\n    caller: PeerId,\n    /// The tensor primitive passed as input\n    input: B::FloatTensorPrimitive,\n    /// Callback for the result of the reduce\n    result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,\n}\n\n/// Type sent to the collective client upon completion of a reduce aggregation\npub(crate) type ReduceResult<T> = Result<Option<T>, CollectiveError>;\n\nimpl<B: Backend> ReduceOp<B> {\n    pub fn new(shape: Shape, reduce_op: ReduceOperation, root: PeerId) -> Self {\n        Self {\n            calls: vec![],\n            op: reduce_op,\n            root,\n            shape,\n        }\n    }\n\n    fn peers(&self) -> Vec<PeerId> {\n        self.calls.iter().map(|c| c.caller).collect()\n    }\n\n    /// Register a call to reduce in this operation.\n    /// When the last caller registers a reduce, the operation is executed.\n    pub fn register_call(\n        &mut self,\n        caller: PeerId,\n        input: B::FloatTensorPrimitive,\n        result_sender: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,\n        op: ReduceOperation,\n        root: PeerId,\n        peer_count: usize,\n    ) -> Result<bool, CollectiveError> {\n        if self.shape != input.shape() {\n            return Err(CollectiveError::ReduceShapeMismatch);\n        }\n        if self.op != op {\n            return Err(CollectiveError::ReduceOperationMismatch);\n        }\n        if self.root != root {\n            return Err(CollectiveError::ReduceRootMismatch);\n        }\n\n        self.calls.push(ReduceOpCall {\n            caller,\n            input,\n            result_sender,\n        });\n\n        Ok(self.calls.len() == peer_count)\n    }\n\n    /// Runs the all-reduce if the operation is ready. Otherwise, do nothing\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(self, config, global_client),\n        fields(\n            ?self.op,\n            ?self.shape,\n            self.peers = ?self.peers(),\n        )\n    ))]\n    pub async fn execute<P: Protocol>(\n        mut self,\n        root: PeerId,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) {\n        match self.reduce(config, global_client).await {\n            Ok(mut result) => {\n                // Return resulting tensor to root, None to others\n                self.calls.iter().for_each(|op| {\n                    let msg = if op.caller == root {\n                        Ok(result.take())\n                    } else {\n                        Ok(None)\n                    };\n                    op.result_sender.send(msg).unwrap();\n                });\n            }\n            Err(err) => {\n                self.fail(err);\n            }\n        }\n    }\n\n    #[cfg_attr(\n        feature = \"tracing\",\n        tracing::instrument(level = \"trace\", skip(self, config, global_client))\n    )]\n    async fn reduce<P: Protocol>(\n        &mut self,\n        config: &CollectiveConfig,\n        global_client: &mut Option<Node<B, P>>,\n    ) -> Result<Option<B::FloatTensorPrimitive>, CollectiveError> {\n        let tensors = self\n            .calls\n            .iter()\n            .map(|call| (call.caller, call.input.clone()))\n            .collect();\n\n        // For Centralized and Tree, we only need to do a reduce here, we'll do a broadcast later\n        let mut local_sum = match config.local_reduce_strategy {\n            ReduceStrategy::Centralized => reduce_sum_centralized::<B>(tensors, &self.root),\n            ReduceStrategy::Tree(arity) => reduce_sum_tree::<B>(tensors, &self.root, arity),\n        };\n\n        // Do aggregation on a global level with the main tensor\n        let result = if let Some(global_client) = global_client {\n            let strategy = config\n                .global_reduce_strategy\n                .expect(\"global_reduce_strategy not defined\");\n\n            global_client\n                .reduce(local_sum, strategy, self.root, self.op)\n                .await\n                .map_err(CollectiveError::Global)?\n        } else {\n            // Mean division locally\n            if self.op == ReduceOperation::Mean {\n                let local_tensor_count = self.calls.len() as f32;\n                local_sum = B::float_div_scalar(local_sum, local_tensor_count.into())\n            }\n            Some(local_sum)\n        };\n\n        Ok(result)\n    }\n\n    /// Send a collective error as result to operation caller\n    pub fn fail(self, err: CollectiveError) {\n        self.calls.iter().for_each(|op| {\n            op.result_sender.send(Err(err.clone())).unwrap();\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/reduce/tree.rs",
    "content": "use crate::PeerId;\nuse crate::local::tensor_map::CollectiveTensorMap;\nuse burn_tensor::backend::{Backend, DeviceOps};\n\n/// Performs a reduce on the provided tensors in a b-tree structure with `arity`.\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\npub(crate) fn reduce_sum_tree<B: Backend>(\n    mut tensors: CollectiveTensorMap<B>,\n    root: &PeerId,\n    arity: u32,\n) -> B::FloatTensorPrimitive {\n    // Convert hash map to vector of key-value pairs because order matters\n    let mut input = vec![];\n    let root_tensor = tensors.remove(root).unwrap();\n    for (_, tensor) in tensors.drain() {\n        input.push(tensor);\n    }\n\n    // Sort to put devices of the same type together\n    input.sort_by(|a, b| {\n        let dev_a = B::float_device(a);\n        let dev_b = B::float_device(b);\n        dev_a.id().cmp(&dev_b.id())\n    });\n\n    // put the root first\n    input.insert(0, root_tensor);\n\n    reduce_sum_tree_inner::<B>(input, arity)\n}\n\n/// Recursive function that sums `tensors`\n///\n/// Traverses `tensors` and reduces in a post-order traversal. The first tensor in the list is\n/// chosen as the root\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensors))\n)]\nfn reduce_sum_tree_inner<B: Backend>(\n    mut tensors: Vec<B::FloatTensorPrimitive>,\n    arity: u32,\n) -> B::FloatTensorPrimitive {\n    let mut parents = vec![];\n    let mut children_groups = vec![];\n\n    // Sum tensors in groups of `arity` + 1\n    while !tensors.is_empty() {\n        let mut children = vec![];\n        let mut parent_tensor = tensors.remove(0);\n        let parent_device = B::float_device(&parent_tensor);\n\n        for _ in 0..arity {\n            if tensors.is_empty() {\n                break;\n            }\n            let child_tensor = tensors.remove(0);\n            children.push(B::float_device(&child_tensor));\n            let rhs = B::float_to_device(child_tensor, &parent_device);\n            parent_tensor = B::float_add(parent_tensor, rhs);\n        }\n\n        parents.push(parent_tensor);\n        children_groups.push(children);\n    }\n\n    if parents.len() > 1 {\n        // Parents are not yet at the root, do the upper part of the tree\n        reduce_sum_tree_inner::<B>(parents, arity)\n    } else {\n        // Root of tree\n        parents.remove(0)\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/server.rs",
    "content": "use crate::{\n    CollectiveConfig, CollectiveError, PeerId, ReduceOperation,\n    global::node::base::Node,\n    local::{\n        AllReduceOp, AllReduceResult, BroadcastOp, BroadcastResult, ReduceOp, ReduceResult,\n        client::LocalCollectiveClient,\n    },\n};\nuse burn_communication::websocket::{WebSocket, WsServer};\nuse burn_tensor::{TensorMetadata, backend::Backend};\nuse std::sync::{MutexGuard, OnceLock};\nuse std::{\n    any::{Any, TypeId},\n    collections::HashMap,\n    fmt::Debug,\n    sync::{\n        Arc, Mutex,\n        mpsc::{Receiver, SyncSender},\n    },\n};\nuse tokio::runtime::{Builder, Runtime};\n\n/// Define the client/server communication on the network\ntype Network = WebSocket;\n/// Type sent to the collective client upon completion of a register request\npub(crate) type RegisterResult = Result<(), CollectiveError>;\n/// Type sent to the collective client upon completion of a finish request\npub(crate) type FinishResult = Result<(), CollectiveError>;\n\n/// The local collective server that manages all the collective aggregation operations\n/// (like all-reduce) between local threads.\n/// This thread takes in messages from different clients. The clients must register, than they can\n/// send an aggregate message. They must all use the same parameters for the same aggregate\n/// operation.\npub(crate) struct LocalCollectiveServer<B: Backend> {\n    /// Channel receiver for messages from clients\n    message_rec: Receiver<Message<B>>,\n\n    /// The collective configuration. Must be the same by every peer when calling register\n    config: Option<CollectiveConfig>,\n\n    /// The ids passed to each register so far\n    peers: Vec<PeerId>,\n\n    /// Callbacks for when all registers are done\n    callbacks_register: Vec<SyncSender<RegisterResult>>,\n\n    /// Map of each peer's id and its device\n    devices: HashMap<PeerId, B::Device>,\n\n    /// Current uncompleted all-reduce operation\n    all_reduce_op: Option<AllReduceOp<B>>,\n\n    /// Current uncompleted reduce call\n    reduce_op: Option<ReduceOp<B>>,\n\n    /// Uncompleted broadcast calls, one for each calling device.\n    broadcast_op: Option<BroadcastOp<B>>,\n\n    /// Client for global collective operations\n    global_client: Option<Node<B, Network>>,\n}\n\n#[derive(Debug)]\npub(crate) enum Message<B: Backend> {\n    /// Register a new peer with the collective.\n    Register {\n        device_id: PeerId,\n        device: B::Device,\n        config: CollectiveConfig,\n        callback: SyncSender<RegisterResult>,\n    },\n    /// Perform an all-reduce operation.\n    AllReduce {\n        device_id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n        callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,\n    },\n    /// Perform a reduce operation.\n    Reduce {\n        device_id: PeerId,\n        tensor: B::FloatTensorPrimitive,\n        op: ReduceOperation,\n        root: PeerId,\n        callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,\n    },\n    /// Perform a broadcast operation (one-sender, many-receiver).\n    Broadcast {\n        device_id: PeerId,\n        tensor: Option<B::FloatTensorPrimitive>,\n        callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,\n    },\n    /// Reset the collective server.\n    Reset,\n    Finish {\n        id: PeerId,\n        callback: SyncSender<FinishResult>,\n    },\n}\n\n/// The type-erased box type for [`LocalCollectiveClient<B>`].\ntype LocalClientBox = Box<dyn Any + Send + Sync>;\n\n/// Global state map from [`Backend`] to boxed [`LocalCollectiveClient<B>`].\nstatic BACKEND_CLIENT_MAP: OnceLock<Mutex<HashMap<TypeId, LocalClientBox>>> = OnceLock::new();\n\n/// Gets a locked mutable view of the `STATE_MAP`.\npub(crate) fn get_backend_client_map() -> MutexGuard<'static, HashMap<TypeId, LocalClientBox>> {\n    BACKEND_CLIENT_MAP\n        .get_or_init(Default::default)\n        .lock()\n        .unwrap()\n}\n\n/// Get a [`LocalCollectiveClient`] for the given [`Backend`].\n///\n/// Will start the local collective client/server pair if necessary.\npub(crate) fn get_collective_client<B: Backend>() -> LocalCollectiveClient<B> {\n    let typeid = TypeId::of::<B>();\n    let mut state_map = get_backend_client_map();\n    match state_map.get(&typeid) {\n        Some(val) => val.downcast_ref().cloned().unwrap(),\n        None => {\n            let client = LocalCollectiveServer::<B>::setup(LocalCollectiveClientConfig::default());\n            state_map.insert(typeid, Box::new(client.clone()));\n            client\n        }\n    }\n}\n\n/// Global runtime.\nstatic SERVER_RUNTIME: OnceLock<Arc<Runtime>> = OnceLock::new();\n\n/// Get the global [`Runtime`].\npub(crate) fn get_collective_server_runtime() -> Arc<Runtime> {\n    SERVER_RUNTIME\n        .get_or_init(|| {\n            Builder::new_multi_thread()\n                .enable_all()\n                .build()\n                .expect(\"Unable to initialize runtime\")\n                .into()\n        })\n        .clone()\n}\n\n/// Configuration for the local collective client/server pair.\npub struct LocalCollectiveClientConfig {\n    /// Channel capacity for the messaging queue from client to server.\n    pub channel_capacity: usize,\n}\n\nimpl Default for LocalCollectiveClientConfig {\n    fn default() -> Self {\n        Self {\n            channel_capacity: 50,\n        }\n    }\n}\n\nimpl From<usize> for LocalCollectiveClientConfig {\n    fn from(capacity: usize) -> Self {\n        Self {\n            channel_capacity: capacity,\n        }\n    }\n}\n\nimpl<B: Backend> LocalCollectiveServer<B> {\n    fn new(rec: Receiver<Message<B>>) -> Self {\n        Self {\n            message_rec: rec,\n            config: None,\n            peers: vec![],\n            devices: HashMap::new(),\n            all_reduce_op: None,\n            reduce_op: None,\n            broadcast_op: None,\n            callbacks_register: vec![],\n            global_client: None,\n        }\n    }\n\n    /// Setup a client/server pair with the given config.\n    pub(crate) fn setup<C>(cfg: C) -> LocalCollectiveClient<B>\n    where\n        C: Into<LocalCollectiveClientConfig>,\n    {\n        let cfg = cfg.into();\n        let (tx, rx) = std::sync::mpsc::sync_channel(cfg.channel_capacity);\n\n        get_collective_server_runtime().spawn(async {\n            let typeid = TypeId::of::<B>();\n            log::info!(\"Starting server for backend: {typeid:?}\");\n            let mut server = LocalCollectiveServer::new(rx);\n\n            loop {\n                match server.message_rec.recv() {\n                    Ok(message) => server.process_message(message).await,\n                    Err(err) => {\n                        log::error!(\n                            \"Error receiving message from local collective server: {err:?}\"\n                        );\n                        break;\n                    }\n                }\n            }\n        });\n\n        LocalCollectiveClient { channel: tx }\n    }\n\n    async fn process_message(&mut self, message: Message<B>) {\n        match message {\n            Message::Register {\n                device_id,\n                device,\n                config,\n                callback,\n            } => {\n                self.process_register_message(device_id, device, config, &callback)\n                    .await\n            }\n            Message::AllReduce {\n                device_id,\n                tensor,\n                op,\n                callback,\n            } => {\n                self.process_all_reduce_message(device_id, tensor, op, callback)\n                    .await\n            }\n            Message::Reduce {\n                device_id,\n                tensor,\n                op,\n                root,\n                callback,\n            } => {\n                self.process_reduce_message(device_id, tensor, op, root, callback)\n                    .await\n            }\n            Message::Broadcast {\n                device_id,\n                tensor,\n                callback,\n            } => {\n                self.process_broadcast_message(device_id, tensor, callback)\n                    .await\n            }\n            Message::Reset => self.reset(),\n            Message::Finish { id, callback } => self.process_finish_message(id, callback).await,\n        }\n    }\n\n    async fn process_register_message(\n        &mut self,\n        device_id: PeerId,\n        device: B::Device,\n        config: CollectiveConfig,\n        callback: &SyncSender<RegisterResult>,\n    ) {\n        if !config.is_valid() {\n            callback.send(Err(CollectiveError::InvalidConfig)).unwrap();\n            return;\n        }\n        if self.peers.contains(&device_id) {\n            callback\n                .send(Err(CollectiveError::MultipleRegister))\n                .unwrap();\n            return;\n        }\n        if self.peers.is_empty() || self.config.is_none() {\n            self.config = Some(config);\n        } else if let Some(cfg) = &self.config\n            && *cfg != config\n        {\n            callback\n                .send(Err(CollectiveError::RegisterParamsMismatch))\n                .unwrap();\n            return;\n        }\n\n        self.peers.push(device_id);\n        self.callbacks_register.push(callback.clone());\n        self.devices.insert(device_id, device);\n\n        let config = self.config.as_ref().unwrap();\n        let global_params = config.global_register_params();\n        if let Some(global_params) = &global_params\n            && self.global_client.is_none()\n        {\n            let server = WsServer::new(global_params.data_service_port);\n            let client = Node::new(&global_params.global_address, server);\n            self.global_client = Some(client)\n        }\n\n        // All have registered, callback\n        if self.peers.len() == config.num_devices {\n            let mut register_result = Ok(());\n\n            // if an error occurs on the global register, it must be passed back to every local peer\n            if let Some(global_params) = global_params {\n                let client = self\n                    .global_client\n                    .as_mut()\n                    .expect(\"Global client should be initialized\");\n\n                register_result = client\n                    .register(self.peers.clone(), global_params)\n                    .await\n                    .map_err(CollectiveError::Global);\n            };\n\n            // Send results to all callbacks.\n            self.callbacks_register\n                .drain(..)\n                .for_each(|tx| tx.send(register_result.clone()).unwrap());\n        }\n    }\n\n    /// Processes an Message::AllReduce.\n    async fn process_all_reduce_message(\n        &mut self,\n        peer_id: PeerId,\n        tensor: <B as Backend>::FloatTensorPrimitive,\n        op: ReduceOperation,\n        callback: SyncSender<AllReduceResult<B::FloatTensorPrimitive>>,\n    ) {\n        if !self.peers.contains(&peer_id) {\n            callback\n                .send(Err(CollectiveError::RegisterNotFirstOperation))\n                .unwrap();\n            return;\n        }\n\n        if self.all_reduce_op.is_none() {\n            // First call to all-reduce\n            self.all_reduce_op = Some(AllReduceOp::new(tensor.shape(), op));\n        }\n        // Take the operation, we'll put it back if we're not done\n        let mut all_reduce_op = self.all_reduce_op.take().unwrap();\n\n        // On the last caller, the all-reduce is done here\n        let res =\n            all_reduce_op.register_call(peer_id, tensor, callback.clone(), op, self.peers.len());\n\n        // Upon an error or the last call, the all_reduce_op is dropped\n        match res {\n            Ok(is_ready) => {\n                if is_ready {\n                    all_reduce_op\n                        .execute(self.config.as_ref().unwrap(), &mut self.global_client)\n                        .await;\n                } else {\n                    // Put operation back, we're waiting for more calls\n                    self.all_reduce_op = Some(all_reduce_op)\n                }\n            }\n            Err(err) => all_reduce_op.fail(err),\n        }\n    }\n\n    /// Processes a Message::Reduce.\n    async fn process_reduce_message(\n        &mut self,\n        peer_id: PeerId,\n        tensor: <B as Backend>::FloatTensorPrimitive,\n        op: ReduceOperation,\n        root: PeerId,\n        callback: SyncSender<ReduceResult<B::FloatTensorPrimitive>>,\n    ) {\n        if !self.peers.contains(&root) {\n            callback\n                .send(Err(CollectiveError::RegisterNotFirstOperation))\n                .unwrap();\n            return;\n        }\n\n        if self.reduce_op.is_none() {\n            // First call to reduce\n            self.reduce_op = Some(ReduceOp::new(tensor.shape(), op, root));\n        }\n        let mut reduce_op = self.reduce_op.take().unwrap();\n\n        // On the last caller, the all-reduce is done here\n        let res = reduce_op.register_call(\n            peer_id,\n            tensor,\n            callback.clone(),\n            op,\n            root,\n            self.peers.len(),\n        );\n\n        // Upon an error or the last call, the all_reduce_op is dropped\n        match res {\n            Ok(is_ready) => {\n                if is_ready {\n                    reduce_op\n                        .execute(root, self.config.as_ref().unwrap(), &mut self.global_client)\n                        .await;\n                } else {\n                    // Put operation back, we're waiting for more calls\n                    self.reduce_op = Some(reduce_op)\n                }\n            }\n            Err(err) => reduce_op.fail(err),\n        }\n    }\n\n    /// Processes a Message::Broadcast.\n    async fn process_broadcast_message(\n        &mut self,\n        caller: PeerId,\n        tensor: Option<<B as Backend>::FloatTensorPrimitive>,\n        callback: SyncSender<BroadcastResult<B::FloatTensorPrimitive>>,\n    ) {\n        if self.config.is_none() {\n            callback\n                .send(Err(CollectiveError::RegisterNotFirstOperation))\n                .unwrap();\n            return;\n        }\n        if !self.peers.contains(&caller) {\n            callback\n                .send(Err(CollectiveError::RegisterNotFirstOperation))\n                .unwrap();\n            return;\n        }\n\n        if self.broadcast_op.is_none() {\n            // First call to broadcast\n            self.broadcast_op = Some(BroadcastOp::new());\n        }\n        let device = self.devices.get(&caller).unwrap().clone();\n\n        let mut broadcast_op = self.broadcast_op.take().unwrap();\n\n        // On the last caller, the all-reduce is done here\n        let res =\n            broadcast_op.register_call(caller, tensor, callback.clone(), device, self.peers.len());\n\n        // Upon an error or the last call, the all_reduce_op is dropped\n        match res {\n            Ok(is_ready) => {\n                if is_ready {\n                    broadcast_op\n                        .execute(self.config.as_ref().unwrap(), &mut self.global_client)\n                        .await;\n                } else {\n                    // Put operation back, we're waiting for more calls\n                    self.broadcast_op = Some(broadcast_op)\n                }\n            }\n            Err(err) => broadcast_op.fail(err),\n        }\n    }\n\n    /// Reinitializes the collective server\n    fn reset(&mut self) {\n        self.peers.clear();\n        self.all_reduce_op = None;\n        self.reduce_op = None;\n        self.broadcast_op = None;\n    }\n\n    /// Processes a Message::Finish.\n    async fn process_finish_message(&mut self, id: PeerId, callback: SyncSender<RegisterResult>) {\n        if self.config.is_none() {\n            callback\n                .send(Err(CollectiveError::RegisterNotFirstOperation))\n                .unwrap();\n            return;\n        }\n        if !self.peers.contains(&id) {\n            callback\n                .send(Err(CollectiveError::MultipleUnregister))\n                .unwrap();\n            return;\n        }\n\n        // Remove registered with id\n        self.peers.retain(|x| *x != id);\n\n        if self.peers.is_empty()\n            && let Some(mut global_client) = self.global_client.take()\n        {\n            global_client.finish().await;\n        }\n\n        callback.send(Ok(())).unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/local/tensor_map.rs",
    "content": "//! # Common Tensor Map for Local Collective Operations\nuse crate::PeerId;\nuse burn_std::Shape;\nuse burn_tensor::TensorMetadata;\nuse burn_tensor::backend::Backend;\nuse std::collections::HashMap;\n\npub type CollectiveTensorMap<B> = HashMap<PeerId, <B as Backend>::FloatTensorPrimitive>;\n\npub type PeerDeviceMap<B> = HashMap<PeerId, <B as Backend>::Device>;\n\n/// Get the shape of the tensors. They should all have the same shape, otherwise None is returned.\npub fn get_common_shape<B: Backend>(tensors: &CollectiveTensorMap<B>) -> Option<Shape> {\n    let mut it = tensors.values();\n    if let Some(first) = it.next() {\n        let shape = first.shape();\n        for tensor in it {\n            if tensor.shape() != shape {\n                return None;\n            }\n        }\n        return Some(shape);\n    }\n    None\n}\n\n/// Get the `{ peer_id -> device }` mapping for the given tensors.\npub fn get_peer_devices<B: Backend>(tensors: &CollectiveTensorMap<B>) -> PeerDeviceMap<B> {\n    tensors\n        .iter()\n        .map(|(id, tensor)| (*id, B::float_device(tensor)))\n        .collect()\n}\n"
  },
  {
    "path": "crates/burn-collective/src/tests/all_reduce.rs",
    "content": "mod tests {\n    use std::sync::mpsc::SyncSender;\n\n    use burn_std::rand::get_seeded_rng;\n    use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};\n\n    use serial_test::serial;\n\n    #[cfg(feature = \"test-ndarray\")]\n    pub type TestBackend = burn_ndarray::NdArray<f32>;\n\n    #[cfg(feature = \"test-cuda\")]\n    pub type TestBackend = burn_cuda::Cuda<f32>;\n\n    #[cfg(feature = \"test-wgpu\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-metal\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-vulkan\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    use crate::{\n        AllReduceStrategy, CollectiveConfig, PeerId, ReduceOperation, all_reduce, register,\n        reset_collective,\n    };\n\n    pub fn run_peer<B: Backend>(\n        id: PeerId,\n        config: CollectiveConfig,\n        input: TensorData,\n        op: ReduceOperation,\n        output: SyncSender<Tensor<B, 1>>,\n    ) {\n        let device = B::Device::default();\n\n        register::<B>(id, device.clone(), config).unwrap();\n\n        let tensor = Tensor::<B, 1>::from_data(input, &device);\n\n        let tensor = Tensor::from_primitive(TensorPrimitive::Float(\n            all_reduce::<B>(id, tensor.into_primitive().tensor(), op).unwrap(),\n        ));\n\n        output.send(tensor).unwrap();\n    }\n\n    fn generate_random_input(\n        shape: Shape,\n        op: ReduceOperation,\n        thread_count: usize,\n    ) -> (Vec<TensorData>, TensorData) {\n        let input: Vec<TensorData> = (0..thread_count)\n            .map(|_| {\n                TensorData::random::<f32, _, _>(\n                    shape.clone(),\n                    burn_tensor::Distribution::Default,\n                    &mut get_seeded_rng(),\n                )\n            })\n            .collect();\n\n        let device = <TestBackend as Backend>::Device::default();\n\n        let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);\n        for item in input.iter().take(thread_count as usize) {\n            let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);\n            expected_tensor = expected_tensor.add(input_tensor);\n        }\n        if op == ReduceOperation::Mean {\n            expected_tensor = expected_tensor.div_scalar(thread_count as u32);\n        }\n\n        let expected = expected_tensor.to_data();\n\n        (input, expected)\n    }\n\n    fn test_all_reduce<B: Backend>(\n        device_count: usize,\n        op: ReduceOperation,\n        strategy: AllReduceStrategy,\n        tensor_size: usize,\n    ) {\n        reset_collective::<TestBackend>();\n\n        let (send, recv) = std::sync::mpsc::sync_channel(32);\n\n        let shape = Shape {\n            dims: vec![tensor_size],\n        };\n\n        let (input, expected) = generate_random_input(shape, op, device_count);\n\n        let config = CollectiveConfig::default()\n            .with_num_devices(device_count)\n            .with_local_all_reduce_strategy(strategy);\n\n        for id in 0..device_count {\n            let send = send.clone();\n            let input = input[id as usize].clone();\n\n            std::thread::spawn({\n                let config = config.clone();\n                move || run_peer::<B>(id.into(), config, input, op, send)\n            });\n        }\n\n        let first = recv.recv().unwrap().to_data();\n        for _ in 1..device_count {\n            let tensor = recv.recv().unwrap();\n            tensor.to_data().assert_eq(&first, true);\n        }\n\n        let tol: Tolerance<f32> = Tolerance::balanced();\n        expected.assert_approx_eq(&first, tol);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_centralized_sum() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_centralized_mean() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_binary_tree_sum() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_binary_tree_mean() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_5_tree_sum() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Tree(5), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_5_tree_mean() {\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Mean, AllReduceStrategy::Tree(5), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_ring_sum() {\n        test_all_reduce::<TestBackend>(3, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_ring_mean() {\n        test_all_reduce::<TestBackend>(3, ReduceOperation::Mean, AllReduceStrategy::Ring, 3);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_all_reduce_ring_irregular_sum() {\n        // this should trigger the fallback algorithm when the tensor is too small.\n        test_all_reduce::<TestBackend>(4, ReduceOperation::Sum, AllReduceStrategy::Ring, 3);\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/tests/broadcast.rs",
    "content": "mod tests {\n    use std::sync::mpsc::SyncSender;\n\n    use burn_std::rand::get_seeded_rng;\n    use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};\n\n    use serial_test::serial;\n\n    #[cfg(feature = \"test-ndarray\")]\n    pub type TestBackend = burn_ndarray::NdArray<f32>;\n\n    #[cfg(feature = \"test-cuda\")]\n    pub type TestBackend = burn_cuda::Cuda<f32>;\n\n    #[cfg(feature = \"test-wgpu\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-metal\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-vulkan\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    use crate::{\n        BroadcastStrategy, CollectiveConfig, PeerId, broadcast, register, reset_collective,\n    };\n\n    pub fn run_peer<B: Backend>(\n        id: PeerId,\n        config: CollectiveConfig,\n        input: Option<TensorData>,\n        output: SyncSender<Tensor<B, 1>>,\n    ) {\n        let device = B::Device::default();\n\n        register::<B>(id, device.clone(), config).unwrap();\n\n        let tensor = input.map(|data| B::float_from_data(data, &device));\n        let tensor = broadcast::<B>(id, tensor).unwrap();\n        let tensor = Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(tensor));\n\n        output.send(tensor).unwrap();\n    }\n\n    fn generate_random_input(shape: Shape) -> TensorData {\n        TensorData::random::<f32, _, _>(\n            shape.clone(),\n            burn_tensor::Distribution::Default,\n            &mut get_seeded_rng(),\n        )\n    }\n\n    fn test_broadcast<B: Backend>(\n        device_count: usize,\n        strategy: BroadcastStrategy,\n        tensor_size: usize,\n    ) {\n        reset_collective::<TestBackend>();\n\n        let (send, recv) = std::sync::mpsc::sync_channel(32);\n\n        let shape = Shape {\n            dims: vec![tensor_size],\n        };\n\n        let input = generate_random_input(shape);\n\n        let config = CollectiveConfig::default()\n            .with_num_devices(device_count)\n            .with_local_broadcast_strategy(strategy);\n\n        for id in 0..device_count {\n            // The peer #0 is the root: it sends the tensor\n            let input = if id == 0 { Some(input.clone()) } else { None };\n\n            std::thread::spawn({\n                let config = config.clone();\n                let send = send.clone();\n                move || run_peer::<B>(id.into(), config, input, send)\n            });\n        }\n\n        // Expect all peers to receive the input tensor\n        let tol: Tolerance<f32> = Tolerance::balanced();\n        for _ in 0..device_count {\n            let tensor = recv.recv().unwrap().to_data();\n            input.assert_approx_eq(&tensor, tol);\n        }\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_centralized_sum() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_centralized_mean() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_binary_tree_sum() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_binary_tree_mean() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_5_tree_sum() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_broadcast_5_tree_mean() {\n        test_broadcast::<TestBackend>(4, BroadcastStrategy::Tree(5), 4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-collective/src/tests/mod.rs",
    "content": "mod all_reduce;\nmod broadcast;\nmod reduce;\n"
  },
  {
    "path": "crates/burn-collective/src/tests/reduce.rs",
    "content": "mod tests {\n    use std::sync::mpsc::SyncSender;\n\n    use burn_std::rand::get_seeded_rng;\n    use burn_tensor::{Shape, Tensor, TensorData, TensorPrimitive, Tolerance, backend::Backend};\n\n    use serial_test::serial;\n\n    #[cfg(feature = \"test-ndarray\")]\n    pub type TestBackend = burn_ndarray::NdArray<f32>;\n\n    #[cfg(feature = \"test-cuda\")]\n    pub type TestBackend = burn_cuda::Cuda<f32>;\n\n    #[cfg(feature = \"test-wgpu\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-metal\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    #[cfg(feature = \"test-vulkan\")]\n    pub type TestBackend = burn_wgpu::Wgpu<f32>;\n\n    use crate::{\n        CollectiveConfig, PeerId, ReduceOperation, ReduceStrategy, reduce, register,\n        reset_collective,\n    };\n\n    pub fn run_peer<B: Backend>(\n        id: PeerId,\n        config: CollectiveConfig,\n        input: TensorData,\n        op: ReduceOperation,\n        root: PeerId,\n        output: SyncSender<Option<Tensor<B, 1>>>,\n    ) {\n        let device = B::Device::default();\n\n        register::<B>(id, device.clone(), config).unwrap();\n\n        let tensor = Tensor::<B, 1>::from_data(input, &device);\n\n        let tensor = tensor.into_primitive().tensor();\n        let tensor = reduce::<B>(id, tensor, op, root).unwrap();\n        let tensor = tensor.map(|t| Tensor::<B, 1>::from_primitive(TensorPrimitive::Float(t)));\n\n        output.send(tensor).unwrap();\n    }\n\n    fn generate_random_input(\n        shape: Shape,\n        op: ReduceOperation,\n        thread_count: usize,\n    ) -> (Vec<TensorData>, TensorData) {\n        let input: Vec<TensorData> = (0..thread_count)\n            .map(|_| {\n                TensorData::random::<f32, _, _>(\n                    shape.clone(),\n                    burn_tensor::Distribution::Default,\n                    &mut get_seeded_rng(),\n                )\n            })\n            .collect();\n\n        let device = <TestBackend as Backend>::Device::default();\n\n        let mut expected_tensor = Tensor::<TestBackend, 1>::zeros(shape, &device);\n        for item in input.iter().take(thread_count) {\n            let input_tensor = Tensor::<TestBackend, 1>::from_data(item.clone(), &device);\n            expected_tensor = expected_tensor.add(input_tensor);\n        }\n        if op == ReduceOperation::Mean {\n            expected_tensor = expected_tensor.div_scalar(thread_count as u32);\n        }\n\n        let expected = expected_tensor.to_data();\n\n        (input, expected)\n    }\n\n    fn test_reduce<B: Backend>(\n        device_count: usize,\n        op: ReduceOperation,\n        strategy: ReduceStrategy,\n        tensor_size: usize,\n    ) {\n        reset_collective::<TestBackend>();\n\n        let (send, recv) = std::sync::mpsc::sync_channel(32);\n\n        let shape = Shape {\n            dims: vec![tensor_size],\n        };\n\n        let (input, expected) = generate_random_input(shape, op, device_count);\n\n        let config = CollectiveConfig::default()\n            .with_num_devices(device_count)\n            .with_local_reduce_strategy(strategy);\n\n        let root: PeerId = 0.into();\n        for id in 0..device_count {\n            let send = send.clone();\n            let input = input[id as usize].clone();\n\n            std::thread::spawn({\n                let config = config.clone();\n                move || run_peer::<B>(id.into(), config, input, op, root, send)\n            });\n        }\n\n        let mut result = None;\n        for _ in 0..device_count {\n            let tensor = recv.recv().unwrap();\n            if tensor.is_some() {\n                if result.is_some() {\n                    panic!(\"Two peers received the result of an reduce!\");\n                }\n                result = tensor.map(|t| t.to_data());\n            }\n        }\n\n        let tol: Tolerance<f32> = Tolerance::balanced();\n        expected.assert_approx_eq(&result.expect(\"One peer has received the result\"), tol);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_centralized_sum() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_centralized_mean() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Centralized, 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_binary_tree_sum() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_binary_tree_mean() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(2), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_5_tree_sum() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Sum, ReduceStrategy::Tree(5), 4);\n    }\n\n    #[test]\n    #[serial]\n    pub fn test_reduce_5_tree_mean() {\n        test_reduce::<TestBackend>(4, ReduceOperation::Mean, ReduceStrategy::Tree(5), 4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-communication/Cargo.toml",
    "content": "[package]\nauthors = [\"Guilhem Ané (@Cielbird)\", \"Nathaniel Simard (@nathanielsimard)\"]\ndescription = \"Abstractions for network communication for Burn\"\nedition.workspace = true\nlicense.workspace = true\nname = \"burn-communication\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-communication\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ntracing = [\n    \"burn-std/tracing\",\n    \"burn-tensor?/tracing\",\n]\n\ndata-service = [\"burn-tensor\"]\nwebsocket = [\"axum\", \"tokio-tungstenite\", \"futures\"]\n\n[dependencies]\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = true }\nbytes = { workspace = true }\nderive-new = { workspace = true }\nfutures-util = { workspace = true }\nlog = { workspace = true }\nrmp-serde = { workspace = true }\nserde = { workspace = true, features = [\"derive\"] }\nserde_bytes = { workspace = true }\ntokio = { workspace = true, features = [\"rt-multi-thread\", \"sync\", \"signal\", \"tracing\"] }\ntokio-util = { workspace = true }\ntracing = { workspace = true, features = [\"default\"] }\ntracing-core = { workspace = true, features = [\"default\"] }\ntracing-subscriber = { workspace = true, features = [\"default\", \"fmt\", \"env-filter\"] }\n\n# Tensor Data Service\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", optional = true }\n\n# Websocket\naxum = { workspace = true, features = [\"ws\"], optional = true }\ntokio-tungstenite = { workspace = true, optional = true }\nfutures = { workspace = true, optional = true }\n"
  },
  {
    "path": "crates/burn-communication/README.md",
    "content": "# Burn Communication\n\nAbstractions for network communication\n\nThe Protocol trait defines how to communicate in a server/client style.\nThe server can set up routes with callbacks upon connection.\n\n## WebSocket\n\nCommunication with WebSockets is implemented with the `websocket` feature.\n\n## Tensor Data Service\n\nThe tensor data service provides easy utilities to share tensors peer-to-peer.\nOne peer can expose a tensor, and another can download it. Each peer is both a client and a server.\n"
  },
  {
    "path": "crates/burn-communication/src/base.rs",
    "content": "use burn_std::future::DynFut;\nuse serde::{Deserialize, Serialize};\nuse std::fmt::{Debug, Display};\nuse std::hash::Hash;\nuse std::str::FromStr;\n\n/// Allows nodes to find each other\n#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]\npub struct Address {\n    pub(crate) inner: String,\n}\n\nimpl FromStr for Address {\n    type Err = String;\n\n    fn from_str(s: &str) -> Result<Self, Self::Err> {\n        Ok(Self {\n            inner: s.to_string(),\n        })\n    }\n}\n\nimpl Display for Address {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        write!(f, \"{}\", self.inner)\n    }\n}\n\n/// The protocol used for the communications.\npub trait Protocol: Clone + Send + Sync + 'static {\n    /// The client implementation for the current protocol.\n    type Client: ProtocolClient;\n    /// The server implementation for the current protocol.\n    type Server: ProtocolServer;\n}\n\n/// Error that happens during a communication.\npub trait CommunicationError: Debug + Send + 'static {}\n\n/// The client is only used to create a [channel](CommunicationChannel), which should be use to\n/// transmit information with the [server](ProtocolServer).\npub trait ProtocolClient: Send + Sync + 'static {\n    /// Channel used by this protocol.\n    type Channel: CommunicationChannel<Error = Self::Error>;\n    /// The error type.\n    type Error: CommunicationError;\n\n    /// Opens a new [channel](CommunicationChannel) with the current protocol at the given\n    /// [address](Address) and route.\n    ///\n    /// * `address` - Address to connect to\n    /// * `route` - The name of the route (no slashes)\n    ///\n    /// Returns None if the connection can't be done.\n    fn connect(address: Address, route: &str) -> DynFut<Option<Self::Channel>>;\n}\n\n/// Data sent and received by the client and server.\n#[derive(new)]\npub struct Message {\n    /// The data is always encoded as bytes.\n    pub data: bytes::Bytes,\n}\n\n/// Defines how to create a server that respond to a [channel](CommunicationChannel).\npub trait ProtocolServer: Sized + Send + Sync + 'static {\n    /// Channel used by this protocol.\n    type Channel: CommunicationChannel<Error = Self::Error>;\n    /// The error type.\n    type Error: CommunicationError;\n\n    /// Defines an endpoint with the function that responds.\n    /// TODO Docs: does it need a slash?\n    fn route<C, Fut>(self, path: &str, callback: C) -> Self\n    where\n        C: FnOnce(Self::Channel) -> Fut + Clone + Send + Sync + 'static,\n        Fut: Future<Output = ()> + Send + 'static;\n\n    /// Start the server.\n    fn serve<F>(\n        self,\n        shutdown: F,\n    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static\n    where\n        F: Future<Output = ()> + Send + 'static;\n}\n\n/// Handles communications.\npub trait CommunicationChannel: Send + 'static {\n    type Error: CommunicationError;\n\n    /// Send a [message](Message) on the channel.\n    fn send(\n        &mut self,\n        message: Message,\n    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;\n\n    /// Receive a [message](Message) on the channel and returns a new [response message](Message).\n    fn recv(\n        &mut self,\n    ) -> impl std::future::Future<Output = Result<Option<Message>, Self::Error>> + Send;\n\n    fn close(&mut self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;\n}\n"
  },
  {
    "path": "crates/burn-communication/src/data_service.rs",
    "content": "//! This module enables direct data transfer between servers without blocking the client or any server.\n//!\n//! It eliminates the need for intermediate data transfer through the client, avoiding the process of downloading data from one server and reuploading it to another.\n//!\n//! The module provides an optimized mechanism for servers to communicate directly, streamlining data movement between them without involving the client.\n\nuse crate::Message;\nuse crate::base::Protocol;\nuse crate::base::{Address, CommunicationChannel, ProtocolClient, ProtocolServer};\nuse burn_tensor::{TensorData, backend::Backend};\nuse serde::{Deserialize, Serialize};\nuse std::{collections::HashMap, marker::PhantomData, sync::Arc};\nuse tokio::sync::Mutex;\nuse tokio::sync::Notify;\nuse tokio_util::sync::CancellationToken;\n\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]\npub struct TensorTransferId(u64);\n\nimpl From<u64> for TensorTransferId {\n    fn from(value: u64) -> Self {\n        Self(value)\n    }\n}\n\nimpl TensorTransferId {\n    pub fn next(&mut self) {\n        self.0 += 1;\n    }\n}\n\n#[derive(Debug, Serialize, Deserialize)]\nenum DataServiceMessage {\n    TensorRequest(TensorTransferId),\n    Tensor(TensorData),\n}\n\ntype ClientChannelRef<C> = Arc<Mutex<<C as ProtocolClient>::Channel>>;\n\npub struct TensorDataService<B: Backend, P: Protocol<Client: ProtocolClient>> {\n    /// Maps tensor transfer IDs to their exposed state.\n    pub exposed_tensors: Mutex<HashMap<TensorTransferId, TensorExposeState>>,\n    /// Maps node addresses to their channels.\n    pub channels: Mutex<HashMap<Address, ClientChannelRef<P::Client>>>,\n    /// Notify when a new tensor is exposed.\n    pub new_tensor_notify: Arc<Notify>,\n\n    cancel_token: CancellationToken,\n\n    _phantom_data: PhantomData<B>,\n}\n\npub struct TensorExposeState {\n    /// The bytes of the tensor data message. Message::Data(...) serialized with rmp_serde\n    pub bytes: bytes::Bytes,\n    /// How many times the tensor will be downloaded\n    pub max_downloads: u32,\n    /// How man times the tensor has been downloaded\n    pub cur_download_count: u32,\n}\n\n/// Provides a routing function for a tensor data service for a communications server\npub trait TensorDataServer<B: Backend, P: Protocol> {\n    /// Routes the tensor data service to the \"/data\" route\n    fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self;\n}\n\nimpl<B: Backend, S: ProtocolServer + Sized, P: Protocol<Server = S> + 'static>\n    TensorDataServer<B, P> for S\n{\n    fn route_tensor_data_service(self, state: Arc<TensorDataService<B, P>>) -> Self {\n        self.route(\"/data\", async move |stream: S::Channel| {\n            state.handle_data_channel(stream).await;\n        })\n    }\n}\n\nimpl<B: Backend, P: Protocol> TensorDataService<B, P> {\n    pub fn new(cancel_token: CancellationToken) -> Self {\n        Self {\n            exposed_tensors: Mutex::new(HashMap::new()),\n            channels: Mutex::new(HashMap::new()),\n            new_tensor_notify: Arc::new(Notify::new()),\n            cancel_token,\n            _phantom_data: PhantomData::<B>,\n        }\n    }\n\n    /// Exposes a tensor to the data server, allowing it to be downloaded by other nodes.\n    pub async fn expose(\n        &self,\n        tensor: B::FloatTensorPrimitive,\n        max_downloads: u32,\n        transfer_id: TensorTransferId,\n    ) {\n        let data = B::float_into_data(tensor).await.unwrap();\n        self.expose_data(data, max_downloads, transfer_id).await\n    }\n\n    /// Exposes a tensor data to the data server, allowing it to be downloaded by other nodes.\n    pub async fn expose_data(\n        &self,\n        tensor_data: TensorData,\n        max_downloads: u32,\n        transfer_id: TensorTransferId,\n    ) {\n        let bytes: bytes::Bytes = rmp_serde::to_vec(&DataServiceMessage::Tensor(tensor_data))\n            .unwrap()\n            .into();\n        let mut exposed_tensors = self.exposed_tensors.lock().await;\n        exposed_tensors.insert(\n            transfer_id,\n            TensorExposeState {\n                bytes,\n                max_downloads,\n                cur_download_count: 0,\n            },\n        );\n        core::mem::drop(exposed_tensors);\n        self.new_tensor_notify.notify_waiters();\n    }\n\n    pub async fn close(&self) {\n        // Send a closing message to every open WebSocket stream\n\n        let mut streams = self.channels.lock().await;\n        for (_, stream) in streams.drain() {\n            let mut stream = stream.lock().await;\n\n            stream\n                .close()\n                .await\n                .expect(\"Failed to close WebSocket stream\");\n        }\n    }\n\n    /// Downloads a tensor that is exposed on another server. Requires a Tokio 1.x runtime\n    ///\n    /// Returns None if the peer closes the connection\n    pub async fn download_tensor(\n        &self,\n        remote: Address,\n        transfer_id: TensorTransferId,\n    ) -> Option<TensorData> {\n        log::info!(\"Downloading tensor from {remote:?}\");\n\n        let stream = self.get_data_stream(remote).await;\n        let mut stream = stream.lock().await;\n\n        // Send the download request with the download id\n        let bytes: bytes::Bytes =\n            rmp_serde::to_vec(&DataServiceMessage::TensorRequest(transfer_id))\n                .unwrap()\n                .into();\n        stream\n            .send(Message::new(bytes))\n            .await\n            .expect(\"Failed to send download id\");\n\n        if let Ok(msg) = stream.recv().await {\n            let Some(msg) = msg else {\n                log::warn!(\"Received None message from the websocket, closing connection.\");\n                return None;\n            };\n\n            let DataServiceMessage::Tensor(data) = rmp_serde::from_slice(&msg.data)\n                .expect(\"Can deserialize messages from the websocket.\")\n            else {\n                panic!(\"Message should have been TensorData\")\n            };\n            return Some(data);\n        }\n        log::warn!(\"Closed connection\");\n        None\n    }\n\n    /// Get the WebSocket stream for the given address, or create a new one if it doesn't exist.\n    async fn get_data_stream(\n        &self,\n        address: Address,\n    ) -> Arc<Mutex<<P::Client as ProtocolClient>::Channel>> {\n        let mut streams = self.channels.lock().await;\n        match streams.get(&address) {\n            Some(stream) => stream.clone(),\n            None => {\n                // Open a new WebSocket connection to the address\n                let stream = P::Client::connect(address.clone(), \"data\").await;\n\n                let Some(stream) = stream else {\n                    panic!(\"Failed to connect to data server at {address:?}\");\n                };\n\n                let stream = Arc::new(Mutex::new(stream));\n                streams.insert(address.clone(), stream.clone());\n\n                stream\n            }\n        }\n    }\n\n    /// Get the requested exposed tensor data, and update download counter\n    async fn get_exposed_tensor_bytes(\n        &self,\n        transfer_id: TensorTransferId,\n    ) -> Option<bytes::Bytes> {\n        loop {\n            {\n                let mut exposed_tensors = self.exposed_tensors.lock().await;\n                // take the tensor out of the hashmap while we download\n                if let Some(mut exposed_state) = exposed_tensors.remove(&transfer_id) {\n                    exposed_state.cur_download_count += 1;\n                    let bytes = if exposed_state.cur_download_count == exposed_state.max_downloads {\n                        exposed_state.bytes\n                    } else {\n                        let bytes = exposed_state.bytes.clone();\n                        exposed_tensors.insert(transfer_id, exposed_state);\n                        bytes\n                    };\n                    return Some(bytes);\n                }\n            }\n            // No matching tensor, wait for a new one to come in.\n            self.new_tensor_notify.notified().await;\n        }\n    }\n\n    /// Handle incoming connections for downloading tensors.\n    pub(crate) async fn handle_data_channel(\n        &self,\n        mut channel: <P::Server as ProtocolServer>::Channel,\n    ) {\n        log::info!(\"[Data Handler] New connection for download.\");\n\n        while !self.cancel_token.is_cancelled() {\n            match channel.recv().await {\n                Ok(message) => {\n                    if let Some(msg) = message {\n                        let bytes = msg.data;\n                        let msg: DataServiceMessage = rmp_serde::from_slice(&bytes)\n                            .expect(\"Can deserialize messages from the websocket.\");\n                        let DataServiceMessage::TensorRequest(transfer_id) = msg else {\n                            panic!(\"Received a message that wasn't a tensor request! {msg:?}\");\n                        };\n\n                        let bytes = self.get_exposed_tensor_bytes(transfer_id).await.unwrap();\n\n                        channel.send(Message::new(bytes)).await.unwrap();\n                    } else {\n                        log::info!(\"Closed connection\");\n                        return;\n                    }\n                }\n                Err(err) => panic!(\"Failed to receive message from websocket: {err:?}\"),\n            };\n        }\n        log::info!(\"[Data Service] Closing connection for download.\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-communication/src/lib.rs",
    "content": "#[macro_use]\nextern crate derive_new;\n\nmod base;\npub use base::*;\n\npub mod util;\n\n#[cfg(feature = \"websocket\")]\npub mod websocket;\n\n#[cfg(feature = \"data-service\")]\npub mod data_service;\n"
  },
  {
    "path": "crates/burn-communication/src/util.rs",
    "content": "use tracing_core::{Level, LevelFilter};\nuse tracing_subscriber::{\n    Layer, filter::filter_fn, layer::SubscriberExt, registry, util::SubscriberInitExt,\n};\n\n/// Utilities to help handle communication termination.\npub async fn os_shutdown_signal() {\n    let ctrl_c = async {\n        tokio::signal::ctrl_c()\n            .await\n            .expect(\"failed to install Ctrl+C handler\");\n    };\n\n    #[cfg(unix)]\n    let terminate = async {\n        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())\n            .expect(\"failed to install signal handler\")\n            .recv()\n            .await;\n    };\n\n    #[cfg(not(unix))]\n    let terminate = std::future::pending::<()>();\n\n    tokio::select! {\n        _ = ctrl_c => {},\n        _ = terminate => {},\n    }\n}\n\npub(crate) fn init_logging() {\n    let layer = tracing_subscriber::fmt::layer()\n        .with_filter(LevelFilter::INFO)\n        .with_filter(filter_fn(|m| {\n            if let Some(path) = m.module_path() {\n                // The wgpu crate is logging too much, so we skip `info` level.\n                if path.starts_with(\"wgpu\") && *m.level() >= Level::INFO {\n                    return false;\n                }\n            }\n            true\n        }));\n\n    // If we start multiple servers in the same process, this will fail, it's ok\n    let _ = registry().with(layer).try_init();\n}\n"
  },
  {
    "path": "crates/burn-communication/src/websocket/base.rs",
    "content": "use crate::{\n    base::{Address, Protocol},\n    websocket::{client::WsClient, server::WsServer},\n};\n\n#[derive(Clone)]\n/// A websocket implements a [communication protocol](Protocol) that can be used to communicate\n/// over the internet.\npub struct WebSocket {}\n\nimpl Protocol for WebSocket {\n    type Client = WsClient;\n    type Server = WsServer;\n}\n\n/// Parse an address, add the ws:// prefix if needed, and return an error if the address is invalid\npub(crate) fn parse_ws_address(mut address: Address) -> Result<Address, String> {\n    let s = &address.inner;\n    let parts = s.split(\"://\").collect::<Vec<&str>>();\n    let num_parts = parts.len();\n    let url = if num_parts == 2 {\n        if parts[0] == \"ws\" {\n            s.to_owned()\n        } else {\n            return Err(format!(\"Invalid prefix: {}\", parts[0]));\n        }\n    } else if num_parts == 1 {\n        return Err(format!(\"ws://{s}\"));\n    } else {\n        return Err(format!(\"Invalid url: {s}\"));\n    };\n\n    address.inner = url;\n    Ok(address)\n}\n"
  },
  {
    "path": "crates/burn-communication/src/websocket/client.rs",
    "content": "use crate::{\n    base::{Address, CommunicationChannel, CommunicationError, Message, ProtocolClient},\n    websocket::base::parse_ws_address,\n};\nuse burn_std::future::DynFut;\nuse futures::{SinkExt, StreamExt};\nuse tokio::net::TcpStream;\nuse tokio_tungstenite::{\n    MaybeTlsStream, WebSocketStream, connect_async_with_config,\n    tungstenite::{self, protocol::WebSocketConfig},\n};\n\n#[derive(Clone)]\npub struct WsClient;\n\nimpl ProtocolClient for WsClient {\n    type Channel = WsClientChannel;\n    type Error = WsClientError;\n\n    fn connect(address: Address, route: &str) -> DynFut<Option<WsClientChannel>> {\n        Box::pin(connect_ws(address, route.to_owned()))\n    }\n}\n\n/// Open a new WebSocket connection to the address\nasync fn connect_ws(address: Address, route: String) -> Option<WsClientChannel> {\n    let address = parse_ws_address(address).ok()?;\n    let address = format!(\"{address}/{route}\");\n    const MB: usize = 1024 * 1024;\n    let (stream, _) = connect_async_with_config(\n        address.clone(),\n        Some(\n            WebSocketConfig::default()\n                .write_buffer_size(0)\n                .max_message_size(None)\n                .max_frame_size(Some(MB * 512))\n                .accept_unmasked_frames(true)\n                .read_buffer_size(64 * 1024), // 64 KiB (previous default)\n        ),\n        true,\n    )\n    .await\n    .ok()?;\n\n    Some(WsClientChannel { inner: stream })\n}\npub struct WsClientChannel {\n    inner: WebSocketStream<MaybeTlsStream<TcpStream>>,\n}\n\nimpl CommunicationChannel for WsClientChannel {\n    type Error = WsClientError;\n\n    async fn send(&mut self, msg: Message) -> Result<(), WsClientError> {\n        self.inner\n            .send(tungstenite::Message::Binary(msg.data))\n            .await?;\n\n        Ok(())\n    }\n\n    async fn recv(&mut self) -> Result<Option<Message>, WsClientError> {\n        match self.inner.next().await {\n            Some(next) => match next {\n                Ok(tungstenite::Message::Binary(data)) => Ok(Some(Message { data })),\n                Ok(tungstenite::Message::Close(_close_frame)) => Ok(None),\n                Err(err) => Err(WsClientError::Tungstenite(err)),\n                msg => Err(WsClientError::UnknownMessage(format!(\"{msg:?}\"))),\n            },\n            None => todo!(),\n        }\n    }\n\n    async fn close(&mut self) -> Result<(), WsClientError> {\n        let reason = \"Peer is closing\".to_string();\n\n        self.inner\n            .send(tungstenite::Message::Close(Some(\n                tungstenite::protocol::CloseFrame {\n                    code: tungstenite::protocol::frame::coding::CloseCode::Normal,\n                    reason: reason.clone().into(),\n                },\n            )))\n            .await?;\n\n        Ok(())\n    }\n}\n\n#[derive(Debug)]\npub enum WsClientError {\n    Io(std::io::Error),\n    Tungstenite(tungstenite::Error),\n    UnknownMessage(String),\n    Other(String),\n}\nimpl CommunicationError for WsClientError {}\n\nimpl From<std::io::Error> for WsClientError {\n    fn from(err: std::io::Error) -> Self {\n        Self::Io(err)\n    }\n}\n\nimpl From<tungstenite::Error> for WsClientError {\n    fn from(err: tungstenite::Error) -> Self {\n        Self::Tungstenite(err)\n    }\n}\n"
  },
  {
    "path": "crates/burn-communication/src/websocket/mod.rs",
    "content": "mod base;\nmod client;\nmod server;\n\npub use base::*;\npub use client::*;\npub use server::*;\n"
  },
  {
    "path": "crates/burn-communication/src/websocket/server.rs",
    "content": "use std::net::SocketAddr;\n\nuse crate::{\n    base::{CommunicationChannel, CommunicationError, Message, ProtocolServer},\n    util::init_logging,\n};\nuse axum::{\n    Router,\n    extract::{\n        State, WebSocketUpgrade,\n        ws::{self, WebSocket},\n    },\n    routing::get,\n};\nuse futures::StreamExt;\n\n#[derive(Clone, Debug)]\npub struct WsServer {\n    port: u16,\n    router: Router<()>,\n}\n\npub struct WsServerChannel {\n    inner: WebSocket,\n}\n\nimpl WsServer {\n    pub fn new(port: u16) -> Self {\n        Self {\n            port,\n            router: Router::new(),\n        }\n    }\n}\n\nimpl ProtocolServer for WsServer {\n    type Channel = WsServerChannel;\n    type Error = WsServerError;\n\n    async fn serve<F>(self, shutdown: F) -> Result<(), Self::Error>\n    where\n        F: Future<Output = ()> + Send + 'static,\n    {\n        init_logging();\n\n        let address = format!(\"0.0.0.0:{}\", self.port);\n        log::info!(\"Starting server {address}\");\n\n        let listener = tokio::net::TcpListener::bind(address).await?;\n\n        axum::serve(\n            listener,\n            self.router\n                .into_make_service_with_connect_info::<SocketAddr>(),\n        )\n        .with_graceful_shutdown(shutdown)\n        .await?;\n\n        Ok(())\n    }\n\n    fn route<C, Fut>(mut self, path: &str, callback: C) -> Self\n    where\n        C: FnOnce(WsServerChannel) -> Fut + Clone + Send + Sync + 'static,\n        Fut: Future<Output = ()> + Send + 'static,\n    {\n        // Format path: should start with a /\n        let path = if path.starts_with(\"/\") {\n            path.to_owned()\n        } else {\n            format!(\"/{path}\")\n        };\n\n        let method = get(|ws: WebSocketUpgrade, _: State<()>| async {\n            ws.on_upgrade(async move |socket| {\n                callback(WsServerChannel { inner: socket }).await;\n            })\n        });\n\n        self.router = self.router.route(&path, method);\n\n        self\n    }\n}\n\nimpl CommunicationChannel for WsServerChannel {\n    type Error = WsServerError;\n\n    async fn send(&mut self, message: Message) -> Result<(), WsServerError> {\n        self.inner.send(ws::Message::Binary(message.data)).await?;\n\n        Ok(())\n    }\n\n    async fn recv(&mut self) -> Result<Option<Message>, WsServerError> {\n        match self.inner.next().await {\n            Some(next) => match next {\n                Ok(ws::Message::Binary(data)) => Ok(Some(Message { data })),\n                Ok(ws::Message::Close(_close_frame)) => Ok(None),\n                Err(err) => Err(WsServerError::Axum(err)),\n                msg => Err(WsServerError::UnknownMessage(format!(\"{msg:?}\"))),\n            },\n            None => todo!(),\n        }\n    }\n\n    async fn close(&mut self) -> Result<(), WsServerError> {\n        let reason = \"Peer is closing\".to_string();\n\n        self.inner\n            .send(ws::Message::Close(Some(ws::CloseFrame {\n                code: 1000, // code: Normal\n                reason: reason.clone().into(),\n            })))\n            .await?;\n\n        Ok(())\n    }\n}\n\n#[derive(Debug)]\npub enum WsServerError {\n    Io(std::io::Error),\n    Axum(axum::Error),\n    UnknownMessage(String),\n    Other(String),\n}\n\nimpl CommunicationError for WsServerError {}\n\nimpl From<std::io::Error> for WsServerError {\n    fn from(err: std::io::Error) -> Self {\n        Self::Io(err)\n    }\n}\n\nimpl From<axum::Error> for WsServerError {\n    fn from(err: axum::Error) -> Self {\n        Self::Axum(err)\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Flexible and Comprehensive Deep Learning Framework in Rust\"\ndocumentation = \"https://docs.rs/burn-core\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-core\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-core\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"std\",\n    \"burn-std/default\",\n    \"burn-dataset?/default\",\n    \"burn-tensor/default\",\n]\ndoc = [\n    \"std\",\n    \"dataset\",\n    \"audio\",\n    # Doc features\n    \"burn-std/doc\",\n    \"burn-dataset/doc\",\n    \"burn-tensor/doc\",\n]\ntracing = [\n    \"burn-std/tracing\",\n    \"burn-tensor/tracing\",\n    \"burn-dataset?/tracing\",\n    \"burn-vision?/tracing\",\n]\n\n\ndataset = [\"burn-dataset\"]\n\nnetwork = [\"burn-std/network\"]\nsqlite = [\"burn-dataset?/sqlite\"]\nsqlite-bundled = [\"burn-dataset?/sqlite-bundled\"]\nstd = [\n    \"bincode/std\",\n    \"burn-std/std\",\n    \"burn-tensor/std\",\n    \"flate2\",\n    \"half/std\",\n    \"log\",\n    \"rand/std\",\n    \"rmp-serde\",\n    \"serde/std\",\n    \"serde_json/std\",\n    \"num-traits/std\",\n]\nvision = [\"burn-vision\", \"burn-dataset?/vision\"]\naudio = [\"burn-dataset?/audio\"]\n\n# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.\nrecord-item-custom-serde = [\"thiserror\"]\n\ntest-cuda = [\n    \"burn-cuda/default\",\n] # To use cuda during testing, default uses ndarray.\ntest-rocm = [\n    \"burn-rocm/default\",\n] # To use hip during testing, default uses ndarray.\ntest-tch = [\n    \"burn-tch/default\",\n] # To use tch during testing, default uses ndarray.\ntest-wgpu = [\n    \"burn-wgpu/default\",\n] # To use wgpu during testing, default uses ndarray.\ntest-vulkan = [\n    \"test-wgpu\",\n    \"burn-wgpu/vulkan\",\n] # To use wgpu-spirv during testing, default uses ndarray.\ntest-metal = [\n    \"test-wgpu\",\n    \"burn-wgpu/metal\",\n] # To use wgpu-spirv during testing, default uses ndarray.\n\n# Memory checks are disabled by default\ntest-memory-checks = [\"burn-fusion/memory-checks\"]\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std when std is disabled **\n\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-dataset = { path = \"../burn-dataset\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-derive = { path = \"../burn-derive\", version = \"=0.21.0-pre.2\" }\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-vision = { path = \"../burn-vision\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\ndata-encoding = { workspace = true }\nuuid = { workspace = true }\n\nderive-new = { workspace = true }\nlog = { workspace = true, optional = true }\nrand = { workspace = true }\n\n# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)\nhashbrown = { workspace = true, features = [\"serde\"] } # no_std compatible\n\n# Serialize Deserialize\nflate2 = { workspace = true, optional = true }\nserde = { workspace = true, features = [\"derive\"] }\n\nahash = { workspace = true }\nbincode = { workspace = true }\nhalf = { workspace = true }\nnum-traits = { workspace = true }\nrmp-serde = { workspace = true, optional = true }\nserde_json = { workspace = true, features = [\"alloc\"] } #Default enables std\nspin = { workspace = true }                             # Using in place of use std::sync::Mutex when std is disabled\nthiserror = { workspace = true, optional = true }\n\n[target.'cfg(target_has_atomic = \"ptr\")'.dependencies]\nregex = { workspace = true }\n\n# FOR TESTING\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-remote = { path = \"../burn-remote\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\n\n[target.'cfg(not(target_has_atomic = \"ptr\"))'.dependencies]\nportable-atomic-util = { workspace = true }\nportable-atomic = { workspace = true }\n\n[dev-dependencies]\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\" }\nburn-dataset = { path = \"../burn-dataset\", version = \"=0.21.0-pre.2\", features = [\n    \"fake\",\n] }\nrstest = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-core/README.md",
    "content": "# Burn Core\n\nThis crate should be used with [burn](https://github.com/tracel-ai/burn). It contains the core\ntraits and components for building and training deep learning models with Burn.\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-core.svg)](https://crates.io/crates/burn-core)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-core/blob/master/README.md)\n\n## Feature Flags\n\nThis crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the\ndefault `std` feature.\n\n- `std` - enables the standard library. Enabled by default.\n"
  },
  {
    "path": "crates/burn-core/src/config.rs",
    "content": "use alloc::{format, string::String, string::ToString};\npub use burn_derive::Config;\nuse core::fmt::Debug;\n\n/// Configuration IO error.\n#[derive(Debug)]\npub enum ConfigError {\n    /// Invalid format.\n    InvalidFormat(String),\n\n    /// File not found.\n    FileNotFound(String),\n}\n\nimpl core::fmt::Display for ConfigError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        let mut message = \"Config error => \".to_string();\n\n        match self {\n            Self::InvalidFormat(err) => {\n                message += format!(\"Invalid format: {err}\").as_str();\n            }\n            Self::FileNotFound(err) => {\n                message += format!(\"File not found: {err}\").as_str();\n            }\n        };\n\n        f.write_str(message.as_str())\n    }\n}\n\nimpl core::error::Error for ConfigError {}\n\n/// Configuration trait.\npub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned {\n    /// Saves the configuration to a file.\n    ///\n    /// # Arguments\n    ///\n    /// * `file` - File to save the configuration to.\n    ///\n    /// # Returns\n    ///\n    /// The output of the save operation.\n    #[cfg(feature = \"std\")]\n    fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {\n        std::fs::write(file, config_to_json(self))\n    }\n\n    /// Loads the configuration from a file.\n    ///\n    /// # Arguments\n    ///\n    /// * `file` - File to load the configuration from.\n    ///\n    /// # Returns\n    ///\n    /// The loaded configuration.\n    #[cfg(feature = \"std\")]\n    fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {\n        let content = std::fs::read_to_string(file.as_ref())\n            .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;\n        config_from_str(&content)\n    }\n\n    /// Loads the configuration from a binary buffer.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - Binary buffer to load the configuration from.\n    ///\n    /// # Returns\n    ///\n    /// The loaded configuration.\n    fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {\n        let content = core::str::from_utf8(data).map_err(|_| {\n            ConfigError::InvalidFormat(\"Could not parse data as utf-8.\".to_string())\n        })?;\n        config_from_str(content)\n    }\n}\n\n/// Converts a configuration to a JSON string.\n///\n/// # Arguments\n///\n/// * `config` - Configuration to convert.\n///\n/// # Returns\n///\n/// The JSON string.\npub fn config_to_json<C: Config>(config: &C) -> String {\n    serde_json::to_string_pretty(config).unwrap()\n}\n\nfn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {\n    serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!(\"{err}\")))\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/base.rs",
    "content": "use burn_tensor::backend::Backend;\n\npub use crate::data::dataset::{Dataset, DatasetIterator};\nuse core::iter::Iterator;\nuse std::sync::Arc;\n\n/// A progress struct that can be used to track the progress of a data loader.\n#[derive(new, Clone, Debug)]\npub struct Progress {\n    /// The number of items that have been processed.\n    pub items_processed: usize,\n\n    /// The total number of items that need to be processed.\n    pub items_total: usize,\n}\n\n/// A data loader iterator that can be used to iterate over a data loader.\npub trait DataLoaderIterator<O>: Iterator<Item = O> {\n    /// Returns the progress of the data loader.\n    fn progress(&self) -> Progress;\n}\n\n/// A data loader that can be used to iterate over a dataset.\npub trait DataLoader<B: Backend, O>: Send + Sync {\n    /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.\n    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;\n\n    /// The number of items (not the number of batches nor the number of iterations),\n    /// corresponding to the items_total of the progress returned by the iterator.\n    fn num_items(&self) -> usize;\n\n    /// Move the data loader to the given device, ensuring the batches are assigned to the correct device.\n    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>>;\n\n    /// Returns a new data loader containing a subset of the data.\n    ///\n    /// The subset includes items from `start` (inclusive) to `end` (exclusive),\n    /// preserving the batch size and ordering of the original data loader.\n    ///\n    /// # Arguments\n    ///\n    /// * `start` - The starting index of the subset (inclusive).\n    /// * `end` - The ending index of the subset (exclusive).\n    ///\n    /// # Returns\n    ///\n    /// A boxed [`DataLoader`] instance containing only the specified range.\n    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>>;\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/batch.rs",
    "content": "use super::{BatchStrategy, DataLoader, DataLoaderIterator, Progress, batcher::Batcher};\nuse burn_dataset::{\n    Dataset,\n    transform::{PartialDataset, ShuffledDataset},\n};\nuse burn_tensor::backend::Backend;\nuse rand::SeedableRng;\nuse std::ops::DerefMut;\nuse std::sync::Arc;\n\n/// A data loader that can be used to iterate over a dataset in batches.\npub struct BatchDataLoader<B: Backend, I, O> {\n    strategy: Box<dyn BatchStrategy<I>>,\n    dataset: Arc<dyn Dataset<I>>,\n    batcher: Arc<dyn Batcher<B, I, O>>,\n    device: B::Device,\n    rng: Option<Arc<spin::Mutex<rand::rngs::StdRng>>>,\n}\n\nimpl<B: Backend, I, O> Clone for BatchDataLoader<B, I, O> {\n    fn clone(&self) -> Self {\n        Self {\n            strategy: self.strategy.clone_dyn(),\n            dataset: self.dataset.clone(),\n            batcher: self.batcher.clone(),\n            device: self.device.clone(),\n            rng: self.rng.clone(),\n        }\n    }\n}\n\nimpl<B: Backend, I, O> BatchDataLoader<B, I, O> {\n    /// Creates a new batch data loader.\n    ///\n    /// # Arguments\n    ///\n    /// * `strategy` - The batch strategy.\n    /// * `dataset` - The dataset.\n    /// * `batcher` - The batcher.\n    /// * `device`  - The device to use when loading a batch.\n    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader\n    ///   iterator is created.\n    ///\n    /// # Returns\n    ///\n    /// The batch data loader.\n    pub fn new(\n        strategy: Box<dyn BatchStrategy<I>>,\n        dataset: Arc<dyn Dataset<I>>,\n        batcher: Arc<dyn Batcher<B, I, O>>,\n        device: B::Device,\n        rng: Option<rand::rngs::StdRng>,\n    ) -> Self {\n        Self {\n            strategy,\n            dataset,\n            batcher,\n            device,\n            rng: rng.map(|rng| Arc::new(spin::Mutex::new(rng))),\n        }\n    }\n}\n\n/// A data loader iterator that can be used to iterate over a data loader.\nstruct BatchDataloaderIterator<B: Backend, I, O> {\n    current_index: usize,\n    strategy: Box<dyn BatchStrategy<I>>,\n    dataset: Arc<dyn Dataset<I>>,\n    batcher: Arc<dyn Batcher<B, I, O>>,\n    device: B::Device,\n}\n\nimpl<B, I, O> DataLoader<B, O> for BatchDataLoader<B, I, O>\nwhere\n    B: Backend,\n    I: Send + Sync + Clone + 'static,\n    O: Send + 'static,\n{\n    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {\n        // When starting a new iteration, we first check if the dataloader was created with an rng,\n        // implying that we should shuffle the dataset beforehand, while advancing the current\n        // rng to ensure that each new iteration shuffles the dataset differently.\n        let dataset = match &self.rng {\n            Some(rng) => Arc::new(ShuffledDataset::new(\n                self.dataset.clone(),\n                rng.lock().deref_mut(),\n            )),\n            None => self.dataset.clone(),\n        };\n        Box::new(BatchDataloaderIterator::new(\n            self.strategy.clone_dyn(),\n            dataset,\n            self.batcher.clone(),\n            self.device.clone(),\n        ))\n    }\n\n    fn num_items(&self) -> usize {\n        self.dataset.len()\n    }\n\n    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {\n        let rng = self.rng.as_ref().map(|rng| {\n            let mut rng = rng.lock();\n            rng.fork()\n        });\n        Arc::new(Self::new(\n            self.strategy.clone_dyn(),\n            self.dataset.clone(),\n            self.batcher.clone(),\n            device.clone(),\n            rng,\n        ))\n    }\n\n    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {\n        let rng = self.rng.as_ref().map(|rng| {\n            let mut rng = rng.lock();\n            rng.fork()\n        });\n        let dataloader = Self::new(\n            self.strategy.clone_dyn(),\n            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),\n            self.batcher.clone(),\n            self.device.clone(),\n            rng,\n        );\n        Arc::new(dataloader)\n    }\n}\n\nimpl<B: Backend, I, O> BatchDataloaderIterator<B, I, O> {\n    /// Creates a new batch data loader iterator.\n    ///\n    /// # Arguments\n    ///\n    /// * `strategy` - The batch strategy.\n    /// * `dataset` - The dataset.\n    /// * `batcher` - The batcher.\n    /// * `device`  - The device to use when loading a batch.\n    ///\n    /// # Returns\n    ///\n    /// The batch data loader iterator.\n    pub fn new(\n        strategy: Box<dyn BatchStrategy<I>>,\n        dataset: Arc<dyn Dataset<I>>,\n        batcher: Arc<dyn Batcher<B, I, O>>,\n        device: B::Device,\n    ) -> Self {\n        BatchDataloaderIterator {\n            current_index: 0,\n            strategy,\n            dataset,\n            batcher,\n            device,\n        }\n    }\n}\n\nimpl<B: Backend, I, O> Iterator for BatchDataloaderIterator<B, I, O> {\n    type Item = O;\n\n    fn next(&mut self) -> Option<O> {\n        while let Some(item) = self.dataset.get(self.current_index) {\n            self.current_index += 1;\n            self.strategy.add(item);\n\n            if let Some(items) = self.strategy.batch(false) {\n                return Some(self.batcher.batch(items, &self.device));\n            }\n        }\n\n        if let Some(items) = self.strategy.batch(true) {\n            return Some(self.batcher.batch(items, &self.device));\n        }\n\n        None\n    }\n}\n\nimpl<B: Backend, I, O> DataLoaderIterator<O> for BatchDataloaderIterator<B, I, O> {\n    fn progress(&self) -> Progress {\n        Progress::new(self.current_index, self.dataset.len())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::collections::HashSet;\n\n    use super::*;\n    use crate::data::dataloader::FixBatchStrategy;\n    use crate::data::dataloader::batcher::TestBatcher;\n    use crate::data::dataset::FakeDataset;\n\n    #[test]\n    fn test_batch_dataloader() {\n        let batcher = Arc::new(TestBatcher::new());\n        let dataset = Arc::new(FakeDataset::<String>::new(27));\n        let dataloader = BatchDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset.clone(),\n            batcher,\n            Default::default(),\n            None,\n        );\n\n        let mut items_dataset = HashSet::new();\n        let mut items_dataloader = HashSet::new();\n\n        for item in dataset.iter() {\n            items_dataset.insert(item);\n        }\n\n        for items in dataloader.iter() {\n            for item in items {\n                items_dataloader.insert(item);\n            }\n        }\n\n        assert_eq!(items_dataset, items_dataloader);\n    }\n\n    #[test]\n    fn test_batch_dataloader_slice() {\n        let batcher = Arc::new(TestBatcher::new());\n        let dataset = Arc::new(FakeDataset::<String>::new(27));\n        let dataloader = BatchDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset.clone(),\n            batcher,\n            Default::default(),\n            None,\n        );\n        let dataloader_slice = dataloader.slice(5, 15);\n\n        let mut items_dataloader = HashSet::new();\n        let mut items_dataloader_slice = HashSet::new();\n\n        let mut idx = 0;\n        for items in dataloader.iter() {\n            for item in items {\n                if (5..15).contains(&idx) {\n                    items_dataloader.insert(item);\n                }\n                idx += 1;\n            }\n        }\n\n        for items in dataloader_slice.iter() {\n            for item in items {\n                items_dataloader_slice.insert(item);\n            }\n        }\n\n        assert_eq!(items_dataloader, items_dataloader_slice);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/batcher.rs",
    "content": "use burn_tensor::backend::Backend;\n\n#[cfg(test)]\nuse crate::TestBackend;\n\n/// A trait for batching items of type `I` into items of type `O`.\npub trait Batcher<B: Backend, I, O>: Send + Sync {\n    /// Batches the given items on the specified device.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - The items to batch.\n    /// * `device` - The backend device to use.\n    ///\n    /// # Returns\n    ///\n    /// The batched items.\n    fn batch(&self, items: Vec<I>, device: &B::Device) -> O;\n}\n\n/// Test batcher\n#[cfg(test)]\n#[derive(new, Clone)]\npub struct TestBatcher;\n\n#[cfg(test)]\nimpl<I> Batcher<TestBackend, I, Vec<I>> for TestBatcher {\n    fn batch(&self, items: Vec<I>, _device: &<TestBackend as Backend>::Device) -> Vec<I> {\n        items\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/builder.rs",
    "content": "use super::{\n    BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,\n    batcher::Batcher,\n};\nuse burn_dataset::Dataset;\nuse burn_tensor::backend::Backend;\nuse rand::{SeedableRng, rngs::StdRng};\nuse std::sync::Arc;\n\n/// A builder for data loaders.\npub struct DataLoaderBuilder<B: Backend, I, O> {\n    strategy: Option<Box<dyn BatchStrategy<I>>>,\n    batcher: Arc<dyn Batcher<B, I, O>>,\n    num_threads: Option<usize>,\n    shuffle: Option<u64>,\n    device: Option<B::Device>,\n}\n\nimpl<B, I, O> DataLoaderBuilder<B, I, O>\nwhere\n    B: Backend,\n    I: Send + Sync + Clone + std::fmt::Debug + 'static,\n    O: Send + Clone + std::fmt::Debug + 'static,\n{\n    /// Creates a new data loader builder.\n    ///\n    /// # Arguments\n    ///\n    /// * `batcher` - The batcher.\n    ///\n    /// # Returns\n    ///\n    /// The data loader builder.\n    pub fn new<Bt>(batcher: Bt) -> Self\n    where\n        Bt: Batcher<B, I, O> + 'static,\n    {\n        Self {\n            batcher: Arc::new(batcher),\n            strategy: None,\n            num_threads: None,\n            shuffle: None,\n            device: None,\n        }\n    }\n\n    /// Sets the batch size to a fix number.\n    ///\n    /// The [fix batch strategy](FixBatchStrategy) will be used.\n    ///\n    /// # Arguments\n    ///\n    /// * `batch_size` - The batch size.\n    ///\n    /// # Returns\n    ///\n    /// The data loader builder.\n    pub fn batch_size(mut self, batch_size: usize) -> Self {\n        self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));\n        self\n    }\n\n    /// Sets the seed for shuffling.\n    ///\n    /// Each time the dataloader starts a new iteration, the dataset will be shuffled.\n    ///\n    /// # Arguments\n    ///\n    /// * `seed` - The seed.\n    ///\n    /// # Returns\n    ///\n    /// The data loader builder.\n    pub fn shuffle(mut self, seed: u64) -> Self {\n        self.shuffle = Some(seed);\n        self\n    }\n\n    /// Sets the number of workers.\n    ///\n    /// - `Some(0)` or `None`: the dataloader will run without work threads.\n    /// - `Some(n); n > 0`: the dataloader will run with `n` background threads.\n    ///\n    /// A 1-worker threaded dataloader will run loads in a background thread,\n    /// while a 0-worker threaded dataloader will run loads in the main thread.\n    ///\n    /// # Arguments\n    ///\n    /// * `num_workers` - The number of workers.\n    ///\n    /// # Returns\n    ///\n    /// The data loader builder.\n    pub fn num_workers(mut self, num_workers: usize) -> Self {\n        self.num_threads = Some(num_workers);\n        self\n    }\n\n    /// Sets the data loader device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - The device to use when loading a batch.\n    ///\n    /// # Returns\n    ///\n    /// The data loader builder.\n    pub fn set_device(mut self, device: B::Device) -> Self {\n        self.device = Some(device);\n        self\n    }\n\n    /// Builds the data loader.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The dataset.\n    ///\n    /// # Returns\n    ///\n    /// The data loader.\n    pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>\n    where\n        D: Dataset<I> + 'static,\n    {\n        let dataset = Arc::new(dataset);\n\n        let device = self.device.unwrap_or_default();\n        let rng = self.shuffle.map(StdRng::seed_from_u64);\n        let strategy = match self.strategy {\n            Some(strategy) => strategy,\n            None => Box::new(FixBatchStrategy::new(1)),\n        };\n\n        if let Some(num_threads) = self.num_threads\n            && num_threads > 0\n        {\n            return Arc::new(MultiThreadDataLoader::new(\n                strategy,\n                dataset,\n                self.batcher,\n                num_threads,\n                device,\n                rng,\n            ));\n        }\n\n        Arc::new(BatchDataLoader::new(\n            strategy,\n            dataset,\n            self.batcher,\n            device,\n            rng,\n        ))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use crate::data::dataset::FakeDataset;\n\n    #[derive(new, Clone)]\n    struct TestBatcherDevice;\n\n    #[cfg(test)]\n    impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {\n        fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {\n            *device\n        }\n    }\n\n    type TestDevice = <TestBackend as Backend>::Device;\n\n    #[test]\n    fn test_dataloader_no_workers() {\n        type TestDevice = <TestBackend as Backend>::Device;\n\n        let default_device = TestDevice::default();\n        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())\n            .batch_size(1)\n            .build(FakeDataset::<String>::new(9));\n\n        assert_eq!(dataloader.num_items(), 9);\n\n        for device in dataloader.iter() {\n            assert_eq!(device, default_device)\n        }\n    }\n\n    #[test]\n    fn test_dataloader_default_device() {\n        let default_device = TestDevice::default();\n        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())\n            .batch_size(1)\n            .num_workers(1)\n            .build(FakeDataset::<String>::new(9));\n\n        assert_eq!(dataloader.num_items(), 9);\n\n        for device in dataloader.iter() {\n            assert_eq!(device, default_device)\n        }\n    }\n\n    #[test]\n    fn test_dataloader_slice_multi_device() {\n        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())\n            .batch_size(1)\n            .num_workers(1)\n            .build(FakeDataset::<String>::new(11));\n\n        #[cfg(all(\n            test,\n            not(feature = \"test-tch\"),\n            not(feature = \"test-wgpu\"),\n            not(feature = \"test-cuda\")\n        ))]\n        // Only one device exists...\n        let (device1, device2) = (\n            burn_ndarray::NdArrayDevice::Cpu,\n            burn_ndarray::NdArrayDevice::Cpu,\n        );\n\n        #[cfg(all(test, feature = \"test-tch\"))]\n        let (device1, device2) = (\n            burn_tch::LibTorchDevice::Cuda(0),\n            burn_tch::LibTorchDevice::Cuda(1),\n        );\n\n        #[cfg(all(test, feature = \"test-wgpu\"))]\n        let (device1, device2) = (\n            burn_wgpu::WgpuDevice::DiscreteGpu(0),\n            burn_wgpu::WgpuDevice::DiscreteGpu(1),\n        );\n\n        #[cfg(all(test, feature = \"test-cuda\"))]\n        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));\n\n        assert_eq!(dataloader.num_items(), 11);\n        let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);\n        let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);\n\n        assert_eq!(dataloader_1.num_items(), 5);\n        assert_eq!(dataloader_2.num_items(), 6);\n\n        let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());\n\n        for _ in 0..5 {\n            assert_eq!(iterator_1.next(), Some(device1));\n            assert_eq!(iterator_2.next(), Some(device2));\n        }\n\n        assert_eq!(iterator_1.next(), None);\n        // For uneven split, the last dataloader (partial dataset) will have the remaining item\n        assert_eq!(iterator_2.next(), Some(device2));\n        assert_eq!(iterator_2.next(), None);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/mod.rs",
    "content": "mod base;\nmod batch;\nmod builder;\nmod multithread;\nmod strategy;\n\n/// Module for batching items.\npub mod batcher;\n/// Module to split a dataloader.\npub mod split;\n\npub use base::*;\npub use batch::*;\npub use builder::*;\npub use multithread::*;\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/multithread.rs",
    "content": "use burn_dataset::Dataset;\nuse burn_dataset::transform::PartialDataset;\nuse burn_tensor::backend::Backend;\nuse rand::distr::{Distribution, StandardUniform};\nuse rand::rngs::StdRng;\nuse rand::{Rng, SeedableRng};\n\nuse super::batcher::Batcher;\nuse super::{BatchDataLoader, BatchStrategy, DataLoader, DataLoaderIterator, Progress};\nuse std::sync::{Arc, OnceLock, mpsc};\nuse std::thread;\n\nconst MAX_QUEUED_ITEMS: usize = 100;\n\ntype RngSeed = <StdRng as SeedableRng>::Seed;\n\n/// A multi-threaded data loader that can be used to iterate over a dataset.\npub struct MultiThreadDataLoader<B: Backend, I, O> {\n    // Configuration parameters needed for initialization\n    strategy: Box<dyn BatchStrategy<I>>,\n    dataset: Arc<dyn Dataset<I>>,\n    batcher: Arc<dyn Batcher<B, I, O>>,\n    device: B::Device,\n    seed: Option<RngSeed>,\n    num_threads: usize,\n\n    // The lazily initialized data loaders\n    dataloaders: OnceLock<Vec<BatchDataLoader<B, I, O>>>,\n}\n\n/// A message that can be sent between threads.\n#[derive(Debug)]\npub enum Message<O> {\n    /// A batch of items.\n    Batch(usize, O, Progress),\n\n    /// The thread is done.\n    Done,\n}\n\nstruct MultiThreadsDataloaderIterator<O> {\n    num_done: usize,\n    workers: Vec<thread::JoinHandle<()>>,\n    receiver: mpsc::Receiver<Message<O>>,\n    progresses: Vec<Progress>,\n}\n\nimpl<B: Backend, I, O> MultiThreadDataLoader<B, I, O>\nwhere\n    I: Send + Sync + Clone + 'static,\n    O: Send + 'static,\n{\n    /// Creates a new multi-threaded batch data loader.\n    ///\n    /// # Arguments\n    ///\n    /// * `strategy` - The batch strategy.\n    /// * `dataset` - The dataset.\n    /// * `batcher` - The batcher.\n    /// * `num_threads` - The number of threads.\n    /// * `device`  - The device to use when loading a batch.\n    /// * `rng`     - The rng determining if the dataset is shuffled each time a dataloader\n    ///   iterator is created.\n    ///\n    /// # Returns\n    ///\n    /// The multi-threaded batch data loader.\n    pub fn new(\n        strategy: Box<dyn BatchStrategy<I>>,\n        dataset: Arc<dyn Dataset<I>>,\n        batcher: Arc<dyn Batcher<B, I, O>>,\n        num_threads: usize,\n        device: B::Device,\n        rng: Option<rand::rngs::StdRng>,\n    ) -> Self {\n        let mut seed = None;\n        if let Some(mut rng) = rng {\n            // RNG stream splitting (not state cloning): derive a new seed from the RNG's output.\n            // This is exactly what `rng.fork()` does.\n            let mut s = RngSeed::default();\n            rng.fill_bytes(&mut s);\n\n            seed = Some(s);\n        }\n        Self::from_seed(strategy, dataset, batcher, num_threads, device, seed)\n    }\n\n    fn from_seed(\n        strategy: Box<dyn BatchStrategy<I>>,\n        dataset: Arc<dyn Dataset<I>>,\n        batcher: Arc<dyn Batcher<B, I, O>>,\n        num_threads: usize,\n        device: B::Device,\n        seed: Option<RngSeed>,\n    ) -> Self {\n        Self {\n            strategy,\n            dataset,\n            batcher,\n            num_threads,\n            device,\n            seed,\n            dataloaders: OnceLock::new(),\n        }\n    }\n\n    /// Force initialization if needed.\n    fn initialize(&self) -> &[BatchDataLoader<B, I, O>] {\n        self.dataloaders\n            .get_or_init(|| {\n                let mut dataset = self.dataset.clone();\n                if let Some(seed) = self.seed.as_ref() {\n                    // Pre-shuffle the dataset before split if shuffle is enabled.\n                    // This ensures that each thread gets a uniform random sample of the dataset.\n                    let mut rng = StdRng::from_seed(*seed);\n                    dataset = Arc::new(burn_dataset::transform::ShuffledDataset::new(\n                        dataset, &mut rng,\n                    ));\n                }\n\n                let datasets = match self.strategy.batch_size() {\n                    Some(batch_size) => {\n                        PartialDataset::split_chunks(dataset, self.num_threads, batch_size)\n                    }\n                    None => PartialDataset::split(dataset, self.num_threads),\n                };\n\n                // Create more rngs from the first one, one for each new dataloader.\n                let mut rng = self.seed.map(StdRng::from_seed);\n                let rngs = (0..self.num_threads).map(|_| {\n                    rng.as_mut().map(|rng| {\n                        StdRng::seed_from_u64(Distribution::sample(&StandardUniform, rng))\n                    })\n                });\n\n                datasets\n                    .into_iter()\n                    .zip(rngs)\n                    .map(|(dataset, rng)| {\n                        let strategy = self.strategy.clone_dyn();\n                        BatchDataLoader::new(\n                            strategy,\n                            Arc::new(dataset),\n                            self.batcher.clone(),\n                            self.device.clone(),\n                            rng,\n                        )\n                    })\n                    .collect()\n            })\n            .as_ref()\n    }\n}\n\nimpl<B: Backend, I, O> DataLoader<B, O> for MultiThreadDataLoader<B, I, O>\nwhere\n    I: Send + Sync + Clone + 'static,\n    O: Send + 'static + std::fmt::Debug,\n{\n    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a> {\n        // This will initialize the loader if it hasn't been initialized yet\n        let dataloaders = self.initialize();\n\n        let (sender, receiver) = mpsc::sync_channel::<Message<O>>(MAX_QUEUED_ITEMS);\n\n        let mut progresses = Vec::with_capacity(dataloaders.len());\n\n        let handlers: Vec<_> = dataloaders\n            .iter()\n            .enumerate()\n            .map(|(index, dataloader)| {\n                let dataloader_cloned = dataloader.clone();\n                let sender_cloned = sender.clone();\n                progresses.push(Progress::new(0, dataloader_cloned.num_items()));\n\n                std::thread::Builder::new()\n                    .name(std::format!(\"dataloader-{index}\"))\n                    .spawn(move || {\n                        let mut iterator = dataloader_cloned.iter();\n                        while let Some(item) = iterator.next() {\n                            let progress = iterator.progress();\n\n                            match sender_cloned.send(Message::Batch(index, item, progress)) {\n                                Ok(_) => {}\n                                // The receiver is probably gone, no need to panic, just need to stop\n                                // iterating.\n                                Err(_) => return,\n                            };\n                        }\n                        // Same thing.\n                        sender_cloned.send(Message::Done).ok();\n                    })\n                    .unwrap()\n            })\n            .collect();\n\n        Box::new(MultiThreadsDataloaderIterator::new(\n            receiver, handlers, progresses,\n        ))\n    }\n\n    fn num_items(&self) -> usize {\n        // For num_items, we can directly use the dataset size without\n        // necessarily initializing the full loader\n        self.dataset.len()\n    }\n\n    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>> {\n        Arc::new(Self::from_seed(\n            self.strategy.clone_dyn(),\n            self.dataset.clone(),\n            self.batcher.clone(),\n            self.num_threads,\n            device.clone(),\n            self.seed,\n        ))\n    }\n\n    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>> {\n        let dataloader = Self::from_seed(\n            self.strategy.clone_dyn(),\n            Arc::new(PartialDataset::new(self.dataset.clone(), start, end)),\n            self.batcher.clone(),\n            self.num_threads,\n            self.device.clone(),\n            self.seed,\n        );\n        Arc::new(dataloader)\n    }\n}\n\nimpl<O> MultiThreadsDataloaderIterator<O> {\n    pub fn new(\n        receiver: mpsc::Receiver<Message<O>>,\n        workers: Vec<thread::JoinHandle<()>>,\n        progresses: Vec<Progress>,\n    ) -> Self {\n        MultiThreadsDataloaderIterator {\n            num_done: 0,\n            workers,\n            receiver,\n            progresses,\n        }\n    }\n}\nimpl<O: std::fmt::Debug> DataLoaderIterator<O> for MultiThreadsDataloaderIterator<O> {\n    fn progress(&self) -> Progress {\n        let mut items_total = 0;\n        let mut items_processed = 0;\n\n        for progress in self.progresses.iter() {\n            items_total += progress.items_total;\n            items_processed += progress.items_processed;\n        }\n\n        Progress::new(items_processed, items_total)\n    }\n}\n\nimpl<O: std::fmt::Debug> Iterator for MultiThreadsDataloaderIterator<O> {\n    type Item = O;\n\n    fn next(&mut self) -> Option<O> {\n        if self.workers.is_empty() {\n            return None;\n        }\n\n        loop {\n            let item = self.receiver.recv();\n            let item = item.unwrap();\n\n            match item {\n                Message::Batch(index, item, progress) => {\n                    if let Some(current) = self.progresses.get_mut(index) {\n                        *current = progress;\n                    }\n                    return Some(item);\n                }\n                Message::Done => {\n                    self.num_done += 1;\n                }\n            };\n\n            if self.num_done == self.workers.len() {\n                while let Some(worker) = self.workers.pop() {\n                    worker.join().unwrap();\n                }\n                return None;\n            }\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::data::dataloader::FixBatchStrategy;\n    use crate::data::dataloader::batcher::TestBatcher;\n    use crate::data::dataset::FakeDataset;\n    use burn_dataset::InMemDataset;\n    use std::collections::HashSet;\n\n    #[test]\n    fn test_multi_thread_batch_dataloader() {\n        let batcher = Arc::new(TestBatcher::new());\n        let dataset = Arc::new(FakeDataset::<String>::new(27));\n        let dataloader_single_thread = BatchDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset.clone(),\n            batcher.clone(),\n            Default::default(),\n            None,\n        );\n        let dataloader_multi_thread = MultiThreadDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset,\n            batcher,\n            4,\n            Default::default(),\n            None,\n        );\n\n        let mut items_single_thread = HashSet::new();\n        let mut items_multi_thread = HashSet::new();\n\n        for items in dataloader_single_thread.iter() {\n            for item in items {\n                items_single_thread.insert(item);\n            }\n        }\n\n        for items in dataloader_multi_thread.iter() {\n            for item in items {\n                items_multi_thread.insert(item);\n            }\n        }\n\n        assert_eq!(items_single_thread, items_multi_thread);\n    }\n\n    #[test]\n    fn test_multi_thread_batch_dataloader_shuffle() {\n        let num_classes = 2;\n        let class_size = 100;\n        let batch_size = 10;\n\n        // Items is a deliberately ordered dataset.\n        let mut items = Vec::new();\n        for class in 0..num_classes {\n            items.extend(vec![class; class_size]);\n        }\n\n        {\n            // Unshuffled multithreaded loader\n            let dataset = Arc::new(InMemDataset::new(items.clone()));\n            let batcher = Arc::new(TestBatcher::new());\n\n            let loader = MultiThreadDataLoader::new(\n                Box::new(FixBatchStrategy::new(batch_size)),\n                dataset,\n                batcher,\n                num_classes,\n                Default::default(),\n                // No rng means no shuffling.\n                None,\n            );\n\n            for batch in loader.iter() {\n                let mut batch_items = HashSet::new();\n                for item in batch {\n                    batch_items.insert(item);\n                }\n\n                // Since the dataset is not shuffled, we expect each batch to contain the same item.\n                assert_eq!(batch_items.len(), 1);\n            }\n        }\n\n        {\n            // Shuffled multithreaded loader\n            let dataset = Arc::new(InMemDataset::new(items.clone()));\n            let batcher = Arc::new(TestBatcher::new());\n\n            let loader = MultiThreadDataLoader::new(\n                Box::new(FixBatchStrategy::new(batch_size)),\n                dataset.clone(),\n                batcher.clone(),\n                num_classes,\n                Default::default(),\n                // The rng enables shuffling.\n                Some(StdRng::seed_from_u64(42)),\n            );\n\n            for batch in loader.iter() {\n                let mut batch_items = HashSet::new();\n                for item in batch {\n                    batch_items.insert(item);\n                }\n\n                // Since the dataset is shuffled, we expect to see all items.\n                assert_eq!(batch_items.len(), num_classes);\n            }\n        }\n    }\n\n    #[test]\n    fn test_multi_thread_batch_dataloader_incomplete_batches() {\n        let batcher = Arc::new(TestBatcher::new());\n        let dataset = Arc::new(FakeDataset::<String>::new(27));\n        let dataloader_single_thread = BatchDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset.clone(),\n            batcher.clone(),\n            Default::default(),\n            None,\n        );\n        let dataloader_multi_thread = MultiThreadDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset,\n            batcher,\n            4,\n            Default::default(),\n            None,\n        );\n\n        let mut items_single_thread = HashSet::new();\n        let mut items_multi_thread = HashSet::new();\n\n        let mut single_thread_cnt = 0;\n        let mut multi_thread_cnt = 0;\n        for items in dataloader_single_thread.iter() {\n            items_single_thread.insert(items);\n            single_thread_cnt += 1;\n        }\n\n        for items in dataloader_multi_thread.iter() {\n            items_multi_thread.insert(items);\n            multi_thread_cnt += 1;\n        }\n\n        assert_eq!(single_thread_cnt, multi_thread_cnt);\n        assert_eq!(items_single_thread, items_multi_thread);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/split.rs",
    "content": "use std::sync::Arc;\n\nuse burn_tensor::backend::Backend;\n\nuse super::DataLoader;\n\n/// Splits a dataloader into multiple partial dataloaders (one per device).\npub fn split_dataloader<B: Backend, O>(\n    dataloader: Arc<dyn DataLoader<B, O>>,\n    devices: &[B::Device],\n) -> Vec<Arc<dyn DataLoader<B, O>>> {\n    let num_splits = devices.len();\n    if num_splits > 1 {\n        let num_items = dataloader.num_items();\n        let mut dataloaders = Vec::with_capacity(num_splits);\n\n        let mut start = 0;\n        let step = num_items / num_splits;\n        for (i, device) in devices.iter().enumerate() {\n            let end = if i == (num_splits - 1) {\n                num_items\n            } else {\n                start + step\n            };\n            let dataloader = dataloader.slice(start, end).to_device(device);\n            dataloaders.push(dataloader);\n            start = end;\n        }\n        dataloaders\n    } else {\n        vec![dataloader]\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::collections::HashSet;\n\n    use super::*;\n    use crate::TestBackend;\n    use crate::data::dataloader::batcher::Batcher;\n    use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};\n    use crate::data::dataset::FakeDataset;\n\n    #[test]\n    fn test_split_batch_dataloader() {\n        type TestDevice = <TestBackend as Backend>::Device;\n\n        #[derive(new, Clone)]\n        pub struct TestBatcher;\n\n        #[cfg(test)]\n        impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {\n            fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {\n                (items, *device)\n            }\n        }\n\n        let batcher = Arc::new(TestBatcher::new());\n        let dataset = Arc::new(FakeDataset::<String>::new(11));\n\n        #[allow(clippy::arc_with_non_send_sync)]\n        let dataloader = Arc::new(BatchDataLoader::new(\n            Box::new(FixBatchStrategy::new(5)),\n            dataset.clone(),\n            batcher,\n            Default::default(),\n            None,\n        ));\n\n        #[cfg(all(\n            test,\n            not(feature = \"test-tch\"),\n            not(feature = \"test-wgpu\"),\n            not(feature = \"test-cuda\")\n        ))]\n        // Only one device exists...\n        let (device1, device2) = (\n            burn_ndarray::NdArrayDevice::Cpu,\n            burn_ndarray::NdArrayDevice::Cpu,\n        );\n\n        #[cfg(all(test, feature = \"test-tch\"))]\n        let (device1, device2) = (\n            burn_tch::LibTorchDevice::Cuda(0),\n            burn_tch::LibTorchDevice::Cuda(1),\n        );\n\n        #[cfg(all(test, feature = \"test-wgpu\"))]\n        let (device1, device2) = (\n            burn_wgpu::WgpuDevice::DiscreteGpu(0),\n            burn_wgpu::WgpuDevice::DiscreteGpu(1),\n        );\n\n        #[cfg(all(test, feature = \"test-cuda\"))]\n        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));\n\n        let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);\n\n        assert_eq!(dataloaders.len(), 2);\n\n        let [dataloader_1, dataloader_2] = match dataloaders.try_into() {\n            Ok(arr) => arr,\n            Err(_) => unreachable!(),\n        };\n        assert_eq!(dataloader_1.num_items(), 5);\n        assert_eq!(dataloader_2.num_items(), 6);\n\n        let mut items_dataloader = HashSet::new();\n        let mut items_dataloader_split = HashSet::new();\n\n        for (items, _device) in dataloader.iter() {\n            for item in items {\n                items_dataloader.insert(item);\n            }\n        }\n\n        for (items, device) in dataloader_1.iter() {\n            assert_eq!(device, device1);\n            for item in items {\n                items_dataloader_split.insert(item);\n            }\n        }\n\n        for (items, device) in dataloader_2.iter() {\n            assert_eq!(device, device2);\n            for item in items {\n                items_dataloader_split.insert(item);\n            }\n        }\n\n        assert_eq!(items_dataloader, items_dataloader_split);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/dataloader/strategy.rs",
    "content": "/// A strategy to batch items.\npub trait BatchStrategy<I>: Send + Sync {\n    /// Adds an item to the strategy.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The item to add.\n    fn add(&mut self, item: I);\n\n    /// Batches the items.\n    ///\n    /// # Arguments\n    ///\n    /// * `force` - Whether to force batching.\n    ///\n    /// # Returns\n    ///\n    /// The batched items.\n    fn batch(&mut self, force: bool) -> Option<Vec<I>>;\n\n    /// Creates a new strategy of the same type.\n    ///\n    /// # Returns\n    ///\n    /// The new strategy.\n    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>>;\n\n    /// Returns the expected batch size for this strategy.\n    ///\n    /// # Returns\n    ///\n    /// The batch size, or None if the strategy doesn't have a fixed batch size.\n    fn batch_size(&self) -> Option<usize>;\n}\n\n/// A strategy to batch items with a fixed batch size.\npub struct FixBatchStrategy<I> {\n    items: Vec<I>,\n    batch_size: usize,\n}\n\nimpl<I> FixBatchStrategy<I> {\n    /// Creates a new strategy to batch items with a fixed batch size.\n    ///\n    /// # Arguments\n    ///\n    /// * `batch_size` - The batch size.\n    ///\n    /// # Returns\n    ///\n    /// The strategy.\n    pub fn new(batch_size: usize) -> Self {\n        FixBatchStrategy {\n            items: Vec::with_capacity(batch_size),\n            batch_size,\n        }\n    }\n}\n\nimpl<I: Send + Sync + 'static> BatchStrategy<I> for FixBatchStrategy<I> {\n    fn add(&mut self, item: I) {\n        self.items.push(item);\n    }\n\n    fn batch(&mut self, force: bool) -> Option<Vec<I>> {\n        if self.items.len() < self.batch_size && !force {\n            return None;\n        }\n\n        let mut items = Vec::with_capacity(self.batch_size);\n        std::mem::swap(&mut items, &mut self.items);\n\n        if items.is_empty() {\n            return None;\n        }\n\n        Some(items)\n    }\n\n    fn clone_dyn(&self) -> Box<dyn BatchStrategy<I>> {\n        Box::new(Self::new(self.batch_size))\n    }\n\n    fn batch_size(&self) -> Option<usize> {\n        Some(self.batch_size)\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/data/mod.rs",
    "content": "/// Dataloader module.\n#[cfg(feature = \"dataset\")]\npub mod dataloader;\n\n/// Dataset module.\n#[cfg(feature = \"dataset\")]\npub mod dataset {\n    pub use burn_dataset::*;\n}\n\n/// Network module.\n#[cfg(feature = \"network\")]\npub mod network {\n    pub use burn_std::network::*;\n}\n"
  },
  {
    "path": "crates/burn-core/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![recursion_limit = \"135\"]\n\n//! The core crate of Burn.\n\n#[macro_use]\nextern crate derive_new;\n\n/// Re-export serde for proc macros.\npub use serde;\n\n/// The configuration module.\npub mod config;\n\n/// Data module.\n#[cfg(feature = \"std\")]\npub mod data;\n\n/// Module for the neural network module.\npub mod module;\n\n/// Module for the recorder.\npub mod record;\n\n/// Module for the tensor.\npub mod tensor;\n// Tensor at root: `burn::Tensor`\npub use tensor::Tensor;\n\n/// Module for visual operations\n#[cfg(feature = \"vision\")]\npub mod vision;\n\nextern crate alloc;\n\n/// Backend for test cases\n#[cfg(all(\n    test,\n    not(feature = \"test-tch\"),\n    not(feature = \"test-wgpu\"),\n    not(feature = \"test-cuda\"),\n    not(feature = \"test-rocm\")\n))]\npub type TestBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(all(test, feature = \"test-tch\"))]\n/// Backend for test cases\npub type TestBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(all(test, feature = \"test-wgpu\"))]\n/// Backend for test cases\npub type TestBackend = burn_wgpu::Wgpu;\n\n#[cfg(all(test, feature = \"test-cuda\"))]\n/// Backend for test cases\npub type TestBackend = burn_cuda::Cuda;\n\n#[cfg(all(test, feature = \"test-rocm\"))]\n/// Backend for test cases\npub type TestBackend = burn_rocm::Rocm;\n\n/// Backend for autodiff test cases\n#[cfg(test)]\npub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;\n\n#[cfg(all(test, feature = \"test-memory-checks\"))]\nmod tests {\n    burn_fusion::memory_checks!();\n}\n\n#[cfg(test)]\nmod test_utils {\n    use crate as burn;\n    use crate::module::Module;\n    use crate::module::Param;\n    use burn_tensor::Tensor;\n    use burn_tensor::backend::Backend;\n\n    /// Simple linear module.\n    #[derive(Module, Debug)]\n    pub struct SimpleLinear<B: Backend> {\n        pub weight: Param<Tensor<B, 2>>,\n        pub bias: Option<Param<Tensor<B, 1>>>,\n    }\n\n    impl<B: Backend> SimpleLinear<B> {\n        pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {\n            let weight = Tensor::random(\n                [out_features, in_features],\n                burn_tensor::Distribution::Default,\n                device,\n            );\n            let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);\n\n            Self {\n                weight: Param::from_tensor(weight),\n                bias: Some(Param::from_tensor(bias)),\n            }\n        }\n    }\n}\n\npub mod prelude {\n    //! Structs and macros used by most projects. Add `use\n    //! burn::prelude::*` to your code to quickly get started with\n    //! Burn.\n    pub use crate::{\n        config::Config,\n        module::Module,\n        tensor::{\n            Bool, Device, ElementConversion, Float, Int, Shape, SliceArg, Tensor, TensorData,\n            backend::Backend, cast::ToElement, s,\n        },\n    };\n    pub use burn_std::device::Device as DeviceOps;\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/base.rs",
    "content": "use super::{Param, ParamId, Quantizer};\nuse crate::{\n    record::Record,\n    tensor::backend::{AutodiffBackend, Backend},\n};\nuse alloc::{string::String, vec::Vec};\npub use burn_derive::Module;\nuse burn_tensor::{Bool, Int, Tensor, ops::Device};\n\n/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using\n/// the `alloc` crate.\npub type Devices<B> = Vec<Device<B>>;\n\n// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.\n// We may consider making it public in the future.\nmacro_rules! module {\n    (map=$module:ident, ops=$item:expr) => {{\n        struct Mapper;\n        impl<B: Backend> ModuleMapper<B> for Mapper {\n            fn map_float<const D: usize>(\n                &mut self,\n                param: Param<Tensor<B, D>>,\n            ) -> Param<Tensor<B, D>> {\n                let (id, tensor, mapper) = param.consume();\n                let func = $item;\n                let tensor = func(tensor);\n                Param::from_mapped_value(id, tensor, mapper)\n            }\n        }\n        let mut mapper = Mapper;\n        $module.map(&mut mapper)\n    }};\n    (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{\n        struct Visitor<'a, B: Backend> {\n            state: &'a mut $state_ty,\n            backend: core::marker::PhantomData<B>,\n        }\n        impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {\n            fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n                let func = $item;\n                func(&param.val(), &mut self.state)\n            }\n        }\n        #[allow(clippy::redundant_closure_call)]\n        let mut state = $init();\n        let mut visitor = Visitor {\n            state: &mut state,\n            backend: core::marker::PhantomData,\n        };\n        $module.visit(&mut visitor);\n        state\n    }};\n}\n\n/// Trait for all neural network modules.\n///\n/// Modules should be created using the [derive](burn_derive::Module) attribute.\n/// This will make your module trainable, savable and loadable via\n/// `state` and `load`.\n///\n/// # Example\n///\n/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic\n/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code\n/// necessary to optimize and train the module on any backend.\n///\n/// ```rust, ignore\n/// // Not necessary when using the burn crate directly.\n/// use burn_core as burn;\n///\n/// use burn::{\n///     module::Module,\n///     nn::Linear,\n///     tensor::Tensor,\n///     tensor::backend::Backend,\n/// };\n///\n/// #[derive(Module, Debug)]\n/// struct MyModule<B: Backend> {\n///   my_param: Linear<B>,\n///   my_other_field: usize,\n/// }\n/// ```\npub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {\n    /// Type to save and load the module.\n    type Record: Record<B>;\n\n    /// Return all the devices found in the underneath module tree added to the given vector\n    /// without duplicates.\n    fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;\n\n    /// Return all the devices found in the underneath module tree without duplicates.\n    fn devices(&self) -> Devices<B> {\n        self.collect_devices(Devices::<B>::new())\n    }\n\n    /// Fork the module and all of its sub-modules to the given device.\n    ///\n    /// # Notes\n    ///\n    /// This is similar to [to_device](Module::to_device), but it ensures the output module on the\n    /// new device will have its own autodiff graph.\n    fn fork(self, device: &B::Device) -> Self;\n\n    /// Move the module and all of its sub-modules to the given device.\n    ///\n    /// # Warnings\n    ///\n    /// The operation supports autodiff and it will be registered when activated. However, this may\n    /// not be what you want. The output model will be an intermediary model, meaning that you\n    /// can't optimize it with gradient descent. If you want to optimize the output network on the\n    /// target device, use [fork](Module::fork) instead.\n    fn to_device(self, device: &B::Device) -> Self;\n\n    /// Each tensor in the module tree will not require grad.\n    ///\n    /// # Warnings\n    ///\n    /// This should not be used for inference, use [valid](AutodiffModule::valid) when using\n    /// AD modules. This is mostly useful when performing partial finetuning, which is updating only\n    /// a small fraction of the parameters instead of finetuning all of them.\n    fn no_grad(self) -> Self {\n        module!(\n            map = self,\n            ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)\n        )\n    }\n\n    /// Move the module and all of its sub-modules to the autodiff backend.\n    ///\n    /// # Notes\n    ///\n    /// * Only plain modules (not already on an autodiff backend) can be moved.\n    /// * Calling `train()` on a module that is already on an autodiff backend\n    ///   will result in a type error, because the module's inner backend does not match.\n    fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule\n    where\n        AB: AutodiffBackend<InnerBackend = B>,\n        Self: HasAutodiffModule<AB>,\n    {\n        <Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)\n    }\n\n    /// Get the number of parameters the module has, including all of its sub-modules.\n    fn num_params(&self) -> usize {\n        module!(\n            visit_float = self,\n            ops = |tensor: &Tensor<B, D>, state: &mut usize| {\n                *state += tensor.shape().num_elements();\n            },\n            state = usize,\n            init = || 0\n        )\n    }\n    /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).\n    fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);\n\n    /// Map each tensor parameter in the module with a [mapper](ModuleMapper).\n    fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;\n\n    /// Load the module state from a record.\n    fn load_record(self, record: Self::Record) -> Self;\n\n    /// Convert the module into a record containing the state.\n    fn into_record(self) -> Self::Record;\n\n    #[cfg(feature = \"std\")]\n    /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).\n    ///\n    /// List of supported file recorders:\n    ///\n    /// * [default](crate::record::DefaultFileRecorder)\n    /// * [bincode](crate::record::BinFileRecorder)\n    /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)\n    /// * [json pretty](crate::record::PrettyJsonFileRecorder)\n    /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)\n    /// * [named mpk](crate::record::NamedMpkFileRecorder)\n    /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)\n    ///\n    /// ## Notes\n    ///\n    /// The file extension is automatically added depending on the file recorder provided, you\n    /// don't have to specify it.\n    fn save_file<FR, PB>(\n        self,\n        file_path: PB,\n        recorder: &FR,\n    ) -> Result<(), crate::record::RecorderError>\n    where\n        FR: crate::record::FileRecorder<B>,\n        PB: Into<std::path::PathBuf>,\n    {\n        let record = Self::into_record(self);\n        recorder.record(record, file_path.into())\n    }\n\n    #[cfg(feature = \"std\")]\n    /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).\n    ///\n    /// The recorder should be the same as the one used to save the module, see\n    /// [save_file](Self::save_file).\n    ///\n    /// ## Notes\n    ///\n    /// The file extension is automatically added depending on the file recorder provided, you\n    /// don't have to specify it.\n    fn load_file<FR, PB>(\n        self,\n        file_path: PB,\n        recorder: &FR,\n        device: &B::Device,\n    ) -> Result<Self, crate::record::RecorderError>\n    where\n        FR: crate::record::FileRecorder<B>,\n        PB: Into<std::path::PathBuf>,\n    {\n        let record = recorder.load(file_path.into(), device)?;\n\n        Ok(self.load_record(record))\n    }\n\n    /// Quantize the weights of the module.\n    fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {\n        self.map(quantizer)\n    }\n}\n\n/// Module visitor trait for traversing and inspecting module parameters.\npub trait ModuleVisitor<B: Backend> {\n    /// Visit a float parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The float parameter to visit\n    #[allow(unused_variables)]\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}\n\n    /// Visit an int parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The integer parameter to visit\n    #[allow(unused_variables)]\n    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}\n\n    /// Visit a bool parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The boolean parameter to visit\n    #[allow(unused_variables)]\n    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}\n\n    /// Called when entering a submodule.\n    ///\n    /// # Parameters\n    /// - `name`: The name of the submodule being entered\n    /// - `container_type`: The type of the container with format:\n    ///   - For user-defined structs: \"Struct:TypeName\" (e.g., \"Struct:Linear\")\n    ///   - For user-defined enums: \"Enum:TypeName\" (e.g., \"Enum:MyEnum\")\n    ///   - For Vec containers: \"Vec\" (name is the index)\n    ///   - For Tuple containers: \"Tuple\" (name is the index)\n    ///   - For Array containers: \"Array\" (name is the index)\n    ///\n    /// Note: Option containers do not call enter_module/exit_module to preserve\n    /// the field name in the path (e.g., \"bias\" instead of \"bias.Some\")\n    #[allow(unused_variables)]\n    fn enter_module(&mut self, name: &str, container_type: &str) {}\n\n    /// Called when exiting a submodule.\n    ///\n    /// # Parameters\n    /// - `name`: The name of the submodule being exited\n    /// - `container_type`: The type of the container with format:\n    ///   - For user-defined structs: \"Struct:TypeName\" (e.g., \"Struct:Linear\")\n    ///   - For user-defined enums: \"Enum:TypeName\" (e.g., \"Enum:MyEnum\")\n    ///   - For Vec containers: \"Vec\" (name is the index)\n    ///   - For Tuple containers: \"Tuple\" (name is the index)\n    ///   - For Array containers: \"Array\" (name is the index)\n    ///\n    /// Note: Option containers do not call enter_module/exit_module to preserve\n    /// the field name in the path (e.g., \"bias\" instead of \"bias.Some\")\n    #[allow(unused_variables)]\n    fn exit_module(&mut self, name: &str, container_type: &str) {}\n\n    /// Visit a float tensor with its full module path.\n    ///\n    /// # Parameters\n    /// - `path`: The path components to the tensor as a slice (e.g., &[\"encoder\", \"layer1\", \"weight\"]).\n    ///   Each element represents a module name in the hierarchy, with the final element\n    ///   being the parameter name. This allows efficient reuse of the path stack.\n    /// - `id`: The unique identifier of the parameter\n    /// - `tensor`: The float tensor to visit\n    #[allow(unused_variables)]\n    fn visit_float_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D>,\n    ) {\n    }\n\n    /// Visit an int tensor with its full module path.\n    ///\n    /// # Parameters\n    /// - `path`: The path components to the tensor as a slice (e.g., &[\"encoder\", \"layer1\", \"weight\"]).\n    ///   Each element represents a module name in the hierarchy, with the final element\n    ///   being the parameter name. This allows efficient reuse of the path stack.\n    /// - `id`: The unique identifier of the parameter\n    /// - `tensor`: The integer tensor to visit\n    #[allow(unused_variables)]\n    fn visit_int_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D, Int>,\n    ) {\n    }\n\n    /// Visit a bool tensor with its full module path.\n    ///\n    /// # Parameters\n    /// - `path`: The path components to the tensor as a slice (e.g., &[\"encoder\", \"layer1\", \"weight\"]).\n    ///   Each element represents a module name in the hierarchy, with the final element\n    ///   being the parameter name. This allows efficient reuse of the path stack.\n    /// - `id`: The unique identifier of the parameter\n    /// - `tensor`: The boolean tensor to visit\n    #[allow(unused_variables)]\n    fn visit_bool_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D, Bool>,\n    ) {\n    }\n}\n\n/// Module mapper trait for transforming module parameters.\npub trait ModuleMapper<B: Backend> {\n    /// Called when entering a submodule.\n    ///\n    /// # Parameters\n    /// - `name`: The name of the submodule being entered\n    /// - `container_type`: The type of the container with format:\n    ///   - For user-defined structs: \"Struct:TypeName\" (e.g., \"Struct:Linear\")\n    ///   - For user-defined enums: \"Enum:TypeName\" (e.g., \"Enum:MyEnum\")\n    ///   - For Vec containers: \"Vec\" (name is the index)\n    ///   - For Tuple containers: \"Tuple\" (name is the index)\n    ///   - For Array containers: \"Array\" (name is the index)\n    ///\n    /// Note: Option containers do not call enter_module/exit_module to preserve\n    /// the field name in the path (e.g., \"bias\" instead of \"bias.Some\")\n    #[allow(unused_variables)]\n    fn enter_module(&mut self, name: &str, container_type: &str) {}\n\n    /// Called when exiting a submodule.\n    ///\n    /// # Parameters\n    /// - `name`: The name of the submodule being exited\n    /// - `container_type`: The type of the container with format:\n    ///   - For user-defined structs: \"Struct:TypeName\" (e.g., \"Struct:Linear\")\n    ///   - For user-defined enums: \"Enum:TypeName\" (e.g., \"Enum:MyEnum\")\n    ///   - For Vec containers: \"Vec\" (name is the index)\n    ///   - For Tuple containers: \"Tuple\" (name is the index)\n    ///   - For Array containers: \"Array\" (name is the index)\n    ///\n    /// Note: Option containers do not call enter_module/exit_module to preserve\n    /// the field name in the path (e.g., \"bias\" instead of \"bias.Some\")\n    #[allow(unused_variables)]\n    fn exit_module(&mut self, name: &str, container_type: &str) {}\n\n    /// Map a float parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The float parameter to transform\n    ///\n    /// # Returns\n    /// The transformed parameter\n    #[allow(unused_variables)]\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n\n    /// Map an int parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The integer parameter to transform\n    ///\n    /// # Returns\n    /// The transformed parameter\n    #[allow(unused_variables)]\n    fn map_int<const D: usize>(\n        &mut self,\n        param: Param<Tensor<B, D, Int>>,\n    ) -> Param<Tensor<B, D, Int>> {\n        let (id, tensor, mapper) = param.consume();\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n\n    /// Map a bool parameter in the module.\n    ///\n    /// # Parameters\n    /// - `param`: The boolean parameter to transform\n    ///\n    /// # Returns\n    /// The transformed parameter\n    #[allow(unused_variables)]\n    fn map_bool<const D: usize>(\n        &mut self,\n        param: Param<Tensor<B, D, Bool>>,\n    ) -> Param<Tensor<B, D, Bool>> {\n        let (id, tensor, mapper) = param.consume();\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n}\n\n/// Module with auto-differentiation backend.\npub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {\n    /// Inner module without auto-differentiation.\n    type InnerModule: Module<B::InnerBackend>;\n\n    /// Returns the same module, but on the inner backend without auto-differentiation.\n    fn valid(&self) -> Self::InnerModule;\n\n    /// Wraps an inner module back into an auto-diff module.\n    fn from_inner(module: Self::InnerModule) -> Self;\n}\n\n/// Helper trait to associate a module with its autodiff version.\npub trait HasAutodiffModule<B: AutodiffBackend> {\n    /// The module with auto-differentiation.\n    type TrainModule: AutodiffModule<B, InnerModule = Self>;\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use crate::TestAutodiffBackend;\n    use crate::test_utils::SimpleLinear;\n\n    #[test]\n    fn test_module_val_train_stateful() {\n        let device = Default::default();\n        let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);\n\n        assert!(module.weight.is_require_grad());\n        assert!(module.weight.require_grad);\n\n        let module = module.valid();\n        assert!(!module.weight.is_require_grad());\n        assert!(module.weight.require_grad); // stateful\n\n        // Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying\n        // let module: SimpleLinear<TestAutodiffBackend> = module.train();\n        let module = module.train::<TestAutodiffBackend>();\n        assert!(module.weight.is_require_grad());\n        assert!(module.weight.require_grad); // stateful\n\n        let module = module.no_grad();\n        assert!(!module.weight.is_require_grad());\n        assert!(!module.weight.require_grad); // stateful\n\n        let module = module.valid();\n        assert!(!module.weight.is_require_grad()); // always\n        assert!(!module.weight.require_grad); // stateful\n\n        let module = module.train::<TestAutodiffBackend>();\n        assert!(!module.weight.is_require_grad());\n        assert!(!module.weight.require_grad); // stateful\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/display.rs",
    "content": "use alloc::{\n    borrow::ToOwned,\n    format,\n    string::{String, ToString},\n    vec::Vec,\n};\nuse core::any;\nuse core::fmt::{Debug, Display, Write};\n\n/// Default display settings for a module.\npub trait ModuleDisplayDefault {\n    /// Attributes of the module used for display purposes.\n    ///\n    /// # Arguments\n    ///\n    /// * `_content` - The content object that contains display settings and attributes.\n    ///\n    /// # Returns\n    ///\n    /// An optional content object containing the display attributes.\n    fn content(&self, _content: Content) -> Option<Content>;\n\n    /// Gets the number of the parameters of the module.\n    fn num_params(&self) -> usize {\n        0\n    }\n}\n\n/// Trait to implement custom display settings for a module.\n///\n/// In order to implement custom display settings for a module,\n/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)]\n/// 2. Implement ModuleDisplay trait for the module\npub trait ModuleDisplay: ModuleDisplayDefault {\n    /// Formats the module with provided display settings.\n    ///\n    /// # Arguments\n    ///\n    /// * `passed_settings` - Display settings passed to the module.\n    ///\n    /// # Returns\n    ///\n    /// A string representation of the formatted module.\n    fn format(&self, passed_settings: DisplaySettings) -> String {\n        let settings = if let Some(custom_settings) = self.custom_settings() {\n            custom_settings.inherit(passed_settings)\n        } else {\n            passed_settings\n        };\n\n        let indent = \" \".repeat(settings.level * settings.indentation_size());\n        let indent_close_braces = \" \".repeat((settings.level - 1) * settings.indentation_size());\n\n        let settings = settings.level_up();\n\n        let self_type = extract_type_name::<Self>();\n\n        // Use custom content if it is implemented and show_all_attributes is false,\n        // otherwise use default content\n        let content = if !settings.show_all_attributes() {\n            self.custom_content(Content::new(settings.clone()))\n                .unwrap_or_else(|| {\n                    self.content(Content::new(settings.clone()))\n                        .unwrap_or_else(|| {\n                            panic!(\"Default content should be implemented for {self_type}.\")\n                        })\n                })\n        } else {\n            self.content(Content::new(settings.clone()))\n                .unwrap_or_else(|| panic!(\"Default content should be implemented for {self_type}.\"))\n        };\n\n        let top_level_type = if let Some(top_level_type) = content.top_level_type {\n            top_level_type.to_owned()\n        } else {\n            self_type.to_owned()\n        };\n\n        // If there is only one item in the content, return it or no attributes\n        if let Some(item) = content.single_item {\n            return item;\n        } else if content.attributes.is_empty() {\n            return top_level_type.to_string();\n        }\n\n        let mut result = String::new();\n\n        // Print the struct name\n        if settings.new_line_after_attribute() {\n            writeln!(result, \"{top_level_type} {{\").unwrap();\n        } else {\n            write!(result, \"{top_level_type} {{\").unwrap();\n        }\n\n        for (i, attribute) in content.attributes.iter().enumerate() {\n            if settings.new_line_after_attribute() {\n                writeln!(result, \"{indent}{}: {}\", attribute.name, attribute.value).unwrap();\n            } else if i == 0 {\n                write!(result, \"{}: {}\", attribute.name, attribute.value).unwrap();\n            } else {\n                write!(result, \", {}: {}\", attribute.name, attribute.value).unwrap();\n            }\n        }\n\n        if settings.show_num_parameters() {\n            let num_params = self.num_params();\n            if num_params > 0 {\n                if settings.new_line_after_attribute() {\n                    writeln!(result, \"{indent}params: {num_params}\").unwrap();\n                } else {\n                    write!(result, \", params: {num_params}\").unwrap();\n                }\n            }\n        }\n\n        if settings.new_line_after_attribute() {\n            write!(result, \"{indent_close_braces}}}\").unwrap();\n        } else {\n            write!(result, \"}}\").unwrap();\n        }\n\n        result\n    }\n\n    /// Custom display settings for the module.\n    ///\n    /// # Returns\n    ///\n    /// An optional display settings object.\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        None\n    }\n\n    /// Custom attributes for the module.\n    ///\n    /// # Arguments\n    ///\n    /// * `_content` - The content object that contains display settings and attributes.\n    ///\n    /// # Returns\n    ///\n    /// An optional content object containing the custom attributes.\n    fn custom_content(&self, _content: Content) -> Option<Content> {\n        None\n    }\n}\n\n/// Custom module display settings.\n#[derive(Debug, Clone)]\npub struct DisplaySettings {\n    /// Whether to print the module parameter ids.\n    show_param_id: Option<bool>,\n\n    /// Whether to print the module attributes.\n    show_all_attributes: Option<bool>,\n\n    /// Whether to print the module number of parameters.\n    show_num_parameters: Option<bool>,\n\n    /// Print new line after an attribute.\n    new_line_after_attribute: Option<bool>,\n\n    /// Indentation size.\n    indentation_size: Option<usize>,\n\n    /// Level of indentation.\n    level: usize,\n}\n\nimpl Default for DisplaySettings {\n    fn default() -> Self {\n        DisplaySettings {\n            show_param_id: None,\n            show_all_attributes: None,\n            show_num_parameters: None,\n            new_line_after_attribute: None,\n            indentation_size: None,\n            level: 1,\n        }\n    }\n}\n\nimpl DisplaySettings {\n    /// Create a new format settings.\n    ///\n    /// # Returns\n    ///\n    /// A new instance of `DisplaySettings`.\n    pub fn new() -> Self {\n        Default::default()\n    }\n\n    /// Sets a flag to show module parameters.\n    ///\n    /// # Arguments\n    ///\n    /// * `flag` - Boolean flag to show module parameters.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn with_show_param_id(mut self, flag: bool) -> Self {\n        self.show_param_id = Some(flag);\n        self\n    }\n\n    /// Sets a flag to show module attributes.\n    ///\n    /// # Arguments\n    ///\n    /// * `flag` - Boolean flag to show all module attributes.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn with_show_all_attributes(mut self, flag: bool) -> Self {\n        self.show_all_attributes = Some(flag);\n        self\n    }\n\n    /// Sets a flag to show the number of module parameters.\n    ///\n    /// # Arguments\n    ///\n    /// * `flag` - Boolean flag to show the number of module parameters.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn with_show_num_parameters(mut self, flag: bool) -> Self {\n        self.show_num_parameters = Some(flag);\n        self\n    }\n\n    /// Sets a flag to print a new line after an attribute.\n    ///\n    /// # Arguments\n    ///\n    /// * `flag` - Boolean flag to print a new line after an attribute.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self {\n        self.new_line_after_attribute = Some(flag);\n        self\n    }\n\n    /// Sets the indentation size.\n    ///\n    /// # Arguments\n    ///\n    /// * `size` - The size of the indentation.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn with_indentation_size(mut self, size: usize) -> Self {\n        self.indentation_size = Some(size);\n        self\n    }\n\n    /// Inherits settings from the provided settings and return a new settings object.\n    ///\n    /// # Arguments\n    ///\n    /// * `top` - The top level `DisplaySettings` to inherit from.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance.\n    pub fn inherit(self, top: Self) -> Self {\n        let mut updated = self.clone();\n\n        if let Some(show_param_id) = top.show_param_id {\n            updated.show_param_id = Some(show_param_id);\n        };\n\n        if let Some(show_all_attributes) = top.show_all_attributes {\n            updated.show_all_attributes = Some(show_all_attributes);\n        }\n\n        if let Some(show_num_parameters) = top.show_num_parameters {\n            updated.show_num_parameters = Some(show_num_parameters);\n        }\n\n        if let Some(new_line_after_attribute) = top.new_line_after_attribute {\n            updated.new_line_after_attribute = Some(new_line_after_attribute);\n        }\n\n        if let Some(indentation_size) = top.indentation_size {\n            updated.indentation_size = Some(indentation_size);\n        }\n\n        updated.level = top.level;\n\n        updated\n    }\n\n    /// A convenience method to wrap the DisplaySettings struct in an option.\n    ///\n    /// # Returns\n    ///\n    /// An optional `DisplaySettings`.\n    pub fn optional(self) -> Option<Self> {\n        Some(self)\n    }\n\n    /// Increases the level of indentation.\n    ///\n    /// # Returns\n    ///\n    /// Updated `DisplaySettings` instance with increased indentation level.\n    pub fn level_up(mut self) -> Self {\n        self.level += 1;\n        self\n    }\n\n    /// Gets `show_param_id` flag, substitutes false if not set.\n    ///\n    /// This flag is used to print the module parameter ids.\n    ///\n    /// # Returns\n    ///\n    /// A boolean value indicating whether to show parameter ids.\n    pub fn show_param_id(&self) -> bool {\n        self.show_param_id.unwrap_or(false)\n    }\n\n    /// Gets `show_all_attributes`, substitutes false if not set.\n    ///\n    /// This flag is used to force to print all module attributes, overriding custom attributes.\n    ///\n    /// # Returns\n    ///\n    /// A boolean value indicating whether to show all attributes.\n    pub fn show_all_attributes(&self) -> bool {\n        self.show_all_attributes.unwrap_or(false)\n    }\n\n    /// Gets `show_num_parameters`, substitutes true if not set.\n    ///\n    /// This flag is used to print the number of module parameters.\n    ///\n    /// # Returns\n    ///\n    /// A boolean value indicating whether to show the number of parameters.\n    pub fn show_num_parameters(&self) -> bool {\n        self.show_num_parameters.unwrap_or(true)\n    }\n\n    /// Gets `new_line_after_attribute`, substitutes true if not set.\n    ///\n    /// This flag is used to print a new line after an attribute.\n    ///\n    /// # Returns\n    ///\n    /// A boolean value indicating whether to print a new line after an attribute.\n    pub fn new_line_after_attribute(&self) -> bool {\n        self.new_line_after_attribute.unwrap_or(true)\n    }\n\n    /// Gets `indentation_size`, substitutes 2 if not set.\n    ///\n    /// This flag is used to set the size of indentation.\n    ///\n    /// # Returns\n    ///\n    /// An integer value indicating the size of indentation.\n    pub fn indentation_size(&self) -> usize {\n        self.indentation_size.unwrap_or(2)\n    }\n}\n\n/// Struct to store the attributes of a module for formatting.\n#[derive(Clone, Debug)]\npub struct Content {\n    /// List of attributes.\n    pub attributes: Vec<Attribute>,\n\n    /// Single item content.\n    pub single_item: Option<String>,\n\n    /// Display settings.\n    pub display_settings: DisplaySettings,\n\n    /// Top level type name.\n    pub top_level_type: Option<String>,\n}\n\nimpl Content {\n    /// Creates a new attributes struct.\n    ///\n    /// # Arguments\n    ///\n    /// * `display_settings` - Display settings for the content.\n    ///\n    /// # Returns\n    ///\n    /// A new instance of `Content`.\n    pub fn new(display_settings: DisplaySettings) -> Self {\n        Content {\n            attributes: Vec::new(),\n            single_item: None,\n            display_settings,\n            top_level_type: None,\n        }\n    }\n\n    /// Adds an attribute to the format settings. The value will be formatted and stored as a string.\n    ///\n    /// # Arguments\n    ///\n    /// * `name` - Name of the attribute.\n    /// * `value` - Value of the attribute.\n    ///\n    /// # Returns\n    ///\n    /// Updated `Content` instance with the new attribute added.\n    pub fn add<T: ModuleDisplay + ?Sized>(mut self, name: &str, value: &T) -> Self {\n        if self.single_item.is_some() {\n            panic!(\"Cannot add multiple attributes when single item is set.\");\n        }\n\n        let attribute = Attribute {\n            name: name.to_owned(),\n            value: value.format(self.display_settings.clone()), // TODO level + 1\n            ty: any::type_name::<T>().to_string(),\n        };\n        self.attributes.push(attribute);\n        self\n    }\n\n    /// Adds an attribute using its `Debug` representation.\n    ///\n    /// This is intended for fields that do not implement [`ModuleDisplay`].\n    ///\n    /// # Arguments\n    ///\n    /// * `name` - Name of the attribute.\n    /// * `value` - Value of the attribute.\n    ///\n    /// # Returns\n    ///\n    /// Updated `Content` instance with the new attribute added.\n    pub fn add_debug_attribute<T: Debug>(mut self, name: &str, value: &T) -> Self {\n        if self.single_item.is_some() {\n            panic!(\"Cannot add multiple attributes when single item is set.\");\n        }\n        self.attributes.push(Attribute {\n            name: name.to_owned(),\n            value: DisplayAdapter(value).format(self.display_settings.clone()),\n            ty: any::type_name::<T>().to_string(),\n        });\n        self\n    }\n\n    /// Adds a single item.\n    ///\n    /// # Arguments\n    ///\n    /// * `value` - Rendered string of the single item.\n    ///\n    /// # Returns\n    ///\n    /// Updated `Content` instance with the single item added.\n    pub fn add_single<T: ModuleDisplay + ?Sized>(mut self, value: &T) -> Self {\n        if !self.attributes.is_empty() {\n            panic!(\"Cannot add single item when attributes are set.\");\n        }\n\n        self.single_item = Some(value.format(self.display_settings.clone()));\n\n        self\n    }\n\n    /// Adds a single item.\n    ///\n    /// # Arguments\n    ///\n    /// * `value` - Formatted display value.\n    ///\n    /// # Returns\n    ///\n    /// Updated `Content` instance with the formatted single item added.\n    pub fn add_formatted<T: Display>(mut self, value: &T) -> Self {\n        if !self.attributes.is_empty() {\n            panic!(\"Cannot add single item when attributes are set.\");\n        }\n\n        self.single_item = Some(format!(\"{value}\"));\n        self\n    }\n\n    /// A convenience method to wrap the Attributes struct in an option\n    /// because it is often used as an optional field.\n    ///\n    /// # Returns\n    ///\n    /// An optional `Content`.\n    pub fn optional(self) -> Option<Self> {\n        if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none()\n        {\n            None\n        } else {\n            Some(self)\n        }\n    }\n\n    /// Sets the top level type name.\n    ///\n    /// # Arguments\n    ///\n    /// * `ty` - The type name to set.\n    ///\n    /// # Returns\n    ///\n    /// Updated `Content` instance with the top level type name set.\n    pub fn set_top_level_type(mut self, ty: &str) -> Self {\n        self.top_level_type = Some(ty.to_owned());\n        self\n    }\n}\n\n/// Minimal display adapter for non-module types.\nstruct DisplayAdapter<'a, T: Debug>(&'a T);\n\nimpl<'a, T: Debug> ModuleDisplayDefault for DisplayAdapter<'a, T> {\n    fn content(&self, content: Content) -> Option<Content> {\n        content.add_single(&format!(\"{:?}\", self.0)).optional()\n    }\n}\n\nimpl<'a, T: Debug> ModuleDisplay for DisplayAdapter<'a, T> {}\n\n/// Attribute to print in the display method.\n#[derive(Clone, Debug)]\npub struct Attribute {\n    /// Name of the attribute.\n    pub name: String,\n\n    /// Value of the attribute.\n    pub value: String,\n\n    /// Type of the attribute.\n    pub ty: String,\n}\n\n/// Extracts the short name of a type T\n///\n/// # Returns\n///\n/// A string slice representing the short name of the type.\npub fn extract_type_name<T: ?Sized>() -> &'static str {\n    // Get the full type name of T, including module path and generic parameters\n    let ty = any::type_name::<T>();\n\n    // Find the first occurrence of '<' in the full type name\n    // If not found, use the length of the type name\n    let end = ty.find('<').unwrap_or(ty.len());\n\n    // Slice the type name up to the first '<' or the end\n    let ty = &ty[0..end];\n\n    // Find the last occurrence of \"::\" in the sliced type name\n    // If found, add 2 to skip the \"::\" itself\n    // If not found, start from the beginning of the type name\n    let start = ty.rfind(\"::\").map(|i| i + 2).unwrap_or(0);\n\n    // Find the last occurrence of '<' in the sliced type name\n    // If not found, use the length of the type name\n    let end = ty.rfind('<').unwrap_or(ty.len());\n\n    // If the start index is less than the end index,\n    // return the slice of the type name from start to end\n    // Otherwise, return the entire sliced type name\n    if start < end { &ty[start..end] } else { ty }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/initializer.rs",
    "content": "use crate::tensor::Shape;\n\nuse crate::config::Config;\nuse crate::module::{Param, ParamId};\nuse crate::tensor::backend::Backend;\nuse crate::tensor::{Distribution, Tensor, s};\n\nuse crate as burn;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Enum specifying with what values a tensor should be initialized\n#[derive(Config, Debug, PartialEq)]\npub enum Initializer {\n    /// Fills tensor with specified value everywhere\n    Constant {\n        /// The value to fill the tensor with\n        value: f64,\n    },\n    /// Fills tensor with 1s everywhere\n    Ones,\n    /// Fills tensor with 0s everywhere\n    Zeros,\n    /// Fills tensor with values drawn uniformly between specified values\n    Uniform {\n        /// The minimum value to draw from\n        min: f64,\n\n        /// The maximum value to draw from\n        max: f64,\n    },\n    /// Fills tensor with values drawn from normal distribution with specified mean and std\n    Normal {\n        /// The mean of the normal distribution\n        mean: f64,\n\n        /// The standard deviation of the normal distribution\n        std: f64,\n    },\n    /// Fills tensor with values according to the uniform version of Kaiming initialization\n    KaimingUniform {\n        /// The gain to use in initialization formula\n        gain: f64,\n\n        /// Whether to use fan out only in initialization formula\n        fan_out_only: bool,\n    },\n    /// Fills tensor with values according to the uniform version of Kaiming initialization\n    KaimingNormal {\n        /// The gain to use in initialization formula\n        gain: f64,\n\n        /// Whether to use fan out only in initialization formula\n        fan_out_only: bool,\n    },\n    /// Fills tensor with values according to the uniform version of Xavier Glorot initialization\n    /// described in [Understanding the difficulty of training deep feedforward neural networks\n    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)\n    XavierUniform {\n        /// The gain to use in initialization formula\n        gain: f64,\n    },\n    /// Fills tensor with values according to the normal version of Xavier Glorot initialization\n    /// described in [Understanding the difficulty of training deep feedforward neural networks\n    /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)\n    XavierNormal {\n        /// The gain to use in initialization formula\n        gain: f64,\n    },\n    /// Fills tensor with values according to the (semi) orthogonal initialization\n    /// described in [Exact solutions to the nonlinear dynamics of learning in deep linear neural networks`\n    ///  - [Saxe, A. et al. (2013)](https://arxiv.org/abs/1312.6120)\n    Orthogonal {\n        /// The gain to use in initialization formula\n        gain: f64,\n    },\n}\n\nimpl Initializer {\n    /// Inits a tensor parameter of given shape with values depending on initializer kind.\n    ///\n    /// # Params\n    ///\n    /// - shape: Shape of the initiated tensor.\n    pub fn init<B: Backend, const D: usize, S: Into<Shape>>(\n        &self,\n        shape: S,\n        device: &B::Device,\n    ) -> Param<Tensor<B, D>> {\n        self.init_with(shape, None, None, device)\n    }\n\n    /// Inits a tensor parameter of given shape with values depending on initializer kind.\n    ///\n    /// # Params\n    ///\n    /// - shape: Shape of the initiated tensor.\n    pub fn init_with<B: Backend, const D: usize, S: Into<Shape>>(\n        &self,\n        shape: S,\n        fan_in: Option<usize>,\n        fan_out: Option<usize>,\n        device: &B::Device,\n    ) -> Param<Tensor<B, D>> {\n        let device = device.clone();\n        let shape: Shape = shape.into();\n        let config = self.clone();\n        let shape_for_closure = shape.clone();\n\n        Param::uninitialized(\n            ParamId::new(),\n            move |device, require_grad| {\n                B::memory_persistent_allocations(device, (), move |_| {\n                    let mut tensor = config.init_tensor(shape.clone(), fan_in, fan_out, device);\n\n                    if require_grad {\n                        tensor = tensor.require_grad();\n                    }\n\n                    tensor\n                })\n            },\n            device,\n            true,\n            shape_for_closure,\n        )\n    }\n\n    fn init_tensor<B: Backend, const D: usize, S: Into<Shape>>(\n        &self,\n        shape: S,\n        fan_in: Option<usize>,\n        fan_out: Option<usize>,\n        device: &B::Device,\n    ) -> Tensor<B, D> {\n        let shape = shape.into();\n        match self {\n            Initializer::Constant { value } => Tensor::<B, D>::full(shape, *value, device),\n            Initializer::Ones => Tensor::<B, D>::ones(shape, device),\n            Initializer::Zeros => Tensor::<B, D>::zeros(shape, device),\n            Initializer::Uniform { min, max } => uniform_draw(shape, *min, *max, device),\n            Initializer::Normal { mean, std } => normal_draw(shape, *mean, *std, device),\n            Initializer::KaimingUniform { gain, fan_out_only } => {\n                let a = 3.0f64.sqrt() * *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);\n                uniform_draw(shape, -a, a, device)\n            }\n            Initializer::KaimingNormal { gain, fan_out_only } => {\n                let std = *gain * self.kaiming_std(*fan_out_only, fan_in, fan_out);\n                normal_draw(shape, 0.0, std, device)\n            }\n            Initializer::XavierUniform { gain } => {\n                let a = 3.0f64.sqrt() * *gain * self.xavier_std(fan_in, fan_out);\n                uniform_draw(shape, -a, a, device)\n            }\n            Initializer::XavierNormal { gain } => {\n                let std = *gain * self.xavier_std(fan_in, fan_out);\n                normal_draw(shape, 0.0, std, device)\n            }\n            Initializer::Orthogonal { gain } => {\n                // following the implementation in pytorch:\n                // https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/init.py#L574\n\n                assert!(\n                    D >= 2,\n                    \"Expected D (in Tensor<B, D>) to be greater or equal 2; (D >= 2)\"\n                );\n\n                let rows: usize = shape.dims::<D>()[0];\n                let cols: usize = shape.num_elements() / rows;\n\n                let mut t: Tensor<B, 2> = normal_draw([rows, cols], 0.0, 1.0, device);\n\n                if rows < cols {\n                    t = t.transpose();\n                }\n\n                let (q, r) = qr_decomposition(t, device);\n                let [r_rows, r_cols] = r.clone().dims();\n\n                let diag_r = Tensor::<B, 2>::ones([1, r_rows], device)\n                    .matmul(Tensor::<B, 2>::eye(r_cols, device).mul(r.clone()));\n\n                let ph = diag_r.clone().sign();\n\n                let mut q = q.mul(ph);\n\n                if rows < cols {\n                    q = q.transpose();\n                }\n\n                q.reshape(shape).mul_scalar(*gain)\n            }\n        }\n    }\n\n    fn kaiming_std(\n        &self,\n        fan_out_only: bool,\n        fan_in: Option<usize>,\n        fan_out: Option<usize>,\n    ) -> f64 {\n        let fan = if fan_out_only { fan_out } else { fan_in };\n        let fan = fan.expect(\n            \"Can't use Kaiming initialization without specifying fan. Use init_with method.\",\n        );\n\n        1.0 / (fan as f64).sqrt()\n    }\n\n    fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {\n        let fan_in = fan_in.expect(\n            \"Can't use Xavier initialization without specifying fan in. Use init_with method and \\\n             provide fan_in.\",\n        );\n        let fan_out = fan_out.expect(\n            \"Can't use Xavier initialization without specifying fan out. Use init_with method and \\\n             provide fan_out.\",\n        );\n        (2.0 / (fan_in + fan_out) as f64).sqrt()\n    }\n}\n\nfn uniform_draw<B: Backend, const D: usize, S: Into<Shape>>(\n    shape: S,\n    low: f64,\n    high: f64,\n    device: &B::Device,\n) -> Tensor<B, D> {\n    let distribution = Distribution::Uniform(low, high);\n    Tensor::<B, D>::random(shape, distribution, device)\n}\n\nfn normal_draw<B: Backend, const D: usize, S: Into<Shape>>(\n    shape: S,\n    mean: f64,\n    std: f64,\n    device: &B::Device,\n) -> Tensor<B, D> {\n    let distribution = Distribution::Normal(mean, std);\n    Tensor::<B, D>::random(shape, distribution, device)\n}\n\nfn qr_decomposition<B: Backend>(\n    a: Tensor<B, 2>,\n    device: &B::Device,\n) -> (Tensor<B, 2>, Tensor<B, 2>) {\n    // Calculate the QR decomposition using Gram-Schmidt-process: https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process\n\n    let [m, n] = a.clone().dims();\n    let mut q = Tensor::<B, 2>::zeros([m, n], device);\n    let mut r = Tensor::<B, 2>::zeros([n, n], device);\n\n    for j in 0..n {\n        let mut v: Tensor<B, 1> = a.clone().slice(s![.., j..=j]).squeeze_dim(1);\n\n        for i in 0..j {\n            let q_i: Tensor<B, 1> = q.clone().slice(s![.., i..=i]).squeeze_dim(1);\n            let r_ij = q_i.clone().mul(v.clone()).sum();\n\n            r = r\n                .clone()\n                .slice_assign([i..i + 1, j..j + 1], r_ij.clone().unsqueeze());\n\n            v = v - q_i.mul(r_ij);\n        }\n\n        // norm of v\n        let r_jj = v\n            .clone()\n            .powf(Tensor::from_floats([2.0], device))\n            .sum()\n            .sqrt();\n\n        r = r\n            .clone()\n            .slice_assign([j..j + 1, j..j + 1], r_jj.clone().unsqueeze());\n\n        let q_j = v / r_jj;\n\n        q = q\n            .clone()\n            .slice_assign([0..m, j..j + 1], q_j.unsqueeze_dim(1));\n    }\n\n    (q, r)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use burn_tensor::{ElementConversion, TensorData};\n    use num_traits::Pow;\n\n    pub type TB = burn_ndarray::NdArray<f32>;\n    use burn_tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TB>;\n\n    fn assert_normal_init(expected_mean: f64, expected_var: f64, tensor: &Tensor<TB, 2>) {\n        let (actual_vars, actual_means) = tensor.clone().var_mean(0);\n        let actual_vars = actual_vars.to_data();\n        let actual_vars = actual_vars.as_slice::<FT>().unwrap();\n        let actual_means = actual_means.to_data();\n        let actual_means = actual_means.as_slice::<FT>().unwrap();\n\n        for i in 0..tensor.shape()[0] {\n            let actual_var = actual_vars[i] as f64;\n            let actual_mean = actual_means[i] as f64;\n\n            assert!(\n                (expected_var - actual_var).abs() <= 0.1,\n                \"Expected variance to be between {expected_var} += 0.1, but got {actual_var}\"\n            );\n            assert!(\n                (expected_mean - actual_mean).abs() <= 0.1,\n                \"Expected mean to be between {expected_mean} += 0.1, but got {actual_mean}\"\n            );\n        }\n    }\n\n    #[test]\n    fn initializer_uniform_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let (min, max) = (0.0, 1.0);\n        let uniform = Initializer::Uniform { min, max };\n        let tensor: Tensor<TB, 4> = uniform.init([2, 2, 2, 2], &Default::default()).into_value();\n\n        tensor\n            .into_data()\n            .assert_within_range::<FT>(min.elem()..max.elem());\n    }\n\n    #[test]\n    fn initializer_normal_init() {\n        // seed random generator\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let (mean, std) = (0.0, 1.0);\n        let normal: Tensor<TB, 1> = Initializer::Normal { mean, std }\n            .init([1000], &Default::default())\n            .into_value();\n        let (var_act, mean_act) = normal.var_mean(0);\n\n        let var_act: f32 = var_act.into_scalar().elem();\n        let mean_act: f32 = mean_act.into_scalar().elem();\n\n        assert!(\n            var_act > 0.9 && var_act < 1.1,\n            \"Expected variance to be between 1.0 += 0.1, but got {var_act}\"\n        );\n        assert!(\n            mean_act > -0.1 && mean_act < 0.1,\n            \"Expected mean to be between 0.0 += 0.1, but got {mean_act}\"\n        );\n    }\n\n    #[test]\n    fn initializer_constant_init() {\n        let value = 5.0;\n        let constants: Tensor<TB, 4> = Initializer::Constant { value }\n            .init([2, 2, 2, 2], &Default::default())\n            .into_value();\n        constants.sum().to_data().assert_approx_eq::<FT>(\n            &TensorData::from([value as f32 * 16.0]),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn initializer_zeros_init() {\n        let zeros: Tensor<TB, 4> = Initializer::Zeros\n            .init([2, 2, 2, 2], &Default::default())\n            .into_value();\n        zeros\n            .sum()\n            .to_data()\n            .assert_approx_eq::<FT>(&TensorData::from([0.0]), Tolerance::default());\n    }\n\n    #[test]\n    fn initializer_ones_init() {\n        let ones: Tensor<TB, 4> = Initializer::Ones\n            .init([2, 2, 2, 2], &Default::default())\n            .into_value();\n        ones.sum()\n            .to_data()\n            .assert_approx_eq::<FT>(&TensorData::from([16.0]), Tolerance::default());\n    }\n\n    #[test]\n    fn initializer_kaiming_uniform_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2_f64;\n        let (fan_in, fan_out) = (5, 6);\n        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();\n\n        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {\n            gain,\n            fan_out_only: false,\n        }\n        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())\n        .into_value();\n        tensor.into_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_kaiming_normal_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2.;\n        let (fan_in, fan_out) = (1000, 10);\n        let expected_mean = 0_f64;\n\n        let expected_var = (gain * (1. / (fan_in as f64)).sqrt()).pow(2.);\n        let tensor: Tensor<TB, 2> = Initializer::KaimingNormal {\n            gain,\n            fan_out_only: false,\n        }\n        .init_with([fan_out, fan_in], Some(fan_in), None, &Default::default())\n        .into_value();\n        assert_normal_init(expected_mean, expected_var, &tensor)\n    }\n\n    #[test]\n    fn initializer_kaiming_uniform_init_bias() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2_f64;\n        let shape = [3];\n        let fan_in = 5;\n        let k = (gain * (3.0 / fan_in as f64).sqrt()).elem::<FT>();\n\n        let tensor: Tensor<TB, 1> = Initializer::KaimingUniform {\n            gain,\n            fan_out_only: false,\n        }\n        .init_with(shape, Some(fan_in), None, &Default::default())\n        .into_value();\n        tensor.into_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_kaiming_uniform_init_fan_out() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2_f64;\n        let (fan_in, fan_out) = (5, 6);\n        let k = (gain * (3.0 / fan_out as f64).sqrt()).elem::<FT>();\n\n        let tensor: Tensor<TB, 2> = Initializer::KaimingUniform {\n            gain,\n            fan_out_only: true,\n        }\n        .init_with([fan_out, fan_in], None, Some(fan_out), &Default::default())\n        .into_value();\n        tensor.into_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    #[should_panic]\n    fn initializer_kaiming_uniform_no_fan() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2_f64;\n        let (fan_in, fan_out) = (5, 6);\n\n        let _: Tensor<TB, 2> = Initializer::KaimingUniform {\n            gain,\n            fan_out_only: false,\n        }\n        .init([fan_out, fan_in], &Default::default())\n        .into_value();\n    }\n\n    #[test]\n    fn initializer_xavier_uniform_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2.;\n        let (fan_in, fan_out) = (5, 6);\n        let bound = (gain * (6. / (fan_in + fan_out) as f64).sqrt()).elem::<FT>();\n        let tensor: Tensor<TB, 2> = Initializer::XavierUniform { gain }\n            .init_with(\n                [fan_out, fan_in],\n                Some(fan_in),\n                Some(fan_out),\n                &Default::default(),\n            )\n            .into_value();\n\n        tensor.into_data().assert_within_range(-bound..bound);\n    }\n\n    #[test]\n    fn initializer_xavier_normal_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2.;\n        let (fan_in, fan_out) = (1000, 10);\n        let expected_mean = 0_f64;\n\n        let expected_var = (gain * (2. / (fan_in as f64 + fan_out as f64)).sqrt()).powf(2.);\n        let tensor: Tensor<TB, 2> = Initializer::XavierNormal { gain }\n            .init_with(\n                [fan_out, fan_in],\n                Some(fan_in),\n                Some(fan_out),\n                &Default::default(),\n            )\n            .into_value();\n        assert_normal_init(expected_mean, expected_var, &tensor)\n    }\n\n    #[test]\n    #[should_panic]\n    fn initializer_xavier_uniform_no_fan() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 2.;\n        let (fan_in, fan_out) = (5, 6);\n        let _: Tensor<TB, 2> = Initializer::XavierUniform { gain }\n            .init([fan_out, fan_in], &Default::default())\n            .into_value();\n    }\n\n    #[test]\n    fn test_qr_decomposition() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        // test values follow the example from https://pytorch.org/docs/stable/generated/torch.linalg.qr.html#torch.linalg.qr\n        let a = Tensor::<TB, 2>::from_floats(\n            [[12., -51., 4.], [6., 167., -68.], [-4., 24., -41.]],\n            &Default::default(),\n        );\n        let qr = qr_decomposition(a.clone(), &Default::default());\n\n        // Q @ R should reconstruct input `a`\n        let q_matmul_r = qr.0.clone().matmul(qr.1.clone());\n\n        // assert that the difference between input (`a`) and Q @ R is (almost) zero\n        q_matmul_r\n            .into_data()\n            .assert_approx_eq::<FT>(&a.into_data(), Tolerance::rel_abs(0.1, 0.1));\n    }\n\n    #[test]\n    fn initializer_orthogonal_correct() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 1.;\n\n        // test 2D tensor\n        let size = 10;\n        let q: Tensor<TB, 2> = Initializer::Orthogonal { gain }\n            .init([size, size], &Default::default())\n            .into_value();\n        let eye = Tensor::<TB, 2>::eye(size, &Default::default());\n\n        // Q.T @ Q should be close to identity matrix\n        q.clone()\n            .transpose()\n            .matmul(q)\n            .into_data()\n            .assert_approx_eq::<FT>(&eye.into_data(), Tolerance::rel_abs(0.1, 0.1));\n    }\n\n    #[test]\n    fn initializer_orthogonal_init() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 1.;\n\n        // test 2D tensor\n        let shape = [25, 30];\n        let t: Tensor<TB, 2> = Initializer::Orthogonal { gain }\n            .init(shape, &Default::default())\n            .into_value();\n        let dims = t.dims();\n        assert_eq!(\n            shape, dims,\n            \"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})\"\n        );\n\n        // test 3D tensor\n        let shape = [24, 6, 85];\n        let t: Tensor<TB, 3> = Initializer::Orthogonal { gain }\n            .init(shape, &Default::default())\n            .into_value();\n        let dims = t.dims();\n        assert_eq!(\n            shape, dims,\n            \"Expected the shape of the input tensor to match the shape of the output. ({shape:?}, {dims:?})\"\n        );\n    }\n\n    #[test]\n    #[should_panic]\n    fn initializer_orthogonal_init_1d() {\n        let device = Default::default();\n        TB::seed(&device, 0);\n\n        let gain = 1.;\n\n        // test 1D tensor\n        let shape = [3];\n        let _: Tensor<TB, 1> = Initializer::Orthogonal { gain }\n            .init(shape, &Default::default())\n            .into_value();\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/mod.rs",
    "content": "mod base;\nmod display;\nmod initializer;\nmod param;\nmod quantize;\n#[cfg(feature = \"std\")]\nmod reinit;\n\npub use base::*;\npub use display::*;\npub use initializer::*;\npub use param::*;\npub use quantize::*;\n\n#[cfg(feature = \"std\")]\npub use reinit::*;\n"
  },
  {
    "path": "crates/burn-core/src/module/param/base.rs",
    "content": "use super::ParamId;\nuse alloc::{boxed::Box, format};\nuse burn_std::stub::RwLock;\nuse burn_tensor::Shape;\nuse core::cell::OnceCell;\nuse core::ops::Deref;\n\n#[cfg(target_has_atomic = \"ptr\")]\nuse alloc::sync::Arc;\n\n#[cfg(not(target_has_atomic = \"ptr\"))]\nuse portable_atomic_util::Arc;\n\n#[cfg(target_has_atomic = \"ptr\")]\ntype Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;\n\n#[cfg(not(target_has_atomic = \"ptr\"))]\ntype Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;\n\n#[cfg(target_has_atomic = \"ptr\")]\nfn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {\n    Arc::new(func)\n}\n\n#[cfg(not(target_has_atomic = \"ptr\"))]\nfn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {\n    Arc::new(Box::new(func))\n}\n\n/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they\n/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during\n/// training, and loaded during inference. If you don't want to save the tensors\n/// and/or don't want to update it during training, you don't need this type to wrap your tensor.\n///\n/// # Core Lazy Initialization Architecture\n///\n/// `Param<T>` has a dual-state design using `OnceCell<T>`:\n///\n/// ## State Management\n///\n/// **Two possible states:**\n///\n/// 1. **Initialized**: `state: OnceCell<T>` contains value, `initialization: None`\n/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock<Option<Uninitialized<T>>>)`\npub struct Param<T: Parameter> {\n    /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.\n    pub id: ParamId,\n    /// The OnceCell holding the initialized parameter value.\n    /// Empty for uninitialized parameters, populated after first access or explicit initialization.\n    pub(crate) state: OnceCell<T>,\n    /// The deferred initialization state for lazy parameters.\n    ///\n    /// **State Transitions:**\n    /// - Initialized params: `None`\n    /// - Uninitialized params: `Some(RwLock<Some(Uninitialized<T>)>)`\n    /// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)\n    pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,\n    pub(crate) param_mapper: ParamMapper<T>,\n    // For stateful `module.valid()` <> `module.train()`\n    pub(crate) require_grad: bool,\n}\n\n#[derive(Clone)]\n/// Applies transformations when loading and saving parameters.\n///\n/// # Mapper System\n///\n/// `ParamMapper<T>` allows applying transformations during serialization and deserialization:\n/// - `load: Option<Mapper<T>>` - transformation during deserialization (applied in `transform_for_load()`)\n/// - `save: Option<Mapper<T>>` - transformation during serialization (applied in `transform_for_save()`)\n///\n/// These are commonly used for:\n/// - Quantization/dequantization\n/// - Precision conversion (e.g., FP32 ↔ FP16)\n/// - Custom parameter transformations\npub struct ParamMapper<T: Parameter> {\n    load: Option<Mapper<T>>,\n    save: Option<Mapper<T>>,\n}\n\nimpl<T: Parameter> core::fmt::Debug for ParamMapper<T> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_fmt(format_args!(\n            \"ParamMapper {{ load: {}, save: {} }}\",\n            self.load.is_some(),\n            self.save.is_some()\n        ))\n    }\n}\n\nimpl<T: Parameter> ParamMapper<T> {\n    /// Applies the transformation when loading the given parameter.\n    pub fn on_load(&self, param: T) -> T {\n        match &self.load {\n            Some(mapper) => mapper(param),\n            None => param,\n        }\n    }\n    /// Applies the transformation when saving the given parameter.\n    pub fn on_save(&self, param: T) -> T {\n        match &self.save {\n            Some(mapper) => mapper(param),\n            None => param,\n        }\n    }\n}\n\nimpl<T: Parameter> Default for ParamMapper<T> {\n    fn default() -> Self {\n        Self {\n            load: None,\n            save: None,\n        }\n    }\n}\n\nimpl<T: Parameter> core::fmt::Display for Param<T> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(format!(\"Param: {}\", self.id).as_str())\n    }\n}\n\nimpl<T: Parameter> core::fmt::Debug for Param<T> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(format!(\"Param: {} - {:?}\", self.id, self.param_mapper).as_str())\n    }\n}\n\n/// Trait that defines what is necessary for a type to be a parameter.\npub trait Parameter: Clone + core::fmt::Debug + Send {\n    /// The device type to be used.\n    type Device: Clone;\n\n    /// Fetch the device.\n    fn device(&self) -> Self::Device;\n\n    /// Fetch the gradient requirement.\n    fn is_require_grad(&self) -> bool;\n\n    /// Set the gradient requirement.\n    fn set_require_grad(self, require_grad: bool) -> Self;\n}\n\n/// The deferred initialization state for lazy parameters.\n#[allow(clippy::type_complexity)]\npub(crate) struct Uninitialized<P: Parameter> {\n    /// The initialization function. Called with `(device, is_require_grad) -> Parameter`.\n    /// This function is consumed during initialization via `FnOnce`.\n    init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,\n    /// The target device on which the parameter should be initialized.\n    /// Used by `lazy_device()` to provide device information without triggering initialization.\n    pub(crate) device: P::Device,\n    /// The gradient requirement for the parameter.\n    /// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization.\n    pub(crate) is_require_grad: bool,\n    /// The shape of the tensor parameter.\n    /// Used by `lazy_shape()` to provide shape information without triggering initialization.\n    pub(crate) shape: Shape,\n}\n\nimpl<P: Parameter> Uninitialized<P> {\n    /// Consumes the uninitialized state and runs the initialization function.\n    ///\n    /// This is called by [Param::val] when accessing an uninitialized parameter for the first time.\n    /// The function is given the stored device and gradient requirement, and returns the initialized parameter.\n    fn initialize(self) -> P {\n        let init = self.init;\n        init(&self.device, self.is_require_grad)\n    }\n}\n\nimpl<T: Parameter> Param<T> {\n    /// Create a new parameter that is already initialized.\n    pub fn initialized(id: ParamId, value: T) -> Self {\n        let require_grad = value.is_require_grad();\n        Self {\n            id,\n            state: OnceCell::from(value),\n            initialization: None,\n            param_mapper: Default::default(),\n            require_grad,\n        }\n    }\n\n    /// Create a new parameter that is not already initialized.\n    pub fn uninitialized<F>(\n        id: ParamId,\n        init: F,\n        device: T::Device,\n        is_require_grad: bool,\n        shape: Shape,\n    ) -> Self\n    where\n        F: FnOnce(&T::Device, bool) -> T + Send + 'static,\n    {\n        Self {\n            id,\n            state: OnceCell::new(),\n            initialization: Some(RwLock::new(Some(Uninitialized {\n                init: Box::new(init),\n                device,\n                is_require_grad,\n                shape,\n            }))),\n            param_mapper: Default::default(),\n            require_grad: is_require_grad,\n        }\n    }\n\n    /// Gets the parameter value, initializing it lazily if needed.\n    ///\n    /// For initialized parameters, this returns a clone of the cached value.\n    /// For uninitialized parameters, this triggers initialization:\n    pub fn val(&self) -> T {\n        self.state\n            .get_or_init(|| {\n                let mut result = self\n                    .initialization\n                    .as_ref()\n                    .expect(\"Should have an initialization when no state provided.\")\n                    .write()\n                    .unwrap();\n                let state = result.take().expect(\"Should exist when not initialized\");\n                state.initialize()\n            })\n            .clone()\n    }\n\n    /// Check if the parameter has been initialized.\n    ///\n    /// Returns `true` if the parameter's value has been computed and cached,\n    /// `false` if it's still lazy and will be initialized on first access.\n    pub fn is_initialized(&self) -> bool {\n        self.state.get().is_some()\n    }\n\n    /// Gets the parameter's value while consuming the parameter.\n    pub fn into_value(self) -> T {\n        self.consume().1\n    }\n\n    /// Gets the parameter id and value while consuming the parameter.\n    pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {\n        let tensor = self.val();\n\n        core::mem::drop(self.state);\n\n        (self.id, tensor, self.param_mapper)\n    }\n\n    /// Execute the given function on the inner value.\n    pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {\n        let (id, tensor, param_mapper) = self.consume();\n        let tensor = func(tensor);\n        let require_grad = tensor.is_require_grad();\n\n        Self {\n            id,\n            state: OnceCell::from(tensor),\n            initialization: None,\n            param_mapper,\n            require_grad,\n        }\n    }\n\n    /// Create an initialized parameter with the given id, value, and param mapper.\n    ///\n    /// This is a helper method for creating parameters while preserving the param mapper,\n    /// typically used in ModuleMapper implementations.\n    pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {\n        let require_grad = value.is_require_grad();\n        Self {\n            id,\n            state: OnceCell::from(value),\n            initialization: None,\n            param_mapper,\n            require_grad,\n        }\n    }\n\n    /// Runs a transformation on the parameter when loading.\n    pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {\n        self.param_mapper.load = Some(new_mapper(func));\n\n        self\n    }\n\n    /// Runs a transformation on the parameter when saving.\n    pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {\n        self.param_mapper.save = Some(new_mapper(func));\n\n        self\n    }\n\n    /// Execute the given function on the inner value.\n    pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self\n    where\n        T: 'static,\n    {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.map(func),\n        };\n\n        let mut init = initialization.write().unwrap();\n\n        match init.as_mut() {\n            Some(value) => {\n                #[allow(clippy::type_complexity)]\n                let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =\n                    Box::new(|_, _| panic!(\"Fake func to not have null ref.\"));\n                core::mem::swap(&mut prev, &mut value.init);\n\n                value.init = Box::new(|a, b| {\n                    let tensor = prev(a, b);\n                    func(tensor)\n                });\n                core::mem::drop(init);\n                self\n            }\n            None => {\n                core::mem::drop(init);\n                self.map(func)\n            }\n        }\n    }\n\n    /// The device on which the parameter is or will be initialized, **without triggering initialization**.\n    ///\n    /// This is critical for the load optimization: when loading tensors into an uninitialized parameter,\n    /// we need to know the target device to move the loaded tensor appropriately, but we don't want to\n    /// trigger the initialization function (which would allocate an unnecessary tensor).\n    ///\n    /// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to\n    /// preserve lazy initialization.\n    pub fn lazy_device(&self) -> T::Device {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.device(),\n        };\n\n        let init = initialization.read().unwrap();\n\n        match init.as_ref() {\n            Some(value) => value.device.clone(),\n            None => self.device(),\n        }\n    }\n\n    /// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**.\n    ///\n    /// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization.\n    /// When loading tensors into an uninitialized parameter, we need to apply the correct gradient\n    /// setting to the loaded tensor without triggering the initialization function.\n    ///\n    /// # Notes\n    ///\n    /// This is a crate-private function, since users are not expected to use `is_require_grad` of an\n    /// uninitialized module to then override its value. All low-level functions should be provided\n    /// by `burn` and should handle those details.\n    pub(crate) fn lazy_is_require_grad(&self) -> bool {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.is_require_grad(),\n        };\n\n        let init = initialization.read().unwrap();\n\n        match init.as_ref() {\n            Some(value) => value.is_require_grad,\n            None => self.is_require_grad(),\n        }\n    }\n\n    /// Override the gradient requirement for the current parameter.\n    pub fn set_require_grad(self, require_grad: bool) -> Self {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.map(|tensor| tensor.set_require_grad(require_grad)),\n        };\n\n        let mut init = initialization.write().unwrap();\n        let mut is_lazy = false;\n\n        if let Some(value) = init.as_mut() {\n            is_lazy = true;\n            value.is_require_grad = require_grad;\n        };\n\n        core::mem::drop(init);\n\n        if is_lazy {\n            return self;\n        }\n\n        self.map(|tensor| tensor.set_require_grad(require_grad))\n    }\n}\n\nimpl<T: Parameter> Clone for Param<T> {\n    fn clone(&self) -> Self {\n        let mut param = Param::initialized(self.id, self.val());\n        param.param_mapper = self.param_mapper.clone();\n        param\n    }\n}\n\nimpl<T: Parameter> Deref for Param<T> {\n    type Target = T;\n\n    fn deref(&self) -> &Self::Target {\n        self.state.get_or_init(|| {\n            let mut result = self\n                .initialization\n                .as_ref()\n                .expect(\"Should have an initialization when no state provided.\")\n                .write()\n                .unwrap();\n\n            let state = result.take().expect(\"Should exist when not initialized\");\n            state.initialize()\n        })\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/constant.rs",
    "content": "use alloc::{format, string::ToString};\nuse core::{fmt::Display, marker::PhantomData};\n\nuse crate as burn;\nuse crate::{\n    module::{\n        AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault,\n        ModuleMapper, ModuleVisitor,\n    },\n    record::{PrecisionSettings, Record},\n};\nuse burn_tensor::{\n    BasicAutodiffOps, BasicOps, Tensor,\n    backend::{AutodiffBackend, Backend},\n    ops::Device,\n};\n\n#[deprecated(\n    since = \"0.21.0\",\n    note = \"ConstantRecord is misleading as it doesn't persist data. Use EmptyRecord instead.\"\n)]\n/// A record representing the absence of persistent module state.\npub type ConstantRecord = EmptyRecord;\n\n/// A record representing the absence of persistent module state.\n///\n/// `EmptyRecord` is used for modules that do not store any data to be\n/// serialized or restored (e.g., modules marked with `#[module(skip)]`\n/// or modules without parameters).\n///\n/// This record contains no fields and serializes to `None`.\n#[derive(Debug, Clone, Copy, new, Default, PartialEq, Eq)]\npub struct EmptyRecord;\n\nimpl serde::Serialize for EmptyRecord {\n    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n    where\n        S: serde::Serializer,\n    {\n        // nothing to serialize\n        S::serialize_none(serializer)\n    }\n}\n\nimpl<'de> serde::Deserialize<'de> for EmptyRecord {\n    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>\n    where\n        D: serde::Deserializer<'de>,\n    {\n        deserializer.deserialize_option(serde::de::IgnoredAny).ok();\n        Ok(EmptyRecord::new())\n    }\n}\n\nimpl<B: Backend> Record<B> for EmptyRecord {\n    type Item<S: PrecisionSettings> = EmptyRecord;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        self\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {\n        item\n    }\n}\n/// Constant macro.\n#[macro_export]\nmacro_rules! empty {\n    (module) => {\n        type Record = burn::module::EmptyRecord;\n\n        fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {\n            // Nothing to do\n        }\n\n        fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {\n            self\n        }\n\n        fn load_record(self, _record: Self::Record) -> Self {\n            self\n        }\n\n        fn into_record(self) -> Self::Record {\n            burn::module::EmptyRecord::new()\n        }\n\n        fn to_device(self, _: &B::Device) -> Self {\n            self\n        }\n\n        fn fork(self, _: &B::Device) -> Self {\n            self\n        }\n\n        fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {\n            devices\n        }\n    };\n\n    (ad_module, $type:ty) => {\n        type InnerModule = $type;\n\n        fn valid(&self) -> Self::InnerModule {\n            self.clone()\n        }\n\n        fn from_inner(module: Self::InnerModule) -> Self {\n            module\n        }\n    };\n\n    ($type:ty) => {\n        impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {\n            empty!(module);\n        }\n\n        impl<B: burn::tensor::backend::AutodiffBackend> burn::module::AutodiffModule<B> for $type {\n            empty!(ad_module, $type);\n        }\n\n        impl burn::module::ModuleDisplayDefault for $type {\n            fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {\n                let string = format!(\"{}\", self);\n                content.add_formatted(&string).optional()\n            }\n        }\n\n        impl burn::module::ModuleDisplay for $type {}\n    };\n}\n\n// TODO: breaking change for these constant types (currently empty record, non-persistent)?\n\n// General Types\nempty!(alloc::string::String);\nempty!(bool);\n\n// Float Types\nempty!(f64);\nempty!(f32);\nempty!(half::bf16);\nempty!(half::f16);\n\n// Unsigned Integer Types\nempty!(usize);\nempty!(u64);\nempty!(u32);\nempty!(u16);\nempty!(u8);\n\n// Signed Integer Types\nempty!(isize);\nempty!(i64);\nempty!(i32);\nempty!(i16);\nempty!(i8);\n\nimpl burn::module::ModuleDisplay for str {}\nimpl burn::module::ModuleDisplayDefault for str {\n    fn content(&self, content: burn::module::Content) -> Option<burn::module::Content> {\n        content.add_formatted(&self).optional()\n    }\n}\n\n// TODO: tensor record should persist\nimpl<const D: usize, B: Backend, K: BasicOps<B>> Module<B> for Tensor<B, D, K> {\n    type Record = EmptyRecord;\n\n    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {}\n\n    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {\n        self\n    }\n\n    fn into_record(self) -> Self::Record {\n        EmptyRecord\n    }\n\n    fn load_record(self, _record: Self::Record) -> Self {\n        self\n    }\n\n    fn to_device(self, device: &B::Device) -> Self {\n        self.to_device(device)\n    }\n\n    fn fork(self, device: &B::Device) -> Self {\n        self.to_device(device)\n    }\n\n    fn collect_devices(&self, mut devices: Devices<B>) -> Devices<B> {\n        let device = self.device();\n\n        if !devices.contains(&device) {\n            devices.push(device)\n        }\n\n        devices\n    }\n}\n\nimpl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplayDefault for Tensor<B, D, K> {\n    fn content(&self, content: Content) -> Option<Content> {\n        let string = format!(\"Tensor {{rank: {D}, shape: {:?}}}\", self.shape().as_slice());\n        content.add_single(&string).optional()\n    }\n}\n\nimpl<const D: usize, B: Backend, K: BasicOps<B>> ModuleDisplay for Tensor<B, D, K> {}\n\nimpl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> AutodiffModule<B>\n    for Tensor<B, D, K>\n{\n    type InnerModule = Tensor<B::InnerBackend, D, K::InnerKind>;\n\n    fn valid(&self) -> Self::InnerModule {\n        self.clone().inner()\n    }\n\n    fn from_inner(tensor: Self::InnerModule) -> Self {\n        Tensor::from_inner(tensor)\n    }\n}\n\nimpl<B: Backend> Module<B> for PhantomData<B> {\n    type Record = EmptyRecord;\n\n    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {\n        // Nothing to do\n    }\n\n    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {\n        self\n    }\n\n    fn load_record(self, _record: Self::Record) -> Self {\n        self\n    }\n\n    fn into_record(self) -> Self::Record {\n        EmptyRecord::new()\n    }\n\n    fn to_device(self, _: &Device<B>) -> Self {\n        self\n    }\n\n    fn fork(self, _: &Device<B>) -> Self {\n        self\n    }\n\n    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {\n        devices\n    }\n}\n\nimpl<B: Backend> ModuleDisplayDefault for PhantomData<B> {\n    fn content(&self, content: Content) -> Option<Content> {\n        content.add_single(&\"PhantomData\".to_string()).optional()\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for PhantomData<B> {}\n\nimpl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {\n    type InnerModule = PhantomData<B::InnerBackend>;\n\n    fn valid(&self) -> Self::InnerModule {\n        PhantomData\n    }\n\n    fn from_inner(_module: Self::InnerModule) -> Self {\n        PhantomData\n    }\n}\n\n/// Container to satisfy the Module trait for types that are not modules.\n#[derive(Clone, Debug)]\n#[deprecated(\n    since = \"0.21.0\",\n    note = \"Ignored<T> is deprecated. Use #[module(skip)] for non-persistent fields (same behavior).\"\n)]\npub struct Ignored<T>(pub T);\n\n#[allow(deprecated)]\nimpl<B, T> Module<B> for Ignored<T>\nwhere\n    B: Backend,\n    T: Sync + Send + core::fmt::Debug + Clone,\n{\n    type Record = EmptyRecord;\n\n    fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {\n        // Nothing to do\n    }\n\n    fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {\n        self\n    }\n\n    fn load_record(self, _record: Self::Record) -> Self {\n        self\n    }\n\n    fn into_record(self) -> Self::Record {\n        EmptyRecord::new()\n    }\n\n    fn to_device(self, _: &Device<B>) -> Self {\n        self\n    }\n\n    fn fork(self, _: &Device<B>) -> Self {\n        self\n    }\n\n    fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {\n        devices\n    }\n}\n\n#[allow(deprecated)]\nimpl<T> ModuleDisplayDefault for Ignored<T>\nwhere\n    T: Sync + Send + core::fmt::Debug + Clone,\n{\n    fn content(&self, content: Content) -> Option<Content> {\n        // For now, just print the debug representation of the ignored value\n        content.add_single(&format!(\"{:?}\", self.0)).optional()\n    }\n}\n\n#[allow(deprecated)]\nimpl<T> ModuleDisplay for Ignored<T> where T: Sync + Send + core::fmt::Debug + Clone {}\n\n#[allow(deprecated)]\nimpl<T> Display for Ignored<T>\nwhere\n    T: Sync + Send + core::fmt::Debug + Clone,\n{\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        write!(f, \"{:?}\", self.0)\n    }\n}\n\n#[allow(deprecated)]\nimpl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>\nwhere\n    B: AutodiffBackend,\n    T: Sync + Send + core::fmt::Debug + Clone,\n{\n    type InnerModule = Ignored<T>;\n\n    fn valid(&self) -> Self::InnerModule {\n        self.clone()\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        module\n    }\n}\n\n#[allow(deprecated)]\n// Implement deref for Ignored\nimpl<T> core::ops::Deref for Ignored<T> {\n    type Target = T;\n\n    fn deref(&self) -> &Self::Target {\n        &self.0\n    }\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use core::marker::PhantomData;\n\n    use burn_tensor::backend::Backend;\n    use burn_tensor::{Device, Tensor};\n\n    use crate::TestBackend;\n    use crate::{\n        TestAutodiffBackend,\n        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},\n    };\n    use burn::module::Module;\n\n    use crate as burn;\n\n    #[test]\n    fn tensor_load_record_setting() {\n        let device: &Device<TestAutodiffBackend> = &Default::default();\n        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], device);\n\n        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();\n        let bytes = Recorder::<TestAutodiffBackend>::record(\n            &byte_recorder,\n            tensor.clone().into_record(),\n            (),\n        )\n        .unwrap();\n\n        let no_grad_is_require_grad = tensor\n            .clone()\n            .no_grad()\n            .load_record(\n                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)\n                    .unwrap(),\n            )\n            .is_require_grad();\n\n        let with_default_is_require_grad = tensor\n            .load_record(\n                Recorder::<TestAutodiffBackend>::load(&byte_recorder, bytes.clone(), device)\n                    .unwrap(),\n            )\n            .is_require_grad();\n\n        assert!(!no_grad_is_require_grad);\n        assert!(!with_default_is_require_grad);\n    }\n\n    #[test]\n    fn empty_module_with_phantom() {\n        #[derive(Module, Debug, new)]\n        struct EmptyModule<B: Backend> {\n            _phantom: PhantomData<B>,\n        }\n\n        let _module = EmptyModule::<TestBackend>::new();\n\n        assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/id.rs",
    "content": "use core::hash::{BuildHasher, Hasher};\n\nuse alloc::string::String;\nuse burn_std::id::IdGenerator;\nuse data_encoding::BASE32_DNSSEC;\n\n// Hashbrown changed its default hasher in 0.15, but there are some issues\n// https://github.com/rust-lang/hashbrown/issues/577\n// Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher.\ntype DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;\n\n/// Parameter ID.\n#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]\npub struct ParamId {\n    value: u64,\n}\n\nimpl From<u64> for ParamId {\n    fn from(value: u64) -> Self {\n        Self { value }\n    }\n}\n\nimpl Default for ParamId {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl ParamId {\n    /// Create a new parameter ID.\n    pub fn new() -> Self {\n        Self {\n            value: IdGenerator::generate(),\n        }\n    }\n\n    /// Gets the internal value of the id.\n    pub fn val(&self) -> u64 {\n        self.value\n    }\n\n    /// Convert the parameter ID into a string.\n    pub fn serialize(self) -> String {\n        BASE32_DNSSEC.encode(&self.value.to_le_bytes())\n    }\n\n    /// Deserialize a param id.\n    ///\n    /// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).\n    pub fn deserialize(encoded: &str) -> ParamId {\n        let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {\n            Ok(bytes) => {\n                let mut buffer = [0u8; 8];\n                buffer[..bytes.len()].copy_from_slice(&bytes);\n                u64::from_le_bytes(buffer)\n            }\n            Err(err) => match uuid::Uuid::try_parse(encoded) {\n                // Backward compatibility with uuid parameter identifiers\n                Ok(id) => {\n                    // Hash the 128-bit uuid to 64-bit\n                    // Though not *theoretically* unique, the probability of a collision should be extremely low\n                    let mut hasher = DefaultHashBuilder::default().build_hasher();\n                    // let mut hasher = DefaultHasher::new();\n                    hasher.write(id.as_bytes());\n                    hasher.finish()\n                }\n                Err(_) => panic!(\"Invalid id. {err}\"),\n            },\n        };\n\n        ParamId::from(u64_id)\n    }\n}\n\nimpl core::fmt::Display for ParamId {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(&self.serialize())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn param_serde_deserialize() {\n        let val = ParamId::from(123456u64);\n        let deserialized = ParamId::deserialize(&val.serialize());\n        assert_eq!(val, deserialized);\n    }\n\n    #[test]\n    fn param_serde_deserialize_legacy() {\n        let legacy_val = [45u8; 6];\n        let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));\n        assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);\n        assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);\n    }\n\n    #[test]\n    fn param_serde_deserialize_legacy_uuid() {\n        // Ensure support for legacy uuid deserialization and make sure it results in the same output\n        let legacy_id = \"30b82c23-788d-4d63-a743-ada258d5f13c\";\n        let param_id1 = ParamId::deserialize(legacy_id);\n        let param_id2 = ParamId::deserialize(legacy_id);\n        assert_eq!(param_id1, param_id2);\n    }\n\n    #[test]\n    #[should_panic = \"Invalid id.\"]\n    fn param_serde_deserialize_invalid_id() {\n        let invalid_uuid = \"30b82c23-788d-4d63-ada258d5f13c\";\n        let _ = ParamId::deserialize(invalid_uuid);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/mod.rs",
    "content": "mod base;\nmod constant;\nmod id;\nmod primitive;\nmod running;\nmod tensor;\nmod visitor;\n\npub use base::*;\npub use constant::*;\npub use id::*;\npub use running::*;\npub use visitor::*;\n"
  },
  {
    "path": "crates/burn-core/src/module/param/primitive.rs",
    "content": "use crate::module::{\n    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,\n    ModuleVisitor,\n};\n\nuse alloc::{format, string::ToString, vec::Vec};\n\nuse burn_tensor::{\n    backend::{AutodiffBackend, Backend},\n    ops::Device,\n};\nuse core::fmt::Debug;\n\nimpl<T, B> Module<B> for Option<T>\nwhere\n    T: Module<B> + Debug + Send + Clone,\n    B: Backend,\n{\n    type Record = Option<T::Record>;\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        if let Some(module) = self {\n            module.visit(visitor)\n        }\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        self.map(|module| module.map(mapper))\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        let is_constant = self.num_params() == 0;\n\n        if is_constant {\n            return self;\n        }\n\n        self.zip(record)\n            .map(|(module, record)| module.load_record(record))\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.map(Module::into_record)\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.map(|module| module.to_device(device))\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.map(|module| module.fork(device))\n    }\n\n    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {\n        if let Some(module) = self.as_ref() {\n            devices = module.collect_devices(devices);\n        }\n\n        devices\n    }\n}\n\nimpl<T: ModuleDisplay> ModuleDisplayDefault for Option<T> {\n    fn content(&self, content: Content) -> Option<Content> {\n        match self {\n            Some(module) => content.add_single(module).optional(),\n            None => content.add_single(\"None\").optional(),\n        }\n    }\n}\n\nimpl<T: ModuleDisplay> ModuleDisplay for Option<T> {}\n\nimpl<T, B> AutodiffModule<B> for Option<T>\nwhere\n    T: AutodiffModule<B> + Debug + Send + Clone,\n    B: AutodiffBackend,\n{\n    type InnerModule = Option<T::InnerModule>;\n\n    fn valid(&self) -> Self::InnerModule {\n        self.as_ref().map(|module| module.valid())\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        module.map(|module| T::from_inner(module))\n    }\n}\n\nimpl<T, B> Module<B> for Vec<T>\nwhere\n    T: Module<B> + Debug + Send + Clone,\n    B: Backend,\n{\n    type Record = Vec<T::Record>;\n\n    fn num_params(&self) -> usize {\n        let mut num_params = 0;\n        for module in self.iter() {\n            num_params += module.num_params();\n        }\n\n        num_params\n    }\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        for (i, module) in self.iter().enumerate() {\n            let index_str = alloc::format!(\"{}\", i);\n            visitor.enter_module(&index_str, \"Vec\");\n            module.visit(visitor);\n            visitor.exit_module(&index_str, \"Vec\");\n        }\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        self.into_iter()\n            .enumerate()\n            .map(|(i, module)| {\n                let index_str = alloc::format!(\"{}\", i);\n                mapper.enter_module(&index_str, \"Vec\");\n                let mapped = module.map(mapper);\n                mapper.exit_module(&index_str, \"Vec\");\n                mapped\n            })\n            .collect()\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.into_iter().map(Module::into_record).collect()\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        assert_eq!(\n            self.len(),\n            record.len(),\n            r#\"[Load Record Error] The vec record does not the same length as the module.\n            Make sure you module initialization is compatible with the record being loaded.\n            \"#,\n        );\n\n        self.into_iter()\n            .zip(record)\n            .map(|(module, record)| module.load_record(record))\n            .collect()\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.into_iter()\n            .map(|module| module.to_device(device))\n            .collect()\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.into_iter().map(|module| module.fork(device)).collect()\n    }\n\n    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {\n        for module in self.iter() {\n            devices = module.collect_devices(devices);\n        }\n\n        devices\n    }\n}\n\nimpl<T: ModuleDisplay> ModuleDisplayDefault for Vec<T> {\n    fn content(&self, content: Content) -> Option<Content> {\n        self.iter()\n            .enumerate()\n            .fold(content, |acc, (i, module)| {\n                let index = format!(\"{i}\");\n                acc.add(&index, module)\n            })\n            .set_top_level_type(format!(\"Vec<0..{}>\", self.len()).as_str())\n            .optional()\n    }\n}\n\nimpl<T: ModuleDisplay> ModuleDisplay for Vec<T> {}\n\nimpl<T, B> AutodiffModule<B> for Vec<T>\nwhere\n    T: AutodiffModule<B> + Debug + Send + Clone,\n    B: AutodiffBackend,\n{\n    type InnerModule = Vec<T::InnerModule>;\n\n    fn valid(&self) -> Self::InnerModule {\n        self.iter().map(|module| module.valid()).collect()\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        module\n            .into_iter()\n            .map(|module| T::from_inner(module))\n            .collect()\n    }\n}\n\nimpl<const N: usize, T, B> Module<B> for [T; N]\nwhere\n    T: Module<B> + Debug + Send + Clone,\n    B: Backend,\n{\n    type Record = [T::Record; N];\n\n    fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {\n        for module in self.iter() {\n            devices = module.collect_devices(devices);\n        }\n\n        devices\n    }\n\n    fn num_params(&self) -> usize {\n        let mut num_params = 0;\n        for module in self.iter() {\n            num_params += module.num_params();\n        }\n\n        num_params\n    }\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        for (i, module) in self.iter().enumerate() {\n            let index_str = alloc::format!(\"{}\", i);\n            visitor.enter_module(&index_str, \"Array\");\n            module.visit(visitor);\n            visitor.exit_module(&index_str, \"Array\");\n        }\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        let mut result = Vec::with_capacity(N);\n        for (i, module) in IntoIterator::into_iter(self).enumerate() {\n            let index_str = alloc::format!(\"{}\", i);\n            mapper.enter_module(&index_str, \"Array\");\n            let mapped = module.map(mapper);\n            mapper.exit_module(&index_str, \"Array\");\n            result.push(mapped);\n        }\n        result\n            .try_into()\n            .unwrap_or_else(|v: Vec<T>| panic!(\"Expected array of length {}, got {}\", N, v.len()))\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        self.into_iter()\n            .zip(record)\n            .map(|(module, record)| module.load_record(record))\n            .collect::<Vec<_>>()\n            .try_into()\n            .unwrap()\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.map(Module::into_record)\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.map(|module| module.to_device(device))\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.map(|module| module.fork(device))\n    }\n}\n\nimpl<const N: usize, T: ModuleDisplay> ModuleDisplayDefault for [T; N] {\n    fn content(&self, content: Content) -> Option<Content> {\n        self.iter()\n            .enumerate()\n            .fold(content, |acc, (i, module)| {\n                let index = format!(\"{i}\");\n                acc.add(&index, module)\n            })\n            .set_top_level_type(format!(\"[0..{}]\", self.len()).as_str())\n            .optional()\n    }\n}\n\nimpl<const N: usize, T: ModuleDisplay> ModuleDisplay for [T; N] {}\n\nimpl<const N: usize, T, B> AutodiffModule<B> for [T; N]\nwhere\n    T: AutodiffModule<B> + Debug + Send + Clone,\n    T::InnerModule: Debug,\n    B: AutodiffBackend,\n{\n    type InnerModule = [T::InnerModule; N];\n\n    fn valid(&self) -> Self::InnerModule {\n        self.clone().map(|module| module.valid())\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        module.map(|module| T::from_inner(module))\n    }\n}\n\n/// A macro for generating implementations for tuple modules of different sizes.\n/// For example: `impl_module_tuple!([L0, L1][0, 1])`.\n/// Would generate an implementation for a tuple of size 2.\n/// For this macro to work properly, please adhere to the convention:\n/// `impl_module_tuple!([L0, L1, ..., Ln][0, 1, ..., n])`.\nmacro_rules! impl_module_tuple {\n    // `$l` represents the generic modules.\n    // `$i` represents the indices of the modules in the tuple.\n    ([$($l:ident),*][$($i:tt),*]) => {\n        impl<B, $($l,)*> Module<B> for ($($l,)*)\n        where\n            B: Backend,\n            $($l: Module<B> + Debug + Send + Clone,)*\n        {\n            type Record = ($($l::Record),*);\n\n            fn collect_devices(&self, mut devices: Vec<B::Device>) -> Vec<B::Device> {\n                $(devices = self.$i.collect_devices(devices);)*\n                devices\n            }\n\n            fn fork(self, device: &Device<B>) -> Self {\n                ($(self.$i.fork(device),)*)\n            }\n\n            fn to_device(self, device: &Device<B>) -> Self {\n                ($(self.$i.to_device(device),)*)\n            }\n\n            fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n                $(\n                    let index_str = $i.to_string();\n                    visitor.enter_module(&index_str, \"Tuple\");\n                    self.$i.visit(visitor);\n                    visitor.exit_module(&index_str, \"Tuple\");\n                )*\n            }\n\n            fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n                ($(\n                    {\n                        let index_str = $i.to_string();\n                        mapper.enter_module(&index_str, \"Tuple\");\n                        let mapped = self.$i.map(mapper);\n                        mapper.exit_module(&index_str, \"Tuple\");\n                        mapped\n                    }\n                ,)*)\n            }\n\n            fn load_record(self, record: Self::Record) -> Self {\n                ($(self.$i.load_record(record.$i),)*)\n            }\n\n            fn into_record(self) -> Self::Record {\n                ($(self.$i.into_record(),)*)\n            }\n        }\n\n        impl<B, $($l,)*> AutodiffModule<B> for ($($l,)*)\n        where\n            B: AutodiffBackend,\n            $($l: AutodiffModule<B> + Debug + Send + Clone,)*\n        {\n            type InnerModule = ($($l::InnerModule,)*);\n\n            fn valid(&self) -> Self::InnerModule {\n                ($(self.$i.valid(),)*)\n            }\n\n            fn from_inner(module: Self::InnerModule) -> Self {\n                ($($l::from_inner(module.$i),)*)\n            }\n        }\n\n        impl<$($l,)*> ModuleDisplayDefault for ($($l,)*)\n        where\n            $($l: ModuleDisplay,)*\n        {\n            fn content(&self, content: Content) -> Option<Content> {\n                let content = content\n                    $(.add(&format!(\"{}\", $i), &self.$i))*\n                    .set_top_level_type(format!(\"({})\", stringify!($($l),*)).as_str());\n                content.optional()\n            }\n        }\n\n        impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {}\n\n    };\n}\n\nimpl_module_tuple!([L0, L1][0, 1]);\nimpl_module_tuple!([L0, L1, L2][0, 1, 2]);\nimpl_module_tuple!([L0, L1, L2, L3][0, 1, 2, 3]);\nimpl_module_tuple!([L0, L1, L2, L3, L4][0, 1, 2, 3, 4]);\nimpl_module_tuple!([L0, L1, L2, L3, L4, L5][0, 1, 2, 3, 4, 5]);\nimpl_module_tuple!([L0, L1, L2, L3, L4, L5, L6][0, 1, 2, 3, 4, 5, 6]);\nimpl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7][0, 1, 2, 3, 4, 5, 6, 7]);\nimpl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8][0, 1, 2, 3, 4, 5, 6, 7, 8]);\nimpl_module_tuple!([L0, L1, L2, L3, L4, L5, L6, L7, L8, L9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn dont_override_constant_module_when_loading_record() {\n        let module = Some(42);\n\n        let record = Module::<TestBackend>::into_record(module);\n        let loaded = Module::<TestBackend>::load_record(module, record);\n\n        assert_eq!(loaded, module);\n    }\n    #[test]\n    fn dont_override_constant_module_when_loading_none_record() {\n        let module = Some(42);\n\n        let record = None;\n        let loaded = Module::<TestBackend>::load_record(module, record);\n\n        assert_eq!(loaded, module);\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/running.rs",
    "content": "use super::ParamId;\nuse crate::module::{\n    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,\n    ModuleVisitor, Param,\n};\n\nuse alloc::string::ToString;\nuse alloc::vec::Vec;\n\n#[cfg(target_has_atomic = \"ptr\")]\nuse alloc::sync::Arc;\n\n#[cfg(not(target_has_atomic = \"ptr\"))]\nuse portable_atomic_util::Arc;\n\nuse burn_std::stub::Mutex;\nuse burn_tensor::{\n    Tensor,\n    backend::{AutodiffBackend, Backend},\n    ops::Device,\n};\n\n#[cfg(feature = \"std\")]\nmod threading {\n    pub(super) use std::collections::HashMap;\n    pub(super) use std::thread::ThreadId;\n\n    #[inline(always)]\n    pub(super) fn get_thread_current_id() -> ThreadId {\n        std::thread::current().id()\n    }\n}\n\n#[cfg(not(feature = \"std\"))]\nmod threading {\n    pub(super) use burn_std::stub::ThreadId;\n    pub(super) use hashbrown::HashMap;\n\n    #[inline(always)]\n    pub(super) fn get_thread_current_id() -> ThreadId {\n        panic!(\"Current thread id is not available\")\n    }\n}\n\n// Re-export items from the disabled/enabled blocks\nuse threading::*;\n\n/// A state that can be updated during the forward pass while being thread safe.\n///\n/// # Note\n///\n/// The state value is the average of all updates on all threads.\n#[derive(Clone, Debug)]\npub struct RunningState<V> {\n    id: ParamId,\n    values: Arc<Mutex<HashMap<ThreadId, V>>>,\n    value: Arc<Mutex<V>>,\n}\n\n// Implement display for the module\n\nimpl<V> core::fmt::Display for RunningState<V> {\n    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {\n        write!(f, \"RunningState(id={})\", self.id)\n    }\n}\n\nimpl<V> ModuleDisplayDefault for RunningState<V> {\n    fn content(&self, content: Content) -> Option<Content> {\n        content\n            .add_formatted(&\"RunningState\".to_string())\n            .optional()\n    }\n}\n\nimpl<V> ModuleDisplay for RunningState<V> {}\n\nimpl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {\n    type Record = Param<Tensor<B, D>>;\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        let tensor = self.value.lock().unwrap();\n        let param = Param::initialized(self.id, tensor.clone());\n        visitor.visit_float(&param)\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        let mut tensor = self.value.lock().unwrap();\n        let param = Param::initialized(self.id, tensor.clone());\n        let param_out = mapper.map_float(param);\n        let (_, tensor_out, _) = param_out.consume();\n\n        *tensor = tensor_out;\n        core::mem::drop(tensor);\n\n        self\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.sync();\n        let tensor = self.value.lock().unwrap();\n\n        Param::initialized(self.id, tensor.clone())\n    }\n\n    fn load_record(mut self, record: Self::Record) -> Self {\n        let mut tensor = self.value.lock().unwrap();\n        *tensor = record.val().to_device(&tensor.device());\n        self.id = record.id;\n\n        core::mem::drop(tensor);\n\n        self\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        let mut tensor = self.value.lock().unwrap();\n        let tensor_out = tensor.clone().to_device(device);\n\n        *tensor = tensor_out;\n        core::mem::drop(tensor);\n\n        self\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.to_device(device) // Same thing here since no grad.\n    }\n\n    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {\n        let device = self.value.lock().unwrap().device();\n\n        if !devices.contains(&device) {\n            devices.push(device)\n        }\n\n        devices\n    }\n}\n\nimpl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {\n    /// Create a new running state.\n    pub fn new(value: Tensor<B, D>) -> Self {\n        Self {\n            id: ParamId::new(),\n            values: Arc::new(Mutex::new(HashMap::new())),\n            value: Arc::new(Mutex::new(value)),\n        }\n    }\n\n    /// Create a new running state.\n    pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {\n        Self {\n            id,\n            values: Arc::new(Mutex::new(HashMap::new())),\n            value: Arc::new(Mutex::new(value)),\n        }\n    }\n\n    /// Create a new running state from a record.\n    pub fn from_record(record: Param<Tensor<B, D>>) -> Self {\n        let tensor = record.val();\n        Self {\n            id: record.id,\n            values: Arc::new(Mutex::new(HashMap::new())),\n            value: Arc::new(Mutex::new(tensor)),\n        }\n    }\n\n    /// Update the value on the current thread.\n    pub fn update(&self, value: Tensor<B, D>) {\n        let thread_id = get_thread_current_id();\n        let mut map = self.values.lock().unwrap();\n\n        if map.contains_key(&thread_id) {\n            self.update_value(&mut map);\n        }\n\n        map.insert(thread_id, value);\n    }\n\n    /// Get the current value,\n    ///\n    /// # Note\n    ///\n    /// The current value might be outdated by one update.\n    pub fn value(&self) -> Tensor<B, D> {\n        let value = self.value.lock().unwrap();\n        value.clone()\n    }\n\n    /// Get the current value and make sure it is sync.\n    ///\n    /// # Note\n    ///\n    /// Don't use this function after an update on the same thread where other threads might have to\n    /// register their update before the actual synchronization needs to happen.\n    pub fn value_sync(&self) -> Tensor<B, D> {\n        let thread_id = get_thread_current_id();\n        let mut map = self.values.lock().unwrap();\n\n        if map.contains_key(&thread_id) {\n            self.update_value(&mut map);\n        }\n\n        let value = self.value.lock().unwrap();\n        value.clone()\n    }\n\n    fn sync(&self) {\n        let mut map = self.values.lock().unwrap();\n\n        if !map.is_empty() {\n            self.update_value(&mut map);\n        }\n    }\n\n    fn update_value(&self, map: &mut HashMap<ThreadId, Tensor<B, D>>) {\n        let mut value_updated: Option<Tensor<B, D>> = None;\n        let mut counter = 0;\n\n        for (_key, tensor) in map.drain() {\n            counter += 1;\n\n            value_updated = match value_updated {\n                Some(current) => {\n                    let device = current.device();\n                    Some(tensor.to_device(&device).add(current))\n                }\n                None => Some(tensor),\n            };\n        }\n\n        if let Some(value) = value_updated {\n            let value = value.div_scalar(counter);\n            let mut value_old = self.value.lock().unwrap();\n            *value_old = value;\n        }\n    }\n}\n\nimpl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tensor<B, D>> {\n    type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;\n\n    fn valid(&self) -> Self::InnerModule {\n        self.sync();\n        let value = self.value();\n\n        RunningState::with_id(self.id, value.inner())\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        module.sync();\n        let value = module.value();\n\n        RunningState::with_id(module.id, Tensor::from_inner(value))\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/tensor.rs",
    "content": "use super::{Param, ParamId, Parameter};\nuse crate::module::{\n    AutodiffModule, Content, HasAutodiffModule, Module, ModuleDisplay, ModuleDisplayDefault,\n    ModuleMapper, ModuleVisitor,\n};\nuse crate::tensor::{\n    Tensor,\n    backend::{AutodiffBackend, Backend},\n};\nuse alloc::{format, string::ToString, vec::Vec};\nuse burn_tensor::{Bool, Float, Int, TensorData, ops::Device};\n\nimpl<B: Backend, const D: usize> Parameter for Tensor<B, D, Float> {\n    type Device = B::Device;\n\n    fn device(&self) -> Self::Device {\n        Tensor::device(self)\n    }\n\n    fn is_require_grad(&self) -> bool {\n        Tensor::is_require_grad(self)\n    }\n\n    fn set_require_grad(self, require_grad: bool) -> Self {\n        Tensor::set_require_grad(self, require_grad)\n    }\n}\n\nimpl<B: Backend, const D: usize> Parameter for Tensor<B, D, Int> {\n    type Device = B::Device;\n\n    fn device(&self) -> Self::Device {\n        Tensor::device(self)\n    }\n\n    fn is_require_grad(&self) -> bool {\n        false\n    }\n\n    fn set_require_grad(self, _require_grad: bool) -> Self {\n        self\n    }\n}\n\nimpl<B: Backend, const D: usize> Parameter for Tensor<B, D, Bool> {\n    type Device = B::Device;\n\n    fn device(&self) -> Self::Device {\n        Tensor::device(self)\n    }\n\n    fn is_require_grad(&self) -> bool {\n        false\n    }\n\n    fn set_require_grad(self, _require_grad: bool) -> Self {\n        self\n    }\n}\n\nimpl<B: Backend, const D: usize> Param<Tensor<B, D>> {\n    /// Create a new parameter from a float tensor.\n    ///\n    /// # Warnings\n    ///\n    /// We strongly recommend using [Param::uninitialized] if you are using this method to\n    /// initialize parameters inside a module, since the tensor initialization will be lazy,\n    /// making the loading of weights more performant.\n    pub fn from_tensor(value: Tensor<B, D>) -> Self {\n        // When creating a parameter from a float tensor, we automatically mark it as requiring\n        // gradients, so that it can be updated by an optimizer.\n        Param::initialized(ParamId::new(), value.require_grad())\n    }\n\n    /// The shape of the parameter, **without triggering initialization**.\n    ///\n    /// This is critical for shape validation during loading: when applying tensors to an\n    /// uninitialized parameter, we need to validate the shape without triggering the\n    /// initialization function (which would allocate an unnecessary tensor).\n    ///\n    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to\n    /// preserve lazy initialization.\n    pub fn lazy_shape(&self) -> burn_tensor::Shape {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.shape(),\n        };\n\n        let init = initialization.read().unwrap();\n\n        match init.as_ref() {\n            Some(value) => value.shape.clone(),\n            None => self.shape(),\n        }\n    }\n\n    /// Create a new parameter from data.\n    pub fn from_data<T>(data: T, device: &B::Device) -> Self\n    where\n        T: Into<TensorData>,\n    {\n        let data: TensorData = data.into();\n        // When creating a parameter from a float tensor, we automatically mark it as requiring\n        // gradients, so that it can be updated by an optimizer.\n        B::memory_persistent_allocations(device, data, |data| {\n            let value = Tensor::from_data(data, device);\n            Param::initialized(ParamId::new(), value.require_grad())\n        })\n    }\n\n    /// Transform a parameter for loading by applying load transformations.\n    ///\n    /// This method is used to restore a parameter from a tensor (typically during deserialization).\n    /// It ensures the tensor is moved to the expected device, applies the param mapper's\n    /// `on_load` transformation, and preserves the autodiff settings (require_grad).\n    pub fn transform_for_load(self, tensor: Tensor<B, D>, param_id: ParamId) -> Self {\n        let mut new_tensor = tensor;\n\n        let mapper = self.param_mapper.clone();\n\n        let expected_device = self.lazy_device();\n        let expected_require_grad = self.lazy_is_require_grad();\n\n        // Make sure we load the tensor into the same module device.\n        if new_tensor.device() != expected_device {\n            new_tensor = new_tensor.to_device(&expected_device).detach();\n        }\n\n        new_tensor = mapper.on_load(new_tensor);\n\n        // Make sure we load the tensor with the same autodiff setting.\n        new_tensor = new_tensor.set_require_grad(expected_require_grad);\n\n        let mut loaded = Self::initialized(param_id, new_tensor);\n        loaded.param_mapper = mapper;\n        loaded\n    }\n\n    /// Transform a parameter for saving by applying save transformations.\n    ///\n    /// This method is used to prepare a parameter for saving (typically during serialization).\n    /// It applies the param mapper's `on_save` transformation, which can be used\n    /// to modify the tensor before serialization (e.g., quantization, precision conversion).\n    pub fn transform_for_save(&self) -> Self {\n        let mut tensor = self.val();\n        let mapper = self.param_mapper.clone();\n\n        tensor = mapper.on_save(tensor);\n\n        Self::initialized(self.id, tensor)\n    }\n}\n\nimpl<B: Backend, const D: usize> Param<Tensor<B, D, Int>> {\n    /// The shape of the parameter, **without triggering initialization**.\n    ///\n    /// This is critical for shape validation during loading: when applying tensors to an\n    /// uninitialized parameter, we need to validate the shape without triggering the\n    /// initialization function (which would allocate an unnecessary tensor).\n    ///\n    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to\n    /// preserve lazy initialization.\n    pub fn lazy_shape(&self) -> burn_tensor::Shape {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.shape(),\n        };\n\n        let init = initialization.read().unwrap();\n\n        match init.as_ref() {\n            Some(value) => value.shape.clone(),\n            None => self.shape(),\n        }\n    }\n\n    /// Transform a parameter for loading by applying load transformations.\n    ///\n    /// This method is used to restore a parameter from a tensor (typically during deserialization).\n    /// It ensures the tensor is moved to the expected device and applies the param mapper's\n    /// `on_load` transformation.\n    pub fn transform_for_load(self, tensor: Tensor<B, D, Int>, param_id: ParamId) -> Self {\n        let mut new_tensor = tensor;\n\n        let mapper = self.param_mapper.clone();\n\n        let expected_device = self.lazy_device();\n\n        // Make sure we load the tensor into the same module device.\n        if new_tensor.device() != expected_device {\n            new_tensor = new_tensor.to_device(&expected_device);\n        }\n\n        new_tensor = mapper.on_load(new_tensor);\n\n        let mut loaded = Self::initialized(param_id, new_tensor);\n        loaded.param_mapper = mapper;\n        loaded\n    }\n\n    /// Transform a parameter for saving by applying save transformations.\n    ///\n    /// This method is used to prepare a parameter for saving (typically during serialization).\n    /// It applies the param mapper's `on_save` transformation, which can be used\n    /// to modify the tensor before serialization (e.g., quantization, precision conversion).\n    pub fn transform_for_save(&self) -> Self {\n        let mut tensor = self.val();\n        let mapper = self.param_mapper.clone();\n\n        tensor = mapper.on_save(tensor);\n\n        Self::initialized(self.id, tensor)\n    }\n}\n\nimpl<B: Backend, const D: usize> Param<Tensor<B, D, Bool>> {\n    /// The shape of the parameter, **without triggering initialization**.\n    ///\n    /// This is critical for shape validation during loading: when applying tensors to an\n    /// uninitialized parameter, we need to validate the shape without triggering the\n    /// initialization function (which would allocate an unnecessary tensor).\n    ///\n    /// **Returns:**\n    /// - For uninitialized params: the shape from the `Uninitialized` struct\n    /// - For initialized params: the actual shape from the tensor\n    ///\n    /// Use this instead of [crate::tensor::Tensor::shape] when you need the shape but want to\n    /// preserve lazy initialization.\n    pub fn lazy_shape(&self) -> burn_tensor::Shape {\n        let initialization = match &self.initialization {\n            Some(init) => init,\n            None => return self.shape(),\n        };\n\n        let init = initialization.read().unwrap();\n\n        match init.as_ref() {\n            Some(value) => value.shape.clone(),\n            None => self.shape(),\n        }\n    }\n\n    /// Transform a parameter for loading by applying load transformations.\n    ///\n    /// This method is used to restore a parameter from a tensor (typically during deserialization).\n    /// It ensures the tensor is moved to the expected device and applies the param mapper's\n    /// `on_load` transformation.\n    pub fn transform_for_load(self, tensor: Tensor<B, D, Bool>, param_id: ParamId) -> Self {\n        let mut new_tensor = tensor;\n\n        let mapper = self.param_mapper.clone();\n\n        let expected_device = self.lazy_device();\n\n        // Make sure we load the tensor into the same module device.\n        if new_tensor.device() != expected_device {\n            new_tensor = new_tensor.to_device(&expected_device);\n        }\n\n        new_tensor = mapper.on_load(new_tensor);\n\n        let mut loaded = Self::initialized(param_id, new_tensor);\n        loaded.param_mapper = mapper;\n        loaded\n    }\n\n    /// Transform a parameter for saving by applying save transformations.\n    ///\n    /// This method is used to prepare a parameter for saving (typically during serialization).\n    /// It applies the param mapper's `on_save` transformation, which can be used\n    /// to modify the tensor before serialization (e.g., quantization, precision conversion).\n    pub fn transform_for_save(&self) -> Self {\n        let mut tensor = self.val();\n        let mapper = self.param_mapper.clone();\n\n        tensor = mapper.on_save(tensor);\n\n        Self::initialized(self.id, tensor)\n    }\n}\n\nimpl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {\n    type Record = Param<Tensor<B, D>>;\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        visitor.visit_float(self)\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        mapper.map_float(self)\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.transform_for_save()\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        let (record_param_id, record_tensor, _) = record.consume();\n        self.transform_for_load(record_tensor, record_param_id)\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.map(|tensor| tensor.to_device(device))\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.map(|tensor| {\n            let is_require_grad = tensor.is_require_grad();\n            let mut tensor = tensor.to_device(device).detach();\n\n            if is_require_grad {\n                tensor = tensor.require_grad();\n            }\n\n            tensor\n        })\n    }\n\n    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {\n        let device = self.val().device();\n\n        if !devices.contains(&device) {\n            devices.push(device)\n        }\n\n        devices\n    }\n}\n\nimpl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D>> {\n    fn content(&self, content: Content) -> Option<Content> {\n        let id = if content.display_settings.show_param_id() {\n            format!(\", id: {}\", self.id)\n        } else {\n            \"\".to_string()\n        };\n        let string = format!(\n            \"ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}\",\n            self.shape().as_slice()\n        );\n        content.add_formatted(&string).optional()\n    }\n}\nimpl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D>> {}\n\nimpl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Int>> {\n    type Record = Param<Tensor<B, D, Int>>;\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        visitor.visit_int(self)\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        mapper.map_int(self)\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.transform_for_save()\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        let (record_param_id, record_tensor, _) = record.consume();\n        self.transform_for_load(record_tensor, record_param_id)\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.map(|tensor| tensor.to_device(device))\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.to_device(device) // Don't support autodiff.\n    }\n\n    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {\n        let device = self.val().device();\n\n        if !devices.contains(&device) {\n            devices.push(device)\n        }\n\n        devices\n    }\n}\n\nimpl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Int>> {\n    fn content(&self, content: Content) -> Option<Content> {\n        let id = if content.display_settings.show_param_id() {\n            format!(\", id: {}\", self.id)\n        } else {\n            \"\".to_string()\n        };\n        let string = format!(\n            \"ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}\",\n            self.shape().as_slice()\n        );\n        content.add_formatted(&string).optional()\n    }\n}\nimpl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Int>> {}\n\nimpl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D, Bool>> {\n    type Record = Param<Tensor<B, D, Bool>>;\n\n    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {\n        visitor.visit_bool(self)\n    }\n\n    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {\n        mapper.map_bool(self)\n    }\n\n    fn into_record(self) -> Self::Record {\n        self.transform_for_save()\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        let (record_param_id, record_tensor, _) = record.consume();\n        self.transform_for_load(record_tensor, record_param_id)\n    }\n\n    fn to_device(self, device: &Device<B>) -> Self {\n        self.map(|tensor| tensor.to_device(device))\n    }\n\n    fn fork(self, device: &Device<B>) -> Self {\n        self.to_device(device) // Don't support autodiff.\n    }\n\n    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {\n        let device = self.val().device();\n\n        if !devices.contains(&device) {\n            devices.push(device)\n        }\n\n        devices\n    }\n}\n\nimpl<const D: usize, B: Backend> ModuleDisplayDefault for Param<Tensor<B, D, Bool>> {\n    fn content(&self, content: Content) -> Option<Content> {\n        let id = if content.display_settings.show_param_id() {\n            format!(\", id: {}\", self.id)\n        } else {\n            \"\".to_string()\n        };\n\n        let string = format!(\n            \"ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}\",\n            self.shape().as_slice()\n        );\n        content.add_formatted(&string).optional()\n    }\n}\n\nimpl<const D: usize, B: Backend> ModuleDisplay for Param<Tensor<B, D, Bool>> {}\n\nimpl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D>> {\n    type InnerModule = Param<Tensor<B::InnerBackend, D>>;\n\n    fn valid(&self) -> Self::InnerModule {\n        // Preserve initialized param `require_grad` state, but reset the inner value's\n        let require_grad = self.require_grad;\n        let mut param = Param::initialized(self.id, self.val().inner().set_require_grad(false));\n        param.require_grad = require_grad;\n        param\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        // Reinstate the param's `require_grad` state\n        let tensor = Tensor::from_inner(module.val()).set_require_grad(module.require_grad);\n        Param::initialized(module.id, tensor)\n    }\n}\n\nimpl<const D: usize, B: AutodiffBackend> HasAutodiffModule<B>\n    for Param<Tensor<B::InnerBackend, D>>\n{\n    type TrainModule = Param<Tensor<B, D>>;\n}\n\nimpl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Int>> {\n    type InnerModule = Param<Tensor<B::InnerBackend, D, Int>>;\n\n    fn valid(&self) -> Self::InnerModule {\n        Param::initialized(self.id, self.val().inner())\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        Param::initialized(module.id, Tensor::from_inner(module.val()))\n    }\n}\n\nimpl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for Param<Tensor<B, D, Bool>> {\n    type InnerModule = Param<Tensor<B::InnerBackend, D, Bool>>;\n\n    fn valid(&self) -> Self::InnerModule {\n        Param::initialized(self.id, self.val().inner())\n    }\n\n    fn from_inner(module: Self::InnerModule) -> Self {\n        Param::initialized(module.id, Tensor::from_inner(module.val()))\n    }\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use super::*;\n    use crate::{\n        TestAutodiffBackend,\n        module::Module,\n        record::{BinBytesRecorder, FullPrecisionSettings, Recorder},\n    };\n\n    #[test]\n    fn test_load_record_setting() {\n        let device = Default::default();\n        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();\n\n        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();\n        let bytes = byte_recorder\n            .record(\n                Param::initialized(ParamId::new(), tensor.clone()).into_record(),\n                (),\n            )\n            .unwrap();\n\n        let no_grad_is_require_grad = Param::initialized(ParamId::new(), tensor.clone())\n            .no_grad()\n            .load_record(byte_recorder.load(bytes.clone(), &device).unwrap())\n            .is_require_grad();\n\n        let with_default_is_require_grad = Param::initialized(ParamId::new(), tensor)\n            .load_record(byte_recorder.load(bytes, &device).unwrap())\n            .is_require_grad();\n\n        assert!(!no_grad_is_require_grad);\n        assert!(with_default_is_require_grad);\n    }\n\n    #[test]\n    fn test_param_require_grad_stateful() {\n        let device = Default::default();\n        let tensor = Tensor::<TestAutodiffBackend, 2>::ones([3, 3], &device).require_grad();\n\n        let param = Param::initialized(ParamId::new(), tensor);\n        assert!(param.is_require_grad());\n        assert!(param.require_grad);\n\n        let param = param.valid();\n        assert!(!param.is_require_grad());\n        assert!(param.require_grad); // stateful\n\n        // Without `HasAutodiffModule`, we would need to specify the param type as well, which would be annoying:\n        // let param: Param<Tensor<TestAutodiffBackend, _>> = param.train();\n        let param = param.train::<TestAutodiffBackend>();\n        assert!(param.is_require_grad());\n        assert!(param.require_grad); // stateful\n\n        let param = param.no_grad();\n        assert!(!param.is_require_grad());\n        assert!(!param.require_grad); // stateful\n\n        let param = param.valid();\n        assert!(!param.is_require_grad()); // always\n        assert!(!param.require_grad); // stateful\n\n        let param = param.train::<TestAutodiffBackend>();\n        assert!(!param.is_require_grad());\n        assert!(!param.require_grad); // stateful\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/param/visitor.rs",
    "content": "use super::{Param, ParamId};\nuse crate::module::{Module, ModuleVisitor};\nuse alloc::vec::Vec;\nuse burn_tensor::{Bool, Int, Tensor, backend::Backend};\nuse core::marker::PhantomData;\n\nstruct ParamIdCollector<'a, M> {\n    param_ids: &'a mut Vec<ParamId>,\n    phantom: PhantomData<M>,\n}\n\nimpl<B, M> ModuleVisitor<B> for ParamIdCollector<'_, M>\nwhere\n    B: Backend,\n    M: Module<B>,\n{\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        self.param_ids.push(param.id);\n    }\n    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {\n        self.param_ids.push(param.id);\n    }\n    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {\n        self.param_ids.push(param.id);\n    }\n}\n\n/// List all the parameter ids in a module.\npub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {\n    let mut params_ids = Vec::new();\n    let mut visitor = ParamIdCollector {\n        param_ids: &mut params_ids,\n        phantom: PhantomData::<M>,\n    };\n    module.visit(&mut visitor);\n\n    params_ids\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/quantize.rs",
    "content": "use burn_tensor::{\n    Tensor,\n    backend::Backend,\n    quantization::{Calibration, QuantScheme, compute_q_params, compute_range},\n};\n\nuse crate::module::{ModuleMapper, Param};\n\n/// Describes how to quantize a module.\npub struct Quantizer {\n    /// The calibration method used in quantization.\n    pub calibration: Calibration,\n    /// The quantization scheme.\n    pub scheme: QuantScheme,\n}\n\nimpl<B: Backend> ModuleMapper<B> for Quantizer {\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        let range = compute_range(&self.scheme, &tensor, &self.calibration);\n        let qparams = compute_q_params(&self.scheme, range);\n        let tensor = tensor.quantize(&self.scheme, qparams);\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n}\n\n#[cfg(all(test, not(feature = \"test-tch\")))]\nmod tests {\n    use crate::test_utils::SimpleLinear;\n    use crate::{\n        TestBackend,\n        module::{Module, Quantizer},\n    };\n    use burn_tensor::{\n        Device, Tolerance,\n        ops::QuantizedTensor,\n        quantization::{Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue},\n    };\n\n    type B = TestBackend;\n\n    #[test]\n    fn should_quantize_module() {\n        let device: Device<B> = Default::default();\n        let module = SimpleLinear::<B>::new(32, 32, &device);\n        let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n            .with_value(QuantValue::Q8S)\n            .with_level(QuantLevel::Tensor)\n            .with_param(QuantParam::F32);\n\n        let result = module.weight.val();\n\n        let calibration = Calibration::MinMax;\n        let mut quantizer = Quantizer {\n            calibration,\n            scheme,\n        };\n        let q_module = module.quantize_weights(&mut quantizer);\n        let q_result = q_module.weight.val().dequantize();\n\n        result\n            .into_data()\n            .assert_approx_eq::<f32>(&q_result.into_data(), Tolerance::permissive());\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/module/reinit.rs",
    "content": "use super::{Module, ModuleMapper};\nuse burn_tensor::{\n    Element, ElementConversion, Tensor, TensorData,\n    backend::Backend,\n    ops::{FloatElem, IntElem},\n};\nuse rand::{RngExt, SeedableRng};\n\n#[derive(Debug)]\n/// Overrides float and int tensors of [burn modules](super::Module).\n///\n/// This is useful for testing.\npub struct Reinitializer<B: Backend> {\n    float: ReinitStrategy<FloatElem<B>>,\n    int: ReinitStrategy<IntElem<B>>,\n}\n\n#[derive(Debug)]\n#[allow(missing_docs)]\nenum ReinitStrategy<E> {\n    Range { min: E, max: E },\n    Constant { value: E },\n    Random { seed: u64, min: E, max: E },\n}\n\nimpl<B: Backend> Default for Reinitializer<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> Reinitializer<B> {\n    /// Create a new [reinitializer](Reinitializer).\n    pub fn new() -> Self {\n        Self {\n            float: ReinitStrategy::Constant {\n                value: 0.elem::<FloatElem<B>>(),\n            },\n            int: ReinitStrategy::Constant {\n                value: 0.elem::<IntElem<B>>(),\n            },\n        }\n    }\n\n    /// Apply the reinitialization to the given [module](Module).\n    pub fn apply<M: Module<B>>(mut self, module: M) -> M {\n        module.map(&mut self)\n    }\n\n    /// Set the reinitialization strategy to constant for all tensors.\n    pub fn constant(self, constant: f64) -> Self {\n        self.constant_float(constant).constant_int(constant as i64)\n    }\n\n    /// Set the reinitialization strategy to constant for float tensors.\n    pub fn constant_float(mut self, constant: f64) -> Self {\n        self.float = ReinitStrategy::Constant {\n            value: constant.elem(),\n        };\n        self\n    }\n\n    /// Set the reinitialization strategy to constant for int tensors.\n    pub fn constant_int(mut self, constant: i64) -> Self {\n        self.int = ReinitStrategy::Constant {\n            value: constant.elem(),\n        };\n        self\n    }\n    /// Set the reinitialization strategy to random for all tensors.\n    pub fn random(self, seed: u64, min: f64, max: f64) -> Self {\n        self.random_float(seed, min, max)\n            .random_int(seed, min as i64, max as i64)\n    }\n\n    /// Set the reinitialization strategy to random for float tensors.\n    pub fn random_float(mut self, seed: u64, min: f64, max: f64) -> Self {\n        self.float = ReinitStrategy::Random {\n            seed,\n            min: min.elem(),\n            max: max.elem(),\n        };\n        self\n    }\n\n    /// Set the reinitialization strategy to random for int tensors.\n    pub fn random_int(mut self, seed: u64, min: i64, max: i64) -> Self {\n        self.int = ReinitStrategy::Random {\n            seed,\n            min: min.elem(),\n            max: max.elem(),\n        };\n        self\n    }\n\n    /// Set the reinitialization strategy to range for all tensors.\n    pub fn range(self, min: f64, max: f64) -> Self {\n        self.range_float(min, max).range_int(min as i64, max as i64)\n    }\n\n    /// Set the reinitialization strategy to range for float tensors.\n    pub fn range_float(mut self, min: f64, max: f64) -> Self {\n        self.float = ReinitStrategy::Range {\n            min: min.elem(),\n            max: max.elem(),\n        };\n        self\n    }\n\n    /// Set the reinitialization strategy to range for int tensors.\n    pub fn range_int(mut self, min: i64, max: i64) -> Self {\n        self.int = ReinitStrategy::Range {\n            min: min.elem(),\n            max: max.elem(),\n        };\n        self\n    }\n}\n\nimpl<B: Backend> ModuleMapper<B> for Reinitializer<B> {\n    fn map_float<const D: usize>(\n        &mut self,\n        param: super::Param<Tensor<B, D>>,\n    ) -> super::Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        let device = tensor.device();\n        let shape = tensor.shape();\n        let num_elements = shape.num_elements();\n\n        let tensor = match &self.float {\n            ReinitStrategy::Range { min, max } => {\n                let tensor = Tensor::arange(0..num_elements as i64, &device)\n                    .reshape(shape)\n                    .float();\n                let (factor, bias) = resolve::<FloatElem<B>>(*min, *max, num_elements);\n                tensor * factor + bias\n            }\n            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),\n            ReinitStrategy::Random { seed, min, max } => {\n                let data = TensorData::new(\n                    random_vector::<FloatElem<B>>(*seed, min.elem(), max.elem(), num_elements),\n                    shape,\n                );\n                Tensor::from_data(data, &device)\n            }\n        };\n\n        super::Param::from_mapped_value(id, tensor, mapper)\n    }\n\n    fn map_int<const D: usize>(\n        &mut self,\n        param: super::Param<Tensor<B, D, burn_tensor::Int>>,\n    ) -> super::Param<Tensor<B, D, burn_tensor::Int>> {\n        let (id, tensor, mapper) = param.consume();\n        let device = tensor.device();\n        let shape = tensor.shape();\n        let num_elements = shape.num_elements();\n\n        let tensor = match &self.int {\n            ReinitStrategy::Range { min, max } => {\n                let tensor = Tensor::arange(0..num_elements as i64, &device).reshape(shape);\n                let (factor, bias) = resolve::<IntElem<B>>(*min, *max, num_elements);\n                tensor * factor + bias\n            }\n            ReinitStrategy::Constant { value } => Tensor::full(shape, *value, &device),\n            ReinitStrategy::Random { seed, min, max } => {\n                let data = TensorData::new(\n                    random_vector::<IntElem<B>>(*seed, min.elem(), max.elem(), num_elements),\n                    shape,\n                );\n                Tensor::from_data(data, &device)\n            }\n        };\n\n        super::Param::from_mapped_value(id, tensor, mapper)\n    }\n\n    fn map_bool<const D: usize>(\n        &mut self,\n        param: super::Param<Tensor<B, D, burn_tensor::Bool>>,\n    ) -> super::Param<Tensor<B, D, burn_tensor::Bool>> {\n        let (id, tensor, mapper) = param.consume();\n        super::Param::from_mapped_value(id, tensor, mapper)\n    }\n}\n\nfn resolve<E: Element>(min: E, max: E, num_elements: usize) -> (E, E) {\n    let range = max.elem::<f64>() - min.elem::<f64>();\n    let factor = range / num_elements as f64;\n    let bias = min.elem::<f64>();\n\n    (factor.elem(), bias.elem())\n}\n\nfn random_vector<E: Element>(seed: u64, min: f64, max: f64, num_elements: usize) -> Vec<E> {\n    let mut rng = rand::rngs::StdRng::seed_from_u64(seed);\n    let dist = rand::distr::Uniform::new(min, max).unwrap();\n    (0..num_elements)\n        .map(|_| rng.sample(dist))\n        .map(|e| e.elem::<E>())\n        .collect()\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/base.rs",
    "content": "pub use burn_derive::Record;\nuse burn_tensor::backend::Backend;\n\nuse super::PrecisionSettings;\nuse serde::{Serialize, de::DeserializeOwned};\n\n/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).\npub trait Record<B: Backend>: Send {\n    /// Type of the item that can be serialized and deserialized.\n    type Item<S: PrecisionSettings>: Serialize + DeserializeOwned + Clone;\n\n    /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;\n\n    /// Convert the given item into a record.\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self;\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/file.rs",
    "content": "use super::{PrecisionSettings, Recorder, RecorderError, bin_config};\nuse burn_tensor::backend::Backend;\nuse core::marker::PhantomData;\nuse flate2::{Compression, read::GzDecoder, write::GzEncoder};\nuse serde::{Serialize, de::DeserializeOwned};\nuse std::io::{BufReader, BufWriter};\nuse std::{fs::File, path::PathBuf};\n\n/// Recorder trait specialized to save and load data to and from files.\npub trait FileRecorder<B: Backend>:\n    Recorder<B, RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>\n{\n    /// File extension of the format used by the recorder.\n    fn file_extension() -> &'static str;\n}\n\n/// Default [file recorder](FileRecorder).\npub type DefaultFileRecorder<S> = NamedMpkFileRecorder<S>;\n\n/// File recorder using the [bincode format](bincode).\n#[derive(new, Debug, Default, Clone)]\npub struct BinFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\n/// File recorder using the [bincode format](bincode) compressed with gzip.\n#[derive(new, Debug, Default, Clone)]\npub struct BinGzFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\n/// File recorder using the [json format](serde_json) compressed with gzip.\n#[derive(new, Debug, Default, Clone)]\npub struct JsonGzFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\n/// File recorder using [pretty json format](serde_json) for easy readability.\n#[derive(new, Debug, Default, Clone)]\npub struct PrettyJsonFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\n/// File recorder using the [named msgpack](rmp_serde) format compressed with gzip.\n#[derive(new, Debug, Default, Clone)]\npub struct NamedMpkGzFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\n/// File recorder using the [named msgpack](rmp_serde) format.\n#[derive(new, Debug, Default, Clone)]\npub struct NamedMpkFileRecorder<S: PrecisionSettings> {\n    _settings: PhantomData<S>,\n}\n\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinGzFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"bin.gz\"\n    }\n}\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for BinFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"bin\"\n    }\n}\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for JsonGzFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"json.gz\"\n    }\n}\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for PrettyJsonFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"json\"\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkGzFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"mpk.gz\"\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> FileRecorder<B> for NamedMpkFileRecorder<S> {\n    fn file_extension() -> &'static str {\n        \"mpk\"\n    }\n}\n\nmacro_rules! str2reader {\n    (\n        $file:expr\n    ) => {{\n        $file.set_extension(<Self as FileRecorder<B>>::file_extension());\n        let path = $file.as_path();\n\n        File::open(path)\n            .map_err(|err| match err.kind() {\n                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),\n                _ => RecorderError::Unknown(err.to_string()),\n            })\n            .map(|file| BufReader::new(file))\n    }};\n}\n\nmacro_rules! str2writer {\n    (\n        $file:expr\n    ) => {{\n        $file.set_extension(<Self as FileRecorder<B>>::file_extension());\n        let path = $file.as_path();\n\n        log::debug!(\"Writing to file: {:?}\", path);\n\n        // Add parent directories if they don't exist\n        if let Some(parent) = path.parent() {\n            std::fs::create_dir_all(parent).ok();\n        }\n\n        if path.exists() {\n            log::warn!(\"File exists, replacing\");\n            std::fs::remove_file(path).map_err(|err| RecorderError::Unknown(err.to_string()))?;\n        }\n\n        File::create(path)\n            .map_err(|err| match err.kind() {\n                std::io::ErrorKind::NotFound => RecorderError::FileNotFound(err.to_string()),\n                _ => RecorderError::Unknown(err.to_string()),\n            })\n            .map(|file| BufWriter::new(file))\n    }};\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for BinGzFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let config = bin_config();\n        let writer = str2writer!(file)?;\n        let mut writer = GzEncoder::new(writer, Compression::default());\n\n        bincode::serde::encode_into_std_write(&item, &mut writer, config)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let reader = str2reader!(file)?;\n        let mut reader = GzDecoder::new(reader);\n        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(state)\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for BinFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let config = bin_config();\n        let mut writer = str2writer!(file)?;\n        bincode::serde::encode_into_std_write(&item, &mut writer, config)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let mut reader = str2reader!(file)?;\n        let state = bincode::serde::decode_from_std_read(&mut reader, bin_config())\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n        Ok(state)\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for JsonGzFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let writer = str2writer!(file)?;\n        let writer = GzEncoder::new(writer, Compression::default());\n        serde_json::to_writer(writer, &item)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let reader = str2reader!(file)?;\n        let reader = GzDecoder::new(reader);\n        let state = serde_json::from_reader(reader)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(state)\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for PrettyJsonFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let writer = str2writer!(file)?;\n        serde_json::to_writer_pretty(writer, &item)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let reader = str2reader!(file)?;\n        let state = serde_json::from_reader(reader)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(state)\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkGzFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let writer = str2writer!(file)?;\n        let mut writer = GzEncoder::new(writer, Compression::default());\n        rmp_serde::encode::write_named(&mut writer, &item)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let reader = str2reader!(file)?;\n        let reader = GzDecoder::new(reader);\n        let state = rmp_serde::decode::from_read(reader)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(state)\n    }\n}\n\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkFileRecorder<S> {\n    type Settings = S;\n    type RecordArgs = PathBuf;\n    type RecordOutput = ();\n    type LoadArgs = PathBuf;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        mut file: Self::RecordArgs,\n    ) -> Result<(), RecorderError> {\n        let mut writer = str2writer!(file)?;\n\n        rmp_serde::encode::write_named(&mut writer, &item)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        file: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let reader = str2reader!(file)?;\n        let state = rmp_serde::decode::from_read(reader)\n            .map_err(|err| RecorderError::Unknown(err.to_string()))?;\n\n        Ok(state)\n    }\n}\n\n#[allow(deprecated)]\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate as burn;\n    use crate::config::Config;\n    use crate::module::Ignored;\n    use crate::test_utils::SimpleLinear;\n    use crate::{\n        TestBackend,\n        module::Module,\n        record::{BinBytesRecorder, FullPrecisionSettings},\n    };\n    use burn_tensor::Tensor;\n    use burn_tensor::backend::Backend;\n\n    #[inline(always)]\n    fn file_path(file: &str) -> PathBuf {\n        std::env::temp_dir().as_path().join(file)\n    }\n\n    #[test]\n    fn test_can_save_and_load_jsongz_format() {\n        test_can_save_and_load(JsonGzFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[test]\n    fn test_can_save_and_load_bin_format() {\n        test_can_save_and_load(BinFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[test]\n    fn test_can_save_and_load_bingz_format() {\n        test_can_save_and_load(BinGzFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[test]\n    fn test_can_save_and_load_pretty_json_format() {\n        test_can_save_and_load(PrettyJsonFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[test]\n    fn test_can_save_and_load_mpkgz_format() {\n        test_can_save_and_load(NamedMpkGzFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[test]\n    fn test_can_save_and_load_mpk_format() {\n        test_can_save_and_load(NamedMpkFileRecorder::<FullPrecisionSettings>::default())\n    }\n\n    fn test_can_save_and_load<Recorder>(recorder: Recorder)\n    where\n        Recorder: FileRecorder<TestBackend>,\n    {\n        let filename = \"burn_test_file_recorder\";\n\n        let device = Default::default();\n        let mut model_before = create_model(&device);\n\n        // NOTE: Non-module fields currently act like `#[module(skip)]`, meaning their state\n        // is not persistent. These fields hold `EmptyRecord`s.\n        // So `model_bytes_after == model_bytes_before` because the changes do not persist in the record.\n        model_before.tensor = Tensor::full([4], 2., &device);\n        model_before.arr = [3, 3];\n        model_before.int = 1;\n        model_before.ignore = Ignored(PaddingConfig2d::Valid);\n\n        recorder\n            .record(model_before.clone().into_record(), file_path(filename))\n            .unwrap();\n\n        let model_after =\n            create_model(&device).load_record(recorder.load(file_path(filename), &device).unwrap());\n\n        // State is not persisted for empty record fields\n        assert_eq!(model_after.arr, [2, 2]);\n        assert_eq!(model_after.int, 0);\n        assert_eq!(model_after.ignore.0, PaddingConfig2d::Same);\n\n        let byte_recorder = BinBytesRecorder::<FullPrecisionSettings>::default();\n        let model_bytes_before = byte_recorder\n            .record(model_before.into_record(), ())\n            .unwrap();\n        let model_bytes_after = byte_recorder.record(model_after.into_record(), ()).unwrap();\n\n        assert_eq!(model_bytes_after, model_bytes_before);\n    }\n\n    #[derive(Config, Debug, PartialEq, Eq)]\n    pub enum PaddingConfig2d {\n        Same,\n        Valid,\n        Explicit(usize, usize),\n    }\n\n    // Dummy model with different record types\n    #[derive(Module, Debug)]\n    pub struct Model<B: Backend> {\n        linear1: SimpleLinear<B>,\n        phantom: PhantomData<B>,\n        tensor: Tensor<B, 1>,\n        arr: [usize; 2],\n        int: usize,\n        ignore: Ignored<PaddingConfig2d>,\n    }\n\n    pub fn create_model(device: &<TestBackend as Backend>::Device) -> Model<TestBackend> {\n        let linear1 = SimpleLinear::new(32, 32, device);\n\n        Model {\n            linear1,\n            phantom: PhantomData,\n            tensor: Tensor::zeros([2], device),\n            arr: [2, 2],\n            int: 0,\n            ignore: Ignored(PaddingConfig2d::Same),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/memory.rs",
    "content": "use super::{PrecisionSettings, Recorder, RecorderError, bin_config};\nuse alloc::vec::Vec;\nuse burn_tensor::backend::Backend;\nuse serde::{Serialize, de::DeserializeOwned};\n\n/// Recorder trait specialized to save and load data to and from bytes.\n///\n/// # Notes\n///\n/// This is especially useful in no_std environment where weights are stored directly in\n/// compiled binaries.\npub trait BytesRecorder<\n    B: Backend,\n    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,\n>: Recorder<B, RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = L>\n{\n}\n\n/// In memory recorder using the [bincode format](bincode).\n#[derive(new, Debug, Default, Clone)]\npub struct BinBytesRecorder<\n    S: PrecisionSettings,\n    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default = Vec<u8>,\n> {\n    _settings: core::marker::PhantomData<S>,\n    _loadargs: core::marker::PhantomData<L>,\n}\n\nimpl<\n    S: PrecisionSettings,\n    B: Backend,\n    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,\n> BytesRecorder<B, L> for BinBytesRecorder<S, L>\n{\n}\n\nimpl<\n    S: PrecisionSettings,\n    B: Backend,\n    L: AsRef<[u8]> + Send + Sync + core::fmt::Debug + Clone + core::default::Default,\n> Recorder<B> for BinBytesRecorder<S, L>\n{\n    type Settings = S;\n    type RecordArgs = ();\n    type RecordOutput = Vec<u8>;\n    type LoadArgs = L;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        _args: Self::RecordArgs,\n    ) -> Result<Self::RecordOutput, RecorderError> {\n        Ok(bincode::serde::encode_to_vec(item, bin_config()).unwrap())\n    }\n\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        args: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        let state = bincode::borrow_decode_from_slice::<'_, bincode::serde::BorrowCompat<I>, _>(\n            args.as_ref(),\n            bin_config(),\n        )\n        .unwrap()\n        .0;\n        Ok(state.0)\n    }\n}\n\n#[cfg(feature = \"std\")]\n/// In memory recorder using the [Named MessagePack](rmp_serde).\n#[derive(new, Debug, Default, Clone)]\npub struct NamedMpkBytesRecorder<S: PrecisionSettings> {\n    _settings: core::marker::PhantomData<S>,\n}\n\n#[cfg(feature = \"std\")]\nimpl<S: PrecisionSettings, B: Backend> BytesRecorder<B, Vec<u8>> for NamedMpkBytesRecorder<S> {}\n\n#[cfg(feature = \"std\")]\nimpl<S: PrecisionSettings, B: Backend> Recorder<B> for NamedMpkBytesRecorder<S> {\n    type Settings = S;\n    type RecordArgs = ();\n    type RecordOutput = Vec<u8>;\n    type LoadArgs = Vec<u8>;\n\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        _args: Self::RecordArgs,\n    ) -> Result<Self::RecordOutput, RecorderError> {\n        rmp_serde::encode::to_vec_named(&item).map_err(|e| RecorderError::Unknown(e.to_string()))\n    }\n    fn load_item<I: DeserializeOwned>(\n        &self,\n        args: &mut Self::LoadArgs,\n    ) -> Result<I, RecorderError> {\n        rmp_serde::decode::from_slice(args).map_err(|e| RecorderError::Unknown(e.to_string()))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::test_utils::SimpleLinear;\n    use crate::{\n        TestBackend, module::Module, record::FullPrecisionSettings, tensor::backend::Backend,\n    };\n\n    #[test]\n    fn test_can_save_and_load_bin_format() {\n        test_can_save_and_load(BinBytesRecorder::<FullPrecisionSettings>::default())\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn test_can_save_and_load_named_mpk_format() {\n        test_can_save_and_load(NamedMpkBytesRecorder::<FullPrecisionSettings>::default())\n    }\n\n    fn test_can_save_and_load<Recorder>(recorder: Recorder)\n    where\n        Recorder: BytesRecorder<TestBackend, Vec<u8>>,\n    {\n        let device = Default::default();\n        let model1 = create_model::<TestBackend>(&device);\n        let model2 = create_model::<TestBackend>(&device);\n        let bytes1 = recorder.record(model1.into_record(), ()).unwrap();\n        let bytes2 = recorder.record(model2.clone().into_record(), ()).unwrap();\n\n        let model2_after = model2.load_record(recorder.load(bytes1.clone(), &device).unwrap());\n        let bytes2_after = recorder.record(model2_after.into_record(), ()).unwrap();\n\n        assert_ne!(bytes1, bytes2);\n        assert_eq!(bytes1, bytes2_after);\n    }\n\n    pub fn create_model<B: Backend>(device: &B::Device) -> SimpleLinear<B> {\n        SimpleLinear::new(32, 32, device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/mod.rs",
    "content": "mod primitive;\nmod tensor;\n\nmod base;\nmod memory;\nmod recorder;\nmod settings;\n\npub use base::*;\npub use memory::*;\npub use recorder::*;\npub use settings::*;\n\n#[cfg(feature = \"std\")]\nmod file;\n#[cfg(feature = \"std\")]\npub use file::*;\n\npub use primitive::ParamSerde;\n\n#[cfg(feature = \"record-item-custom-serde\")]\npub mod serde;\n"
  },
  {
    "path": "crates/burn-core/src/record/primitive.rs",
    "content": "use alloc::{string::String, vec, vec::Vec};\nuse core::{fmt, marker::PhantomData};\n\nuse super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde};\nuse super::{PrecisionSettings, Record};\nuse crate::module::{Param, ParamId};\n\nuse burn_tensor::{Bool, Int, Tensor, backend::Backend};\n\nuse hashbrown::HashMap;\nuse serde::{\n    Deserialize, Serialize,\n    de::{Error, SeqAccess, Visitor},\n    ser::SerializeTuple,\n};\n\nimpl<B> Record<B> for ()\nwhere\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = ();\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}\n\n    fn from_item<S: PrecisionSettings>(_item: Self::Item<S>, _device: &B::Device) -> Self {}\n}\n\nimpl<T, B> Record<B> for Vec<T>\nwhere\n    T: Record<B>,\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = Vec<T::Item<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        self.into_iter().map(Record::into_item).collect()\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        item.into_iter()\n            .map(|i| Record::from_item(i, device))\n            .collect()\n    }\n}\n\nimpl<T, B> Record<B> for Option<T>\nwhere\n    T: Record<B>,\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = Option<T::Item<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        self.map(Record::into_item)\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        item.map(|i| Record::from_item(i, device))\n    }\n}\n\nimpl<const N: usize, T, B> Record<B> for [T; N]\nwhere\n    T: Record<B>,\n    B: Backend,\n{\n    /// The record item is an array of the record item of the elements.\n    /// The reason why we wrap the array in a struct is because serde does not support\n    /// deserializing arrays of variable size,\n    /// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937).\n    /// for backward compatibility reasons. Serde APIs were created before const generics.\n    type Item<S: PrecisionSettings> = Array<N, T::Item<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        Array(self.map(Record::into_item))\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        item.0.map(|i| Record::from_item(i, device))\n    }\n}\n\n/// A macro for generating implementations for tuple records of different sizes.\n/// For example: `impl_record_tuple!([R0, R1][0, 1])`.\n/// Would generate an implementation for a tuple of size 2.\n/// For this macro to work properly, please adhere to the convention:\n/// `impl_record_tuple!([R0, R1, ..., Rn][0, 1, ..., n])`.\nmacro_rules! impl_record_tuple {\n    // `$r` represents the generic records.\n    // `$i` represents the indices of the records in the tuple.\n    ([$($r:ident),*][$($i:tt),*]) => {\n        impl<B, $($r,)*> Record<B> for ($($r,)*)\n        where\n            B: Backend,\n            $($r: Record<B>),*\n        {\n            type Item<S: PrecisionSettings> = ($($r::Item<S>,)*);\n\n            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n                ($(self.$i.into_item(),)*)\n            }\n\n            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n                ($(Record::from_item(item.$i, device),)*)\n            }\n        }\n    };\n}\n\nimpl_record_tuple!([R0, R1][0, 1]);\nimpl_record_tuple!([R0, R1, R2][0, 1, 2]);\nimpl_record_tuple!([R0, R1, R2, R3][0, 1, 2, 3]);\nimpl_record_tuple!([R0, R1, R2, R3, R4][0, 1, 2, 3, 4]);\nimpl_record_tuple!([R0, R1, R2, R3, R4, R5][0, 1, 2, 3, 4, 5]);\nimpl_record_tuple!([R0, R1, R2, R3, R4, R5, R6][0, 1, 2, 3, 4, 5, 6]);\nimpl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7][0, 1, 2, 3, 4, 5, 6, 7]);\nimpl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8][0, 1, 2, 3, 4, 5, 6, 7, 8]);\nimpl_record_tuple!([R0, R1, R2, R3, R4, R5, R6, R7, R8, R9][0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);\n\nimpl<T, B> Record<B> for HashMap<ParamId, T>\nwhere\n    T: Record<B>,\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        let mut items = HashMap::with_capacity(self.len());\n        self.into_iter().for_each(|(id, record)| {\n            items.insert(id.serialize(), record.into_item());\n        });\n        items\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        let mut record = HashMap::with_capacity(item.len());\n        item.into_iter().for_each(|(id, item)| {\n            record.insert(ParamId::deserialize(&id), T::from_item(item, device));\n        });\n        record\n    }\n}\n\n/// (De)serialize parameters into a clean format.\n#[derive(new, Debug, Clone, Serialize, Deserialize)]\npub struct ParamSerde<T> {\n    id: String,\n    param: T,\n}\n\nimpl<B, const D: usize> Record<B> for Param<Tensor<B, D>>\nwhere\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = ParamSerde<FloatTensorSerde<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        let (id, tensor, mapper) = self.consume();\n        let tensor = mapper.on_save(tensor);\n        ParamSerde::new(id.serialize(), tensor.into_item())\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        B::memory_persistent_allocations(device, item, |item| {\n            Param::initialized(\n                ParamId::deserialize(&item.id),\n                Tensor::from_item(item.param, device).require_grad(), // Same behavior as when we create a new\n                                                                      // Param from a tensor.\n            )\n        })\n    }\n}\n\nimpl<B, const D: usize> Record<B> for Param<Tensor<B, D, Int>>\nwhere\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = ParamSerde<IntTensorSerde<S>>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        let (id, tensor, mapper) = self.consume();\n        let tensor = mapper.on_save(tensor);\n        ParamSerde::new(id.serialize(), tensor.into_item())\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        B::memory_persistent_allocations(device, item, |item| {\n            Param::initialized(\n                ParamId::deserialize(&item.id),\n                Tensor::from_item(item.param, device),\n            )\n        })\n    }\n}\n\nimpl<B, const D: usize> Record<B> for Param<Tensor<B, D, Bool>>\nwhere\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = ParamSerde<BoolTensorSerde>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        let (id, tensor, mapper) = self.consume();\n        let tensor = mapper.on_save(tensor);\n        ParamSerde::new(id.serialize(), tensor.into_item::<S>())\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        B::memory_persistent_allocations(device, item, |item| {\n            Param::initialized(\n                ParamId::deserialize(&item.id),\n                Tensor::from_item::<S>(item.param, device),\n            )\n        })\n    }\n}\n\n// Type that can be serialized as is without any conversion.\nmacro_rules! primitive {\n    ($type:ty) => {\n        impl<B: Backend> Record<B> for $type {\n            type Item<S: PrecisionSettings> = $type;\n\n            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n                self\n            }\n\n            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {\n                item\n            }\n        }\n    };\n}\n\n// General Types\nprimitive!(alloc::string::String);\nprimitive!(bool);\n\n// Float Types\nprimitive!(f64);\nprimitive!(f32);\n\nprimitive!(half::bf16);\nprimitive!(half::f16);\n\n// Unsigned Integer Types\nprimitive!(usize);\nprimitive!(u64);\nprimitive!(u32);\nprimitive!(u16);\nprimitive!(u8);\n\n// Signed Integer Types\nprimitive!(isize);\nprimitive!(i64);\nprimitive!(i32);\nprimitive!(i16);\nprimitive!(i8);\n\n/// A wrapper around an array of size N, so that it can be serialized and deserialized\n/// using serde.\n///\n/// The reason why we wrap the array in a struct is because serde does not support\n/// deserializing arrays of variable size,\n/// see [serde/issues/1937](https://github.com/serde-rs/serde/issues/1937)\n/// for backward compatibility reasons. Serde APIs were created before const generics.\n#[derive(Clone)]\npub struct Array<const N: usize, T>([T; N]);\n\nimpl<T: Serialize, const N: usize> Serialize for Array<N, T> {\n    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n    where\n        S: serde::Serializer,\n    {\n        let mut seq = serializer.serialize_tuple(self.0.len())?;\n        for element in &self.0 {\n            seq.serialize_element(element)?;\n        }\n        seq.end()\n    }\n}\n\nimpl<'de, T, const N: usize> Deserialize<'de> for Array<N, T>\nwhere\n    T: Deserialize<'de>,\n{\n    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>\n    where\n        D: serde::Deserializer<'de>,\n    {\n        struct ArrayVisitor<T, const N: usize> {\n            marker: PhantomData<T>,\n        }\n\n        impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>\n        where\n            T: Deserialize<'de>,\n        {\n            type Value = Array<N, T>;\n\n            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {\n                formatter.write_str(\"a fixed size array\")\n            }\n\n            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>\n            where\n                A: SeqAccess<'de>,\n            {\n                let mut items = vec![];\n\n                for i in 0..N {\n                    let item = seq\n                        .next_element()?\n                        .ok_or_else(|| Error::invalid_length(i, &self))?;\n                    items.push(item);\n                }\n\n                let array: [T; N] = items\n                    .into_iter()\n                    .collect::<Vec<_>>()\n                    .try_into()\n                    .map_err(|_| \"An array of size {N}\")\n                    .unwrap();\n\n                Ok(Array(array))\n            }\n        }\n\n        deserializer.deserialize_tuple(\n            N,\n            ArrayVisitor {\n                marker: PhantomData,\n            },\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/recorder.rs",
    "content": "use core::any::type_name;\nuse core::marker::PhantomData;\n\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse burn_tensor::backend::Backend;\nuse serde::{Deserialize, Serialize, de::DeserializeOwned};\n\nuse super::{BinBytesRecorder, FullPrecisionSettings, PrecisionSettings, Record};\n\n#[cfg(feature = \"std\")]\nuse super::{\n    BinFileRecorder, BinGzFileRecorder, DefaultFileRecorder, HalfPrecisionSettings,\n    PrettyJsonFileRecorder,\n};\n\n/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).\npub trait Recorder<B: Backend>:\n    Send + Sync + core::default::Default + core::fmt::Debug + Clone\n{\n    /// Type of the settings used by the recorder.\n    type Settings: PrecisionSettings;\n\n    /// Arguments used to record objects.\n    type RecordArgs: Clone;\n\n    /// Record output type.\n    type RecordOutput;\n\n    /// Arguments used to load recorded objects.\n    type LoadArgs;\n\n    /// Records an item.\n    ///\n    /// # Arguments\n    ///\n    /// * `record` - The item to record.\n    /// * `args` - Arguments used to record the item.\n    ///\n    /// # Returns\n    ///\n    /// The output of the recording.\n    fn record<R>(\n        &self,\n        record: R,\n        args: Self::RecordArgs,\n    ) -> Result<Self::RecordOutput, RecorderError>\n    where\n        R: Record<B>,\n    {\n        let item = record.into_item::<Self::Settings>();\n        let item = BurnRecord::new::<Self>(item);\n\n        self.save_item(item, args)\n    }\n\n    /// Load an item from the given arguments.\n    fn load<R>(&self, mut args: Self::LoadArgs, device: &B::Device) -> Result<R, RecorderError>\n    where\n        R: Record<B>,\n    {\n        let item: BurnRecord<R::Item<Self::Settings>, B> =\n            self.load_item(&mut args).map_err(|err| {\n                if let Ok(record) = self.load_item::<BurnRecordNoItem>(&mut args) {\n                    let mut message = \"Unable to load record.\".to_string();\n                    let metadata = recorder_metadata::<Self, B>();\n                    if metadata.float != record.metadata.float {\n                        message += format!(\n                            \"\\nMetadata has a different float type: Actual {:?}, Expected {:?}\",\n                            record.metadata.float, metadata.float\n                        )\n                        .as_str();\n                    }\n                    if metadata.int != record.metadata.int {\n                        message += format!(\n                            \"\\nMetadata has a different int type: Actual {:?}, Expected {:?}\",\n                            record.metadata.int, metadata.int\n                        )\n                        .as_str();\n                    }\n                    if metadata.format != record.metadata.format {\n                        message += format!(\n                            \"\\nMetadata has a different format: Actual {:?}, Expected {:?}\",\n                            record.metadata.format, metadata.format\n                        )\n                        .as_str();\n                    }\n                    if metadata.version != record.metadata.version {\n                        message += format!(\n                            \"\\nMetadata has a different Burn version: Actual {:?}, Expected {:?}\",\n                            record.metadata.version, metadata.version\n                        )\n                        .as_str();\n                    }\n\n                    message += format!(\"\\nError: {err:?}\").as_str();\n\n                    return RecorderError::Unknown(message);\n                }\n\n                err\n            })?;\n\n        Ok(R::from_item(item.item, device))\n    }\n\n    /// Saves an item.\n    ///\n    /// This method is used by [record](Recorder::record) to save the item.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - Item to save.\n    /// * `args` - Arguments to use to save the item.\n    ///\n    /// # Returns\n    ///\n    /// The output of the save operation.\n    fn save_item<I: Serialize>(\n        &self,\n        item: I,\n        args: Self::RecordArgs,\n    ) -> Result<Self::RecordOutput, RecorderError>;\n\n    /// Loads an item.\n    ///\n    /// This method is used by [load](Recorder::load) to load the item.\n    ///\n    /// # Arguments\n    ///\n    /// * `args` - Arguments to use to load the item.\n    ///\n    /// # Returns\n    ///\n    /// The loaded item.\n    fn load_item<I>(&self, args: &mut Self::LoadArgs) -> Result<I, RecorderError>\n    where\n        I: DeserializeOwned;\n}\n\nfn recorder_metadata<R, B>() -> BurnMetadata\nwhere\n    R: Recorder<B>,\n    B: Backend,\n{\n    BurnMetadata::new(\n        type_name::<<R::Settings as PrecisionSettings>::FloatElem>().to_string(),\n        type_name::<<R::Settings as PrecisionSettings>::IntElem>().to_string(),\n        type_name::<R>().to_string(),\n        env!(\"CARGO_PKG_VERSION\").to_string(),\n        format!(\"{:?}\", R::Settings::default()),\n    )\n}\n\n/// Error that can occur when using a [Recorder](Recorder).\n#[derive(Debug)]\npub enum RecorderError {\n    /// File not found.\n    FileNotFound(String),\n\n    /// Failed to read file.\n    DeserializeError(String),\n\n    /// Other error.\n    Unknown(String),\n}\n\nimpl core::fmt::Display for RecorderError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(format!(\"{self:?}\").as_str())\n    }\n}\n\nimpl core::error::Error for RecorderError {}\n\npub(crate) fn bin_config() -> bincode::config::Configuration {\n    bincode::config::standard()\n}\n\n/// Metadata of a record.\n#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]\npub struct BurnMetadata {\n    /// Float type used to record the item.\n    pub float: String,\n\n    /// Int type used to record the item.\n    pub int: String,\n\n    /// Format used to record the item.\n    pub format: String,\n\n    /// Burn record version used to record the item.\n    pub version: String,\n\n    /// Settings used to record the item.\n    pub settings: String,\n}\n\n/// Record that can be saved by a [Recorder](Recorder).\n#[derive(Serialize, Deserialize, Debug)]\npub struct BurnRecord<I, B: Backend> {\n    /// Metadata of the record.\n    pub metadata: BurnMetadata,\n\n    /// Item to record.\n    pub item: I,\n\n    _b: PhantomData<B>,\n}\n\nimpl<I, B: Backend> BurnRecord<I, B> {\n    /// Creates a new record.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - Item to record.\n    ///\n    /// # Returns\n    ///\n    /// The new record.\n    pub fn new<R: Recorder<B>>(item: I) -> Self {\n        let metadata = recorder_metadata::<R, B>();\n\n        Self {\n            metadata,\n            item,\n            _b: PhantomData,\n        }\n    }\n}\n\n/// Record that can be saved by a [Recorder](Recorder) without the item.\n#[derive(new, Debug, Serialize, Deserialize)]\npub struct BurnRecordNoItem {\n    /// Metadata of the record.\n    pub metadata: BurnMetadata,\n}\n\n/// Default recorder.\n///\n/// It uses the [named msgpack](rmp_serde) format for serialization with full precision.\n#[cfg(feature = \"std\")]\npub type DefaultRecorder = DefaultFileRecorder<FullPrecisionSettings>;\n\n/// Recorder optimized for compactness.\n///\n/// It uses the [named msgpack](rmp_serde) format for serialization with half precision.\n/// If you are looking for the recorder that offers the smallest file size, have a look at\n/// [sensitive compact recorder](SensitiveCompactRecorder).\n#[cfg(feature = \"std\")]\npub type CompactRecorder = DefaultFileRecorder<HalfPrecisionSettings>;\n\n/// Recorder optimized for compactness making it a good choice for model deployment.\n///\n/// It uses the [bincode](bincode) format for serialization and half precision.\n/// This format is not resilient to type changes since no metadata is encoded.\n/// Favor [default recorder](DefaultRecorder) or [compact recorder](CompactRecorder)\n/// for long term data storage.\n#[cfg(feature = \"std\")]\npub type SensitiveCompactRecorder = BinGzFileRecorder<HalfPrecisionSettings>;\n\n/// Training recorder compatible with no-std inference.\n#[cfg(feature = \"std\")]\npub type NoStdTrainingRecorder = BinFileRecorder<FullPrecisionSettings>;\n\n/// Inference recorder compatible with no-std.\npub type NoStdInferenceRecorder = BinBytesRecorder<FullPrecisionSettings, &'static [u8]>;\n\n/// Debug recorder.\n///\n/// It uses the [pretty json](serde_json) format for serialization with full precision making it\n/// human readable.\n#[cfg(feature = \"std\")]\npub type DebugRecordSettings = PrettyJsonFileRecorder<FullPrecisionSettings>;\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    static FILE_PATH: &str = \"/tmp/burn_test_record\";\n\n    use crate::TestBackend;\n\n    use super::*;\n    use burn_tensor::{Device, ElementConversion};\n\n    #[test]\n    #[should_panic]\n    fn err_when_invalid_item() {\n        #[derive(new, Serialize, Deserialize, Clone)]\n        struct Item<S: PrecisionSettings> {\n            value: S::FloatElem,\n        }\n\n        impl<D, B> Record<B> for Item<D>\n        where\n            D: PrecisionSettings,\n            B: Backend,\n        {\n            type Item<S: PrecisionSettings> = Item<S>;\n\n            fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n                Item {\n                    value: self.value.elem(),\n                }\n            }\n\n            fn from_item<S: PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {\n                Item {\n                    value: item.value.elem(),\n                }\n            }\n        }\n\n        let item = Item::<FullPrecisionSettings>::new(16.elem());\n        let device: Device<TestBackend> = Default::default();\n\n        // Serialize in f32.\n        let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();\n        Recorder::<TestBackend>::record(&recorder, item, FILE_PATH.into()).unwrap();\n\n        // Can't deserialize f32 into f16.\n        let recorder = DefaultFileRecorder::<HalfPrecisionSettings>::new();\n        Recorder::<TestBackend>::load::<Item<FullPrecisionSettings>>(\n            &recorder,\n            FILE_PATH.into(),\n            &device,\n        )\n        .unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/adapter.rs",
    "content": "use super::data::NestedValue;\n\n/// A trait that defines the adapter for a Burn module.\n///\n/// This is used to adapt an incoming module to a Burn module.\npub trait BurnModuleAdapter: Sized {\n    /// Adapts a module.\n    fn adapt(name: &str, data: NestedValue) -> NestedValue {\n        match name {\n            \"BatchNorm\" => Self::adapt_batch_norm(data),\n            \"Conv1d\" => Self::adapt_conv1d(data),\n            \"Conv2d\" => Self::adapt_conv2d(data),\n            \"Conv3d\" => Self::adapt_conv3d(data),\n            \"ConvTranspose1d\" => Self::adapt_conv_transpose_1d(data),\n            \"ConvTranspose2d\" => Self::adapt_conv_transpose_2d(data),\n            \"ConvTranspose3d\" => Self::adapt_conv_transpose_3d(data),\n            \"Embedding\" => Self::adapt_embedding(data),\n            \"GroupNorm\" => Self::adapt_group_norm(data),\n            \"LayerNorm\" => Self::adapt_layer_norm(data),\n            \"Linear\" => Self::adapt_linear(data),\n            _ => data,\n        }\n    }\n\n    /// Adapts a linear module.\n    fn adapt_linear(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts a Convolution 1D module.\n    fn adapt_conv1d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts a Convolution 2D module.\n    fn adapt_conv2d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts a Convolution 3D module.\n    fn adapt_conv3d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts convolution transpose 1D module.\n    fn adapt_conv_transpose_1d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts convolution transpose 2D module.\n    fn adapt_conv_transpose_2d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts convolution transpose 2D module.\n    fn adapt_conv_transpose_3d(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts embedding module.\n    fn adapt_embedding(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts group normalization module.\n    fn adapt_group_norm(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts layer normalization module.\n    fn adapt_layer_norm(data: NestedValue) -> NestedValue {\n        data\n    }\n\n    /// Adapts batch normalization module.\n    fn adapt_batch_norm(data: NestedValue) -> NestedValue {\n        data\n    }\n}\n\n/// Default adapter that takes no action.\npub struct DefaultAdapter;\nimpl BurnModuleAdapter for DefaultAdapter {}\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/data.rs",
    "content": "use std::collections::HashMap;\n\nuse super::adapter::BurnModuleAdapter;\nuse super::de::Deserializer;\nuse super::error::Error;\nuse super::ser::Serializer;\nuse crate::record::{PrecisionSettings, Record};\nuse crate::tensor::backend::Backend;\n\nuse alloc::fmt;\nuse burn_tensor::Bytes;\nuse num_traits::cast::ToPrimitive;\nuse regex::Regex;\nuse serde::Deserialize;\n\n/// The main data structure used for deserialization.\n///\n/// It can hold tree-like structures of nested maps and vectors.\n#[derive(Clone)]\npub enum NestedValue {\n    /// The default value, which actually does not hold any value and it is used to indicate that\n    /// the value should be populated with the default value. It contains an optional string with\n    /// the originator field name.\n    Default(Option<String>),\n\n    /// A boolean value.\n    Bool(bool),\n\n    /// A string value.\n    String(String),\n\n    /// Floating point 32-bit value.\n    F32(f32),\n\n    /// Floating point 64-bit value.\n    F64(f64),\n\n    /// Signed 16-bit integer value.\n    I16(i16),\n\n    /// Signed 32-bit integer value.\n    I32(i32),\n\n    /// Signed 64-bit integer value.\n    I64(i64),\n\n    /// Unsigned 8-bit integer value.\n    U8(u8),\n\n    /// Unsigned 16-bit integer value used for bf16 and f16 serialization\n    U16(u16),\n\n    /// Unsigned 64-bit integer value.\n    U64(u64),\n\n    /// A map of nested values (typically used for structs)\n    Map(HashMap<String, NestedValue>),\n\n    /// A vector of nested values (typically used for vector of structs or numbers)\n    Vec(Vec<NestedValue>),\n\n    /// A vector of 8-bit unsigned integer values.\n    U8s(Vec<u8>),\n\n    /// A vector of 16-bit unsigned integer values.\n    U16s(Vec<u16>),\n\n    /// A vector of 32-bit floating point values.\n    F32s(Vec<f32>),\n\n    /// An opaque vector of bytes, with alignment.\n    Bytes(Bytes),\n}\n\nimpl NestedValue {\n    /// Get the nested value as a map.\n    pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {\n        match self {\n            NestedValue::Map(map) => Some(map),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a boolean.\n    pub fn as_bool(self) -> Option<bool> {\n        match self {\n            NestedValue::Bool(bool) => Some(bool),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a string.\n    pub fn as_string(self) -> Option<String> {\n        match self {\n            NestedValue::String(string) => Some(string),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a f32.\n    pub fn as_f32(self) -> Option<f32> {\n        match self {\n            NestedValue::F32(f32) => Some(f32),\n            NestedValue::F64(f) => f.to_f32(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a f64.\n    pub fn as_f64(self) -> Option<f64> {\n        match self {\n            NestedValue::F64(f64) => Some(f64),\n            NestedValue::F32(f) => f.to_f64(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as an i16.\n    pub fn as_i16(self) -> Option<i16> {\n        match self {\n            NestedValue::I16(i16) => Some(i16),\n            NestedValue::I32(i) => i.to_i16(),\n            NestedValue::I64(i) => i.to_i16(),\n            NestedValue::U16(u) => u.to_i16(),\n            NestedValue::U64(u) => u.to_i16(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as an i32.\n    pub fn as_i32(self) -> Option<i32> {\n        match self {\n            NestedValue::I32(i32) => Some(i32),\n            NestedValue::I16(i) => i.to_i32(),\n            NestedValue::I64(i) => i.to_i32(),\n            NestedValue::U16(u) => u.to_i32(),\n            NestedValue::U64(u) => u.to_i32(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as an i64.\n    pub fn as_i64(self) -> Option<i64> {\n        match self {\n            NestedValue::I64(i64) => Some(i64),\n            NestedValue::I16(i) => i.to_i64(),\n            NestedValue::I32(i) => i.to_i64(),\n            NestedValue::U16(u) => u.to_i64(),\n            NestedValue::U64(u) => u.to_i64(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a u8.\n    pub fn as_u8(self) -> Option<u8> {\n        match self {\n            NestedValue::U8(u8) => Some(u8),\n            NestedValue::I16(i) => i.to_u8(),\n            NestedValue::I32(i) => i.to_u8(),\n            NestedValue::I64(i) => i.to_u8(),\n            NestedValue::U16(u) => u.to_u8(),\n            NestedValue::U64(u) => u.to_u8(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a u16.\n    pub fn as_u16(self) -> Option<u16> {\n        match self {\n            NestedValue::U16(u16) => Some(u16),\n            NestedValue::I16(i) => i.to_u16(),\n            NestedValue::I32(i) => i.to_u16(),\n            NestedValue::I64(i) => i.to_u16(),\n            NestedValue::U64(u) => u.to_u16(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a u64.\n    pub fn as_u64(self) -> Option<u64> {\n        match self {\n            NestedValue::U64(u64) => Some(u64),\n            NestedValue::I16(i) => i.to_u64(),\n            NestedValue::I32(i) => i.to_u64(),\n            NestedValue::I64(i) => i.to_u64(),\n            NestedValue::U16(u) => u.to_u64(),\n            _ => None,\n        }\n    }\n\n    /// Get the nested value as a vector of bytes.\n    pub fn as_bytes(self) -> Option<Bytes> {\n        match self {\n            NestedValue::Bytes(u) => Some(u),\n            NestedValue::U8s(u) => Some(Bytes::from_elems(u)),\n            _ => None,\n        }\n    }\n\n    /// Deserialize a nested value into a record type.\n    pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>\n    where\n        B: Backend,\n        T: Record<B>,\n        PS: PrecisionSettings,\n        A: BurnModuleAdapter,\n    {\n        let deserializer = Deserializer::<A>::new(self, false);\n\n        let item = T::Item::deserialize(deserializer)?;\n\n        // Convert the deserialized item into a Record instance\n        Ok(T::from_item::<PS>(item, device))\n    }\n}\n\n/// Remap the tensor locations according to the key remapping.\n///\n/// # Arguments\n///\n/// * `tensors` - A map of tensors.\n/// * `key_remap` - A vector of tuples containing a regular expression and a replacement string.\n///   See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)\n///   for more information.\n///\n/// # Returns\n///\n/// A map of tensors with the remapped keys and\n/// a vector of tuples containing the remapped and original.\npub fn remap<T>(\n    mut tensors: HashMap<String, T>,\n    key_remap: Vec<(Regex, String)>,\n) -> (HashMap<String, T>, Vec<(String, String)>) {\n    if key_remap.is_empty() {\n        let remapped_names = tensors\n            .keys()\n            .cloned()\n            .map(|s| (s.clone(), s)) // Name is the same as the remapped name\n            .collect();\n        return (tensors, remapped_names);\n    }\n\n    let mut remapped = HashMap::new();\n    let mut remapped_names = Vec::new();\n\n    for (name, tensor) in tensors.drain() {\n        let mut new_name = name.clone();\n        for (pattern, replacement) in &key_remap {\n            if pattern.is_match(&new_name) {\n                new_name = pattern\n                    .replace_all(&new_name, replacement.as_str())\n                    .to_string();\n            }\n        }\n\n        remapped_names.push((new_name.clone(), name));\n        remapped.insert(new_name, tensor);\n    }\n\n    (remapped, remapped_names)\n}\n\n/// Helper function to insert a value into a nested map/vector of tensors.\nfn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {\n    if keys.is_empty() {\n        *current = value;\n        return;\n    }\n\n    match current {\n        NestedValue::Map(map) => {\n            if !map.contains_key(keys[0]) {\n                let next = if keys[1..]\n                    .first()\n                    .and_then(|k| k.parse::<usize>().ok())\n                    .is_some()\n                {\n                    NestedValue::Vec(Vec::new())\n                } else {\n                    NestedValue::Map(HashMap::new())\n                };\n                map.insert(keys[0].to_string(), next);\n            }\n            insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);\n        }\n        NestedValue::Vec(vec) => {\n            let index = keys[0].parse::<usize>().unwrap();\n            if index >= vec.len() {\n                vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));\n            }\n            insert_nested_value(&mut vec[index], &keys[1..], value);\n        }\n        _ => panic!(\"Invalid structure encountered\"),\n    }\n}\n\n/// A trait for encapsulating the serialization logic.\npub trait Serializable {\n    /// Serializes the object into a `NestedValue` using the provided `Serializer`.\n    /// This method is generic over the precision settings `PS`.\n    ///\n    /// # Parameters\n    /// - `serializer`: The `Serializer` to use for serializing the object.\n    ///\n    /// # Returns\n    /// - `Result<NestedValue, Error>`: The result of serialization.\n    ///   Returns a `NestedValue` on success,\n    ///   or an `Error` on failure.\n    ///\n    /// # Type Parameters\n    /// - `PS`: The precision settings to use during serialization.\n    ///   This is a generic parameter and can be any type\n    ///   that implements the `PrecisionSettings` trait.\n    fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>\n    where\n        PS: PrecisionSettings;\n}\n\n/// Convert a vector of tensors to a nested value.\npub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>\nwhere\n    PS: PrecisionSettings,\n    T: Serializable,\n{\n    let mut result = NestedValue::Map(HashMap::new());\n\n    for (key, value) in input {\n        let parts: Vec<&str> = key.split('.').collect();\n        let st = value.serialize::<PS>(Serializer::new())?;\n\n        insert_nested_value(&mut result, &parts, st);\n    }\n\n    cleanup_empty_maps(&mut result);\n\n    Ok(result)\n}\n\n/// Removes empty maps from the nested value.\n///\n/// We need to clean up empty maps from the nested value\n/// in some cases when there is non-contiguous indices in keys.\nfn cleanup_empty_maps(current: &mut NestedValue) {\n    match current {\n        NestedValue::Map(map) => {\n            map.values_mut().for_each(cleanup_empty_maps);\n        }\n        NestedValue::Vec(vec) => {\n            vec.iter_mut().for_each(cleanup_empty_maps);\n            vec.retain(|v| !matches!(v, NestedValue::Map(m) if m.is_empty()));\n        }\n        _ => {}\n    }\n}\n\nfn write_vec_truncated<T: core::fmt::Debug>(\n    vec: &[T],\n    f: &mut core::fmt::Formatter,\n) -> fmt::Result {\n    write!(f, \"Vec([\")?;\n    for (i, v) in vec.iter().take(3).enumerate() {\n        if i > 0 {\n            write!(f, \", \")?;\n        }\n        write!(f, \"{v:?}\")?;\n    }\n    write!(f, \", ...] len={})\", vec.len())\n}\n\nimpl fmt::Debug for NestedValue {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        match self {\n            // Truncate values for vector\n            NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),\n            NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),\n            NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),\n            NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),\n            NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f),\n            // Handle other variants as usual\n            NestedValue::Default(origin) => f.debug_tuple(\"Default\").field(origin).finish(),\n            NestedValue::Bool(b) => f.debug_tuple(\"Bool\").field(b).finish(),\n            NestedValue::String(s) => f.debug_tuple(\"String\").field(s).finish(),\n            NestedValue::F32(val) => f.debug_tuple(\"F32\").field(val).finish(),\n            NestedValue::F64(val) => f.debug_tuple(\"F64\").field(val).finish(),\n            NestedValue::I16(val) => f.debug_tuple(\"I16\").field(val).finish(),\n            NestedValue::I32(val) => f.debug_tuple(\"I32\").field(val).finish(),\n            NestedValue::I64(val) => f.debug_tuple(\"I64\").field(val).finish(),\n            NestedValue::U8(val) => f.debug_tuple(\"U8\").field(val).finish(),\n            NestedValue::U16(val) => f.debug_tuple(\"U16\").field(val).finish(),\n            NestedValue::U64(val) => f.debug_tuple(\"U64\").field(val).finish(),\n            NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),\n            NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),\n            NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(),\n            NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),\n            NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),\n            NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/de.rs",
    "content": "use core::ptr;\nuse std::collections::HashMap;\n\nuse super::data::NestedValue;\nuse super::{adapter::BurnModuleAdapter, error::Error};\n\nuse serde::de::{EnumAccess, VariantAccess};\nuse serde::{\n    de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor},\n    forward_to_deserialize_any,\n};\n\nconst RECORD_ITEM_SUFFIX: &str = \"RecordItem\";\n\n/// A deserializer for the nested value data structure.\npub struct Deserializer<A: BurnModuleAdapter> {\n    // This string starts with the input data and characters are truncated off\n    // the beginning as data is parsed.\n    value: Option<NestedValue>,\n    default_for_missing_fields: bool,\n    phantom: std::marker::PhantomData<A>,\n}\n\nimpl<A: BurnModuleAdapter> Deserializer<A> {\n    /// Creates a new deserializer with the given nested value.\n    ///\n    /// # Arguments\n    ///\n    /// * `value` - A nested value.\n    /// * `default_for_missing_fields` - A boolean indicating whether to add missing fields with default value.\n    pub fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {\n        Self {\n            value: Some(value),\n            default_for_missing_fields,\n            phantom: std::marker::PhantomData,\n        }\n    }\n}\n\nimpl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {\n    type Error = Error;\n\n    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_any is not implemented\")\n    }\n\n    fn deserialize_struct<V>(\n        self,\n        name: &'static str,\n        fields: &'static [&'static str],\n        visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        let value = match self.value {\n            Some(value) => {\n                // Adapt modules\n                if let Some(name) = name.strip_suffix(RECORD_ITEM_SUFFIX) {\n                    A::adapt(name, value)\n                } else {\n                    value\n                }\n            }\n            None => {\n                return Err(de::Error::custom(format!(\n                    \"Expected some value but got {:?}\",\n                    self.value\n                )));\n            }\n        };\n\n        match value {\n            NestedValue::Map(map) => {\n                // Add missing fields into the map with default value if needed.\n                let map = if self.default_for_missing_fields {\n                    let mut map = map;\n                    for field in fields.iter().map(|s| s.to_string()) {\n                        map.entry(field.clone())\n                            .or_insert(NestedValue::Default(Some(field)));\n                    }\n                    map\n                } else {\n                    map\n                };\n\n                visitor.visit_map(HashMapAccess::<A>::new(\n                    map,\n                    self.default_for_missing_fields,\n                ))\n            }\n\n            _ => Err(de::Error::custom(format!(\n                \"Expected struct but got {value:?}\"\n            ))),\n        }\n    }\n\n    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_string(self.value.unwrap().as_string().unwrap().to_string())\n    }\n\n    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_unit()\n    }\n\n    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        match self.value {\n            Some(NestedValue::Map(map)) => visitor.visit_map(HashMapAccess::<A>::new(\n                map,\n                self.default_for_missing_fields,\n            )),\n\n            _ => Err(de::Error::custom(format!(\n                \"Expected map value but got {:?}\",\n                self.value\n            ))),\n        }\n    }\n\n    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_bool(self.value.unwrap().as_bool().unwrap())\n    }\n\n    fn deserialize_i8<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_i8 is not implemented\")\n    }\n\n    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i16(self.value.unwrap().as_i16().unwrap().to_owned())\n    }\n\n    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i32(self.value.unwrap().as_i32().unwrap().to_owned())\n    }\n\n    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i64(self.value.unwrap().as_i64().unwrap().to_owned())\n    }\n\n    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u8(self.value.unwrap().as_u8().unwrap().to_owned())\n    }\n\n    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u16(self.value.unwrap().as_u16().unwrap().to_owned())\n    }\n\n    fn deserialize_u32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_u32 is not implemented\")\n    }\n\n    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u64(self.value.unwrap().as_u64().unwrap().to_owned())\n    }\n\n    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_f32(self.value.unwrap().as_f32().unwrap().to_owned())\n    }\n\n    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_f64(self.value.unwrap().as_f64().unwrap().to_owned())\n    }\n\n    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_char is not implemented\")\n    }\n\n    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_str(self.value.unwrap().as_string().unwrap().as_ref())\n    }\n\n    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_bytes is not implemented\")\n    }\n\n    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        let bytes = self.value.unwrap().as_bytes().unwrap();\n        match bytes.try_into_vec::<u8>() {\n            Ok(bytes) => visitor.visit_byte_buf(bytes),\n            Err(bytes) => visitor.visit_bytes(&bytes),\n        }\n    }\n\n    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        if let Some(value) = self.value {\n            visitor.visit_some(Deserializer::<A>::new(\n                value,\n                self.default_for_missing_fields,\n            ))\n        } else {\n            visitor.visit_none()\n        }\n    }\n\n    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_unit is not implemented\")\n    }\n\n    fn deserialize_unit_struct<V>(\n        self,\n        _name: &'static str,\n        _visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_unit_struct is not implemented\")\n    }\n\n    fn deserialize_newtype_struct<V>(\n        self,\n        _name: &'static str,\n        visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_newtype_struct(Deserializer::<A>::new(\n            self.value.unwrap(),\n            self.default_for_missing_fields,\n        ))\n    }\n\n    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        if let Some(value) = self.value {\n            match value {\n                NestedValue::Vec(_) => visitor.visit_seq(VecSeqAccess::<A, NestedValue>::new(\n                    value,\n                    self.default_for_missing_fields,\n                )),\n                NestedValue::U8s(_) => visitor.visit_seq(VecSeqAccess::<A, u8>::new(\n                    value,\n                    self.default_for_missing_fields,\n                )),\n                NestedValue::U16s(_) => visitor.visit_seq(VecSeqAccess::<A, u16>::new(\n                    value,\n                    self.default_for_missing_fields,\n                )),\n                NestedValue::F32s(_) => visitor.visit_seq(VecSeqAccess::<A, f32>::new(\n                    value,\n                    self.default_for_missing_fields,\n                )),\n                _ => Err(de::Error::custom(format!(\"Expected Vec but got {value:?}\"))),\n            }\n        } else {\n            Err(de::Error::custom(\"Expected Vec but got None\"))\n        }\n    }\n\n    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_tuple is not implemented\")\n    }\n\n    fn deserialize_tuple_struct<V>(\n        self,\n        _name: &'static str,\n        _len: usize,\n        _visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_tuple_struct is not implemented\")\n    }\n\n    /// Deserializes an enum by attempting to match its variants against the provided data.\n    ///\n    /// This function attempts to deserialize an enum by iterating over its possible variants\n    /// and trying to deserialize the data into each until one succeeds. We need to do this\n    /// because we don't have a way to know which variant to deserialize from the data.\n    ///\n    /// This is similar to Serde's\n    /// [untagged enum deserialization](https://serde.rs/enum-representations.html#untagged),\n    /// but it's on the deserializer side. Using `#[serde(untagged)]` on the enum will force\n    /// using `deserialize_any`, which is not what we want because we want to use methods, such\n    /// as `visit_struct`. Also we do not wish to use auto generate code for Deserialize just\n    /// for enums because it will affect other serialization and deserialization, such\n    /// as JSON and Bincode.\n    ///\n    /// # Safety\n    /// The function uses an unsafe block to clone the `visitor`. This is necessary because\n    /// the `Visitor` trait does not have a `Clone` implementation, and we need to clone it\n    /// as we are going to use it multiple times. The Visitor is a code generated unit struct\n    /// with no states or mutations, so it is safe to clone it in this case. We mainly care\n    /// about the `visit_enum` method, which is the only method that will be called on the\n    /// cloned visitor.\n    fn deserialize_enum<V>(\n        self,\n        _name: &'static str,\n        variants: &'static [&'static str],\n        visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        fn clone_unsafely<T>(thing: &T) -> T {\n            unsafe {\n                // Allocate memory for the clone.\n                let mut clone = std::mem::MaybeUninit::<T>::uninit();\n                // Get a mutable pointer to the allocated memory.\n                let clone_ptr = clone.as_mut_ptr();\n                // Copy the memory\n                ptr::copy_nonoverlapping(thing as *const T, clone_ptr, 1);\n                // Assume the cloned data is initialized and convert it to an owned instance of T.\n                clone.assume_init()\n            }\n        }\n\n        // Try each variant in order\n        for &variant in variants {\n            // clone visitor to avoid moving it\n            let cloned_visitor = clone_unsafely(&visitor);\n            let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(\n                self.value.clone().unwrap(),\n                variant.to_owned(),\n                self.default_for_missing_fields,\n            ));\n\n            if result.is_ok() {\n                return result;\n            }\n        }\n\n        Err(de::Error::custom(\"No variant match\"))\n    }\n\n    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"deserialize_identifier is not implemented\")\n    }\n}\n\n/// A sequence access for a vector in the nested value data structure.\nstruct VecSeqAccess<A: BurnModuleAdapter, I> {\n    iter: Box<dyn Iterator<Item = I>>,\n    default_for_missing_fields: bool,\n    phantom: std::marker::PhantomData<A>,\n}\n\n// Concrete implementation for `Vec<NestedValue>`\nimpl<A: BurnModuleAdapter> VecSeqAccess<A, NestedValue> {\n    fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {\n        match vec {\n            NestedValue::Vec(v) => VecSeqAccess {\n                iter: Box::new(v.into_iter()),\n                default_for_missing_fields,\n                phantom: std::marker::PhantomData,\n            },\n            _ => panic!(\"Invalid vec sequence\"),\n        }\n    }\n}\n\n// Concrete implementation for `Vec<u8>`\nimpl<A: BurnModuleAdapter> VecSeqAccess<A, u8> {\n    fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {\n        match vec {\n            NestedValue::U8s(v) => VecSeqAccess {\n                iter: Box::new(v.into_iter()),\n                default_for_missing_fields,\n                phantom: std::marker::PhantomData,\n            },\n            _ => panic!(\"Invalid vec sequence\"),\n        }\n    }\n}\n\n// Concrete implementation for `Vec<u16>`\nimpl<A: BurnModuleAdapter> VecSeqAccess<A, u16> {\n    fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {\n        match vec {\n            NestedValue::U16s(v) => VecSeqAccess {\n                iter: Box::new(v.into_iter()),\n                default_for_missing_fields,\n                phantom: std::marker::PhantomData,\n            },\n            _ => panic!(\"Invalid vec sequence\"),\n        }\n    }\n}\n\n// Concrete implementation for `Vec<f32>`\nimpl<A: BurnModuleAdapter> VecSeqAccess<A, f32> {\n    fn new(vec: NestedValue, default_for_missing_fields: bool) -> Self {\n        match vec {\n            NestedValue::F32s(v) => VecSeqAccess {\n                iter: Box::new(v.into_iter()),\n                default_for_missing_fields,\n                phantom: std::marker::PhantomData,\n            },\n            _ => panic!(\"Invalid vec sequence\"),\n        }\n    }\n}\n\n// Concrete implementation for `Vec<NestedValue>`\nimpl<'de, A> SeqAccess<'de> for VecSeqAccess<A, NestedValue>\nwhere\n    NestedValueWrapper<A>: IntoDeserializer<'de, Error>,\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        let item = match self.iter.next() {\n            Some(v) => v,\n            None => return Ok(None),\n        };\n\n        seed.deserialize(\n            NestedValueWrapper::<A>::new(item, self.default_for_missing_fields).into_deserializer(),\n        )\n        .map(Some)\n    }\n}\n\n// Concrete implementation for `Vec<u8>`\nimpl<'de, A> SeqAccess<'de> for VecSeqAccess<A, u8>\nwhere\n    NestedValueWrapper<A>: IntoDeserializer<'de, Error>,\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        let item = match self.iter.next() {\n            Some(v) => v,\n            None => return Ok(None),\n        };\n\n        seed.deserialize(\n            NestedValueWrapper::<A>::new(NestedValue::U8(item), self.default_for_missing_fields)\n                .into_deserializer(),\n        )\n        .map(Some)\n    }\n}\n\n// Concrete implementation for `Vec<u16>`\nimpl<'de, A> SeqAccess<'de> for VecSeqAccess<A, u16>\nwhere\n    NestedValueWrapper<A>: IntoDeserializer<'de, Error>,\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        let item = match self.iter.next() {\n            Some(v) => v,\n            None => return Ok(None),\n        };\n\n        seed.deserialize(\n            NestedValueWrapper::<A>::new(NestedValue::U16(item), self.default_for_missing_fields)\n                .into_deserializer(),\n        )\n        .map(Some)\n    }\n}\n\n// Concrete implementation for `Vec<f32>`\nimpl<'de, A> SeqAccess<'de> for VecSeqAccess<A, f32>\nwhere\n    NestedValueWrapper<A>: IntoDeserializer<'de, Error>,\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        let item = match self.iter.next() {\n            Some(v) => v,\n            None => return Ok(None),\n        };\n\n        seed.deserialize(\n            NestedValueWrapper::<A>::new(NestedValue::F32(item), self.default_for_missing_fields)\n                .into_deserializer(),\n        )\n        .map(Some)\n    }\n}\n\n/// A map access for a map in the nested value data structure.\nstruct HashMapAccess<A: BurnModuleAdapter> {\n    iter: std::collections::hash_map::IntoIter<String, NestedValue>,\n    next_value: Option<NestedValue>,\n    default_for_missing_fields: bool,\n    phantom: std::marker::PhantomData<A>,\n}\n\nimpl<A: BurnModuleAdapter> HashMapAccess<A> {\n    fn new(map: HashMap<String, NestedValue>, default_for_missing_fields: bool) -> Self {\n        HashMapAccess {\n            iter: map.into_iter(),\n            next_value: None,\n            default_for_missing_fields,\n            phantom: std::marker::PhantomData,\n        }\n    }\n}\n\nimpl<'de, A> MapAccess<'de> for HashMapAccess<A>\nwhere\n    String: IntoDeserializer<'de, Error>,\n    NestedValueWrapper<A>: IntoDeserializer<'de, Error>,\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        match self.iter.next() {\n            Some((k, v)) => {\n                // Keep the value for the next call to next_value_seed.\n                self.next_value = Some(v);\n                // Deserialize the key.\n                seed.deserialize(k.into_deserializer()).map(Some)\n            }\n            None => Ok(None),\n        }\n    }\n\n    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        match self.next_value.take() {\n            Some(NestedValue::Default(originator)) => {\n                seed.deserialize(DefaultDeserializer::new(originator))\n            }\n            Some(v) => seed.deserialize(\n                NestedValueWrapper::new(v, self.default_for_missing_fields).into_deserializer(),\n            ),\n            None => seed.deserialize(DefaultDeserializer::new(None)),\n        }\n    }\n}\n\nstruct ProbeEnumAccess<A: BurnModuleAdapter> {\n    value: NestedValue,\n    current_variant: String,\n    default_for_missing_fields: bool,\n    phantom: std::marker::PhantomData<A>,\n}\n\nimpl<A: BurnModuleAdapter> ProbeEnumAccess<A> {\n    fn new(value: NestedValue, current_variant: String, default_for_missing_fields: bool) -> Self {\n        ProbeEnumAccess {\n            value,\n            current_variant,\n            default_for_missing_fields,\n            phantom: std::marker::PhantomData,\n        }\n    }\n}\n\nimpl<'de, A> EnumAccess<'de> for ProbeEnumAccess<A>\nwhere\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n    type Variant = Self;\n\n    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>\n    where\n        V: DeserializeSeed<'de>,\n    {\n        seed.deserialize(self.current_variant.clone().into_deserializer())\n            .map(|v| (v, self))\n    }\n}\n\nimpl<'de, A> VariantAccess<'de> for ProbeEnumAccess<A>\nwhere\n    A: BurnModuleAdapter,\n{\n    type Error = Error;\n\n    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        let value = seed.deserialize(\n            NestedValueWrapper::<A>::new(self.value, self.default_for_missing_fields)\n                .into_deserializer(),\n        )?;\n        Ok(value)\n    }\n\n    fn unit_variant(self) -> Result<(), Self::Error> {\n        // Support tensor `DType` deserialization\n        match self.value {\n            NestedValue::Map(value) if value.contains_key(\"DType\") => {\n                match value.get(\"DType\") {\n                    Some(NestedValue::String(variant)) => {\n                        if *variant == self.current_variant {\n                            Ok(())\n                        } else {\n                            Err(Error::Other(\"Wrong variant\".to_string())) // wrong match\n                        }\n                    }\n                    _ => panic!(\"expected DType variant as string\"),\n                }\n            }\n            _ => unimplemented!(\n                \"unit variant is not implemented because it is not used in the burn module\"\n            ),\n        }\n    }\n\n    fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\"tuple variant is not implemented because it is not used in the burn module\")\n    }\n\n    fn struct_variant<V>(\n        self,\n        _fields: &'static [&'static str],\n        _visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!(\n            \"struct variant is not implemented because it is not used in the burn module\"\n        )\n    }\n}\n\n/// A wrapper for the nested value data structure with a burn module adapter.\nstruct NestedValueWrapper<A: BurnModuleAdapter> {\n    value: NestedValue,\n    default_for_missing_fields: bool,\n    phantom: std::marker::PhantomData<A>,\n}\n\nimpl<A: BurnModuleAdapter> NestedValueWrapper<A> {\n    fn new(value: NestedValue, default_for_missing_fields: bool) -> Self {\n        Self {\n            value,\n            default_for_missing_fields,\n            phantom: std::marker::PhantomData,\n        }\n    }\n}\n\nimpl<A: BurnModuleAdapter> IntoDeserializer<'_, Error> for NestedValueWrapper<A> {\n    type Deserializer = Deserializer<A>;\n\n    fn into_deserializer(self) -> Self::Deserializer {\n        Deserializer::<A>::new(self.value, self.default_for_missing_fields)\n    }\n}\n\n/// A default deserializer that always returns the default value.\nstruct DefaultDeserializer {\n    /// The originator field name (the top-level missing field name)\n    originator_field_name: Option<String>,\n}\n\nimpl DefaultDeserializer {\n    fn new(originator_field_name: Option<String>) -> Self {\n        Self {\n            originator_field_name,\n        }\n    }\n}\n\nimpl<'de> serde::Deserializer<'de> for DefaultDeserializer {\n    type Error = Error;\n\n    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        unimplemented!()\n    }\n\n    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i32(Default::default())\n    }\n\n    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_f32(Default::default())\n    }\n\n    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i16(Default::default())\n    }\n\n    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i64(Default::default())\n    }\n\n    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u16(Default::default())\n    }\n\n    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u64(Default::default())\n    }\n\n    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_f64(Default::default())\n    }\n\n    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_bool(Default::default())\n    }\n\n    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_char(Default::default())\n    }\n\n    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_str(Default::default())\n    }\n\n    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_i8(Default::default())\n    }\n\n    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u8(Default::default())\n    }\n\n    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_u32(Default::default())\n    }\n\n    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_none()\n    }\n\n    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_seq(DefaultSeqAccess::new(None))\n    }\n\n    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_string(Default::default())\n    }\n\n    fn deserialize_struct<V>(\n        self,\n        name: &'static str,\n        _fields: &'static [&'static str],\n        _visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        // Return an error if the originator field name is not set\n        Err(Error::Other(format!(\n            \"Missing source values for the '{}' field of type '{}'. Please verify the source data and ensure the field name is correct\",\n            self.originator_field_name.unwrap_or(\"UNKNOWN\".to_string()),\n            name,\n        )))\n    }\n\n    fn deserialize_tuple_struct<V>(\n        self,\n        _name: &'static str,\n        len: usize,\n        visitor: V,\n    ) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_seq(DefaultSeqAccess::new(Some(len)))\n    }\n\n    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_seq(DefaultSeqAccess::new(Some(len)))\n    }\n\n    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_map(DefaultMapAccess::new())\n    }\n\n    forward_to_deserialize_any! {\n        u128 bytes byte_buf unit unit_struct newtype_struct\n        enum identifier ignored_any\n    }\n}\n\n/// A default sequence access that always returns None (empty sequence).\npub struct DefaultSeqAccess {\n    size: Option<usize>,\n}\n\nimpl Default for DefaultSeqAccess {\n    fn default() -> Self {\n        Self::new(None)\n    }\n}\n\nimpl DefaultSeqAccess {\n    /// Creates a new default sequence access with the given size hint.\n    pub fn new(size: Option<usize>) -> Self {\n        DefaultSeqAccess { size }\n    }\n}\n\nimpl<'de> SeqAccess<'de> for DefaultSeqAccess {\n    type Error = Error;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        match self.size {\n            Some(0) => Ok(None),\n            Some(ref mut size) => {\n                *size -= 1;\n                seed.deserialize(DefaultDeserializer::new(None)).map(Some)\n            }\n            None => Ok(None),\n        }\n    }\n\n    fn size_hint(&self) -> Option<usize> {\n        self.size\n    }\n}\n\n/// A default map access that always returns None (empty map).\npub struct DefaultMapAccess;\n\nimpl Default for DefaultMapAccess {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl DefaultMapAccess {\n    /// Creates a new default map access.\n    pub fn new() -> Self {\n        DefaultMapAccess\n    }\n}\n\nimpl<'de> MapAccess<'de> for DefaultMapAccess {\n    type Error = Error;\n\n    fn next_key_seed<T>(&mut self, _seed: T) -> Result<Option<T::Value>, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        // Since this is a default implementation, we'll just return None.\n        Ok(None)\n    }\n\n    fn next_value_seed<T>(&mut self, _seed: T) -> Result<T::Value, Self::Error>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        unimplemented!(\"This should never be called since next_key_seed always returns None\")\n    }\n\n    fn size_hint(&self) -> Option<usize> {\n        // Since this is a default implementation, we'll just return None.\n        None\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/error.rs",
    "content": "use crate::record::RecorderError;\n\n/// The error type for Record serde.\n#[derive(thiserror::Error, Debug)]\npub enum Error {\n    /// Failed to deserialize.\n    #[error(\"failed to deserialize: {0}\")]\n    Deserialize(#[from] serde::de::value::Error),\n\n    /// Failed to serialize.\n    #[error(\"failed to serialize\")]\n    Serialize(String),\n\n    /// Encountered an invalid state.\n    #[error(\"invalid state\")]\n    InvalidState,\n\n    /// Other error.\n    #[error(\"other error: {0}\")]\n    Other(String),\n}\n\nimpl serde::de::Error for Error {\n    fn custom<T: std::fmt::Display>(msg: T) -> Self {\n        Error::Deserialize(serde::de::value::Error::custom(msg.to_string()))\n    }\n}\n\nimpl serde::ser::Error for Error {\n    fn custom<T: std::fmt::Display>(msg: T) -> Self {\n        Error::Serialize(msg.to_string())\n    }\n}\n\n// Implement From trait for Error to RecorderError\nimpl From<Error> for RecorderError {\n    fn from(error: Error) -> Self {\n        RecorderError::DeserializeError(error.to_string())\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/mod.rs",
    "content": "//! Module contains the serde implementation for the record module\n//! useful for custom importing model weights, such as PyTorch's pt file format.\n\n/// The adapter trait that is used to convert the nested value to the module type.\npub mod adapter;\n\n/// The main data structure used for deserialization.\npub mod data;\n\n/// The deserializer that is used to convert the nested value to the record.\npub mod ser;\n\n/// The deserializer that is used to convert the nested value to the record.\npub mod de;\n\n/// Error types.\npub mod error;\n"
  },
  {
    "path": "crates/burn-core/src/record/serde/ser.rs",
    "content": "use std::collections::HashMap;\n\nuse super::{\n    data::NestedValue,\n    error::{self, Error},\n};\n\nuse serde::{\n    Serialize,\n    ser::{self, SerializeSeq, SerializeStruct, Serializer as SerializerTrait},\n};\n\n/// Simple struct serializer that converts a struct into NestedValues.\n///\n/// NOTE: This is used to serialize Param structs into NestedValues and not so much for\n/// the actual serialization of modules (although it could be used for that as well if all\n/// primitive types are implemented).\n#[derive(Clone)]\npub struct Serializer {\n    /// The state of the serialization process\n    state: Option<NestedValue>,\n}\n\nimpl Serializer {\n    /// Creates a new serializer.\n    pub fn new() -> Self {\n        Serializer { state: None }\n    }\n}\n\nimpl Default for Serializer {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl SerializerTrait for Serializer {\n    type Ok = NestedValue;\n    type Error = Error;\n    type SerializeSeq = Self;\n    type SerializeTuple = ser::Impossible<NestedValue, Self::Error>;\n    type SerializeTupleStruct = ser::Impossible<NestedValue, Self::Error>;\n    type SerializeTupleVariant = ser::Impossible<NestedValue, Self::Error>;\n    type SerializeMap = ser::Impossible<NestedValue, Self::Error>;\n    type SerializeStruct = Self;\n    type SerializeStructVariant = ser::Impossible<NestedValue, Self::Error>;\n\n    fn serialize_struct(\n        self,\n        _name: &'static str,\n        _len: usize,\n    ) -> Result<Self::SerializeStruct, Self::Error> {\n        Ok(self)\n    }\n\n    fn serialize_newtype_struct<T>(\n        self,\n        _name: &'static str,\n        value: &T,\n    ) -> Result<Self::Ok, Self::Error>\n    where\n        T: Serialize + ?Sized,\n    {\n        value.serialize(self)\n    }\n\n    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {\n        Ok(self)\n    }\n\n    fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::I32(v))\n    }\n\n    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::String(v.to_string()))\n    }\n\n    fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::I16(v))\n    }\n\n    fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::I64(v))\n    }\n\n    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::U16(v))\n    }\n\n    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::U64(v))\n    }\n\n    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::F32(v))\n    }\n\n    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::F64(v))\n    }\n\n    // The following methods are not implemented because they are not needed for the\n    // serialization of Param structs.\n\n    fn serialize_char(self, _v: char) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::U8s(v.to_vec()))\n    }\n\n    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::Default(None))\n    }\n    fn serialize_u32(self, _v: u32) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n    fn serialize_bool(self, _v: bool) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_i8(self, _v: i8) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::U8(v))\n    }\n\n    fn serialize_some<T>(self, value: &T) -> Result<Self::Ok, Self::Error>\n    where\n        T: Serialize + ?Sized,\n    {\n        value.serialize(self)\n    }\n\n    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_unit_variant(\n        self,\n        _name: &'static str,\n        _variant_index: u32,\n        _variant: &'static str,\n    ) -> Result<Self::Ok, Self::Error> {\n        Ok(NestedValue::Map(HashMap::from([(\n            _name.to_string(),\n            NestedValue::String(_variant.to_string()),\n        )])))\n    }\n\n    fn serialize_newtype_variant<T>(\n        self,\n        _name: &'static str,\n        _variant_index: u32,\n        _variant: &'static str,\n        _value: &T,\n    ) -> Result<Self::Ok, Self::Error>\n    where\n        T: Serialize + ?Sized,\n    {\n        unimplemented!()\n    }\n\n    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_tuple_struct(\n        self,\n        _name: &'static str,\n        _len: usize,\n    ) -> Result<Self::SerializeTupleStruct, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_tuple_variant(\n        self,\n        _name: &'static str,\n        _variant_index: u32,\n        _variant: &'static str,\n        _len: usize,\n    ) -> Result<Self::SerializeTupleVariant, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {\n        unimplemented!()\n    }\n\n    fn serialize_struct_variant(\n        self,\n        _name: &'static str,\n        _variant_index: u32,\n        _variant: &'static str,\n        _len: usize,\n    ) -> Result<Self::SerializeStructVariant, Self::Error> {\n        unimplemented!()\n    }\n}\n\n// Implementing the SerializeStruct trait for Serializer\nimpl SerializeStruct for Serializer {\n    type Ok = NestedValue;\n    type Error = Error;\n\n    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error>\n    where\n        T: Serialize + ?Sized,\n    {\n        let serialized_value = value.serialize(Serializer::new())?;\n\n        match self.state {\n            Some(NestedValue::Map(ref mut map)) => {\n                map.insert(key.to_string(), serialized_value); // Inserting into the state\n            }\n            Some(_) => {\n                panic!(\"Invalid state encountered\");\n            }\n            None => {\n                let mut map = HashMap::new();\n                map.insert(key.to_string(), serialized_value); // Inserting into the state\n                self.state = Some(NestedValue::Map(map));\n            }\n        }\n\n        Ok(())\n    }\n\n    fn end(self) -> Result<Self::Ok, Self::Error> {\n        if self.state.is_none() {\n            // If the state is empty, return an empty map\n            Ok(NestedValue::Map(HashMap::new()))\n        } else {\n            self.state.ok_or(error::Error::InvalidState)\n        }\n    }\n}\n\nimpl SerializeSeq for Serializer {\n    type Ok = NestedValue;\n    type Error = Error;\n\n    fn serialize_element<T>(&mut self, value: &T) -> Result<(), Self::Error>\n    where\n        T: Serialize + ?Sized,\n    {\n        let serialized_value = value.serialize(Serializer::new())?;\n\n        match self.state {\n            Some(NestedValue::Vec(ref mut vec)) => {\n                vec.push(serialized_value); // Inserting into the state\n            }\n            Some(NestedValue::U8s(ref mut vec)) => {\n                if let NestedValue::U8(val) = serialized_value {\n                    vec.push(val);\n                } else {\n                    panic!(\"Invalid value type encountered\");\n                }\n            }\n            Some(NestedValue::U16s(ref mut vec)) => {\n                if let NestedValue::U16(val) = serialized_value {\n                    vec.push(val);\n                } else {\n                    panic!(\"Invalid value type encountered\");\n                }\n            }\n            Some(NestedValue::F32s(ref mut vec)) => {\n                if let NestedValue::F32(val) = serialized_value {\n                    vec.push(val);\n                } else {\n                    panic!(\"Invalid value type encountered\");\n                }\n            }\n            Some(_) => {\n                panic!(\"Invalid state encountered\");\n            }\n            None => {\n                let val = match serialized_value {\n                    NestedValue::U8(val) => NestedValue::U8s(vec![val]),\n                    NestedValue::U16(val) => NestedValue::U16s(vec![val]),\n                    NestedValue::F32(val) => NestedValue::F32s(vec![val]),\n                    _ => NestedValue::Vec(vec![serialized_value]),\n                };\n                self.state = Some(val);\n            }\n        }\n\n        Ok(())\n    }\n\n    fn end(self) -> Result<Self::Ok, Self::Error> {\n        if self.state.is_none() {\n            // If the state is empty, return an empty vector\n            Ok(NestedValue::Vec(Vec::new()))\n        } else {\n            self.state.ok_or(error::Error::InvalidState)\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::{\n        TestBackend,\n        module::{Param, ParamId},\n        record::{FullPrecisionSettings, Record},\n        tensor::Tensor,\n    };\n    use serde::Deserialize;\n\n    use super::*;\n\n    #[derive(Serialize, Deserialize, Debug, Clone)]\n    struct MyStruct1 {\n        a: MyStruct3,\n        b: MyStruct2,\n    }\n\n    #[derive(Serialize, Deserialize, Debug, Clone)]\n    struct MyStruct2 {\n        a: i32,\n        b: Option<i32>,\n        c: String,\n        d: Option<String>,\n    }\n\n    #[derive(Serialize, Deserialize, Debug, Clone)]\n    struct MyStruct3 {\n        x: String,\n        y: String,\n    }\n\n    #[test]\n    fn test_serialize() {\n        let my_struct = MyStruct1 {\n            a: MyStruct3 {\n                x: \"Hello\".to_owned(),\n                y: \"World\".to_owned(),\n            },\n            b: MyStruct2 {\n                a: 1,\n                b: None,\n                c: \"Hello\".to_owned(),\n                d: Some(\"World\".to_owned()),\n            },\n        };\n\n        let serialized = my_struct\n            .serialize(Serializer::new())\n            .expect(\"Should serialize item successfully\");\n\n        let serialized_str = format!(\"{serialized:?}\");\n\n        // Compare the lengths of expected and actual serialized strings because\n        // the order of the fields is not guaranteed for HashMaps.\n        assert_eq!(serialized_str.len(), 135);\n    }\n\n    #[test]\n    fn test_param_serde() {\n        let device = Default::default();\n        let tensor: Tensor<TestBackend, 2> = Tensor::ones([2, 2], &device);\n        let param = Param::initialized(ParamId::new(), tensor);\n        let param_item = param.into_item::<FullPrecisionSettings>();\n\n        let serialized = param_item\n            .serialize(Serializer::new())\n            .expect(\"Should serialize item successfully\");\n\n        let bytes = serialized.as_map().expect(\"is a map\")[\"param\"]\n            .clone()\n            .as_map()\n            .expect(\"param is a map\")[\"bytes\"]\n            .clone()\n            .as_bytes()\n            .expect(\"has bytes vec\");\n        assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/settings.rs",
    "content": "use burn_tensor::Element;\nuse serde::{Serialize, de::DeserializeOwned};\n\n/// Settings allowing to control the precision when (de)serializing items.\npub trait PrecisionSettings:\n    Send + Sync + core::fmt::Debug + core::default::Default + Clone\n{\n    /// Float element type.\n    type FloatElem: Element + Serialize + DeserializeOwned;\n\n    /// Integer element type.\n    type IntElem: Element + Serialize + DeserializeOwned;\n}\n\n/// Default precision settings.\n#[derive(Debug, Default, Clone)]\npub struct FullPrecisionSettings;\n\n/// Precision settings optimized for compactness.\n#[derive(Debug, Default, Clone)]\npub struct HalfPrecisionSettings;\n\n/// Precision settings optimized for precision.\n#[derive(Debug, Default, Clone)]\npub struct DoublePrecisionSettings;\n\nimpl PrecisionSettings for FullPrecisionSettings {\n    type FloatElem = f32;\n    type IntElem = i32;\n}\n\nimpl PrecisionSettings for DoublePrecisionSettings {\n    type FloatElem = f64;\n    type IntElem = i64;\n}\n\nimpl PrecisionSettings for HalfPrecisionSettings {\n    type FloatElem = half::f16;\n    type IntElem = i16;\n}\n"
  },
  {
    "path": "crates/burn-core/src/record/tensor.rs",
    "content": "use core::marker::PhantomData;\n\nuse super::{PrecisionSettings, Record};\nuse burn_tensor::{Bool, DType, Element, Int, Tensor, TensorData, backend::Backend};\nuse serde::{Deserialize, Serialize};\n\nuse alloc::format;\n\n/// Deserialize the value into [`TensorData`].\nfn deserialize_data<'de, E, De>(deserializer: De) -> Result<TensorData, De::Error>\nwhere\n    E: Element + Deserialize<'de>,\n    De: serde::Deserializer<'de>,\n{\n    let data = TensorData::deserialize(deserializer).map_err(|e| {\n        serde::de::Error::custom(format!(\n            \"{e:?}\\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\\n\"\n        ))\n    })?;\n    let data = if let DType::QFloat(_) = data.dtype {\n        data // do not convert quantized tensors\n    } else {\n        data.convert::<E>()\n    };\n    Ok(data)\n}\n\n/// This struct implements serde to lazily serialize and deserialize a float tensor\n/// using the given [record settings](RecordSettings).\n#[derive(new, Clone, Debug)]\npub struct FloatTensorSerde<S: PrecisionSettings> {\n    data: TensorData,\n    _e: PhantomData<S::FloatElem>,\n}\n\n/// This struct implements serde to lazily serialize and deserialize an int tensor\n/// using the given [record settings](RecordSettings).\n#[derive(new, Clone, Debug)]\npub struct IntTensorSerde<S: PrecisionSettings> {\n    data: TensorData,\n    _e: PhantomData<S::IntElem>,\n}\n\n/// This struct implements serde to lazily serialize and deserialize an bool tensor.\n#[derive(new, Clone, Debug)]\npub struct BoolTensorSerde {\n    data: TensorData,\n}\n\n// --- SERDE IMPLEMENTATIONS --- //\n\nimpl<S: PrecisionSettings> Serialize for FloatTensorSerde<S> {\n    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>\n    where\n        Se: serde::Serializer,\n    {\n        self.data.serialize(serializer)\n    }\n}\n\nimpl<'de, S: PrecisionSettings> Deserialize<'de> for FloatTensorSerde<S> {\n    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>\n    where\n        De: serde::Deserializer<'de>,\n    {\n        let data = deserialize_data::<S::FloatElem, De>(deserializer)?;\n\n        Ok(Self::new(data))\n    }\n}\n\nimpl<S: PrecisionSettings> Serialize for IntTensorSerde<S> {\n    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>\n    where\n        Se: serde::Serializer,\n    {\n        self.data.serialize(serializer)\n    }\n}\n\nimpl<'de, S: PrecisionSettings> Deserialize<'de> for IntTensorSerde<S> {\n    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>\n    where\n        De: serde::Deserializer<'de>,\n    {\n        let data = deserialize_data::<S::IntElem, De>(deserializer)?;\n\n        Ok(Self::new(data))\n    }\n}\n\nimpl Serialize for BoolTensorSerde {\n    fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>\n    where\n        Se: serde::Serializer,\n    {\n        self.data.serialize(serializer)\n    }\n}\n\nimpl<'de> Deserialize<'de> for BoolTensorSerde {\n    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>\n    where\n        De: serde::Deserializer<'de>,\n    {\n        let data = deserialize_data::<bool, De>(deserializer)?;\n\n        Ok(Self::new(data))\n    }\n}\n\n// --- RECORD IMPLEMENTATIONS --- //\n\nimpl<B: Backend, const D: usize> Record<B> for Tensor<B, D> {\n    type Item<S: PrecisionSettings> = FloatTensorSerde<S>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        let data = self.into_data();\n        let data = if let DType::QFloat(_) = data.dtype {\n            data // do not convert quantized tensors\n        } else {\n            data.convert::<S::FloatElem>()\n        };\n        FloatTensorSerde::new(data)\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        let data = if let DType::QFloat(_) = item.data.dtype {\n            item.data // do not convert quantized tensors\n        } else {\n            item.data.convert::<B::FloatElem>()\n        };\n        Tensor::from_data(data, device)\n    }\n}\n\nimpl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Int> {\n    type Item<S: PrecisionSettings> = IntTensorSerde<S>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        IntTensorSerde::new(self.into_data().convert::<S::IntElem>())\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        Tensor::from_data(item.data.convert::<B::IntElem>(), device)\n    }\n}\n\nimpl<B: Backend, const D: usize> Record<B> for Tensor<B, D, Bool> {\n    type Item<S: PrecisionSettings> = BoolTensorSerde;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        BoolTensorSerde::new(self.into_data())\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        Tensor::from_data(item.data, device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/src/tensor.rs",
    "content": "pub use burn_tensor::*;\n"
  },
  {
    "path": "crates/burn-core/src/vision.rs",
    "content": "pub use burn_vision::*;\n"
  },
  {
    "path": "crates/burn-core/tests/test_derive_config.rs",
    "content": "use burn::config::{Config, config_to_json};\nuse burn_core as burn;\n\n#[derive(Config, Debug, PartialEq, Eq)]\npub struct TestEmptyStructConfig {}\n\n#[derive(Config, Debug, PartialEq)]\npub struct TestStructConfig {\n    int: i32,\n    #[config(default = 2)]\n    int_default: i32,\n    float: f32,\n    #[config(default = 2.0)]\n    float_default: f32,\n    string: String,\n    other_config: TestEmptyStructConfig,\n}\n\n#[derive(Config, Debug, PartialEq)]\npub enum TestEnumConfig {\n    None,\n    Single(f32),\n    Multiple(f32, String),\n    Named { first: f32, second: String },\n}\n\n#[cfg(feature = \"std\")]\n#[inline(always)]\nfn file_path(file_name: &str) -> std::path::PathBuf {\n    std::env::temp_dir().join(file_name)\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn struct_config_should_impl_serde() {\n    let config = TestStructConfig::new(2, 3.0, \"Allow\".to_string(), TestEmptyStructConfig::new());\n    let file_path = file_path(\"test_struct_config.json\");\n\n    config.save(&file_path).unwrap();\n\n    let config_loaded = TestStructConfig::load(&file_path).unwrap();\n    assert_eq!(config, config_loaded);\n}\n\n#[test]\nfn struct_config_should_impl_clone() {\n    let config = TestStructConfig::new(2, 3.0, \"Allow\".to_string(), TestEmptyStructConfig::new());\n    assert_eq!(config, config.clone());\n}\n\n#[test]\nfn struct_config_should_impl_display() {\n    let config = TestStructConfig::new(2, 3.0, \"Allow\".to_string(), TestEmptyStructConfig::new());\n    assert_eq!(burn::config::config_to_json(&config), config.to_string());\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn enum_config_no_value_should_impl_serde() {\n    let config = TestEnumConfig::None;\n    let file_path = file_path(\"test_enum_no_value_config.json\");\n\n    config.save(&file_path).unwrap();\n\n    let config_loaded = TestEnumConfig::load(&file_path).unwrap();\n    assert_eq!(config, config_loaded);\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn enum_config_one_value_should_impl_serde() {\n    let config = TestEnumConfig::Single(42.0);\n    let file_path = file_path(\"test_enum_one_value_config.json\");\n\n    config.save(&file_path).unwrap();\n\n    let config_loaded = TestEnumConfig::load(&file_path).unwrap();\n    assert_eq!(config, config_loaded);\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn enum_config_multiple_values_should_impl_serde() {\n    let config = TestEnumConfig::Multiple(42.0, \"Allow\".to_string());\n    let file_path = file_path(\"test_enum_multiple_values_config.json\");\n\n    config.save(&file_path).unwrap();\n\n    let config_loaded = TestEnumConfig::load(&file_path).unwrap();\n    assert_eq!(config, config_loaded);\n}\n\n#[test]\nfn enum_config_should_impl_clone() {\n    let config = TestEnumConfig::Multiple(42.0, \"Allow\".to_string());\n    assert_eq!(config, config.clone());\n}\n\n#[test]\nfn enum_config_should_impl_display() {\n    let config = TestEnumConfig::Multiple(42.0, \"Allow\".to_string());\n    assert_eq!(burn::config::config_to_json(&config), config.to_string());\n}\n\n#[test]\nfn struct_config_can_load_binary() {\n    let config = TestStructConfig::new(2, 3.0, \"Allow\".to_string(), TestEmptyStructConfig::new());\n\n    let binary = config_to_json(&config).as_bytes().to_vec();\n\n    let config_loaded = TestStructConfig::load_binary(&binary).unwrap();\n    assert_eq!(config, config_loaded);\n}\n"
  },
  {
    "path": "crates/burn-core/tests/test_derive_module.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn::module::Initializer;\nuse burn::module::{Module, Param};\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{Int, Tensor};\nuse burn_core as burn;\n\npub type TestBackend = burn_ndarray::NdArray<f32>;\n#[cfg(feature = \"std\")]\npub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;\n\n#[derive(Module, Debug)]\npub struct ModuleBasic<B: Backend> {\n    weight_basic: Param<Tensor<B, 2>>,\n}\n\n#[derive(Module, Debug)]\n#[allow(unused)]\nstruct ModuleTensorConstInt<B: Backend> {\n    weight_basic: Tensor<B, 2, Int>,\n}\n\nimpl<B: Backend> ModuleBasic<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight_basic: Initializer::Normal {\n                std: 1.0,\n                mean: 0.0,\n            }\n            .init([20, 20], device),\n        }\n    }\n}\n\n#[derive(Module, Debug)]\nstruct ModuleWithConstGeneric<B: Backend, const N: usize> {\n    modules: [ModuleBasic<B>; N],\n}\n\n#[derive(Module, Debug)]\nstruct ModuleWithGenericModule<B: Backend, M> {\n    module: M,\n    _backend: PhantomData<B>,\n}\n\n#[derive(Module, Debug)]\n#[allow(clippy::large_enum_variant)]\nenum ModuleEnum<B: Backend> {\n    Basic(ModuleBasic<B>),\n    Composed(ModuleComposed<B>),\n}\n\n#[derive(Module, Debug)]\n#[allow(unused)]\nenum ModuleEnumNested<B: Backend> {\n    AnotherEnum(ModuleEnum<B>),\n}\n\n#[derive(Module, Debug)]\nenum ModuleEnumWithGenericModule<B: Backend, M: Module<B>> {\n    Basic(ModuleBasic<B>),\n    Generic(ModuleWithGenericModule<B, M>),\n}\n\n#[derive(Module, Debug)]\npub struct ModuleComposed<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    basic: ModuleBasic<B>,\n    tuple: (ModuleBasic<B>, ModuleBasic<B>),\n}\n\nimpl<B: Backend> ModuleComposed<B> {\n    fn new(device: &B::Device) -> Self {\n        let weight = Initializer::Normal {\n            std: 1.0,\n            mean: 0.0,\n        }\n        .init([20, 20], device);\n\n        Self {\n            weight,\n            basic: ModuleBasic::new(device),\n            tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),\n        }\n    }\n}\n\n#[derive(Debug, Clone)]\npub enum PaddingConfig {\n    Default,\n    Other,\n}\n\n#[derive(Module, Debug)]\npub struct ModuleWithAttributes<B: Backend, M: Module<B>, N> {\n    /// A normal parameter.\n    weight: Param<Tensor<B, 2>>,\n    /// A nested module.\n    nested: ModuleEnumWithGenericModule<B, M>,\n    /// By default, primitives were not persistent (same as `#[module(skip)]`).\n    other_prob: f64,\n    /// By default, tensors were not persistent and not visited/mapped (same as `#[module(skip)]`).\n    tensor: Tensor<B, 1>,\n    /// A field that is recomputed at runtime.\n    #[module(skip)]\n    cached_mask: Option<Tensor<B, 2>>,\n    /// A field that contains some debug state.\n    debug_state: String,\n    /// Hint required: this generic is NOT a module.\n    #[module(skip)]\n    config: N,\n}\n\nimpl<B: Backend> ModuleWithAttributes<B, ModuleBasic<B>, PaddingConfig> {\n    fn new(device: &B::Device) -> Self {\n        let basic = ModuleBasic::new(device);\n        let weight = basic.weight_basic.clone();\n\n        Self {\n            weight,\n            nested: ModuleEnumWithGenericModule::Basic(basic),\n            other_prob: 1.,\n            tensor: Tensor::ones([2], device),\n            cached_mask: Some(Tensor::ones([2, 2], device)),\n            debug_state: \"Hello World\".into(),\n            config: PaddingConfig::Default,\n        }\n    }\n}\n\n#[allow(dead_code)]\nmod compiletime_clone_impl_check {\n    use burn_core::{\n        module::{Module, ModuleDisplay},\n        prelude::Backend,\n        record::{PrecisionSettings, Record},\n    };\n\n    use super::*;\n\n    type RecordItem<M, B, S> = <<M as Module<B>>::Record as Record<B>>::Item<S>;\n\n    fn implements_clone<T: Clone>() {}\n\n    fn basic_implements_clone<B: Backend, S: PrecisionSettings>() {\n        implements_clone::<RecordItem<ModuleBasic<B>, B, S>>();\n        implements_clone::<RecordItem<ModuleComposed<B>, B, S>>();\n    }\n\n    fn generic_implements_clone<B, S, M>()\n    where\n        B: Backend,\n        S: PrecisionSettings,\n        M: Module<B> + ModuleDisplay,\n        RecordItem<M, B, S>: Clone,\n    {\n        implements_clone::<RecordItem<ModuleWithGenericModule<B, M>, B, S>>();\n        implements_clone::<RecordItem<ModuleEnumWithGenericModule<B, M>, B, S>>();\n    }\n}\n\nmod state {\n    use burn_core::module::EmptyRecord;\n\n    use super::*;\n\n    #[test]\n    fn should_load_from_record_basic() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module_1 = ModuleBasic::<TestBackend>::new(&device);\n        let mut module_2 = ModuleBasic::<TestBackend>::new(&device);\n        let state_1 = module_1.clone().into_record();\n\n        assert_ne!(\n            module_1.weight_basic.to_data(),\n            module_2.weight_basic.to_data()\n        );\n\n        module_2 = module_2.load_record(state_1);\n\n        assert_eq!(\n            module_1.weight_basic.to_data(),\n            module_2.weight_basic.to_data()\n        );\n    }\n\n    #[test]\n    fn should_load_from_record_compose() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module_1 = ModuleComposed::<TestBackend>::new(&device);\n        let mut module_2 = ModuleComposed::<TestBackend>::new(&device);\n        assert_ne!(module_1.weight.to_data(), module_2.weight.to_data());\n        assert_ne!(\n            module_1.basic.weight_basic.to_data(),\n            module_2.basic.weight_basic.to_data()\n        );\n\n        let state_1 = module_1.clone().into_record();\n        module_2 = module_2.load_record(state_1);\n\n        assert_eq!(module_1.weight.to_data(), module_2.weight.to_data());\n        assert_eq!(\n            module_1.basic.weight_basic.to_data(),\n            module_2.basic.weight_basic.to_data()\n        );\n    }\n\n    #[test]\n    fn should_load_from_record_enum() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));\n        let mut module_2 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));\n        let state_1 = module_1.clone().into_record();\n\n        let ModuleEnum::Basic(module_1_basic) = module_1 else {\n            panic!(\"Invalid module type\")\n        };\n        let ModuleEnum::Basic(module_2_basic) = module_2.clone() else {\n            panic!(\"Invalid module type\")\n        };\n        assert_ne!(\n            module_1_basic.weight_basic.to_data(),\n            module_2_basic.weight_basic.to_data()\n        );\n\n        module_2 = module_2.load_record(state_1);\n\n        let ModuleEnum::Basic(module_2_basic) = module_2 else {\n            panic!(\"Invalid module type\")\n        };\n        assert_eq!(\n            module_1_basic.weight_basic.to_data(),\n            module_2_basic.weight_basic.to_data()\n        );\n    }\n\n    #[test]\n    fn should_load_from_record_based_on_attributes() {\n        let device = <TestBackend as Backend>::Device::default();\n        let mut module_1 = ModuleWithAttributes::<TestBackend, _, _>::new(&device);\n        let mut module_2 = ModuleWithAttributes::new(&device);\n\n        assert_ne!(module_1.weight.to_data(), module_2.weight.to_data(),);\n\n        let ModuleEnumWithGenericModule::Basic(ref m1_basic) = module_1.nested else {\n            panic!(\"Invalid module type\")\n        };\n        let ModuleEnumWithGenericModule::Basic(ref m2_basic) = module_2.nested else {\n            panic!(\"Invalid module type\")\n        };\n\n        assert_ne!(\n            m1_basic.weight_basic.to_data(),\n            m2_basic.weight_basic.to_data(),\n        );\n\n        assert_eq!(module_1.tensor.to_data(), module_2.tensor.to_data());\n        assert_eq!(\n            module_1.cached_mask.as_ref().unwrap().to_data(),\n            module_2.cached_mask.as_ref().unwrap().to_data()\n        );\n\n        assert_eq!(module_1.other_prob, module_2.other_prob);\n        assert_eq!(module_1.debug_state, module_2.debug_state);\n\n        // Alter state of skipped fields to validate persistence\n        module_1.cached_mask = Some(module_1.cached_mask.unwrap() * 2);\n        module_1.tensor = module_1.tensor * 2;\n        module_1.other_prob = 0.;\n        module_1.debug_state = \"Hello World!\".into();\n        module_1.config = PaddingConfig::Other;\n\n        let state_1 = module_1.clone().into_record();\n\n        assert_eq!(state_1.cached_mask, EmptyRecord);\n        assert_eq!(state_1.other_prob, EmptyRecord);\n        assert_eq!(state_1.debug_state, EmptyRecord);\n        assert_eq!(state_1.config, EmptyRecord);\n\n        module_2 = module_2.load_record(state_1);\n\n        let ModuleEnumWithGenericModule::Basic(m2_basic) = module_2.nested else {\n            panic!(\"Invalid module type\")\n        };\n\n        // Modules & params\n        assert_eq!(module_1.weight.to_data(), module_2.weight.to_data(),);\n        assert_eq!(\n            m1_basic.weight_basic.to_data(),\n            m2_basic.weight_basic.to_data(),\n        );\n\n        // `#[module(skip)]` field and other skip-by-default\n        assert_ne!(module_1.other_prob, module_2.other_prob);\n        assert_ne!(module_1.debug_state, module_2.debug_state);\n        assert!(matches!(module_1.config, PaddingConfig::Other));\n        assert!(matches!(module_2.config, PaddingConfig::Default));\n        assert_ne!(module_1.tensor.to_data(), module_2.tensor.to_data());\n        assert_ne!(\n            module_1.cached_mask.as_ref().unwrap().to_data(),\n            module_2.cached_mask.as_ref().unwrap().to_data()\n        );\n    }\n\n    #[test]\n    fn should_load_from_record_const_generic() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module_1 = ModuleWithConstGeneric {\n            modules: [\n                ModuleBasic::<TestBackend>::new(&device),\n                ModuleBasic::<TestBackend>::new(&device),\n            ],\n        };\n        let mut module_2 = ModuleWithConstGeneric {\n            modules: [\n                ModuleBasic::<TestBackend>::new(&device),\n                ModuleBasic::<TestBackend>::new(&device),\n            ],\n        };\n        let state_1 = module_1.clone().into_record();\n\n        assert_ne!(\n            module_1.modules[0].weight_basic.to_data(),\n            module_2.modules[0].weight_basic.to_data(),\n        );\n        assert_ne!(\n            module_1.modules[1].weight_basic.to_data(),\n            module_2.modules[1].weight_basic.to_data(),\n        );\n\n        module_2 = module_2.load_record(state_1);\n\n        assert_eq!(\n            module_1.modules[0].weight_basic.to_data(),\n            module_2.modules[0].weight_basic.to_data(),\n        );\n        assert_eq!(\n            module_1.modules[1].weight_basic.to_data(),\n            module_2.modules[1].weight_basic.to_data(),\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"Can't parse record from a different variant\")]\n    fn should_panic_load_from_incorrect_enum_variant() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module_1 = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));\n        let module_2 = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));\n        let state_1 = module_1.clone().into_record();\n\n        module_2.load_record(state_1);\n    }\n}\n\nmod num_params {\n    use super::*;\n\n    #[test]\n    fn should_calculate_num_params_basic() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module = ModuleBasic::<TestBackend>::new(&device);\n        assert_eq!(20 * 20, module.num_params());\n    }\n\n    #[test]\n    fn should_output_state_composed() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module = ModuleComposed::<TestBackend>::new(&device);\n        assert_eq!(4 * 20 * 20, module.num_params());\n    }\n\n    #[test]\n    fn should_calculate_num_params_enum() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module = ModuleEnum::Basic(ModuleBasic::<TestBackend>::new(&device));\n        assert_eq!(20 * 20, module.num_params());\n\n        let module = ModuleEnum::Composed(ModuleComposed::<TestBackend>::new(&device));\n        assert_eq!(4 * 20 * 20, module.num_params());\n    }\n\n    #[test]\n    fn should_calculate_num_params_based_on_attributes() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module = ModuleWithAttributes::<TestBackend, _, _>::new(&device);\n        assert_eq!(20 * 20 * 2, module.num_params());\n    }\n}\n\n#[cfg(feature = \"std\")]\nmod require_grad {\n    use burn_tensor::backend::AutodiffBackend;\n\n    use super::*;\n\n    #[test]\n    fn should_have_grad_by_default() {\n        let device = <TestBackend as Backend>::Device::default();\n        let module = ModuleBasic::<TestAutodiffBackend>::new(&device);\n        let mut grads = calculate_grads(&module);\n\n        let grad_x = module.weight_basic.grad_remove(&mut grads);\n\n        assert!(grad_x.is_some());\n    }\n\n    #[test]\n    fn should_have_no_grad_after_no_grad() {\n        let device = <TestAutodiffBackend as Backend>::Device::default();\n        let module = ModuleBasic::<TestAutodiffBackend>::new(&device).no_grad();\n        let mut grads = calculate_grads(&module);\n\n        let grad_x = module.weight_basic.grad_remove(&mut grads);\n\n        assert!(grad_x.is_none());\n    }\n\n    #[test]\n    fn should_have_grad_when_from_record() {\n        let device = <TestAutodiffBackend as Backend>::Device::default();\n        let module = ModuleBasic::<TestAutodiffBackend>::new(&device);\n        let record = ModuleBasicRecord {\n            weight_basic: module.weight_basic.clone(), // Even when param is no_grad,\n        };\n        let module = module.load_record(record);\n        let mut grads = calculate_grads(&module);\n\n        let grad_x = module.weight_basic.grad_remove(&mut grads);\n\n        assert!(grad_x.is_some());\n    }\n\n    fn calculate_grads(\n        module: &ModuleBasic<TestAutodiffBackend>,\n    ) -> <TestAutodiffBackend as AutodiffBackend>::Gradients {\n        let device = module.weight_basic.device();\n        let x = Tensor::ones([20, 20], &device).require_grad();\n        let y = module.weight_basic.val().matmul(x);\n\n        y.backward()\n    }\n}\n"
  },
  {
    "path": "crates/burn-core/tests/test_derive_record.rs",
    "content": "use burn_core as burn;\nuse burn_core::record::Record;\n\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\n// It compiles\n#[derive(Record)]\npub struct TestWithBackendRecord<B: Backend> {\n    tensor: Tensor<B, 2>,\n}\n\n// It compiles\n#[derive(Record)]\npub struct TestWithoutBackendRecord {\n    _tensor: usize,\n}\n"
  },
  {
    "path": "crates/burn-core/tests/test_record_resilience.rs",
    "content": "#[cfg(feature = \"std\")]\nmod tests {\n    use burn::{\n        module::{Module, Param},\n        record::{\n            BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings,\n            PrettyJsonFileRecorder, RecorderError,\n        },\n    };\n    use burn_core as burn;\n    use burn_ndarray::NdArrayDevice;\n    use burn_tensor::{Tensor, backend::Backend};\n    use std::path::PathBuf;\n\n    type TestBackend = burn_ndarray::NdArray<f32>;\n\n    /// Simple linear module.\n    #[derive(Module, Debug)]\n    pub struct Linear<B: Backend> {\n        pub weight: Param<Tensor<B, 2>>,\n        pub bias: Option<Param<Tensor<B, 1>>>,\n    }\n\n    impl<B: Backend> Linear<B> {\n        pub fn new(in_features: usize, out_features: usize, device: &B::Device) -> Self {\n            let weight = Tensor::random(\n                [out_features, in_features],\n                burn_tensor::Distribution::Default,\n                device,\n            );\n            let bias = Tensor::random([out_features], burn_tensor::Distribution::Default, device);\n\n            Self {\n                weight: Param::from_tensor(weight),\n                bias: Some(Param::from_tensor(bias)),\n            }\n        }\n    }\n\n    #[derive(Module, Debug)]\n    pub struct Model<B: Backend> {\n        single_const: f32,\n        linear1: Linear<B>,\n        array_const: [usize; 2],\n        linear2: Linear<B>,\n        array_lin: [Linear<B>; 2],\n    }\n\n    #[derive(Module, Debug)]\n    pub struct ModelNewOptionalField<B: Backend> {\n        single_const: f32,\n        linear1: Linear<B>,\n        array_const: [usize; 2],\n        linear2: Linear<B>,\n        array_lin: [Linear<B>; 2],\n        new_field: Option<usize>,\n    }\n\n    #[derive(Module, Debug)]\n    pub struct ModelNewConstantField<B: Backend> {\n        single_const: f32,\n        linear1: Linear<B>,\n        array_const: [usize; 2],\n        linear2: Linear<B>,\n        array_lin: [Linear<B>; 2],\n        new_field: usize,\n    }\n\n    #[derive(Module, Debug)]\n    #[allow(unused)]\n    pub struct ModelNewFieldOrders<B: Backend> {\n        array_const: [usize; 2],\n        linear2: Linear<B>,\n        single_const: f32,\n        array_lin: [Linear<B>; 2],\n        linear1: Linear<B>,\n    }\n\n    #[test]\n    fn deserialize_with_new_optional_field_works_with_default_file_recorder() {\n        deserialize_with_new_optional_field(\n            \"default\",\n            DefaultFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_optional_field_works_with_default_file_recorder() {\n        deserialize_with_removed_optional_field(\n            \"default\",\n            DefaultFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_constant_field_works_with_default_file_recorder() {\n        deserialize_with_new_constant_field(\n            \"default\",\n            DefaultFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_constant_field_works_with_default_file_recorder() {\n        deserialize_with_removed_constant_field(\n            \"default\",\n            DefaultFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_field_order_works_with_default_file_recorder() {\n        deserialize_with_new_field_order(\n            \"default\",\n            DefaultFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n    #[test]\n    fn deserialize_with_new_optional_field_works_with_pretty_json() {\n        deserialize_with_new_optional_field(\n            \"pretty-json\",\n            PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_optional_field_works_with_pretty_json() {\n        deserialize_with_removed_optional_field(\n            \"pretty-json\",\n            PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_constant_field_works_with_pretty_json() {\n        deserialize_with_new_constant_field(\n            \"pretty-json\",\n            PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_constant_field_works_with_pretty_json() {\n        deserialize_with_removed_constant_field(\n            \"pretty-json\",\n            PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_field_order_works_with_pretty_json() {\n        deserialize_with_new_field_order(\n            \"pretty-json\",\n            PrettyJsonFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_optional_field_works_with_bin_file_recorder() {\n        deserialize_with_new_optional_field(\"bin\", BinFileRecorder::<FullPrecisionSettings>::new())\n            .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() {\n        deserialize_with_removed_optional_field(\n            \"bin\",\n            BinFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_constant_field_works_with_bin_file_recorder() {\n        deserialize_with_new_constant_field(\"bin\", BinFileRecorder::<FullPrecisionSettings>::new())\n            .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() {\n        deserialize_with_removed_constant_field(\n            \"bin\",\n            BinFileRecorder::<FullPrecisionSettings>::new(),\n        )\n        .unwrap();\n    }\n\n    #[test]\n    fn deserialize_with_new_field_order_works_with_bin_file_recorder() {\n        deserialize_with_new_field_order(\"bin\", BinFileRecorder::<FullPrecisionSettings>::new())\n            .unwrap();\n    }\n\n    #[inline(always)]\n    fn file_path(filename: String) -> PathBuf {\n        std::env::temp_dir().join(filename)\n    }\n\n    #[test]\n    fn test_tensor_serde() {\n        let tensor: burn_tensor::Tensor<TestBackend, 1> =\n            burn_tensor::Tensor::ones([1], &NdArrayDevice::default());\n        let encoded = serde_json::to_string(&tensor).unwrap();\n        let decoded: burn_tensor::Tensor<TestBackend, 1> = serde_json::from_str(&encoded).unwrap();\n        assert_eq!(tensor.into_data(), decoded.into_data());\n    }\n\n    fn deserialize_with_new_optional_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>\n    where\n        R: FileRecorder<TestBackend>,\n    {\n        let device = Default::default();\n        let file_path: PathBuf = file_path(format!(\"deserialize_with_new_optional_field-{name}\"));\n        let model = Model {\n            single_const: 32.0,\n            linear1: Linear::<TestBackend>::new(20, 20, &device),\n            array_const: [2, 2],\n            linear2: Linear::<TestBackend>::new(20, 20, &device),\n            array_lin: [\n                Linear::<TestBackend>::new(20, 20, &device),\n                Linear::<TestBackend>::new(20, 20, &device),\n            ],\n        };\n\n        recorder\n            .record(model.into_record(), file_path.clone())\n            .unwrap();\n        let result =\n            recorder.load::<ModelNewOptionalFieldRecord<TestBackend>>(file_path.clone(), &device);\n        std::fs::remove_file(file_path).ok();\n\n        result?;\n        Ok(())\n    }\n\n    fn deserialize_with_removed_optional_field<R>(\n        name: &str,\n        recorder: R,\n    ) -> Result<(), RecorderError>\n    where\n        R: FileRecorder<TestBackend>,\n    {\n        let device = Default::default();\n        let file_path: PathBuf =\n            file_path(format!(\"deserialize_with_removed_optional_field-{name}\"));\n        let model = ModelNewOptionalField {\n            single_const: 32.0,\n            linear1: Linear::<TestBackend>::new(20, 20, &device),\n            array_const: [2, 2],\n            linear2: Linear::<TestBackend>::new(20, 20, &device),\n            array_lin: [\n                Linear::<TestBackend>::new(20, 20, &device),\n                Linear::<TestBackend>::new(20, 20, &device),\n            ],\n            new_field: None,\n        };\n\n        recorder\n            .record(model.into_record(), file_path.clone())\n            .unwrap();\n        let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);\n        std::fs::remove_file(file_path).ok();\n\n        result?;\n        Ok(())\n    }\n\n    fn deserialize_with_new_constant_field<R>(name: &str, recorder: R) -> Result<(), RecorderError>\n    where\n        R: FileRecorder<TestBackend>,\n    {\n        let device = Default::default();\n        let file_path: PathBuf = file_path(format!(\"deserialize_with_new_constant_field-{name}\"));\n        let model = Model {\n            single_const: 32.0,\n            array_const: [2, 2],\n            linear1: Linear::<TestBackend>::new(20, 20, &device),\n            linear2: Linear::<TestBackend>::new(20, 20, &device),\n            array_lin: [\n                Linear::<TestBackend>::new(20, 20, &device),\n                Linear::<TestBackend>::new(20, 20, &device),\n            ],\n        };\n\n        recorder\n            .record(model.into_record(), file_path.clone())\n            .unwrap();\n        let result =\n            recorder.load::<ModelNewConstantFieldRecord<TestBackend>>(file_path.clone(), &device);\n        std::fs::remove_file(file_path).ok();\n\n        result?;\n        Ok(())\n    }\n\n    fn deserialize_with_removed_constant_field<R>(\n        name: &str,\n        recorder: R,\n    ) -> Result<(), RecorderError>\n    where\n        R: FileRecorder<TestBackend>,\n    {\n        let device = Default::default();\n        let file_path: PathBuf =\n            file_path(format!(\"deserialize_with_removed_constant_field-{name}\"));\n        let model = ModelNewConstantField {\n            single_const: 32.0,\n            array_const: [2, 2],\n            linear1: Linear::<TestBackend>::new(20, 20, &device),\n            linear2: Linear::<TestBackend>::new(20, 20, &device),\n            array_lin: [\n                Linear::<TestBackend>::new(20, 20, &device),\n                Linear::<TestBackend>::new(20, 20, &device),\n            ],\n            new_field: 0,\n        };\n\n        recorder\n            .record(model.into_record(), file_path.clone())\n            .unwrap();\n        let result = recorder.load::<ModelRecord<TestBackend>>(file_path.clone(), &device);\n        std::fs::remove_file(file_path).ok();\n\n        result?;\n        Ok(())\n    }\n\n    fn deserialize_with_new_field_order<R>(name: &str, recorder: R) -> Result<(), RecorderError>\n    where\n        R: FileRecorder<TestBackend>,\n    {\n        let device = Default::default();\n        let file_path: PathBuf = file_path(format!(\"deserialize_with_new_field_order-{name}\"));\n        let model = Model {\n            array_const: [2, 2],\n            single_const: 32.0,\n            linear1: Linear::<TestBackend>::new(20, 20, &device),\n            linear2: Linear::<TestBackend>::new(20, 20, &device),\n            array_lin: [\n                Linear::<TestBackend>::new(20, 20, &device),\n                Linear::<TestBackend>::new(20, 20, &device),\n            ],\n        };\n\n        recorder\n            .record(model.into_record(), file_path.clone())\n            .unwrap();\n\n        let result =\n            recorder.load::<ModelNewFieldOrdersRecord<TestBackend>>(file_path.clone(), &device);\n        std::fs::remove_file(file_path).ok();\n\n        result?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cpu/Cargo.toml",
    "content": "[package]\nauthors = [\"marcantoinem <marc-antoine.m@outlook.com>\"]\ncategories = [\"science\"]\ndescription = \"MLIR based CPU backend for the Burn framework\"\ndocumentation = \"https://docs.rs/burn-cpu\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"cpu\"]\nlicense.workspace = true\nname = \"burn-cpu\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-cpu\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\", \"fusion\", \"autotune\", \"burn-cubecl/default\", \"cubecl/default\"]\ndoc = [\"burn-cubecl/doc\"]\nfusion = [\"burn-fusion\", \"burn-cubecl/fusion\"]\nstd = [\"burn-cubecl/std\", \"cubecl/std\"]\ntracing = [\n    \"burn-backend/tracing\",\n    \"burn-cubecl/tracing\",\n    \"burn-fusion?/tracing\",\n    \"cubecl/tracing\",\n]\n\nautotune = [\"burn-cubecl/autotune\"]\nautotune-checks = [\"burn-cubecl/autotune-checks\"]\n\n[dependencies]\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", features = [\n    \"cubecl-cpu\",\n] }\ncubecl = { workspace = true, features = [\"cpu\"] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-cpu/README.md",
    "content": "# Burn CPU Backend\n\n[Burn](https://github.com/tracel-ai/burn) CubeCL CPU backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-cuda.svg)](https://crates.io/crates/burn-cuda)\n\nThis crate provides a MLIR based CPU backend for [Burn](https://github.com/tracel-ai/burn) using the\n[cubecl](https://github.com/tracel-ai/cubecl.git) crates.\n\n## Usage Example\n\nExample coming soon\n"
  },
  {
    "path": "crates/burn-cpu/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n\nextern crate alloc;\n\nuse burn_cubecl::CubeBackend;\npub use cubecl::cpu::CpuDevice;\nuse cubecl::cpu::CpuRuntime;\n\n#[cfg(not(feature = \"fusion\"))]\npub type Cpu<F = f32, I = i32> = CubeBackend<CpuRuntime, F, I, u8>;\n\n#[cfg(feature = \"fusion\")]\npub type Cpu<F = f32, I = i32> = burn_fusion::Fusion<CubeBackend<CpuRuntime, F, I, u8>>;\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive};\n    use burn_cubecl::tensor::CubeTensor;\n\n    #[test]\n    fn should_support_dtypes() {\n        type B = Cpu;\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F64));\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::F16));\n        assert!(B::supports_dtype(&device, DType::BF16));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::I16));\n        assert!(B::supports_dtype(&device, DType::I8));\n        assert!(B::supports_dtype(&device, DType::U64));\n        assert!(B::supports_dtype(&device, DType::U32));\n        assert!(B::supports_dtype(&device, DType::U16));\n        assert!(B::supports_dtype(&device, DType::U8));\n        assert!(B::supports_dtype(\n            &device,\n            DType::QFloat(CubeTensor::<CpuRuntime>::default_scheme())\n        ));\n\n        // Currently not registered in supported types\n        assert!(!B::supports_dtype(&device, DType::Flex32));\n        assert!(!B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Generic backend that can be compiled just-in-time to any shader language target\"\ndocumentation = \"https://docs.rs/burn-cubecl\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\"]\nlicense.workspace = true\nname = \"burn-cubecl\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"autotune\",\n    \"std\",\n    \"fusion\",\n    \"cubecl/default\",\n    \"burn-fusion?/default\",\n    \"burn-cubecl-fusion?/default\",\n]\nstd = [\n    \"cubecl/std\",\n    \"burn-backend/std\",\n    \"burn-fusion?/std\",\n    \"burn-cubecl-fusion?/std\",\n]\ndoc = [\"default\"]\nmemory-checks = [\"burn-fusion?/memory-checks\"]\ntracing = [\n    \"dep:tracing\",\n    \"cubecl/tracing\",\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n    \"burn-fusion?/tracing\",\n    \"burn-cubecl-fusion?/tracing\",\n]\n\nautotune = [\"burn-cubecl-fusion?/autotune\"]\nautotune-checks = [\n    \"autotune\",\n    \"cubecl/autotune-checks\",\n    \"burn-cubecl-fusion?/autotune-checks\",\n]\n\nfusion = [\"burn-fusion\", \"burn-cubecl-fusion\"]\nfusion-experimental = [\"fusion\"]\n\ntemplate = []\n\n[dependencies]\nburn-cubecl-fusion = { path = \"../burn-cubecl-fusion\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"cubecl\",\n] }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"cubecl\",\n] }\ncubecl = { workspace = true, features = [\"stdlib\"] }\ncubek = { workspace = true, features = [\n    \"attention\",\n    \"matmul\",\n    \"convolution\",\n    \"reduce\",\n    \"random\",\n    \"quantization\",\n] }\ntracing = { workspace = true, features = [\"attributes\"], optional = true }\n\nderive-new = { workspace = true }\nlog = { workspace = true }\n\n# Async\nfutures-lite = { workspace = true, features = [\"std\"] }\n\n# Template\nserde = { workspace = true }\ntext_placeholder = { workspace = true, features = [\"struct_context\"] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-cubecl/README.md",
    "content": "# Burn CubeCL Backend\n\nGeneric backend that can be compiled just-in-time (JIT) to any shader language target.\n"
  },
  {
    "path": "crates/burn-cubecl/src/backend.rs",
    "content": "use crate::{CubeRuntime, FloatElement, IntElement, element::BoolElement, tensor::CubeTensor};\nuse burn_backend::{Backend, DTypeUsage, DTypeUsageSet, DeviceOps, ExecutionError, TensorData};\nuse burn_std::DType;\nuse cubecl::{\n    features::{MmaConfig, TypeUsage},\n    server::ComputeServer,\n};\nuse std::marker::PhantomData;\n\n#[cfg(not(feature = \"fusion\"))]\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};\n#[cfg(not(feature = \"fusion\"))]\nuse burn_ir::{BackendIr, TensorHandle};\n\n/// Generic tensor backend that can be compiled just-in-time to any shader runtime\n#[derive(new)]\npub struct CubeBackend<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> {\n    _runtime: PhantomData<R>,\n    _float_elem: PhantomData<F>,\n    _int_elem: PhantomData<I>,\n    _bool_elem: PhantomData<BT>,\n}\n\nimpl<R, F, I, BT> Backend for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    R::Server: ComputeServer,\n    R::Device: DeviceOps,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    type Device = R::Device;\n\n    type FloatElem = F;\n    type IntElem = I;\n    type BoolElem = BT;\n\n    type FloatTensorPrimitive = CubeTensor<R>;\n    type IntTensorPrimitive = CubeTensor<R>;\n    type BoolTensorPrimitive = CubeTensor<R>;\n    type QuantizedTensorPrimitive = CubeTensor<R>;\n\n    fn name(device: &Self::Device) -> String {\n        let client = R::client(device);\n        format!(\"cubecl<{}>\", R::name(&client))\n    }\n\n    fn seed(_device: &Self::Device, seed: u64) {\n        cubek::random::seed(seed);\n    }\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {\n        let client = R::client(device);\n        futures_lite::future::block_on(client.sync()).map_err(|err| ExecutionError::WithContext {\n            reason: format!(\"{err}\"),\n        })\n    }\n\n    fn memory_persistent_allocations<\n        Output: Send,\n        Input: Send,\n        Func: Fn(Input) -> Output + Send,\n    >(\n        device: &Self::Device,\n        input: Input,\n        func: Func,\n    ) -> Output {\n        let client = R::client(device);\n        client.memory_persistent_allocation(input, func).unwrap()\n    }\n\n    fn memory_cleanup(device: &Self::Device) {\n        let client = R::client(device);\n        client.memory_cleanup();\n    }\n\n    fn staging<'a, Iter>(data: Iter, device: &Self::Device)\n    where\n        Iter: Iterator<Item = &'a mut TensorData>,\n    {\n        let client = R::client(device);\n        client.staging(data.map(|td| &mut td.bytes), false);\n    }\n\n    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {\n        let client = R::client(device);\n\n        let type_usage = client.properties().type_usage(dtype.into());\n        // Same as `TypeUsage::all_scalar()`, but we make the usage explicit here\n        type_usage.is_superset(\n            TypeUsage::Buffer\n                | TypeUsage::Conversion\n                | TypeUsage::Arithmetic\n                | TypeUsage::DotProduct,\n        )\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> DTypeUsageSet {\n        let client = R::client(device);\n\n        let props = client.properties();\n        let storage = dtype.into();\n        let usage = props.type_usage(storage);\n\n        let mut out = DTypeUsageSet::new();\n\n        if usage.is_superset(TypeUsage::Buffer | TypeUsage::Conversion) {\n            out |= DTypeUsage::Storage;\n        }\n\n        if usage.contains(TypeUsage::Arithmetic) {\n            out |= DTypeUsage::Arithmetic;\n        }\n\n        let has_mma = |cfg: &MmaConfig| {\n            cfg.a_type == storage || cfg.b_type == storage || cfg.cd_type == storage\n        };\n        if props.features.cmma.iter().any(has_mma) || props.features.mma.iter().any(has_mma) {\n            out |= DTypeUsage::Accelerated;\n        }\n\n        out\n    }\n}\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> core::fmt::Debug\n    for CubeBackend<R, F, I, BT>\n{\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_str(\"CubeCLBackend\")\n    }\n}\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Clone\n    for CubeBackend<R, F, I, BT>\n{\n    fn clone(&self) -> Self {\n        Self::new()\n    }\n}\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Default\n    for CubeBackend<R, F, I, BT>\n{\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<R: cubecl::Runtime> CubeRuntime for R\nwhere\n    R::Device: DeviceOps,\n{\n    type CubeDevice = R::Device;\n    type CubeServer = R::Server;\n}\n\n#[cfg(not(feature = \"fusion\"))]\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr\n    for CubeBackend<R, F, I, BT>\n{\n    type Handle = CubeTensor<R>;\n\n    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {\n        handle.handle\n    }\n\n    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {\n        handle.handle\n    }\n\n    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {\n        handle.handle\n    }\n\n    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {\n        handle.handle\n    }\n\n    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {\n        tensor\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/element.rs",
    "content": "use burn_backend::{Element, bf16, f16};\nuse burn_std::DType;\nuse cubecl::{\n    CubeElement as CubeElem, flex32,\n    prelude::{Float, Int, Numeric},\n};\nuse cubek::{\n    matmul::definition::{MatmulPrecision, MatrixPrecision},\n    reduce::ReducePrecision,\n};\n\n/// The base element trait for the jit backend.\npub trait CubeElement: Element + CubeElem + PartialEq + Numeric {}\n\n/// Element that can be used for matrix multiplication. Includes ints and floats.\npub trait MatmulElement:\n    CubeElement + MatmulPrecision<Acc: MatrixPrecision<Global: CubeElement>>\n{\n}\n\n/// The float element type for the jit backend.\npub trait FloatElement: MatmulElement + Float {}\n\n/// The int element type for the jit backend.\npub trait IntElement:\n    MatmulElement + Int + ReducePrecision<EI: CubeElement, EA: CubeElement>\n{\n}\n\n/// The element type for booleans for the jit backend.\npub trait BoolElement: CubeElement + Int {\n    /// The true value for the boolean element.\n    fn true_val() -> Self {\n        Self::from_int(1)\n    }\n\n    /// The false value for the boolean element.\n    fn false_val() -> Self {\n        Self::from_int(0)\n    }\n\n    /// New bool element from Rust bool.\n    fn new_bool(val: bool) -> Self {\n        match val {\n            true => Self::true_val(),\n            false => Self::false_val(),\n        }\n    }\n}\n\nimpl CubeElement for u64 {}\nimpl CubeElement for u32 {}\nimpl CubeElement for u16 {}\nimpl CubeElement for u8 {}\nimpl CubeElement for i64 {}\nimpl CubeElement for i32 {}\nimpl CubeElement for i16 {}\nimpl CubeElement for i8 {}\nimpl CubeElement for f64 {}\nimpl CubeElement for f32 {}\nimpl CubeElement for flex32 {}\nimpl CubeElement for f16 {}\nimpl CubeElement for bf16 {}\n\nimpl FloatElement for f64 {}\nimpl FloatElement for f32 {}\nimpl FloatElement for flex32 {}\nimpl FloatElement for bf16 {}\nimpl FloatElement for f16 {}\nimpl IntElement for i64 {}\nimpl IntElement for i32 {}\nimpl IntElement for i16 {}\nimpl IntElement for i8 {}\nimpl IntElement for u64 {}\nimpl IntElement for u32 {}\nimpl IntElement for u16 {}\nimpl IntElement for u8 {}\n\nimpl BoolElement for u8 {}\nimpl BoolElement for u32 {}\n\nimpl MatmulElement for f64 {}\nimpl MatmulElement for f32 {}\nimpl MatmulElement for flex32 {}\nimpl MatmulElement for bf16 {}\nimpl MatmulElement for f16 {}\n\nimpl MatmulElement for i64 {}\nimpl MatmulElement for i32 {}\nimpl MatmulElement for i16 {}\nimpl MatmulElement for i8 {}\nimpl MatmulElement for u64 {}\nimpl MatmulElement for u32 {}\nimpl MatmulElement for u16 {}\nimpl MatmulElement for u8 {}\n\n// TODO: remove once backends no longer rely on generics for default elem types\n/// Returns the bool element dtype.\npub(crate) fn bool_dtype<BT: burn_backend::Element>() -> DType {\n    match BT::dtype() {\n        DType::U32 => DType::Bool(burn_backend::BoolStore::U32),\n        DType::U8 => DType::Bool(burn_backend::BoolStore::U8),\n        other => unimplemented!(\"Invalid bool dtye {other:?}\"),\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/fusion.rs",
    "content": "use crate::BoolElement;\nuse crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};\nuse burn_backend::{DType, Shape};\nuse burn_cubecl_fusion::optim::reduce::ReduceSettings;\nuse burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;\nuse burn_cubecl_fusion::{\n    CubeFusionHandle, FallbackOperation,\n    optim::{\n        CubeOptimization, CubeOptimizationState,\n        elemwise::{ElementWiseFuser, ElemwiseOptimization},\n        matmul::{MatmulFuser, MatmulOptimization},\n        reduce::{ReduceFuser, ReduceOptimization},\n        reduce_broadcasted::ReduceBroadcastedOptimization,\n    },\n};\nuse burn_fusion::{\n    FusionBackend, FusionRuntime,\n    stream::{Operation, OrderedExecution},\n};\nuse burn_ir::{BackendIr, TensorHandle};\nuse burn_std::Metadata;\nuse core::marker::PhantomData;\nuse std::sync::Arc;\n\nimpl<R> burn_fusion::Optimization<FusionCubeRuntime<R>> for CubeOptimization<R>\nwhere\n    R: CubeRuntime,\n{\n    fn execute(\n        &mut self,\n        context: &mut burn_fusion::stream::Context<\n            '_,\n            <FusionCubeRuntime<R> as FusionRuntime>::FusionHandle,\n        >,\n        execution: &OrderedExecution<FusionCubeRuntime<R>>,\n    ) {\n        match self {\n            Self::ElementWise(op) => op.execute(context),\n            Self::Matmul(op) => op.execute(context, |index| {\n                let operation = execution.operation_within_optimization(index);\n                Box::new(FallbackOperationWrapper::new(operation))\n            }),\n            Self::Reduce(op) => op.execute(context, |index| {\n                let operation = execution.operation_within_optimization(index);\n                Box::new(FallbackOperationWrapper::new(operation))\n            }),\n            Self::ReduceBroadcasted(op) => op.execute(context, |index| {\n                let operation = execution.operation_within_optimization(index);\n                Box::new(FallbackOperationWrapper::new(operation))\n            }),\n        }\n    }\n\n    fn to_state(&self) -> CubeOptimizationState {\n        self.to_opt_state()\n    }\n\n    fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {\n        match state {\n            CubeOptimizationState::ElementWise(state) => {\n                Self::ElementWise(ElemwiseOptimization::from_state(device, state))\n            }\n            CubeOptimizationState::Matmul(state) => {\n                Self::Matmul(MatmulOptimization::from_state(device, state))\n            }\n            CubeOptimizationState::Reduce(state) => {\n                Self::Reduce(ReduceOptimization::from_state(device, state))\n            }\n            CubeOptimizationState::ReduceBroadcasted(state) => {\n                Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))\n            }\n        }\n    }\n}\n\nstruct FallbackOperationWrapper<O: Clone> {\n    operation: O,\n}\n\nimpl<O: Clone> FallbackOperationWrapper<O> {\n    fn new(op: O) -> Self {\n        Self { operation: op }\n    }\n}\n\nimpl<R: CubeRuntime> FallbackOperation<R>\n    for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R>>>>\n{\n    fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {\n        self.operation.as_ref().execute(context.handles);\n    }\n}\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr\n    for CubeBackend<R, F, I, BT>\n{\n    type Handle = CubeFusionHandle<R>;\n\n    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {\n        into_tensor(handle.handle, handle.shape)\n    }\n\n    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {\n        into_tensor(handle.handle, handle.shape)\n    }\n\n    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {\n        into_tensor(handle.handle, handle.shape)\n    }\n\n    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {\n        into_tensor(handle.handle, handle.shape)\n    }\n\n    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {\n        tensor.into()\n    }\n\n    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {\n        tensor.into()\n    }\n\n    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {\n        tensor.into()\n    }\n\n    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {\n        tensor.into()\n    }\n}\n\nimpl<R: CubeRuntime> FusionRuntime for FusionCubeRuntime<R> {\n    type OptimizationState = CubeOptimizationState;\n    type Optimization = CubeOptimization<R>;\n    type FusionHandle = CubeFusionHandle<R>;\n    type FusionDevice = R::CubeDevice;\n\n    fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {\n        vec![\n            Box::new(ElementWiseFuser::new(device.clone())),\n            Box::new(MatmulFuser::new(device.clone())),\n            Box::new(ReduceFuser::new(device.clone(), ReduceSettings::Always)),\n            Box::new(ReduceBroadcastedFuser::new(device.clone())),\n        ]\n    }\n}\n\n/// Fusion runtime for JIT runtimes.\n#[derive(Debug)]\npub struct FusionCubeRuntime<R: CubeRuntime> {\n    _b: PhantomData<R>,\n}\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend\n    for CubeBackend<R, F, I, BT>\n{\n    type FusionRuntime = FusionCubeRuntime<R>;\n\n    type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;\n\n    fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {\n        kernel::cast(tensor, dtype).into()\n    }\n}\n\nfn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {\n    CubeTensor {\n        client: handle.client.clone(),\n        handle: handle.handle.clone(),\n        device: handle.device.clone(),\n        meta: Box::new(Metadata::new(shape, handle.strides.clone())),\n        dtype: handle.dtype,\n        qparams: handle.qparams.clone(),\n    }\n}\n\nimpl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {\n    fn from(value: CubeTensor<R>) -> Self {\n        Self {\n            client: value.client.clone(),\n            handle: value.handle.clone(),\n            device: value.device.clone(),\n            strides: value.meta.strides.clone(),\n            dtype: value.dtype,\n            qparams: value.qparams.clone(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/attention/base.rs",
    "content": "use crate::{\n    CubeBackend, CubeRuntime, kernel::attention::attention_autotune,\n    ops::numeric::empty_device_dtype, tensor::CubeTensor,\n};\nuse burn_backend::{\n    DType, Shape,\n    ops::{AttentionModuleOptions, attention::attention_fallback},\n};\nuse cubek::attention::launch;\nuse cubek::attention::{\n    definition::{\n        AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,\n    },\n    routines::blackbox_accelerated::BlackboxAcceleratedStrategy,\n};\n\n#[derive(Debug)]\n/// Strategy used to select which attention implementation to run.\npub enum AttentionStrategy {\n    /// Flash Attention using accelerated inner matmuls.\n    FlashBlackboxAccelerated(BlackboxAcceleratedStrategy),\n\n    /// Flash Attention using unit inner matmuls.\n    FlashUnit,\n\n    /// Fallback implementation using multiple separate kernels.\n    Fallback,\n\n    /// Automatically benchmark and select the best strategy at runtime.\n    #[cfg(feature = \"autotune\")]\n    Autotune,\n}\n\nimpl Default for AttentionStrategy {\n    fn default() -> Self {\n        // if autotune is enabled, default to autotune\n        #[cfg(feature = \"autotune\")]\n        return AttentionStrategy::Autotune;\n\n        // if autotune is disabled, default to fallback to make sure it runs\n        #[cfg(not(feature = \"autotune\"))]\n        AttentionStrategy::Fallback\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\n/// Launch an attention kernel with given strategy\npub fn attention<R: CubeRuntime>(\n    query: CubeTensor<R>,\n    key: CubeTensor<R>,\n    value: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    attn_bias: Option<CubeTensor<R>>,\n    options: AttentionModuleOptions,\n    strategy: AttentionStrategy,\n    out: Option<CubeTensor<R>>,\n) -> Result<CubeTensor<R>, AttentionSetupError> {\n    let mut out = out.unwrap_or_else(|| init_attention_output(&query, &value));\n    match strategy {\n        AttentionStrategy::FlashBlackboxAccelerated(strategy) => flash_attention(\n            query,\n            key,\n            value,\n            mask,\n            attn_bias,\n            options,\n            out,\n            launch::Strategy::BlackboxAccelerated(\n                cubek::attention::launch::BlueprintStrategy::Inferred(strategy),\n            ),\n        ),\n        AttentionStrategy::FlashUnit => flash_attention(\n            query,\n            key,\n            value,\n            mask,\n            attn_bias,\n            options,\n            out,\n            launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),\n        ),\n        AttentionStrategy::Fallback => {\n            out = attention_fallback::<CubeBackend<R, f32, i32, u8>>(\n                query, key, value, mask, attn_bias, options,\n            );\n            Ok(out)\n        }\n        #[cfg(feature = \"autotune\")]\n        AttentionStrategy::Autotune => {\n            attention_autotune(query, key, value, mask, attn_bias, options, out)\n        }\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\n/// Launch a flash attention kernel\npub fn flash_attention<R: CubeRuntime>(\n    query: CubeTensor<R>,\n    key: CubeTensor<R>,\n    value: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    _attn_bias: Option<CubeTensor<R>>,\n    options: AttentionModuleOptions,\n    out: CubeTensor<R>,\n    strategy: launch::Strategy,\n) -> Result<CubeTensor<R>, AttentionSetupError> {\n    let client = query.client.clone();\n\n    let dtypes = AttentionGlobalTypes {\n        query: query.dtype.into(),\n        key: key.dtype.into(),\n        value: value.dtype.into(),\n        mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),\n        out: out.dtype.into(),\n    };\n\n    cubek::attention::launch::launch_ref::<R>(\n        strategy,\n        &client,\n        query.binding(),\n        key.binding(),\n        value.binding(),\n        mask.map(|mask| mask.binding()),\n        out.clone().binding(),\n        &dtypes,\n        AttentionOptions {\n            causal: options.is_causal,\n            accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(\n                cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),\n            )),\n        },\n    )?;\n\n    Ok(out)\n}\n\npub(crate) fn init_attention_output<R: CubeRuntime>(\n    query: &CubeTensor<R>,\n    value: &CubeTensor<R>,\n) -> CubeTensor<R> {\n    let num_batches = query.meta.shape[0];\n    let num_heads = query.meta.shape[1];\n    let seq_q = query.meta.shape[2];\n    let val_dim = value.meta.shape[3];\n    let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);\n\n    empty_device_dtype::<R>(\n        query.client.clone(),\n        query.device.clone(),\n        out_shape,\n        query.dtype,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/attention/mod.rs",
    "content": "mod base;\nmod tune;\n\npub use base::*;\npub use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/attention/tune.rs",
    "content": "use crate::{\n    CubeRuntime, CubeTuneId,\n    kernel::attention::{AttentionStrategy, attention},\n    tensor::CubeTensor,\n};\nuse burn_backend::ops::AttentionModuleOptions;\nuse cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};\nuse cubek::attention::{\n    definition::AttentionSetupError, launch::AttentionAutotuneKey,\n    routines::blackbox_accelerated::BlackboxAcceleratedStrategy,\n};\n\n/// Executes autotune on attention operations\npub fn attention_autotune<R: CubeRuntime>(\n    query: CubeTensor<R>,\n    key: CubeTensor<R>,\n    value: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    attn_bias: Option<CubeTensor<R>>,\n    options: AttentionModuleOptions,\n    out: CubeTensor<R>,\n) -> Result<CubeTensor<R>, AttentionSetupError> {\n    let client = query.client.clone();\n\n    static TUNER: LocalTuner<AttentionAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 3;\n        const PRIORITY_MIN: i8 = 0;\n\n        let flash_attention =\n            TuneGroup::<AttentionAutotuneKey>::new(\"flash_attention\", |_key| PRIORITY_MAX);\n\n        let fallback = TuneGroup::<AttentionAutotuneKey>::new(\"fallback\", |key| {\n            if key.seq_q > 4096 {\n                PRIORITY_MIN\n            } else {\n                PRIORITY_MAX\n            }\n        });\n\n        let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);\n\n        // First entry should always work, since it is considered the fallback.\n        set = set.with(\n            Tunable::new(\n                \"fallback\",\n                |query, key, value, mask, attn_bias, out, options| {\n                    attention::<R>(\n                        query,\n                        key,\n                        value,\n                        mask,\n                        attn_bias,\n                        options,\n                        AttentionStrategy::Fallback,\n                        Some(out),\n                    )\n                    .map_err(|err| std::format!(\"{err:?}\"))\n                },\n            )\n            .group(&fallback, |_key| PRIORITY_MAX),\n        );\n\n        let seq_q = 1;\n        let seq_kv = 1;\n        for num_planes in [2, 4, 8] {\n            let name = format!(\"blackbox_accelerated_{num_planes}_planes_p_{seq_q}-{seq_kv}\");\n            set = set.with(\n                Tunable::new(\n                    &name,\n                    move |query, key, value, mask, attn_bias, out, options| {\n                        attention::<R>(\n                            query,\n                            key,\n                            value,\n                            mask,\n                            attn_bias,\n                            options,\n                            AttentionStrategy::FlashBlackboxAccelerated(\n                                BlackboxAcceleratedStrategy {\n                                    num_planes,\n                                    seq_q,\n                                    seq_kv,\n                                },\n                            ),\n                            Some(out),\n                        )\n                        .map_err(|err| std::format!(\"{err:?}\"))\n                    },\n                )\n                .group(&flash_attention, |_key| PRIORITY_MAX),\n            );\n        }\n\n        set = set.with(\n            Tunable::new(\n                \"unit\",\n                |query, key, value, mask, attn_bias, out, options| {\n                    attention::<R>(\n                        query,\n                        key,\n                        value,\n                        mask,\n                        attn_bias,\n                        options,\n                        AttentionStrategy::FlashUnit,\n                        Some(out),\n                    )\n                    .map_err(|err| std::format!(\"{err:?}\"))\n                },\n            )\n            .group(&flash_attention, |_key| PRIORITY_MIN),\n        );\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&client, &query.device),\n        &client,\n        tunables,\n        (query, key, value, mask, attn_bias, out.clone(), options),\n    );\n\n    Ok(out)\n}\n\nfn create_key<R: CubeRuntime>(\n    query: &CubeTensor<R>,\n    key: &CubeTensor<R>,\n    value: &CubeTensor<R>,\n    mask: &Option<CubeTensor<R>>,\n    _attn_bias: &Option<CubeTensor<R>>,\n    out: &CubeTensor<R>,\n    _options: &AttentionModuleOptions,\n) -> AttentionAutotuneKey {\n    let total_batches = query.meta.shape[0] * query.meta.shape[1];\n    let seq_q = query.meta.shape[2];\n    let head_dim = query.meta.shape[3];\n    let seq_kv = value.meta.shape[2];\n    let val_dim = value.meta.shape[3];\n\n    AttentionAutotuneKey::generate(\n        query.dtype.into(),\n        key.dtype.into(),\n        value.dtype.into(),\n        out.dtype.into(),\n        total_batches,\n        seq_q,\n        head_dim,\n        seq_kv,\n        val_dim,\n        mask.is_some(),\n    )\n}\n\n#[allow(clippy::type_complexity)]\n#[allow(clippy::too_many_arguments)]\nfn input_gen<R: CubeRuntime>(\n    _key: &AttentionAutotuneKey,\n    query: &CubeTensor<R>,\n    key: &CubeTensor<R>,\n    value: &CubeTensor<R>,\n    mask: &Option<CubeTensor<R>>,\n    attn_bias: &Option<CubeTensor<R>>,\n    out: &CubeTensor<R>,\n    options: &AttentionModuleOptions,\n) -> (\n    CubeTensor<R>,\n    CubeTensor<R>,\n    CubeTensor<R>,\n    Option<CubeTensor<R>>,\n    Option<CubeTensor<R>>,\n    CubeTensor<R>,\n    AttentionModuleOptions,\n) {\n    (\n        query.clone(),\n        key.clone(),\n        value.clone(),\n        mask.clone(),\n        attn_bias.clone(),\n        out.copy(),\n        *options,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/binary.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::{TensorMetadata, bf16, f16};\nuse cubecl::{\n    calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView,\n};\n\npub(crate) trait BinaryOpFamily: Send + Sync + 'static {\n    type BinaryOp<C: Numeric, N: Size>: BinaryOp<C, N>;\n}\n\n#[cube]\npub(crate) trait BinaryOp<C: Numeric, N: Size>: 'static + Send + Sync {\n    /// Execute a binary operation.\n    fn execute(lhs: Vector<C, N>, rhs: Vector<C, N>) -> Vector<C, N>;\n}\n\npub(crate) struct AddOp;\npub(crate) struct SubOp;\npub(crate) struct MulOp;\npub(crate) struct DivOp;\npub(crate) struct RemainderOp;\npub(crate) struct AndOp;\npub(crate) struct OrOp;\npub(crate) struct PowOp;\n\nimpl BinaryOpFamily for AddOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for SubOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for MulOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for DivOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for RemainderOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for PowOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for AndOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\nimpl BinaryOpFamily for OrOp {\n    type BinaryOp<C: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for AddOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs + rhs\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for SubOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs - rhs\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for MulOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs * rhs\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for DivOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs / rhs\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for RemainderOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        Vector::rem(lhs, rhs)\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for PowOp {\n    #[allow(unused)]\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        intrinsic!(|scope| {\n            let elem = T::as_type(scope).elem_type();\n\n            if let cubecl::ir::ElemType::Float(kind) = elem {\n                match kind {\n                    cubecl::ir::FloatKind::F16 => {\n                        let lhs = <Vector<f16, N> as Cast>::__expand_cast_from(scope, lhs);\n                        let rhs = <Vector<f16, N> as Cast>::__expand_cast_from(scope, rhs);\n                        let out = Vector::__expand_powf(scope, lhs, rhs);\n                        return <Vector<T, N> as Cast>::__expand_cast_from(scope, out);\n                    }\n                    cubecl::ir::FloatKind::BF16 => {\n                        let lhs = <Vector<bf16, N> as Cast>::__expand_cast_from(scope, lhs);\n                        let rhs = <Vector<bf16, N> as Cast>::__expand_cast_from(scope, rhs);\n                        let out = Vector::__expand_powf(scope, lhs, rhs);\n                        return <Vector<T, N> as Cast>::__expand_cast_from(scope, out);\n                    }\n                    cubecl::ir::FloatKind::F64 => {\n                        let lhs = <Vector<f64, N> as Cast>::__expand_cast_from(scope, lhs);\n                        let rhs = <Vector<f64, N> as Cast>::__expand_cast_from(scope, rhs);\n                        let out = Vector::__expand_powf(scope, lhs, rhs);\n                        return <Vector<T, N> as Cast>::__expand_cast_from(scope, out);\n                    }\n                    _ => {}\n                }\n            };\n\n            let lhs = <Vector<f32, N> as Cast>::__expand_cast_from(scope, lhs);\n            let rhs = <Vector<f32, N> as Cast>::__expand_cast_from(scope, rhs);\n            let out = Vector::__expand_powf(scope, lhs, rhs);\n            return <Vector<T, N> as Cast>::__expand_cast_from(scope, out);\n        })\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for AndOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        Vector::cast_from(Vector::<bool, N>::cast_from(lhs).and(Vector::<bool, N>::cast_from(rhs)))\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> BinaryOp<T, N> for OrOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        Vector::cast_from(Vector::<bool, N>::cast_from(lhs).or(Vector::<bool, N>::cast_from(rhs)))\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_scalar_binop<C: Numeric, N: Size, O: BinaryOpFamily>(\n    input: &LinearView<Vector<C, N>>,\n    scalar: InputScalar,\n    output: &mut LinearView<Vector<C, N>, ReadWrite>,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] =\n        O::BinaryOp::<C, N>::execute(input[ABSOLUTE_POS], Vector::new(scalar.get::<C>()));\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_binop<C: Numeric, N: Size, O: BinaryOpFamily>(\n    lhs: &LinearView<Vector<C, N>>,\n    rhs: &LinearView<Vector<C, N>>,\n    out: &mut LinearView<Vector<C, N>, ReadWrite>,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !out.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    out[ABSOLUTE_POS] = O::BinaryOp::<C, N>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);\n}\n\npub(crate) fn launch_binop<R: CubeRuntime, O: BinaryOpFamily>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let vector_size_lhs = max_vector_size(&lhs);\n    let vector_size_rhs = max_vector_size(&rhs);\n    let vector_size = Ord::min(vector_size_lhs, vector_size_rhs);\n\n    let shape_out = broadcast_shape(&[&lhs, &rhs]);\n    let dtype = lhs.dtype;\n\n    let client = lhs.client.clone();\n    let num_elems = shape_out.num_elements();\n    let working_units = num_elems / vector_size as usize;\n\n    let cube_dim = CubeDim::new(&lhs.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);\n\n    unsafe {\n        if lhs.can_mut_broadcast(&rhs) {\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.clone().into_linear_view(),\n                rhs.into_linear_view_like(&lhs),\n                lhs.as_linear_view_alias(0),\n                dtype.into(),\n            );\n\n            lhs\n        } else if rhs.can_mut_broadcast(&lhs) {\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.into_linear_view_like(&rhs),\n                rhs.clone().into_linear_view(),\n                rhs.as_linear_view_alias(1),\n                dtype.into(),\n            );\n\n            rhs\n        } else {\n            let output =\n                empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);\n\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs, output),\n                vector_size,\n                lhs.into_linear_view_like(&output),\n                rhs.into_linear_view_like(&output),\n                output.clone().into_linear_view(),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n\npub(crate) fn launch_scalar_binop<R: CubeRuntime, O: BinaryOpFamily>(\n    tensor: CubeTensor<R>,\n    scalar: InputScalar,\n) -> CubeTensor<R> {\n    // Vectorization is only enabled when the last dimension is contiguous.\n    let vector_size = max_vector_size(&tensor);\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n    let dtype = tensor.dtype;\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    unsafe {\n        if tensor.can_mut() && tensor.is_nonoverlapping() {\n            kernel_scalar_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                scalar,\n                tensor.as_linear_view_alias(0),\n                dtype.into(),\n            );\n\n            tensor\n        } else {\n            let output = empty_device_dtype(\n                tensor.client.clone(),\n                tensor.device.clone(),\n                tensor.shape(),\n                dtype,\n            );\n\n            kernel_scalar_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                scalar,\n                output.clone().into_linear_view(),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/binary_float.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\npub(crate) trait BinaryOpFloatFamily: Send + Sync + 'static {\n    type BinaryOp<C: Float, N: Size>: BinaryOpFloat<C, N>;\n}\n\n#[cube]\npub(crate) trait BinaryOpFloat<C: Float, N: Size>: 'static + Send + Sync {\n    /// Execute a binary operation.\n    fn execute(lhs: Vector<C, N>, rhs: Vector<C, N>) -> Vector<C, N>;\n}\n\npub(crate) struct ArcTan2Op;\n\nimpl BinaryOpFloatFamily for ArcTan2Op {\n    type BinaryOp<C: Float, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Float, N: Size> BinaryOpFloat<T, N> for ArcTan2Op {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        Vector::atan2(lhs, rhs)\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_binop<C: Float, N: Size, O: BinaryOpFloatFamily>(\n    lhs: &LinearView<Vector<C, N>>,\n    rhs: &LinearView<Vector<C, N>>,\n    out: &mut LinearView<Vector<C, N>, ReadWrite>,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !out.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    out[ABSOLUTE_POS] = O::BinaryOp::<C, N>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);\n}\n\npub(crate) fn launch_binop_float<R: CubeRuntime, O: BinaryOpFloatFamily>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let vector_size_lhs = max_vector_size(&lhs);\n    let vector_size_rhs = max_vector_size(&rhs);\n    let vector_size = Ord::min(vector_size_lhs, vector_size_rhs);\n\n    let shape_out = broadcast_shape(&[&lhs, &rhs]);\n    let dtype = lhs.dtype;\n\n    let client = lhs.client.clone();\n    let num_elems = shape_out.num_elements();\n    let working_units = num_elems / vector_size as usize;\n\n    let cube_dim = CubeDim::new(&lhs.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);\n\n    unsafe {\n        if lhs.can_mut_broadcast(&rhs) {\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.clone().into_linear_view(),\n                rhs.clone().into_linear_view_like(&lhs),\n                lhs.as_linear_view_alias(0),\n                dtype.into(),\n            );\n\n            lhs\n        } else if rhs.can_mut_broadcast(&lhs) {\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.into_linear_view_like(&rhs),\n                rhs.clone().into_linear_view(),\n                rhs.as_linear_view_alias(1),\n                dtype.into(),\n            );\n\n            rhs\n        } else {\n            let output =\n                empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, dtype);\n\n            kernel_binop::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs, output),\n                vector_size,\n                lhs.into_linear_view_like(&output),\n                rhs.into_linear_view_like(&output),\n                output.clone().into_linear_view(),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n\n/// Calculate the four-quadrant inverse tangent of `lhs / rhs`.\npub fn atan2<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop_float::<R, ArcTan2Op>(lhs, rhs)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/binary_int.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\npub(crate) trait BinaryOpIntFamily: Send + Sync + 'static {\n    type BinaryOp<C: Int, N: Size>: BinaryOpInt<C, N>;\n}\n\n#[cube]\npub(crate) trait BinaryOpInt<C: Int, N: Size>: 'static + Send + Sync {\n    /// Execute a binary operation.\n    fn execute(lhs: Vector<C, N>, rhs: Vector<C, N>) -> Vector<C, N>;\n}\n\npub(crate) struct BitwiseAndOp;\npub(crate) struct BitwiseOrOp;\npub(crate) struct BitwiseXorOp;\npub(crate) struct BitwiseShrOp;\npub(crate) struct BitwiseShlOp;\n\nimpl BinaryOpIntFamily for BitwiseAndOp {\n    type BinaryOp<C: Int, N: Size> = Self;\n}\n\nimpl BinaryOpIntFamily for BitwiseOrOp {\n    type BinaryOp<C: Int, N: Size> = Self;\n}\n\nimpl BinaryOpIntFamily for BitwiseXorOp {\n    type BinaryOp<C: Int, N: Size> = Self;\n}\n\nimpl BinaryOpIntFamily for BitwiseShrOp {\n    type BinaryOp<C: Int, N: Size> = Self;\n}\n\nimpl BinaryOpIntFamily for BitwiseShlOp {\n    type BinaryOp<C: Int, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Int, N: Size> BinaryOpInt<T, N> for BitwiseAndOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs & rhs\n    }\n}\n\n#[cube]\nimpl<T: Int, N: Size> BinaryOpInt<T, N> for BitwiseOrOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs | rhs\n    }\n}\n\n#[cube]\nimpl<T: Int, N: Size> BinaryOpInt<T, N> for BitwiseXorOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs ^ rhs\n    }\n}\n\n#[cube]\nimpl<T: Int, N: Size> BinaryOpInt<T, N> for BitwiseShrOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs >> rhs\n    }\n}\n\n#[cube]\nimpl<T: Int, N: Size> BinaryOpInt<T, N> for BitwiseShlOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> Vector<T, N> {\n        lhs << rhs\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_scalar_binop_int<C: Int, N: Size, O: BinaryOpIntFamily>(\n    input: &LinearView<Vector<C, N>>,\n    scalar: InputScalar,\n    output: &mut LinearView<Vector<C, N>, ReadWrite>,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] =\n        O::BinaryOp::<C, N>::execute(input[ABSOLUTE_POS], Vector::new(scalar.get::<C>()));\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_binop_int<C: Int, N: Size, O: BinaryOpIntFamily>(\n    lhs: &LinearView<Vector<C, N>>,\n    rhs: &LinearView<Vector<C, N>>,\n    out: &mut LinearView<Vector<C, N>, ReadWrite>,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !out.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    out[ABSOLUTE_POS] = O::BinaryOp::<C, N>::execute(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);\n}\n\npub(crate) fn launch_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let vector_size_lhs = max_vector_size(&lhs);\n    let vector_size_rhs = max_vector_size(&rhs);\n    let vector_size = Ord::min(vector_size_lhs, vector_size_rhs);\n\n    let shape_out = broadcast_shape(&[&lhs, &rhs]);\n\n    let client = lhs.client.clone();\n    let num_elems = shape_out.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&lhs.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);\n    let dtype = lhs.dtype;\n\n    unsafe {\n        if lhs.can_mut_broadcast(&rhs) {\n            kernel_binop_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.clone().into_linear_view(),\n                rhs.into_linear_view_like(&lhs),\n                lhs.as_linear_view_alias(0),\n                dtype.into(),\n            );\n\n            lhs\n        } else if rhs.can_mut_broadcast(&lhs) {\n            kernel_binop_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.into_linear_view_like(&rhs),\n                rhs.clone().into_linear_view(),\n                rhs.as_linear_view_alias(1),\n                dtype.into(),\n            );\n\n            rhs\n        } else {\n            let output =\n                empty_device_dtype(lhs.client.clone(), lhs.device.clone(), shape_out, lhs.dtype);\n\n            kernel_binop_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs, output),\n                vector_size,\n                lhs.into_linear_view_like(&output),\n                rhs.into_linear_view_like(&output),\n                output.clone().into_linear_view(),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n\npub(crate) fn launch_scalar_binop_int<R: CubeRuntime, O: BinaryOpIntFamily>(\n    tensor: CubeTensor<R>,\n    scalar: InputScalar,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&tensor);\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.shape.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    unsafe {\n        if tensor.can_mut() && tensor.is_nonoverlapping() {\n            kernel_scalar_binop_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                scalar,\n                tensor.as_linear_view_alias(0),\n                tensor.dtype.into(),\n            );\n\n            tensor\n        } else {\n            let output = empty_device_dtype(\n                tensor.client.clone(),\n                tensor.device.clone(),\n                tensor.shape(),\n                tensor.dtype,\n            );\n\n            kernel_scalar_binop_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                scalar,\n                output.clone().into_linear_view(),\n                output.dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/cast/base.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, TensorMetadata};\nuse cubecl::std::tensor::layout::linear::LinearView;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\n#[cube(launch, address_type = \"dynamic\")]\npub(crate) fn cast_element<I: Numeric, O: Numeric, N: Size>(\n    input: &LinearView<Vector<I, N>>,\n    output: &mut LinearView<Vector<O, N>, ReadWrite>,\n    #[define(I, O)] _dtypes: [StorageType; 2],\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS]);\n}\n\n/// Cast a tensor to the given element type.\n///\n/// Note: When input element is semantically a boolean, prefer bool_cast function.\npub fn cast<R: CubeRuntime>(input: CubeTensor<R>, dtype: DType) -> CubeTensor<R> {\n    let dtype_output = match dtype {\n        DType::Flex32 => DType::F32,\n        _ => dtype,\n    };\n    let dtype_input = match input.dtype {\n        DType::Flex32 => DType::F32,\n        _ => input.dtype,\n    };\n\n    if dtype_input == dtype_output {\n        return input;\n    }\n\n    let client = input.client.clone();\n\n    let vector_size = max_vector_size(&input);\n\n    let num_elems: usize = input.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);\n\n    let output = empty_device_dtype(\n        client.clone(),\n        input.device.clone(),\n        input.shape(),\n        dtype, // We take the same dtype as passed as input (Flex32 not F32)\n    );\n\n    cast_element::launch(\n        &client,\n        cube_count,\n        cube_dim,\n        address_type!(input, output),\n        vector_size,\n        input.into_linear_view(),\n        output.clone().into_linear_view(),\n        [dtype_input.into(), dtype_output.into()],\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/cast/bool_cast.rs",
    "content": "use crate::{\n    CubeElement, CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size, numeric::empty_device},\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::{\n    CubeDim, calculate_cube_count_elemwise, num_traits::One, prelude::*,\n    std::tensor::layout::linear::LinearView,\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn bool_cast_kernel<B: Int, T: Numeric, N: Size>(\n    input: &LinearView<Vector<B, N>>,\n    output: &mut LinearView<Vector<T, N>, ReadWrite>,\n    #[define(B)] _input_ty: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = Vector::cast_from(input[ABSOLUTE_POS] & Vector::one());\n}\n\n/// Cast a bool tensor to the given element type.\n///\n/// This alternative to cast is necessary because bool are represented as u32 or u8\n/// where any non-zero value means true. Depending how it was created\n/// it may hold an uncanny bit combination. Naively casting it would not\n/// necessarily yield 0 or 1.\npub fn bool_cast<R: CubeRuntime, EO: CubeElement>(tensor: CubeTensor<R>) -> CubeTensor<R> {\n    let output =\n        empty_device::<R, EO>(tensor.client.clone(), tensor.device.clone(), tensor.shape());\n\n    let vector_size = max_vector_size(&tensor);\n    let num_elems = tensor.meta.num_elements();\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    let dtype = tensor.dtype;\n\n    unsafe {\n        bool_cast_kernel::launch_unchecked::<EO, R>(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, output),\n            vector_size,\n            tensor.into_linear_view(),\n            output.clone().into_linear_view(),\n            dtype.into(),\n        )\n    };\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/cast/mod.rs",
    "content": "mod base;\nmod bool_cast;\n\npub use base::*;\npub use bool_cast::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/clamp.rs",
    "content": "use cubecl::prelude::*;\n\nuse crate::{\n    CubeRuntime,\n    kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric},\n    tensor::CubeTensor,\n};\n\n#[derive(CubeLaunch, CubeType)]\nstruct Options {\n    min_value: InputScalar,\n    max_value: InputScalar,\n}\n\npub(crate) fn clamp<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    min_value: InputScalar,\n    max_value: InputScalar,\n) -> CubeTensor<R> {\n    struct ClampOp;\n\n    #[cube]\n    impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for ClampOp {\n        type Options = Options;\n\n        fn execute(input: Vector<T, N>, options: &Self::Options) -> Vector<T, N> {\n            cubecl::prelude::clamp(\n                input,\n                Vector::new(options.min_value.get::<T>()),\n                Vector::new(options.max_value.get::<T>()),\n            )\n        }\n    }\n\n    impl NumericUnaryOpFamily for ClampOp {\n        type Options = Options;\n        type Unary<T: Numeric, N: Size> = Self;\n    }\n\n    launch_unary_numeric::<R, ClampOp, _>(input, |_| OptionsLaunch::new(min_value, max_value))\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/comparison.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, TensorMetadata};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\n#[cube]\npub(crate) trait ComparisonOpFamily: 'static + Send + Sync {\n    type Operation<T: Numeric, N: Size>: ComparisonOp<T, N>;\n}\n\n#[cube]\npub(crate) trait ComparisonOp<C: Numeric, N: Size>: 'static + Send + Sync {\n    /// Execute a comparison operation.\n    fn execute(lhs: Vector<C, N>, rhs: Vector<C, N>) -> bool;\n}\n\nstruct EqualOp;\nstruct GreaterEqualOp;\nstruct LowerEqualOp;\nstruct GreaterOp;\nstruct LowerOp;\n\nimpl ComparisonOpFamily for EqualOp {\n    type Operation<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> ComparisonOp<T, N> for EqualOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> bool {\n        lhs == rhs\n    }\n}\n\nimpl ComparisonOpFamily for GreaterEqualOp {\n    type Operation<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> ComparisonOp<T, N> for GreaterEqualOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> bool {\n        lhs >= rhs\n    }\n}\n\nimpl ComparisonOpFamily for LowerEqualOp {\n    type Operation<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> ComparisonOp<T, N> for LowerEqualOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> bool {\n        lhs <= rhs\n    }\n}\n\nimpl ComparisonOpFamily for GreaterOp {\n    type Operation<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> ComparisonOp<T, N> for GreaterOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> bool {\n        lhs > rhs\n    }\n}\n\nimpl ComparisonOpFamily for LowerOp {\n    type Operation<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> ComparisonOp<T, N> for LowerOp {\n    fn execute(lhs: Vector<T, N>, rhs: Vector<T, N>) -> bool {\n        lhs < rhs\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_scalar_cmp<T: Numeric, Bool: Numeric, N: Size, O: ComparisonOpFamily>(\n    input: &LinearView<Vector<T, N>>,\n    scalar: InputScalar,\n    output: &mut LinearView<Vector<Bool, N>, ReadWrite>,\n    #[define(T, Bool)] _dtypes: [StorageType; 2],\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = Vector::cast_from(O::Operation::<T, N>::execute(\n        input[ABSOLUTE_POS],\n        Vector::new(scalar.get::<T>()),\n    ));\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_cmp<T: Numeric, Bool: Numeric, N: Size, O: ComparisonOpFamily>(\n    lhs: &LinearView<Vector<T, N>>,\n    rhs: &LinearView<Vector<T, N>>,\n    out: &mut LinearView<Vector<Bool, N>, ReadWrite>,\n    #[define(T, Bool)] _dtype: [StorageType; 2],\n) {\n    if !out.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    out[ABSOLUTE_POS] = Vector::cast_from(O::Operation::<T, N>::execute(\n        lhs[ABSOLUTE_POS],\n        rhs[ABSOLUTE_POS],\n    ));\n}\n\npub(crate) fn launch_cmp<R: CubeRuntime, O: ComparisonOpFamily>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let vector_size_lhs = max_vector_size(&lhs);\n    let vector_size_rhs = max_vector_size(&rhs);\n\n    let vector_size = Ord::min(vector_size_lhs, vector_size_rhs);\n\n    let shape_out = broadcast_shape(&[&lhs, &rhs]);\n    let client = lhs.client.clone();\n    let num_elems = shape_out.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&lhs.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&lhs.client, working_units, cube_dim);\n\n    let dtypes = [lhs.dtype.into(), dtype_bool.into()];\n    let same_tensor_type = dtypes[0] == dtypes[1];\n    if same_tensor_type && lhs.can_mut_broadcast(&rhs) {\n        unsafe {\n            kernel_cmp::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.clone().into_linear_view(),\n                rhs.into_linear_view_like(&lhs),\n                lhs.as_linear_view_alias(0),\n                dtypes,\n            );\n        }\n\n        CubeTensor::new(\n            lhs.client.clone(),\n            lhs.handle.clone(),\n            *lhs.meta.clone(),\n            lhs.device.clone(),\n            dtype_bool,\n        )\n    } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {\n        unsafe {\n            kernel_cmp::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs),\n                vector_size,\n                lhs.into_linear_view_like(&rhs),\n                rhs.clone().into_linear_view(),\n                rhs.as_linear_view_alias(1),\n                dtypes,\n            );\n        };\n\n        CubeTensor::new(\n            rhs.client.clone(),\n            rhs.handle.clone(),\n            *rhs.meta.clone(),\n            rhs.device.clone(),\n            dtype_bool,\n        )\n    } else {\n        let output = empty_device_dtype(\n            lhs.client.clone(),\n            lhs.device.clone(),\n            shape_out,\n            dtype_bool,\n        );\n\n        unsafe {\n            kernel_cmp::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(lhs, rhs, output),\n                vector_size,\n                lhs.into_linear_view_like(&output),\n                rhs.into_linear_view_like(&output),\n                output.clone().into_linear_view(),\n                dtypes,\n            );\n        };\n\n        output\n    }\n}\n\npub(crate) fn launch_scalar_cmp<R: CubeRuntime, O: ComparisonOpFamily>(\n    tensor: CubeTensor<R>,\n    scalar: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&tensor);\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    let dtypes = [tensor.dtype.into(), dtype_bool.into()];\n    let same_tensor_type = dtypes[0] == dtypes[1];\n\n    if same_tensor_type && tensor.can_mut() && tensor.is_nonoverlapping() {\n        unsafe {\n            kernel_scalar_cmp::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                scalar,\n                tensor.as_linear_view_alias(0),\n                dtypes,\n            );\n        }\n\n        CubeTensor::new(\n            tensor.client.clone(),\n            tensor.handle.clone(),\n            *tensor.meta.clone(),\n            tensor.device.clone(),\n            dtype_bool,\n        )\n    } else {\n        let output = empty_device_dtype(\n            tensor.client.clone(),\n            tensor.device.clone(),\n            tensor.shape(),\n            dtype_bool,\n        );\n\n        unsafe {\n            kernel_scalar_cmp::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                scalar,\n                output.clone().into_linear_view(),\n                dtypes,\n            );\n        }\n\n        output\n    }\n}\n\npub fn equal<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)\n}\n\npub fn greater<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)\n}\n\npub fn greater_equal<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)\n}\n\npub fn lower<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)\n}\n\npub fn lower_equal<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)\n}\n\npub fn equal_elem<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_scalar_cmp::<R, EqualOp>(lhs, rhs, dtype_bool)\n}\n\npub fn greater_elem<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_scalar_cmp::<R, GreaterOp>(lhs, rhs, dtype_bool)\n}\n\npub fn lower_elem<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_scalar_cmp::<R, LowerOp>(lhs, rhs, dtype_bool)\n}\n\npub fn greater_equal_elem<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_scalar_cmp::<R, GreaterEqualOp>(lhs, rhs, dtype_bool)\n}\n\npub fn lower_equal_elem<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    launch_scalar_cmp::<R, LowerEqualOp>(lhs, rhs, dtype_bool)\n}\n\n// Unary comparison / predicate / relational ops\n\n#[cube]\npub(crate) trait PredicateOp<F: Float, N: Size>: 'static + Send + Sync {\n    /// Execute a predicate operation.\n    fn execute(input: Vector<F, N>) -> Vector<bool, N>;\n}\n\npub(crate) trait PredicateOpFamily: 'static + Send + Sync {\n    type Operation<F: Float, N: Size>: PredicateOp<F, N>;\n}\n\nstruct IsNanOp;\nstruct IsInfOp;\n\nimpl PredicateOpFamily for IsNanOp {\n    type Operation<F: Float, N: Size> = Self;\n}\n\n#[cube]\nimpl<F: Float, N: Size> PredicateOp<F, N> for IsNanOp {\n    fn execute(input: Vector<F, N>) -> Vector<bool, N> {\n        Vector::is_nan(input)\n    }\n}\n\nimpl PredicateOpFamily for IsInfOp {\n    type Operation<F: Float, N: Size> = Self;\n}\n#[cube]\nimpl<F: Float, N: Size> PredicateOp<F, N> for IsInfOp {\n    fn execute(input: Vector<F, N>) -> Vector<bool, N> {\n        Vector::is_inf(input)\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn kernel_predicate<F: Float, Bool: Numeric, N: Size, O: PredicateOpFamily>(\n    input: &LinearView<Vector<F, N>>,\n    output: &mut LinearView<Vector<Bool, N>, ReadWrite>,\n    #[define(F, Bool)] _dtypes: [StorageType; 2],\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = Vector::cast_from(O::Operation::<F, N>::execute(input[ABSOLUTE_POS]));\n}\n\npub(crate) fn launch_predicate<R: CubeRuntime, O: PredicateOpFamily>(\n    tensor: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&tensor);\n\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n\n    let dtypes = [tensor.dtype.into(), dtype_bool.into()];\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        tensor.shape(),\n        dtype_bool,\n    );\n\n    unsafe {\n        kernel_predicate::launch_unchecked::<O, R>(\n            &client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, output),\n            vector_size,\n            tensor.into_linear_view_like(&output),\n            output.clone().into_linear_view(),\n            dtypes,\n        );\n    }\n\n    output\n}\n\npub fn is_nan<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {\n    launch_predicate::<R, IsNanOp>(tensor, dtype_bool)\n}\n\npub fn is_inf<R: CubeRuntime>(tensor: CubeTensor<R>, dtype_bool: DType) -> CubeTensor<R> {\n    launch_predicate::<R, IsInfOp>(tensor, dtype_bool)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/contiguous.rs",
    "content": "use burn_backend::{DType, QTensorPrimitive, TensorMetadata};\nuse cubecl::quant::scheme::{QuantStore, QuantValue};\nuse cubecl::server::MemoryLayoutStrategy;\n\nuse crate::{CubeRuntime, ops::empty_qtensor, tensor::CubeTensor};\n\n/// Make a jit tensor contiguous.\npub fn into_contiguous<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {\n    if tensor.is_contiguous() {\n        return tensor;\n    }\n\n    if tensor.qparams.is_some() {\n        return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Contiguous);\n    }\n\n    let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);\n\n    let output = cubecl::std::tensor::into_contiguous(&client, tensor.binding(), dtype.into());\n\n    CubeTensor::new(\n        client.clone(),\n        output.handle,\n        *output.metadata,\n        device,\n        dtype,\n    )\n}\n\n/// Make a jit tensor contiguous with an aligned last stride. Tensor is considered already contiguous\n/// if runtime can read it as is. This is equivalent in practice.\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensor))\n)]\npub fn into_contiguous_aligned<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {\n    if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) {\n        return tensor;\n    }\n\n    if tensor.qparams.is_some() {\n        return into_contiguous_quantized(tensor, MemoryLayoutStrategy::Optimized);\n    }\n\n    let (client, device, dtype) = (tensor.client.clone(), tensor.device.clone(), tensor.dtype);\n\n    let output =\n        cubecl::std::tensor::into_contiguous_pitched(&client, tensor.binding(), dtype.into());\n\n    CubeTensor::new(\n        client.clone(),\n        output.handle,\n        *output.metadata,\n        device,\n        dtype,\n    )\n}\n\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensor))\n)]\nfn into_contiguous_quantized<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    strategy: MemoryLayoutStrategy,\n) -> CubeTensor<R> {\n    let scheme = tensor.scheme();\n    let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, strategy);\n    let (values, scales) = tensor.quantized_handles().unwrap();\n    let (out_values, out_scales) = output.quantized_handles().unwrap();\n\n    let (client, dtype_scales, dtype_value) = (scales.client.clone(), scales.dtype, values.dtype);\n\n    match scheme.store {\n        QuantStore::PackedU32(packed_dim) => {\n            cubecl::std::tensor::into_contiguous_packed_ref(\n                &client,\n                values.binding(),\n                out_values.binding(),\n                packed_dim,\n                tensor.meta.shape(),\n                scheme.num_quants(),\n                DType::U32.into(),\n            );\n        }\n        // e2m1 is special because it has a native packed representation, `e2m1x2`.\n        // It's internally stored as `u8` with a packing factor of 2.\n        QuantStore::PackedNative(packed_dim) if scheme.value == QuantValue::E2M1 => {\n            cubecl::std::tensor::into_contiguous_packed_ref(\n                &client,\n                values.binding(),\n                out_values.binding(),\n                packed_dim,\n                tensor.meta.shape(),\n                scheme.num_quants(),\n                DType::U8.into(),\n            );\n        }\n        _ => {\n            cubecl::std::tensor::copy_into(\n                &client,\n                values.binding(),\n                out_values.binding(),\n                dtype_value.into(),\n            );\n        }\n    }\n\n    cubecl::std::tensor::copy_into(\n        &client,\n        scales.binding(),\n        out_scales.binding(),\n        dtype_scales.into(),\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs",
    "content": "use burn_backend::{\n    TensorMetadata,\n    ops::{ConvOptions, ConvTransposeOptions},\n};\nuse burn_std::Shape;\nuse cubek::convolution::components::ConvSetupError;\n\nuse crate::{\n    CubeRuntime,\n    kernel::conv::{conv_transpose2d, conv_transpose3d},\n    ops::{permute_nchw_to_nhwc, permute_nhwc_to_nchw, reshape},\n    tensor::CubeTensor,\n};\n\npub(crate) fn conv_data_backward_fallback<R: CubeRuntime, const N_DIM: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    in_shape: Shape,\n    options: ConvOptions<N_DIM>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let dim_c = out_grad.rank();\n\n    let kernel_size = &weights.meta.shape()[1..dim_c];\n    let in_shape = &in_shape[1..dim_c];\n    let out_shape = &out_grad.meta.shape()[1..dim_c];\n\n    let mut padding_out = [0; N_DIM];\n\n    for i in 0..N_DIM {\n        padding_out[i] = calculate_padding_out(\n            kernel_size[i],\n            options.stride[i],\n            options.padding[i],\n            options.dilation[i],\n            in_shape[i],\n            out_shape[i],\n        );\n    }\n\n    // We don't yet have NHWC kernels for conv_transpose so need to do this.\n    // Should eventually use NHWC kernels instead\n    let out_grad = permute_nhwc_to_nchw(out_grad);\n    let weights = permute_nhwc_to_nchw(weights);\n\n    let in_grad = match N_DIM {\n        1 => conv_transpose1d_from_conv_transpose2d(\n            out_grad,\n            weights,\n            ConvTransposeOptions::new(\n                [options.stride[0]],\n                [options.padding[0]],\n                [padding_out[0]],\n                [options.dilation[0]],\n                options.groups,\n            ),\n        ),\n        2 => conv_transpose2d(\n            out_grad,\n            weights,\n            None,\n            ConvTransposeOptions::new(\n                [options.stride[0], options.stride[1]],\n                [options.padding[0], options.padding[1]],\n                [padding_out[0], padding_out[1]],\n                [options.dilation[0], options.dilation[1]],\n                options.groups,\n            ),\n            Default::default(),\n        ),\n        3 => Ok(conv_transpose3d(\n            out_grad,\n            weights,\n            None,\n            ConvTransposeOptions::new(\n                [options.stride[0], options.stride[1], options.stride[2]],\n                [options.padding[0], options.padding[1], options.padding[2]],\n                [padding_out[0], padding_out[1], padding_out[2]],\n                [\n                    options.dilation[0],\n                    options.dilation[1],\n                    options.dilation[2],\n                ],\n                options.groups,\n            ),\n        )\n        .unwrap()),\n        _ => unimplemented!(\"Invalid dimensionality\"),\n    }?;\n    Ok(permute_nchw_to_nhwc(in_grad))\n}\n\nfn calculate_padding_out(\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    size_in: usize,\n    size_out: usize,\n) -> usize {\n    if stride <= 1 {\n        return 0;\n    }\n\n    let out = 1\n        + ((size_in + 2 * padding - dilation * (kernel_size - 1) - 1) as f64 / stride as f64).ceil()\n            as usize;\n    i64::max(0, out as i64 - size_out as i64) as usize\n}\n\nfn conv_transpose1d_from_conv_transpose2d<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    options: ConvTransposeOptions<1>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let [channels_in, channels_out, kernel_size] = weight.shape().dims();\n    let [batch_size, _channels_in, length_in] = x.shape().dims();\n\n    let weight = reshape(\n        weight,\n        Shape::new([channels_in, channels_out, kernel_size, 1]),\n    );\n    let x = reshape(x, Shape::new([batch_size, channels_in, length_in, 1]));\n\n    let tensor = conv_transpose2d(\n        x,\n        weight,\n        None,\n        ConvTransposeOptions::new(\n            [options.stride[0], 1],\n            [options.padding[0], 0],\n            [options.padding_out[0], 0],\n            [options.dilation[0], 1],\n            options.groups,\n        ),\n        Default::default(),\n    )?;\n    let [batch_size, channels_out, height_out, _weight_out] = tensor.shape().dims();\n    Ok(reshape(\n        tensor,\n        Shape::from([batch_size, channels_out, height_out]),\n    ))\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse burn_std::Shape;\nuse cubek::{\n    convolution::{\n        AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_data,\n        components::ConvSetupError,\n    },\n    matmul::{\n        definition::{MatmulElems, MatmulGlobalElems},\n        launch::MatmulInputBinding,\n    },\n};\n\nuse crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\n\npub fn dgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    input_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::Strided,\n    };\n    launch_backwards_data::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        out_grad,\n        weights,\n        input_shape,\n        options,\n    )\n}\n\npub fn dgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    input_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,\n    };\n    launch_backwards_data::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        out_grad,\n        weights,\n        input_shape,\n        options,\n    )\n}\n\npub fn dgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    input_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    launch_backwards_data::<R, N>(\n        &Strategy::Simple {\n            read_strategy: ReadingStrategy::Tma,\n            tile_kind,\n        },\n        out_grad,\n        weights,\n        input_shape,\n        options,\n    )\n}\n\n/// Perform a convolution backwards data pass using the implicit GEMM (im2col) algorithm, using\n/// cubecl tiling matmul components.\n///\n/// * `input` - The input feature map\n/// * `out_grad` - The output gradients\n/// * `weight_shape` - The shape of the weights/weight gradients\n/// * `options` - The options to use for the convolution\npub fn launch_backwards_data<R: CubeRuntime, const N: usize>(\n    strategy: &Strategy,\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    input_shape: Shape,\n    options: ConvOptions<N>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {\n        return Err(ConvSetupError::Groups(options.groups));\n    }\n\n    let out_dtype = out_grad.dtype;\n\n    let in_grad = empty_device_dtype(\n        out_grad.client.clone(),\n        out_grad.device.clone(),\n        input_shape,\n        out_dtype,\n    );\n\n    let client = out_grad.client.clone();\n    let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {\n        lhs: out_grad.dtype.into(),\n        rhs: weights.dtype.into(),\n        out: out_dtype.into(),\n    });\n    let out_grad_dtype = out_grad.dtype;\n    let weights_dtype = weights.dtype;\n    let out_grad = MatmulInputBinding::new(out_grad.binding(), out_grad_dtype.into());\n    let weights = MatmulInputBinding::new(weights.binding(), weights_dtype.into());\n\n    backward_data::launch_ref::<R, N>(\n        strategy,\n        &client,\n        out_grad,\n        weights,\n        in_grad.clone().binding(),\n        ConvolutionArgs {\n            stride: options.stride,\n            padding: options.padding,\n            dilation: options.dilation,\n        },\n        dtypes,\n    )?;\n\n    Ok(in_grad)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/mod.rs",
    "content": "pub mod launch;\npub use launch::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_data/mod.rs",
    "content": "pub mod fallback;\npub mod implicit_gemm;\n\n#[cfg(feature = \"autotune\")]\npub mod tune;\n\n#[cfg(feature = \"autotune\")]\npub(crate) use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse burn_std::Shape;\nuse cubecl::{\n    ir::StorageType,\n    tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},\n};\nuse cubek::convolution::AcceleratedTileKind;\n\nuse crate::{\n    CubeAutotuneKey, CubeRuntime, CubeTuneId,\n    kernel::conv::{\n        ConvAutotuneKey,\n        backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},\n    },\n    tensor::CubeTensor,\n};\n\n/// Executes autotune on conv2d operations\npub fn dgrad_autotune<R: CubeRuntime, const N: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    input_shape: Shape,\n    options: ConvOptions<N>,\n) -> CubeTensor<R> {\n    let client = out_grad.client.clone();\n\n    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();\n\n    // Note: TMA isn't currently implemented properly, and will always error.\n    // It's kept here so it gets automatically enabled as soon as cubek updates.\n    // No CMMA for TMA because swizzling will be mandatory for good performance on dgrad.\n    let tunables = TUNER.init(|| {\n        TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)\n            .with(Tunable::new(\n                \"wgrad_fallback\",\n                conv_data_backward_fallback::<R, N>,\n            ))\n            .with(Tunable::new(\n                \"simple_sync_cmma\",\n                |input, grad, shape, options| {\n                    dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_sync_mma\",\n                |input, grad, shape, options| {\n                    dgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_cmma\",\n                |input, grad, shape, options| {\n                    dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_mma\",\n                |input, grad, shape, options| {\n                    dgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_tma_mma\",\n                |input, grad, shape, options| {\n                    dgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&out_grad.client, &out_grad.device),\n        &client,\n        tunables,\n        (out_grad, weights, input_shape, options),\n    )\n}\n\npub fn create_wgrad_input<R: CubeRuntime, const N: usize>(\n    _key: &CubeAutotuneKey,\n    out_grad: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    input_shape: &Shape,\n    options: &ConvOptions<N>,\n) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {\n    (\n        out_grad.clone(),\n        weights.clone(),\n        input_shape.clone(),\n        options.clone(),\n    )\n}\n\nfn create_key<R: CubeRuntime, const N: usize>(\n    out_grad: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    input_shape: &Shape,\n    options: &ConvOptions<N>,\n) -> CubeAutotuneKey {\n    let dtype = out_grad.dtype;\n    let rank = out_grad.meta.num_dims();\n    let dim_c = rank - 1;\n\n    let batch_size = out_grad.meta.shape()[0];\n    let in_channels = input_shape[dim_c];\n    let out_channels = out_grad.meta.shape()[dim_c];\n\n    let kernel_size = weights.meta.shape()[1..dim_c].to_vec();\n    let in_shape = input_shape[1..dim_c]\n        .iter()\n        .map(|shape| anchor(*shape, None, None, None))\n        .collect();\n\n    let ConvOptions {\n        stride,\n        padding,\n        dilation,\n        groups,\n    } = options.clone();\n\n    let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {\n        stride_align(out_grad.meta.strides(), out_grad.dtype.into())\n    } else {\n        0\n    };\n    let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);\n    let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {\n        stride_align(weights.meta.strides(), weights.dtype.into())\n    } else {\n        0\n    };\n    let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);\n\n    CubeAutotuneKey::Conv(ConvAutotuneKey::new(\n        kernel_size,\n        stride.to_vec(),\n        padding.to_vec(),\n        dilation.to_vec(),\n        groups,\n        in_channels,\n        out_channels,\n        in_shape,\n        batch_size,\n        false,\n        dtype,\n        lhs_shape_align,\n        lhs_stride_align,\n        rhs_shape_align,\n        rhs_stride_align,\n    ))\n}\n\n/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's\n/// repeat number, so it's the largest align that can have performance impacts.\nconst MAX_STRIDE_FACTOR: u32 = 10;\n\n/// Defines the non-contiguous stride alignment in terms of powers of two\nfn stride_align(strides: &[usize], elem: StorageType) -> u8 {\n    let max = MAX_STRIDE_FACTOR;\n    let dim_c = strides.len() - 1;\n    let factor = strides[..dim_c]\n        .iter()\n        .map(|it| (*it * elem.size_bits()) / 8)\n        .map(|it| it.trailing_zeros())\n        .min()\n        .unwrap_or(max);\n    factor.min(max) as u8\n}\n\n/// Defines the potential vectorization.\nfn pow2_factor(axis: usize) -> u8 {\n    axis.trailing_zeros().min(4) as u8\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs",
    "content": "use burn_backend::{TensorMetadata, ops::ConvOptions};\nuse burn_std::{Shape, Slice};\nuse cubek::convolution::components::ConvSetupError;\n\nuse crate::{\n    CubeRuntime,\n    kernel::{conv::base::conv_forward_nhwc, slice, slice_assign},\n    ops::{numeric::empty_device_dtype, swap_dims},\n    tensor::CubeTensor,\n};\n\n/// Calculate the convolution backward pass with regard to the weight gradients.\npub fn conv_weight_backward_fallback<R: CubeRuntime, const N_DIM: usize>(\n    input: CubeTensor<R>,\n    output_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N_DIM>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    match options.groups == 1 {\n        true => conv_weight_grad_no_groups::<R, N_DIM>(input, output_grad, weight_shape, options),\n        false => conv_weight_grad_groups::<R, N_DIM>(input, output_grad, weight_shape, options),\n    }\n}\n\nfn conv_weight_grad_no_groups<R: CubeRuntime, const N_DIM: usize>(\n    input: CubeTensor<R>,\n    output_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N_DIM>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let dim_c = input.rank() - 1;\n\n    let input_swapped = swap_dims(input, 0, dim_c);\n    let out_grad_swapped = swap_dims(output_grad, 0, dim_c);\n    let weight_grad_swapped = conv_forward_nhwc(\n        input_swapped,\n        out_grad_swapped,\n        None,\n        ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n        Default::default(),\n    )?;\n    let mut weight_grad = swap_dims(weight_grad_swapped, 0, dim_c);\n    if weight_grad.shape() != weight_shape {\n        let ranges = weight_shape.iter().map(|&s| 0..s).collect::<Vec<_>>();\n        weight_grad = slice(weight_grad, &ranges);\n    }\n\n    Ok(weight_grad)\n}\n\n#[allow(clippy::single_range_in_vec_init, reason = \"False positive\")]\nfn conv_weight_grad_groups<R: CubeRuntime, const N_DIM: usize>(\n    input: CubeTensor<R>,\n    output_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N_DIM>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let mut weight_grad = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        weight_shape.clone(),\n        input.dtype,\n    );\n\n    let dim_c = input.rank() - 1;\n\n    let channels_out = weight_shape[0];\n    let increment_co = channels_out / options.groups;\n\n    let input_swapped = swap_dims(input, 0, dim_c);\n    let output_grad_swapped = swap_dims(output_grad, 0, dim_c);\n\n    let kernel_size = &weight_shape[1..dim_c];\n    let kernel_size_slice = kernel_size.iter().map(|&s| 0..s).collect::<Vec<_>>();\n    let increment_ci = weight_grad.meta.shape()[dim_c];\n\n    for g in 0..options.groups {\n        let start_idx_ci = g * increment_ci;\n        let end_idx_ci = (g + 1) * increment_ci;\n        let start_idx_co = g * increment_co;\n        let end_idx_co = (g + 1) * increment_co;\n\n        let input = slice(input_swapped.clone(), &[start_idx_ci..end_idx_ci]);\n        let grad = slice(output_grad_swapped.clone(), &[start_idx_co..end_idx_co]);\n\n        let weight_grad_tmp = conv_forward_nhwc(\n            input,\n            grad,\n            None,\n            ConvOptions::new(options.dilation, options.padding, options.stride, 1),\n            Default::default(),\n        )?;\n        let mut weight_grad_tmp = swap_dims(weight_grad_tmp, 0, dim_c);\n        let kernel_size_tmp = &weight_grad_tmp.meta.shape()[1..dim_c];\n\n        if kernel_size != kernel_size_tmp {\n            let mut slices = vec![0..increment_co];\n            slices.extend(kernel_size_slice.clone());\n            slices.push(0..increment_ci);\n            weight_grad_tmp = slice(weight_grad_tmp, &slices);\n        }\n\n        let mut slices = vec![start_idx_co..end_idx_co];\n        slices.extend(kernel_size_slice.clone());\n        slices.push(0..increment_ci);\n        let slices = slices.into_iter().map(Slice::from).collect::<Vec<_>>();\n\n        weight_grad = slice_assign(weight_grad, &slices, weight_grad_tmp);\n    }\n\n    Ok(weight_grad)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse burn_std::Shape;\nuse cubek::{\n    convolution::{\n        AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy, backward_weight,\n        components::ConvSetupError,\n    },\n    matmul::{\n        definition::{MatmulElems, MatmulGlobalElems},\n        launch::MatmulInputBinding,\n    },\n};\n\nuse crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\n\npub(crate) fn wgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::Strided,\n    };\n    launch_backwards_weight::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        input,\n        out_grad,\n        weight_shape,\n        options,\n    )\n}\n\npub(crate) fn wgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,\n    };\n    launch_backwards_weight::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        input,\n        out_grad,\n        weight_shape,\n        options,\n    )\n}\n\npub(crate) fn wgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    launch_backwards_weight::<R, N>(\n        &Strategy::Simple {\n            read_strategy: ReadingStrategy::Tma,\n            tile_kind,\n        },\n        input,\n        out_grad,\n        weight_shape,\n        options,\n    )\n}\n\n/// Perform a convolution backwards weight pass using the implicit GEMM (im2col) algorithm, using\n/// cubecl tiling matmul components.\n///\n/// * `input` - The input feature map\n/// * `out_grad` - The output gradients\n/// * `weight_shape` - The shape of the weights/weight gradients\n/// * `options` - The options to use for the convolution\npub fn launch_backwards_weight<R: CubeRuntime, const N: usize>(\n    strategy: &Strategy,\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    if options.groups != 1 {\n        return Err(ConvSetupError::Groups(options.groups));\n    }\n\n    let out_dtype = out_grad.dtype;\n\n    let weight_grad = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        weight_shape,\n        out_dtype,\n    );\n\n    let client = input.client.clone();\n    let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {\n        lhs: input.dtype.into(),\n        rhs: out_grad.dtype.into(),\n        out: out_dtype.into(),\n    });\n    let input_dtype = input.dtype;\n    let out_grad_dtype = out_grad.dtype;\n    let input = MatmulInputBinding::new(input.binding(), input_dtype.into());\n    let out_grad = MatmulInputBinding::new(out_grad.binding(), out_grad_dtype.into());\n\n    backward_weight::launch_ref::<R, N>(\n        strategy,\n        &client,\n        input,\n        out_grad,\n        weight_grad.clone().binding(),\n        ConvolutionArgs {\n            stride: options.stride,\n            padding: options.padding,\n            dilation: options.dilation,\n        },\n        dtypes,\n    )?;\n\n    Ok(weight_grad)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/mod.rs",
    "content": "pub mod launch;\npub use launch::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_weight/mod.rs",
    "content": "pub mod fallback;\npub mod implicit_gemm;\n\n#[cfg(feature = \"autotune\")]\npub mod tune;\n\n#[cfg(feature = \"autotune\")]\npub(crate) use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse burn_std::Shape;\nuse cubecl::{\n    ir::StorageType,\n    tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},\n};\nuse cubek::convolution::AcceleratedTileKind;\n\nuse crate::{\n    CubeAutotuneKey, CubeRuntime, CubeTuneId,\n    kernel::conv::{\n        ConvAutotuneKey,\n        backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},\n    },\n    tensor::CubeTensor,\n};\n\n/// Executes autotune on the weight gradients pass for convolution\npub fn wgrad_autotune<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n) -> CubeTensor<R> {\n    let client = input.client.clone();\n\n    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        TunableSet::new(create_key::<R, N>, create_wgrad_input::<R, N>)\n            .with(Tunable::new(\n                \"wgrad_fallback\",\n                conv_weight_backward_fallback::<R, N>,\n            ))\n            .with(Tunable::new(\n                \"simple_sync_cmma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_sync_mma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_sync(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_cmma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_mma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_async(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_tma_cmma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_tma_mma\",\n                |input, grad, shape, options| {\n                    wgrad_gemm_simple_tma(input, grad, shape, options, AcceleratedTileKind::Mma)\n                },\n            ))\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&input.client, &input.device),\n        &client,\n        tunables,\n        (input, out_grad, weight_shape, options),\n    )\n}\n\npub fn create_wgrad_input<R: CubeRuntime, const N: usize>(\n    _key: &CubeAutotuneKey,\n    input: &CubeTensor<R>,\n    out_grad: &CubeTensor<R>,\n    weight_shape: &Shape,\n    options: &ConvOptions<N>,\n) -> (CubeTensor<R>, CubeTensor<R>, Shape, ConvOptions<N>) {\n    (\n        input.clone(),\n        out_grad.clone(),\n        weight_shape.clone(),\n        options.clone(),\n    )\n}\n\nfn create_key<R: CubeRuntime, const N: usize>(\n    input: &CubeTensor<R>,\n    out_grad: &CubeTensor<R>,\n    weight_shape: &Shape,\n    options: &ConvOptions<N>,\n) -> CubeAutotuneKey {\n    let dtype = input.dtype;\n    let rank = input.meta.num_dims();\n    let dim_c = rank - 1;\n\n    let batch_size = input.meta.shape()[0];\n    let in_channels = input.meta.shape()[dim_c];\n    let out_channels = weight_shape[0];\n\n    let kernel_size = weight_shape[1..dim_c].to_vec();\n    let in_shape = input.meta.shape()[1..dim_c]\n        .iter()\n        .map(|shape| anchor(*shape, None, None, None))\n        .collect();\n\n    let ConvOptions {\n        stride,\n        padding,\n        dilation,\n        groups,\n    } = options.clone();\n\n    let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 {\n        stride_align(out_grad.meta.strides(), out_grad.dtype.into())\n    } else {\n        0\n    };\n    let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align);\n    let rhs_stride_align = if input.meta.strides()[dim_c] == 1 {\n        stride_align(input.meta.strides(), input.dtype.into())\n    } else {\n        0\n    };\n    let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);\n\n    CubeAutotuneKey::Conv(ConvAutotuneKey::new(\n        kernel_size,\n        stride.to_vec(),\n        padding.to_vec(),\n        dilation.to_vec(),\n        groups,\n        in_channels,\n        out_channels,\n        in_shape,\n        batch_size,\n        false,\n        dtype,\n        lhs_shape_align,\n        lhs_stride_align,\n        rhs_shape_align,\n        rhs_stride_align,\n    ))\n}\n\n/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's\n/// repeat number, so it's the largest align that can have performance impacts.\nconst MAX_STRIDE_FACTOR: u32 = 10;\n\n/// Defines the non-contiguous stride alignment in terms of powers of two\nfn stride_align(strides: &[usize], elem: StorageType) -> u8 {\n    let max = MAX_STRIDE_FACTOR;\n    let dim_c = strides.len() - 1;\n    let factor = strides[..dim_c]\n        .iter()\n        .map(|it| (*it * elem.size_bits()) / 8)\n        .map(|it| it.trailing_zeros())\n        .min()\n        .unwrap_or(max);\n    factor.min(max) as u8\n}\n\n/// Defines the potential vectorization.\nfn pow2_factor(axis: usize) -> u8 {\n    axis.trailing_zeros().min(4) as u8\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/base.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse burn_std::Shape;\nuse cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};\n\n#[cfg(feature = \"autotune\")]\nuse crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};\nuse crate::{\n    CubeRuntime,\n    kernel::conv::{\n        backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},\n        backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},\n        forward::implicit_gemm::conv_gemm_simple_sync,\n    },\n    ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},\n    tensor::CubeTensor,\n};\n\nuse super::conv_direct;\n#[cfg(feature = \"autotune\")]\nuse super::forward::conv_autotune;\n\n/// The strategy to be used when launching a convolution kernel.\npub enum ConvStrategy {\n    /// A simple direct convolution.\n    Direct,\n    #[cfg(feature = \"autotune\")]\n    /// Using autotune to choose the best kernel based on runtime information.\n    Autotune,\n    /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and\n    /// has constraints on tensor shape.\n    ImplicitGemm,\n}\n\nimpl Default for ConvStrategy {\n    fn default() -> Self {\n        // if autotune is enabled, default to autotune\n        #[cfg(feature = \"autotune\")]\n        return ConvStrategy::Autotune;\n\n        // if autotune is disabled, default to the more memory-conservative algorithm\n        #[cfg(not(feature = \"autotune\"))]\n        ConvStrategy::Direct\n    }\n}\n\n/// Performs an N-dimensional convolution with the given strategy\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\n/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.\npub fn conv_forward<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n    strategy: ConvStrategy,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let input = permute_nchw_to_nhwc(input);\n    let weight = permute_nchw_to_nhwc(weight);\n\n    let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;\n\n    Ok(permute_nhwc_to_nchw(out))\n}\n\n/// Performs an N-dimensional convolution with the given strategy on NHWC inputs/outputs\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\n/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.\npub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n    strategy: ConvStrategy,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    match strategy {\n        ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),\n        #[cfg(feature = \"autotune\")]\n        ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),\n        ConvStrategy::ImplicitGemm => {\n            if options.groups != 1 {\n                conv_direct::<R, N>(input, weight, bias, options)\n            } else {\n                conv_gemm_simple_sync::<R, N>(\n                    input,\n                    weight,\n                    bias,\n                    options,\n                    AcceleratedTileKind::Cmma,\n                )\n            }\n        }\n    }\n}\n\n/// Performs an N-dimensional convolution backwards pass with regard to weight, with the given strategy\n///\n/// * `input` - The input feature map\n/// * `out_grad` - The output gradients\n/// * `weight_shape` - The shape of the weights/weight gradients\n/// * `options` - The options used for the convolution\n/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.\npub fn conv_weight_backward<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    weight_shape: Shape,\n    options: ConvOptions<N>,\n    strategy: ConvStrategy,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let input = permute_nchw_to_nhwc(input);\n    let out_grad = permute_nchw_to_nhwc(out_grad);\n    let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);\n\n    let weight_grad = match strategy {\n        ConvStrategy::Direct => {\n            conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)\n        }\n        #[cfg(feature = \"autotune\")]\n        ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(\n            input,\n            out_grad,\n            weight_shape,\n            options,\n        )),\n        ConvStrategy::ImplicitGemm => {\n            if options.groups != 1 {\n                conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)\n            } else {\n                wgrad_gemm_simple_sync::<R, N>(\n                    input,\n                    out_grad,\n                    weight_shape,\n                    options,\n                    AcceleratedTileKind::Cmma,\n                )\n            }\n        }\n    }?;\n\n    Ok(permute_nhwc_to_nchw(weight_grad))\n}\n\n/// Performs an N-dimensional convolution backwards data pass with the given strategy\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `in_shape` - The shape of the input to the layer\n/// * `options` - The options to use for the convolution\n/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.\npub fn conv_data_backward<R: CubeRuntime, const N: usize>(\n    out_grad: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    in_shape: Shape,\n    options: ConvOptions<N>,\n    strategy: ConvStrategy,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let out_grad = permute_nchw_to_nhwc(out_grad);\n    let weights = permute_nchw_to_nhwc(weights);\n    let in_shape = permute_nchw_to_nhwc_shape(in_shape);\n\n    let weight_grad = match strategy {\n        ConvStrategy::Direct => {\n            conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?\n        }\n        #[cfg(feature = \"autotune\")]\n        ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),\n        ConvStrategy::ImplicitGemm => {\n            if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {\n                conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?\n            } else {\n                dgrad_gemm_simple_sync::<R, N>(\n                    out_grad,\n                    weights,\n                    in_shape,\n                    options,\n                    AcceleratedTileKind::Cmma,\n                )?\n            }\n        }\n    };\n\n    Ok(permute_nhwc_to_nchw(weight_grad))\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose2d/base.rs",
    "content": "use crate::{CubeRuntime, tensor::CubeTensor};\nuse burn_backend::ops::ConvTransposeOptions;\nuse cubek::convolution::components::ConvSetupError;\n\n#[cfg(feature = \"autotune\")]\nuse super::conv_transpose2d_autotune;\nuse super::{conv_transpose2d_col2im, conv_transpose2d_direct};\n\n/// The strategy to be used when launching a conv_transpose kernel.\npub enum ConvTranspose2dStrategy {\n    /// A simple direct convolution.\n    Direct,\n    #[cfg(feature = \"autotune\")]\n    /// Using autotune to choose the best kernel based on runtime information.\n    Autotune,\n    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.\n    Gemm,\n}\n\nimpl Default for ConvTranspose2dStrategy {\n    fn default() -> Self {\n        // if autotune is enabled, default to autotune\n        #[cfg(feature = \"autotune\")]\n        return ConvTranspose2dStrategy::Autotune;\n\n        // if autotune is disabled, default to the more memory-conservative algorithm\n        #[cfg(not(feature = \"autotune\"))]\n        ConvTranspose2dStrategy::Direct\n    }\n}\n\n/// Performs a 2D convolution with the given strategy\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\n/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.\npub fn conv_transpose2d<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvTransposeOptions<2>,\n    strategy: ConvTranspose2dStrategy,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    match strategy {\n        ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),\n        #[cfg(feature = \"autotune\")]\n        ConvTranspose2dStrategy::Autotune => {\n            Ok(conv_transpose2d_autotune(input, weight, bias, options))\n        }\n        ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        conv::batches_per_run,\n        into_contiguous_aligned,\n        matmul::{MatmulStrategy, matmul},\n        slice,\n        utils::{address_type, decompose_linear, shape_divmod},\n    },\n    ops::{numeric::empty_device_dtype, reshape, swap_dims},\n    tensor::CubeTensor,\n};\nuse burn_backend::{\n    Shape,\n    ops::{ConvTransposeOptions, conv::calculate_conv_transpose_output_size},\n};\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\nuse cubek::convolution::components::ConvSetupError;\n\n/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm.\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\npub fn conv_transpose2d_col2im<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvTransposeOptions<2>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();\n    let [batch_size, _, input_h, input_w] = input.meta.shape().dims();\n    let groups = options.groups;\n    let input_ch_per_group = input_channels / groups;\n    let ConvTransposeOptions {\n        padding: [padding_h, padding_w],\n        padding_out: [padding_out_h, padding_out_w],\n        dilation: [dilation_h, dilation_w],\n        stride: [stride_h, stride_w],\n        ..\n    } = options.clone();\n\n    let im_h = calculate_conv_transpose_output_size(\n        kernel_h,\n        stride_h,\n        padding_h,\n        padding_out_h,\n        dilation_h,\n        input_h,\n    );\n    let im_w = calculate_conv_transpose_output_size(\n        kernel_w,\n        stride_w,\n        padding_w,\n        padding_out_w,\n        dilation_w,\n        input_w,\n    );\n    let im_channels = im_ch_per_group * groups;\n\n    let batches_per_run = batches_per_run(\n        batch_size,\n        input_h * input_w,\n        input.client.properties().hardware.plane_size_max as usize,\n    )?;\n    let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;\n\n    let weight = reshape(\n        weight.clone(),\n        Shape::new([groups, input_ch_per_group, col_shape_0]),\n    );\n    let weight = into_contiguous_aligned(swap_dims(weight, 1, 2));\n\n    if batches_per_run != batch_size {\n        let runs = batch_size / batches_per_run;\n\n        let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]);\n        let image = empty_device_dtype(\n            input.client.clone(),\n            input.device.clone(),\n            im_shape,\n            input.dtype,\n        );\n\n        let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]);\n        let input = reshape(input, input_shape);\n        let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]);\n\n        for run in 0..runs {\n            let input = index(input.clone(), run);\n            let input = reshape(input, input_shape_run.clone());\n            let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);\n            let image_slice = index(image.clone(), run);\n            let image_slice = reshape(image_slice, im_shape);\n            execute(\n                input,\n                weight.clone(),\n                bias.clone(),\n                image_slice,\n                options.clone(),\n                kernel_h,\n                kernel_w,\n            )?;\n        }\n        Ok(reshape(\n            image,\n            Shape::new([batch_size, im_channels, im_h, im_w]),\n        ))\n    } else {\n        let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);\n        let image = empty_device_dtype(\n            input.client.clone(),\n            input.device.clone(),\n            im_shape,\n            input.dtype,\n        );\n        execute(\n            input,\n            weight,\n            bias,\n            image.clone(),\n            options,\n            kernel_h,\n            kernel_w,\n        )?;\n        Ok(image)\n    }\n}\n\npub(crate) fn index<R: CubeRuntime>(tensor: CubeTensor<R>, i: usize) -> CubeTensor<R> {\n    #[allow(clippy::single_range_in_vec_init)]\n    let mut indices = vec![i..i + 1];\n    for dim in tensor.meta.shape()[1..].iter() {\n        indices.push(0..*dim);\n    }\n    let mut tensor = slice(tensor, &indices);\n    tensor.meta.remove(0);\n    tensor\n}\n\n#[allow(clippy::too_many_arguments)]\nfn execute<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    image: CubeTensor<R>,\n    options: ConvTransposeOptions<2>,\n    kernel_h: usize,\n    kernel_w: usize,\n) -> Result<(), ConvSetupError> {\n    let [batch_size, _, input_h, input_w] = input.meta.shape().dims();\n    let [groups, col_shape_0, input_ch_per_group] = weight.meta.shape().dims();\n\n    let col_shape_1 = batch_size * input_h * input_w;\n\n    let input = swap_dims(input, 0, 1);\n    let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);\n    let input = reshape(input, input_shape);\n\n    let dtype = input.dtype;\n    let columns = matmul(weight, input, None, MatmulStrategy::default(), dtype)?;\n    let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));\n\n    col2im(\n        columns, bias, image, kernel_h, kernel_w, input_h, input_w, options,\n    )?;\n\n    Ok(())\n}\n\n#[allow(clippy::too_many_arguments)]\nfn col2im<R: CubeRuntime>(\n    columns: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    out: CubeTensor<R>,\n    kernel_h: usize,\n    kernel_w: usize,\n    out_h: usize,\n    out_w: usize,\n    options: ConvTransposeOptions<2>,\n) -> Result<(), LaunchError> {\n    let dtype = columns.dtype;\n\n    let columns = into_contiguous_aligned(columns);\n    let bias = bias.map(into_contiguous_aligned);\n\n    let num_elems = out.meta.num_elements();\n\n    let cube_dim = CubeDim::new(&columns.client, num_elems);\n    let cube_count = calculate_cube_count_elemwise(&columns.client, num_elems, cube_dim);\n\n    let shape = shape_divmod(&out);\n    unsafe {\n        col2im_kernel::launch_unchecked(\n            &columns.client.clone(),\n            cube_count,\n            cube_dim,\n            address_type!(columns, bias, out),\n            columns.into_tensor_arg(),\n            bias.map(|bias| bias.into_tensor_arg()).into(),\n            out.into_linear_view(),\n            shape,\n            Col2ImArgsLaunch::new(\n                out_h,\n                out_w,\n                kernel_h,\n                kernel_w,\n                options.padding[0],\n                options.padding[1],\n                options.dilation[0],\n                options.dilation[1],\n                options.stride[0],\n                options.stride[1],\n            ),\n            dtype.into(),\n        )\n    };\n\n    Ok(())\n}\n\n#[derive(CubeLaunch, CubeType)]\nstruct Col2ImArgs {\n    out_h: usize,\n    out_w: usize,\n\n    kernel_h: usize,\n    kernel_w: usize,\n\n    pad_h: usize,\n    pad_w: usize,\n    dilation_h: usize,\n    dilation_w: usize,\n    stride_h: usize,\n    stride_w: usize,\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn col2im_kernel<E: Numeric>(\n    columns: &Tensor<E>,\n    bias: &ComptimeOption<Tensor<E>>,\n    image: &mut LinearView<E, ReadWrite>,\n    image_shape: Sequence<FastDivmod<usize>>,\n    args: &Col2ImArgs,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= image.shape() {\n        terminate!();\n    }\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS, &image_shape);\n    let [batch, ch_im, im_y, im_x] = *pos else {\n        unreachable!()\n    };\n\n    let im_x = im_x + args.pad_w;\n    let im_y = im_y + args.pad_h;\n\n    let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1;\n    let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1;\n\n    let mut val = E::zero();\n\n    let x_col_start = if im_x >= kernel_extent_w {\n        (im_x - kernel_extent_w) / args.stride_w + 1\n    } else {\n        0usize.runtime()\n    };\n    let x_col_end = clamp_max(im_x / args.stride_w + 1, args.out_w);\n    let y_col_start = if im_y >= kernel_extent_h {\n        (im_y - kernel_extent_h) / args.stride_h + 1\n    } else {\n        0usize.runtime()\n    };\n    let y_col_end = clamp_max(im_y / args.stride_h + 1, args.out_h);\n\n    for col_y in y_col_start..y_col_end {\n        let kernel_y = im_y - col_y * args.stride_h;\n        for col_x in x_col_start..x_col_end {\n            let kernel_x = im_x - col_x * args.stride_w;\n\n            if kernel_y.is_multiple_of(args.dilation_h) && kernel_x.is_multiple_of(args.dilation_w)\n            {\n                let kernel_y = kernel_y / args.dilation_h;\n                let kernel_x = kernel_x / args.dilation_w;\n\n                let col_k =\n                    ch_im * args.kernel_h * args.kernel_w + kernel_y * args.kernel_w + kernel_x;\n                let col_n = batch * args.out_h * args.out_w + col_y * args.out_w + col_x;\n                let col_pos = col_k * columns.stride(0) + col_n * columns.stride(1);\n                val += columns[col_pos];\n            }\n        }\n    }\n\n    #[comptime]\n    match bias {\n        ComptimeOption::Some(bias) => image[ABSOLUTE_POS] = val + bias[ch_im],\n        ComptimeOption::None => image[ABSOLUTE_POS] = val,\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose2d/mod.rs",
    "content": "mod base;\nmod col2im;\n\nmod transpose_direct;\n\n#[cfg(feature = \"autotune\")]\nmod tune;\n\npub use base::*;\npub use col2im::*;\n\npub use transpose_direct::*;\n\n#[cfg(feature = \"autotune\")]\npub use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, decompose_linear, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse burn_backend::{Shape, ops::ConvTransposeOptions};\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\nuse cubek::convolution::components::ConvSetupError;\n\n#[derive(CubeLaunch, CubeType)]\nstruct ConvArgs {\n    conv_stride_0: usize,\n    conv_stride_1: usize,\n    dilation_0: usize,\n    dilation_1: usize,\n    padding_0: usize,\n    padding_1: usize,\n    groups: usize,\n}\n\n#[cube(launch, address_type = \"dynamic\")]\nfn conv_transpose2d_direct_kernel<E: Numeric>(\n    input: &Tensor<E>,\n    weight: &Tensor<E>,\n    bias: &ComptimeOption<Tensor<E>>,\n    output: &mut LinearView<E, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    args: ConvArgs,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.shape() {\n        terminate!();\n    }\n\n    let in_c_per_group = weight.shape(0) / args.groups;\n    let out_c_per_group = weight.shape(1);\n    let kernel_h = weight.shape(2);\n    let kernel_w = weight.shape(3);\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);\n    let [batch, oc_out, out_y, out_x] = *pos else {\n        unreachable!()\n    };\n\n    let k = oc_out / out_c_per_group;\n    let group = k % args.groups;\n    let out_c = oc_out - out_c_per_group * group;\n\n    let in_c_start = group * in_c_per_group;\n    let in_c_end = in_c_start + in_c_per_group;\n\n    let stride_0_i = args.conv_stride_0 as i32;\n    let stride_1_i = args.conv_stride_1 as i32;\n\n    let kms_h = (kernel_h * args.dilation_0) as i32 - stride_0_i;\n    let kms_w = (kernel_w * args.dilation_1) as i32 - stride_1_i;\n\n    let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i;\n    let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i;\n\n    let y_end = clamp(kms_h + y_start + 1, 0, input.shape(2) as i32) as usize;\n    let x_end = clamp(kms_w + x_start + 1, 0, input.shape(3) as i32) as usize;\n    let y_start = clamp_min(y_start, 0) as usize;\n    let x_start = clamp_min(x_start, 0) as usize;\n\n    let idx_input_batch = batch * input.stride(0);\n    let idx_weight_oc = out_c * weight.stride(1);\n\n    let bias: ComptimeOption<E> = bias.map(|bias| bias[oc_out]);\n    let mut sum = bias.unwrap_or_default();\n\n    let numerator_h_base = out_y + args.padding_0;\n    let numerator_w_base = out_x + args.padding_1;\n\n    for in_c in in_c_start..in_c_end {\n        let idx_input_ic = in_c * input.stride(1);\n        let idx_weight_ic = in_c * weight.stride(0);\n\n        for in_y in y_start..y_end {\n            let numerator_tmp = in_y * args.conv_stride_0;\n            let numerator_h = numerator_h_base - numerator_tmp;\n\n            if numerator_h_base >= numerator_tmp && numerator_h.is_multiple_of(args.dilation_0) {\n                let kernel_y = numerator_h / args.dilation_0;\n                let idx_input_y = in_y * input.stride(2);\n                let idx_weight_ky = kernel_y * weight.stride(2);\n\n                for in_x in x_start..x_end {\n                    let numerator_tmp = in_x * args.conv_stride_1;\n                    let numerator_w = numerator_w_base - numerator_tmp;\n\n                    if numerator_w_base >= numerator_tmp\n                        && numerator_w.is_multiple_of(args.dilation_1)\n                    {\n                        let kernel_x = numerator_w / args.dilation_1;\n                        let idx_input_x = in_x * input.stride(3);\n                        let idx_weight_kx = kernel_x * weight.stride(3);\n\n                        let index_input =\n                            idx_input_batch + idx_input_ic + idx_input_y + idx_input_x;\n                        let index_weight =\n                            idx_weight_ic + idx_weight_oc + idx_weight_ky + idx_weight_kx;\n\n                        let value = input[index_input];\n                        let weight = weight[index_weight];\n\n                        sum += value * weight;\n                    }\n                }\n            }\n        }\n    }\n\n    output[ABSOLUTE_POS] = sum;\n}\n\n/// Perform a 2D convolution transposition using the direct algorithm.\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\n///\npub fn conv_transpose2d_direct<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvTransposeOptions<2>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let [batch_size, _, in_height, in_width] = input.meta.shape().dims();\n    let [_, out_channels, kernel_0, kernel_1] = weight.meta.shape().dims();\n\n    let out_0 = (in_height - 1) * options.stride[0]\n        + options.dilation[0] * (kernel_0 - 1)\n        + options.padding_out[0]\n        - 2 * options.padding[0]\n        + 1;\n    let out_1 = (in_width - 1) * options.stride[1]\n        + options.dilation[1] * (kernel_1 - 1)\n        + options.padding_out[1]\n        - 2 * options.padding[1]\n        + 1;\n\n    let shape_out = Shape::new([batch_size, out_channels * options.groups, out_0, out_1]);\n\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        shape_out.clone(),\n        input.dtype,\n    );\n\n    let num_elems = output.meta.num_elements();\n    let cube_dim = CubeDim::new(&input.client, num_elems);\n    let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);\n    let dtype = input.dtype;\n\n    conv_transpose2d_direct_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, weight, bias, output),\n        input.into_tensor_arg(),\n        weight.into_tensor_arg(),\n        bias.map(|bias| bias.into_tensor_arg()).into(),\n        output.clone().into_linear_view(),\n        shape_divmod(&output),\n        ConvArgsLaunch::new(\n            options.stride[0],\n            options.stride[1],\n            options.dilation[0],\n            options.dilation[1],\n            options.padding[0],\n            options.padding[1],\n            options.groups,\n        ),\n        dtype.into(),\n    );\n\n    Ok(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs",
    "content": "use burn_backend::ops::ConvTransposeOptions;\nuse cubecl::tune::{LocalTuner, Tunable, TunableSet, local_tuner};\n\nuse crate::{\n    CubeAutotuneKey, CubeRuntime, CubeTuneId,\n    kernel::conv::{ConvTranspose2dAutotuneKey, conv_transpose2d_col2im, conv_transpose2d_direct},\n    tensor::CubeTensor,\n};\n\n/// Executes autotune on conv2d operations\npub fn conv_transpose2d_autotune<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weights: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvTransposeOptions<2>,\n) -> CubeTensor<R> {\n    let client = input.client.clone();\n\n    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tune_set = TUNER.init(|| {\n        TunableSet::new(create_key::<R>, create_transpose2d_input::<R>)\n            .with(Tunable::new(\n                \"conv_transpose2d_direct\",\n                conv_transpose2d_direct::<R>,\n            ))\n            .with(Tunable::new(\n                \"conv_transpose2d_col2im\",\n                conv_transpose2d_col2im::<R>,\n            ))\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&input.client, &input.device),\n        &client,\n        tune_set,\n        (input, weights, bias, options),\n    )\n}\n\npub fn create_transpose2d_input<R: CubeRuntime>(\n    _key: &CubeAutotuneKey,\n    input: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    bias: &Option<CubeTensor<R>>,\n    options: &ConvTransposeOptions<2>,\n) -> (\n    CubeTensor<R>,\n    CubeTensor<R>,\n    Option<CubeTensor<R>>,\n    ConvTransposeOptions<2>,\n) {\n    (\n        input.clone(),\n        weights.clone(),\n        bias.clone(),\n        options.clone(),\n    )\n}\n\nfn create_key<R: CubeRuntime>(\n    input: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    bias: &Option<CubeTensor<R>>,\n    options: &ConvTransposeOptions<2>,\n) -> CubeAutotuneKey {\n    let [batch_size, in_channels, height, width] = input.meta.shape().dims();\n    let [out_channels, _, kernel_h, kernel_w] = weights.meta.shape().dims();\n    let ConvTransposeOptions {\n        stride,\n        padding,\n        dilation,\n        groups,\n        padding_out,\n    } = options.clone();\n    CubeAutotuneKey::ConvTranspose(ConvTranspose2dAutotuneKey::new(\n        [kernel_h, kernel_w],\n        stride,\n        padding,\n        padding_out,\n        dilation,\n        groups,\n        in_channels,\n        out_channels,\n        height,\n        width,\n        batch_size,\n        bias.is_some(),\n        input.dtype,\n    ))\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs",
    "content": "use cubecl::{\n    calculate_cube_count_elemwise,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, decompose_linear, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse burn_backend::{Shape, ops::ConvTransposeOptions};\n\n#[derive(CubeLaunch, CubeType)]\nstruct ConvArgs {\n    conv_stride_0: usize,\n    conv_stride_1: usize,\n    conv_stride_2: usize,\n    dilation_0: usize,\n    dilation_1: usize,\n    dilation_2: usize,\n    padding_0: usize,\n    padding_1: usize,\n    padding_2: usize,\n    groups: usize,\n}\n\n#[cube(launch, address_type = \"dynamic\")]\nfn conv_transpose3d_kernel<E: Numeric>(\n    input: &Tensor<E>,\n    weight: &Tensor<E>,\n    bias: &ComptimeOption<Tensor<E>>,\n    output: &mut LinearView<E, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    args: ConvArgs,\n    #[define(E)] _dtype: StorageType,\n) {\n    let in_channels = weight.shape(0);\n    let out_c_per_group = weight.shape(1);\n    let kernel_size_0 = weight.shape(2);\n    let kernel_size_1 = weight.shape(3);\n    let kernel_size_2 = weight.shape(4);\n\n    let stride_0_i = args.conv_stride_0 as i32;\n    let stride_1_i = args.conv_stride_1 as i32;\n    let stride_2_i = args.conv_stride_2 as i32;\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS, &out_shape);\n    let [batch, out_c_out, out_z, out_y, out_x] = *pos else {\n        unreachable!()\n    };\n\n    let groups = args.groups;\n    let in_c_per_group = in_channels / groups;\n\n    let k = out_c_out / out_c_per_group;\n    let group = k % groups;\n    let out_channel = out_c_out - out_c_per_group * group;\n\n    let in_c_start = group * in_c_per_group;\n    let in_c_end = in_c_start + in_c_per_group;\n\n    let kernel_d = (kernel_size_0 * args.dilation_0 - args.conv_stride_0) as i32;\n    let kernel_h = (kernel_size_1 * args.dilation_1 - args.conv_stride_1) as i32;\n    let kernel_w = (kernel_size_2 * args.dilation_2 - args.conv_stride_2) as i32;\n\n    let z_start = ((out_z + args.padding_0) as i32 - kernel_d) / stride_0_i;\n    let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i;\n    let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i;\n\n    let z_end = clamp(kernel_d + z_start + 1, 0, input.shape(2) as i32) as usize;\n    let y_end = clamp(kernel_h + y_start + 1, 0, input.shape(3) as i32) as usize;\n    let x_end = clamp(kernel_w + x_start + 1, 0, input.shape(4) as i32) as usize;\n\n    let z_start = clamp_min(z_start, 0) as usize;\n    let y_start = clamp_min(y_start, 0) as usize;\n    let x_start = clamp_min(x_start, 0) as usize;\n\n    let index_input_batch = batch * input.stride(0);\n    let index_weight_out_c = out_channel * weight.stride(1);\n\n    let bias: ComptimeOption<E> = bias.map(|bias| bias[out_c_out]);\n    let mut sum = bias.unwrap_or_default();\n\n    let numerator_d_base = out_z + args.padding_0;\n    let numerator_h_base = out_y + args.padding_1;\n    let numerator_w_base = out_x + args.padding_2;\n\n    for in_c in in_c_start..in_c_end {\n        let index_input_in_c = in_c * input.stride(1);\n        let index_weight_in_c = in_c * weight.stride(0);\n\n        for in_z in z_start..z_end {\n            let numerator_tmp = in_z * args.conv_stride_0;\n            let numerator_d = numerator_d_base - numerator_tmp;\n\n            if numerator_d_base >= numerator_tmp && numerator_d.is_multiple_of(args.dilation_0) {\n                let kernel_z = numerator_d / args.dilation_0;\n                let index_input_z = in_z * input.stride(2);\n                let index_weight_kz = kernel_z * weight.stride(2);\n\n                for in_y in y_start..y_end {\n                    let numerator_tmp = in_y * args.conv_stride_1;\n                    let numerator_h = numerator_h_base - numerator_tmp;\n\n                    if numerator_h_base >= numerator_tmp\n                        && numerator_h.is_multiple_of(args.dilation_1)\n                    {\n                        let kernel_y = numerator_h / args.dilation_1;\n                        let index_input_y = in_y * input.stride(3);\n                        let index_weight_ky = kernel_y * weight.stride(3);\n\n                        for in_x in x_start..x_end {\n                            let numerator_tmp = in_x * args.conv_stride_2;\n                            let numerator_w = numerator_w_base - numerator_tmp;\n\n                            if numerator_w_base >= numerator_tmp\n                                && numerator_w.is_multiple_of(args.dilation_2)\n                            {\n                                let kernel_x = numerator_w / args.dilation_2;\n                                let index_input_x = in_x * input.stride(4);\n                                let index_weight_kx = kernel_x * weight.stride(4);\n\n                                let index_input = index_input_batch\n                                    + index_input_in_c\n                                    + index_input_z\n                                    + index_input_y\n                                    + index_input_x;\n\n                                let index_weight = index_weight_in_c\n                                    + index_weight_out_c\n                                    + index_weight_kz\n                                    + index_weight_ky\n                                    + index_weight_kx;\n\n                                let value = input[index_input];\n                                let weight = weight[index_weight];\n\n                                sum += value * weight;\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    output[ABSOLUTE_POS] = sum;\n}\n\npub(crate) fn conv_transpose3d<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvTransposeOptions<3>,\n) -> Result<CubeTensor<R>, LaunchError> {\n    let [batch_size, _, in_depth, in_height, in_width] = input.meta.shape().dims();\n    let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.meta.shape().dims();\n\n    let out_0 = (in_depth - 1) * options.stride[0]\n        + options.dilation[0] * (kernel_0 - 1)\n        + options.padding_out[0]\n        - 2 * options.padding[0]\n        + 1;\n    let out_1 = (in_height - 1) * options.stride[1]\n        + options.dilation[1] * (kernel_1 - 1)\n        + options.padding_out[1]\n        - 2 * options.padding[1]\n        + 1;\n    let out_2 = (in_width - 1) * options.stride[2]\n        + options.dilation[2] * (kernel_2 - 1)\n        + options.padding_out[2]\n        - 2 * options.padding[2]\n        + 1;\n\n    let shape_out = Shape::new([\n        batch_size,\n        out_channels * options.groups,\n        out_0,\n        out_1,\n        out_2,\n    ]);\n\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        shape_out.clone(),\n        input.dtype,\n    );\n\n    let num_elems = output.meta.num_elements();\n    let cube_dim = CubeDim::new(&input.client, num_elems);\n    let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim);\n\n    let dtype = input.dtype;\n    conv_transpose3d_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, weight, bias, output),\n        input.into_tensor_arg(),\n        weight.into_tensor_arg(),\n        bias.map(|bias| bias.into_tensor_arg()).into(),\n        output.clone().into_linear_view(),\n        shape_divmod(&output),\n        ConvArgsLaunch::new(\n            options.stride[0],\n            options.stride[1],\n            options.stride[2],\n            options.dilation[0],\n            options.dilation[1],\n            options.dilation[2],\n            options.padding[0],\n            options.padding[1],\n            options.padding[2],\n            options.groups,\n        ),\n        dtype.into(),\n    );\n\n    Ok(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs",
    "content": "use cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod};\nuse cubek::convolution::components::ConvSetupError;\n\nuse burn_backend::{\n    Shape,\n    ops::{DeformConvOptions, conv::calculate_conv_output_size},\n};\n\nuse crate::{\n    CubeRuntime,\n    kernel::{\n        AddOp, into_contiguous_aligned, launch_binop,\n        matmul::{MatmulStrategy, matmul},\n        utils::address_type,\n    },\n    ops::{numeric::zeros_client, reshape, swap_dims},\n    tensor::CubeTensor,\n};\n\n#[derive(CubeLaunch, CubeType)]\nstruct DeformConv2dArgs {\n    conv_stride_h: usize,\n    conv_stride_w: usize,\n    dilation_h: usize,\n    dilation_w: usize,\n    padding_h: InputScalar,\n    padding_w: InputScalar,\n    offset_groups: usize,\n\n    kernel_height: usize,\n    kernel_width: usize,\n    out_h: usize,\n    out_w: usize,\n}\n\n#[cube(launch, address_type = \"dynamic\")]\nfn deform_im2col_kernel<F: Float>(\n    input: &Tensor<F>,\n    offset: &Tensor<F>,\n    mask: &ComptimeOption<Tensor<F>>,\n    columns: &mut Tensor<F>,\n    pos_shape: Sequence<FastDivmod<usize>>,\n    args: &DeformConv2dArgs,\n    #[comptime] kernel_h_unroll: Option<usize>,\n    #[comptime] kernel_w_unroll: Option<usize>,\n    #[define(F)] _dtype: StorageType,\n) {\n    // position shape: [in_channels, batch_size, out_h, out_w]\n    // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]\n\n    let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height);\n    let unroll_h = kernel_h_unroll.is_some();\n    let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width);\n    let unroll_w = kernel_w_unroll.is_some();\n\n    let out_h = args.out_h;\n    let out_w = args.out_w;\n    let in_channels = input.shape(1);\n    let height = input.shape(2);\n    let width = input.shape(3);\n    let col_stride_0 = columns.stride(0);\n\n    let (rem, out_x) = pos_shape[3].div_mod(ABSOLUTE_POS);\n    let (rem, out_y) = pos_shape[2].div_mod(rem);\n    let (in_channel, batch) = pos_shape[1].div_mod(rem);\n\n    if in_channel >= in_channels {\n        terminate!()\n    }\n\n    let out_k_base = in_channel * kernel_height * kernel_width;\n    let out_n = batch * out_h * out_w + out_y * out_w + out_x;\n\n    let channels_per_offset_group = in_channels / args.offset_groups;\n    let group_index = in_channel / channels_per_offset_group;\n\n    let mut col_base_idx = out_k_base * columns.stride(0) + out_n * columns.stride(1);\n\n    let input_base_idx = batch * input.stride(0) + in_channel * input.stride(1);\n\n    let offset_base_idx = batch * offset.stride(0)\n        + group_index * kernel_height * kernel_width * 2 * offset.stride(1);\n\n    let mask_base_idx = mask.as_ref().map(|mask| {\n        batch * mask.stride(0) + group_index * kernel_height * kernel_width * mask.stride(1)\n    });\n\n    #[unroll(unroll_h)]\n    for kernel_y in 0..kernel_height {\n        #[unroll(unroll_w)]\n        for kernel_x in 0..kernel_width {\n            let mask_index = kernel_y * kernel_width + kernel_x;\n            let offset_index = mask_index * 2;\n\n            let offset_y = offset[offset_base_idx\n                + offset_index * offset.stride(1)\n                + out_y * offset.stride(2)\n                + out_x * offset.stride(3)];\n            let offset_x = offset[offset_base_idx\n                + (offset_index + 1) * offset.stride(1)\n                + out_y * offset.stride(2)\n                + out_x * offset.stride(3)];\n            let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h)\n                - args.padding_h.get::<F>()\n                + offset_y;\n            let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w)\n                - args.padding_w.get::<F>()\n                + offset_x;\n\n            let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx);\n            #[comptime]\n            let value = match mask.zip::<usize>(mask_base_idx) {\n                ComptimeOption::Some((mask, base_idx)) => {\n                    let mask_value = mask[base_idx\n                        + mask_index * mask.stride(1)\n                        + out_y * mask.stride(2)\n                        + out_x * mask.stride(3)];\n                    mask_value * interpolated\n                }\n                ComptimeOption::None => interpolated,\n            };\n\n            columns[col_base_idx] = value;\n            col_base_idx += col_stride_0;\n        }\n    }\n}\n\n#[cube]\npub(crate) fn bilinear_interpolate<F: Float>(\n    input: &Tensor<F>,\n    height: usize,\n    width: usize,\n    y: F,\n    x: F,\n    offset: usize,\n) -> F {\n    // To simplify code\n    let y = f32::cast_from(y);\n    let x = f32::cast_from(x);\n    let stride_y = input.stride(2);\n    let stride_x = input.stride(3);\n\n    let mut result = F::new(0.0);\n    if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {\n        let y_low = y.floor();\n        let x_low = x.floor();\n        let y_high = (y_low + 1.) as usize;\n        let x_high = (x_low + 1.) as usize;\n\n        let zero = F::new(0.0);\n        let v1: F = if y_low >= 0. && x_low >= 0. {\n            input[offset + y_low as usize * stride_y + x_low as usize * stride_x]\n        } else {\n            zero\n        };\n        let v2: F = if y_low >= 0. && x_high < width {\n            input[offset + y_low as usize * stride_y + x_high * stride_x]\n        } else {\n            zero\n        };\n        let v3: F = if y_high < height && x_low >= 0. {\n            input[offset + y_high * stride_y + x_low as usize * stride_x]\n        } else {\n            zero\n        };\n        let v4: F = if y_high < height && x_high < width {\n            input[offset + y_high * stride_y + x_high * stride_x]\n        } else {\n            zero\n        };\n\n        let l_y = y - y_low;\n        let l_x = x - x_low;\n        let h_y = 1.0 - l_y;\n        let h_x = 1.0 - l_x;\n\n        let w1 = F::cast_from(h_y * h_x);\n        let w2 = F::cast_from(h_y * l_x);\n        let w3 = F::cast_from(l_y * h_x);\n        let w4 = F::cast_from(l_y * l_x);\n\n        result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;\n    }\n    result\n}\n\npub(crate) fn deform_im2col<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    options: DeformConvOptions<2>,\n    out_dims: (usize, usize),\n    kernel_dims: (usize, usize),\n) -> Result<CubeTensor<R>, LaunchError> {\n    let client = input.client.clone();\n    let device = input.device.clone();\n    let dtype = input.dtype;\n\n    let [batch_size, in_channels, _, _] = input.meta.shape().dims();\n    let (out_height, out_width) = out_dims;\n    let (kernel_height, kernel_width) = kernel_dims;\n\n    let shape_out = Shape::new([\n        in_channels * kernel_height * kernel_width,\n        batch_size * out_height * out_width,\n    ]);\n\n    let pos_shape = [in_channels, batch_size, out_height, out_width]\n        .into_iter()\n        .collect();\n\n    let output = zeros_client(client.clone(), device.clone(), shape_out.clone(), dtype);\n\n    let num_kernels = in_channels * batch_size * out_height * out_width;\n    let cube_dim = CubeDim::new(&input.client, num_kernels);\n    let cube_count = calculate_cube_count_elemwise(&input.client, num_kernels, cube_dim);\n\n    deform_im2col_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, offset, mask, output),\n        input.into_tensor_arg(),\n        offset.into_tensor_arg(),\n        mask.map(|mask| mask.into_tensor_arg()).into(),\n        output.clone().binding().into_tensor_arg(),\n        pos_shape,\n        DeformConv2dArgsLaunch::new(\n            options.stride[0],\n            options.stride[1],\n            options.dilation[0],\n            options.dilation[1],\n            {\n                let val = options.padding[0] as f32;\n                InputScalar::new(val, dtype)\n            },\n            {\n                let val = options.padding[1] as f32;\n                InputScalar::new(val, dtype)\n            },\n            options.offset_groups,\n            kernel_height,\n            kernel_width,\n            out_height,\n            out_width,\n        ),\n        Some(kernel_height),\n        Some(kernel_width),\n        dtype.into(),\n    );\n\n    Ok(output)\n}\n\npub(crate) fn deform_conv2d<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    bias: Option<CubeTensor<R>>,\n    options: DeformConvOptions<2>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let input = into_contiguous_aligned(input);\n    let offset = into_contiguous_aligned(offset);\n    let weight = into_contiguous_aligned(weight);\n    let mask = mask.map(|it| into_contiguous_aligned(it));\n    let bias = bias.map(|it| into_contiguous_aligned(it));\n\n    let [batch_size, _, in_height, in_width] = input.meta.shape().dims();\n    let [out_channels, _, kernel_h, kernel_w] = weight.meta.shape().dims();\n    let groups = options.weight_groups;\n\n    let out_h = calculate_conv_output_size(\n        kernel_h,\n        options.stride[0],\n        options.padding[0],\n        options.dilation[0],\n        in_height,\n    );\n    let out_w = calculate_conv_output_size(\n        kernel_w,\n        options.stride[1],\n        options.padding[1],\n        options.dilation[1],\n        in_width,\n    );\n    let out_dims = (out_h, out_w);\n\n    let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w))?;\n\n    let [col_size_0, col_size_1] = columns.meta.shape().dims();\n    let col_size_0 = col_size_0 / groups;\n    let out_c_per_group = out_channels / groups;\n\n    let dtype = weight.dtype;\n    let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));\n    let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));\n    let out = matmul(weight, columns, None, MatmulStrategy::default(), dtype)?;\n\n    let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));\n    let out = swap_dims(out, 0, 1);\n\n    if let Some(bias) = bias {\n        let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));\n        Ok(launch_binop::<R, AddOp>(out, bias))\n    } else {\n        Ok(out)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs",
    "content": "use super::{bilinear_interpolate, deform_im2col, index};\nuse crate::{\n    CubeRuntime,\n    kernel::{\n        cast, into_contiguous_aligned,\n        matmul::{MatmulStrategy, matmul},\n        reduce::reduce_dim,\n        slice_assign,\n        utils::{address_type, decompose_linear},\n    },\n    ops::{\n        numeric::{empty_device_dtype, zeros_client},\n        reshape, swap_dims,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions};\nuse cubecl::{\n    CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube,\n    features::TypeUsage,\n    ir::FloatKind,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\nuse cubek::{\n    convolution::components::ConvSetupError,\n    reduce::components::instructions::ReduceOperationConfig,\n};\nuse std::marker::PhantomData;\n\n/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.\n#[allow(\n    clippy::single_range_in_vec_init,\n    clippy::type_complexity,\n    clippy::too_many_arguments\n)]\npub(crate) fn deform_conv2d_backward<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    bias: Option<CubeTensor<R>>,\n    out_grad: CubeTensor<R>,\n    options: DeformConvOptions<2>,\n) -> Result<\n    (\n        CubeTensor<R>,\n        CubeTensor<R>,\n        CubeTensor<R>,\n        Option<CubeTensor<R>>,\n        Option<CubeTensor<R>>,\n    ),\n    ConvSetupError,\n> {\n    let [_, _, out_h, out_w] = out_grad.meta.shape().dims();\n    let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims();\n\n    let gradient_bias = bias.map(|bias| {\n        let grad = reduce_dim(\n            out_grad.clone(),\n            None,\n            0,\n            Default::default(),\n            ReduceOperationConfig::Sum,\n        )\n        .unwrap();\n        let grad = reduce_dim(\n            grad,\n            None,\n            2,\n            Default::default(),\n            ReduceOperationConfig::Sum,\n        )\n        .unwrap();\n        let grad = reduce_dim(\n            grad,\n            None,\n            3,\n            Default::default(),\n            ReduceOperationConfig::Sum,\n        )\n        .unwrap();\n\n        reshape(grad, bias.meta.shape.clone())\n    });\n\n    let input = into_contiguous_aligned(input);\n    let offset = into_contiguous_aligned(offset);\n    let weight = into_contiguous_aligned(weight);\n    let mask = mask.map(|it| into_contiguous_aligned(it));\n\n    let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(\n        input.clone(),\n        weight.clone(),\n        offset.clone(),\n        mask.clone(),\n        out_grad.clone(),\n        &options,\n        (kernel_h, kernel_w),\n    )?;\n\n    let weight_grad = compute_weight_grad(\n        input,\n        offset,\n        mask,\n        out_grad,\n        options,\n        (kernel_h, kernel_w),\n        (out_h, out_w),\n    )?;\n\n    Ok((\n        input_gradient,\n        offset_gradient,\n        weight_grad,\n        mask_gradient,\n        gradient_bias,\n    ))\n}\n\nfn compute_weight_grad<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    out_grad: CubeTensor<R>,\n    options: DeformConvOptions<2>,\n    kernel_dims: (usize, usize),\n    out_dims: (usize, usize),\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let [_, in_channels, _, _] = input.meta.shape().dims();\n    let [_, out_channels, _, _] = out_grad.meta.shape().dims();\n    let (kernel_h, kernel_w) = kernel_dims;\n    let groups = options.weight_groups;\n    let dtype = input.dtype;\n\n    let in_c_per_group = in_channels / groups;\n    let out_c_per_group = out_channels / groups;\n\n    let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?;\n    let [col_size_0, col_size_1] = columns.meta.shape().dims();\n    let col_size_0 = col_size_0 / groups;\n\n    let out_grad = swap_dims(out_grad, 0, 1);\n    let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));\n\n    let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));\n    let columns = swap_dims(columns, 1, 2);\n\n    let grad_weight = matmul(out_grad, columns, None, MatmulStrategy::default(), dtype)?;\n\n    Ok(reshape(\n        grad_weight,\n        Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),\n    ))\n}\n\ntype InputGradients<R> = (CubeTensor<R>, CubeTensor<R>, Option<CubeTensor<R>>);\n\nfn backward_gradient_inputs<R: CubeRuntime>(\n    image: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    out_grad: CubeTensor<R>,\n    options: &DeformConvOptions<2>,\n    kernel_dims: (usize, usize),\n) -> Result<InputGradients<R>, ConvSetupError> {\n    let client = out_grad.client.clone();\n    let device = out_grad.device.clone();\n\n    let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims();\n    let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims();\n\n    let groups = options.weight_groups;\n    let out_c_per_group = out_channels / groups;\n\n    let col_shape_0 = in_c_per_group * kernel_h * kernel_w;\n    let col_shape_1 = batch_size * out_h * out_w;\n    let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);\n    let mut columns = empty_device_dtype(client, device, col_shape, weight.dtype);\n\n    let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));\n\n    let out_grad = swap_dims(out_grad, 0, 1);\n    let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);\n    let out_grad = reshape(out_grad, out_grad_shape);\n\n    for group in 0..groups {\n        let dtype = weight.dtype;\n        let weight = swap_dims(index(weight.clone(), group), 0, 1);\n        let out_grad = index(out_grad.clone(), group);\n        let values = matmul(weight, out_grad, None, MatmulStrategy::default(), dtype)?;\n        let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));\n        columns = slice_assign(\n            columns,\n            &[\n                burn_backend::Slice::from(group..group + 1),\n                burn_backend::Slice::from(0..col_shape_0),\n                burn_backend::Slice::from(0..col_shape_1),\n            ],\n            values,\n        );\n    }\n\n    let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));\n\n    let input_shape = image.shape();\n    let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(\n        columns.clone(),\n        image,\n        offset.clone(),\n        mask.clone(),\n        options,\n        kernel_dims,\n    )?;\n\n    let input_gradient =\n        compute_input_grad(columns, offset, mask, options, kernel_dims, input_shape)?;\n\n    Ok((input_gradient, offset_gradient, mask_gradient))\n}\n\nfn compute_offset_and_mask_gradient<R: CubeRuntime>(\n    columns: CubeTensor<R>,\n    image: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    options: &DeformConvOptions<2>,\n    kernel_dims: (usize, usize),\n) -> Result<(CubeTensor<R>, Option<CubeTensor<R>>), ConvSetupError> {\n    let client = offset.client.clone();\n    let device = offset.device.clone();\n    let (kernel_h, kernel_w) = kernel_dims;\n\n    let [batches, _, out_h, out_w] = offset.meta.shape().dims();\n    let offset_groups = options.offset_groups;\n\n    let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w];\n    let pos_shape = pos_shape.into_iter().collect();\n\n    let grad_offset =\n        empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype);\n    let grad_mask = mask\n        .as_ref()\n        .map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype));\n\n    let num_elements_offset = offset.meta.num_elements();\n    let cube_dim = CubeDim::new(&image.client, num_elements_offset);\n    let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim);\n\n    let dtype: StorageType = image.dtype.into();\n    unsafe {\n        deform_col2img_coord_kernel::launch_unchecked(\n            &grad_offset.client,\n            cube_count,\n            cube_dim,\n            address_type!(image, offset, mask, grad_offset, grad_mask),\n            image.into_tensor_arg(),\n            offset.into_tensor_arg(),\n            mask.map(|mask| mask.into_tensor_arg()).into(),\n            columns.into_tensor_arg(),\n            grad_offset.clone().into_linear_view(),\n            grad_mask\n                .clone()\n                .map(|grad_mask| grad_mask.into_tensor_arg())\n                .into(),\n            pos_shape,\n            DeformConv2dCol2ImgCoordArgsLaunch::new(\n                options.stride[0],\n                options.stride[1],\n                options.dilation[0],\n                options.dilation[1],\n                InputScalar::new(options.padding[0] as f32, dtype.elem_type()),\n                InputScalar::new(options.padding[1] as f32, dtype.elem_type()),\n                offset_groups,\n                kernel_h,\n                kernel_w,\n            ),\n            dtype,\n        )\n    };\n\n    Ok((grad_offset, grad_mask))\n}\n\n#[derive(CubeLaunch, CubeType)]\nstruct DeformConv2dCol2ImgCoordArgs {\n    stride_h: usize,\n    stride_w: usize,\n    dilation_h: usize,\n    dilation_w: usize,\n    pad_h: InputScalar,\n    pad_w: InputScalar,\n    offset_groups: usize,\n    kernel_height: usize,\n    kernel_width: usize,\n}\n\n#[allow(clippy::collapsible_if)]\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn deform_col2img_coord_kernel<F: Float>(\n    image: &Tensor<F>,\n    offset: &Tensor<F>,\n    mask: &ComptimeOption<Tensor<F>>,\n    columns: &Tensor<F>,\n    grad_offset: &mut LinearView<F, ReadWrite>,\n    grad_mask: &mut ComptimeOption<Tensor<F>>,\n    pos_shape: Sequence<FastDivmod<usize>>,\n    args: &DeformConv2dCol2ImgCoordArgs,\n    #[define(F)] _dtype: StorageType,\n) {\n    // Position format: [batch, [offset_groups, kernel_h, kernel_w, 2], out_h, out_w]\n    // Columns format: [[in_channel, kernel_h, kernel_w], [batch, out_h, out_w]]\n    // Alternatively : [batch, offset_channels, out_h, out_w]\n\n    if ABSOLUTE_POS >= grad_offset.shape() {\n        terminate!();\n    }\n\n    let out_h = offset.shape(2);\n    let out_w = offset.shape(3);\n    let in_channels = image.shape(1);\n    let height = image.shape(2);\n    let width = image.shape(3);\n    let kernel_w = args.kernel_width;\n    let kernel_h = args.kernel_height;\n\n    let mut grad_offset_val = F::new(0.0);\n    let mut grad_mask_val = F::new(0.0);\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);\n    let [batch, offset_group, kernel_y, kernel_x, dir, out_y, out_x] = *pos else {\n        unreachable!()\n    };\n\n    let channels_per_offset_group = in_channels / args.offset_groups;\n\n    let col_n = batch * out_h * out_w + out_y * out_w + out_x;\n\n    let col_base_idx =\n        offset_group * channels_per_offset_group * kernel_h * kernel_w * columns.stride(0)\n            + col_n * columns.stride(1);\n    let mut image_base_idx =\n        batch * image.stride(0) + offset_group * channels_per_offset_group * image.stride(1);\n\n    let offset_pos_1 =\n        offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;\n    let offset_base_idx = batch * offset.stride(0)\n        + offset_pos_1 * offset.stride(1)\n        + out_y * offset.stride(2)\n        + out_x * offset.stride(3);\n\n    let offset_y_idx = offset_base_idx;\n    let offset_x_idx = offset_base_idx + offset.stride(1);\n\n    let offset_y = offset[offset_y_idx];\n    let offset_x = offset[offset_x_idx];\n\n    let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;\n    #[comptime]\n    let mask_value = match &mask {\n        ComptimeOption::Some(mask) => {\n            let mask_idx = batch * mask.stride(0)\n                + mask_pos_1 * mask.stride(1)\n                + out_y * mask.stride(2)\n                + out_x * mask.stride(3);\n            mask[mask_idx]\n        }\n        ComptimeOption::None => F::new(1.0),\n    };\n\n    let is_y_direction = dir == 0;\n\n    for col_c in 0..channels_per_offset_group {\n        let col_pos = col_base_idx + col_c * kernel_h * kernel_w * columns.stride(0);\n\n        let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)\n            - args.pad_h.get::<F>()\n            + offset_y;\n        let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)\n            - args.pad_w.get::<F>()\n            + offset_x;\n\n        let weight =\n            get_coordinate_weight(image, image_base_idx, height, width, y, x, is_y_direction);\n        let columns_value = columns[col_pos];\n\n        grad_offset_val += mask_value * weight * columns_value;\n\n        if grad_mask.is_some() && is_y_direction {\n            grad_mask_val +=\n                columns_value * bilinear_interpolate(image, height, width, y, x, image_base_idx);\n        }\n\n        image_base_idx += image.stride(1);\n    }\n\n    grad_offset[ABSOLUTE_POS] = grad_offset_val;\n\n    #[comptime]\n    if let ComptimeOption::Some(grad_mask) = grad_mask {\n        if is_y_direction {\n            let idx = batch * grad_mask.stride(0)\n                + mask_pos_1 * grad_mask.stride(1)\n                + out_y * grad_mask.stride(2)\n                + out_x * grad_mask.stride(3);\n\n            grad_mask[idx] = grad_mask_val\n        }\n    }\n}\n\n#[cube]\nfn get_coordinate_weight<F: Float>(\n    input: &Tensor<F>,\n    offset: usize,\n    height: usize,\n    width: usize,\n    y: F,\n    x: F,\n    is_y_direction: bool,\n) -> F {\n    let stride_y = input.stride(2);\n    let stride_x = input.stride(3);\n\n    let y = f32::cast_from(y);\n    let x = f32::cast_from(x);\n\n    let y_low = f32::floor(y);\n    let x_low = f32::floor(x);\n    let y_high = y_low + 1.;\n    let x_high = x_low + 1.;\n\n    let valid_y_low = y_low >= 0. && y_low < height as f32;\n    let valid_y_high = y_high >= 0. && y_high < height as f32;\n    let valid_x_low = x_low >= 0. && x_low < width as f32;\n    let valid_x_high = x_high >= 0. && x_high < width as f32;\n\n    let bottom_left = if valid_y_low && valid_x_low {\n        input[offset + y_low as usize * stride_y + x_low as usize * stride_x]\n    } else {\n        F::new(0.0)\n    };\n    let bottom_right = if valid_y_low && valid_x_high {\n        input[offset + y_low as usize * stride_y + x_high as usize * stride_x]\n    } else {\n        F::new(0.0)\n    };\n    let top_left = if valid_y_high && valid_x_low {\n        input[offset + y_high as usize * stride_y + x_low as usize * stride_x]\n    } else {\n        F::new(0.0)\n    };\n    let top_right = if valid_y_high && valid_x_high {\n        input[offset + y_high as usize * stride_y + x_high as usize * stride_x]\n    } else {\n        F::new(0.0)\n    };\n\n    if is_y_direction {\n        let delta_x = F::cast_from(x - x_low);\n        delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)\n    } else {\n        let delta_y = F::cast_from(y - y_low);\n        delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)\n    }\n}\n\nfn compute_input_grad<R: CubeRuntime>(\n    columns: CubeTensor<R>,\n    offset: CubeTensor<R>,\n    mask: Option<CubeTensor<R>>,\n    options: &DeformConvOptions<2>,\n    kernel_dims: (usize, usize),\n    input_shape: Shape,\n) -> Result<CubeTensor<R>, LaunchError> {\n    let client = offset.client.clone();\n    let device = offset.device.clone();\n\n    let supports_fadd = client\n        .properties()\n        .type_usage(StorageType::Atomic(FloatKind::F32.into()))\n        .contains(TypeUsage::AtomicAdd);\n    let supports_same_type = client\n        .properties()\n        .type_usage(StorageType::Atomic(columns.dtype.into()))\n        .contains(TypeUsage::AtomicAdd);\n\n    let [batches, in_channels, height, width] = input_shape.dims();\n    let [_, _, out_h, out_w] = offset.meta.shape().dims();\n    let (kernel_h, kernel_w) = kernel_dims;\n\n    let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w];\n    let pos_shape = pos_shape.into_iter().collect();\n\n    let shape = Shape::new([batches, in_channels, height, width]);\n    let grad_in = match supports_fadd && supports_same_type {\n        // Use type as is to save a cast\n        true => zeros_client(client.clone(), device.clone(), shape, columns.dtype),\n        // Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported\n        false => zeros_client(client.clone(), device.clone(), shape, DType::F32),\n    };\n    let grad_arg = grad_in.clone().into_tensor_arg();\n\n    let num_elements = columns.meta.num_elements();\n    let cube_dim = CubeDim::new(&offset.client, num_elements);\n    let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim);\n\n    let launch = match supports_fadd {\n        true => deform_col2img_kernel::launch_unchecked::<IntrinsicFloatAtomicAddFamily, R>,\n        false => deform_col2img_kernel::launch_unchecked::<CASFloatAtomicAdd, R>,\n    };\n    let dtype = offset.dtype;\n    let dtypes: [StorageType; 2] = match supports_same_type {\n        true => [dtype.into(), dtype.into()],\n        false => [dtype.into(), DType::F32.into()],\n    };\n\n    unsafe {\n        launch(\n            &grad_in.client,\n            cube_count,\n            cube_dim,\n            address_type!(offset, mask, columns, grad_in),\n            offset.into_tensor_arg(),\n            mask.map(|mask| mask.into_tensor_arg()).into(),\n            columns.into_linear_view(),\n            grad_arg,\n            pos_shape,\n            DeformConv2dCol2ImgArgsLaunch::new(\n                options.stride[0],\n                options.stride[1],\n                options.dilation[0],\n                options.dilation[1],\n                InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()),\n                InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()),\n                options.offset_groups,\n                kernel_h,\n                kernel_w,\n            ),\n            dtypes,\n        )\n    };\n\n    Ok(if !supports_same_type || !supports_fadd {\n        cast(grad_in, dtype)\n    } else {\n        grad_in\n    })\n}\n\n#[derive(CubeLaunch, CubeType)]\nstruct DeformConv2dCol2ImgArgs {\n    stride_h: usize,\n    stride_w: usize,\n    dilation_h: usize,\n    dilation_w: usize,\n    pad_h: InputScalar,\n    pad_w: InputScalar,\n    offset_groups: usize,\n    kernel_height: usize,\n    kernel_width: usize,\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn deform_col2img_kernel<F: Float, FP: Float, FAdd: FloatAtomicAddFamily>(\n    offset: &Tensor<F>,\n    mask: &ComptimeOption<Tensor<F>>,\n    columns: &LinearView<F>,\n    grad_input: &mut Tensor<Atomic<ProxyType<FAdd, FP>>>,\n    pos_shape: Sequence<FastDivmod<usize>>,\n    args: &DeformConv2dCol2ImgArgs,\n    #[define(F, FP)] _dtype: [StorageType; 2],\n) {\n    // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]\n    if ABSOLUTE_POS >= columns.shape() {\n        terminate!();\n    }\n\n    let n_in_channels = grad_input.shape(1);\n    let height = grad_input.shape(2);\n    let width = grad_input.shape(3);\n    let kernel_h = args.kernel_height;\n    let kernel_w = args.kernel_width;\n    let n_offset_groups = args.offset_groups;\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS, &pos_shape);\n    let [in_channel, kernel_y, kernel_x, batch, out_y, out_x] = *pos else {\n        unreachable!()\n    };\n\n    let channels_per_offset_group = n_in_channels / n_offset_groups;\n    let offset_group = in_channel / channels_per_offset_group;\n\n    let offset_pos_1 =\n        offset_group * kernel_h * kernel_w * 2 + kernel_y * kernel_w * 2 + kernel_x * 2;\n    let offset_base_idx = batch * offset.stride(0)\n        + offset_pos_1 * offset.stride(1)\n        + out_y * offset.stride(2)\n        + out_x * offset.stride(3);\n\n    let offset_y_idx = offset_base_idx;\n    let offset_x_idx = offset_base_idx + offset.stride(1);\n\n    let offset_y = offset[offset_y_idx];\n    let offset_x = offset[offset_x_idx];\n\n    #[comptime]\n    let mask_value = match mask {\n        ComptimeOption::Some(mask) => {\n            let mask_pos_1 = offset_group * kernel_h * kernel_w + kernel_y * kernel_w + kernel_x;\n            mask[batch * mask.stride(0)\n                + mask_pos_1 * mask.stride(1)\n                + out_y * mask.stride(2)\n                + out_x * mask.stride(3)]\n        }\n        ComptimeOption::None => F::new(1.0),\n    };\n\n    let y = F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h)\n        - args.pad_h.get::<F>()\n        + offset_y;\n    let x = F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w)\n        - args.pad_w.get::<F>()\n        + offset_x;\n\n    for dy in -1..=1i32 {\n        #[unroll]\n        for dx in -1..=1i32 {\n            let yp = y.floor() + F::cast_from(dy);\n            let xp = x.floor() + F::cast_from(dx);\n\n            if yp >= F::new(0.0)\n                && yp < F::cast_from(height)\n                && xp >= F::new(0.0)\n                && xp < F::cast_from(width)\n                && F::abs(y - yp) < F::new(1.0)\n                && F::abs(x - xp) < F::new(1.0)\n            {\n                let gradient_pos = batch * grad_input.stride(0)\n                    + in_channel * grad_input.stride(1)\n                    + usize::cast_from(yp) * grad_input.stride(2)\n                    + usize::cast_from(xp) * grad_input.stride(3);\n\n                let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp));\n\n                let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];\n\n                FAdd::Op::<FP>::float_atomic_add::<F>(&mut grad_input[gradient_pos], value);\n            }\n        }\n    }\n}\n\ntype ProxyType<FADF, FP> = <<FADF as FloatAtomicAddFamily>::Op<FP> as FloatAtomicAdd>::ProxyType;\n\n#[cube]\ntrait FloatAtomicAddFamily: Send + Sync + 'static {\n    type Op<ProxyType: Float>: FloatAtomicAdd;\n}\n\n#[cube]\ntrait FloatAtomicAdd: Send + Sync + 'static {\n    type ProxyType: Numeric;\n\n    fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F);\n}\n\n#[derive(CubeType)]\nstruct IntrinsicFloatAtomicAdd<F: Float> {\n    #[cube(comptime)]\n    _ty: PhantomData<F>,\n}\n\n#[derive(CubeType)]\nstruct CASFloatAtomicAdd;\n\nstruct IntrinsicFloatAtomicAddFamily;\n\nimpl FloatAtomicAddFamily for IntrinsicFloatAtomicAddFamily {\n    type Op<ProxyType: Float> = IntrinsicFloatAtomicAdd<ProxyType>;\n}\n\nimpl FloatAtomicAddFamily for CASFloatAtomicAdd {\n    type Op<ProxyType: Float> = Self;\n}\n\n#[cube]\nimpl<FAdd: Float> FloatAtomicAdd for IntrinsicFloatAtomicAdd<FAdd> {\n    type ProxyType = FAdd;\n\n    fn float_atomic_add<F: Float>(ptr: &mut Atomic<FAdd>, value: F) {\n        let value = FAdd::cast_from(value);\n        ptr.fetch_add(value);\n    }\n}\n\n#[cube]\nimpl FloatAtomicAdd for CASFloatAtomicAdd {\n    type ProxyType = u32;\n\n    fn float_atomic_add<F: Float>(ptr: &mut Atomic<Self::ProxyType>, value: F) {\n        let value = f32::cast_from(value);\n        if value != 0.0 {\n            let mut v = ptr.load();\n            loop {\n                let prev = v;\n                let v_float = f32::from_bits(v);\n                let new = (v_float + value).to_bits();\n                v = ptr.compare_exchange_weak(v, new);\n                if prev == v {\n                    break;\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/direct.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{into_contiguous_aligned, utils::address_type},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\nuse crate::{kernel::utils::decompose_linear, ops::numeric::empty_device_dtype};\nuse burn_backend::{\n    TensorMetadata,\n    ops::{ConvOptions, conv::calculate_conv_output_sizes},\n};\nuse cubecl::{\n    calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView,\n    tensor_vector_size_parallel,\n};\nuse cubecl::{num_traits::Zero, std::FastDivmod};\nuse cubek::convolution::components::ConvSetupError;\n\n#[derive(CubeLaunch, CubeType, Clone)]\npub(crate) struct ConvParam {\n    pub stride: u32,\n    pub dilation: u32,\n    pub padding: i32,\n}\n\n#[derive(CubeLaunch, CubeType)]\nstruct Conv2dArgs {\n    conv_params: Sequence<ConvParam>,\n    channels_per_group: u32,\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\n#[allow(clippy::redundant_closure)]\nfn direct_conv2d_kernel<E: Numeric, NIn: Size, NOut: Size>(\n    input: &Tensor<Vector<E, NIn>>,\n    weight: &Tensor<Vector<E, NIn>>,\n    bias: ComptimeOption<Tensor<Vector<E, NOut>>>,\n    output: &mut LinearView<Vector<E, NOut>, ReadWrite>,\n    args: Conv2dArgs,\n    shape_out: Sequence<FastDivmod<u32>>,\n    shape_out_c: FastDivmod<u32>,\n    #[comptime] has_padding: bool,\n    #[define(E)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let n_spatial = comptime![shape_out.len()];\n\n    let vector_size_out = output.vector_size();\n    let pos = ABSOLUTE_POS * vector_size_out;\n\n    let in_c_per_group = weight.shape(weight.rank() - 1) as u32;\n\n    let (rem, out_c) = shape_out_c.div_mod(pos as u32);\n    let (b, spatial_pos) = decompose_linear(rem, &shape_out);\n\n    let g = out_c / args.channels_per_group;\n    let ic_start = in_c_per_group * g;\n\n    let bias: ComptimeOption<Vector<E, NOut>> =\n        bias.map(|bias| bias[out_c as usize / vector_size_out]);\n    let mut sum = bias.unwrap_or_else(|| Vector::zero());\n\n    let in_offs = b as usize * input.stride(0) + ic_start as usize;\n\n    let stride_oc = weight.stride(0);\n\n    let mut in_shape = Sequence::new();\n    let mut in_strides = Sequence::new();\n    let mut kernel_shape = Sequence::new();\n    let mut kernel_strides = Sequence::new();\n\n    #[unroll]\n    for i in 0..n_spatial {\n        in_shape.push(input.shape(i + 1) as u32);\n        in_strides.push(input.stride(i + 1));\n        kernel_shape.push(weight.shape(i + 1) as u32);\n        kernel_strides.push(weight.stride(i + 1));\n    }\n\n    let weight_offs = out_c as usize * stride_oc;\n\n    let loop_params = LoopParams {\n        out_pos: spatial_pos,\n        in_shape,\n        in_strides,\n        kernel_shape,\n        kernel_strides,\n        conv_params: args.conv_params,\n        in_c_per_group,\n        stride_oc,\n    };\n\n    kernel_loop(\n        input,\n        weight,\n        &mut sum,\n        in_offs,\n        true,\n        weight_offs,\n        &loop_params,\n        0usize,\n        has_padding,\n    );\n\n    output[ABSOLUTE_POS] = sum;\n}\n\n#[derive(CubeType, Clone)]\nstruct LoopParams {\n    out_pos: Sequence<u32>,\n    in_shape: Sequence<u32>,\n    in_strides: Sequence<usize>,\n    kernel_shape: Sequence<u32>,\n    kernel_strides: Sequence<usize>,\n    conv_params: Sequence<ConvParam>,\n\n    in_c_per_group: u32,\n    stride_oc: usize,\n}\n\n#[cube]\nfn kernel_loop<E: Numeric, NIn: Size, NOut: Size>(\n    input: &Tensor<Vector<E, NIn>>,\n    weight: &Tensor<Vector<E, NIn>>,\n    sum: &mut Vector<E, NOut>,\n    in_offs: usize,\n    in_bounds: bool,\n    weight_offs: usize,\n    params: &LoopParams,\n    #[comptime] kernel_dim: usize,\n    #[comptime] has_padding: bool,\n) {\n    if comptime![kernel_dim < params.kernel_shape.len()] {\n        let out_idx = *params.out_pos.index(kernel_dim);\n        let conv = params.conv_params.index(kernel_dim);\n        let shape = *params.in_shape.index(kernel_dim);\n        let stride = *params.in_strides.index(kernel_dim);\n        let k_stride = *params.kernel_strides.index(kernel_dim);\n\n        for pos in 0..*params.kernel_shape.index(kernel_dim) {\n            let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding;\n            let in_offs = in_offs + in_pos as usize * stride;\n            let weight_offs = weight_offs + pos as usize * k_stride;\n            let mut in_bounds = in_bounds;\n\n            if has_padding {\n                in_bounds &= in_pos >= 0 && (in_pos as u32) < shape;\n            }\n\n            kernel_loop(\n                input,\n                weight,\n                sum,\n                in_offs,\n                in_bounds,\n                weight_offs,\n                params,\n                comptime![kernel_dim + 1],\n                has_padding,\n            );\n        }\n    } else {\n        kernel_loop_inner(\n            input,\n            weight,\n            sum,\n            in_offs,\n            in_bounds,\n            weight_offs,\n            params.in_c_per_group,\n            params.stride_oc,\n        );\n    }\n}\n\n#[cube]\nfn kernel_loop_inner<E: Numeric, NIn: Size, NOut: Size>(\n    input: &Tensor<Vector<E, NIn>>,\n    weight: &Tensor<Vector<E, NIn>>,\n    sum: &mut Vector<E, NOut>,\n    in_offs: usize,\n    in_bounds: bool,\n    weight_offs: usize,\n    in_c_per_group: u32,\n    stride_oc: usize,\n) {\n    let vector_size_in = input.vector_size();\n    let vector_size_out = sum.size();\n\n    if in_bounds {\n        for in_c in range_stepped(0, in_c_per_group, vector_size_in as u32) {\n            let in_pos = in_offs + in_c as usize;\n            let mut weight_pos = weight_offs + in_c as usize;\n\n            let val = input[in_pos / vector_size_in];\n\n            #[unroll]\n            for v in 0..vector_size_out {\n                let weight = weight[weight_pos / vector_size_in];\n                let val = val * weight;\n\n                #[unroll]\n                for i in 0..vector_size_in {\n                    sum[v] += val[i];\n                }\n                weight_pos += stride_oc;\n            }\n        }\n    }\n}\n\n/// Perform a 2D convolution using the direct convolution algorithm.\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\n///\npub fn conv_direct<R: CubeRuntime, const N: usize>(\n    mut input: CubeTensor<R>,\n    mut weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let out_dtype = input.dtype;\n    let rank = input.meta.shape().num_dims();\n    let dim_c = rank - 1;\n\n    // We only care about the channels here, everything else can be permuted\n    if input.meta.strides()[dim_c] != 1 {\n        input = into_contiguous_aligned(input);\n    }\n    if weight.meta.strides()[dim_c] != 1 {\n        weight = into_contiguous_aligned(weight);\n    }\n\n    let batch_size = input.meta.shape()[0];\n    let in_shape = &input.meta.shape()[1..dim_c];\n    let out_channels = weight.meta.shape()[0];\n    let kernel_shape = &weight.meta.shape()[1..dim_c];\n\n    let channels_per_group = out_channels / options.groups;\n\n    let out_size = calculate_conv_output_sizes(\n        kernel_shape,\n        &options.stride,\n        &options.padding,\n        &options.dilation,\n        in_shape,\n    );\n\n    let mut shape_out = vec![batch_size];\n    shape_out.extend(out_size.iter().copied());\n    shape_out.push(out_channels);\n\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        shape_out.into(),\n        out_dtype,\n    );\n\n    // Need custom vector size calculation here to account for the groups division. Need to vectorize\n    // over `channels_per_group` instead.\n    let mut grouped_out_shape = output.shape();\n    grouped_out_shape[dim_c] = channels_per_group;\n    let vector_size_out = tensor_vector_size_parallel(\n        input.client.io_optimized_vector_sizes(input.dtype.size()),\n        &grouped_out_shape,\n        output.meta.strides(),\n        dim_c,\n    );\n    // Use channels_per_group instead of in_channels to avoid issues here\n    let vector_size_in = max_vector_size(&weight);\n\n    let shape_out = output.meta.shape()[1..dim_c]\n        .iter()\n        .map(|s| *s as u32)\n        .collect();\n    let shape_out_c = out_channels as u32;\n\n    let mut conv_params = SequenceArg::new();\n\n    for i in 0..kernel_shape.len() {\n        conv_params.push(ConvParamLaunch::new(\n            options.stride[i] as u32,\n            options.dilation[i] as u32,\n            options.padding[i] as i32,\n        ));\n    }\n\n    let working_units = output.meta.num_elements() / vector_size_out;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    unsafe {\n        direct_conv2d_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(input, weight, bias, output),\n            vector_size_in,\n            vector_size_out,\n            input.into_tensor_arg(),\n            weight.into_tensor_arg(),\n            bias.map(|b| b.into_tensor_arg()).into(),\n            output.clone().into_linear_view(),\n            Conv2dArgsLaunch::new(conv_params, channels_per_group as u32),\n            shape_out,\n            shape_out_c,\n            options.padding.iter().any(|it| *it != 0),\n            out_dtype.into(),\n        )\n    };\n\n    Ok(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs",
    "content": "use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes};\nuse cubek::{\n    convolution::{\n        AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,\n        components::ConvSetupError, forward,\n    },\n    matmul::{\n        definition::{MatmulElems, MatmulGlobalElems},\n        launch::MatmulInputBinding,\n    },\n};\n\n/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul\n/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\npub fn conv_gemm_simple_sync<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::Cyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::Strided,\n    };\n    launch_convolution_forward::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        input,\n        weight,\n        bias,\n        options,\n    )\n}\n\npub fn conv_gemm_simple_async<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    let read_strategy = match tile_kind {\n        AcceleratedTileKind::Cmma => ReadingStrategy::AsyncCyclic,\n        AcceleratedTileKind::Mma => ReadingStrategy::AsyncStrided,\n    };\n    launch_convolution_forward::<R, N>(\n        &Strategy::Simple {\n            read_strategy,\n            tile_kind,\n        },\n        input,\n        weight,\n        bias,\n        options,\n    )\n}\n\n/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul\n/// components. Uses [`CmmaLargeMAlgorithm`] for the stage size\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\npub fn conv_gemm_simple_tma<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n    tile_kind: AcceleratedTileKind,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    launch_convolution_forward::<R, N>(\n        &Strategy::Simple {\n            read_strategy: ReadingStrategy::Tma,\n            tile_kind,\n        },\n        input,\n        weight,\n        bias,\n        options,\n    )\n}\n\n/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul\n/// components, using the specified algorithm.\n///\n/// * `input` - The input feature map\n/// * `weight` - The weights (filter) applied to each kernel\n/// * `bias` - The bias added to each channel\n/// * `options` - The options to use for the convolution\npub fn launch_convolution_forward<R: CubeRuntime, const N: usize>(\n    strategy: &Strategy,\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    if options.groups != 1 {\n        return Err(ConvSetupError::Groups(options.groups));\n    }\n\n    let out_dtype = input.dtype;\n    let rank = input.meta.shape().num_dims();\n    let batch_size = input.meta.shape()[0];\n    let dim_c = rank - 1;\n    let shape = &input.meta.shape()[1..dim_c];\n\n    let out_channels = weight.meta.shape()[0];\n    let weight_shape = &weight.meta.shape()[1..dim_c];\n\n    let mut out_shape = calculate_conv_output_sizes(\n        weight_shape,\n        &options.stride,\n        &options.padding,\n        &options.dilation,\n        shape,\n    );\n\n    out_shape.insert(0, batch_size);\n    out_shape.push(out_channels);\n\n    let out = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        out_shape.into(),\n        out_dtype,\n    );\n\n    let bias = bias.map(|bias| {\n        let dtype = bias.dtype;\n        MatmulInputBinding::Normal(bias.binding(), dtype.into())\n    });\n\n    let client = input.client.clone();\n    let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {\n        lhs: input.dtype.into(),\n        rhs: weight.dtype.into(),\n        out: out_dtype.into(),\n    });\n    let input_dtype = input.dtype;\n    let weight_dtype = weight.dtype;\n    let input = MatmulInputBinding::new(input.binding(), input_dtype.into());\n    let weight = MatmulInputBinding::new(weight.binding(), weight_dtype.into());\n\n    forward::launch_ref::<R, N>(\n        strategy,\n        &client,\n        input,\n        weight,\n        bias,\n        out.clone().binding(),\n        ConvolutionArgs {\n            stride: options.stride,\n            padding: options.padding,\n            dilation: options.dilation,\n        },\n        dtypes,\n    )?;\n\n    Ok(out)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/mod.rs",
    "content": "pub mod launch;\npub use launch::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/forward/mod.rs",
    "content": "pub mod implicit_gemm;\n\n#[cfg(feature = \"autotune\")]\npub mod tune;\n\n#[cfg(feature = \"autotune\")]\npub(crate) use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/forward/tune.rs",
    "content": "use burn_backend::ops::ConvOptions;\nuse cubecl::{\n    ir::StorageType,\n    tune::{LocalTuner, Tunable, TunableSet, anchor, local_tuner},\n};\nuse cubek::convolution::AcceleratedTileKind;\n\nuse crate::{\n    CubeAutotuneKey, CubeRuntime, CubeTuneId,\n    kernel::conv::{ConvAutotuneKey, conv_direct, conv_im2col_1x1, forward::implicit_gemm::*},\n    tensor::CubeTensor,\n};\n\n/// Executes autotune on convolution operations\npub fn conv_autotune<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n) -> CubeTensor<R> {\n    let client = input.client.clone();\n\n    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        TunableSet::new(create_key::<R, N>, create_conv_input::<R, N>)\n            .with(Tunable::new(\"conv_direct\", conv_direct::<R, N>))\n            .with(Tunable::new(\"conv_im2col_1x1\", conv_im2col_1x1::<R, N>))\n            .with(Tunable::new(\n                \"simple_sync_cmma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_sync_mma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_sync(input, weight, bias, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_cmma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_async_mma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_async(input, weight, bias, options, AcceleratedTileKind::Mma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_tma_cmma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Cmma)\n                },\n            ))\n            .with(Tunable::new(\n                \"simple_tma_mma\",\n                |input, weight, bias, options| {\n                    conv_gemm_simple_tma(input, weight, bias, options, AcceleratedTileKind::Mma)\n                },\n            ))\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&input.client, &input.device),\n        &client,\n        tunables,\n        (input, weight, bias, options),\n    )\n}\n\npub fn create_conv_input<R: CubeRuntime, const N: usize>(\n    _key: &CubeAutotuneKey,\n    input: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    bias: &Option<CubeTensor<R>>,\n    options: &ConvOptions<N>,\n) -> (\n    CubeTensor<R>,\n    CubeTensor<R>,\n    Option<CubeTensor<R>>,\n    ConvOptions<N>,\n) {\n    (\n        input.clone(),\n        weights.clone(),\n        bias.clone(),\n        options.clone(),\n    )\n}\n\nfn create_key<R: CubeRuntime, const N: usize>(\n    input: &CubeTensor<R>,\n    weights: &CubeTensor<R>,\n    bias: &Option<CubeTensor<R>>,\n    options: &ConvOptions<N>,\n) -> CubeAutotuneKey {\n    let dtype = input.dtype;\n    let rank = input.meta.shape().num_dims();\n    let dim_c = rank - 1;\n\n    let batch_size = input.meta.shape()[0];\n    let in_channels = input.meta.shape()[dim_c];\n    let out_channels = weights.meta.shape()[0];\n\n    let kernel_size = weights.meta.shape()[1..dim_c].to_vec();\n    let in_shape = input.meta.shape()[1..dim_c]\n        .iter()\n        .map(|shape| anchor(*shape, None, None, None))\n        .collect();\n\n    let ConvOptions {\n        stride,\n        padding,\n        dilation,\n        groups,\n    } = options.clone();\n\n    let lhs_stride_align = if input.meta.strides()[dim_c] == 1 {\n        stride_align(input.meta.strides(), input.dtype.into())\n    } else {\n        0\n    };\n    let lhs_shape_align = pow2_factor(in_channels).min(lhs_stride_align);\n    let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 {\n        stride_align(weights.meta.strides(), weights.dtype.into())\n    } else {\n        0\n    };\n    let rhs_shape_align = pow2_factor(in_channels).min(rhs_stride_align);\n\n    CubeAutotuneKey::Conv(ConvAutotuneKey::new(\n        kernel_size,\n        stride.to_vec(),\n        padding.to_vec(),\n        dilation.to_vec(),\n        groups,\n        in_channels,\n        out_channels,\n        in_shape,\n        batch_size,\n        bias.is_some(),\n        dtype,\n        lhs_shape_align,\n        lhs_stride_align,\n        rhs_shape_align,\n        rhs_stride_align,\n    ))\n}\n\n/// Maximum factor relevant for strides. Currently set to 2^10 because that's 128-byte swizzle's\n/// repeat number, so it's the largest align that can have performance impacts.\nconst MAX_STRIDE_FACTOR: u32 = 10;\n\n/// Defines the non-contiguous stride alignment in terms of powers of two\nfn stride_align(strides: &[usize], elem: StorageType) -> u8 {\n    let max = MAX_STRIDE_FACTOR;\n    let dim_c = strides.len() - 1;\n    let factor = strides[..dim_c]\n        .iter()\n        .map(|it| (*it * elem.size_bits()) / 8)\n        .map(|it| it.trailing_zeros())\n        .min()\n        .unwrap_or(max);\n    factor.min(max) as u8\n}\n\n/// Defines the potential vectorization.\nfn pow2_factor(axis: usize) -> u8 {\n    axis.trailing_zeros().min(4) as u8\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/im2col.rs",
    "content": "use burn_backend::{\n    DType,\n    ops::{ConvOptions, conv::calculate_conv_output_sizes},\n};\nuse burn_std::{Metadata, Shape};\nuse core::iter;\nuse cubecl::{\n    prelude::*,\n    std::tensor::{TensorHandle, into_contiguous_pitched},\n};\nuse cubek::convolution::components::ConvSetupError;\n\nuse crate::{\n    CubeRuntime,\n    kernel::{\n        AddOp, into_contiguous_aligned, launch_binop,\n        matmul::{MatmulStrategy, matmul},\n        utils::split_dim,\n    },\n    ops::{reshape, swap_dims},\n    tensor::CubeTensor,\n};\n\n#[cfg(not(test))]\npub(crate) fn batches_per_run(\n    batch_size: usize,\n    out_shape: usize,\n    plane_size: usize,\n) -> Result<usize, ConvSetupError> {\n    use cubek::matmul::definition::MatmulAvailabilityError;\n\n    let cube_count_per_batch = out_shape.div_ceil(plane_size);\n    let max_cube_count = u16::MAX as usize;\n    let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size);\n    if max_simultaneous == 0 {\n        return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(\n            cube_count_per_batch as u32,\n            1,\n            1,\n        ))\n        .into());\n    }\n    Ok((0..=max_simultaneous)\n        .rev()\n        .find(|per_run| batch_size.is_multiple_of(*per_run))\n        .expect(\"Logically not possible\"))\n}\n\n#[cfg(test)]\n#[allow(unused)]\npub(crate) fn batches_per_run(\n    batch_size: usize,\n    out_shape: usize,\n    plane_size: usize,\n) -> Result<usize, ConvSetupError> {\n    Ok(1)\n}\n\npub fn conv_im2col_1x1<R: CubeRuntime, const N: usize>(\n    input: CubeTensor<R>,\n    mut weight: CubeTensor<R>,\n    bias: Option<CubeTensor<R>>,\n    options: ConvOptions<N>,\n) -> Result<CubeTensor<R>, ConvSetupError> {\n    if options.groups != 1 {\n        return Err(ConvSetupError::Groups(options.groups));\n    }\n\n    let rank = input.meta.num_dims();\n    let dim_c = rank - 1;\n\n    let batch_size = input.meta.shape()[0];\n    let in_channels = input.meta.shape()[dim_c];\n    let in_shape = &input.meta.shape()[1..dim_c];\n    let out_channels = weight.meta.shape()[0];\n    let kernel_shape = &weight.meta.shape()[1..dim_c];\n\n    if kernel_shape.iter().any(|s| *s != 1) {\n        return Err(ConvSetupError::Unknown);\n    }\n\n    let out_shape = calculate_conv_output_sizes(\n        kernel_shape,\n        &options.stride,\n        &options.padding,\n        &options.dilation,\n        in_shape,\n    );\n\n    let mut split_m = vec![batch_size];\n    split_m.extend(out_shape.iter().copied());\n\n    if kernel_shape.iter().any(|it| *it != 1) || in_shape != out_shape {\n        return Err(ConvSetupError::Unknown);\n    }\n\n    let input = reshape_input(input); // [(NHW), C] : [M, K]\n    let dtype = input.dtype;\n\n    // Efficient permutation that takes the stride required for TMA into account\n    let weight = if weight.meta.strides()[dim_c] != 1 {\n        // Remove kernel dims so padded dim is channels\n        *weight.meta = Metadata::new(\n            [out_channels, in_channels], // [N, K]\n            [weight.meta.strides()[0], weight.meta.strides()[dim_c]],\n        );\n        // Pitched contiguous to skip running another kernel for TMA\n        into_contiguous_aligned(weight)\n    } else {\n        // Already compatible, skip initial reshape\n        *weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]);\n        weight\n    };\n\n    // Permute to N-major, while keeping memory layout K-major. K-major for both sides is the most\n    // efficient for matmul, and allows skipping a contiguous kernel\n    let weight = swap_dims(weight, 0, 1); // [K, N]\n\n    let out = matmul(input, weight, None, MatmulStrategy::default(), dtype)?; // [M, N]\n\n    // Skip reshape to avoid potential `into_contiguous`. We're only splitting dims so it's safe.\n    let mut out = split_dim(out, 0, &split_m); // [N, H, W, C]\n\n    if let Some(bias) = bias {\n        let mut bias_shape = iter::repeat_n(1, rank - 1).collect::<Vec<_>>();\n        bias_shape.push(out_channels);\n        let bias = reshape(bias, bias_shape.into());\n        out = launch_binop::<R, AddOp>(out, bias);\n    }\n\n    Ok(out)\n}\n\n/// Reshapes NHWC input to [(N, H, W), C]\nfn reshape_input<R: CubeRuntime>(input: CubeTensor<R>) -> CubeTensor<R> {\n    let rank = input.meta.num_dims();\n    let dim_c = rank - 1;\n    let dtype = input.dtype;\n\n    let batch_size = input.meta.shape()[0];\n    let in_c: usize = input.meta.shape()[dim_c];\n    let in_shape: Shape = input.meta.shape()[1..dim_c].into();\n\n    let mut input = if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) {\n        let (client, device) = (input.client.clone(), input.device.clone());\n        let contiguous = into_contiguous_pitched(&client, input.binding(), dtype.into());\n        from_handle(client, device, contiguous, dtype)\n    } else {\n        input\n    };\n\n    *input.meta = Metadata::new(\n        [batch_size * in_shape.num_elements(), in_c], // [M, K]\n        [input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]],\n    );\n    input\n}\n\nfn is_spatial_contiguous(shape: &[usize], strides: &[usize]) -> bool {\n    let rank = shape.len();\n    let dim_c = rank - 1;\n\n    // Channel must be contiguous for the [(N, H, W), C] reshape to be valid\n    if strides[dim_c] != 1 {\n        return false;\n    }\n\n    for i in (1..dim_c).rev() {\n        if strides[i + 1] * shape[i + 1] != strides[i] {\n            return false;\n        }\n    }\n    true\n}\n\nfn from_handle<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    handle: TensorHandle<R>,\n    dtype: DType,\n) -> CubeTensor<R> {\n    CubeTensor::new(\n        client.clone(),\n        handle.handle,\n        *handle.metadata,\n        device.clone(),\n        dtype,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/mod.rs",
    "content": "mod backward_data;\nmod backward_weight;\nmod base;\nmod conv_transpose2d;\nmod conv_transpose3d;\nmod deform_conv2d;\nmod deform_conv_transpose2d;\nmod direct;\nmod forward;\nmod im2col;\n\nmod tune_key;\n\npub(crate) use backward_data::*;\npub(crate) use conv_transpose2d::*;\npub(crate) use conv_transpose3d::*;\npub(crate) use deform_conv_transpose2d::*;\npub(crate) use deform_conv2d::*;\npub(crate) use direct::*;\npub(crate) use im2col::*;\n\npub use base::*;\npub use conv_transpose2d::{ConvTranspose2dStrategy, conv_transpose2d};\n\npub(crate) use tune_key::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/conv/tune_key.rs",
    "content": "use burn_backend::DType;\nuse cubecl::AutotuneKey;\nuse serde::{Deserialize, Serialize};\n\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\n/// Autotune key representative of matmul versions\npub struct ConvAutotuneKey {\n    pub kernel_size: Vec<usize>,\n    pub stride: Vec<usize>,\n    pub padding: Vec<usize>,\n    pub dilation: Vec<usize>,\n    pub groups: usize,\n    #[autotune(anchor)]\n    pub in_channels: usize,\n    #[autotune(anchor)]\n    pub out_channels: usize,\n    pub shape: Vec<usize>,\n    #[autotune(anchor)]\n    pub batch_size: usize,\n    pub has_bias: bool,\n    pub dtype: DType,\n\n    pub lhs_shape_align: u8,\n    pub lhs_stride_align: u8,\n    pub rhs_shape_align: u8,\n    pub rhs_stride_align: u8,\n}\n\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\n/// Autotune key representative of matmul versions\npub struct ConvTranspose2dAutotuneKey {\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub padding_out: [usize; 2],\n    pub dilation: [usize; 2],\n    pub groups: usize,\n    #[autotune(anchor)]\n    pub in_channels: usize,\n    #[autotune(anchor)]\n    pub out_channels: usize,\n    #[autotune(anchor)]\n    pub height: usize,\n    #[autotune(anchor)]\n    pub width: usize,\n    #[autotune(anchor)]\n    pub batch_size: usize,\n    pub has_bias: bool,\n    pub dtype: DType,\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/cross.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse cubecl::std::tensor::layout::linear::LinearView;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn cross_kernel<E: Float>(\n    lhs: &LinearView<E>,\n    rhs: &LinearView<E>,\n    output: &mut LinearView<E, ReadWrite>,\n    #[define(E)] _dtype: StorageType,\n) {\n    // Each thread processes one 3-element vector\n    let vector_idx = ABSOLUTE_POS;\n    let base_pos = vector_idx * 3;\n\n    if !output.is_in_bounds(base_pos) {\n        terminate!();\n    }\n\n    // Extract vectors\n    let a0 = lhs[base_pos];\n    let a1 = lhs[base_pos + 1];\n    let a2 = lhs[base_pos + 2];\n    let b0 = rhs[base_pos];\n    let b1 = rhs[base_pos + 1];\n    let b2 = rhs[base_pos + 2];\n\n    // Compute cross product: a × b\n    let x = a1 * b2 - a2 * b1;\n    let y = a2 * b0 - a0 * b2;\n    let z = a0 * b1 - a1 * b0;\n\n    // Store result\n    output[base_pos] = x;\n    output[base_pos + 1] = y;\n    output[base_pos + 2] = z;\n}\n\npub(crate) fn cross<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    dim: usize,\n) -> CubeTensor<R> {\n    let ndims = lhs.meta.num_dims();\n\n    // Validate that the cross dimension has size 3\n    if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 {\n        panic!(\n            \"Cross product requires dimension {} to have size 3, but got {} and {}\",\n            dim,\n            lhs.meta.shape()[dim],\n            rhs.meta.shape()[dim]\n        );\n    }\n\n    // For now, only support cross on the last dimension\n    if dim != ndims - 1 {\n        unimplemented!(\n            \"Cross product on non-last dimension not yet implemented for CubeCL backend\"\n        );\n    }\n\n    let output_shape = broadcast_shape(&[&lhs, &rhs]);\n\n    let output = empty_device_dtype(\n        lhs.client.clone(),\n        lhs.device.clone(),\n        output_shape.clone(),\n        lhs.dtype,\n    );\n\n    // Number of vectors to process\n    let num_vectors = output_shape.num_elements() / 3;\n\n    let cube_dim = CubeDim::new(&lhs.client, num_vectors);\n    let cube_count = calculate_cube_count_elemwise(&lhs.client, num_vectors, cube_dim);\n    let dtype = lhs.dtype;\n\n    unsafe {\n        cross_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(lhs, rhs, output),\n            lhs.into_linear_view_like(&output),\n            rhs.into_linear_view_like(&output),\n            output.clone().into_linear_view(),\n            dtype.into(),\n        );\n    };\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/grid_sample/base.rs",
    "content": "use cubecl::prelude::*;\n\nuse crate::{CubeRuntime, tensor::CubeTensor};\nuse burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};\n\nuse super::bilinear::grid_sample_bilinear_launch;\n\n/// Grid sample operation supporting bilinear interpolation\npub fn grid_sample<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    grid: CubeTensor<R>,\n    options: GridSampleOptions,\n) -> CubeTensor<R> {\n    match options.mode {\n        InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),\n        _ => panic!(\n            \"Unsupported grid_sample interpolation mode: {:?}\",\n            options.mode\n        ),\n    }\n}\n\n/// Compile-time padding mode for kernel specialization\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub enum PaddingMode {\n    /// Fill with zeros for out-of-bounds coordinates.\n    Zeros,\n    /// Clamp coordinates to the border (use nearest edge value).\n    Border,\n    /// Reflect coordinates at the boundary.\n    Reflection,\n}\n\nimpl From<GridSamplePaddingMode> for PaddingMode {\n    fn from(mode: GridSamplePaddingMode) -> Self {\n        match mode {\n            GridSamplePaddingMode::Zeros => PaddingMode::Zeros,\n            GridSamplePaddingMode::Border => PaddingMode::Border,\n            GridSamplePaddingMode::Reflection => PaddingMode::Reflection,\n        }\n    }\n}\n\n/// Fetch value based on padding mode (dispatch to appropriate handler)\n#[cube]\npub(crate) fn fetch_value<F: Float>(\n    input: &Tensor<F>,\n    base: usize,\n    stride_h: usize,\n    stride_w: usize,\n    y: i32,\n    x: i32,\n    h: i32,\n    w: i32,\n    #[comptime] padding_mode: PaddingMode,\n) -> F {\n    match padding_mode {\n        PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),\n        PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),\n        PaddingMode::Reflection => {\n            fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)\n        }\n    }\n}\n\n/// Fetch value with zeros padding (return 0 for out-of-bounds).\n#[cube]\npub(crate) fn fetch_with_zeros<F: Float>(\n    input: &Tensor<F>,\n    base: usize,\n    stride_h: usize,\n    stride_w: usize,\n    y: i32,\n    x: i32,\n    h: i32,\n    w: i32,\n) -> F {\n    let in_bounds = x >= 0 && x < w && y >= 0 && y < h;\n    let x_clamped = clamp(x, 0, w - 1) as usize;\n    let y_clamped = clamp(y, 0, h - 1) as usize;\n    let idx = base + y_clamped * stride_h + x_clamped * stride_w;\n    select(in_bounds, input[idx], F::new(0.0))\n}\n\n/// Fetch value with border padding (clamp to edge).\n#[cube]\npub(crate) fn fetch_with_border<F: Float>(\n    input: &Tensor<F>,\n    base: usize,\n    stride_h: usize,\n    stride_w: usize,\n    y: i32,\n    x: i32,\n    h: i32,\n    w: i32,\n) -> F {\n    let x_clamped = clamp(x, 0, w - 1) as usize;\n    let y_clamped = clamp(y, 0, h - 1) as usize;\n    let idx = base + y_clamped * stride_h + x_clamped * stride_w;\n    input[idx]\n}\n\n/// Fetch value with reflection padding.\n/// Assumes float reflection was applied to center, so indices are at most 2 steps out of bounds.\n#[cube]\npub(crate) fn fetch_with_reflection<F: Float>(\n    input: &Tensor<F>,\n    base: usize,\n    stride_h: usize,\n    stride_w: usize,\n    y: i32,\n    x: i32,\n    h: i32,\n    w: i32,\n) -> F {\n    let x_reflected = reflect_coord_bounded(x, w);\n    let y_reflected = reflect_coord_bounded(y, h);\n    let idx = base + y_reflected * stride_h + x_reflected * stride_w;\n    input[idx]\n}\n\n/// Reflect an integer index that may be out of bounds.\n/// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear).\n#[cube]\nfn reflect_coord_bounded(idx: i32, size: i32) -> usize {\n    let max_idx = size - 1;\n    let neg_reflected = -idx - 1;\n    let pos_reflected = 2 * max_idx + 1 - idx;\n    let result = select(\n        idx < 0,\n        neg_reflected,\n        select(idx > max_idx, pos_reflected, idx),\n    );\n    clamp(result, 0, max_idx) as usize\n}\n\n/// Reflect a float coordinate into the valid sampling range.\n#[cube]\npub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {\n    let size_f = F::cast_from(size);\n    if align_corners {\n        reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))\n    } else {\n        reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))\n    }\n}\n\n/// Reflect a float coordinate into [min_val, max_val] using a triangle wave pattern.\n#[cube]\nfn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {\n    let span = max_val - min_val;\n\n    let is_valid = span > F::new(0.0);\n    let safe_span = select(is_valid, span, F::new(1.0));\n\n    // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val\n    let period = safe_span * F::new(2.0);\n    let x = (coord - min_val).abs();\n    let x_mod = x - (x / period).floor() * period;\n    let reflected = safe_span - (x_mod - safe_span).abs() + min_val;\n\n    select(is_valid, reflected, min_val)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs",
    "content": "use cubecl::std::FastDivmod;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\nuse crate::{\n    CubeRuntime, kernel::utils::address_type, ops::numeric::empty_device_dtype, tensor::CubeTensor,\n};\nuse burn_backend::{Shape, ops::GridSampleOptions};\n\nuse super::base::{PaddingMode, fetch_value, reflect_coord};\n\n/// Grid sample with bilinear interpolation.\n///\n/// Each thread processes all channels for one spatial output position:\n/// 1. Reading (x, y) coordinates from the grid tensor (once per spatial position)\n/// 2. Converting normalized [-1, 1] coords to pixel coordinates (once)\n/// 3. For each channel: fetch 4 corner values, interpolate, and write output\n#[cube(launch, address_type = \"dynamic\")]\nfn grid_sample_bilinear_kernel<F: Float>(\n    input: &Tensor<F>,                          // [N, C, H_in, W_in]\n    grid: &Tensor<F>,                           // [N, H_out, W_out, 2]\n    output: &mut Tensor<F>,                     // [N, C, H_out, W_out]\n    shape_spatial: Sequence<FastDivmod<usize>>, // [N, H_out, W_out] for thread decomposition\n    #[comptime] align_corners: bool,\n    #[comptime] pad_mode: PaddingMode,\n    #[define(F)] _dtype: StorageType,\n) {\n    // Thread index maps to spatial position (n, h_out, w_out) only\n    let spatial_idx = ABSOLUTE_POS;\n    let num_spatial = output.shape(0) * output.shape(2) * output.shape(3);\n    if spatial_idx >= num_spatial {\n        terminate!();\n    }\n\n    // Decompose spatial index into (n, h_out, w_out)\n    let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx);\n    let (n, h_out) = shape_spatial[1].div_mod(rem);\n\n    let channels = input.shape(1) as u32;\n    let h_in = input.shape(2) as u32;\n    let w_in = input.shape(3) as u32;\n\n    // Read grid coordinates once per spatial position\n    let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2);\n    let gx = grid[grid_offset]; // x coordinate in [-1, 1]\n    let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1]\n\n    // Convert normalized coordinates to pixel coordinates\n    let (px, py) = if align_corners {\n        let px = (gx + F::new(1.0)) * F::cast_from((w_in - 1) as f32) / F::new(2.0);\n        let py = (gy + F::new(1.0)) * F::cast_from((h_in - 1) as f32) / F::new(2.0);\n        (px, py)\n    } else {\n        let px = (gx + F::new(1.0)) * F::cast_from(w_in as f32) / F::new(2.0) - F::new(0.5);\n        let py = (gy + F::new(1.0)) * F::cast_from(h_in as f32) / F::new(2.0) - F::new(0.5);\n        (px, py)\n    };\n\n    // For reflection padding, reflect the coordinate into the valid sampling range.\n    // This ensures integer indices are at most 1 step out of bounds.\n    let (px, py) = if comptime!(pad_mode == PaddingMode::Reflection) {\n        let px = reflect_coord::<F>(px, w_in, align_corners);\n        let py = reflect_coord::<F>(py, h_in, align_corners);\n        (px, py)\n    } else {\n        (px, py)\n    };\n\n    // Compute floor and ceil indices\n    let x0_f = px.floor();\n    let y0_f = py.floor();\n    let x1_f = x0_f + F::new(1.0);\n    let y1_f = y0_f + F::new(1.0);\n\n    // Compute interpolation weights\n    let wx = px - x0_f;\n    let wy = py - y0_f;\n    let wx_ = F::new(1.0) - wx;\n    let wy_ = F::new(1.0) - wy;\n\n    // Convert to integers for indexing\n    let x0 = i32::cast_from(x0_f);\n    let y0 = i32::cast_from(y0_f);\n    let x1 = i32::cast_from(x1_f);\n    let y1 = i32::cast_from(y1_f);\n\n    let w_in = w_in as i32;\n    let h_in = h_in as i32;\n\n    // Pre-compute strides\n    let stride_n = input.stride(0);\n    let stride_c = input.stride(1);\n    let stride_h = input.stride(2);\n    let stride_w = input.stride(3);\n    let out_stride_n = output.stride(0);\n    let out_stride_c = output.stride(1);\n    let out_stride_h = output.stride(2);\n    let out_stride_w = output.stride(3);\n\n    // Base offsets for this spatial position\n    let in_base_n = n * stride_n;\n    let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w;\n\n    // Loop over all channels - grid coords and weights are reused\n    for c in 0..channels {\n        let in_base = in_base_n + c as usize * stride_c;\n\n        let v00 = fetch_value(\n            input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode,\n        );\n        let v01 = fetch_value(\n            input, in_base, stride_h, stride_w, y1, x0, h_in, w_in, pad_mode,\n        );\n        let v10 = fetch_value(\n            input, in_base, stride_h, stride_w, y0, x1, h_in, w_in, pad_mode,\n        );\n        let v11 = fetch_value(\n            input, in_base, stride_h, stride_w, y1, x1, h_in, w_in, pad_mode,\n        );\n\n        // Bilinear interpolation\n        let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11;\n\n        let out_idx = out_base_spatial + c as usize * out_stride_c;\n        output[out_idx] = result;\n    }\n}\n\n/// Launch the grid sample bilinear kernel\npub(crate) fn grid_sample_bilinear_launch<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    grid: CubeTensor<R>,\n    options: GridSampleOptions,\n) -> CubeTensor<R> {\n    let [batch_size, channels, _h_in, _w_in] = input.meta.shape().dims();\n    let [_n, h_out, w_out, two] = grid.meta.shape().dims();\n    assert_eq!(two, 2, \"Grid last dimension must be 2\");\n\n    // Create output tensor [N, C, H_out, W_out]\n    let output_shape = Shape::new([batch_size, channels, h_out, w_out]);\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        output_shape,\n        input.dtype,\n    );\n\n    // Spatial threading: one thread per (n, h_out, w_out)\n    let spatial_shape = Shape::new([batch_size, h_out, w_out]);\n    let num_spatial = spatial_shape.num_elements();\n\n    let mut shape_spatial = SequenceArg::new();\n    for dim in spatial_shape.iter() {\n        shape_spatial.push(*dim);\n    }\n\n    let cube_dim = CubeDim::new(&input.client, num_spatial);\n    let cube_count = calculate_cube_count_elemwise(&input.client, num_spatial, cube_dim);\n\n    let padding_mode: PaddingMode = options.padding_mode.into();\n\n    let dtype = input.dtype;\n\n    grid_sample_bilinear_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, grid, output),\n        input.into_tensor_arg(),\n        grid.into_tensor_arg(),\n        output.clone().into_tensor_arg(),\n        shape_spatial,\n        options.align_corners,\n        padding_mode,\n        dtype.into(),\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/grid_sample/mod.rs",
    "content": "mod base;\nmod bilinear;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/flip.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, TensorMetadata};\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn flip_kernel<E: Numeric, Bool: Int>(\n    input: &Tensor<E>,\n    output: &mut LinearView<E, ReadWrite>,\n    in_shape: Sequence<FastDivmod<usize>>,\n    indices: Sequence<InputScalar>,\n    #[define(E, Bool)] _dtypes: [StorageType; 2],\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let rank = in_shape.len().comptime();\n\n    let mut offset = ABSOLUTE_POS;\n    let mut offset_input = 0;\n\n    #[unroll]\n    for i in 0..rank {\n        let dim = rank - i - 1;\n        let shape = input.shape(dim);\n\n        let (rem, offset_local) = in_shape[dim].div_mod(offset);\n        offset = rem;\n\n        let flip = indices.index(dim).get::<Bool>() == Bool::from_int(1);\n        let offset_local = select(flip, shape - offset_local - 1, offset_local);\n\n        offset_input += offset_local * input.stride(dim);\n    }\n\n    output[ABSOLUTE_POS] = input[offset_input];\n}\n\npub(crate) fn flip<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    indices: &[usize],\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        tensor.shape(),\n        tensor.dtype,\n    );\n    flip_on_output(tensor, output, indices, dtype_bool)\n}\n\npub(crate) fn flip_on_output<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    output: CubeTensor<R>,\n    indices: &[usize],\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let dtype_input = tensor.dtype;\n    let ndims = tensor.meta.num_dims();\n    let mut indices_sequence = SequenceArg::<R, InputScalar>::new();\n\n    for i in 0..ndims {\n        indices_sequence.push({\n            let val = indices.contains(&i) as u8;\n            InputScalar::new(val, dtype_bool)\n        });\n    }\n\n    let num_elements = output.meta.num_elements();\n    let cube_dim = CubeDim::new(&tensor.client, num_elements);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim);\n\n    let shape = shape_divmod(&tensor);\n    unsafe {\n        flip_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, output),\n            tensor.into_tensor_arg(),\n            output.clone().into_linear_view(),\n            shape,\n            indices_sequence,\n            [dtype_input.into(), dtype_bool.into()],\n        )\n    }\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/gather.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_strides, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor};\nuse cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod};\nuse cubecl::{CubeDim, std::tensor::layout::linear::LinearView};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn gather_kernel<T: Numeric, I: Numeric>(\n    input: &Tensor<T>,\n    indices: &LinearView<I>,\n    output: &mut LinearView<T, ReadWrite>,\n    in_strides: Sequence<usize>, // zeroed out for broadcast dims and `dim`\n    out_shape: Sequence<FastDivmod<usize>>,\n    dim: usize,\n    #[define(T, I)] _dtypes: [StorageType; 2],\n) {\n    if !indices.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let mut offset = index_offset_contiguous_fastdivmod(\n        ABSOLUTE_POS,\n        &out_shape,\n        &in_strides,\n        input.vector_size(),\n    );\n\n    offset += usize::cast_from(indices[ABSOLUTE_POS]) * input.stride(dim);\n\n    output[ABSOLUTE_POS] = input[offset];\n}\n\npub(crate) fn gather<R: CubeRuntime>(\n    dim: usize,\n    tensor: CubeTensor<R>,\n    indices: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let shape_output = indices.shape();\n    let total_elem = shape_output.num_elements();\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        shape_output,\n        tensor.dtype,\n    );\n\n    let cube_dim = CubeDim::new(&tensor.client, total_elem);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, total_elem, cube_dim);\n    let mut in_strides = broadcast_strides(&output, &tensor);\n    in_strides.values[dim] = 0; // Zero `dim` to exclude it from the indexing\n\n    let (dtype, indices_dtype) = (tensor.dtype, indices.dtype);\n\n    unsafe {\n        gather_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, indices, output),\n            tensor.into_tensor_arg(),\n            indices.into_linear_view(),\n            output.clone().into_linear_view(),\n            in_strides,\n            shape_divmod(&output),\n            dim,\n            [dtype.into(), indices_dtype.into()],\n        )\n    }\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/mod.rs",
    "content": "mod flip;\nmod gather;\nmod repeat_dim;\nmod scatter;\nmod select;\nmod select_assign;\nmod slice;\nmod slice_assign;\n\npub(crate) use flip::*;\npub(crate) use repeat_dim::*;\npub(crate) use select::*;\npub(crate) use select_assign::*;\npub use slice::*;\npub(crate) use slice_assign::*;\n\npub(crate) use gather::*;\npub(crate) use scatter::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/repeat_dim.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::FastDivmod};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn repeat_dim_kernel<E: Numeric>(\n    input: &Tensor<E>,\n    output: &mut Tensor<E>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    in_shape: FastDivmod<usize>,\n    #[comptime] dim: usize,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let rank = out_shape.len().comptime();\n\n    let mut pos = ABSOLUTE_POS;\n    let mut offset_input = 0;\n    let mut offset_output = 0;\n\n    #[unroll]\n    for i in 0..rank {\n        let i = rank - i - 1;\n\n        let (rem, mut local_pos) = out_shape[i].div_mod(pos);\n        pos = rem;\n\n        offset_output += local_pos * output.stride(i);\n\n        if i == dim {\n            local_pos = in_shape.modulo(local_pos);\n        }\n\n        offset_input += local_pos * input.stride(i);\n    }\n\n    output[offset_output] = input[offset_input];\n}\n\npub(crate) fn repeat_dim<R: CubeRuntime>(\n    mut input: CubeTensor<R>,\n    dim: usize,\n    times: usize,\n) -> CubeTensor<R> {\n    if input.meta.shape()[dim] == 1 {\n        input.meta.strides[dim] = 0;\n        input.meta.shape = input.meta.shape.clone().repeat(dim, times).unwrap();\n        return input;\n    }\n\n    let shape = input.meta.shape.clone().repeat(dim, times).unwrap();\n\n    // Create output handle\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        shape,\n        input.dtype,\n    );\n\n    let working_units = output.meta.num_elements();\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    let shape_arg = input.meta.shape()[dim];\n\n    unsafe {\n        repeat_dim_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(input, output),\n            input.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n            shape_divmod(&output),\n            shape_arg,\n            dim,\n            output.dtype.into(),\n        )\n    };\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/scatter.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        AddOp, BinaryOp, BinaryOpFamily, OrOp,\n        utils::{address_type, shape_divmod},\n    },\n    tensor::CubeTensor,\n};\nuse cubecl::{CubeDim, calculate_cube_count_elemwise};\nuse cubecl::{prelude::*, std::FastDivmod};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn scatter_kernel<T: Numeric, I: Int, Op: BinaryOpFamily>(\n    input: &mut Tensor<T>,\n    indices: &Tensor<I>,\n    value: &Tensor<T>,\n    in_shape: Sequence<FastDivmod<usize>>,\n    #[comptime] dim: usize,\n    #[define(T, I)] _dtypes: [StorageType; 2],\n) {\n    let rank = in_shape.len().comptime();\n    let stride_input = input.stride(dim);\n    let stride_value = value.stride(dim);\n    let stride_indices = indices.stride(dim);\n    let shape_value = value.shape(dim);\n\n    let mut offset = ABSOLUTE_POS;\n    let mut offset_input = 0;\n    let mut offset_indices = 0;\n    let mut offset_value = 0;\n    let mut num_elems = 1;\n\n    #[unroll]\n    for i in 0..rank {\n        let i = rank - i - 1;\n        if i != dim {\n            let shape_input_loop = input.shape(i);\n\n            let (rem, local_pos) = in_shape[i].div_mod(offset);\n            offset = rem;\n\n            offset_input += local_pos * input.stride(i);\n            offset_indices += local_pos * indices.stride(i);\n            offset_value += local_pos * value.stride(i);\n\n            num_elems *= shape_input_loop;\n        }\n    }\n\n    let should_stop = ABSOLUTE_POS >= num_elems;\n    if should_stop {\n        terminate!();\n    }\n\n    for i in 0..shape_value {\n        let value_idx = (stride_value * i) + offset_value;\n        let index_idx = (stride_indices * i) + offset_indices;\n\n        let value = value[value_idx];\n        let index = usize::cast_from(indices[index_idx]);\n\n        let input_idx = (stride_input * index) + offset_input;\n\n        let value = Op::BinaryOp::<T, Const<1>>::execute(\n            Vector::cast_from(input[input_idx]),\n            Vector::cast_from(value),\n        );\n        input[input_idx] = value[0];\n    }\n}\n\npub(crate) fn scatter<R: CubeRuntime>(\n    dim: usize,\n    tensor: CubeTensor<R>,\n    indices: CubeTensor<R>,\n    value: CubeTensor<R>,\n    is_bool: bool,\n) -> CubeTensor<R> {\n    let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {\n        true => tensor,\n        false => tensor.copy(),\n    };\n\n    let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];\n\n    let working_units = num_elems;\n    let cube_dim = CubeDim::new(&indices.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);\n\n    let launch = match is_bool {\n        true => scatter_kernel::launch_unchecked::<OrOp, R>,\n        false => scatter_kernel::launch_unchecked::<AddOp, R>,\n    };\n\n    let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype);\n\n    unsafe {\n        launch(\n            &tensor.client.clone(),\n            cube_count,\n            cube_dim,\n            address_type!(tensor, indices, value),\n            tensor.clone().into_tensor_arg(),\n            indices.into_tensor_arg(),\n            value.into_tensor_arg(),\n            shape_divmod(&tensor),\n            dim,\n            [tensor_dtype.into(), indices_dtype.into()],\n        )\n    }\n    tensor\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/select.rs",
    "content": "use crate::{CubeRuntime, kernel::utils::address_type, tensor::CubeTensor};\nuse crate::{kernel::utils::shape_divmod, ops::numeric::empty_device_dtype};\nuse burn_backend::TensorMetadata;\nuse cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};\nuse cubecl::{prelude::*, std::FastDivmod};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn select_kernel<T: Numeric, I: Numeric>(\n    input: &Tensor<T>,\n    indices: &LinearView<I>,\n    output: &mut LinearView<T, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    dim: usize,\n    #[define(T, I)] _dtypes: [StorageType; 2],\n) {\n    if ABSOLUTE_POS >= output.shape() {\n        terminate!();\n    }\n\n    let rank = out_shape.len().comptime();\n\n    let mut offset = ABSOLUTE_POS;\n    let mut offset_input = 0;\n\n    #[unroll]\n    for i in 0..rank {\n        let i = rank - i - 1;\n        let (rem, offset_local) = out_shape[i].div_mod(offset);\n        offset = rem;\n\n        let offset_local = cubecl::prelude::select(\n            i == dim,\n            usize::cast_from(indices[offset_local]),\n            offset_local,\n        );\n\n        offset_input += offset_local * input.stride(i);\n    }\n\n    output[ABSOLUTE_POS] = input[offset_input];\n}\n\npub(crate) fn select<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    dim: usize,\n    indices: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let mut shape_output = tensor.shape();\n    shape_output[dim] = indices.meta.shape()[0];\n    let total_elem = shape_output.num_elements();\n\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        shape_output,\n        tensor.dtype,\n    );\n\n    let working_units = total_elem;\n    let cube_dim = CubeDim::new(&indices.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);\n\n    let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype);\n\n    unsafe {\n        select_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, indices, output),\n            tensor.into_tensor_arg(),\n            indices.into_linear_view(),\n            output.clone().into_linear_view(),\n            shape_divmod(&output),\n            dim,\n            [tensor_dtype.into(), indices_dtype.into()],\n        )\n    };\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/select_assign.rs",
    "content": "use crate::kernel::{\n    AddOp, BinaryOp, BinaryOpFamily, OrOp,\n    utils::{address_type, shape_divmod},\n};\nuse crate::{CubeRuntime, tensor::CubeTensor};\nuse cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView};\nuse cubecl::{prelude::*, std::FastDivmod};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn select_assign_kernel<F: Numeric, I: Numeric, Op: BinaryOpFamily>(\n    tensor: &mut Tensor<F>,\n    indices: &LinearView<I>,\n    value: &Tensor<F>,\n    value_shape: Sequence<FastDivmod<usize>>,\n    num_elems: usize,\n    #[comptime] dim: usize,\n    #[define(F, I)] _dtypes: [StorageType; 2],\n) {\n    if ABSOLUTE_POS >= num_elems {\n        terminate!();\n    }\n\n    let rank = value_shape.len().comptime();\n\n    let mut offset = ABSOLUTE_POS;\n    let mut offset_tensor = 0;\n    let mut offset_value = 0;\n\n    // Calculate offsets and num_elems\n    #[unroll]\n    for i in 0..rank {\n        let i = rank - i - 1;\n        if i != dim {\n            let (rem, local_pos) = value_shape[i].div_mod(offset);\n            offset = rem;\n\n            offset_tensor += local_pos * tensor.stride(i);\n            offset_value += local_pos * value.stride(i);\n        }\n    }\n\n    let strides_tensor_dim = tensor.stride(dim);\n    let strides_value_dim = value.stride(dim);\n\n    // Main operation\n    for i in 0..value.shape(dim) {\n        let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor;\n        let index_value = i * strides_value_dim + offset_value;\n\n        let value = Op::BinaryOp::<F, Const<1>>::execute(\n            Vector::cast_from(tensor[index_tensor]),\n            Vector::cast_from(value[index_value]),\n        );\n        tensor[index_tensor] = F::cast_from(value);\n    }\n}\n\npub(crate) fn select_assign<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    dim: usize,\n    indices: CubeTensor<R>,\n    value: CubeTensor<R>,\n    is_bool: bool,\n) -> CubeTensor<R> {\n    let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {\n        true => tensor,\n        false => tensor.copy(),\n    };\n\n    let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim];\n    let working_units = num_elems;\n    let cube_dim = CubeDim::new(&indices.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);\n\n    let launch = match is_bool {\n        true => select_assign_kernel::launch_unchecked::<OrOp, R>,\n        false => select_assign_kernel::launch_unchecked::<AddOp, R>,\n    };\n\n    let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype);\n\n    let shape = shape_divmod(&value);\n    unsafe {\n        launch(\n            &tensor.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, indices, value),\n            tensor.clone().into_tensor_arg(),\n            indices.into_linear_view(),\n            value.into_tensor_arg(),\n            shape,\n            num_elems,\n            dim,\n            [tensor_dtype.into(), indices_dtype.into()],\n        )\n    };\n\n    tensor\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/slice.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, shape_divmod},\n    ops::numeric::empty_device_dtype,\n    tensor::CubeTensor,\n};\nuse burn_backend::{Slice, TensorMetadata};\nuse burn_std::{Metadata, SliceOps};\nuse cubecl::{\n    calculate_cube_count_elemwise, intrinsic,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\nuse std::ops::Range;\n\n/// Slice a jit tensor with a set of ranges\npub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {\n    let mut dims = tensor.shape();\n    let mut offset_start = 0u64;\n    let mut offset_end = 0u64;\n\n    for i in 0..indices.len() {\n        offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64;\n        offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64;\n        dims[i] = indices[i].end - indices[i].start;\n    }\n\n    let offset_start = offset_start * tensor.dtype.size() as u64;\n    let offset_end = offset_end * tensor.dtype.size() as u64;\n\n    let memory_offset_alignment = tensor.client.properties().memory.alignment;\n\n    if offset_start.is_multiple_of(memory_offset_alignment)\n        && offset_end.is_multiple_of(memory_offset_alignment)\n    {\n        CubeTensor::new(\n            tensor.client.clone(),\n            tensor\n                .handle\n                .clone()\n                .offset_start(offset_start)\n                .offset_end(offset_end),\n            Metadata::new(dims, tensor.meta.strides.clone()),\n            tensor.device.clone(),\n            tensor.dtype,\n        )\n    } else {\n        let output = empty_device_dtype(\n            tensor.client.clone(),\n            tensor.device.clone(),\n            dims,\n            tensor.dtype,\n        );\n        slice_on_output(tensor, output, indices)\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn slice_kernel<E: Numeric>(\n    input: &Tensor<E>,\n    output: &mut LinearView<E, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    indices: Sequence<usize>,\n    #[define(E)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let rank = comptime![out_shape.len()];\n    let mut offset_output = ABSOLUTE_POS;\n    let mut offset_input = 0;\n\n    #[unroll]\n    for i in 0..rank {\n        // Iterate in reverse to use divmod\n        let dim = rank - i - 1;\n\n        let range_start = indices[dim];\n        let (rem, offset_local) = out_shape[dim].div_mod(offset_output);\n        offset_output = rem;\n\n        let offset_local = offset_local + range_start;\n\n        offset_input += offset_local * input.stride(dim);\n    }\n\n    output[ABSOLUTE_POS] = input[offset_input];\n}\n\npub(crate) fn slice_on_output<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    output: CubeTensor<R>,\n    indices: &[Range<usize>],\n) -> CubeTensor<R> {\n    let ndims = tensor.meta.num_dims();\n    let mut indices_sequence = SequenceArg::<R, usize>::new();\n\n    for i in 0..ndims {\n        let start = indices.get(i).map(|index| index.start).unwrap_or(0);\n        indices_sequence.push(start);\n    }\n\n    let working_units = output.meta.num_elements();\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n    let dtype = tensor.dtype;\n\n    unsafe {\n        slice_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, output),\n            tensor.into_tensor_arg(),\n            output.clone().into_linear_view(),\n            shape_divmod(&output),\n            indices_sequence,\n            dtype.into(),\n        )\n    };\n\n    output\n}\n\n/// Kernel for slicing with steps\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn slice_with_steps_kernel<E: Numeric>(\n    input: &Tensor<E>,\n    output: &mut LinearView<E, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    starts: Sequence<usize>,\n    ends: Sequence<usize>,\n    steps: Sequence<i32>,\n    #[define(E)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let rank = comptime![out_shape.len()];\n    let mut output_offset = ABSOLUTE_POS;\n    let mut input_offset = 0;\n\n    // Calculate the input offset based on output position and slice info\n    #[unroll]\n    for i in 0..rank {\n        // Iterate in reverse to use divmod\n        let dim = rank - i - 1;\n        let start = starts[dim];\n        let end = ends[dim];\n        let step = steps[dim];\n\n        let (rem, output_idx) = out_shape[dim].div_mod(output_offset);\n        output_offset = rem;\n\n        let input_idx = if step > 0 {\n            // Forward stepping\n            start + output_idx * (step as usize)\n        } else {\n            // Backward stepping - start from end-1\n            let abs_step = (-step) as usize;\n            let end_minus_1 = end - 1;\n            end_minus_1 - output_idx * abs_step\n        };\n\n        input_offset += input_idx * input.stride(dim);\n    }\n\n    output[ABSOLUTE_POS] = input[input_offset];\n}\n\n/// Slice a tensor with steps\npub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {\n    // Check if all steps are 1 - if so, use the optimized regular slice\n    let all_steps_one = slices.iter().all(|info| info.step == 1);\n\n    if all_steps_one {\n        // Convert Slice to Range for step=1\n        let simple_ranges: Vec<Range<usize>> = slices\n            .iter()\n            .enumerate()\n            .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))\n            .collect();\n        return slice(tensor, &simple_ranges);\n    }\n\n    // Calculate output shape\n    let shape_output = tensor.shape().slice(slices).unwrap();\n\n    // Create output tensor\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        shape_output.clone(),\n        tensor.dtype,\n    );\n\n    // Prepare three separate sequences for kernel\n    let mut starts = SequenceArg::<R, usize>::new();\n    let mut ends = SequenceArg::<R, usize>::new();\n    let mut steps = SequenceArg::<R, i32>::new();\n\n    for (dim, slice) in slices.iter().enumerate() {\n        let range = slice.to_range(tensor.meta.shape()[dim]);\n        starts.push(range.start);\n        ends.push(range.end);\n        steps.push(slice.step as i32);\n    }\n\n    // Pad with default values if needed to match tensor dimensions\n    for dim in slices.len()..tensor.meta.num_dims() {\n        starts.push(0);\n        ends.push(tensor.meta.shape[dim]);\n        steps.push(1);\n    }\n\n    // Launch kernel\n    let working_units = shape_output.num_elements();\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n    let dtype = tensor.dtype;\n\n    unsafe {\n        slice_with_steps_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, output),\n            tensor.into_tensor_arg(),\n            output.clone().into_linear_view(),\n            shape_divmod(&output),\n            starts,\n            ends,\n            steps,\n            dtype.into(),\n        );\n    }\n\n    output\n}\n\n/// This is annoying and we need to find a way to do this automatically at some point\n#[allow(unused)]\n#[cube]\nfn unwrap(value: u32) -> comptime_type!(u32) {\n    intrinsic!(|_| value.constant().unwrap().as_u32())\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/index/slice_assign.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, shape_divmod},\n    tensor::CubeTensor,\n};\nuse cubecl::{\n    calculate_cube_count_elemwise, intrinsic,\n    prelude::*,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn slice_assign_kernel<E: Numeric, N: Size>(\n    input: &mut Tensor<Vector<E, N>>,\n    value: &LinearView<Vector<E, N>>,\n    slice_shape: Sequence<FastDivmod<usize>>,\n    slice_offsets: Sequence<usize>,\n    #[define(E)] _dtype: StorageType,\n) {\n    if !value.is_in_bounds(ABSOLUTE_POS) {\n        terminate!()\n    }\n\n    let rank = comptime!(slice_shape.len());\n\n    let line_size = input.vector_size();\n    let mut offset_remainder = ABSOLUTE_POS * line_size;\n    let mut offset_input = 0;\n\n    #[allow(clippy::explicit_counter_loop)]\n    #[unroll]\n    for i in 0..rank {\n        let dim = rank - i - 1;\n        let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);\n\n        let range_start = slice_offsets[dim];\n        let offset_local_input = offset_local + range_start;\n\n        offset_input += offset_local_input * input.stride(dim);\n        offset_remainder = rem;\n    }\n\n    // Value tensor is accessed linearly since it's a LinearView\n    input[offset_input / line_size] = value[ABSOLUTE_POS];\n}\n\n/// Kernel for slice assign with steps\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn slice_assign_with_steps_kernel<E: Numeric>(\n    input: &mut Tensor<E>,\n    value: &LinearView<E>,\n    value_shape: Sequence<FastDivmod<usize>>,\n    starts: Sequence<usize>,\n    ends: Sequence<usize>,\n    steps: Sequence<i32>,\n    #[define(E)] _dtype: StorageType,\n) {\n    if !value.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let rank = comptime![value_shape.len()];\n    let mut value_offset = ABSOLUTE_POS;\n    let mut input_offset = 0;\n\n    // Calculate the input offset based on value position and slice info\n    #[unroll]\n    for i in 0..rank {\n        // Iterate in reverse to use divmod\n        let dim = rank - i - 1;\n        let start = starts[dim];\n        let end = ends[dim];\n        let step = steps[dim];\n\n        let (rem, value_idx) = value_shape[dim].div_mod(value_offset);\n        value_offset = rem;\n\n        let input_idx = if step > 0 {\n            // Forward stepping\n            start + value_idx * (step as usize)\n        } else if step < 0 {\n            // Backward stepping - start from end-1\n            // For negative steps, we iterate backwards through the selected indices\n            let abs_step = (-step) as usize;\n            let end_minus_1 = end - 1;\n            end_minus_1 - value_idx * abs_step\n        } else {\n            // step == 0, shouldn't happen\n            value_idx\n        };\n\n        input_offset += input_idx * input.stride(dim);\n    }\n\n    input[input_offset] = value[ABSOLUTE_POS];\n}\n\npub(crate) fn slice_assign<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    indices: &[burn_backend::Slice],\n    value: CubeTensor<R>,\n) -> CubeTensor<R> {\n    // Check if any slice has non-unit step\n    let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);\n\n    if has_non_unit_step {\n        // Use slice_assign_with_steps\n        return slice_assign_with_steps(tensor, indices, value);\n    }\n\n    let client = tensor.client.clone();\n    let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {\n        true => tensor,\n        false => tensor.copy(),\n    };\n    let ndims = tensor.meta.num_dims();\n\n    let vector_size =\n        if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1 {\n            let last = indices\n                .get(ndims - 1)\n                .cloned()\n                .unwrap_or(burn_backend::Slice {\n                    start: 0,\n                    end: Some(tensor.meta.shape()[ndims - 1] as isize),\n                    step: 1,\n                });\n            let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize);\n            let shape = (end - last.start) as usize;\n            let offset = last.start as usize;\n            client\n                .io_optimized_vector_sizes(tensor.dtype.size())\n                .filter(|&it| {\n                    shape.is_multiple_of(it)\n                        && strides_compatible(tensor.meta.strides(), it)\n                        && strides_compatible(value.meta.strides(), it)\n                        && offset.is_multiple_of(it)\n                })\n                .max()\n                .unwrap_or(1)\n        } else {\n            1\n        };\n\n    let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();\n    let mut offsets = SequenceArg::<R, usize>::new();\n\n    for i in 0..ndims {\n        let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {\n            start: 0,\n            end: Some(tensor.meta.shape()[i] as isize),\n            step: 1,\n        });\n        let start = slice.start as usize;\n        let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize);\n        let length = (end - slice.start) as usize;\n\n        shape.push(length);\n        offsets.push(start);\n    }\n\n    let working_units = value.meta.num_elements() / vector_size;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    unsafe {\n        slice_assign_kernel::launch_unchecked(\n            &tensor.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, value),\n            vector_size,\n            tensor.clone().into_tensor_arg(),\n            value.into_linear_view(),\n            shape,\n            offsets,\n            tensor.dtype.into(),\n        )\n    };\n\n    tensor\n}\n\n/// Slice assign with steps support\n///\n/// This function handles slice assignment with arbitrary step values, including negative steps.\n/// It follows NumPy/PyTorch semantics where values[i] is assigned to selected_indices[i].\n///\n/// For example, with s![0..6;-1] which selects indices [5,4,3,2,1,0]:\n/// - values[0] goes to index 5\n/// - values[1] goes to index 4\n/// - etc.\npub(crate) fn slice_assign_with_steps<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    slices: &[burn_backend::Slice],\n    value: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {\n        true => tensor,\n        false => tensor.copy(),\n    };\n\n    // Prepare sequences for kernel\n    let mut starts = SequenceArg::<R, usize>::new();\n    let mut ends = SequenceArg::<R, usize>::new();\n    let mut steps = SequenceArg::<R, i32>::new();\n\n    for (dim, slice) in slices.iter().enumerate() {\n        let range = slice.to_range(tensor.meta.shape()[dim]);\n        starts.push(range.start);\n        ends.push(range.end);\n        steps.push(slice.step as i32);\n    }\n\n    // Pad with default values if needed to match tensor dimensions\n    for dim in slices.len()..tensor.meta.num_dims() {\n        starts.push(0);\n        ends.push(tensor.meta.shape[dim]);\n        steps.push(1);\n    }\n\n    // Launch kernel\n    let working_units = value.meta.num_elements();\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n\n    let shape = shape_divmod(&value);\n    unsafe {\n        slice_assign_with_steps_kernel::launch_unchecked(\n            &tensor.client,\n            cube_count,\n            cube_dim,\n            address_type!(tensor, value),\n            tensor.clone().into_tensor_arg(),\n            value.into_linear_view(),\n            shape,\n            starts,\n            ends,\n            steps,\n            tensor.dtype.into(),\n        );\n    }\n\n    tensor\n}\n\nfn strides_compatible(strides: &[usize], vec: usize) -> bool {\n    strides\n        .iter()\n        .all(|stride| *stride % vec == 0 || *stride == 1)\n}\n\n/// Helper function for unwrap\n#[allow(unused)]\n#[cube]\nfn unwrap(value: u32) -> comptime_type!(u32) {\n    intrinsic!(|_| value.constant().unwrap().as_u32())\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/base.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::into_contiguous,\n    ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},\n    tensor::CubeTensor,\n};\nuse burn_backend::{\n    Shape, TensorMetadata,\n    ops::{InterpolateMode, InterpolateOptions},\n};\n\nuse super::{\n    bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch,\n    lanczos3::interpolate_lanczos3_launch, nearest::interpolate_nearest_launch,\n    nearest_backward::interpolate_nearest_backward_launch,\n};\n\n/// Interpolate operation\n///\n/// Supports nearest, bilinear, bicubic and lanczos3 modes\npub fn interpolate<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output_size: [usize; 2],\n    options: InterpolateOptions,\n) -> CubeTensor<R> {\n    let [batch_size, channels, _, _] = input.meta.shape().dims();\n    let [out_height, out_width] = output_size;\n\n    let input = into_contiguous(permute_nchw_to_nhwc(input));\n\n    let shape_out = Shape::new([batch_size, out_height, out_width, channels]);\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        shape_out,\n        input.dtype,\n    );\n\n    let align_corners = options.align_corners;\n    let output = match options.mode {\n        InterpolateMode::Nearest => interpolate_nearest_launch(input, output),\n        InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners),\n        InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners),\n        InterpolateMode::Lanczos3 => interpolate_lanczos3_launch(input, output, align_corners),\n    };\n\n    permute_nhwc_to_nchw(output)\n}\n\n/// Backward interpolate operation\n///\n/// Note: only nearest mode is supported\npub fn interpolate_backward<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n    _output_size: [usize; 2],\n    options: InterpolateOptions,\n) -> CubeTensor<R> {\n    let input = permute_nchw_to_nhwc(input);\n    let out_grad = permute_nchw_to_nhwc(out_grad);\n\n    let output_shape = input.shape();\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        output_shape,\n        input.dtype,\n    );\n\n    let output = match options.mode {\n        InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output),\n        InterpolateMode::Bilinear => {\n            panic!(\"bilinear interpolation backward is not supported by JIT backend\")\n        }\n        InterpolateMode::Bicubic => {\n            panic!(\"bicubic interpolation backward is not supported by JIT backend\")\n        }\n        InterpolateMode::Lanczos3 => {\n            panic!(\"lanczos3 interpolation backward is not supported by JIT backend\")\n        }\n    };\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/bicubic.rs",
    "content": "use cubecl::std::{\n    FastDivmod,\n    tensor::layout::{linear::LinearLayout, *},\n};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, linear_layout, shape_divmod},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn interpolate_bicubic_kernel<F: Float, N: Size>(\n    input: &Tensor<Vector<F, N>>,\n    output: &mut Tensor<Vector<F, N>>,\n    shape_out: Sequence<FastDivmod<usize>>,\n    out_layout: LinearLayout,\n    #[comptime] align_corners: bool,\n    #[define(F)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let vector_size = input.vector_size();\n    let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);\n\n    let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size);\n    let (rem, x) = shape_out[2].div_mod(rem);\n    let (b, y) = shape_out[1].div_mod(rem);\n\n    let input_height = input.shape(1) - 1;\n    let input_height_f = input_height as f32;\n\n    let frac = if align_corners {\n        let output_height = clamp_min(output.shape(1) - 1, 1) as f32;\n        (y * input_height) as f32 / output_height\n    } else {\n        let in_size = (input_height + 1) as f32;\n        let out_size = output.shape(1) as f32;\n        (y as f32 + 0.5) * (in_size / out_size) - 0.5\n    };\n    let y_in_f = frac.floor();\n    let yw = Vector::new(F::cast_from(frac - y_in_f));\n\n    // Clamp indices in float space to handle negative coordinates from half_pixel\n    let y0 = clamp(y_in_f - 1.0, 0.0, input_height_f) as usize;\n    let y1 = clamp(y_in_f, 0.0, input_height_f) as usize;\n    let y2 = clamp(y_in_f + 1.0, 0.0, input_height_f) as usize;\n    let y3 = clamp(y_in_f + 2.0, 0.0, input_height_f) as usize;\n\n    let input_width = input.shape(2) - 1;\n    let input_width_f = input_width as f32;\n\n    let frac = if align_corners {\n        let output_width = clamp_min(output.shape(2) - 1, 1) as f32;\n        (x * input_width) as f32 / output_width\n    } else {\n        let in_size = (input_width + 1) as f32;\n        let out_size = output.shape(2) as f32;\n        (x as f32 + 0.5) * (in_size / out_size) - 0.5\n    };\n    let x_in_f = frac.floor();\n    let xw = Vector::new(F::cast_from(frac - x_in_f));\n\n    // Clamp indices in float space to handle negative coordinates from half_pixel\n    let x0 = clamp(x_in_f - 1.0, 0.0, input_width_f) as usize;\n    let x1 = clamp(x_in_f, 0.0, input_width_f) as usize;\n    let x2 = clamp(x_in_f + 1.0, 0.0, input_width_f) as usize;\n    let x3 = clamp(x_in_f + 2.0, 0.0, input_width_f) as usize;\n\n    let index_base = b * input.stride(0) + c * input.stride(3);\n    let in_stride_y = input.stride(1);\n    let in_stride_x = input.stride(2);\n\n    let y0_stride = y0 * in_stride_y;\n    let y1_stride = y1 * in_stride_y;\n    let y2_stride = y2 * in_stride_y;\n    let y3_stride = y3 * in_stride_y;\n    let x0_stride = x0 * in_stride_x;\n    let x1_stride = x1 * in_stride_x;\n    let x2_stride = x2 * in_stride_x;\n    let x3_stride = x3 * in_stride_x;\n\n    let inp_0 = input[(index_base + y0_stride + x0_stride) / vector_size];\n    let inp_1 = input[(index_base + y0_stride + x1_stride) / vector_size];\n    let inp_2 = input[(index_base + y0_stride + x2_stride) / vector_size];\n    let inp_3 = input[(index_base + y0_stride + x3_stride) / vector_size];\n\n    let coefficients0 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw);\n\n    let inp_0 = input[(index_base + y1_stride + x0_stride) / vector_size];\n    let inp_1 = input[(index_base + y1_stride + x1_stride) / vector_size];\n    let inp_2 = input[(index_base + y1_stride + x2_stride) / vector_size];\n    let inp_3 = input[(index_base + y1_stride + x3_stride) / vector_size];\n\n    let coefficients1 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw);\n\n    let inp_0 = input[(index_base + y2_stride + x0_stride) / vector_size];\n    let inp_1 = input[(index_base + y2_stride + x1_stride) / vector_size];\n    let inp_2 = input[(index_base + y2_stride + x2_stride) / vector_size];\n    let inp_3 = input[(index_base + y2_stride + x3_stride) / vector_size];\n\n    let coefficients2 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw);\n\n    let inp_0 = input[(index_base + y3_stride + x0_stride) / vector_size];\n    let inp_1 = input[(index_base + y3_stride + x1_stride) / vector_size];\n    let inp_2 = input[(index_base + y3_stride + x2_stride) / vector_size];\n    let inp_3 = input[(index_base + y3_stride + x3_stride) / vector_size];\n\n    let coefficients3 = cubic_interp_1d(inp_0, inp_1, inp_2, inp_3, xw);\n\n    let val = cubic_interp_1d(\n        coefficients0,\n        coefficients1,\n        coefficients2,\n        coefficients3,\n        yw,\n    );\n\n    output[out_idx] = val;\n}\n\n#[cube]\nfn cubic_interp_1d<F: Float, N: Size>(\n    x0: Vector<F, N>,\n    x1: Vector<F, N>,\n    x2: Vector<F, N>,\n    x3: Vector<F, N>,\n    t: Vector<F, N>,\n) -> Vector<F, N> {\n    let a = float(-0.75);\n\n    let coeffs0 = cubic_convolution_2(t + float(1.0), a);\n    let coeffs1 = cubic_convolution_1(t, a);\n    let coeffs2 = cubic_convolution_1(float(1.0) - t, a);\n    let coeffs3 = cubic_convolution_2(float(2.0) - t, a);\n\n    x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3\n}\n\n#[cube]\nfn cubic_convolution_1<F: Float, N: Size>(x: Vector<F, N>, a: Vector<F, N>) -> Vector<F, N> {\n    let conv = (a + float(2.0)) * x;\n    let tmp = a + float(3.0);\n    (conv - tmp) * x * x + float(1.0)\n}\n\n#[cube]\nfn cubic_convolution_2<F: Float, N: Size>(x: Vector<F, N>, a: Vector<F, N>) -> Vector<F, N> {\n    let conv = a * x;\n    let conv = (conv - float(5.0) * a) * x;\n    let tmp = float(8.0) * a;\n    let conv = (conv + tmp) * x;\n\n    conv - float(4.0) * a\n}\n\n#[cube]\nfn float<F: Float, N: Size>(#[comptime] v: f32) -> Vector<F, N> {\n    Vector::new(F::new(v))\n}\n\npub(crate) fn interpolate_bicubic_launch<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output: CubeTensor<R>,\n    align_corners: bool,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&input);\n    let out_shape = shape_divmod(&output);\n    let out_layout = linear_layout(&output, vector_size);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    interpolate_bicubic_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, output),\n        vector_size,\n        input.into_tensor_arg(),\n        output.clone().into_tensor_arg(),\n        out_shape,\n        out_layout,\n        align_corners,\n        output.dtype.into(),\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/bilinear.rs",
    "content": "use cubecl::{calculate_cube_count_elemwise, prelude::*};\nuse cubecl::{\n    num_traits::Zero,\n    std::{\n        FastDivmod,\n        tensor::layout::{linear::LinearLayout, *},\n    },\n};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, linear_layout, shape_divmod},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn interpolate_bilinear_kernel<F: Float, N: Size>(\n    input: &Tensor<Vector<F, N>>,\n    output: &mut Tensor<Vector<F, N>>,\n    shape_out: Sequence<FastDivmod<usize>>,\n    out_layout: LinearLayout,\n    #[comptime] align_corners: bool,\n    #[define(F)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let vector_size = input.vector_size();\n    let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);\n\n    let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size);\n    let (rem, x) = shape_out[2].div_mod(rem);\n    let (b, y) = shape_out[1].div_mod(rem);\n\n    let frac = if align_corners {\n        let numerator = (input.shape(1) - 1) as f32;\n        let denominator = clamp_min(output.shape(1) - 1, 1) as f32;\n        y as f32 * (numerator / denominator)\n    } else {\n        let in_size = input.shape(1) as f32;\n        let out_size = output.shape(1) as f32;\n        clamp(\n            (y as f32 + 0.5) * (in_size / out_size) - 0.5,\n            0.0,\n            in_size - 1.0,\n        )\n    };\n\n    let v0 = frac.floor();\n    let v1 = frac.ceil();\n    let yw = F::cast_from(frac - v0);\n    let yw_ = Vector::new(F::one() - yw);\n    let yw = Vector::new(yw);\n    let y0_ok = v0 >= 0.0;\n    let y0 = v0 as usize;\n    let y1 = v1 as usize;\n\n    let frac = if align_corners {\n        let numerator = (input.shape(2) - 1) as f32;\n        let denominator = clamp_min(output.shape(2) - 1, 1) as f32;\n        x as f32 * (numerator / denominator)\n    } else {\n        let in_size = input.shape(2) as f32;\n        let out_size = output.shape(2) as f32;\n        clamp(\n            (x as f32 + 0.5) * (in_size / out_size) - 0.5,\n            0.0,\n            in_size - 1.0,\n        )\n    };\n    let v0 = frac.floor();\n    let v1 = frac.ceil();\n    let xw = F::cast_from(frac - v0);\n    let xw_ = Vector::new(F::one() - xw);\n    let xw = Vector::new(xw);\n    let x0_ok = v0 >= 0.0;\n    let x0 = v0 as usize;\n    let x1 = v1 as usize;\n\n    let index_base = b * input.stride(0) + c * input.stride(3);\n\n    let in_stride_y = input.stride(1);\n    let in_stride_x = input.stride(2);\n\n    let y0_stride = y0 * in_stride_y;\n    let y1_stride = y1 * in_stride_y;\n    let x0_stride = x0 * in_stride_x;\n    let x1_stride = x1 * in_stride_x;\n\n    let height = input.shape(1);\n    let width = input.shape(2);\n\n    let y1_ok = y1 < height;\n    let x1_ok = x1 < width;\n\n    let zero = Vector::zero();\n\n    let p_a = select(\n        x0_ok && y0_ok,\n        input[(index_base + y0_stride + x0_stride) / vector_size] * xw_ * yw_,\n        zero,\n    );\n    let p_b = select(\n        x1_ok && y0_ok,\n        input[(index_base + y0_stride + x1_stride) / vector_size] * xw * yw_,\n        zero,\n    );\n    let p_c = select(\n        x0_ok && y1_ok,\n        input[(index_base + y1_stride + x0_stride) / vector_size] * xw_ * yw,\n        zero,\n    );\n    let p_d = select(\n        x1_ok && y1_ok,\n        input[(index_base + y1_stride + x1_stride) / vector_size] * xw * yw,\n        zero,\n    );\n\n    output[out_idx] = p_a + p_b + p_c + p_d;\n}\n\npub(crate) fn interpolate_bilinear_launch<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output: CubeTensor<R>,\n    align_corners: bool,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&input);\n    let out_shape = shape_divmod(&output);\n    let out_layout = linear_layout(&output, vector_size);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    interpolate_bilinear_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, output),\n        vector_size,\n        input.into_tensor_arg(),\n        output.clone().into_tensor_arg(),\n        out_shape,\n        out_layout,\n        align_corners,\n        output.dtype.into(),\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/lanczos3.rs",
    "content": "use cubecl::{calculate_cube_count_elemwise, prelude::*};\nuse cubecl::{\n    num_traits::Zero,\n    std::{\n        FastDivmod,\n        tensor::layout::{linear::LinearLayout, *},\n    },\n};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, linear_layout, shape_divmod},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn interpolate_lanczos3_kernel<F: Float, N: Size>(\n    input: &Tensor<Vector<F, N>>,\n    output: &mut Tensor<Vector<F, N>>,\n    shape_out: Sequence<FastDivmod<usize>>,\n    out_layout: LinearLayout,\n    #[comptime] align_corners: bool,\n    #[define(F)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let vector_size = input.vector_size();\n    let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);\n\n    let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size);\n    let (rem, x) = shape_out[2].div_mod(rem);\n    let (b, y) = shape_out[1].div_mod(rem);\n\n    let input_height = input.shape(1) - 1;\n    let input_height_f = input_height as f32;\n\n    let y_frac = if align_corners {\n        let output_height = clamp_min(output.shape(1) - 1, 1) as f32;\n        (y * input_height) as f32 / output_height\n    } else {\n        let in_size = (input_height + 1) as f32;\n        let out_size = output.shape(1) as f32;\n        (y as f32 + 0.5) * (in_size / out_size) - 0.5\n    };\n    let y0 = f32::floor(y_frac);\n\n    let input_width = input.shape(2) - 1;\n    let input_width_f = input_width as f32;\n\n    let x_frac = if align_corners {\n        let output_width = clamp_min(output.shape(2) - 1, 1) as f32;\n        (x * input_width) as f32 / output_width\n    } else {\n        let in_size = (input_width + 1) as f32;\n        let out_size = output.shape(2) as f32;\n        (x as f32 + 0.5) * (in_size / out_size) - 0.5\n    };\n    let x0 = f32::floor(x_frac);\n\n    let index_base = b * input.stride(0) + c * input.stride(3);\n    let in_stride_y = input.stride(1);\n    let in_stride_x = input.stride(2);\n\n    let mut result = Vector::zero();\n    let mut weight_sum = 0.0f32;\n\n    // 6-tap separable Lanczos3 filter: ky in -2..=3, kx in -2..=3\n    // Skip out-of-bounds positions instead of clamping (matches TF/JAX/PIL)\n    #[unroll]\n    for ky in -2..4i32 {\n        let y_pos = y0 + ky as f32;\n        if y_pos >= 0.0 && y_pos <= input_height_f {\n            let y_idx = y_pos as usize;\n            let wy = lanczos3_weight(y_frac - y_pos);\n\n            #[unroll]\n            for kx in -2..4i32 {\n                let x_pos = x0 + kx as f32;\n                if x_pos >= 0.0 && x_pos <= input_width_f {\n                    let x_idx = x_pos as usize;\n                    let wx = lanczos3_weight(x_frac - x_pos);\n\n                    let wt = wy * wx;\n                    let idx = index_base + y_idx * in_stride_y + x_idx * in_stride_x;\n                    let pixel = input[idx / vector_size];\n                    let w = Vector::new(F::cast_from(wt));\n                    result += pixel * w;\n                    weight_sum += wt;\n                }\n            }\n        }\n    }\n\n    if weight_sum != 0.0 {\n        let inv_w = Vector::new(F::cast_from(1.0 / weight_sum));\n        result *= inv_w;\n    }\n\n    output[out_idx] = result;\n}\n\n#[cube]\nfn lanczos3_weight(x: f32) -> f32 {\n    let abs_x = f32::abs(x);\n    let mut result = 0.0f32;\n    if abs_x < 1e-7 {\n        result = 1.0;\n    } else if abs_x < 3.0 {\n        let pi = core::f32::consts::PI;\n        let pi_x = pi * x;\n        let pi_x_over_3 = pi_x / 3.0;\n        result = (f32::sin(pi_x) * f32::sin(pi_x_over_3)) / (pi_x * pi_x_over_3);\n    }\n    result\n}\n\npub(crate) fn interpolate_lanczos3_launch<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output: CubeTensor<R>,\n    align_corners: bool,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&input);\n    let out_shape = shape_divmod(&output);\n    let out_layout = linear_layout(&output, vector_size);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    interpolate_lanczos3_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, output),\n        vector_size,\n        input.into_tensor_arg(),\n        output.clone().into_tensor_arg(),\n        out_shape,\n        out_layout,\n        align_corners,\n        output.dtype.into(),\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/mod.rs",
    "content": "mod base;\nmod bicubic;\nmod bilinear;\nmod lanczos3;\nmod nearest;\nmod nearest_backward;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/nearest.rs",
    "content": "use cubecl::std::{\n    FastDivmod,\n    tensor::layout::{linear::LinearLayout, *},\n};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, linear_layout, shape_divmod},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn interpolate_nearest_kernel<F: Float, N: Size>(\n    input: &Tensor<Vector<F, N>>,\n    output: &mut Tensor<Vector<F, N>>,\n    shape_out: Sequence<FastDivmod<usize>>,\n    out_layout: LinearLayout,\n    #[define(F)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let vector_size = input.vector_size();\n    let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);\n\n    let out_pos = ABSOLUTE_POS * vector_size;\n\n    let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32);\n    let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32);\n\n    let (rem, c) = shape_out[3].div_mod(out_pos);\n    let (rem, x) = shape_out[2].div_mod(rem);\n    let (b, y) = shape_out[1].div_mod(rem);\n\n    let y = y as f32 * (h_in / h_out);\n    let x = x as f32 * (w_in / w_out);\n\n    let in_idx = b * input.stride(0)\n        + y as usize * input.stride(1)\n        + x as usize * input.stride(2)\n        + c * input.stride(3);\n\n    output[out_idx] = input[in_idx / vector_size];\n}\n\npub(crate) fn interpolate_nearest_launch<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let client = input.client.clone();\n\n    let vector_size = max_vector_size(&input);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    let shape_out = shape_divmod(&output);\n    let out_layout = linear_layout(&output, vector_size);\n\n    unsafe {\n        interpolate_nearest_kernel::launch_unchecked(\n            &client,\n            cube_count,\n            cube_dim,\n            address_type!(input, output),\n            vector_size,\n            input.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n            shape_out,\n            out_layout,\n            output.dtype.into(),\n        )\n    };\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs",
    "content": "use cubecl::{calculate_cube_count_elemwise, prelude::*};\nuse cubecl::{\n    num_traits::Zero,\n    std::{\n        FastDivmod,\n        tensor::layout::{linear::LinearLayout, *},\n    },\n};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, linear_layout, shape_divmod},\n    ops::max_vector_size,\n    tensor::CubeTensor,\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn interpolate_nearest_backward_kernel<F: Float, N: Size>(\n    grad: &Tensor<Vector<F, N>>,\n    output: &mut Tensor<Vector<F, N>>,\n    shape_out: Sequence<FastDivmod<usize>>,\n    out_layout: LinearLayout,\n    #[define(F)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= output.len() {\n        terminate!();\n    }\n\n    let vector_size = grad.vector_size();\n    let out_idx = out_layout.to_source_pos(ABSOLUTE_POS);\n\n    let out_h = output.shape(1);\n    let out_w = output.shape(2);\n    let grad_h = grad.shape(1);\n    let grad_w = grad.shape(2);\n\n    let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * vector_size);\n    let (rem, out_x) = shape_out[2].div_mod(rem);\n    let (b, out_y) = shape_out[1].div_mod(rem);\n\n    let grad_y_start = start_index::<F>(out_y, grad_h, out_h);\n    let grad_y_end = end_index::<F>(out_y, grad_h, out_h);\n    let grad_x_start = start_index::<F>(out_x, grad_w, out_w);\n    let grad_x_end = end_index::<F>(out_x, grad_w, out_w);\n\n    let index_grad_base = b * grad.stride(0) + c * grad.stride(3);\n\n    let mut sum = Vector::zero();\n\n    for grad_y in grad_y_start..grad_y_end {\n        for grad_x in grad_x_start..grad_x_end {\n            let index_grad = index_grad_base + grad_y * grad.stride(1) + grad_x * grad.stride(2);\n\n            sum += grad[index_grad];\n        }\n    }\n\n    output[out_idx] = sum;\n}\n\n#[cube]\nfn start_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {\n    let numerator = F::cast_from(input_index * output_size);\n    let div = (numerator / F::cast_from(input_size)).ceil();\n\n    usize::cast_from(div)\n}\n\n#[cube]\nfn end_index<F: Float>(input_index: usize, output_size: usize, input_size: usize) -> usize {\n    let numerator = F::cast_from((input_index + 1) * output_size);\n    let div = (numerator / F::cast_from(input_size)).ceil();\n    let index = usize::cast_from(div);\n\n    clamp_max(index, output_size)\n}\n\npub(crate) fn interpolate_nearest_backward_launch<R: CubeRuntime>(\n    out_grad: CubeTensor<R>,\n    output: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size(&out_grad);\n    let out_shape = shape_divmod(&output);\n    let out_layout = linear_layout(&output, vector_size);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&out_grad.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&out_grad.client, working_units, cube_dim);\n\n    unsafe {\n        interpolate_nearest_backward_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(out_grad, output),\n            vector_size,\n            out_grad.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n            out_shape,\n            out_layout,\n            output.dtype.into(),\n        )\n    };\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/mask/base.rs",
    "content": "use burn_backend::DType;\nuse cubecl::prelude::InputScalar;\n\nuse super::{MaskFillStrategy, mask_where::MaskWhereStrategy};\nuse crate::{CubeRuntime, tensor::CubeTensor};\n\n/// Execute the mask fill kernel.\npub(crate) fn mask_fill_auto<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    mask: CubeTensor<R>,\n    value: InputScalar,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let strategy = if tensor.can_mut() && tensor.is_nonoverlapping() {\n        MaskFillStrategy::Inplace\n    } else {\n        MaskFillStrategy::Readonly\n    };\n\n    super::mask_fill(tensor, mask, value, strategy, dtype_bool)\n}\n\n/// Execute the mask where kernel.\npub(crate) fn mask_where_auto<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    mask: CubeTensor<R>,\n    value: CubeTensor<R>,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let strategy = if tensor.can_mut_broadcast(&value) {\n        MaskWhereStrategy::InplaceLhs\n    } else if value.can_mut_broadcast(&tensor) {\n        MaskWhereStrategy::InplaceRhs\n    } else {\n        MaskWhereStrategy::Readonly\n    };\n\n    super::mask_where(tensor, mask, value, strategy, dtype_bool)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/mask/mask_fill.rs",
    "content": "use burn_backend::{DType, TensorMetadata};\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size_many, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn mask_fill_kernel<T: Numeric, B: Int, N: Size>(\n    input: &LinearView<Vector<T, N>>,\n    mask: &LinearView<Vector<B, N>>,\n    output: &mut LinearView<Vector<T, N>, ReadWrite>,\n    value: InputScalar,\n    #[define(T, B)] _dtypes: [StorageType; 2],\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let mask = Vector::cast_from(mask[ABSOLUTE_POS]);\n    let input = input[ABSOLUTE_POS];\n    let value = Vector::new(value.get::<T>());\n\n    output[ABSOLUTE_POS] = select_many(mask, value, input);\n}\n\n#[derive(Clone, Copy, Debug)]\n/// Define how to run the mask fill kernel.\n///\n/// # Notes\n///\n/// All assertions should be done before choosing the strategy.\npub enum MaskFillStrategy {\n    /// Don't mutate any input.\n    Readonly,\n    /// Reuse the input tensor inplace.\n    Inplace,\n}\n\n/// Execute the mask fill kernel with the given strategy.\npub fn mask_fill<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    mask: CubeTensor<R>,\n    value: InputScalar,\n    strategy: MaskFillStrategy,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let ndims = input.meta.num_dims();\n    let output = match strategy {\n        MaskFillStrategy::Readonly => empty_device_dtype(\n            input.client.clone(),\n            input.device.clone(),\n            input.shape(),\n            input.dtype,\n        ),\n        MaskFillStrategy::Inplace => input.clone(),\n    };\n\n    let vector_size = max_vector_size_many(&[&input, &mask], ndims - 1);\n    let working_units = input.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    let out_arg = match strategy {\n        MaskFillStrategy::Readonly => output.clone().into_linear_view(),\n        MaskFillStrategy::Inplace => output.as_linear_view_alias(0),\n    };\n\n    let at = address_type!(input, mask, output);\n    let mask = mask.into_linear_view_like(&input);\n\n    unsafe {\n        mask_fill_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            at,\n            vector_size,\n            input.into_linear_view(),\n            mask,\n            out_arg,\n            value,\n            [output.dtype.into(), dtype_bool.into()],\n        );\n    }\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/mask/mask_where.rs",
    "content": "use burn_backend::DType;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\nuse crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, broadcast_shape},\n    ops::{max_vector_size_many, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn mask_where_kernel<T: Numeric, B: Int, N: Size>(\n    input: &LinearView<Vector<T, N>>,\n    value: &LinearView<Vector<T, N>>,\n    mask: &LinearView<Vector<B, N>>,\n    output: &mut LinearView<Vector<T, N>, ReadWrite>,\n    #[define(T, B)] _dtypes: [StorageType; 2],\n) {\n    let pos = ABSOLUTE_POS;\n    if !output.is_in_bounds(pos) {\n        terminate!();\n    }\n\n    output[pos] = select_many(Vector::cast_from(mask[pos]), value[pos], input[pos]);\n}\n\n#[derive(Clone, Copy, Debug)]\n/// Define how to run the mask where kernel.\n///\n/// # Notes\n///\n/// All assertions should be done before choosing the strategy.\npub enum MaskWhereStrategy {\n    /// Don't mutate any input.\n    Readonly,\n    /// Reuse the lhs tensor inplace.\n    InplaceLhs,\n    /// Reuse the rhs tensor inplace.\n    InplaceRhs,\n}\n\n/// Execute the mask where kernel with the given strategy.\npub fn mask_where<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    mask: CubeTensor<R>,\n    value: CubeTensor<R>,\n    strategy: MaskWhereStrategy,\n    dtype_bool: DType,\n) -> CubeTensor<R> {\n    let vector_size = max_vector_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1);\n\n    let working_units = input.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    let out_shape = broadcast_shape(&[&input, &mask, &value]);\n\n    let output = match strategy {\n        MaskWhereStrategy::Readonly => empty_device_dtype(\n            input.client.clone(),\n            input.device.clone(),\n            out_shape,\n            input.dtype,\n        ),\n        MaskWhereStrategy::InplaceLhs => input.clone(),\n        MaskWhereStrategy::InplaceRhs => value.clone(),\n    };\n\n    let out = match strategy {\n        MaskWhereStrategy::Readonly => output.clone().into_linear_view(),\n        MaskWhereStrategy::InplaceLhs => output.as_linear_view_alias(0),\n        MaskWhereStrategy::InplaceRhs => output.as_linear_view_alias(1),\n    };\n\n    mask_where_kernel::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, value, mask, output),\n        vector_size,\n        input.into_linear_view_like(&output),\n        value.into_linear_view_like(&output),\n        mask.into_linear_view_like(&output),\n        out,\n        [output.dtype.into(), dtype_bool.into()],\n    );\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/mask/mod.rs",
    "content": "mod base;\nmod mask_fill;\nmod mask_where;\n\npub(crate) use base::*;\n\npub use mask_fill::*;\npub use mask_where::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/matmul/base.rs",
    "content": "use super::init_matmul_output;\nuse crate::{CubeRuntime, kernel::quantization::dequantize, tensor::CubeTensor};\nuse burn_backend::{DType, QTensorPrimitive};\nuse burn_std::QuantLevel;\nuse cubek::matmul::{\n    definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError},\n    launch::{MatmulInputBinding, Strategy},\n};\n\n#[cfg(feature = \"autotune\")]\nuse super::matmul_autotune;\n\n/// The strategy to be used when launching a matmul kernel.\npub enum MatmulStrategy {\n    #[cfg(feature = \"autotune\")]\n    /// Using autotune to choose the best kernel based on runtime information.\n    Autotune,\n    /// Cube implementation of matmul.\n    Cube,\n}\n\nimpl Default for MatmulStrategy {\n    fn default() -> Self {\n        // if autotune is enabled, default to autotune\n        #[cfg(feature = \"autotune\")]\n        return MatmulStrategy::Autotune;\n\n        #[cfg(not(feature = \"autotune\"))]\n        MatmulStrategy::Cube\n    }\n}\n\n/// Launch a matmul kernel using the given strategy.\npub fn matmul<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    out: Option<CubeTensor<R>>,\n    strategy: MatmulStrategy,\n    out_dtype: DType,\n) -> Result<CubeTensor<R>, MatmulSetupError> {\n    match strategy {\n        MatmulStrategy::Cube => {\n            let out = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));\n            launch_matmul(&Default::default(), lhs, rhs, out.clone())?;\n            Ok(out)\n        }\n        #[cfg(feature = \"autotune\")]\n        MatmulStrategy::Autotune => Ok(matmul_autotune(lhs, rhs, out, out_dtype)),\n    }\n}\n\npub(crate) fn launch_matmul_naive<R: CubeRuntime>(\n    strategy: &Strategy,\n    mut lhs: CubeTensor<R>,\n    mut rhs: CubeTensor<R>,\n    out: CubeTensor<R>,\n) -> Result<(), MatmulSetupError> {\n    // Naive has very specific layout requirements for block scaled tensors, so we need to manually\n    // dequantize if it fails to launch normally. This is because naive is assumed to always work.\n    if lhs.qparams.is_some() || rhs.qparams.is_some() {\n        match launch_matmul(strategy, lhs.clone(), rhs.clone(), out.clone()) {\n            Err(_) => {\n                if lhs.qparams.is_some() {\n                    lhs = dequantize(lhs, out.dtype);\n                }\n                if rhs.qparams.is_some() {\n                    rhs = dequantize(rhs, out.dtype);\n                }\n                launch_matmul(strategy, lhs, rhs, out)\n            }\n            Ok(_) => Ok(()),\n        }\n    } else {\n        launch_matmul(strategy, lhs, rhs, out)\n    }\n}\n\npub(crate) fn launch_matmul<R: CubeRuntime>(\n    strategy: &Strategy,\n    lhs: CubeTensor<R>,\n    mut rhs: CubeTensor<R>,\n    out: CubeTensor<R>,\n) -> Result<(), MatmulSetupError> {\n    let client = &out.client;\n\n    let lhs_quant_handles = lhs.quantized_handles();\n    let out_dtype: DType = out.dtype;\n\n    let (lhs_dtype, lhs_handle) = match lhs_quant_handles {\n        None => {\n            let lhs_dtype = lhs.dtype;\n            (\n                lhs_dtype,\n                MatmulInputBinding::new(lhs.binding(), lhs_dtype.into()),\n            )\n        }\n        Some((data, scale)) => {\n            let scheme = *lhs.scheme();\n            let data_dtype = data.dtype;\n            let scale_dtype = scale.dtype;\n            (\n                out_dtype,\n                MatmulInputBinding::quantized(\n                    data.binding(),\n                    scale.binding(),\n                    lhs.meta.shape().clone(),\n                    scheme,\n                    data_dtype.into(),\n                    scale_dtype.into(),\n                ),\n            )\n        }\n    };\n\n    let rhs_quant_handles = rhs.quantized_handles();\n\n    let (rhs_dtype, rhs_handle) = match rhs_quant_handles {\n        None => (\n            lhs_dtype,\n            MatmulInputBinding::new(rhs.binding(), lhs_dtype.into()),\n        ),\n        Some((data, scale)) => {\n            // Extremely hacky fix to ensure naive can run in every case\n            if matches!(strategy, Strategy::Naive)\n                && matches!(rhs.scheme().level, QuantLevel::Block(_))\n            {\n                rhs = dequantize(rhs.clone(), lhs_dtype);\n                let rhs_dtype = rhs.dtype;\n                (\n                    lhs_dtype,\n                    MatmulInputBinding::new(rhs.binding(), rhs_dtype.into()),\n                )\n            } else {\n                let scheme = *rhs.scheme();\n                let data_dtype = data.dtype;\n                let scale_dtype = scale.dtype;\n                (\n                    out_dtype,\n                    MatmulInputBinding::quantized(\n                        data.binding(),\n                        scale.binding(),\n                        rhs.meta.shape().clone(),\n                        scheme,\n                        data_dtype.into(),\n                        scale_dtype.into(),\n                    ),\n                )\n            }\n        }\n    };\n\n    let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems {\n        lhs: lhs_dtype.into(),\n        rhs: rhs_dtype.into(),\n        out: out_dtype.into(),\n    });\n\n    cubek::matmul::launch::launch_ref(\n        strategy,\n        client,\n        lhs_handle,\n        rhs_handle,\n        out.clone().binding(),\n        &mut dtypes,\n    )?;\n\n    Ok(())\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/matmul/mod.rs",
    "content": "mod base;\nmod tune;\n\n/// Contains utilities for matmul operation\npub mod utils;\n\npub use base::*;\n#[cfg(feature = \"autotune\")]\npub use tune::*;\npub use utils::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/matmul/tune/base.rs",
    "content": "use crate::{\n    CubeRuntime, CubeTuneId,\n    kernel::matmul::{launch_matmul, launch_matmul_naive, utils::init_matmul_output},\n    tensor::CubeTensor,\n};\nuse burn_backend::DType;\nuse cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner};\nuse cubek::matmul::{\n    definition::MatmulKind,\n    launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering},\n    routines::{\n        BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs,\n        double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs,\n        simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs,\n    },\n};\n\nfn matmul_input_gen<R: CubeRuntime>(\n    _key: &MatmulAutotuneKey,\n    lhs: &CubeTensor<R>,\n    rhs: &CubeTensor<R>,\n    out: &CubeTensor<R>,\n) -> (CubeTensor<R>, CubeTensor<R>, CubeTensor<R>) {\n    (lhs.clone(), rhs.clone(), out.copy())\n}\n\n/// Executes autotune on matmul operations\npub fn matmul_autotune<R: CubeRuntime>(\n    lhs: CubeTensor<R>,\n    rhs: CubeTensor<R>,\n    out: Option<CubeTensor<R>>,\n    out_dtype: DType,\n) -> CubeTensor<R> {\n    let output = out.unwrap_or_else(|| init_matmul_output(&lhs, &rhs, out_dtype));\n\n    let client = lhs.client.clone();\n\n    static TUNER: LocalTuner<MatmulAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 3;\n        const PRIORITY_HIGH: i8 = 2;\n        const PRIORITY_MEDIUM: i8 = 1;\n        const PRIORITY_MIN: i8 = 0;\n        const PRIORITY_NEVER: i8 = -1;\n\n        let cmma = TuneGroup::<MatmulAutotuneKey>::new(\"cmma\", |key| {\n            if matches!(\n                key.analysis.kind,\n                MatmulKind::General\n                // Those variants are just because the unit alternatives aren't very good yet.\n                | MatmulKind::VecMat | MatmulKind::MatVec\n            ) {\n                PRIORITY_HIGH\n            } else {\n                PRIORITY_MEDIUM\n            }\n        });\n\n        let mma = TuneGroup::<MatmulAutotuneKey>::new(\"mma\", |key| {\n            if matches!(\n                key.analysis.kind,\n                // General is usually bad, but I think shapes like 16x8196 would be classed as\n                // general and are very good with MMA\n                // Should highly degenerated matrices that aren't VecMat have their own class?\n                MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec\n            ) {\n                PRIORITY_HIGH\n            } else {\n                PRIORITY_MEDIUM\n            }\n        });\n\n        let unit = TuneGroup::<MatmulAutotuneKey>::new(\"unit\", |key| {\n            if !matches!(key.analysis.kind, MatmulKind::General)\n                || matches!(key.analysis.scale_global, MatmulGlobalScale::Small)\n            {\n                PRIORITY_HIGH\n            } else {\n                PRIORITY_MIN\n            }\n        });\n\n        let tma = TuneGroup::<MatmulAutotuneKey>::new(\"tma\", |key| {\n            // For large matmul, we set the max priority to TMA kernels, higher than any other\n            // matmuls, since they are the best kernels no matter what.\n            //\n            // But only when all axis are large.\n            let max_axis = usize::max(key.definition.m, key.definition.n);\n            let max_axis = usize::max(key.definition.k, max_axis);\n\n            let min_axis = usize::min(key.definition.m, key.definition.n);\n            let min_axis = usize::min(key.definition.k, min_axis);\n\n            let skewed_factor = max_axis / min_axis;\n\n            let priority_max = if matches!(key.analysis.kind, MatmulKind::General)\n                && matches!(key.analysis.scale_global, MatmulGlobalScale::Large)\n                && skewed_factor < 4\n            {\n                PRIORITY_MAX\n            } else {\n                PRIORITY_HIGH\n            };\n\n            if key.definition.lhs_stride_factor >= 4 && key.definition.rhs_stride_factor >= 4 {\n                priority_max\n            } else {\n                PRIORITY_NEVER\n            }\n        });\n\n        fn double_buffering_priority(key: &MatmulAutotuneKey, max: i8, min: i8) -> i8 {\n            if should_tune_double_buffering(false, key) {\n                max\n            } else {\n                min\n            }\n        }\n\n        let mut set = TunableSet::new(create_key::<R>, matmul_input_gen::<R>);\n\n        // First entry should always work, since it is considered the fallback.\n        set = set.with(\n            Tunable::new(\"matmul_naive\", |lhs, rhs, out| {\n                launch_matmul_naive::<R>(&Strategy::Naive, lhs, rhs, out)\n                    .map_err(|err| std::format!(\"{err:?}\"))\n            })\n            .group(&unit, |key| {\n                if matches!(key.analysis.scale_global, MatmulGlobalScale::Small)\n                    || matches!(key.analysis.kind, MatmulKind::InnerProduct)\n                {\n                    PRIORITY_MAX\n                } else {\n                    PRIORITY_MIN\n                }\n            }),\n        );\n\n        // Unit VecMat\n        for (strategy, double_buf) in [\n            (\n                Strategy::SimpleVecMat(BlueprintStrategy::Inferred(().into())),\n                false,\n            ),\n            (\n                Strategy::DoubleVecMat(BlueprintStrategy::Inferred(().into())),\n                true,\n            ),\n        ] {\n            set = set.with(\n                Tunable::new(strategy.to_string(), move |lhs, rhs, out| {\n                    launch_matmul::<R>(&strategy, lhs, rhs, out)\n                        .map_err(|err| std::format!(\"{err:?}\"))\n                })\n                .group(&unit, move |key| match double_buf {\n                    false => PRIORITY_MAX,\n                    true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),\n                }),\n            );\n        }\n\n        // Unit matmuls\n        for tile_size in [\n            TileSizeSelection::MaxTileSize,\n            TileSizeSelection::MinTileSize,\n        ] {\n            for (strategy, double_buf) in [\n                (\n                    Strategy::SimpleUnit(BlueprintStrategy::Inferred(SimpleUnitSelectionArgs {\n                        tile_size,\n                    })),\n                    false,\n                ),\n                (\n                    Strategy::DoubleUnit(BlueprintStrategy::Inferred(DoubleUnitSelectionArgs {\n                        tile_size,\n                    })),\n                    true,\n                ),\n            ] {\n                set = set.with(\n                    Tunable::new(strategy.to_string(), move |lhs, rhs, out| {\n                        launch_matmul::<R>(&strategy, lhs, rhs, out)\n                            .map_err(|err| format!(\"{err:?}\"))\n                    })\n                    .group(&unit, move |key| match double_buf {\n                        false => PRIORITY_MAX,\n                        true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),\n                    }),\n                )\n            }\n        }\n\n        // Accelerated matmuls\n        for (strategy, double_buf, group_extra, tile_group) in [\n            (\n                Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: false,\n                })),\n                false,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: false,\n                })),\n                false,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::SimpleCyclicCmma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: true,\n                })),\n                false,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::SimpleCyclicMma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: true,\n                })),\n                false,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {\n                    partition_k: Some(2),\n                    row_count: Some(4),\n                    rows_per_plane: Some(2),\n                })),\n                true,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {\n                    partition_k: Some(2),\n                    row_count: Some(4),\n                    rows_per_plane: Some(2),\n                })),\n                true,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::OrderedDoubleCmma(BlueprintStrategy::Inferred(OrderedSelectionArgs {\n                    partition_k: Some(2),\n                    row_count: Some(8),\n                    rows_per_plane: Some(2),\n                })),\n                true,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::OrderedDoubleMma(BlueprintStrategy::Inferred(OrderedSelectionArgs {\n                    partition_k: Some(2),\n                    row_count: Some(8),\n                    rows_per_plane: Some(2),\n                })),\n                true,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {\n                    specialized: false,\n                })),\n                true,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {\n                    specialized: false,\n                })),\n                true,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::DoubleCyclicCmma(BlueprintStrategy::Inferred(DoubleBufferingArgs {\n                    specialized: true,\n                })),\n                true,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::DoubleCyclicMma(BlueprintStrategy::Inferred(DoubleBufferingArgs {\n                    specialized: true,\n                })),\n                true,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::SpecializedCyclicCmma(BlueprintStrategy::Inferred(().into())),\n                true,\n                None,\n                &cmma,\n            ),\n            (\n                Strategy::SpecializedCyclicMma(BlueprintStrategy::Inferred(().into())),\n                true,\n                None,\n                &mma,\n            ),\n            (\n                Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: false,\n                })),\n                false,\n                Some(&tma),\n                &cmma,\n            ),\n            (\n                Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: false,\n                })),\n                false,\n                Some(&tma),\n                &mma,\n            ),\n            (\n                Strategy::SimpleTmaCmma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: true,\n                })),\n                false,\n                Some(&tma),\n                &cmma,\n            ),\n            (\n                Strategy::SimpleTmaMma(BlueprintStrategy::Inferred(SimpleArgs {\n                    multi_rows: true,\n                })),\n                false,\n                Some(&tma),\n                &mma,\n            ),\n            (\n                Strategy::SpecializedTmaCmma(BlueprintStrategy::Inferred(().into())),\n                true,\n                Some(&tma),\n                &cmma,\n            ),\n            (\n                Strategy::SpecializedTmaMma(BlueprintStrategy::Inferred(().into())),\n                true,\n                Some(&tma),\n                &mma,\n            ),\n        ] {\n            let priority_within_group = |key: &MatmulAutotuneKey, double_buf: bool| match double_buf\n            {\n                false => PRIORITY_MAX,\n                true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),\n            };\n            let mut tunable = Tunable::new(strategy.to_string(), move |lhs, rhs, out| {\n                launch_matmul::<R>(&strategy, lhs, rhs, out).map_err(|err| format!(\"{err:?}\"))\n            });\n\n            // tile group\n            tunable = tunable.group(tile_group, move |key| {\n                priority_within_group(key, double_buf)\n            });\n\n            // extra group\n            if let Some(group) = group_extra {\n                tunable = tunable.group(group, move |key| priority_within_group(key, double_buf));\n            }\n            set = set.with(tunable);\n        }\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&lhs.client, &lhs.device),\n        &client,\n        tunables,\n        (lhs, rhs, output.clone()),\n    );\n\n    output\n}\n\nfn create_key<R: CubeRuntime>(\n    lhs: &CubeTensor<R>,\n    rhs: &CubeTensor<R>,\n    out: &CubeTensor<R>,\n) -> MatmulAutotuneKey {\n    MatmulAutotuneKey::generate(\n        &lhs.client,\n        lhs.meta.shape(),\n        rhs.meta.shape(),\n        lhs.meta.strides(),\n        rhs.meta.strides(),\n        lhs.dtype.into(),\n        rhs.dtype.into(),\n        out.dtype.into(),\n        lhs.try_scheme(),\n        rhs.try_scheme(),\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/matmul/tune/mod.rs",
    "content": "#[cfg(feature = \"autotune\")]\nmod base;\n\n#[cfg(feature = \"autotune\")]\npub use base::matmul_autotune;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/matmul/utils.rs",
    "content": "use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::{DType, calculate_matmul_output};\n\n/// Creates an empty output tensor with matmul output shape\npub fn init_matmul_output<R: CubeRuntime>(\n    lhs: &CubeTensor<R>,\n    rhs: &CubeTensor<R>,\n    dtype: DType,\n) -> CubeTensor<R> {\n    empty_device_dtype(\n        lhs.client.clone(),\n        lhs.device.clone(),\n        calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(),\n        dtype,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/mod.rs",
    "content": "mod binary;\nmod binary_float;\nmod binary_int;\nmod cast;\nmod clamp;\nmod comparison;\nmod contiguous;\nmod cross;\nmod index;\nmod mask;\nmod unary_float;\nmod unary_int;\nmod unary_numeric;\n\npub(crate) use binary::*;\npub(crate) use binary_float::*;\npub(crate) use binary_int::*;\npub use cast::*;\npub use contiguous::*;\npub(crate) use cross::*;\npub use mask::*;\npub(crate) use unary_float::*;\npub(crate) use unary_int::*;\npub(crate) use unary_numeric::*;\n\npub use crate::cubecl::prelude::KernelMetadata;\n\n/// Attention kernels\npub mod attention;\n/// Convolution kernels\npub mod conv;\n/// Grid sampling kernels\npub mod grid_sample;\n/// Interpolation kernels\npub mod interpolate;\n/// Matmul kernels\npub mod matmul;\n/// Pooling kernels\npub mod pool;\n/// Pseudo-random number generator kernels\npub mod prng;\n/// Quantization operations\npub mod quantization;\n/// Reduction algorithms\npub mod reduce;\n\npub(crate) use clamp::*;\npub(crate) use comparison::*;\npub use index::*;\n\npub(crate) mod utils;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        into_contiguous_aligned,\n        pool::pool2d::{Position, view4d},\n        utils::{address_type, decompose_linear, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::Shape;\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    num_traits::Zero,\n    prelude::*,\n    std::{FastDivmod, tensor::View},\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn adaptive_avg_pool2d_direct<E: Numeric, N: Size>(\n    input: &Tensor<Vector<E, N>>,\n    output: &mut View<Vector<E, N>, Position, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    working_units: usize,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= working_units {\n        terminate!();\n    }\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);\n    let [b, oh, ow, c] = *pos else { unreachable!() };\n\n    let (_, out_h, out_w, _) = output.shape();\n    let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));\n    let (in_h, in_w) = (input.shape(1), input.shape(2));\n\n    let ih_start = start_index(oh, out_h, in_h);\n    let ih_end = end_index(oh, out_h, in_h);\n\n    let iw_start = start_index(ow, out_w, in_w);\n    let iw_end = end_index(ow, out_w, in_w);\n\n    let mut sum = Vector::zero();\n\n    let index_input_base = b * input.stride(0) + c * input.stride(3);\n\n    for ih in ih_start..ih_end {\n        let index_input_2 = ih * in_stride_h;\n\n        for iw in iw_start..iw_end {\n            let index_input_3 = iw * in_stride_w;\n\n            let index_input = index_input_base + index_input_2 + index_input_3;\n            sum += input[index_input / input.vector_size()];\n        }\n    }\n\n    let num_ih = ih_end - ih_start;\n    let num_iw = iw_end - iw_start;\n\n    output[(b, oh, ow, c)] = sum / Vector::cast_from(num_ih * num_iw);\n}\n\n#[cube]\nfn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    (output_size_index * input_size) / output_size\n}\n\n#[cube]\nfn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    let index = (output_size_index + 1) * input_size;\n    let index = index.div_ceil(output_size);\n\n    if input_size < index {\n        input_size\n    } else {\n        index\n    }\n}\n\npub(crate) fn adaptive_avg_pool2d<R: CubeRuntime>(\n    input: CubeTensor<R>,\n    output_size: [usize; 2],\n) -> CubeTensor<R> {\n    let [batch_size, channels, _, _] = input.meta.shape().dims();\n\n    let input = into_contiguous_aligned(permute_nchw_to_nhwc(input));\n    let vector_size = max_vector_size(&input);\n\n    let output_shape = Shape::new([batch_size, output_size[0], output_size[1], channels]);\n    let num_elems: usize = output_shape.num_elements();\n    let output = empty_device_dtype(\n        input.client.clone(),\n        input.device.clone(),\n        output_shape,\n        input.dtype,\n    );\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&input.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim);\n\n    adaptive_avg_pool2d_direct::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(input, output),\n        vector_size,\n        input.into_tensor_arg(),\n        view4d(output.clone(), vector_size),\n        shape_divmod(&output),\n        working_units,\n        output.dtype.into(),\n    );\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        into_contiguous_aligned,\n        pool::pool2d::{Position, view4d},\n        utils::{address_type, decompose_linear, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::Shape;\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    num_traits::Zero,\n    prelude::*,\n    std::{FastDivmod, tensor::View},\n};\n\n#[cube(launch, address_type = \"dynamic\")]\nfn adaptive_avg_pool2d_backward_direct<E: Numeric, N: Size>(\n    grad: &Tensor<Vector<E, N>>,\n    output: &mut View<Vector<E, N>, Position, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    working_units: usize,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= working_units {\n        terminate!();\n    }\n\n    let (_, out_h, out_w, _) = output.shape();\n    let (grad_stride_h, grad_stride_w) = (grad.stride(1), grad.stride(2));\n    let (grad_h, grad_w) = (grad.shape(1), grad.shape(2));\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);\n    let [b, ih, iw, c] = *pos else { unreachable!() };\n\n    let oh_start = start_index(ih, out_h, grad_h);\n    let oh_end = end_index(ih, out_h, grad_h);\n\n    let ow_start = start_index(iw, out_w, grad_w);\n    let ow_end = end_index(iw, out_w, grad_w);\n\n    let mut grad_acc = Vector::zero();\n\n    let index_base = b * grad.stride(0) + (c * grad.stride(3));\n\n    for oh in oh_start..oh_end {\n        let ih_start = start_index(oh, grad_h, out_h);\n        let ih_end = end_index(oh, grad_h, out_h);\n\n        if ih >= ih_start && ih < ih_end {\n            for ow in ow_start..ow_end {\n                let iw_start = start_index(ow, grad_w, out_w);\n                let iw_end = end_index(ow, grad_w, out_w);\n\n                if iw >= iw_start && iw < iw_end {\n                    let num_ih = ih_end - ih_start;\n                    let num_iw = iw_end - iw_start;\n\n                    let index = index_base + (oh * grad_stride_h) + (ow * grad_stride_w);\n                    grad_acc +=\n                        grad[index / grad.vector_size()] / Vector::cast_from(num_iw * num_ih);\n                }\n            }\n        }\n    }\n\n    output[(b, ih, iw, c)] = grad_acc;\n}\n\n#[cube]\nfn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    (output_size_index * input_size) / output_size\n}\n\n#[cube]\nfn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    let index = (output_size_index + 1) * input_size;\n    let index = index.div_ceil(output_size);\n\n    if input_size < index {\n        input_size\n    } else {\n        index\n    }\n}\n\npub(crate) fn adaptive_avg_pool2d_backward<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    out_grad: CubeTensor<R>,\n) -> CubeTensor<R> {\n    let [batches, channels, height, width] = x.meta.shape().dims();\n\n    let out_grad = into_contiguous_aligned(permute_nchw_to_nhwc(out_grad));\n    let vector_size = max_vector_size(&out_grad);\n\n    let out_shape = Shape::new([batches, height, width, channels]);\n    let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);\n\n    let num_elems = output.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n\n    adaptive_avg_pool2d_backward_direct::launch(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(out_grad, output),\n        vector_size,\n        out_grad.into_tensor_arg(),\n        view4d(output.clone(), vector_size),\n        shape_divmod(&output),\n        working_units,\n        output.dtype.into(),\n    );\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs",
    "content": "use super::pool2d::{\n    Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,\n};\nuse crate::{\n    CubeRuntime,\n    kernel::{\n        into_contiguous_aligned,\n        pool::pool2d::{Position, view4d},\n        utils::{address_type, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::{Shape, ops::conv::calculate_pool_output_size};\nuse cubecl::{CubeDim, calculate_cube_count_elemwise, num_traits::Zero};\nuse cubecl::{prelude::*, std::tensor::View};\n\nstruct AvgPoolStrategy;\n\nimpl Pool2dDirectStrategyFamily for AvgPoolStrategy {\n    type Indices<N: Size> = ();\n    type Config = AvgPoolStrategyConfig;\n    type Pool2d<T: Numeric, N: Size> = Self;\n}\n\n#[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)]\npub struct AvgPoolStrategyConfig {\n    count_include_pad: bool,\n    /// Total padded height (input_height + 2 * padding_0)\n    padded_h: u32,\n    /// Total padded width (input_width + 2 * padding_1)\n    padded_w: u32,\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> Pool2dDirectStrategy<T, N> for AvgPoolStrategy {\n    type Accumulator = (Vector<T, N>, u32);\n    type Config = AvgPoolStrategyConfig;\n    type Indices = ();\n\n    fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator {\n        let sum = Vector::zero();\n        // Count will be set dynamically: either by accumulate (count_include_pad=false)\n        // or by set_padded_count (count_include_pad=true)\n        let count = 0u32;\n\n        (sum, count)\n    }\n\n    fn accumulate(\n        #[comptime] config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        _index: usize,\n        result: Vector<T, N>,\n    ) {\n        let (sum, count) = accumulator;\n\n        // Only count valid positions when count_include_pad=false\n        if comptime![!config.count_include_pad] {\n            *count += 1;\n        }\n\n        *sum += result;\n    }\n\n    fn count_position(\n        #[comptime] config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        ih: u32,\n        iw: u32,\n    ) {\n        // When count_include_pad=true, count positions within padded bounds\n        // (excludes ceil_mode extensions beyond the padded input)\n        if comptime![config.count_include_pad] && ih < config.padded_h && iw < config.padded_w {\n            let (_sum, count) = accumulator;\n            *count += 1;\n        }\n    }\n\n    fn store(\n        #[comptime] _config: &Self::Config,\n        position: Position,\n        output: &mut View<Vector<T, N>, Position, ReadWrite>,\n        _output_indices: &mut (),\n        accumulator: Self::Accumulator,\n    ) {\n        let (sum, count) = accumulator;\n        output[position] = sum / Vector::cast_from(count);\n    }\n}\n\npub(crate) fn avg_pool2d<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> CubeTensor<R> {\n    let [batch_size, channels, in_h, in_w] = x.meta.shape().dims();\n    let dilation = 1;\n\n    let size_0 = calculate_pool_output_size(\n        kernel_size[0],\n        stride[0],\n        padding[0],\n        dilation,\n        in_h,\n        ceil_mode,\n    );\n    let size_1 = calculate_pool_output_size(\n        kernel_size[1],\n        stride[1],\n        padding[1],\n        dilation,\n        in_w,\n        ceil_mode,\n    );\n\n    // Padded dimensions (for count_include_pad with ceil_mode)\n    let padded_0 = in_h + 2 * padding[0];\n    let padded_1 = in_w + 2 * padding[1];\n\n    let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));\n    let vector_size = max_vector_size(&x);\n\n    let shape_out = Shape::new([batch_size, size_0, size_1, channels]);\n    let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n\n    pool2d_direct::launch::<AvgPoolStrategy, R>(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(x, output),\n        vector_size,\n        x.into_tensor_arg(),\n        view4d(output.clone(), vector_size),\n        (),\n        shape_divmod(&output),\n        working_units,\n        Pool2dDirectArgsLaunch::new(\n            stride[0] as u32,\n            stride[1] as u32,\n            dilation as u32,\n            dilation as u32,\n            padding[0] as u32,\n            padding[1] as u32,\n        ),\n        (kernel_size[0] as u32, kernel_size[1] as u32),\n        AvgPoolStrategyConfig {\n            count_include_pad,\n            padded_h: padded_0 as u32,\n            padded_w: padded_1 as u32,\n        },\n        output.dtype.into(),\n    );\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        pool::pool2d::{Position, view4d},\n        utils::{address_type, decompose_linear, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::Shape;\nuse cubecl::{\n    calculate_cube_count_elemwise,\n    num_traits::Zero,\n    prelude::*,\n    std::{FastDivmod, tensor::View},\n};\n\n#[derive(CubeLaunch, CubeType)]\npub(crate) struct PoolBackwardArgs {\n    pub stride_0: i32,\n    pub stride_1: i32,\n    pub dilation_0: i32,\n    pub dilation_1: i32,\n    pub padding_0: i32,\n    pub padding_1: i32,\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn avg_pool2d_backward_kernel<E: Numeric, N: Size>(\n    grad: &Tensor<Vector<E, N>>,\n    output: &mut View<Vector<E, N>, Position, ReadWrite>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    working_units: usize,\n    args: &PoolBackwardArgs,\n    #[comptime] kernel_size_0: i32,\n    #[comptime] kernel_size_1: i32,\n    #[comptime] count_include_pad: bool,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= working_units {\n        terminate!();\n    }\n\n    let vector_size = grad.vector_size();\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);\n    let [batch, ih, iw, channel] = *pos else {\n        unreachable!()\n    };\n\n    let mut grad_acc = Vector::zero();\n\n    let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(\n        ih as i32,\n        iw as i32,\n        grad.shape(1) as u32,\n        grad.shape(2) as u32,\n        args,\n        kernel_size_0,\n        kernel_size_1,\n    );\n\n    let padding_0 = args.padding_0 as u32;\n    let padding_1 = args.padding_1 as u32;\n    let stride_0 = args.stride_0 as u32;\n    let stride_1 = args.stride_1 as u32;\n    let kernel_size_0 = comptime![kernel_size_0 as u32];\n    let kernel_size_1 = comptime![kernel_size_1 as u32];\n\n    let index_base = batch * grad.stride(0) + channel * grad.stride(3);\n    let border_bottom = output.shape().1 as u32 + padding_0;\n    let border_right = output.shape().2 as u32 + padding_1;\n    let begin_h = ih as u32 + padding_0;\n    let begin_w = iw as u32 + padding_1;\n\n    for oh in oh_start..oh_end {\n        let ih_start = oh * stride_0;\n        let ih_end = clamp_max(ih_start + kernel_size_0, border_bottom);\n        let ih_start = clamp_min(ih_start, padding_0);\n\n        if begin_h >= ih_start && (ih as u32) < ih_end {\n            for ow in ow_start..ow_end {\n                let index =\n                    index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);\n\n                let iw_start = ow * stride_1;\n                let iw_end = clamp_max(iw_start + kernel_size_1, border_right);\n                let iw_start = clamp_min(iw_start, padding_1);\n\n                if begin_w >= iw_start && (iw as u32) < iw_end {\n                    if count_include_pad {\n                        grad_acc += grad[index / vector_size]\n                            / Vector::cast_from(kernel_size_0 * kernel_size_1);\n                    } else {\n                        let ih_diff = ih_end - ih_start;\n                        let iw_diff = iw_end - iw_start;\n                        let count = Vector::cast_from(ih_diff * iw_diff);\n                        grad_acc += grad[index / vector_size] / count;\n                    }\n                }\n            }\n        }\n    }\n\n    output[(batch, ih, iw, channel)] = grad_acc;\n}\n\n#[cube]\nfn loop_ranges(\n    ih: i32,\n    iw: i32,\n    grad_h: u32,\n    grad_w: u32,\n    args: &PoolBackwardArgs,\n    #[comptime] kernel_size_0: i32,\n    #[comptime] kernel_size_1: i32,\n) -> (u32, u32, u32, u32) {\n    let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;\n    let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;\n\n    let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;\n    let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;\n    let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;\n    let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;\n\n    (oh_start, oh_end, ow_start, ow_end)\n}\n\npub(crate) fn avg_pool2d_backward<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    grad: CubeTensor<R>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    count_include_pad: bool,\n    _ceil_mode: bool,\n) -> CubeTensor<R> {\n    let [batches, channels, height, width] = x.meta.shape().dims();\n\n    let grad = permute_nchw_to_nhwc(grad);\n\n    let vector_size = if x.meta.strides()[3] == grad.meta.strides()[3] {\n        max_vector_size(&x)\n    } else {\n        1\n    };\n\n    let dilation = 1;\n\n    let out_shape = Shape::new([batches, height, width, channels]);\n    let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n\n    unsafe {\n        avg_pool2d_backward_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(grad, output),\n            vector_size,\n            grad.into_tensor_arg(),\n            view4d(output.clone(), vector_size),\n            shape_divmod(&output),\n            working_units,\n            PoolBackwardArgsLaunch::new(\n                stride[0] as i32,\n                stride[1] as i32,\n                dilation,\n                dilation,\n                padding[0] as i32,\n                padding[1] as i32,\n            ),\n            kernel_size[0] as i32,\n            kernel_size[1] as i32,\n            count_include_pad,\n            output.dtype.into(),\n        )\n    };\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/max_pool2d.rs",
    "content": "use super::pool2d::{\n    Pool2dDirectArgsLaunch, Pool2dDirectStrategy, Pool2dDirectStrategyFamily, pool2d_direct,\n};\nuse crate::{\n    CubeRuntime,\n    kernel::{\n        into_contiguous_aligned,\n        pool::pool2d::{Position, view4d},\n        utils::{address_type, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, Shape, ops::conv::calculate_pool_output_size};\nuse cubecl::{\n    CubeDim, calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::tensor::View,\n};\n\nstruct MaxPoolStrategy;\nstruct MaxPoolWithIndicesStrategy;\n\nimpl Pool2dDirectStrategyFamily for MaxPoolStrategy {\n    type Indices<N: Size> = ();\n    type Config = ();\n    type Pool2d<T: Numeric, N: Size> = Self;\n}\n\nimpl Pool2dDirectStrategyFamily for MaxPoolWithIndicesStrategy {\n    type Indices<N: Size> = View<Vector<i32, N>, Position, ReadWrite>;\n    type Config = ();\n    type Pool2d<T: Numeric, N: Size> = Self;\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> Pool2dDirectStrategy<T, N> for MaxPoolStrategy {\n    type Accumulator = Vector<T, N>;\n    type Config = ();\n    type Indices = ();\n\n    fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator {\n        Vector::new(T::min_value())\n    }\n\n    fn accumulate(\n        #[comptime] _config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        _index: VectorSize,\n        result: Vector<T, N>,\n    ) {\n        *accumulator = max(*accumulator, result);\n    }\n\n    fn count_position(\n        #[comptime] _config: &Self::Config,\n        _accumulator: &mut Self::Accumulator,\n        _ih: u32,\n        _iw: u32,\n    ) {\n    }\n\n    fn store(\n        #[comptime] _config: &Self::Config,\n        position: Position,\n        output: &mut View<Vector<T, N>, Position, ReadWrite>,\n        _output_indices: &mut (),\n        accumulator: Self::Accumulator,\n    ) {\n        output[position] = accumulator;\n    }\n}\n\n#[cube]\nimpl<T: Numeric, N: Size> Pool2dDirectStrategy<T, N> for MaxPoolWithIndicesStrategy {\n    type Accumulator = (Vector<T, N>, Vector<i32, N>);\n    type Config = ();\n    type Indices = View<Vector<i32, N>, Position, ReadWrite>;\n\n    fn initialize(#[comptime] _config: &Self::Config) -> Self::Accumulator {\n        let val = Vector::new(T::min_value());\n        let idx = Vector::zero();\n        (val, idx)\n    }\n\n    fn accumulate(\n        #[comptime] _config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        index: usize,\n        result: Vector<T, N>,\n    ) {\n        let indices = Vector::cast_from(index);\n        accumulator.1 = select_many(result.greater_than(accumulator.0), indices, accumulator.1);\n        accumulator.0 = max(result, accumulator.0);\n    }\n\n    fn count_position(\n        #[comptime] _config: &Self::Config,\n        _accumulator: &mut Self::Accumulator,\n        _ih: u32,\n        _iw: u32,\n    ) {\n    }\n\n    fn store(\n        #[comptime] _config: &Self::Config,\n        position: Position,\n        output: &mut View<Vector<T, N>, Position, ReadWrite>,\n        output_indices: &mut View<Vector<i32, N>, Position, ReadWrite>,\n        accumulator: Self::Accumulator,\n    ) {\n        output[position] = accumulator.0;\n        output_indices[position] = accumulator.1;\n    }\n}\n\npub(crate) fn max_pool2d<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n) -> CubeTensor<R> {\n    let [batch_size, channels, height, width] = x.meta.shape().dims();\n\n    let size_0 = calculate_pool_output_size(\n        kernel_size[0],\n        stride[0],\n        padding[0],\n        dilation[0],\n        height,\n        ceil_mode,\n    );\n    let size_1 = calculate_pool_output_size(\n        kernel_size[1],\n        stride[1],\n        padding[1],\n        dilation[1],\n        width,\n        ceil_mode,\n    );\n\n    let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));\n\n    let vector_size = max_vector_size(&x);\n\n    let shape_out = Shape::new([batch_size, size_0, size_1, channels]);\n    let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n\n    pool2d_direct::launch::<MaxPoolStrategy, R>(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(x, output),\n        vector_size,\n        x.into_tensor_arg(),\n        view4d(output.clone(), vector_size),\n        (),\n        shape_divmod(&output),\n        working_units,\n        Pool2dDirectArgsLaunch::new(\n            stride[0] as u32,\n            stride[1] as u32,\n            dilation[0] as u32,\n            dilation[1] as u32,\n            padding[0] as u32,\n            padding[1] as u32,\n        ),\n        (kernel_size[0] as u32, kernel_size[1] as u32),\n        (),\n        output.dtype.into(),\n    );\n\n    permute_nhwc_to_nchw(output)\n}\n\npub(crate) fn max_pool2d_with_indices<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n    dtype_indices: DType,\n) -> (CubeTensor<R>, CubeTensor<R>) {\n    let [batch_size, channels, size_0, size_1] = x.meta.shape().dims();\n\n    let size_0 = calculate_pool_output_size(\n        kernel_size[0],\n        stride[0],\n        padding[0],\n        dilation[0],\n        size_0,\n        ceil_mode,\n    );\n    let size_1 = calculate_pool_output_size(\n        kernel_size[1],\n        stride[1],\n        padding[1],\n        dilation[1],\n        size_1,\n        ceil_mode,\n    );\n\n    let x = into_contiguous_aligned(permute_nchw_to_nhwc(x));\n    let vector_size = max_vector_size(&x);\n\n    let shape_out = Shape::new([batch_size, size_0, size_1, channels]);\n    let output = empty_device_dtype(\n        x.client.clone(),\n        x.device.clone(),\n        shape_out.clone(),\n        x.dtype,\n    );\n    let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n\n    pool2d_direct::launch::<MaxPoolWithIndicesStrategy, R>(\n        &output.client,\n        cube_count,\n        cube_dim,\n        address_type!(x, output, indices),\n        vector_size,\n        x.into_tensor_arg(),\n        view4d(output.clone(), vector_size),\n        view4d(indices.clone(), vector_size),\n        shape_divmod(&output),\n        working_units,\n        Pool2dDirectArgsLaunch::new(\n            stride[0] as u32,\n            stride[1] as u32,\n            dilation[0] as u32,\n            dilation[1] as u32,\n            padding[0] as u32,\n            padding[1] as u32,\n        ),\n        (kernel_size[0] as u32, kernel_size[1] as u32),\n        (),\n        output.dtype.into(),\n    );\n\n    let output = permute_nhwc_to_nchw(output);\n    let indices = permute_nhwc_to_nchw(indices);\n    (output, indices)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::{\n        into_contiguous_aligned,\n        utils::{address_type, decompose_linear, shape_divmod},\n    },\n    ops::{\n        max_vector_size, numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw,\n    },\n    tensor::CubeTensor,\n};\nuse burn_backend::Shape;\nuse cubecl::{calculate_cube_count_elemwise, num_traits::Zero, prelude::*, std::FastDivmod};\n\nuse super::{PoolBackwardArgs, PoolBackwardArgsLaunch};\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn max_pool2d_with_indices_backward_kernel<E: Numeric, I: Int, N: Size>(\n    grad: &Tensor<Vector<E, N>>,\n    indices: &Tensor<Vector<I, N>>,\n    output: &mut Tensor<Vector<E, N>>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    working_units: usize,\n    args: &PoolBackwardArgs,\n    #[comptime] kernel_size_0: i32,\n    #[comptime] kernel_size_1: i32,\n    #[define(E, I)] _dtypes: [StorageType; 2],\n) {\n    if ABSOLUTE_POS >= working_units {\n        terminate!();\n    }\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);\n    let [batch, ih, iw, channel] = *pos else {\n        unreachable!()\n    };\n\n    let vector_size = grad.vector_size();\n\n    let index_current = ih * output.shape(2) + iw;\n\n    let (oh_start, oh_end, ow_start, ow_end) = loop_ranges(\n        ih as i32,\n        iw as i32,\n        grad.shape(1) as u32,\n        grad.shape(2) as u32,\n        args,\n        kernel_size_0,\n        kernel_size_1,\n    );\n\n    let mut grad_acc = Vector::zero();\n\n    let grad_idx_base = batch * grad.stride(0) + channel * grad.stride(3);\n    let ind_idx_base = batch * indices.stride(0) + channel * indices.stride(3);\n\n    for oh in oh_start..oh_end {\n        for ow in ow_start..ow_end {\n            let grad_index =\n                grad_idx_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2);\n            let indices_index =\n                ind_idx_base + oh as usize * indices.stride(1) + ow as usize * indices.stride(2);\n            let index_max = Vector::<u32, N>::cast_from(indices[indices_index / vector_size]);\n\n            grad_acc += select_many(\n                index_max.equal(Vector::cast_from(index_current)),\n                grad[grad_index / vector_size],\n                Vector::zero(),\n            );\n        }\n    }\n\n    let index_output = batch * output.stride(0)\n        + ih * output.stride(1)\n        + iw * output.stride(2)\n        + channel * output.stride(3);\n\n    output[index_output / output.vector_size()] = grad_acc;\n}\n\n#[cube]\nfn loop_ranges(\n    ih: i32,\n    iw: i32,\n    grad_h: u32,\n    grad_w: u32,\n    args: &PoolBackwardArgs,\n    #[comptime] kernel_size_0: i32,\n    #[comptime] kernel_size_1: i32,\n) -> (u32, u32, u32, u32) {\n    let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0;\n    let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1;\n\n    let oh_start = clamp_min((ih + args.padding_0 - kms_0) / args.stride_0, 0) as u32;\n    let ow_start = clamp_min((iw + args.padding_1 - kms_1) / args.stride_1, 0) as u32;\n    let oh_end = clamp_max(clamp_min(kms_0, 0) as u32 + oh_start, grad_h - 1) + 1;\n    let ow_end = clamp_max(clamp_min(kms_1, 0) as u32 + ow_start, grad_w - 1) + 1;\n\n    (oh_start, oh_end, ow_start, ow_end)\n}\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn max_pool2d_with_indices_backward<R: CubeRuntime>(\n    x: CubeTensor<R>,\n    grad: CubeTensor<R>,\n    indices: CubeTensor<R>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    _ceil_mode: bool,\n) -> CubeTensor<R> {\n    let [batches, channels, height, width] = x.meta.shape().dims();\n\n    let grad = into_contiguous_aligned(permute_nchw_to_nhwc(grad));\n    let indices = into_contiguous_aligned(permute_nchw_to_nhwc(indices));\n\n    let vector_size = if grad.meta.strides()[3] == indices.meta.strides()[3] {\n        max_vector_size(&grad)\n    } else {\n        1\n    };\n\n    let out_shape = Shape::new([batches, height, width, channels]);\n    let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype);\n\n    let working_units = output.meta.num_elements() / vector_size as usize;\n    let cube_dim = CubeDim::new(&x.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim);\n    let indices_dtype = indices.dtype;\n    let x_dtype = x.dtype;\n\n    unsafe {\n        max_pool2d_with_indices_backward_kernel::launch_unchecked(\n            &output.client,\n            cube_count,\n            cube_dim,\n            address_type!(grad, indices, output),\n            vector_size,\n            grad.into_tensor_arg(),\n            indices.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n            shape_divmod(&output),\n            working_units,\n            PoolBackwardArgsLaunch::new(\n                stride[0] as i32,\n                stride[1] as i32,\n                dilation[0] as i32,\n                dilation[1] as i32,\n                padding[0] as i32,\n                padding[1] as i32,\n            ),\n            kernel_size[0] as i32,\n            kernel_size[1] as i32,\n            [x_dtype.into(), indices_dtype.into()],\n        )\n    };\n\n    permute_nhwc_to_nchw(output)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/mod.rs",
    "content": "mod adaptive_avg_pool2d;\nmod adaptive_avg_pool2d_backward;\nmod avg_pool2d;\nmod avg_pool2d_backward;\nmod max_pool2d;\nmod max_pool2d_backward;\n\npub(super) mod pool2d;\n\npub(crate) use adaptive_avg_pool2d::*;\npub(crate) use adaptive_avg_pool2d_backward::*;\npub(crate) use avg_pool2d::*;\npub(crate) use avg_pool2d_backward::*;\npub(crate) use max_pool2d::*;\npub(crate) use max_pool2d_backward::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/pool/pool2d.rs",
    "content": "use core::hash::Hash;\nuse cubecl::{\n    prelude::*,\n    std::{\n        FastDivmod,\n        tensor::{\n            View,\n            launch::ViewArg,\n            layout::fixed_dim::{FixedDimLayout, FixedDimLayoutLaunch},\n        },\n    },\n};\n\nuse crate::{CubeRuntime, kernel::utils::decompose_linear, tensor::CubeTensor};\n\npub trait Pool2dDirectStrategyFamily: Send + Sync + 'static {\n    type Indices<N: Size>: LaunchArg;\n    type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;\n    type Pool2d<T: Numeric, N: Size>: Pool2dDirectStrategy<T, N, Config = Self::Config, Indices = Self::Indices<N>>;\n}\n\npub(super) type Position = (usize, usize, usize, usize);\n\n#[cube]\npub(crate) trait Pool2dDirectStrategy<T: Numeric, N: Size>: Send + Sync + 'static {\n    type Accumulator: CubeType;\n    type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq;\n\n    type Indices: LaunchArg;\n\n    fn initialize(#[comptime] config: &Self::Config) -> Self::Accumulator;\n\n    fn accumulate(\n        #[comptime] config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        index: usize,\n        result: Vector<T, N>,\n    );\n\n    /// Count a position within the kernel window (for avg_pool count_include_pad).\n    /// Called for each position in the kernel window with the current ih/iw coordinates.\n    /// Only avg_pool uses this; max_pool implements as no-op.\n    fn count_position(\n        #[comptime] config: &Self::Config,\n        accumulator: &mut Self::Accumulator,\n        ih: u32,\n        iw: u32,\n    );\n\n    fn store(\n        #[comptime] config: &Self::Config,\n        position: Position,\n        output: &mut View<Vector<T, N>, Position, ReadWrite>,\n        output_indices: &mut Self::Indices,\n        accumulator: Self::Accumulator,\n    );\n}\n\n#[derive(CubeLaunch, CubeType)]\npub struct Pool2dDirectArgs {\n    pub strides_0: u32,\n    pub strides_1: u32,\n    pub dilation_0: u32,\n    pub dilation_1: u32,\n    pub padding_0: u32,\n    pub padding_1: u32,\n}\n\n#[cube(launch, address_type = \"dynamic\")]\npub fn pool2d_direct<E: Numeric, N: Size, S: Pool2dDirectStrategyFamily>(\n    input: &Tensor<Vector<E, N>>,\n    output: &mut View<Vector<E, N>, Position, ReadWrite>,\n    indices: &mut S::Indices<N>,\n    out_shape: Sequence<FastDivmod<usize>>,\n    working_units: usize,\n    args: &Pool2dDirectArgs,\n    #[comptime] kernel_size: (u32, u32),\n    #[comptime] config: &S::Config,\n    #[define(E)] _dtype: StorageType,\n) {\n    if ABSOLUTE_POS >= working_units {\n        terminate!();\n    }\n\n    let (_, pos) = decompose_linear(ABSOLUTE_POS * output.vector_size(), &out_shape);\n    let [b, oh, ow, c] = *pos else { unreachable!() };\n\n    let (in_stride_h, in_stride_w) = (input.stride(1), input.stride(2));\n    let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32);\n\n    let mut accumulator = S::Pool2d::<E, N>::initialize(config);\n\n    let in_b_off = b * input.stride(0);\n    let in_c_off = c * input.stride(3);\n\n    let border_bottom = in_h + args.padding_0;\n    let border_right = in_w + args.padding_1;\n\n    for kh in 0..kernel_size.0 {\n        let ih = oh as u32 * args.strides_0 + kh * args.dilation_0;\n        let within_padding_h = ih >= args.padding_0 && ih < border_bottom;\n\n        for kw in 0..kernel_size.1 {\n            let iw = ow as u32 * args.strides_1 + kw * args.dilation_1;\n            let within_padding_w = iw >= args.padding_1 && iw < border_right;\n\n            // Let strategy handle position counting (only used by avg_pool)\n            S::Pool2d::<E, N>::count_position(config, &mut accumulator, ih, iw);\n\n            // Only accumulate values from valid input positions\n            if within_padding_h && within_padding_w {\n                let ih_pad = ih - args.padding_0;\n                let iw_pad = iw - args.padding_1;\n\n                let in_h_off = ih_pad as usize * in_stride_h;\n                let in_w_off = iw_pad as usize * in_stride_w;\n\n                let index_input = in_b_off + in_c_off + in_h_off + in_w_off;\n\n                S::Pool2d::<E, N>::accumulate(\n                    config,\n                    &mut accumulator,\n                    ih_pad as usize * in_w as usize + iw_pad as usize,\n                    input[index_input / input.vector_size()],\n                );\n            }\n        }\n    }\n\n    S::Pool2d::<E, N>::store(config, (b, oh, ow, c), output, indices, accumulator);\n}\n\npub(super) fn view4d<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    vector_size: VectorSize,\n) -> ViewArg<Position, R> {\n    let shape = tensor.meta.shape();\n    let shape = (shape[0], shape[1], shape[2], shape[3]);\n    let binding = tensor.binding();\n    let layout = FixedDimLayoutLaunch::<Position, R>::from_shape_handle_unchecked(\n        &binding,\n        shape,\n        vector_size,\n    );\n    let buffer = binding.into_tensor_arg();\n    ViewArg::new_tensor::<FixedDimLayout<Position>>(buffer, layout)\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/prng/bernoulli.rs",
    "content": "use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::{DType, Shape};\n\n/// Pseudo-random generator with bernoulli distribution\npub fn random_bernoulli<R: CubeRuntime>(\n    shape: Shape,\n    device: &R::Device,\n    probability: f32,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);\n\n    cubek::random::random_bernoulli(&client, probability, output.clone().binding(), dtype.into())\n        .expect(\"Kernel to never fail\");\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/prng/mod.rs",
    "content": "mod bernoulli;\nmod normal;\nmod uniform;\n\npub use bernoulli::*;\npub use normal::*;\npub use uniform::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/prng/normal.rs",
    "content": "use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::{DType, Shape};\n\n/// Pseudo-random generator with uniform distribution\npub fn random_normal<R: CubeRuntime>(\n    shape: Shape,\n    device: &R::Device,\n    mean: f32,\n    std: f32,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);\n\n    cubek::random::random_normal(&client, mean, std, output.clone().binding(), dtype.into())\n        .expect(\"Kernel to never fail\");\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/prng/uniform.rs",
    "content": "use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::{DType, Shape, TensorMetadata};\n\n/// Pseudo-random generator with uniform distribution\npub fn random_uniform<R: CubeRuntime>(\n    shape: Shape,\n    device: &R::Device,\n    lower_bound: f32,\n    upper_bound: f32,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);\n\n    cubek::random::random_uniform(\n        &client,\n        lower_bound,\n        upper_bound,\n        output.clone().binding(),\n        dtype.into(),\n    )\n    .expect(\"Kernel to never fail\");\n\n    output\n}\n\n/// Pseudo-random generator for uniform distribution, based on\n/// another tensor.\npub fn random_like_uniform<R: CubeRuntime>(\n    tensor: &CubeTensor<R>,\n    lower_bound: f32,\n    upper_bound: f32,\n    dtype: DType,\n) -> CubeTensor<R> {\n    random_uniform(\n        tensor.shape(),\n        &tensor.device,\n        lower_bound,\n        upper_bound,\n        dtype,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/quantization/dequantize.rs",
    "content": "use crate::tensor::CubeTensor;\nuse crate::{CubeRuntime, ops::numeric::empty_device_dtype};\nuse burn_backend::{DType, TensorMetadata};\n\n/// Convert the tensor back to a higher precision data type.\npub fn dequantize<R>(tensor: CubeTensor<R>, dtype: DType) -> CubeTensor<R>\nwhere\n    R: CubeRuntime,\n{\n    let scheme = match tensor.dtype {\n        DType::QFloat(scheme) => scheme,\n        _ => return tensor,\n    };\n\n    let output = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        tensor.shape(),\n        dtype,\n    );\n    let (values, params) = tensor.quantized_handles().unwrap();\n\n    cubek::quantization::dequantize::launch_ref(\n        &output.client,\n        values.binding(),\n        output.clone().binding(),\n        params.binding(),\n        &scheme,\n        dtype.into(),\n    )\n    .expect(\"Kernel to never fail\");\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/quantization/mod.rs",
    "content": "mod dequantize;\nmod quantize;\n\npub use dequantize::*;\npub use quantize::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/quantization/quantize.rs",
    "content": "use crate::CubeRuntime;\nuse crate::{ops::empty_qtensor_optimized, tensor::CubeTensor};\nuse burn_backend::{TensorMetadata, quantization::QuantScheme};\n\n/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.\npub fn quantize<R>(\n    tensor: CubeTensor<R>,\n    scheme: &QuantScheme,\n    scale: CubeTensor<R>,\n) -> CubeTensor<R>\nwhere\n    R: CubeRuntime,\n{\n    let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device);\n    let (out_values, out_params) = output.clone().quantized_handles().unwrap();\n    let dtype = tensor.dtype;\n\n    cubek::quantization::quantize::launch_ref(\n        &output.client,\n        tensor.binding(),\n        out_values.binding(),\n        scale.binding(),\n        out_params.binding(),\n        scheme,\n        dtype.into(),\n    )\n    .expect(\"Kernel to never fail\");\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/reduce/base.rs",
    "content": "#[cfg(feature = \"autotune\")]\nuse super::{autotune_reduce, autotune_sum};\nuse crate::{\n    CubeRuntime,\n    ops::numeric::{empty_device_contiguous_dtype, zeros_client},\n    tensor::CubeTensor,\n};\nuse burn_backend::{DType, TensorMetadata};\nuse burn_std::Metadata;\nuse cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType};\nuse cubek::reduce::{\n    ReduceDtypes, ReduceError, ReduceStrategy,\n    components::instructions::ReduceOperationConfig,\n    launch::{RoutineStrategy, VectorizationStrategy},\n    routines::{BlueprintStrategy, unit::UnitStrategy},\n    shared_sum,\n};\nuse serde::{Deserialize, Serialize};\n\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\n/// Autotune key representative of sum versions\npub struct SumAutotuneKey {\n    /// The type of the tensor\n    dtype: burn_backend::DType,\n    /// The anchored length of the tensor\n    #[autotune(anchor)]\n    length: usize,\n}\n\n/// Check if the client supports atomic add for the given element type.\nfn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {\n    client\n        .properties()\n        .type_usage(StorageType::Atomic(dtype.into()))\n        .contains(TypeUsage::AtomicAdd)\n}\n\n/// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`.\npub fn sum_fallback<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    mut strategy: SumStrategy,\n) -> Result<CubeTensor<R>, ReduceError> {\n    // Early check before creating output and fallback\n    if matches!(strategy, SumStrategy::OneShot(_))\n        && !supports_atomic_add(&tensor.client, tensor.dtype)\n    {\n        strategy = SumStrategy::Chained(Default::default());\n    }\n    sum(tensor, strategy)\n}\n\n/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return\n/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.\n///\n/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.\n///\n/// Return an error if the `client` doesn't support atomic add for the type `E`.\npub fn sum<Run: CubeRuntime>(\n    tensor: CubeTensor<Run>,\n    strategy: SumStrategy,\n) -> Result<CubeTensor<Run>, ReduceError> {\n    let client = tensor.client.clone();\n    let device = tensor.device.clone();\n\n    match strategy {\n        SumStrategy::OneShot(cube_count) => {\n            let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);\n            let dtype = tensor.dtype;\n\n            shared_sum::<Run>(\n                &client,\n                tensor.binding(),\n                output.clone().binding(),\n                cube_count,\n                dtype.into(),\n            )?;\n\n            Ok(output)\n        }\n        SumStrategy::Chained(strategy) => {\n            reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)\n        }\n        #[cfg(feature = \"autotune\")]\n        SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),\n    }\n}\n\n/// Select a strategy to perform a sum.\npub enum SumStrategy {\n    /// Run a single kernel with many cubes working in parallel to sum all elements.\n    /// The provided value is the number of elements summed per unit (up-to-rounding )\n    OneShot(u32),\n    /// Use multiple kernels\n    Chained(KernelReduceStrategy),\n    /// Use autotune to find the best cube count given the hardware and the input.\n    #[cfg(feature = \"autotune\")]\n    Autotune,\n}\n\nimpl Default for SumStrategy {\n    fn default() -> Self {\n        #[cfg(feature = \"autotune\")]\n        return Self::Autotune;\n\n        #[cfg(not(feature = \"autotune\"))]\n        return Self::OneShot(4);\n    }\n}\n\n/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).\n///\n/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.\n///\n/// If there is no error, the output is a tensor with decreasing strides\n/// where the shape of reduced dim is set to 1 but all shape are similar to the input.\npub fn reduce<Run: CubeRuntime>(\n    mut tensor: CubeTensor<Run>,\n    output_dtype: Option<DType>,\n    strategy: KernelReduceStrategy,\n    config: ReduceOperationConfig,\n) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {\n    // In practice, it looks like starting by the axis with the smallest shape\n    // and going in increasing order lead to the fastest calculation.\n    let sorted_axis = argsort(tensor.meta.shape());\n    for axis in sorted_axis {\n        tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;\n    }\n    // reshape to scalar tensor\n    *tensor.meta = Metadata::new([1], [1]);\n    Ok(tensor)\n}\n\nfn argsort(shape: &[usize]) -> Vec<usize> {\n    let mut indices = (0..shape.len()).collect::<Vec<_>>();\n    indices.sort_by_key(|&i| &shape[i]);\n    indices\n}\n\n/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).\n///\n/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.\n/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.\n///\n/// If there is no error, the output is a tensor with decreasing strides\n/// where the shape of reduced dim is set to 1 but all shape are similar to the input.\npub fn reduce_dim<Run: CubeRuntime>(\n    input: CubeTensor<Run>,\n    output_dtype: Option<DType>,\n    dim: usize,\n    strategy: KernelReduceStrategy,\n    config: ReduceOperationConfig,\n) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {\n    debug_assert!(\n        !matches!(\n            config,\n            ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin\n        ) || output_dtype.is_some(),\n        \"The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.\n        \"\n    );\n\n    let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));\n    let client = input.client.clone();\n    let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(\n        cubek::reduce::ReduceError::InvalidAxis {\n            axis: dim,\n            rank: input.meta.num_dims(),\n        },\n    )?;\n\n    let result = match strategy {\n        KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(\n            &client,\n            input.binding(),\n            output.clone().binding(),\n            dim,\n            ReduceStrategy {\n                routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),\n                vectorization: VectorizationStrategy {\n                    parallel_output_vectorization: false,\n                },\n            },\n            config,\n            dtypes,\n        ),\n        KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(\n            &client,\n            input.binding(),\n            output.clone().binding(),\n            dim,\n            strategy,\n            config,\n            dtypes,\n        ),\n        #[cfg(feature = \"autotune\")]\n        KernelReduceStrategy::Autotune => {\n            autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);\n            Ok(())\n        }\n    };\n    result.map(|_| output)\n}\n\n/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`\n/// or return `None` if `axis` is out-of-bound.\npub fn init_reduce_output<Run: CubeRuntime>(\n    input: &CubeTensor<Run>,\n    dim: usize,\n    dtypes: &ReduceDtypes,\n) -> Option<CubeTensor<Run>> {\n    (dim < input.meta.num_dims()).then(|| {\n        let mut shape_out = input.shape();\n        shape_out[dim] = 1;\n        empty_device_contiguous_dtype(\n            input.client.clone(),\n            input.device.clone(),\n            shape_out,\n            dtypes.output.elem_type().into(),\n        )\n    })\n}\n\n/// Select a strategy to perform a reduction.\n#[derive(Clone, Debug)]\npub enum KernelReduceStrategy {\n    /// Use a best-effort strategy based on the hardware capacity.\n    /// This differs from Autotune as it doesn't try and compare many strategies to select the best.\n    Unspecified,\n    /// Fix the exact strategy for the reduction.\n    Specific(cubek::reduce::launch::ReduceStrategy),\n    /// Use autotune to find the best strategy given the hardware and the inputs.\n    #[cfg(feature = \"autotune\")]\n    Autotune,\n}\n\nimpl Default for KernelReduceStrategy {\n    fn default() -> Self {\n        #[cfg(feature = \"autotune\")]\n        return Self::Autotune;\n\n        #[cfg(not(feature = \"autotune\"))]\n        return Self::Unspecified;\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/reduce/mod.rs",
    "content": "mod base;\n#[cfg(feature = \"autotune\")]\nmod tune;\n\npub use base::*;\n#[cfg(feature = \"autotune\")]\npub use tune::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/reduce/tune.rs",
    "content": "#![allow(missing_docs)]\n\nuse super::SumAutotuneKey;\nuse crate::{CubeAutotuneKey, CubeRuntime, CubeTuneId, tensor::CubeTensor};\nuse cubecl::{\n    client::ComputeClient,\n    tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},\n};\nuse cubek::reduce::{\n    ReduceDtypes, ReduceStrategy,\n    components::instructions::ReduceOperationConfig,\n    launch::{RoutineStrategy, VectorizationStrategy, tune_key::ReduceAutotuneKey},\n    routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},\n};\n\n/// Executes autotune on reduce operations.\npub fn autotune_reduce<R: CubeRuntime>(\n    client: &ComputeClient<R>,\n    input: CubeTensor<R>,\n    output: CubeTensor<R>,\n    axis: usize,\n    config: ReduceOperationConfig,\n    dtypes: ReduceDtypes,\n) {\n    use reduce_ops::*;\n\n    static TUNER: LocalTuner<ReduceAutotuneKey, CubeTuneId> = local_tuner!(\"reduce-dim\");\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 2;\n        const PRIORITY_MIN: i8 = 1;\n        const PRIORITY_SKIP: i8 = -1;\n\n        let mut set = TunableSet::new(create_key::<R>, reduce_input_gen::<R>);\n\n        let default_group =\n            TuneGroup::<ReduceAutotuneKey>::new(\"default_reduce\", |_key| PRIORITY_MAX);\n        let vectorized_parallel_group =\n            TuneGroup::<ReduceAutotuneKey>::new(\"vectorized_parallel_reduce\", |key| {\n                if key.axis_is_contiguous {\n                    PRIORITY_MAX\n                } else {\n                    // We disable the tunable with the setting [vector_size.parallel_output_vectorization]\n                    // when the reduce isn't parallel, since it would duplicate tunables.\n                    PRIORITY_SKIP\n                }\n            });\n\n        enum ReduceProps {\n            GreatWithLowReduceCount,\n            GreatWithHighReduceCount,\n            Balanced,\n        }\n\n        for (vectorization, vector_size_ident) in [\n            (\n                VectorizationStrategy {\n                    parallel_output_vectorization: true,\n                },\n                \"_vectorized_parallel_reduce\",\n            ),\n            (\n                VectorizationStrategy {\n                    parallel_output_vectorization: false,\n                },\n                \"\",\n            ),\n        ] {\n            for (name, routine, props) in [\n                (\n                    \"unit\",\n                    RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),\n                    ReduceProps::GreatWithHighReduceCount,\n                ),\n                (\n                    \"plane\",\n                    RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {\n                        independent: true,\n                    })),\n                    ReduceProps::Balanced,\n                ),\n                (\n                    \"cube\",\n                    RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {\n                        use_planes: true,\n                    })),\n                    ReduceProps::GreatWithLowReduceCount,\n                ),\n            ] {\n                let name = format!(\"{name}{vector_size_ident}\");\n                let mut tunable = Tunable::new(\n                    name,\n                    move |(input, output, axis, config, dtypes): (\n                        CubeTensor<R>,\n                        CubeTensor<R>,\n                        usize,\n                        ReduceOperationConfig,\n                        ReduceDtypes,\n                    )| {\n                        let strategy = ReduceStrategy {\n                            routine: routine.clone(),\n                            vectorization,\n                        };\n                        cubek::reduce::reduce::<R>(\n                            &output.client,\n                            input.binding(),\n                            output.clone().binding(),\n                            axis,\n                            strategy,\n                            config,\n                            dtypes,\n                        )\n                        .map_err(|e| format!(\"{e}\"))\n                    },\n                );\n                if vectorization.parallel_output_vectorization {\n                    tunable = tunable.group(&vectorized_parallel_group, |_| PRIORITY_MAX);\n                }\n\n                tunable = tunable.group(&default_group, move |key| match props {\n                    ReduceProps::GreatWithLowReduceCount => {\n                        if key.vector_count < 128 {\n                            PRIORITY_MAX\n                        } else {\n                            // When you have a high level of vector to reduce, it is normally\n                            // better to use another routine.\n                            PRIORITY_MIN\n                        }\n                    }\n                    ReduceProps::GreatWithHighReduceCount => {\n                        if key.vector_count > 64 {\n                            PRIORITY_MAX\n                        } else {\n                            // Bellow 64 it is normally better to use another routine\n                            PRIORITY_MIN\n                        }\n                    }\n                    ReduceProps::Balanced => PRIORITY_MAX,\n                });\n                set = set.with(tunable);\n            }\n        }\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&input.client, &input.device),\n        client,\n        tunables,\n        (input, output, axis, config, dtypes),\n    );\n}\n\npub(crate) fn create_key<Run: CubeRuntime>(\n    input: &CubeTensor<Run>,\n    output: &CubeTensor<Run>,\n    axis: &usize,\n    _config: &ReduceOperationConfig,\n    dtypes: &ReduceDtypes,\n) -> ReduceAutotuneKey {\n    let elem_input = input.dtype.into();\n    let elem_output = output.dtype.into();\n    let elem_acc = dtypes.accumulation.elem_type();\n\n    ReduceAutotuneKey::generate(\n        elem_input,\n        elem_output,\n        elem_acc,\n        input.meta.shape(),\n        input.meta.strides()[*axis] == 1,\n        *axis,\n    )\n}\n\nmod reduce_ops {\n    #![allow(missing_docs)]\n\n    use cubek::reduce::ReduceDtypes;\n\n    use super::*;\n\n    pub(crate) fn reduce_input_gen<Run: CubeRuntime>(\n        _key: &ReduceAutotuneKey,\n        input: &CubeTensor<Run>,\n        output: &CubeTensor<Run>,\n        dim: &usize,\n        config: &ReduceOperationConfig,\n        dtypes: &ReduceDtypes,\n    ) -> (\n        CubeTensor<Run>,\n        CubeTensor<Run>,\n        usize,\n        ReduceOperationConfig,\n        ReduceDtypes,\n    ) {\n        (input.clone(), output.copy(), *dim, *config, *dtypes)\n    }\n}\n\n/// Executes autotune on reduce operations.\n#[cfg(feature = \"autotune\")]\npub fn autotune_sum<R: CubeRuntime>(\n    client: &ComputeClient<R>,\n    input: CubeTensor<R>,\n) -> CubeTensor<R> {\n    use sum_ops::*;\n\n    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!(\"autotune-sum\");\n\n    let tunables = TUNER.init(|| {\n        TunableSet::new(create_key_sum::<R>, sum_input_gen::<R>)\n            .with(Tunable::new(\"sum_chained\", sum_chained::<R>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 1>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 2>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 4>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 8>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 16>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 32>))\n            .with(Tunable::new(\"sum_one_shot\", sum_one_shot::<R, 64>))\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&input.client, &input.device),\n        client,\n        tunables,\n        input,\n    )\n}\n\npub(crate) fn create_key_sum<Run: CubeRuntime>(input: &CubeTensor<Run>) -> CubeAutotuneKey {\n    CubeAutotuneKey::Sum(SumAutotuneKey::generate(input))\n}\n\nimpl SumAutotuneKey {\n    #[allow(unused)]\n    pub(crate) fn generate<Run: CubeRuntime>(input: &CubeTensor<Run>) -> Self {\n        let dtype = input.dtype;\n        let length = input.meta.num_elements();\n        Self::new(dtype, length)\n    }\n}\nmod sum_ops {\n    #![allow(missing_docs)]\n    use crate::ops::numeric::zeros_client;\n\n    use super::*;\n\n    pub(crate) fn sum_input_gen<Run: CubeRuntime>(\n        _key: &CubeAutotuneKey,\n        input: &CubeTensor<Run>,\n    ) -> CubeTensor<Run> {\n        input.clone()\n    }\n\n    pub(crate) fn sum_one_shot<Run: CubeRuntime, const C: u32>(\n        input: CubeTensor<Run>,\n    ) -> Result<CubeTensor<Run>, String> {\n        let client = input.client.clone();\n        let device = input.device.clone();\n        let output = zeros_client(client.clone(), device, [1].into(), input.dtype);\n        let dtype = input.dtype;\n\n        cubek::reduce::shared_sum::<Run>(\n            &output.client,\n            input.binding(),\n            output.clone().binding(),\n            C,\n            dtype.into(),\n        )\n        .map_err(|e| e.to_string())\n        .map(|_| output)\n    }\n\n    #[cfg(feature = \"autotune\")]\n    pub(crate) fn sum_chained<Run: CubeRuntime>(\n        input: CubeTensor<Run>,\n    ) -> Result<CubeTensor<Run>, String> {\n        crate::kernel::reduce::reduce::<Run>(\n            input,\n            None,\n            crate::kernel::reduce::KernelReduceStrategy::Autotune,\n            cubek::reduce::components::instructions::ReduceOperationConfig::Sum,\n        )\n        .map_err(|e| e.to_string())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/unary_float.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\npub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync {\n    type Options: LaunchArg;\n    type Unary<F: Float, N: Size>: FloatUnaryOp<F, N, Options = Self::Options>;\n}\n\n#[cube]\npub(crate) trait FloatUnaryOp<F: Float, N: Size>: 'static + Send + Sync {\n    type Options: LaunchArg;\n\n    fn execute(input: Vector<F, N>, options: &Self::Options) -> Vector<F, N>;\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn unary_float<F: Float, N: Size, O: FloatUnaryOpFamily>(\n    input: &LinearView<Vector<F, N>>,\n    output: &mut LinearView<Vector<F, N>, ReadWrite>,\n    options: &O::Options,\n    #[define(F)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = O::Unary::<F, N>::execute(input[ABSOLUTE_POS], options);\n}\n\npub(crate) fn launch_unary_float<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>\nwhere\n    // Magic fix for lifetime, the closure is supposed to capture everything required to create the\n    // argument.\n    for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<O::Options, R>,\n    R: CubeRuntime,\n    O: FloatUnaryOpFamily,\n{\n    let vector_size = max_vector_size(&tensor);\n\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n    let dtype = tensor.dtype;\n\n    unsafe {\n        if tensor.can_mut() && tensor.is_nonoverlapping() {\n            unary_float::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                tensor.as_linear_view_alias(0),\n                args(&()),\n                dtype.into(),\n            );\n\n            tensor\n        } else {\n            let output = empty_device_dtype(\n                tensor.client.clone(),\n                tensor.device.clone(),\n                tensor.shape(),\n                tensor.dtype,\n            );\n\n            unary_float::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                output.clone().into_linear_view(),\n                args(&()),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n\n/// Use comptime enum to implement all unary operations that don't have any input argument in the\n/// kernel definition.\npub(crate) mod unary_basic {\n    use cubecl::num_traits::{One, Zero};\n\n    use super::*;\n\n    pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>\n    where\n        R: CubeRuntime,\n        for<'a> Args: FnOnce(&'a ()) -> BasicFloatUnaryKind,\n    {\n        launch_unary_float::<R, BasicFloatUnary, _>(tensor, |input| {\n            BasicFloatUnaryOptionsLaunch::new(args(input))\n        })\n    }\n\n    #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]\n    pub enum BasicFloatUnaryKind {\n        Exp,\n        Log,\n        Log1p,\n        Sqrt,\n        Abs,\n        Sign,\n        ArcCos,\n        ArcCosh,\n        ArcSin,\n        ArcSinh,\n        ArcTan,\n        ArcTanh,\n        Cos,\n        Cosh,\n        Sin,\n        Sinh,\n        Tan,\n        Tanh,\n        Round,\n        Floor,\n        Ceil,\n        Trunc,\n        Erf,\n        Recip,\n    }\n\n    #[derive(CubeLaunch, CubeType)]\n    struct BasicFloatUnaryOptions {\n        #[cube(comptime)]\n        kind: BasicFloatUnaryKind,\n    }\n    struct BasicFloatUnary;\n\n    #[cube]\n    impl<F: Float, N: Size> FloatUnaryOp<F, N> for BasicFloatUnary {\n        type Options = BasicFloatUnaryOptions;\n\n        fn execute(input: Vector<F, N>, options: &Self::Options) -> Vector<F, N> {\n            match comptime![options.kind] {\n                BasicFloatUnaryKind::Exp => Vector::exp(input),\n                BasicFloatUnaryKind::Log => Vector::ln(input),\n                BasicFloatUnaryKind::Log1p => Vector::log1p(input),\n                BasicFloatUnaryKind::Sqrt => Vector::sqrt(input),\n                BasicFloatUnaryKind::Abs => Vector::abs(input),\n                BasicFloatUnaryKind::Sign => {\n                    let zero = Vector::zero();\n                    let one = Vector::one();\n                    let minus_one = Vector::new(F::new(-1.0));\n\n                    let is_positive = input.greater_than(zero);\n                    let is_negative = input.less_than(zero);\n                    let sign = select_many(is_negative, minus_one, zero);\n\n                    select_many(is_positive, one, sign)\n                }\n                BasicFloatUnaryKind::Cos => Vector::cos(input),\n                BasicFloatUnaryKind::Sin => Vector::sin(input),\n                BasicFloatUnaryKind::Tan => Vector::tan(input),\n                BasicFloatUnaryKind::Cosh => Vector::cosh(input),\n                BasicFloatUnaryKind::Sinh => Vector::sinh(input),\n                BasicFloatUnaryKind::Tanh => Vector::tanh(input),\n                BasicFloatUnaryKind::Round => Vector::round(input),\n                BasicFloatUnaryKind::Floor => Vector::floor(input),\n                BasicFloatUnaryKind::Ceil => Vector::ceil(input),\n                BasicFloatUnaryKind::Trunc => Vector::trunc(input),\n                BasicFloatUnaryKind::Erf => Vector::erf(input),\n                BasicFloatUnaryKind::Recip => Vector::recip(input),\n                BasicFloatUnaryKind::ArcCos => Vector::acos(input),\n                BasicFloatUnaryKind::ArcCosh => Vector::acosh(input),\n                BasicFloatUnaryKind::ArcSin => Vector::asin(input),\n                BasicFloatUnaryKind::ArcSinh => Vector::asinh(input),\n                BasicFloatUnaryKind::ArcTan => Vector::atan(input),\n                BasicFloatUnaryKind::ArcTanh => Vector::atanh(input),\n            }\n        }\n    }\n\n    impl FloatUnaryOpFamily for BasicFloatUnary {\n        type Options = BasicFloatUnaryOptions;\n        type Unary<F: Float, N: Size> = Self;\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/unary_int.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\npub(crate) trait IntUnaryOpFamily: 'static + Send + Sync {\n    type Options: LaunchArg;\n    type Unary<I: Int, N: Size>: IntUnaryOp<I, N, Options = Self::Options>;\n}\n\n#[cube]\npub(crate) trait IntUnaryOp<I: Scalar, N: Size>: 'static + Send + Sync {\n    type Options: LaunchArg;\n\n    fn execute(input: Vector<I, N>, options: &Self::Options) -> Vector<I, N>;\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn unary_int<I: Int, N: Size, O: IntUnaryOpFamily>(\n    input: &LinearView<Vector<I, N>>,\n    output: &mut LinearView<Vector<I, N>, ReadWrite>,\n    options: &O::Options,\n    #[define(I)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = O::Unary::<I, N>::execute(input[ABSOLUTE_POS], options);\n}\n\npub(crate) fn launch_unary_int<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>\nwhere\n    for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<O::Options, R>,\n    R: CubeRuntime,\n    O: IntUnaryOpFamily,\n{\n    let vector_size = max_vector_size(&tensor);\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n    let dtype = tensor.dtype;\n\n    unsafe {\n        if tensor.can_mut() && tensor.is_nonoverlapping() {\n            unary_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                tensor.as_linear_view_alias(0),\n                args(&()),\n                dtype.into(),\n            );\n\n            tensor\n        } else {\n            let output = empty_device_dtype(\n                tensor.client.clone(),\n                tensor.device.clone(),\n                tensor.shape(),\n                tensor.dtype,\n            );\n\n            unary_int::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                output.clone().into_linear_view(),\n                args(&()),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n\npub(crate) mod unary_basic_int {\n\n    use cubecl::num_traits::{One, Zero};\n\n    use super::*;\n\n    pub(crate) fn launch<R, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>\n    where\n        R: CubeRuntime,\n        for<'a> Args: FnOnce(&'a ()) -> BasicIntUnaryKind,\n    {\n        launch_unary_int::<R, BasicIntUnary, _>(tensor, |input| {\n            BasicIntUnaryOptionsLaunch::new(args(input))\n        })\n    }\n\n    #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]\n    pub enum BasicIntUnaryKind {\n        BitwiseNot,\n        Sign,\n    }\n\n    #[derive(CubeLaunch, CubeType)]\n    struct BasicIntUnaryOptions {\n        #[cube(comptime)]\n        kind: BasicIntUnaryKind,\n    }\n    struct BasicIntUnary;\n\n    #[cube]\n    impl<I: Int, N: Size> IntUnaryOp<I, N> for BasicIntUnary {\n        type Options = BasicIntUnaryOptions;\n\n        fn execute(input: Vector<I, N>, options: &Self::Options) -> Vector<I, N> {\n            match comptime![options.kind] {\n                BasicIntUnaryKind::BitwiseNot => !input,\n                BasicIntUnaryKind::Sign => {\n                    let zero = Vector::zero();\n                    let one = Vector::one();\n                    let minus_one = Vector::new(I::new(-1));\n\n                    let is_positive = input.greater_than(zero);\n                    let is_negative = input.less_than(zero);\n                    let sign = select_many(is_negative, minus_one, zero);\n\n                    select_many(is_positive, one, sign)\n                }\n            }\n        }\n    }\n\n    impl IntUnaryOpFamily for BasicIntUnary {\n        type Options = BasicIntUnaryOptions;\n        type Unary<I: Int, N: Size> = Self;\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/unary_numeric.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::address_type,\n    ops::{max_vector_size, numeric::empty_device_dtype},\n    tensor::CubeTensor,\n};\nuse burn_backend::TensorMetadata;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView};\n\npub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync {\n    type Options: LaunchArg;\n    type Unary<T: Numeric, N: Size>: NumericUnaryOp<T, N, Options = Self::Options>;\n}\n\n#[cube]\npub(crate) trait NumericUnaryOp<T: Scalar, N: Size>: 'static + Send + Sync {\n    type Options: LaunchArg;\n\n    fn execute(input: Vector<T, N>, options: &Self::Options) -> Vector<T, N>;\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub(crate) fn unary_numeric<T: Numeric, N: Size, O: NumericUnaryOpFamily>(\n    input: &LinearView<Vector<T, N>>,\n    output: &mut LinearView<Vector<T, N>, ReadWrite>,\n    options: &O::Options,\n    #[define(T)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    output[ABSOLUTE_POS] = O::Unary::<T, N>::execute(input[ABSOLUTE_POS], options);\n}\n\npub(crate) fn launch_unary_numeric<R, O, Args>(tensor: CubeTensor<R>, args: Args) -> CubeTensor<R>\nwhere\n    // Magic fix for lifetime, the closure is supposed to capture everything required to create the\n    // argument.\n    for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<O::Options, R>,\n    R: CubeRuntime,\n    O: NumericUnaryOpFamily,\n{\n    let vector_size = max_vector_size(&tensor);\n    let client = tensor.client.clone();\n    let num_elems = tensor.meta.num_elements();\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&tensor.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);\n    let dtype = tensor.dtype;\n\n    unsafe {\n        if tensor.can_mut() && tensor.is_nonoverlapping() {\n            unary_numeric::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor),\n                vector_size,\n                tensor.clone().into_linear_view(),\n                tensor.as_linear_view_alias(0),\n                args(&()),\n                dtype.into(),\n            );\n\n            tensor\n        } else {\n            let output = empty_device_dtype(\n                tensor.client.clone(),\n                tensor.device.clone(),\n                tensor.shape(),\n                tensor.dtype,\n            );\n\n            unary_numeric::launch_unchecked::<O, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                address_type!(tensor, output),\n                vector_size,\n                tensor.into_linear_view(),\n                output.clone().into_linear_view(),\n                args(&()),\n                dtype.into(),\n            );\n\n            output\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/kernel/utils.rs",
    "content": "use burn_backend::Shape;\nuse cubecl::prelude::SequenceArg;\nuse cubecl::{\n    ir::{UIntKind, VectorSize},\n    prelude::*,\n    std::{\n        FastDivmod, FastDivmodInt,\n        tensor::layout::linear::{LinearLayoutLaunch, LinearViewLayoutLaunch},\n    },\n};\n\nuse crate::{CubeRuntime, tensor::CubeTensor};\n\npub fn shape_divmod<R: CubeRuntime>(tensor: &CubeTensor<R>) -> SequenceArg<R, FastDivmod<usize>> {\n    let mut arg = SequenceArg::new();\n    for dim in tensor.meta.shape().iter() {\n        arg.push(*dim);\n    }\n    arg\n}\n\npub fn linear_layout<R: CubeRuntime>(\n    tensor: &CubeTensor<R>,\n    vector_size: VectorSize,\n) -> LinearLayoutLaunch<R> {\n    LinearLayoutLaunch::from_shape_strides(\n        tensor.meta.shape().clone(),\n        tensor.meta.strides().clone(),\n        // Don't care about type size, only vector size\n        Type::new(UIntKind::U32.into()).with_vector_size(vector_size),\n        LinearViewLayoutLaunch::new(),\n    )\n}\n\npub fn split_dim<R: CubeRuntime>(\n    mut tensor: CubeTensor<R>,\n    dim: usize,\n    shape: &[usize],\n) -> CubeTensor<R> {\n    let mut stride = tensor.meta.strides()[dim];\n    tensor.meta.remove(dim);\n\n    for size in shape.iter().rev() {\n        tensor.meta.insert(dim, *size, stride);\n        stride *= size;\n    }\n\n    tensor\n}\n\npub fn broadcast_shape<R: CubeRuntime>(tensors: &[&CubeTensor<R>]) -> Shape {\n    let rank = tensors[0].meta.num_dims();\n    debug_assert!(\n        tensors.iter().all(|it| it.meta.num_dims() == rank),\n        \"Broadcast tensors must have the same rank\"\n    );\n\n    let dims = (0..rank).map(|dim| {\n        let max = tensors.iter().map(|it| it.meta.shape()[dim]).max();\n        let max = max.unwrap_or(1);\n        debug_assert!(\n            tensors\n                .iter()\n                .all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1),\n            \"Broadcast dims must be size 1\"\n        );\n        max\n    });\n\n    Shape::from(dims)\n}\n\npub fn broadcast_strides<R: CubeRuntime>(\n    reference: &CubeTensor<R>,\n    tensor: &CubeTensor<R>,\n) -> SequenceArg<R, usize> {\n    if reference.meta.shape() != tensor.meta.shape() {\n        tensor\n            .meta\n            .strides()\n            .iter()\n            .zip(\n                tensor\n                    .meta\n                    .shape()\n                    .iter()\n                    .zip(reference.meta.shape().iter()),\n            )\n            .map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 })\n            .collect()\n    } else {\n        tensor.meta.strides().iter().copied().collect()\n    }\n}\n\n#[cube]\npub(crate) fn decompose_linear<I: FastDivmodInt>(\n    pos: I,\n    shape: &Sequence<FastDivmod<I>>,\n) -> (I, Sequence<I>) {\n    let rank = comptime![shape.len()];\n    let mut offs = pos;\n    let mut out = Sequence::new();\n\n    #[unroll]\n    for i in 0..rank {\n        let dim = comptime![rank - i - 1];\n        let (rem, offs_local) = shape.index(dim).div_mod(offs);\n        out.push(offs_local);\n        offs = rem;\n    }\n\n    (offs, out.rev())\n}\n\npub(crate) trait RequiredAddrType {\n    fn required_address_type(&self) -> AddressType;\n}\n\nimpl<R: CubeRuntime> RequiredAddrType for CubeTensor<R> {\n    fn required_address_type(&self) -> AddressType {\n        self.required_address_type()\n    }\n}\nimpl<R: CubeRuntime> RequiredAddrType for Option<CubeTensor<R>> {\n    fn required_address_type(&self) -> AddressType {\n        self.as_ref()\n            .map(|it| it.required_address_type())\n            .unwrap_or_default()\n    }\n}\n\nmacro_rules! address_type {\n    ($($tensor: tt),*) => {\n        [$($crate::kernel::utils::RequiredAddrType::required_address_type(&$tensor)),*]\n        .into_iter()\n        .max()\n        .unwrap_or_default()\n    };\n}\npub(crate) use address_type;\n"
  },
  {
    "path": "crates/burn-cubecl/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! Burn JIT Backend\n\n#[macro_use]\nextern crate derive_new;\nextern crate alloc;\n\n/// Utilities for implementing JIT kernels\npub mod ops;\n\n/// Kernel module\npub mod kernel;\n/// Tensor module.\npub mod tensor;\n\n/// Elements for JIT backend\npub mod element;\n\nuse cubecl::{CubeTask, Runtime};\npub use element::{BoolElement, CubeElement, FloatElement, IntElement};\n\nmod backend;\n\npub use backend::*;\n\n// Re-export cubecl.\npub use cubecl;\n\nmod tune_key;\npub use tune_key::CubeAutotuneKey;\n\n#[cfg(any(feature = \"fusion\", test))]\n/// Module for interacting with fusion\npub mod fusion;\n\n#[cfg(feature = \"template\")]\n/// Module for compiling custom non-jit kernels\npub mod template;\n\n/// Just-in-Time runtime extending the [cube runtime](Runtime).\npub trait CubeRuntime: Runtime<Device = Self::CubeDevice, Server = Self::CubeServer> {\n    /// The device that should also implement [burn_backend::backend::DeviceOps].\n    type CubeDevice: burn_backend::DeviceOps;\n    /// The cube server with the [CubeAutotuneKey].\n    type CubeServer: cubecl::server::ComputeServer<Kernel = Box<dyn CubeTask<Self::Compiler>>>;\n}\n\npub use cubecl::CubeTuneId;\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/activation.rs",
    "content": "use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};\nuse burn_backend::ops::ActivationOps;\n\nimpl<R, F, I, BT> ActivationOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/base.rs",
    "content": "use crate::{CubeRuntime, kernel, ops::numeric::empty_device_dtype, tensor::CubeTensor};\nuse burn_backend::{\n    DType, ExecutionError, QTensorPrimitive, Shape, TensorData,\n    quantization::{QuantLevel, QuantStore, params_shape},\n};\nuse burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape};\nuse burn_std::{\n    Metadata, strides,\n    tensor::{ReshapeAction, contiguous_strides, reshape_action},\n};\nuse cubecl::{ir::VectorSize, server::CopyDescriptor};\nuse cubecl::{quant::scheme::BlockSize, tensor_vector_size_parallel};\n\npub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {\n    let client = R::client(device);\n    let alloc = client.create_tensor(data.bytes, data.shape.clone(), data.dtype.size());\n    let shape: Shape = (&data.shape).into();\n    CubeTensor::new(\n        client,\n        alloc.memory,\n        Metadata::new(shape, alloc.strides),\n        device.clone(),\n        data.dtype,\n    )\n}\n\npub(crate) async fn into_data<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n) -> Result<TensorData, ExecutionError> {\n    let tensor = kernel::into_contiguous_aligned(tensor);\n\n    let elem_size = tensor.elem_size();\n    let shape = tensor.meta.shape().clone();\n    let strides = tensor.meta.strides().clone();\n    let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size);\n    let bytes = tensor\n        .client\n        .read_one_tensor_async(binding)\n        .await\n        .map_err(|err| ExecutionError::WithContext {\n            reason: format!(\"{err}\"),\n        })?;\n\n    Ok(TensorData::from_bytes(\n        bytes,\n        tensor.meta.shape.clone(),\n        tensor.dtype,\n    ))\n}\n\n/// Read data from a `CubeTensor` synchronously\n#[allow(unused, reason = \"useful for debugging kernels\")]\npub fn into_data_sync<R: CubeRuntime>(tensor: CubeTensor<R>) -> TensorData {\n    burn_std::future::block_on(into_data(tensor)).unwrap()\n}\n\n#[cfg_attr(\n    feature = \"tracing\",\n    tracing::instrument(level = \"trace\", skip(tensor, device))\n)]\npub(crate) fn to_device<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    device: &R::Device,\n) -> CubeTensor<R> {\n    if &tensor.device == device {\n        return tensor;\n    }\n\n    let tensor = kernel::into_contiguous_aligned(tensor);\n    let client = R::client(device);\n    tensor.to_client(client, device.clone())\n}\n\npub(crate) fn empty<R: CubeRuntime>(\n    shape: Shape,\n    device: &R::Device,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n    let alloc = client.empty_tensor(shape.clone(), dtype.size());\n\n    CubeTensor::new(\n        client,\n        alloc.memory,\n        Metadata::new(shape, alloc.strides),\n        device.clone(),\n        dtype,\n    )\n}\n\npub(crate) fn swap_dims<R: CubeRuntime>(\n    mut tensor: CubeTensor<R>,\n    dim1: usize,\n    dim2: usize,\n) -> CubeTensor<R> {\n    tensor.meta.swap(dim1, dim2);\n\n    if let DType::QFloat(scheme) = tensor.dtype\n        && let QuantLevel::Block(block_size) = scheme.level\n    {\n        let rank = tensor.rank();\n        let qparams = tensor.qparams.as_mut().unwrap();\n        let mut block_size = block_size.to_dim_vec(rank);\n        block_size.swap(dim1, dim2);\n\n        // Truncate unit dims from the start\n        let block_size = BlockSize::new_trim(block_size);\n        if block_size.len() > BlockSize::MAX_DIMS {\n            panic!(\"Swapped block size would exceed max dims\");\n        }\n\n        qparams.scales.metadata.swap(dim1, dim2);\n\n        tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size)))\n    }\n\n    if let DType::QFloat(scheme) = &mut tensor.dtype\n        && let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) =\n            &mut scheme.store\n    {\n        let rank = tensor.meta.num_dims();\n\n        if *packed_dim == rank - dim1 - 1 {\n            *packed_dim = rank - dim2 - 1;\n        } else if *packed_dim == rank - dim2 - 1 {\n            *packed_dim = rank - dim1 - 1;\n        }\n    }\n\n    tensor\n}\n\n/// Permute a tensor's dimensions\npub fn permute<R: CubeRuntime>(mut tensor: CubeTensor<R>, axes: &[usize]) -> CubeTensor<R> {\n    tensor.meta.permute(axes).unwrap();\n\n    if let DType::QFloat(scheme) = tensor.dtype\n        && let QuantLevel::Block(block_size) = scheme.level\n    {\n        let rank = tensor.rank();\n        let qparams = tensor.qparams.as_mut().unwrap();\n\n        let mut block_size = block_size.to_dim_vec(rank);\n        block_size = axes.iter().map(|i| block_size[*i]).collect();\n\n        // Truncate unit dims from the start\n        let block_size = block_size\n            .into_iter()\n            .skip_while(|it| *it == 1)\n            .collect::<Vec<_>>();\n        if block_size.len() > BlockSize::MAX_DIMS {\n            panic!(\"Swapped block size would exceed max dims\");\n        }\n\n        qparams.scales.metadata.permute(axes).unwrap();\n\n        tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size)))\n    }\n\n    if let DType::QFloat(scheme) = &mut tensor.dtype\n        && let QuantStore::PackedU32(packed_dim) = &mut scheme.store\n    {\n        let rank = tensor.meta.num_dims();\n        let new_pos = axes\n            .iter()\n            .position(|axis| *axis == rank - *packed_dim - 1)\n            .unwrap_or(0);\n        *packed_dim = rank - new_pos - 1;\n    }\n\n    tensor\n}\n\n/// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent\npub fn permute_nchw_to_nhwc<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {\n    let rank = tensor.meta.num_dims();\n    let c_dim = 1;\n\n    let mut dims = vec![0];\n    dims.extend(2..rank);\n    dims.push(c_dim);\n\n    permute(tensor, &dims)\n}\n\n/// Permute a shape's dimensions from NCHW to NHWC, or the N-dimensional equivalent\npub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape {\n    let rank = shape.num_dims();\n    let c_dim = 1;\n\n    let mut dims = vec![0];\n    dims.extend(2..rank);\n    dims.push(c_dim);\n\n    shape.permuted(&dims).expect(\"Shape permute should succeed\")\n}\n\n/// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent\npub fn permute_nhwc_to_nchw<R: CubeRuntime>(tensor: CubeTensor<R>) -> CubeTensor<R> {\n    let rank = tensor.meta.num_dims();\n    let c_dim = rank - 1;\n\n    let mut dims = vec![0];\n    dims.push(c_dim);\n    dims.extend(1..c_dim);\n\n    permute(tensor, &dims)\n}\n\n/// Permute a shape's dimensions from NHWC to NCHW, or the N-dimensional equivalent\npub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape {\n    let rank = shape.num_dims();\n    let c_dim = rank - 1;\n\n    let mut dims = vec![0];\n    dims.push(c_dim);\n    dims.extend(1..c_dim);\n\n    shape.permuted(&dims).expect(\"Shape permute should succeed\")\n}\n\npub(crate) fn expand<R: CubeRuntime>(tensor: CubeTensor<R>, target_shape: Shape) -> CubeTensor<R> {\n    let ndims_in = tensor.meta.shape().num_dims();\n    let ndims_out = target_shape.num_dims();\n\n    // Initialize new strides with zeros\n    let mut new_strides = strides![0usize; ndims_out];\n\n    // Calculate the difference in dimensions\n    let dim_diff = ndims_out.saturating_sub(ndims_in);\n\n    // Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones\n    let mut tensor_dim_iter = tensor.meta.shape().iter().rev();\n    for i in (0..ndims_out).rev() {\n        if i >= dim_diff {\n            if let Some(&tensor_dim) = tensor_dim_iter.next() {\n                if tensor_dim == target_shape[i] || tensor_dim == 1 {\n                    // Copy stride for non-broadcast dimensions or set to 0 for broadcast ones\n                    new_strides[i] = if tensor_dim == target_shape[i] {\n                        tensor.meta.strides()[i - dim_diff]\n                    } else {\n                        0\n                    };\n                } else {\n                    // Error handling: Dimension mismatch for broadcasting\n                    panic!(\n                        \"Dimension mismatch: cannot broadcast dimension {tensor_dim} of tensor to target shape\"\n                    );\n                }\n            } else {\n                // If the input tensor has fewer dimensions, treat missing dimensions as 1\n                // and set stride to 0 (broadcasting)\n                new_strides[i] = 0;\n            }\n        } else {\n            // For extra dimensions in the target shape, set stride to 0 (broadcasting)\n            new_strides[i] = 0;\n        }\n    }\n\n    // Extra check to ensure block scales must be properly handled once they're added\n    if tensor.qparams.is_some() {\n        match tensor.scheme().level {\n            QuantLevel::Tensor => {}\n            QuantLevel::Block(_) => todo!(),\n        }\n    }\n\n    CubeTensor {\n        client: tensor.client.clone(),\n        device: tensor.device.clone(),\n        meta: Box::new(Metadata::new(target_shape, new_strides)),\n        handle: tensor.handle.clone(),\n        dtype: tensor.dtype,\n        qparams: tensor.qparams.clone(),\n    }\n}\n\n/// Reshape a jit tensor to a new shape\npub fn reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {\n    let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape);\n\n    match analysis {\n        ReshapeAction::UpdateStrides { strides } => {\n            *tensor.meta = Metadata::new(shape, strides);\n            return tensor;\n        }\n        ReshapeAction::NoChange => return tensor,\n        ReshapeAction::Recompute => (),\n    }\n\n    let out = empty_device_dtype(\n        tensor.client.clone(),\n        tensor.device.clone(),\n        shape,\n        tensor.dtype,\n    );\n\n    cubecl::std::tensor::copy_into(\n        &out.client,\n        tensor.binding(),\n        out.clone().binding(),\n        out.dtype.into(),\n    );\n\n    out\n}\n\n/// Reshape a jit tensor to a new shape\npub fn q_reshape<R: CubeRuntime>(mut tensor: CubeTensor<R>, shape: Shape) -> CubeTensor<R> {\n    let scheme = *tensor.scheme();\n\n    let shape_values = {\n        let rank = shape.num_dims();\n        let mut shape = shape.clone();\n        shape[rank - 1] = shape[rank - 1].div_ceil(scheme.num_quants());\n        shape\n    };\n    let shape_scales = params_shape(&shape, scheme.level);\n    let (values, scales) = tensor.quantized_handles().unwrap();\n\n    let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values);\n    let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales);\n\n    match (analysis_values, analysis_scales) {\n        (\n            ReshapeAction::UpdateStrides { strides },\n            ReshapeAction::UpdateStrides {\n                strides: scales_strides,\n            },\n        ) => {\n            let qparams = tensor.qparams.as_mut().unwrap();\n\n            *tensor.meta = Metadata::new(shape, strides);\n            qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);\n        }\n        (ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => {\n            *tensor.meta = Metadata::new(shape, strides);\n        }\n        (\n            ReshapeAction::NoChange,\n            ReshapeAction::UpdateStrides {\n                strides: scales_strides,\n            },\n        ) => {\n            let qparams = tensor.qparams.as_mut().unwrap();\n\n            qparams.scales.metadata = Metadata::new(shape_scales, scales_strides);\n        }\n        (ReshapeAction::NoChange, ReshapeAction::NoChange) => {}\n        _ => {\n            tensor = kernel::into_contiguous(tensor);\n            *tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values));\n\n            let qparams = tensor.qparams.as_mut().unwrap();\n\n            let strides = contiguous_strides(&shape_scales);\n            qparams.scales.metadata = Metadata::new(shape_scales, strides);\n        }\n    }\n\n    tensor\n}\n\npub(crate) fn max_vector_size<R: CubeRuntime>(tensor: &CubeTensor<R>) -> VectorSize {\n    tensor_vector_size_parallel(\n        tensor.client.io_optimized_vector_sizes(tensor.dtype.size()),\n        tensor.meta.shape(),\n        tensor.meta.strides(),\n        tensor.meta.num_dims() - 1,\n    )\n}\n\npub(crate) fn max_vector_size_many<R: CubeRuntime>(\n    tensors: &[&CubeTensor<R>],\n    axis: usize,\n) -> VectorSize {\n    let vec = tensors\n        .iter()\n        .map(|tensor| {\n            tensor_vector_size_parallel(\n                tensor.client.io_optimized_vector_sizes(tensor.dtype.size()),\n                tensor.meta.shape(),\n                tensor.meta.strides(),\n                axis,\n            )\n        })\n        .min();\n\n    vec.unwrap_or(0)\n}\n\n/// Unfold windows along a dimension.\n///\n/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n/// where windows are advanced by `step` at each index.\n///\n/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n///\n/// The new view will have the unfolded dimension replaced by two dimensions;\n/// one in the position of the original dimension, with size equal to the number of windows,\n/// and one appended to the right-most position, with size equal to `size`.\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n/// * `dim` - the dimension to unfold.\n/// * `size` - the size of each unfolded window.\n/// * `step` - the step between each window.\n///\n/// # Returns\n///\n/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.\npub fn unfold<R: CubeRuntime>(\n    tensor: CubeTensor<R>,\n    dim: usize,\n    size: usize,\n    step: usize,\n) -> CubeTensor<R> {\n    let shape = calculate_unfold_shape(tensor.shape(), dim, size, step);\n\n    let d_stride = tensor.meta.strides()[dim];\n    let mut strides = tensor.meta.strides.clone();\n    strides[dim] = step * d_stride;\n    strides.push(d_stride);\n\n    CubeTensor {\n        meta: Box::new(Metadata::new(shape, strides)),\n        client: tensor.client.clone(),\n        handle: tensor.handle.clone(),\n        device: tensor.device.clone(),\n        dtype: tensor.dtype,\n        qparams: tensor.qparams.clone(),\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/bool_tensor.rs",
    "content": "use crate::{\n    CubeBackend, CubeRuntime, FloatElement, IntElement,\n    element::{BoolElement, bool_dtype},\n    kernel::{self, AndOp, OrOp},\n};\nuse burn_backend::{\n    ExecutionError, Slice,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, Device, FloatTensor, IntTensor},\n};\nuse burn_backend::{Scalar, Shape, TensorData};\nuse burn_std::{BoolStore, DType};\nuse cubecl::prelude::InputScalar;\nuse std::ops::Range;\n\nuse super::{expand, numeric, permute, unfold};\n\nimpl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        super::empty(shape, device, bool_dtype::<BT>())\n    }\n\n    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        numeric::zeros(device.clone(), shape, bool_dtype::<BT>())\n    }\n\n    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        numeric::ones(device.clone(), shape, bool_dtype::<BT>())\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {\n        super::into_data(tensor).await\n    }\n\n    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {\n        let bool_dtype = bool_dtype::<BT>();\n        // TODO: remove once backends no longer rely on generics for default elem types\n        let data = match (data.dtype, bool_dtype) {\n            (DType::U8, DType::Bool(BoolStore::U8)) | (DType::U32, DType::Bool(BoolStore::U32)) => {\n                // No-op, but change dtype to bool w/ storage type\n                data.convert_dtype(bool_dtype)\n            }\n            (DType::U8, DType::U8) | (DType::U32, DType::U32) => data,\n            other => unimplemented!(\"Unsupported dtype for `bool_from_data` {other:?}\"),\n        };\n        super::from_data(data, device)\n    }\n\n    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        kernel::bool_cast::<R, I>(tensor)\n    }\n\n    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {\n        tensor.device.clone()\n    }\n\n    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {\n        super::to_device(tensor, device)\n    }\n\n    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        super::reshape(tensor, shape)\n    }\n\n    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {\n        // Check if all steps are 1\n        let all_steps_one = slices.iter().all(|info| info.step == 1);\n\n        if all_steps_one {\n            // Use optimized slice for step=1\n            let simple_ranges: Vec<Range<usize>> = slices\n                .iter()\n                .enumerate()\n                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))\n                .collect();\n\n            kernel::slice(tensor, &simple_ranges)\n        } else {\n            // Use slice with steps kernel\n            kernel::slice_with_steps(tensor, slices)\n        }\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        ranges: &[Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::slice_assign(tensor, ranges, value)\n    }\n\n    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        kernel::equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        kernel::equal_elem(\n            tensor,\n            InputScalar::new(BT::false_val(), bool_dtype::<BT>()),\n            bool_dtype::<BT>(),\n        )\n    }\n\n    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        kernel::launch_binop::<R, AndOp>(lhs, rhs)\n    }\n\n    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        kernel::launch_binop::<R, OrOp>(lhs, rhs)\n    }\n\n    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {\n        kernel::bool_cast::<R, F>(tensor)\n    }\n\n    fn bool_swap_dims(mut tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {\n        tensor.meta.swap(dim1, dim2);\n\n        tensor\n    }\n\n    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {\n        kernel::repeat_dim(tensor, dim, times)\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        permute(tensor, axes)\n    }\n\n    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn bool_select(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::select(tensor, dim, indices)\n    }\n\n    fn bool_select_or(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::select_assign(tensor, dim, indices, value, true)\n    }\n\n    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        kernel::flip(tensor, axes, bool_dtype::<BT>())\n    }\n\n    fn bool_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::mask_where_auto(tensor, mask, value, bool_dtype::<BT>())\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        let dtype = tensor.dtype;\n        kernel::mask_fill_auto(tensor, mask, InputScalar::new(value, dtype), dtype)\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::gather(dim, tensor, indices)\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        kernel::scatter(dim, tensor, indices, value, true)\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), dtype)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/int_tensor.rs",
    "content": "use self::unary_basic_int::BasicIntUnaryKind;\n\nuse super::{expand, numeric, permute, unfold};\nuse crate::element::bool_dtype;\nuse crate::kernel::{\n    BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, launch_binop_int,\n    launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,\n};\nuse crate::{\n    CubeBackend, CubeRuntime, FloatElement, IntElement,\n    kernel::{\n        self,\n        matmul::{MatmulStrategy, matmul},\n    },\n};\nuse crate::{\n    element::BoolElement,\n    kernel::prng::{random_bernoulli, random_normal, random_uniform},\n};\nuse burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};\nuse burn_backend::{DType, IntDType, Slice, ops::IntTensorOps};\nuse burn_backend::{Distribution, ElementConversion, Shape, TensorData};\nuse burn_backend::{ExecutionError, Scalar};\nuse cubecl::frontend::Numeric;\nuse cubecl::prelude::*;\nuse cubek::reduce::components::instructions::ReduceOperationConfig;\nuse std::ops::Range;\n\nimpl<R, F, I, BT> IntTensorOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let dtype = dtype.into();\n        super::empty(shape, device, dtype)\n    }\n\n    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {\n        super::into_data(tensor).await\n    }\n\n    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {\n        match data.dtype {\n            DType::I64\n            | DType::I32\n            | DType::I16\n            | DType::I8\n            | DType::U64\n            | DType::U32\n            | DType::U16\n            | DType::U8 => super::from_data(data, device),\n            _ => unimplemented!(\"Unsupported dtype for `int_from_data`\"),\n        }\n    }\n\n    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {\n        tensor.device.clone()\n    }\n\n    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {\n        super::to_device(tensor, device)\n    }\n\n    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        super::reshape(tensor, shape)\n    }\n\n    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {\n        // Check if all steps are 1\n        let all_steps_one = slices.iter().all(|info| info.step == 1);\n\n        if all_steps_one {\n            // Use optimized slice for step=1\n            let simple_ranges: Vec<Range<usize>> = slices\n                .iter()\n                .enumerate()\n                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))\n                .collect();\n\n            kernel::slice(tensor, &simple_ranges)\n        } else {\n            // Use slice with steps kernel\n            kernel::slice_with_steps(tensor, slices)\n        }\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<Self>,\n        ranges: &[Slice],\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::slice_assign(tensor, ranges, value)\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::mask_where_auto(tensor, mask, value, bool_dtype::<BT>())\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> IntTensor<Self> {\n        let dtype = tensor.dtype;\n        kernel::mask_fill_auto(\n            tensor,\n            mask,\n            InputScalar::new(value, dtype),\n            bool_dtype::<BT>(),\n        )\n    }\n\n    fn int_gather(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::gather(dim, tensor, indices)\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::scatter(dim, tensor, indices, value, false)\n    }\n\n    fn int_select(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::select(tensor, dim, indices)\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        kernel::select_assign(tensor, dim, indices, value, false)\n    }\n\n    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        kernel::equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        kernel::greater(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        kernel::greater_equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        kernel::lower(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        kernel::lower_equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::add(lhs, rhs)\n    }\n\n    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::sub(lhs, rhs)\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::mul(lhs, rhs)\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::div(lhs, rhs)\n    }\n\n    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::remainder(lhs, rhs)\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let dtype = dtype.into();\n        numeric::zeros(device.clone(), shape, dtype)\n    }\n\n    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let dtype = dtype.into();\n        numeric::ones(device.clone(), shape, dtype)\n    }\n\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<Self>,\n        dtype: IntDType,\n    ) -> IntTensor<Self> {\n        let dtype: DType = dtype.into();\n        let client = R::client(device);\n        numeric::full_device_dtype(\n            client,\n            shape,\n            device.clone(),\n            InputScalar::new(fill_value, dtype),\n            dtype,\n        )\n    }\n\n    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        reduce::sum_fallback(tensor, Default::default()).unwrap()\n    }\n\n    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Sum,\n        )\n        .unwrap()\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        reduce::reduce(\n            tensor,\n            None,\n            Default::default(),\n            ReduceOperationConfig::Prod,\n        )\n        .unwrap()\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Prod,\n        )\n        .unwrap()\n    }\n\n    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()\n    }\n\n    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Max,\n        )\n        .unwrap()\n    }\n\n    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        reduce::reduce(\n            tensor,\n            None,\n            Default::default(),\n            ReduceOperationConfig::MaxAbs,\n        )\n        .unwrap()\n    }\n\n    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::MaxAbs,\n        )\n        .unwrap()\n    }\n\n    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()\n    }\n\n    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Min,\n        )\n        .unwrap()\n    }\n\n    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Mean,\n        )\n        .unwrap()\n    }\n\n    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        numeric::cumsum(tensor, dim)\n    }\n\n    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        numeric::cumprod(tensor, dim)\n    }\n\n    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        numeric::cummin(tensor, dim)\n    }\n\n    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        numeric::cummax(tensor, dim)\n    }\n\n    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let dtype = tensor.dtype;\n        reduce::reduce_dim(\n            tensor,\n            Some(dtype),\n            dim,\n            Default::default(),\n            ReduceOperationConfig::ArgMax,\n        )\n        .unwrap()\n    }\n\n    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let dtype = tensor.dtype;\n        reduce::reduce_dim(\n            tensor,\n            Some(dtype),\n            dim,\n            Default::default(),\n            ReduceOperationConfig::ArgMin,\n        )\n        .unwrap()\n    }\n\n    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {\n        let dtype = tensor.dtype;\n        kernel::clamp(\n            tensor,\n            InputScalar::new(min, dtype),\n            InputScalar::new(max, dtype),\n        )\n    }\n\n    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        struct Abs;\n\n        #[cube]\n        impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Abs {\n            type Options = ();\n\n            fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {\n                Vector::abs(input)\n            }\n        }\n\n        impl NumericUnaryOpFamily for Abs {\n            type Options = ();\n            type Unary<T: Numeric, N: Size> = Self;\n        }\n\n        launch_unary_numeric::<R, Abs, _>(tensor, |_| ())\n    }\n\n    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::Sign)\n    }\n\n    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {\n        kernel::cast(tensor, F::dtype())\n    }\n\n    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        tensor.meta.swap(dim1, dim2);\n\n        tensor\n    }\n\n    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {\n        kernel::repeat_dim(tensor, dim, times)\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> IntTensor<Self> {\n        let dtype = IntElem::<Self>::dtype();\n        match distribution {\n            Distribution::Default => random_uniform(shape, device, 0., 255., dtype),\n            Distribution::Uniform(low, high) => {\n                random_uniform(shape, device, low.elem(), high.elem(), dtype)\n            }\n            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),\n            Distribution::Normal(mean, std) => {\n                random_normal(shape, device, mean.elem(), std.elem(), dtype)\n            }\n        }\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        permute(tensor, axes)\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        kernel::flip(tensor, axes, bool_dtype::<BT>())\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::bitwise_and(lhs, rhs)\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::bitwise_and_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::bitwise_or(lhs, rhs)\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::bitwise_or_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        numeric::bitwise_xor(lhs, rhs)\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::bitwise_xor_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_basic_int::launch::<R, _>(tensor, |_| BasicIntUnaryKind::BitwiseNot)\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        launch_binop_int::<R, kernel::BitwiseShlOp>(lhs, rhs)\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        launch_scalar_binop_int::<R, BitwiseShlOp>(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        launch_binop_int::<R, BitwiseShrOp>(lhs, rhs)\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let dtype = lhs.dtype;\n        launch_scalar_binop_int::<R, BitwiseShrOp>(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        kernel::cast(tensor, dtype.into())\n    }\n\n    fn int_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/mod.rs",
    "content": "mod activation;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\n\npub(crate) mod base;\npub use base::*;\npub use qtensor::*;\n\n/// Numeric utility functions for jit backends\npub mod numeric;\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/module.rs",
    "content": "use crate::{\n    CubeBackend, CubeRuntime, FloatElement, IntElement,\n    element::BoolElement,\n    kernel::{self, conv::ConvTranspose2dStrategy},\n};\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};\nuse burn_backend::{\n    TensorMetadata,\n    ops::{\n        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,\n        DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,\n    },\n};\n\nimpl<R, F, I, BT> ModuleOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    fn conv1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_forward::<R, 1>(x, weight, bias, options, Default::default()).unwrap()\n    }\n\n    fn conv1d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_data_backward(\n            output_grad,\n            weight,\n            x.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn conv1d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_weight_backward::<R, 1>(\n            x,\n            output_grad,\n            weight.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn conv2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_forward::<R, 2>(x, weight, bias, options, Default::default()).unwrap()\n    }\n\n    fn conv2d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_data_backward(\n            output_grad,\n            weight,\n            x.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn conv2d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_weight_backward::<R, 2>(\n            x,\n            output_grad,\n            weight.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::deform_conv2d(x, offset, weight, mask, bias, options).unwrap()\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward(\n            x,\n            offset,\n            weight,\n            mask,\n            bias,\n            output_grad,\n            options,\n        )\n        .unwrap();\n        DeformConv2dBackward::new(x, o, w, m, b)\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_forward::<R, 3>(x, weight, bias, options, Default::default()).unwrap()\n    }\n\n    fn conv3d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_data_backward(\n            output_grad,\n            weight,\n            x.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn conv3d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_weight_backward::<R, 3>(\n            x,\n            output_grad,\n            weight.shape(),\n            options,\n            Default::default(),\n        )\n        .unwrap()\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_transpose2d(x, weight, bias, options, ConvTranspose2dStrategy::default())\n            .unwrap()\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        kernel::conv::conv_transpose3d(x, weight, bias, options).expect(\"Kernel to never fail\")\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        kernel::pool::avg_pool2d(\n            x,\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n        )\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        kernel::pool::avg_pool2d_backward(\n            x,\n            grad,\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n        )\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        kernel::pool::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        let (output, indices) = kernel::pool::max_pool2d_with_indices(\n            x,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            I::dtype(),\n        );\n\n        MaxPool2dWithIndices::new(output, indices)\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool2dBackward<Self> {\n        MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward(\n            x,\n            output_grad,\n            indices,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n        ))\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        kernel::pool::adaptive_avg_pool2d(x, output_size)\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::pool::adaptive_avg_pool2d_backward(x, grad)\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        kernel::interpolate::interpolate(x, output_size, options)\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        kernel::interpolate::interpolate_backward(x, grad, output_size, options)\n    }\n\n    fn attention(\n        query: FloatTensor<Self>,\n        key: FloatTensor<Self>,\n        value: FloatTensor<Self>,\n        mask: Option<BoolTensor<Self>>,\n        attn_bias: Option<FloatTensor<Self>>,\n        options: AttentionModuleOptions,\n    ) -> FloatTensor<Self> {\n        // Fall back to naive attention for features the flash kernel doesn't support.\n        if attn_bias.is_some() || options.softcap.is_some() || options.scale.is_some() {\n            return burn_backend::ops::attention::attention_fallback::<Self>(\n                query, key, value, mask, attn_bias, options,\n            );\n        }\n\n        kernel::attention::attention(\n            query,\n            key,\n            value,\n            mask,\n            attn_bias,\n            options,\n            Default::default(),\n            None,\n        )\n        .expect(\"Kernel to never fail\")\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/numeric.rs",
    "content": "use crate::{\n    CubeRuntime,\n    kernel::utils::{address_type, shape_divmod},\n};\nuse crate::{element::CubeElement, tensor::CubeTensor};\nuse crate::{\n    kernel::{\n        AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,\n        launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,\n    },\n    ops::max_vector_size,\n};\nuse burn_backend::{DType, Shape, TensorMetadata};\nuse burn_std::Metadata;\nuse cubecl::{calculate_cube_count_elemwise, prelude::*};\nuse cubecl::{client::ComputeClient, server::MemoryLayout};\nuse cubecl::{\n    server::MemoryLayoutDescriptor,\n    std::{FastDivmod, tensor::layout::linear::LinearView},\n};\n\n/// Creates a tensor filled with `value`\npub fn full<R: CubeRuntime, E: CubeElement>(\n    shape: Shape,\n    device: &R::Device,\n    value: E,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n\n    full_client::<R, E>(client, shape, device.clone(), value)\n}\n\n/// Creates a tensor filled with `value`\npub fn full_client<R: CubeRuntime, E: CubeElement>(\n    client: ComputeClient<R>,\n    shape: Shape,\n    device: R::Device,\n    value: E,\n) -> CubeTensor<R> {\n    let dtype = E::dtype();\n    full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)\n}\n\n/// Creates a tensor filled with `value`\npub fn full_device_dtype<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    shape: Shape,\n    device: R::Device,\n    value: InputScalar,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let empty = empty_device_dtype(client, device, shape, dtype);\n\n    #[cube(launch_unchecked, address_type = \"dynamic\")]\n    pub fn full_kernel<C: Numeric, N: Size>(\n        tensor: &mut LinearView<Vector<C, N>, ReadWrite>,\n        value: InputScalar,\n        #[define(C)] _dtype: StorageType,\n    ) {\n        if !tensor.is_in_bounds(ABSOLUTE_POS) {\n            terminate!();\n        }\n\n        tensor[ABSOLUTE_POS] = Vector::new(value.get::<C>());\n    }\n\n    let num_elems = empty.meta.num_elements();\n    let vector_size = max_vector_size(&empty);\n\n    let working_units = num_elems / vector_size as usize;\n    let cube_dim = CubeDim::new(&empty.client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);\n\n    unsafe {\n        full_kernel::launch_unchecked(\n            &empty.client,\n            cube_count,\n            cube_dim,\n            address_type!(empty),\n            vector_size,\n            empty.clone().into_linear_view(),\n            value,\n            empty.dtype.into(),\n        );\n    }\n\n    empty\n}\n\n/// Creates a tensor filled with zeros\npub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {\n    let client = R::client(&device);\n    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)\n}\n\n/// Creates a tensor filled with ones\npub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {\n    let client = R::client(&device);\n    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)\n}\n\n/// Creates a tensor filled with zeros\npub fn zeros_client<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    shape: Shape,\n    dtype: DType,\n) -> CubeTensor<R> {\n    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)\n}\n\n/// Creates a tensor filled with ones\npub fn ones_client<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    shape: Shape,\n    dtype: DType,\n) -> CubeTensor<R> {\n    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)\n}\n\n/// Create a tensor with uninitialized memory\npub fn empty_device<R: CubeRuntime, E: CubeElement>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    shape: Shape,\n) -> CubeTensor<R> {\n    let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), size_of::<E>());\n\n    CubeTensor::new(\n        client,\n        memory,\n        Metadata::new(shape, strides),\n        device,\n        E::dtype(),\n    )\n}\n\n/// Create a tensor with uninitialized memory\npub fn empty_device_dtype<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    shape: Shape,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), dtype.size());\n\n    CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)\n}\n\n/// Create a contiguous tensor with uninitialized memory\npub fn empty_device_contiguous_dtype<R: CubeRuntime>(\n    client: ComputeClient<R>,\n    device: R::Device,\n    shape: Shape,\n    dtype: DType,\n) -> CubeTensor<R> {\n    let descriptor = MemoryLayoutDescriptor::contiguous(shape.clone(), dtype.size());\n    let MemoryLayout { memory, strides } = client.empty_tensors(vec![descriptor]).remove(0);\n\n    CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)\n}\n\n/// Add two tensors\npub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, AddOp>(lhs, rhs)\n}\n\n/// Add a tensor and a scalar\npub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop::<R, AddOp>(lhs, rhs)\n}\n\n/// Subtract two tensors\npub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, SubOp>(lhs, rhs)\n}\n\n/// Subtract a tensor and a scalar\npub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop::<R, SubOp>(lhs, rhs)\n}\n\n/// Multiply two tensors\npub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, MulOp>(lhs, rhs)\n}\n\n/// Multiply a tensor and a scalar\npub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop::<R, MulOp>(lhs, rhs)\n}\n\n/// Divide two tensors\npub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, DivOp>(lhs, rhs)\n}\n\n/// Divide a tensor by a scalar\npub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop::<R, DivOp>(lhs, rhs)\n}\n\n/// Calculate remainder of two tensors\npub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, RemainderOp>(lhs, rhs)\n}\n\n/// Calculate the remainder of a tensor with a scalar\npub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop::<R, RemainderOp>(lhs, rhs)\n}\n\n/// Calculate the power of two tensors\npub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop::<R, PowOp>(lhs, rhs)\n}\n\n/// Bitwise and two tensors\npub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)\n}\n\n/// Bitwise and with a scalar\npub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)\n}\n\n/// Bitwise or two tensors\npub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)\n}\n\n/// Bitwise or with a scalar\npub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)\n}\n\n/// Bitwise xor two tensors\npub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {\n    launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)\n}\n\n/// Bitwise xor with a scalar\npub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {\n    launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)\n}\n\n/// Operation family trait for cumulative operations\npub(crate) trait CumulativeOpFamily: Send + Sync + 'static {\n    type CumulativeOp<C: Numeric>: CumulativeOp<C>;\n}\n\n/// Trait for cumulative operations\n#[cube]\npub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {\n    /// Execute a cumulative operation\n    fn execute(lhs: C, rhs: C) -> C;\n\n    /// Get the initial value for the accumulator\n    fn init_value(first_element: C) -> C;\n}\n\n// Operation types\nstruct SumOp;\nstruct ProdOp;\nstruct MaxOp;\nstruct MinOp;\n\n// Implement CumulativeOpFamily for each operation\nimpl CumulativeOpFamily for SumOp {\n    type CumulativeOp<C: Numeric> = Self;\n}\n\nimpl CumulativeOpFamily for ProdOp {\n    type CumulativeOp<C: Numeric> = Self;\n}\n\nimpl CumulativeOpFamily for MaxOp {\n    type CumulativeOp<C: Numeric> = Self;\n}\n\nimpl CumulativeOpFamily for MinOp {\n    type CumulativeOp<C: Numeric> = Self;\n}\n\n// Implement CumulativeOp for each operation type\n#[cube]\nimpl<N: Numeric> CumulativeOp<N> for SumOp {\n    fn execute(lhs: N, rhs: N) -> N {\n        lhs + rhs\n    }\n\n    fn init_value(_first_element: N) -> N {\n        N::zero()\n    }\n}\n\n#[cube]\nimpl<N: Numeric> CumulativeOp<N> for ProdOp {\n    fn execute(lhs: N, rhs: N) -> N {\n        lhs * rhs\n    }\n\n    fn init_value(_first_element: N) -> N {\n        N::from_int(1)\n    }\n}\n\n#[cube]\nimpl<N: Numeric> CumulativeOp<N> for MaxOp {\n    fn execute(lhs: N, rhs: N) -> N {\n        max(lhs, rhs)\n    }\n\n    fn init_value(first_element: N) -> N {\n        first_element\n    }\n}\n\n#[cube]\nimpl<N: Numeric> CumulativeOp<N> for MinOp {\n    fn execute(lhs: N, rhs: N) -> N {\n        min(lhs, rhs)\n    }\n\n    fn init_value(first_element: N) -> N {\n        first_element\n    }\n}\n\n/// Generic cumulative operation kernel\n///\n/// # Limitations\n///\n/// This is a **naive sequential implementation** along the cumulative dimension:\n/// - Each output element sequentially reads all previous elements along the dimension\n/// - Computational complexity: O(n^2) memory reads where n is the size of the cumulative dimension\n/// - **Performance:** Suitable for small tensors or small dimensions. For large tensors,\n///   performance will degrade significantly compared to an optimized parallel scan algorithm.\n///\n/// # TODO\n///\n/// Implement an efficient GPU-optimized parallel scan algorithm.\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(\n    input: &Tensor<C>,\n    output: &mut LinearView<C, ReadWrite>,\n    shape: Sequence<FastDivmod<usize>>,\n    #[comptime] dim: usize,\n    #[define(C)] _dtype: StorageType,\n) {\n    if !output.is_in_bounds(ABSOLUTE_POS) {\n        terminate!();\n    }\n\n    let rank = comptime![shape.len()];\n    let dim_stride = input.stride(dim);\n\n    let mut remainder = ABSOLUTE_POS;\n    let mut offset = 0;\n    let mut dim_idx = 0;\n\n    #[unroll]\n    for i in 0..shape.len() {\n        let i = comptime![rank - i - 1];\n        let (rem, local_idx) = shape.index(i).div_mod(remainder);\n        remainder = rem;\n        if i == dim {\n            dim_idx = local_idx;\n        } else {\n            offset += local_idx * input.stride(i);\n        }\n    }\n\n    // Read first element\n    let first_read_idx = offset + dim_idx * dim_stride;\n    let first_elem = input[first_read_idx];\n\n    // Initialize accumulator\n    let mut result = O::CumulativeOp::<C>::init_value(first_elem);\n\n    // Accumulate values\n    for i in 0..=dim_idx {\n        let read_idx = offset + i * dim_stride;\n        result = O::CumulativeOp::<C>::execute(result, input[read_idx]);\n    }\n    output[ABSOLUTE_POS] = result;\n}\n\n/// Compute the cumulative sum along a dimension\npub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {\n    cumulative_op::<R, SumOp>(input, dim)\n}\n\n/// Compute the cumulative product along a dimension\npub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {\n    cumulative_op::<R, ProdOp>(input, dim)\n}\n\n/// Compute the cumulative minimum along a dimension\npub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {\n    cumulative_op::<R, MinOp>(input, dim)\n}\n\n/// Compute the cumulative maximum along a dimension\npub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {\n    cumulative_op::<R, MaxOp>(input, dim)\n}\n\n/// Generic cumulative operation function\nfn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(\n    input: CubeTensor<R>,\n    dim: usize,\n) -> CubeTensor<R> {\n    let client = input.client.clone();\n    let device = input.device.clone();\n\n    let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype);\n\n    let num_elems = output.meta.num_elements();\n    let working_units = num_elems;\n    let cube_dim = CubeDim::new(&client, working_units);\n    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);\n    let shape = shape_divmod(&input);\n\n    unsafe {\n        cumulative_kernel::launch_unchecked::<O, R>(\n            &client,\n            cube_count,\n            cube_dim,\n            address_type!(input, output),\n            input.into_tensor_arg(),\n            output.clone().into_linear_view(),\n            shape,\n            dim,\n            output.dtype.into(),\n        );\n    }\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorMetadata,\n    TensorPrimitive,\n    ops::QTensorOps,\n    quantization::{\n        QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue,\n        QuantizationParametersPrimitive, params_shape,\n    },\n    tensor::{Device, FloatElem, FloatTensor, IntTensor, QuantizedTensor},\n};\nuse burn_std::Metadata;\nuse cubecl::server::{MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutStrategy};\nuse cubecl::{e2m1x2, quant::scheme::QuantStore};\n\nuse crate::{\n    CubeBackend, CubeRuntime, FloatElement, IntElement,\n    element::BoolElement,\n    kernel::{self, matmul::MatmulStrategy},\n    tensor::{CubeTensor, QParams},\n};\n\nuse super::{into_data, permute, swap_dims};\n\n/// Create a quantized tensor with packed values (u32).\nfn new_qtensor_optimized<R: CubeRuntime>(\n    data: Bytes,\n    shape: impl Into<Shape>,\n    scheme: QuantScheme,\n    device: &R::Device,\n) -> CubeTensor<R> {\n    new_qtensor(data, shape, scheme, device, MemoryLayoutStrategy::Optimized)\n}\n\n/// Create a quantized tensor with packed values (u32).\nfn new_qtensor<R: CubeRuntime>(\n    data: Bytes,\n    shape: impl Into<Shape>,\n    scheme: QuantScheme,\n    device: &R::Device,\n    kind: MemoryLayoutStrategy,\n) -> CubeTensor<R> {\n    new_quantized(shape, scheme, device, Some(data), kind)\n}\n\n/// Create an empty quantized tensor.\npub fn empty_qtensor_optimized<R: CubeRuntime>(\n    shape: impl Into<Shape>,\n    scheme: QuantScheme,\n    device: &R::Device,\n) -> CubeTensor<R> {\n    empty_qtensor(shape, scheme, device, MemoryLayoutStrategy::Optimized)\n}\n\n/// Create an empty quantized tensor.\npub fn empty_qtensor<R: CubeRuntime>(\n    shape: impl Into<Shape>,\n    scheme: QuantScheme,\n    device: &R::Device,\n    kind: MemoryLayoutStrategy,\n) -> CubeTensor<R> {\n    new_quantized(shape, scheme, device, None, kind)\n}\n\nfn new_quantized<R: CubeRuntime>(\n    shape: impl Into<Shape>,\n    scheme: QuantScheme,\n    device: &R::Device,\n    data: Option<Bytes>,\n    alloc_kind: MemoryLayoutStrategy,\n) -> CubeTensor<R> {\n    let client = R::client(device);\n    let shape: Shape = shape.into();\n    let mut shape_value: Shape = shape.clone();\n\n    let rank = shape.rank();\n    let shape_last = shape[rank - 1];\n    let num_quants = scheme.num_quants();\n\n    let data_size = match scheme.store {\n        QuantStore::PackedU32(_) => {\n            if !shape_last.is_multiple_of(num_quants) {\n                panic!(\"Can't store in u32\")\n            }\n            shape_value[rank - 1] = shape_last.div_ceil(num_quants);\n            size_of::<u32>()\n        }\n        QuantStore::Native => match scheme.value {\n            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {\n                size_of::<i8>()\n            }\n            QuantValue::Q4F\n            | QuantValue::Q4S\n            | QuantValue::Q2F\n            | QuantValue::Q2S\n            | QuantValue::E2M1 => {\n                panic!(\"Can't store native sub-byte values\")\n            }\n        },\n        QuantStore::PackedNative(_) => match scheme.value {\n            QuantValue::E2M1 => size_of::<e2m1x2>(),\n            other => panic!(\"{other:?} doesn't support native packing\"),\n        },\n    };\n\n    let scales_dtype = match scheme.param {\n        QuantParam::F32 => DType::F32,\n        QuantParam::F16 => DType::F16,\n        QuantParam::BF16 => DType::BF16,\n        // Represented by U8 and reinterpreted in the kernel\n        QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,\n    };\n\n    let scales_shape = params_shape(&shape, scheme.level);\n    let data_desc = MemoryLayoutDescriptor::new(alloc_kind, shape_value.clone(), data_size);\n    let scales_desc =\n        MemoryLayoutDescriptor::new(alloc_kind, scales_shape.clone(), scales_dtype.size());\n\n    let mut tensors = match data {\n        Some(data) => {\n            let num_bytes = shape_value.num_elements() * data_size;\n\n            match data.split(num_bytes) {\n                Ok((bytes_data, bytes_scales)) => client\n                    .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]),\n                Err((data, _)) => client.create_tensors_from_slices(vec![\n                    (data_desc, &data[..num_bytes]),\n                    (scales_desc, &data[num_bytes..]),\n                ]),\n            }\n        }\n        None => client.empty_tensors(vec![data_desc, scales_desc]),\n    };\n    let MemoryLayout {\n        memory: scales_handle,\n        strides: scales_strides,\n    } = tensors.remove(1);\n    let MemoryLayout { memory, strides } = tensors.remove(0);\n\n    let scales = QParamTensor {\n        offset_start: scales_handle.offset_start.unwrap_or(0) as usize,\n        offset_end: scales_handle.offset_end.unwrap_or(0) as usize,\n        metadata: Metadata::new(scales_shape, scales_strides),\n        dtype: scales_dtype,\n    };\n    let qparams = QParams { scales };\n\n    CubeTensor::new_quantized(\n        client,\n        memory,\n        shape,\n        device.clone(),\n        strides,\n        DType::QFloat(scheme),\n        qparams,\n    )\n}\n\nimpl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {\n        match data.dtype {\n            DType::QFloat(scheme) => match scheme {\n                QuantScheme {\n                    level: QuantLevel::Tensor | QuantLevel::Block(_),\n                    mode: QuantMode::Symmetric,\n                    value:\n                        QuantValue::Q8F\n                        | QuantValue::Q8S\n                        | QuantValue::Q4F\n                        | QuantValue::Q4S\n                        | QuantValue::Q2F\n                        | QuantValue::Q2S\n                        | QuantValue::E4M3\n                        | QuantValue::E5M2\n                        | QuantValue::E2M1,\n                    ..\n                } => {\n                    // TensorData quantized representation is the same, with multiple quantized values\n                    // packed into u32 and quantization parameters appended to the bytes\n                    new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device)\n                }\n            },\n            _ => panic!(\n                \"Invalid dtype (expected DType::QFloat, got {:?})\",\n                data.dtype\n            ),\n        }\n    }\n\n    // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor)\n\n    fn quantize(\n        tensor: FloatTensor<Self>,\n        scheme: &QuantScheme,\n        qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        kernel::quantization::quantize(tensor, scheme, qparams.scales)\n    }\n\n    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        kernel::quantization::dequantize(tensor, FloatElem::<Self>::dtype())\n    }\n\n    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {\n        tensor.device.clone()\n    }\n\n    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {\n        super::to_device(tensor, device)\n    }\n\n    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        super::q_reshape(tensor, shape)\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        if tensor.qparams.is_none() {\n            return into_data(tensor).await;\n        }\n\n        let (shape, dtype) = (tensor.shape(), tensor.dtype);\n        let (values, params) = tensor.quantized_handles().unwrap();\n\n        let mut data_values = into_data(values).await?;\n        let data_params = into_data(params).await?;\n\n        data_values.bytes.extend_from_byte_slice(&data_params.bytes);\n\n        Ok(TensorData {\n            bytes: data_values.bytes,\n            shape,\n            dtype,\n        })\n    }\n\n    fn q_swap_dims(\n        tensor: QuantizedTensor<Self>,\n        dim1: usize,\n        dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        swap_dims(tensor, dim1, dim2)\n    }\n\n    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        permute(tensor, axes)\n    }\n\n    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_gather(\n        _dim: usize,\n        _tensor: QuantizedTensor<Self>,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_select(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {\n        let (propagation, scheme) = match (&lhs, &rhs) {\n            (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),\n            (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),\n            _ => unreachable!(),\n        };\n\n        // Inherit precision for mixed inputs, default to `FloatElem` for fully quantized.\n        let out_dtype = match (&lhs, &rhs) {\n            (TensorPrimitive::Float(lhs), _) => lhs.dtype,\n            (_, TensorPrimitive::Float(rhs)) => rhs.dtype,\n            _ => F::dtype(),\n        };\n\n        let (_lhs_dtype, lhs) = match lhs {\n            TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),\n            TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),\n        };\n        let (_rhs_dtype, rhs) = match rhs {\n            TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),\n            TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),\n        };\n\n        let out =\n            kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap();\n\n        match propagation {\n            QuantPropagation::Propagate => {\n                TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))\n            }\n            QuantPropagation::Inhibit => TensorPrimitive::Float(out),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/tensor.rs",
    "content": "use super::{expand, numeric, permute, unfold};\nuse crate::CubeBackend;\nuse crate::element::bool_dtype;\nuse crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};\nuse crate::kernel::unary_basic::BasicFloatUnaryKind;\nuse crate::kernel::{\n    self, FloatUnaryOp, FloatUnaryOpFamily, launch_unary_float, reduce, unary_basic,\n};\nuse crate::{CubeRuntime, FloatElement, IntElement};\nuse crate::{\n    element::BoolElement,\n    kernel::matmul::{MatmulStrategy, matmul},\n};\nuse burn_backend::ops::GridSampleOptions;\nuse burn_backend::tensor::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor};\nuse burn_backend::{Backend, ExecutionError, Scalar};\nuse burn_backend::{DType, ElementConversion, FloatDType, Slice};\nuse burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps};\nuse cubecl::prelude::*;\nuse cubek::reduce::components::instructions::ReduceOperationConfig;\nuse std::ops::Range;\n\nimpl<R, F, I, BT> FloatTensorOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(data),\n        fields(?data.shape, ?data.dtype)\n    ))]\n    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {\n        match data.dtype {\n            DType::F64 | DType::F32 | DType::F16 | DType::BF16 => super::from_data(data, device),\n            _ => unimplemented!(\"Unsupported dtype for `float_from_data`\"),\n        }\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> FloatTensor<Self> {\n        let dtype = FloatElem::<Self>::dtype();\n        match distribution {\n            Distribution::Default => random_uniform(shape, device, 0., 1., dtype),\n            Distribution::Uniform(low, high) => {\n                random_uniform(shape, device, low.elem(), high.elem(), dtype)\n            }\n            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob as f32, dtype),\n            Distribution::Normal(mean, std) => {\n                random_normal(shape, device, mean.elem(), std.elem(), dtype)\n            }\n        }\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)\n    ))]\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        super::into_data(tensor).await\n    }\n\n    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {\n        tensor.device.clone()\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype)\n    ))]\n    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {\n        super::to_device(tensor, device)\n    }\n\n    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let dtype = dtype.into();\n        super::empty(shape, device, dtype)\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::add(lhs, rhs)\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::add_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let dtype = dtype.into();\n        numeric::zeros(device.clone(), shape, dtype)\n    }\n\n    fn float_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &R::Device,\n        dtype: FloatDType,\n    ) -> FloatTensor<Self> {\n        let dtype: DType = dtype.into();\n        let client = R::client(device);\n        numeric::full_device_dtype(\n            client,\n            shape,\n            device.clone(),\n            InputScalar::new(fill_value, dtype),\n            dtype,\n        )\n    }\n\n    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let dtype = dtype.into();\n        numeric::ones(device.clone(), shape, dtype)\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::sub(lhs, rhs)\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::sub_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::mul(lhs, rhs)\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::mul_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::div(lhs, rhs)\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::div_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::remainder(lhs, rhs)\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        numeric::remainder_scalar(lhs, InputScalar::new(rhs, dtype))\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let dtype = lhs.dtype;\n        matmul(lhs, rhs, None, MatmulStrategy::default(), dtype).unwrap()\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        kernel::cross(lhs, rhs, dim)\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        super::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        super::reshape(tensor, shape)\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::gather(dim, tensor, indices)\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::scatter(dim, tensor, indices, value, false)\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::select(tensor, dim, indices)\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::select_assign(tensor, dim, indices, value, false)\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        // Check if all steps are 1\n        let all_steps_one = slices.iter().all(|info| info.step == 1);\n\n        if all_steps_one {\n            // Use optimized slice for step=1\n            let simple_ranges: Vec<Range<usize>> = slices\n                .iter()\n                .enumerate()\n                .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))\n                .collect();\n\n            kernel::slice(tensor, &simple_ranges)\n        } else {\n            // Use slice with steps kernel\n            kernel::slice_with_steps(tensor, slices)\n        }\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        ranges: &[Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::slice_assign(tensor, ranges, value)\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        kernel::mask_where_auto(tensor, mask, value, bool_dtype::<BT>())\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        let dtype = tensor.dtype;\n        kernel::mask_fill_auto(\n            tensor,\n            mask,\n            InputScalar::new(value, dtype),\n            bool_dtype::<BT>(),\n        )\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::greater(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::greater_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::greater_equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::greater_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::lower(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::lower_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::lower_equal(lhs, rhs, bool_dtype::<BT>())\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let dtype = lhs.dtype;\n        kernel::lower_equal_elem(lhs, InputScalar::new(rhs, dtype), bool_dtype::<BT>())\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        reduce::sum_fallback(tensor, Default::default()).unwrap()\n    }\n\n    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Max).unwrap()\n    }\n\n    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Max,\n        )\n        .unwrap()\n    }\n\n    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        reduce::reduce(tensor, None, Default::default(), ReduceOperationConfig::Min).unwrap()\n    }\n\n    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Min,\n        )\n        .unwrap()\n    }\n\n    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        reduce::reduce(\n            tensor,\n            None,\n            Default::default(),\n            ReduceOperationConfig::MaxAbs,\n        )\n        .unwrap()\n    }\n\n    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::MaxAbs,\n        )\n        .unwrap()\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Sum,\n        )\n        .unwrap()\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Mean,\n        )\n        .unwrap()\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        numeric::cumsum(tensor, dim)\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        numeric::cumprod(tensor, dim)\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        numeric::cummin(tensor, dim)\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        numeric::cummax(tensor, dim)\n    }\n\n    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        reduce::reduce(\n            tensor,\n            None,\n            Default::default(),\n            ReduceOperationConfig::Prod,\n        )\n        .unwrap()\n    }\n\n    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            None,\n            dim,\n            Default::default(),\n            ReduceOperationConfig::Prod,\n        )\n        .unwrap()\n    }\n\n    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Exp)\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log)\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Log1p)\n    }\n\n    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        struct Powf;\n\n        #[cube]\n        impl<F: Float, N: Size> FloatUnaryOp<F, N> for Powf {\n            type Options = InputScalar;\n\n            fn execute(input: Vector<F, N>, options: &Self::Options) -> Vector<F, N> {\n                Vector::powf(input, Vector::new(options.get::<F>()))\n            }\n        }\n\n        impl FloatUnaryOpFamily for Powf {\n            type Options = InputScalar;\n            type Unary<F: Float, N: Size> = Self;\n        }\n\n        let dtype = lhs.dtype;\n        launch_unary_float::<R, Powf, _>(lhs, |_| InputScalar::new(rhs, dtype))\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sqrt)\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Abs)\n    }\n\n    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sign)\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cos)\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sin)\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tan)\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Cosh)\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Sinh)\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Tanh)\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCos)\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcCosh)\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSin)\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcSinh)\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTan)\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::ArcTanh)\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        crate::kernel::atan2::<R>(lhs, rhs)\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Round)\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Floor)\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Ceil)\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Trunc)\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Erf)\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            Some(<Self as Backend>::IntElem::dtype()),\n            dim,\n            Default::default(),\n            ReduceOperationConfig::ArgMax,\n        )\n        .unwrap()\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce::reduce_dim(\n            tensor,\n            Some(<Self as Backend>::IntElem::dtype()),\n            dim,\n            Default::default(),\n            ReduceOperationConfig::ArgMin,\n        )\n        .unwrap()\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {\n        kernel::cast(tensor, I::dtype())\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        let dtype = tensor.dtype;\n        kernel::clamp(\n            tensor,\n            InputScalar::new(min, dtype),\n            InputScalar::new(max, dtype),\n        )\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_basic::launch::<R, _>(tensor, |_| BasicFloatUnaryKind::Recip)\n    }\n\n    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {\n        kernel::repeat_dim(tensor, dim, times)\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        numeric::pow(lhs, rhs)\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        permute(tensor, axes)\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        expand(tensor, shape)\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        kernel::flip(tensor, axes, bool_dtype::<BT>())\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        kernel::cast(tensor, dtype.into())\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        unfold(tensor, dim, size, step)\n    }\n\n    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::is_nan(tensor, bool_dtype::<BT>())\n    }\n\n    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        kernel::is_inf(tensor, bool_dtype::<BT>())\n    }\n\n    fn float_grid_sample_2d(\n        tensor: FloatTensor<Self>,\n        grid: FloatTensor<Self>,\n        options: GridSampleOptions,\n    ) -> FloatTensor<Self> {\n        kernel::grid_sample::grid_sample(tensor, grid, options)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/ops/transaction.rs",
    "content": "use burn_backend::{\n    DType, TensorData,\n    backend::ExecutionError,\n    ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},\n};\nuse burn_std::{Shape, Strides};\nuse cubecl::server::{CopyDescriptor, Handle};\n\nuse crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};\n\nimpl<R, F, I, BT> TransactionOps<Self> for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    async fn tr_execute(\n        transaction: TransactionPrimitive<Self>,\n    ) -> Result<TransactionPrimitiveData, ExecutionError> {\n        let mut client = None;\n\n        enum Kind {\n            Float,\n            Int,\n            Bool,\n        }\n\n        #[derive(new)]\n        struct BindingData {\n            index: usize,\n            kind: Kind,\n            handle: Option<Handle>,\n            shape: Shape,\n            strides: Strides,\n            dtype: DType,\n        }\n\n        let mut num_bindings = 0;\n\n        let mut kinds = Vec::new();\n\n        for t in transaction.read_floats.into_iter() {\n            if client.is_none() {\n                client = Some(t.client.clone());\n            }\n\n            let t = crate::kernel::into_contiguous_aligned(t);\n            let binding = BindingData::new(\n                num_bindings,\n                Kind::Float,\n                Some(t.handle.clone()),\n                t.meta.shape.clone(),\n                t.meta.strides.clone(),\n                t.dtype,\n            );\n\n            kinds.push(binding);\n            num_bindings += 1;\n        }\n        for t in transaction.read_ints.into_iter() {\n            if client.is_none() {\n                client = Some(t.client.clone());\n            }\n\n            let t = crate::kernel::into_contiguous_aligned(t);\n            let binding = BindingData::new(\n                num_bindings,\n                Kind::Int,\n                Some(t.handle.clone()),\n                t.meta.shape.clone(),\n                t.meta.strides.clone(),\n                t.dtype,\n            );\n\n            kinds.push(binding);\n            num_bindings += 1;\n        }\n        for t in transaction.read_bools.into_iter() {\n            if client.is_none() {\n                client = Some(t.client.clone());\n            }\n\n            let t = crate::kernel::into_contiguous_aligned(t);\n            let binding = BindingData::new(\n                num_bindings,\n                Kind::Bool,\n                Some(t.handle.clone()),\n                t.meta.shape.clone(),\n                t.meta.strides.clone(),\n                t.dtype,\n            );\n\n            kinds.push(binding);\n            num_bindings += 1;\n        }\n\n        let client = client.unwrap();\n\n        let bindings = kinds\n            .iter_mut()\n            .map(|b| {\n                CopyDescriptor::new(\n                    b.handle.take().unwrap().binding(),\n                    b.shape.clone(),\n                    b.strides.clone(),\n                    b.dtype.size(),\n                )\n            })\n            .collect();\n\n        let mut data: Vec<Option<_>> = client\n            .read_tensor_async(bindings)\n            .await\n            .map_err(|err| ExecutionError::WithContext {\n                reason: format!(\"{err:?}\"),\n            })?\n            .into_iter()\n            .map(Some)\n            .collect::<Vec<Option<_>>>();\n\n        let mut result = TransactionPrimitiveData::default();\n\n        for binding in kinds {\n            let bytes = data.get_mut(binding.index).unwrap().take().unwrap();\n            let t_data = TensorData::from_bytes(bytes, binding.shape, binding.dtype);\n\n            match binding.kind {\n                Kind::Float => {\n                    result.read_floats.push(t_data);\n                }\n                Kind::Int => {\n                    result.read_ints.push(t_data);\n                }\n                Kind::Bool => {\n                    result.read_bools.push(t_data);\n                }\n            }\n        }\n\n        Ok(result)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/template/base.rs",
    "content": "use super::SourceTemplate;\nuse crate::{CubeRuntime, element::CubeElement, tensor::CubeTensor};\nuse cubecl::{CompilationError, Compiler, CubeTask, prelude::*};\n\n/// Kernel source to create a [source](SourceTemplate)\npub trait KernelSource: Send + 'static + Sync {\n    /// Convert to [source](SourceTemplate)\n    fn source(&self) -> SourceTemplate;\n    /// Identifier for the kernel, used for caching kernel compilation.\n    fn id(&self) -> KernelId;\n}\n\n#[derive(new)]\n/// Wraps a [kernel source](KernelSource) into a [cube task](CubeTask).\npub struct SourceKernel<K> {\n    kernel_source: K,\n    cube_dim: CubeDim,\n}\n\nimpl<C: Compiler, K: KernelSource> CubeTask<C> for SourceKernel<K> {\n    fn compile(\n        &self,\n        _compiler: &mut C,\n        _options: &C::CompilationOptions,\n        _mode: ExecutionMode,\n        _address_type: StorageType,\n    ) -> Result<CompiledKernel<C>, CompilationError> {\n        let source_template = self.kernel_source.source();\n        let source = source_template.complete();\n\n        Ok(CompiledKernel {\n            entrypoint_name: \"main\".to_string(),\n            debug_name: Some(core::any::type_name::<K>()),\n            source,\n            cube_dim: self.cube_dim,\n            debug_info: None,\n            repr: None,\n        })\n    }\n}\n\nimpl<K: KernelSource> KernelMetadata for SourceKernel<K> {\n    fn id(&self) -> KernelId {\n        self.kernel_source.id()\n    }\n\n    fn address_type(&self) -> StorageType {\n        u32::as_type_native_unchecked().storage_type()\n    }\n}\n\n/// Generates kernel source code by replacing some information using templating.\n#[macro_export]\nmacro_rules! kernel_source {\n    (\n        $struct:ident,\n        $file:expr\n    ) => {\n        /// Generated kernel from a source file.\n        #[derive(new)]\n        pub struct $struct;\n\n        impl $struct {\n            fn source(&self) -> $crate::template::SourceTemplate {\n                $crate::template::SourceTemplate::new(include_str!($file))\n            }\n        }\n    };\n}\n\n/// Create a vector containing the dimension, strides and shape of tensors.\n///\n/// # Example\n///\n/// With two tensors (lhs, rhs)\n///\n/// | Indexes                  | Value       |\n/// |:------------------------:|:-----------:|\n/// |           0..1           | D           |\n/// |           1..D + 1       | lhs strides |\n/// |     (D + 1)..(2 * D + 1) | rhs strides |\n/// | (2 * D + 1)..(3 * D + 1) | lhs shape   |\n/// | (3 * D + 1)..(4 * D + 1) | rhs shape   |\npub fn build_info<R: CubeRuntime, E: CubeElement>(tensors: &[&CubeTensor<R>]) -> Vec<u32> {\n    let ndims = tensors[0].meta.num_dims();\n    let mut info: Vec<u32> = vec![0; tensors.len() * 2 * ndims + 1];\n    info[0] = ndims as u32;\n\n    let mut current = 1;\n    for tensor in tensors.iter() {\n        for d in 0..ndims {\n            info[current] = tensor.meta.strides()[d] as u32;\n            current += 1;\n        }\n    }\n    for tensor in tensors.iter() {\n        for d in 0..ndims {\n            info[current] = tensor.meta.shape()[d] as u32;\n            current += 1;\n        }\n    }\n    info\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/template/mod.rs",
    "content": "mod base;\npub use base::*;\n\nmod source;\npub use source::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/template/source.rs",
    "content": "use std::collections::HashMap;\n\n/// Kernel source code abstraction allowing for templating.\n///\n/// The templates can have text placeholders in the form {{ label }}.\n/// They will be updated with their proper value when `generate` is called.\n#[derive(Debug)]\npub struct SourceTemplate {\n    items: HashMap<String, String>,\n    templates: Vec<String>,\n}\n\nimpl SourceTemplate {\n    /// Create a new source template.\n    pub fn new<S>(template: S) -> Self\n    where\n        S: Into<String>,\n    {\n        Self {\n            items: HashMap::new(),\n            templates: vec![template.into()],\n        }\n    }\n\n    /// Register the value for a placeholder item.\n    ///\n    /// # Notes\n    ///\n    /// The value can't have placeholders, since it would require recursive templating with\n    /// possibly circular dependencies. If you want to add a value that has some\n    /// placeholders, consider adding a new template to the source using\n    /// [add_template](SourceTemplate::add_template). The added template can be a function, and you can\n    /// register the function call instead.\n    pub fn register<Name, Value>(mut self, name: Name, value: Value) -> Self\n    where\n        Name: Into<String>,\n        Value: Into<String>,\n    {\n        self.items.insert(name.into(), value.into());\n        self\n    }\n\n    /// Add a new template.\n    pub fn add_template<S>(mut self, template: S) -> Self\n    where\n        S: Into<String>,\n    {\n        self.templates.push(template.into());\n        self\n    }\n\n    /// Complete the template and returns the source code.\n    pub fn complete(mut self) -> String {\n        let mut source = self.templates.remove(0);\n\n        for s in self.templates.into_iter() {\n            source.push_str(&s);\n        }\n\n        let template = text_placeholder::Template::new(&source);\n        let mut context = HashMap::new();\n\n        for (key, value) in self.items.iter() {\n            context.insert(key.as_str(), value.as_str());\n        }\n\n        template.fill_with_hashmap(&context)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/tensor/base.rs",
    "content": "use crate::CubeRuntime;\nuse crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};\nuse burn_backend::quantization::QuantScheme;\nuse burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata};\nuse burn_std::{Metadata, strides, tensor::is_contiguous};\nuse cubecl::server::Handle;\nuse cubecl::std::tensor::TensorHandle;\nuse cubecl::{client::ComputeClient, std::tensor::layout::linear::LinearViewLaunch};\nuse cubecl::{frontend::Numeric, std::tensor::layout::linear::LinearViewLayoutLaunch};\nuse cubecl::{\n    prelude::{TensorBinding, *},\n    std::tensor::layout::linear::LinearViewLayout,\n};\nuse std::marker::PhantomData;\n\nuse super::QParams;\n\n/// The basic tensor primitive struct.\npub struct CubeTensor<R: CubeRuntime> {\n    /// Compute client for the [runtime](CubeRuntime).\n    pub client: ComputeClient<R>,\n    /// The buffer where the data are stored.\n    pub handle: Handle,\n    /// The metadata of the tensor.\n    pub meta: Box<Metadata>,\n    /// The device of the tensor.\n    pub device: R::Device,\n    /// The datatype of the tensor.\n    pub dtype: DType,\n    /// Runtime quantization parameters, if applicable\n    pub qparams: Option<QParams>,\n}\n\nimpl<R: CubeRuntime> From<CubeTensor<R>> for TensorHandle<R> {\n    fn from(val: CubeTensor<R>) -> Self {\n        TensorHandle::new(\n            val.handle.clone(),\n            val.meta.shape().clone(),\n            val.meta.strides().clone(),\n            val.dtype,\n        )\n    }\n}\n\nimpl<R: CubeRuntime> cubecl::tune::AutotuneOutput for CubeTensor<R> {\n    #[cfg(feature = \"autotune-checks\")]\n    fn check_equivalence(&self, other: Self) {\n        use crate::ops::into_data_sync;\n        use burn_backend::Tolerance;\n\n        let expected = into_data_sync::<R>(self.clone());\n        let actual = into_data_sync::<R>(other);\n        expected.assert_approx_eq::<f32>(&actual, Tolerance::permissive());\n    }\n}\n\n// TODO: Needed to cleanup leaves tensor.\n//\n// Maybe not needed when fusion is activated, since we have a detector there.\n// We could rely on basic GC strategy when not using fusion.\n//\n// impl<R: CubeRuntime> Drop for CubeTensor<R> {\n//     fn drop(&mut self) {\n//         todo!()\n//     }\n// }\n\nimpl<R> core::fmt::Debug for CubeTensor<R>\nwhere\n    R: CubeRuntime,\n{\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_fmt(format_args!(\n            \"CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}\",\n            self.meta.shape(),\n            self.device,\n            self.meta.strides(),\n            self.dtype.name(),\n            R::name(&self.client),\n        ))\n    }\n}\n\nimpl<R> Clone for CubeTensor<R>\nwhere\n    R: CubeRuntime,\n{\n    fn clone(&self) -> Self {\n        Self {\n            client: self.client.clone(),\n            handle: self.handle.clone(),\n            meta: self.meta.clone(),\n            device: self.device.clone(),\n            dtype: self.dtype,\n            qparams: self.qparams.clone(),\n        }\n    }\n}\n\nimpl<R: CubeRuntime> TensorMetadata for CubeTensor<R> {\n    fn dtype(&self) -> DType {\n        self.dtype\n    }\n\n    fn shape(&self) -> Shape {\n        self.meta.shape().clone()\n    }\n\n    fn rank(&self) -> usize {\n        self.meta.rank()\n    }\n}\n\nimpl<R: CubeRuntime> QTensorPrimitive for CubeTensor<R> {\n    fn scheme(&self) -> &QuantScheme {\n        if let DType::QFloat(scheme) = &self.dtype {\n            scheme\n        } else {\n            panic!(\n                \"Quantization scheme is not valid for dtype {:?}\",\n                self.dtype,\n            )\n        }\n    }\n}\n\nimpl<R> CubeTensor<R>\nwhere\n    R: CubeRuntime,\n{\n    /// Create a new standard tensor\n    pub fn new(\n        client: ComputeClient<R>,\n        handle: Handle,\n        metadata: Metadata,\n        device: R::Device,\n        dtype: DType,\n    ) -> Self {\n        CubeTensor {\n            client,\n            handle,\n            meta: Box::new(metadata),\n            device,\n            dtype,\n            qparams: None,\n        }\n    }\n\n    /// Create a new tensor with a contiguous memory layout.\n    pub fn new_contiguous(\n        client: ComputeClient<R>,\n        device: R::Device,\n        shape: Shape,\n        handle: Handle,\n        dtype: DType,\n    ) -> Self {\n        let ndims = shape.num_dims();\n        let mut strides = strides![0; ndims];\n        let mut current = 1;\n\n        shape.iter().enumerate().rev().for_each(|(index, val)| {\n            strides[index] = current;\n            current *= val;\n        });\n\n        Self {\n            client,\n            handle,\n            meta: Box::new(Metadata::new(shape, strides)),\n            device,\n            dtype,\n            qparams: None,\n        }\n    }\n\n    /// Change the context of the current tensor and return the newly transferred tensor.\n    pub fn to_client(&self, client: ComputeClient<R>, device: R::Device) -> Self {\n        let desc = self.handle.clone().copy_descriptor(\n            self.meta.shape().clone(),\n            self.meta.strides().clone(),\n            self.elem_size(),\n        );\n        let handle = self.client.to_client_tensor(desc, &client);\n\n        Self {\n            client,\n            handle,\n            meta: Box::new(Metadata::new(self.shape(), self.meta.strides().clone())),\n            device,\n            dtype: self.dtype,\n            qparams: self.qparams.clone(),\n        }\n    }\n\n    /// Return the reference to a tensor handle.\n    pub fn binding(self) -> TensorBinding<R> {\n        TensorBinding {\n            handle: self.handle.binding(),\n            strides: self.meta.strides,\n            shape: self.meta.shape,\n            runtime: PhantomData,\n        }\n    }\n\n    /// Returns the element size of this tensor\n    pub fn elem_size(&self) -> usize {\n        self.dtype.size()\n    }\n\n    /// Return the reference to a tensor argument.\n    pub fn into_tensor_arg(self) -> TensorArg<R> {\n        self.binding().into_tensor_arg()\n    }\n\n    /// Return the reference to an array argument.\n    pub fn into_array_arg(self) -> ArrayArg<R> {\n        self.into_tensor_arg().into_array_arg()\n    }\n\n    /// Returns a reference to the aliased tensor argument.\n    pub fn as_tensor_alias(&self, input_pos: usize) -> TensorArg<R> {\n        TensorArg::Alias {\n            input_pos,\n            strides: self.meta.strides().clone(),\n            shape: self.meta.shape().clone(),\n        }\n    }\n\n    /// Return a linear view of this tensor.\n    pub fn into_linear_view(self) -> LinearViewLaunch<R> {\n        let layout = LinearViewLayoutLaunch::new();\n        let buffer = self.into_tensor_arg();\n        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)\n    }\n\n    /// Return an aliased linear view of this tensor\n    pub fn as_linear_view_alias(&self, input_pos: usize) -> LinearViewLaunch<R> {\n        let layout = LinearViewLayoutLaunch::new();\n        let buffer = self.as_tensor_alias(input_pos);\n        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)\n    }\n\n    /// Return a linear view broadcast to the reference tensor's shape\n    pub fn into_linear_view_like(self, reference: &Self) -> LinearViewLaunch<R> {\n        let layout = LinearViewLayoutLaunch::from_reference_shape(reference.shape());\n        let buffer = self.into_tensor_arg();\n        LinearViewLaunch::new_tensor::<LinearViewLayout>(buffer, layout)\n    }\n\n    /// Returns the address type required to index this tensor\n    pub fn required_address_type(&self) -> AddressType {\n        match self.try_scheme() {\n            Some(scheme) => {\n                let len = self.handle.size() as usize * 8 / scheme.size_bits_value();\n                AddressType::from_len(len)\n            }\n            None => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),\n        }\n    }\n\n    /// Return the `QuantScheme` if present\n    pub fn try_scheme(&self) -> Option<&QuantScheme> {\n        match &self.dtype {\n            DType::QFloat(scheme) => Some(scheme),\n            _ => None,\n        }\n    }\n\n    pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool {\n        if !self.handle.can_mut() || !self.is_nonoverlapping() {\n            return false;\n        }\n        let ndims = self.meta.num_dims();\n\n        for i in 0..ndims {\n            let shape_lhs = self.meta.shape()[i];\n            let shape_rhs = rhs.meta.shape()[i];\n\n            // Output tensor will be different from the mutable tensor.\n            if shape_lhs < shape_rhs {\n                return false;\n            }\n        }\n\n        true\n    }\n\n    /// Copy the current tensor.\n    pub fn copy(&self) -> Self {\n        struct Copy;\n\n        #[cube]\n        impl<T: Numeric, N: Size> NumericUnaryOp<T, N> for Copy {\n            type Options = ();\n\n            fn execute(input: Vector<T, N>, _options: &Self::Options) -> Vector<T, N> {\n                input\n            }\n        }\n\n        impl NumericUnaryOpFamily for Copy {\n            type Options = ();\n            type Unary<T: Numeric, N: Size> = Self;\n        }\n\n        let tensor = self.clone();\n        launch_unary_numeric::<R, Copy, _>(tensor, |_| ())\n    }\n\n    /// Check if the tensor is safe to mutate.\n    pub fn can_mut(&self) -> bool {\n        self.handle.can_mut()\n    }\n\n    /// Assert that both tensors are on the same device.\n    pub fn assert_is_on_same_device(&self, other: &Self) {\n        if self.device != other.device {\n            panic!(\n                \"Both tensors should be on the same device {:?} != {:?}\",\n                self.device, other.device\n            );\n        }\n    }\n\n    /// Check if the current tensor is contiguous.\n    ///\n    /// A tensor is contiguous if the elements are stored in memory\n    /// if the strides in non-increasing order and the\n    /// strides at position k is equal to the product of the shapes\n    /// at all positions greater than k. However, all axes with a shape of 1 are ignored.\n    pub fn is_contiguous(&self) -> bool {\n        is_contiguous(self.meta.shape(), self.meta.strides())\n    }\n\n    /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory\n    /// regions within the shape).\n    pub fn is_contiguous_buffer(&self) -> bool {\n        self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize\n    }\n\n    /// Checks if the tensor is non-overlapping (can be safely written to).\n    pub fn is_nonoverlapping(&self) -> bool {\n        let shape = self.meta.shape();\n        let strides = self.meta.strides();\n\n        if strides.contains(&0) {\n            return false;\n        }\n        let rank = self.rank();\n        if rank > 1 {\n            let mut dims = shape.iter().zip(strides.iter()).collect::<Vec<_>>();\n            dims.sort_by_key(|(_, stride)| **stride);\n\n            let mut max_offset = 0;\n            for (shape, stride) in dims.into_iter() {\n                if *stride <= max_offset && *shape != 1 {\n                    return false;\n                }\n\n                max_offset += (*shape - 1) * *stride;\n            }\n        }\n        true\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn is_contiguous_non_increasing() {\n        assert!(is_contiguous(&[3, 1], &[1, 1]));\n    }\n\n    #[test]\n    fn is_contiguous_basic() {\n        assert!(is_contiguous(&[32, 32], &[32, 1]));\n    }\n\n    #[test]\n    fn is_contiguous_permuted() {\n        assert!(!is_contiguous(&[32, 32], &[1, 32]));\n    }\n\n    #[test]\n    fn is_contiguous_slice() {\n        assert!(!is_contiguous(&[32, 1, 64], &[32, 64, 1]));\n    }\n\n    #[test]\n    fn is_contiguous_4d_positive() {\n        assert!(is_contiguous(&[8, 256, 32, 32], &[262144, 1024, 32, 1]));\n    }\n\n    #[test]\n    fn is_contiguous_4d_negative() {\n        assert!(!is_contiguous(&[256, 8, 32, 32], &[1024, 262144, 32, 1]));\n    }\n\n    /// Based on a bug encountered in interpolate_1d\n    #[test]\n    fn is_contiguous_4d_unit_shape() {\n        assert!(!is_contiguous(&[1, 1, 1, 9], &[72, 1, 72, 8]));\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/tensor/mod.rs",
    "content": "mod base;\nmod quantization;\n\npub use base::*;\npub use quantization::*;\n"
  },
  {
    "path": "crates/burn-cubecl/src/tensor/quantization.rs",
    "content": "use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor};\nuse burn_std::{Metadata, Strides};\nuse cubecl::quant::scheme::{QuantStore, QuantValue};\nuse cubecl::{client::ComputeClient, server::Handle};\n\nuse crate::CubeRuntime;\n\nuse super::CubeTensor;\n\n/// Runtime parameters for quantization. Can be used to construct a scales handle from the base\n/// tensor handle.\npub type QParams = burn_backend::quantization::QParams<QParamTensor>;\n\nimpl<R: CubeRuntime> CubeTensor<R> {\n    /// Create a new quantized tensor\n    pub fn new_quantized(\n        client: ComputeClient<R>,\n        handle: Handle,\n        shape: Shape,\n        device: R::Device,\n        strides: Strides,\n        dtype: DType,\n        qparams: QParams,\n    ) -> Self {\n        CubeTensor {\n            client,\n            handle,\n            meta: Box::new(Metadata::new(shape, strides)),\n            device,\n            dtype,\n            qparams: Some(qparams),\n        }\n    }\n\n    /// Returns the two tensors: (values, params) for a quantized tensor.\n    /// For the values, native types that aren't supported as a normal `DType` will be returned\n    /// as an unsigned integer tensor representing the bits. Should be reconstructed using `from_bits`\n    /// in kernels.\n    pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {\n        let params = self.scales()?;\n        let scheme = match self.dtype {\n            DType::QFloat(sc) => sc,\n            _ => return None,\n        };\n        let values = match scheme.store {\n            QuantStore::Native => match scheme.value {\n                QuantValue::Q8F | QuantValue::Q8S => CubeTensor {\n                    client: self.client.clone(),\n                    handle: self.handle.clone(),\n                    meta: self.meta.clone(),\n                    device: self.device.clone(),\n                    dtype: DType::I8,\n                    qparams: None,\n                },\n                QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor {\n                    client: self.client.clone(),\n                    handle: self.handle.clone(),\n                    meta: self.meta.clone(),\n                    device: self.device.clone(),\n                    dtype: DType::U8,\n                    qparams: None,\n                },\n                QuantValue::Q4F\n                | QuantValue::Q4S\n                | QuantValue::Q2F\n                | QuantValue::Q2S\n                | QuantValue::E2M1 => {\n                    panic!(\"Can't store native sub-byte values\")\n                }\n            },\n            QuantStore::PackedU32(packed_dim) => {\n                let packed_dim = self.rank() - packed_dim - 1;\n                let mut shape = self.shape();\n                shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());\n\n                CubeTensor {\n                    client: self.client.clone(),\n                    handle: self.handle.clone(),\n                    meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),\n                    device: self.device.clone(),\n                    dtype: DType::U32,\n                    qparams: None,\n                }\n            }\n            QuantStore::PackedNative(packed_dim) => match scheme.value {\n                QuantValue::E2M1 => {\n                    let packed_dim = self.rank() - packed_dim - 1;\n                    let mut shape = self.shape();\n                    shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());\n\n                    CubeTensor {\n                        client: self.client.clone(),\n                        handle: self.handle.clone(),\n                        meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),\n                        device: self.device.clone(),\n                        dtype: DType::U8,\n                        qparams: None,\n                    }\n                }\n                other => panic!(\"{other:?} doesn't support native packing\"),\n            },\n        };\n\n        Some((values, params))\n    }\n\n    /// Construct a separate tensor for the quantization scales, if present\n    pub fn scales(&self) -> Option<CubeTensor<R>> {\n        let qparams = self.qparams.as_ref()?;\n        let mut handle = self.handle.clone();\n        handle.offset_start = Some(qparams.scales.offset_start as u64);\n        handle.offset_end = Some(qparams.scales.offset_end as u64);\n\n        Some(CubeTensor::new(\n            self.client.clone(),\n            handle,\n            qparams.scales.metadata.clone(),\n            self.device.clone(),\n            qparams.scales.dtype,\n        ))\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl/src/tune_key.rs",
    "content": "use crate::kernel::{\n    conv::{ConvAutotuneKey, ConvTranspose2dAutotuneKey},\n    reduce::SumAutotuneKey,\n};\nuse cubecl::tune::AutotuneKey;\nuse serde::{Deserialize, Serialize};\nuse std::fmt::Display;\n\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]\n/// Key for all autotune-enabled operations\npub enum CubeAutotuneKey {\n    /// Key for sum operations\n    Sum(SumAutotuneKey),\n    /// Key for convolution operations\n    Conv(ConvAutotuneKey),\n    /// Key for transpose convolution operations\n    ConvTranspose(ConvTranspose2dAutotuneKey),\n}\n\nimpl Display for CubeAutotuneKey {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            CubeAutotuneKey::Sum(reduce_key) => std::fmt::Debug::fmt(&reduce_key, f),\n            CubeAutotuneKey::Conv(conv_key) => std::fmt::Debug::fmt(&conv_key, f),\n            CubeAutotuneKey::ConvTranspose(conv_key) => std::fmt::Debug::fmt(&conv_key, f),\n        }\n    }\n}\n\nimpl AutotuneKey for CubeAutotuneKey {}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Provide optimizations that can be used with cubecl based backends.\"\ndocumentation = \"https://docs.rs/burn-cubecl-fusion\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\"]\nlicense.workspace = true\nname = \"burn-cubecl-fusion\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-cubecl-fusion\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"autotune\", \"std\", \"cubecl/default\", \"burn-fusion/default\"]\n\nautotune = []\nautotune-checks = [\"cubecl/autotune-checks\", \"burn-backend\", \"half\"]\ndoc = [\"default\"]\nstd = [\"cubecl/std\", \"burn-backend?/std\", \"burn-fusion/std\"]\ntracing = [\n    \"cubecl/tracing\",\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n    \"burn-fusion/tracing\",\n]\n\n[dependencies]\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", features = [\n    \"cubecl\",\n] }\ncubecl = { workspace = true }\ncubek = { workspace = true, features = [\n    \"matmul\",\n    \"reduce\",\n    \"quantization\",\n    \"stdlib\",\n] }\nhalf = { workspace = true, optional = true }\n\n# Only for `TensorData` with autotune-checks\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\n\nderive-new = { workspace = true }\nserde = { workspace = true }\n\n[dev-dependencies]\ncubecl = { workspace = true, features = [\"test-runtime\"] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/README.md",
    "content": "# Burn CubeCl Fusion\n\nProvide optimizations that can be used with [cubecl](../burn-cubecl) based backends.\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/base.rs",
    "content": "use burn_fusion::stream::Context;\nuse burn_std::{DType, Shape, Strides, quantization::QParamTensor, strides};\nuse cubecl::quant::scheme::{QuantParam, QuantScheme};\nuse cubecl::{\n    Runtime,\n    client::ComputeClient,\n    ir::AddressType,\n    prelude::{TensorArg, TensorBinding},\n};\nuse std::marker::PhantomData;\n\n/// Defines a fallback operation when fusion isn't possible.\npub trait FallbackOperation<R: Runtime>: Send + Sync {\n    /// Executes the fallback procedure.\n    fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);\n}\n\n/// Runtime parameters for quantization. Can be used to construct a scales handle from the base\n/// tensor handle.\npub type QParams = burn_std::quantization::QParams<QParamTensor>;\n\n/// Handle to be used when fusing operations.\npub struct CubeFusionHandle<R: Runtime> {\n    /// Compute client for jit.\n    pub client: ComputeClient<R>,\n    /// The buffer where the data are stored.\n    pub handle: cubecl::server::Handle,\n    /// The device of the current tensor.\n    pub device: R::Device,\n    /// The element type of the tensor.\n    pub dtype: DType,\n    /// The strides of the tensor.\n    pub strides: Strides,\n    /// Quantization runtime parameters, if applicable\n    pub qparams: Option<QParams>,\n}\n\nimpl<R: Runtime> core::fmt::Debug for CubeFusionHandle<R> {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_fmt(format_args!(\n            \"CubeFusionHandle {{ device: {:?}, runtime: {}}}\",\n            self.device,\n            R::name(&self.client),\n        ))\n    }\n}\n\nimpl<R: Runtime> Clone for CubeFusionHandle<R> {\n    fn clone(&self) -> Self {\n        Self {\n            client: self.client.clone(),\n            handle: self.handle.clone(),\n            device: self.device.clone(),\n            strides: self.strides.clone(),\n            dtype: self.dtype,\n            qparams: self.qparams.clone(),\n        }\n    }\n}\n\nunsafe impl<R: Runtime> Send for CubeFusionHandle<R> {}\nunsafe impl<R: Runtime> Sync for CubeFusionHandle<R> {}\n\nimpl<R: Runtime> CubeFusionHandle<R> {\n    /// Return the reference to a tensor handle.\n    pub fn binding(self, shape: Shape) -> TensorBinding<R> {\n        TensorBinding {\n            handle: self.handle.binding(),\n            strides: self.strides.clone(),\n            shape,\n            runtime: PhantomData,\n        }\n    }\n\n    pub fn required_address_type(&self) -> AddressType {\n        match self.dtype {\n            DType::QFloat(scheme) => {\n                let len = self.handle.size() as usize * 8 / scheme.size_bits_value();\n                AddressType::from_len(len)\n            }\n            _ => AddressType::from_len(self.handle.size() as usize / self.dtype.size()),\n        }\n    }\n\n    /// Return the reference to a tensor argument.\n    pub fn into_tensor_arg(self, shape: Shape) -> TensorArg<R> {\n        let handle = self.binding(shape);\n        handle.into_tensor_arg()\n    }\n\n    /// Construct a separate tensor for the quantization scales, if present\n    pub fn params(&self, scheme: QuantScheme) -> Option<Self> {\n        let qparams = self.qparams.as_ref()?;\n        let mut handle = self.handle.clone();\n        handle.offset_start = Some(qparams.scales.offset_start as u64);\n        handle.offset_end = Some(qparams.scales.offset_end as u64);\n\n        Some(Self {\n            client: self.client.clone(),\n            handle,\n            device: self.device.clone(),\n            dtype: match scheme.param {\n                QuantParam::F32 => DType::F32,\n                QuantParam::F16 => DType::F16,\n                QuantParam::BF16 => DType::BF16,\n                QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!(\"Not yet supported\"),\n            },\n            strides: qparams.scales.metadata.strides().clone(),\n            qparams: None,\n        })\n    }\n}\n\npub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides {\n    let mut strides = strides![0; shape.len()];\n\n    let mut current = 1;\n    shape.iter().enumerate().rev().for_each(|(index, val)| {\n        strides[index] = current;\n        current *= val;\n    });\n\n    strides\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/base.rs",
    "content": "use cubecl::{define_scalar, define_size};\n\ndefine_scalar!(pub DynElem);\ndefine_size!(pub DynSize);\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/io.rs",
    "content": "//! This module declares input-output primitives to read and write values during kernel expansion.\nuse crate::engine::codegen::{DynElem, DynSize};\n\nuse super::{ir::*, tensor::GlobalTensor};\nuse burn_std::quantization::QuantScheme;\nuse cubecl::quant::scheme::QuantLevel;\nuse cubecl::{\n    intrinsic,\n    ir::{ManagedVariable, Variable},\n    prelude::*,\n    std::{FastDivmod, tensor::View},\n};\nuse cubek::quantization::layout::{BlockScaledLayout, PerTensorLayout, ScalesLayout};\nuse serde::{Deserialize, Serialize};\n\n/// Define how a tensor might be transformed at runtime.\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]\npub enum Transform {\n    /// A reshape operation has been registered on a tensor.\n    ///\n    /// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the\n    /// new shape for the current tensor.\n    Reshape(Vec<FuseArg>),\n    /// Two axes have been swapped on a tensor.\n    ///\n    /// The enum entry contains those two axes.\n    SwapDims(usize, usize),\n}\n\n/// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive.\n///\n/// # Notes\n///\n/// The [global arguments](GlobalArgs) for both inputs and outputs as well as the\n/// [local arguments](LocalArgs) need to be passed to this function.\n///\n/// This is because the [argument](FuseArg) might point to a global input, output or local variable\n/// created during kernel expansion.\n#[cube]\npub fn read<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &GlobalArgs,\n    locals: &LocalArgs,\n    ref_pos: usize,\n    #[comptime] arg: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) -> Vector<C, N> {\n    set_polyfill_typed::<Vector<C, N>, DynElem, DynSize>();\n    match arg {\n        FuseArg::Input(pos, _precision, layout) => {\n            let global = inputs.tensors.index(pos);\n            let vector_size = global.tensor.vector_size();\n\n            if comptime![!global.broadcasted && vector_size != config.width] {\n                read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None)\n            } else {\n                read_input(inputs, locals, pos, ref_pos, layout, config, None)\n            }\n        }\n        FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {\n            Vector::cast_from(outputs.variables.read(key))\n        }\n        FuseArg::Output(pos, _precision, layout) => {\n            read_output(inputs, outputs, locals, pos, ref_pos, layout, config)\n        }\n        FuseArg::BlockLocal { pos, ty } => match comptime![ty] {\n            FuseType::F64 => Vector::cast_from(locals.l_f64.find(pos)),\n            FuseType::F32 | FuseType::Flex32 => Vector::cast_from(locals.l_f32.find(pos)),\n            FuseType::F16 => Vector::cast_from(locals.l_f16.find(pos)),\n            FuseType::BF16 => Vector::cast_from(locals.l_bf16.find(pos)),\n            FuseType::U64 => Vector::cast_from(locals.l_u64.find(pos)),\n            FuseType::U32 => Vector::cast_from(locals.l_u32.find(pos)),\n            FuseType::U16 => Vector::cast_from(locals.l_u16.find(pos)),\n            FuseType::U8 => Vector::cast_from(locals.l_u8.find(pos)),\n            FuseType::I64 => Vector::cast_from(locals.l_i64.find(pos)),\n            FuseType::I32 => Vector::cast_from(locals.l_i32.find(pos)),\n            FuseType::I16 => Vector::cast_from(locals.l_i16.find(pos)),\n            FuseType::I8 => Vector::cast_from(locals.l_i8.find(pos)),\n        },\n        FuseArg::Scalar(..) => {\n            let scalar = read_scalar::<C>(inputs, arg);\n            Vector::new(scalar)\n        }\n        FuseArg::ScalarShape(_) => {\n            let scalar = read_scalar_shape(inputs, arg);\n            Vector::cast_from(scalar)\n        }\n        FuseArg::Literal(val, _precision) => Vector::new(from_const_int::<C>(val)),\n        FuseArg::InputReshaped {\n            original,\n            shape,\n            broadcasted,\n        } => match comptime![original.as_ref().clone()] {\n            FuseArg::Input(pos, _precision, layout) => {\n                let global = inputs.tensors.index(pos);\n                let vector_size = global.tensor.vector_size();\n\n                if comptime![!broadcasted && vector_size != config.width] {\n                    read_input_aligned(\n                        inputs,\n                        locals,\n                        pos,\n                        ref_pos,\n                        layout,\n                        config,\n                        comptime![Some(Transform::Reshape(shape))],\n                    )\n                } else {\n                    read_input(\n                        inputs,\n                        locals,\n                        pos,\n                        ref_pos,\n                        layout,\n                        config,\n                        comptime![Some(Transform::Reshape(shape))],\n                    )\n                }\n            }\n            _ => comptime![panic!(\"Only input can be reshaped\")],\n        },\n        FuseArg::InputSwapDims {\n            original,\n            dims,\n            broadcasted,\n        } => match comptime![original.as_ref().clone()] {\n            FuseArg::Input(pos, _precision, layout) => {\n                let global = inputs.tensors.index(pos);\n                let vector_size = global.tensor.vector_size();\n\n                if comptime![!broadcasted && vector_size != config.width] {\n                    read_input_aligned(\n                        inputs,\n                        locals,\n                        pos,\n                        ref_pos,\n                        layout,\n                        config,\n                        comptime![Some(Transform::SwapDims(dims.0, dims.1))],\n                    )\n                } else {\n                    read_input(\n                        inputs,\n                        locals,\n                        pos,\n                        ref_pos,\n                        layout,\n                        config,\n                        comptime![Some(Transform::SwapDims(dims.0, dims.1))],\n                    )\n                }\n            }\n            _ => comptime![panic!(\"Only input can be swapped dims\")],\n        },\n    }\n}\n\n/// Computes the offset for the current global tensor with a quantized layout.\n///\n/// The offset can be used to fetch the correct data from the quantized tensor as if it was in a\n/// linear contiguous format.\n#[cube]\nfn index_offset_with_quant_layout(\n    tensor: &GlobalTensor,\n    locals: &LocalArgs,\n    index: usize,\n    #[comptime] rank: usize,\n    #[comptime] scheme: QuantScheme,\n) -> usize {\n    let (start, end) = (0, rank - 1);\n    let num_quants = scheme.num_quants();\n\n    let offset_ref = index * locals.ref_vector_size;\n    let mut offset = 0;\n\n    #[unroll]\n    for i in start..end {\n        let ogwl = offset_ref / locals.ref_strides[i];\n        offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);\n    }\n\n    // Handle packed representation in last dim\n    let ogwl = offset_ref / locals.ref_strides[end];\n    let shape_last = tensor.tensor.shape(end).div_ceil(num_quants);\n    let stride_last = tensor.tensor.stride(end);\n    offset += (ogwl.div_ceil(num_quants)) % shape_last * stride_last;\n\n    offset / tensor.tensor.vector_size()\n}\n\n/// Reads a global quantized tensor at the given position.\n///\n/// # Notes\n///\n/// The values returned in the [Vector] are not dequantized.\n#[cube]\npub fn read_quantized<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    ref_pos: usize,\n    #[comptime] arg: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n    #[comptime] scheme: QuantScheme,\n) -> Vector<C, N> {\n    match arg {\n        FuseArg::Input(pos, _precision, _layout) => {\n            set_polyfill_typed::<Vector<C, N>, DynElem, DynSize>();\n            let global = inputs.tensors.index(pos);\n\n            let offset =\n                index_offset_with_quant_layout(global, locals, ref_pos, config.rank, scheme);\n            let val = global.tensor[offset];\n            Vector::cast_from(val)\n        }\n        _ => panic!(\"Not supported\"),\n    }\n}\n\n/// Reads a global scalar.\n#[cube]\npub fn read_scalar<C: Scalar>(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> C {\n    match arg {\n        FuseArg::Scalar(pos, _precision) => {\n            let scalar = inputs.scalars.index(pos);\n            scalar.get::<C>()\n        }\n        _ => comptime![panic!(\"Not a scalar\")],\n    }\n}\n\n/// Reads a global scalar that is used as a reshape position.\n#[cube]\npub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize {\n    match arg {\n        FuseArg::ScalarShape(pos) => inputs.reshapes[pos],\n        _ => comptime![panic!(\"Not a scalar shape\")],\n    }\n}\n\n/// Reads an input tensor.\n#[cube]\npub fn read_input<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    #[comptime] pos: usize,\n    ref_pos: usize,\n    #[comptime] layout: LayoutInfo,\n    #[comptime] config: &FuseBlockConfig,\n    #[comptime] transform: Option<Transform>,\n) -> Vector<C, N> {\n    set_polyfill_typed::<Vector<C, N>, DynElem, DynSize>();\n    let tensor = inputs.tensors.index(pos);\n    let offset = match layout {\n        LayoutInfo::SameAsRef => ref_pos,\n        LayoutInfo::IsRef => ref_pos,\n        LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, transform),\n    };\n    Vector::cast_from(tensor.tensor[offset])\n}\n\n/// Returns a slice of data in the asked precision of the input tensor at the given position.\n#[cube]\npub fn read_input_window<C: CubePrimitive>(\n    inputs: &GlobalArgs,\n    #[comptime] pos: usize,\n    start: usize,\n    end: usize,\n) -> Slice<C> {\n    set_polyfill_typed::<C, DynElem, DynSize>();\n    let tensor = inputs.tensors.index(pos);\n    let slice = tensor.tensor.slice(start, end);\n    slice.downcast()\n}\n\n/// Returns the input as a slice.\n#[cube]\npub fn input_as_slice<C: CubePrimitive>(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice<C> {\n    set_polyfill_typed::<C, DynElem, DynSize>();\n    let tensor = inputs.tensors.index(pos);\n    let slice = tensor.tensor.to_slice();\n    slice.downcast()\n}\n\n/// Returns the input tensor as a quantized scale view.\n#[cube]\npub fn input_as_scales_view<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    #[comptime] pos: usize,\n    #[comptime] tensor_pos: usize,\n    #[comptime] level: QuantLevel,\n    #[comptime] config: &FuseBlockConfig,\n) -> View<C, usize> {\n    set_polyfill_typed::<Vector<C, N>, DynElem, DynSize>();\n    let tensor = inputs.tensors.index(tensor_pos);\n    let scales = inputs.tensors.index(pos);\n    let tensor_len = tensor.tensor.len();\n    let rank = config.rank;\n    let layout = match level {\n        QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)),\n        QuantLevel::Block(block_size) => {\n            let block_size = comptime![block_size.to_dim_vec(rank)];\n            let mut tensor_shape = Sequence::new();\n            let mut scales_strides = Sequence::new();\n            #[unroll]\n            for i in 0..rank {\n                tensor_shape.push(FastDivmod::new_Fallback(tensor.tensor.shape(i)));\n                scales_strides.push(scales.tensor.stride(i));\n            }\n            let vector_size = scales.tensor.vector_size();\n            let layout = BlockScaledLayout::new(\n                tensor_shape,\n                tensor_len,\n                scales_strides,\n                block_size,\n                vector_size,\n            );\n            ScalesLayout::new_BlockScaled(layout)\n        }\n    };\n    View::new::<Slice<C>, usize>(&scales.tensor.to_slice().downcast(), layout)\n}\n\n/// Reads the input tensor aligned.\n#[cube]\npub fn read_input_aligned<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    #[comptime] pos: usize,\n    ref_pos: usize,\n    #[comptime] layout: LayoutInfo,\n    #[comptime] config: &FuseBlockConfig,\n    #[comptime] transform: Option<Transform>,\n) -> Vector<C, N> {\n    let mut result = Vector::<C, N>::empty();\n    let tensor = inputs.tensors.index(pos);\n\n    match transform.clone() {\n        Some(Transform::Reshape(shape)) => {\n            // Very brute force, not really efficient, but not easy to optimize and not a very\n            // frequent workflow.\n            let ref_pos = ref_pos * config.width;\n            #[unroll]\n            for i in 0..config.width {\n                let index = reshaped_index(\n                    inputs,\n                    locals,\n                    ref_pos + i,\n                    config.rank,\n                    comptime![shape.clone()],\n                );\n                let index = reshaped_index_to_original_index(&tensor.tensor, index, config.rank);\n                result[i] = C::cast_from(tensor.tensor[index][0])\n            }\n        }\n        Some(Transform::SwapDims(dim1, dim2)) => {\n            let offset =\n                get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);\n            let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))];\n            let stride = tensor.tensor.stride(i);\n\n            #[unroll]\n            for i in 0..config.width {\n                let index = offset + i * stride;\n                result[i] = C::cast_from(tensor.tensor[index][0])\n            }\n        }\n        None => {\n            let offset =\n                get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform);\n            let stride = tensor.tensor.stride(config.rank - 1);\n            #[unroll]\n            for i in 0..config.width {\n                let index = offset + i * stride;\n                result[i] = C::cast_from(tensor.tensor[index][0])\n            }\n        }\n    }\n\n    result\n}\n\n/// Computes the offset of the given [GlobalTensor] at on the reference position with a linear\n/// layout.\n#[cube]\npub fn get_offset_aligned(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    tensor: &GlobalTensor,\n    ref_pos: usize,\n    #[comptime] layout: LayoutInfo,\n    #[comptime] config: &FuseBlockConfig,\n    #[comptime] transform: Option<Transform>,\n) -> usize {\n    match layout {\n        LayoutInfo::SameAsRef | LayoutInfo::IsRef => {\n            (ref_pos * locals.ref_vector_size) / tensor.tensor.vector_size()\n        }\n        LayoutInfo::Unknown => get_offset(\n            inputs,\n            locals,\n            tensor,\n            ref_pos,\n            None,\n            config,\n            comptime!(transform.clone()),\n        ),\n    }\n}\n\n/// Reads an output tensor.\n#[cube]\npub fn read_output<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &GlobalArgs,\n    locals: &LocalArgs,\n    #[comptime] pos: usize,\n    ref_pos: usize,\n    #[comptime] layout: LayoutInfo,\n    #[comptime] config: &FuseBlockConfig,\n) -> Vector<C, N> {\n    let tensor = outputs.tensors.index(pos);\n    let offset = match layout {\n        LayoutInfo::SameAsRef => ref_pos,\n        LayoutInfo::IsRef => ref_pos,\n        LayoutInfo::Unknown => get_offset(inputs, locals, tensor, ref_pos, None, config, None),\n    };\n    Vector::cast_from(tensor.tensor[offset])\n}\n\n#[cube]\n/// Write the given value at the [arg](Arg) position.\npub fn write<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    ref_pos: usize,\n    value: Vector<C, N>,\n    #[comptime] arg: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    set_polyfill_typed::<Vector<C, N>, DynElem, DynSize>();\n\n    match arg {\n        FuseArg::Output(pos, _, layout) => {\n            let tensor = outputs.tensors.index(pos);\n            let offset = match layout {\n                LayoutInfo::SameAsRef => ref_pos,\n                LayoutInfo::IsRef => ref_pos,\n                LayoutInfo::Unknown => {\n                    get_offset(inputs, locals, tensor, ref_pos, None, config, None)\n                }\n            };\n            let tensor = outputs.tensors.index_mut(pos);\n\n            let value = Vector::cast_from(value);\n\n            tensor.tensor[offset] = value;\n        }\n        FuseArg::BlockLocal { .. } => write_scalar::<C, N>(locals, value, arg),\n        FuseArg::MultiBlockLocal(key, _) | FuseArg::MultiBlockGlobal(key, _) => {\n            outputs.variables.write(key, Vector::cast_from(value))\n        }\n        _ => comptime![panic!(\"Can't write into inputs and scalars\")],\n    }\n}\n\n#[cube]\n/// Write the given value at the [arg](Arg) position.\npub fn write_scalar<C: Scalar, N: Size>(\n    locals: &mut LocalArgs,\n    value: Vector<C, N>,\n    #[comptime] arg: FuseArg,\n) {\n    match arg {\n        FuseArg::BlockLocal { pos, ty } => match comptime![ty] {\n            FuseType::F64 => locals.l_f64.insert(pos, Vector::cast_from(value)),\n            FuseType::F32 | FuseType::Flex32 => locals.l_f32.insert(pos, Vector::cast_from(value)),\n            FuseType::F16 => locals.l_f16.insert(pos, Vector::cast_from(value)),\n            FuseType::BF16 => locals.l_bf16.insert(pos, Vector::cast_from(value)),\n            FuseType::U64 => locals.l_u64.insert(pos, Vector::cast_from(value)),\n            FuseType::U32 => locals.l_u32.insert(pos, Vector::cast_from(value)),\n            FuseType::U16 => locals.l_u16.insert(pos, Vector::cast_from(value)),\n            FuseType::U8 => locals.l_u8.insert(pos, Vector::cast_from(value)),\n            FuseType::I64 => locals.l_i64.insert(pos, Vector::cast_from(value)),\n            FuseType::I32 => locals.l_i32.insert(pos, Vector::cast_from(value)),\n            FuseType::I16 => locals.l_i16.insert(pos, Vector::cast_from(value)),\n            FuseType::I8 => locals.l_i8.insert(pos, Vector::cast_from(value)),\n        },\n        _ => comptime![panic!(\"Can't write into something else than scalars\")],\n    }\n}\n\n#[cube]\npub(crate) fn global_offset(\n    inputs: &GlobalArgs,\n    outputs: &GlobalArgs,\n    locals: &LocalArgs,\n    index: usize,\n    #[comptime] arg: FuseArg,\n    #[comptime] range: Option<(usize, usize)>,\n    #[comptime] config: &FuseBlockConfig,\n) -> usize {\n    match arg {\n        FuseArg::Input(pos, _precision, _layout) => {\n            let tensor = inputs.tensors.index(pos);\n            get_offset(inputs, locals, tensor, index, range, config, None)\n        }\n        FuseArg::Output(pos, _precision, _layout) => {\n            let tensor = outputs.tensors.index(pos);\n            get_offset(inputs, locals, tensor, index, range, config, None)\n        }\n        _ => panic!(\"Only input and output tensors have global offset.\"),\n    }\n}\n\n#[cube]\nfn get_offset(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    tensor: &GlobalTensor,\n    ref_pos: usize,\n    #[comptime] range: Option<(usize, usize)>,\n    #[comptime] config: &FuseBlockConfig,\n    #[comptime] transform: Option<Transform>,\n) -> usize {\n    index_offset_with_layout(\n        inputs,\n        tensor,\n        locals,\n        ref_pos,\n        range,\n        config.rank,\n        transform,\n    )\n}\n\n#[cube]\n/// Gets the vector size for a global tensor.\npub fn global_vector_size(\n    global: &GlobalArgs,\n    #[comptime] pos: usize,\n) -> comptime_type!(VectorSize) {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.vector_size()\n}\n\n#[cube]\n/// Gets the rank for a global tensor.\npub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.rank()\n}\n\n#[cube]\n/// Gets the length for a global tensor.\npub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.len()\n}\n\n#[cube]\n/// Gets the buffer length for a global tensor.\npub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.buffer_len()\n}\n\n#[cube]\n/// Gets the reference tensor length.\npub fn ref_len(\n    inputs: &GlobalArgs,\n    outputs: &GlobalArgs,\n    locals: &LocalArgs,\n    #[comptime] config: &FuseBlockConfig,\n) -> usize {\n    match config.ref_layout.clone() {\n        RefLayout::Concrete(arg) => match comptime![arg] {\n            FuseArg::Input(index, _, _) => global_len(inputs, index),\n            FuseArg::Output(index, _, _) => global_len(outputs, index),\n            _ => panic!(\"Invalid concrete ref layout.\"),\n        },\n        RefLayout::Virtual(..) => num_elements(locals, config),\n    }\n}\n\n#[cube]\n/// Gets the reference buffer tensor length.\npub fn ref_buffer_len(\n    inputs: &GlobalArgs,\n    outputs: &GlobalArgs,\n    locals: &LocalArgs,\n    #[comptime] config: &FuseBlockConfig,\n) -> usize {\n    match config.ref_layout.clone() {\n        RefLayout::Concrete(arg) => match comptime![arg] {\n            FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),\n            FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),\n            _ => panic!(\"Invalid concrete ref layout.\"),\n        },\n        RefLayout::Virtual(VirtualLayout::SwapDims(arg, ..)) => match arg {\n            FuseArg::Input(index, _, _) => global_buffer_len(inputs, index),\n            FuseArg::Output(index, _, _) => global_buffer_len(outputs, index),\n            _ => panic!(\"Invalid concrete ref layout.\"),\n        },\n        RefLayout::Virtual(VirtualLayout::Reshaped { .. }) => num_elements(locals, config),\n        RefLayout::Virtual(VirtualLayout::Shape(..)) => num_elements(locals, config),\n        RefLayout::Virtual(VirtualLayout::Runtime { .. }) => num_elements(locals, config),\n    }\n}\n\n#[cube]\n/// Gets the reference number of elements.\npub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize {\n    let mut length = 1;\n\n    for i in 0..config.rank {\n        length *= locals.ref_shape[i];\n    }\n\n    length\n}\n\n#[cube]\n/// Gets the reference axis shape.\npub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize {\n    locals.ref_shape[axis]\n}\n\n#[cube]\n/// Gets the reference axis stride.\npub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize {\n    locals.ref_strides[axis]\n}\n\n#[cube]\n/// Gets the reference vector size.\npub fn ref_vector_size(locals: &LocalArgs) -> comptime_type!(VectorSize) {\n    comptime![locals.ref_vector_size]\n}\n\n#[cube]\n/// Gets the given tensor axis shape.\npub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.shape(axis)\n}\n\n#[cube]\n/// Gets the given tensor axis stride.\npub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize {\n    let tensor = global.tensors.index(pos);\n    tensor.tensor.stride(dim)\n}\n\n#[cube]\nfn index_offset_with_layout(\n    inputs: &GlobalArgs,\n    tensor: &GlobalTensor,\n    locals: &LocalArgs,\n    index: usize,\n    #[comptime] range: Option<(usize, usize)>,\n    #[comptime] rank: usize,\n    #[comptime] transform: Option<Transform>,\n) -> usize {\n    match comptime![transform.clone()] {\n        Some(Transform::Reshape(shape)) => {\n            comptime![assert!(\n                range.is_none(),\n                \"Can't get a range on a reshaped tensor.\"\n            )];\n\n            let index = index * locals.ref_vector_size;\n            let index = reshaped_index(inputs, locals, index, rank, shape);\n            reshaped_index_to_original_index(&tensor.tensor, index, rank)\n        }\n        Some(Transform::SwapDims(dim1, dim2)) => {\n            let (start, end) = comptime! {match range {\n                Some(range) => range,\n                None => (0, rank),\n            }};\n\n            let offset_ref = index * locals.ref_vector_size;\n            let mut offset = 0;\n\n            #[unroll]\n            for i in start..end {\n                let index = comptime![swap_dims_transform(i, (dim1, dim2))];\n                let ogwl = offset_ref / locals.ref_strides[i];\n                offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index);\n            }\n\n            offset / tensor.tensor.vector_size()\n        }\n        None => {\n            let (start, end) = comptime! {match range {\n                Some(range) => range,\n                None => (0, rank),\n            }};\n\n            let offset_ref = index * locals.ref_vector_size;\n            let mut offset = 0;\n\n            #[unroll]\n            for i in start..end {\n                let ogwl = offset_ref / locals.ref_strides[i];\n                offset += ogwl % tensor.tensor.shape(i) * tensor.tensor.stride(i);\n            }\n\n            offset / tensor.tensor.vector_size()\n        }\n    }\n}\n\npub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize {\n    if i == dims.0 {\n        dims.1\n    } else if i == dims.1 {\n        dims.0\n    } else {\n        i\n    }\n}\n\n#[cube]\n#[allow(clippy::clone_on_copy)]\n/// The index the input tensor would be at if it was contiguous.\nfn reshaped_index(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    index: usize,\n    #[comptime] rank: usize,\n    #[comptime] shape: Vec<FuseArg>,\n) -> usize {\n    let mut offset = 0;\n    let mut stride_curr = 1;\n\n    #[unroll]\n    for r in 0..rank {\n        let i = reverse_index(rank, r).comptime();\n        let arg = shape[i].clone();\n        let shape_i = read_scalar_shape(inputs, arg);\n        let ogwl = index / locals.ref_strides[i];\n\n        offset += ogwl % shape_i * stride_curr;\n\n        stride_curr *= shape_i;\n    }\n\n    offset\n}\n\n#[allow(unreachable_code)]\n#[cube]\n#[allow(clippy::clone_on_copy)]\nfn reshaped_index_to_original_index<C: Scalar, N: Size>(\n    original: &Tensor<Vector<C, N>>,\n    index_reshaped: usize,\n    #[comptime] rank: usize,\n) -> usize {\n    let mut remaining = index_reshaped;\n    let mut offset = 0;\n\n    #[unroll]\n    for r in 0..rank {\n        let i = reverse_index(rank, r);\n        let shape = original.shape(i);\n        let stride = original.stride(i);\n\n        let coordinate = remaining % shape;\n\n        remaining /= shape;\n        offset += coordinate * stride;\n    }\n\n    offset / original.vector_size()\n}\n\n#[cube]\n#[allow(unused_variables)]\npub(crate) fn reverse_index(\n    #[comptime] rank: usize,\n    #[comptime] iter: usize,\n) -> comptime_type!(usize) {\n    rank - iter - 1\n}\n\n/// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion.\n#[allow(unused_variables)]\n#[cube]\nfn from_const_int<C: CubePrimitive>(#[comptime] value: usize) -> C {\n    intrinsic!(|scope| {\n        ManagedVariable::Plain(Variable::constant(value.into(), C::as_type(scope))).into()\n    })\n}\n\n#[cube]\n#[allow(clippy::extra_unused_type_parameters)]\npub(crate) fn set_polyfill_typed<C: CubePrimitive, Dyn: Scalar, DynSize: Size>() {\n    intrinsic!(|scope| {\n        let elem_type = C::as_type(scope);\n        set_polyfill::expand::<Dyn, DynSize>(scope, elem_type);\n    })\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/ir.rs",
    "content": "use super::tensor::GlobalTensor;\nuse crate::engine::codegen::{DynElem, DynSize};\nuse burn_std::{\n    BoolStore, DType, Shape, Strides, bf16, f16,\n    quantization::{QuantScheme, QuantStore, QuantValue},\n    strides,\n};\nuse core::fmt::Display;\nuse cubecl::{\n    ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind},\n    prelude::*,\n};\nuse serde::{Deserialize, Serialize};\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]\n/// Argument to a [fuse operation](FuseOp).\npub enum FuseArg {\n    /// A readonly input tensor.\n    Input(usize, FuseType, LayoutInfo),\n    /// A readwrite output tensor.\n    Output(usize, FuseType, LayoutInfo),\n    /// A temporary local variable within a single [block](FuseBlockConfig).\n    BlockLocal {\n        /// The position of the current variable relative to all local variables within a single block.\n        pos: usize,\n        /// The type of the current variable.\n        ty: FuseType,\n    },\n    /// A variable shared between multiple [block](FuseBlockConfig) that must have a compatible\n    /// scope.\n    MultiBlockLocal(MultiBlockPos, FuseType),\n    /// A variable shared between multiple [blocks](FuseBlockConfig) within a global accessible\n    /// scope.\n    MultiBlockGlobal(MultiBlockPos, FuseType),\n    /// A global scalar.\n    Scalar(usize, FuseType),\n    /// A global scalar used in a reshape operation.\n    ///\n    /// This is not a scalar defined by a user for computation, but a scalar defined as part of\n    /// a reshape operation.\n    ScalarShape(usize),\n    /// Only constant that can be encoded into an u32 can be used as literal.\n    Literal(usize, FuseType),\n    /// A readonly input tensor that is reshaped.\n    InputReshaped {\n        original: Box<FuseArg>,\n        shape: Vec<FuseArg>,\n        broadcasted: bool,\n    },\n    /// A readonly input tensor with swapped dimensions.\n    InputSwapDims {\n        original: Box<FuseArg>,\n        dims: (usize, usize),\n        broadcasted: bool,\n    },\n}\n\n/// Metadata of a variable shared between blocks.\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]\npub struct MultiBlockPos {\n    /// The block position in all blocks included in a fused trace.\n    pub block_pos: usize,\n    /// The [FuseArg::BlockLocal] position in the block where the variable is first initialized.\n    pub block_local_pos: usize,\n}\n\n#[derive(\n    CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,\n)]\n/// Layout information.\npub enum LayoutInfo {\n    /// The layout if the same as the reference.\n    SameAsRef,\n    /// The reference layout.\n    IsRef,\n    /// The layout if unknown.\n    Unknown,\n}\n\nimpl FuseArg {\n    pub fn precision(&self) -> FuseType {\n        *match self {\n            FuseArg::Input(_, p, _) => p,\n            FuseArg::BlockLocal { ty, .. } => ty,\n            FuseArg::MultiBlockLocal(_, p) => p,\n            FuseArg::MultiBlockGlobal(_, p) => p,\n            FuseArg::Output(_, p, _) => p,\n            FuseArg::Scalar(_, p) => p,\n            FuseArg::Literal(_, p) => p,\n            FuseArg::ScalarShape(_) => return FuseType::U32,\n            FuseArg::InputReshaped { original, .. } => return original.precision(),\n            FuseArg::InputSwapDims { original, .. } => return original.precision(),\n        }\n    }\n}\n\nimpl CubeType for FuseArg {\n    type ExpandType = Self;\n}\n\nimpl IntoMut for FuseArg {\n    fn into_mut(self, _context: &mut Scope) -> Self {\n        self\n    }\n}\n\nimpl IntoRuntime for FuseArg {\n    fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType {\n        self\n    }\n}\n\nimpl CubeDebug for FuseArg {}\n\n#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// Operations that can be executed and fused automatically using a fuse-on-read and/or\n/// fuse-on-write strategy.\npub enum FuseOp {\n    Add(BinaryFuseArgs),\n    Sub(BinaryFuseArgs),\n    Mul(BinaryFuseArgs),\n    Div(BinaryFuseArgs),\n    Powf(BinaryFuseArgs),\n    Abs(UnaryFuseArgs),\n    Exp(UnaryFuseArgs),\n    Log(UnaryFuseArgs),\n    Log1p(UnaryFuseArgs),\n    Cos(UnaryFuseArgs),\n    Sin(UnaryFuseArgs),\n    Tanh(UnaryFuseArgs),\n    Erf(UnaryFuseArgs),\n    Sqrt(UnaryFuseArgs),\n    Recip(UnaryFuseArgs),\n    Assign(UnaryFuseArgs),\n    Equal(BinaryFuseArgs),\n    Lower(BinaryFuseArgs),\n    Greater(BinaryFuseArgs),\n    LowerEqual(BinaryFuseArgs),\n    Rem(BinaryFuseArgs),\n    GreaterEqual(BinaryFuseArgs),\n    Clamp {\n        input: FuseArg,\n        min: FuseArg,\n        max: FuseArg,\n        out: FuseArg,\n    },\n    ConditionalAssign {\n        cond: FuseArg,\n        lhs: FuseArg,\n        rhs: FuseArg,\n        out: FuseArg,\n    },\n    Gather {\n        input: FuseArg,\n        indices: FuseArg,\n        output: FuseArg,\n        dim: usize,\n    },\n    Select {\n        input: FuseArg,\n        indices: FuseArg,\n        output: FuseArg,\n        dim: usize,\n    },\n    Dequantize {\n        values: FuseArg,\n        params: FuseArg,\n        output: FuseArg,\n        scheme: QuantSchemeFuse,\n    },\n}\n\nimpl Display for FuseOp {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            FuseOp::Add(args) => write!(f, \"{} = {} + {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Sub(args) => write!(f, \"{} = {} - {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Mul(args) => write!(f, \"{} = {} * {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Div(args) => write!(f, \"{} = {} / {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Powf(args) => write!(f, \"{} = powf({}, {})\", args.out, args.lhs, args.rhs),\n            FuseOp::Abs(args) => write!(f, \"{} = abs({})\", args.out, args.input),\n            FuseOp::Exp(args) => write!(f, \"{} = exp({})\", args.out, args.input),\n            FuseOp::Log(args) => write!(f, \"{} = log({})\", args.out, args.input),\n            FuseOp::Log1p(args) => write!(f, \"{} = log1p({})\", args.out, args.input),\n            FuseOp::Cos(args) => write!(f, \"{} = cos({})\", args.out, args.input),\n            FuseOp::Sin(args) => write!(f, \"{} = sin({})\", args.out, args.input),\n            FuseOp::Tanh(args) => write!(f, \"{} = tanh({})\", args.out, args.input),\n            FuseOp::Erf(args) => write!(f, \"{} = erf({})\", args.out, args.input),\n            FuseOp::Sqrt(args) => write!(f, \"{} = sqrt({})\", args.out, args.input),\n            FuseOp::Recip(args) => write!(f, \"{} = recip({})\", args.out, args.input),\n            FuseOp::Assign(args) => write!(f, \"{} = {}\", args.out, args.input),\n            FuseOp::Equal(args) => write!(f, \"{} = {} == {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Lower(args) => write!(f, \"{} = {} < {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Greater(args) => write!(f, \"{} = {} > {}\", args.out, args.lhs, args.rhs),\n            FuseOp::LowerEqual(args) => write!(f, \"{} = {} <= {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Rem(args) => write!(f, \"{} = {} % {}\", args.out, args.lhs, args.rhs),\n            FuseOp::GreaterEqual(args) => write!(f, \"{} = {} >= {}\", args.out, args.lhs, args.rhs),\n            FuseOp::Clamp {\n                input,\n                min,\n                max,\n                out,\n            } => write!(f, \"{} = clamp({}, min={}, max={})\", out, input, min, max),\n            FuseOp::ConditionalAssign {\n                cond,\n                lhs,\n                rhs,\n                out,\n            } => write!(\n                f,\n                \"{} = select(cond={}, lhs={}, rhs={})\",\n                out, cond, lhs, rhs\n            ),\n            FuseOp::Gather {\n                input,\n                indices,\n                output,\n                dim,\n            } => write!(\n                f,\n                \"{} = gather(input={}, indices={}, dim={})\",\n                output, input, indices, dim\n            ),\n            FuseOp::Select {\n                input,\n                indices,\n                output,\n                dim,\n            } => write!(\n                f,\n                \"{} = select(input={}, indices={}, dim={})\",\n                output, input, indices, dim\n            ),\n            FuseOp::Dequantize {\n                values,\n                params,\n                output,\n                scheme: _,\n            } => write!(\n                f,\n                \"{} = dequantize(values={}, params={})\",\n                output, values, params\n            ),\n        }\n    }\n}\n\n#[derive(\n    CubeType, CubeLaunch, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord,\n)]\npub struct QuantSchemeFuse {\n    #[cube(comptime)]\n    pub(crate) scheme: QuantScheme,\n}\n\nimpl FuseOp {\n    /// Element type used for the computation.\n    pub(crate) fn cmp_elem(&self) -> ElemType {\n        match self {\n            FuseOp::Add(op) => op.lhs.precision().into_elem(),\n            FuseOp::Sub(op) => op.lhs.precision().into_elem(),\n            FuseOp::Mul(op) => op.lhs.precision().into_elem(),\n            FuseOp::Div(op) => op.lhs.precision().into_elem(),\n            FuseOp::Powf(op) => op.lhs.precision().into_elem(),\n            FuseOp::Abs(op) => op.out.precision().into_elem(),\n            FuseOp::Exp(op) => op.out.precision().into_elem(),\n            FuseOp::Log(op) => op.out.precision().into_elem(),\n            FuseOp::Log1p(op) => op.out.precision().into_elem(),\n            FuseOp::Cos(op) => op.out.precision().into_elem(),\n            FuseOp::Sin(op) => op.out.precision().into_elem(),\n            FuseOp::Tanh(op) => op.out.precision().into_elem(),\n            FuseOp::Erf(op) => op.out.precision().into_elem(),\n            FuseOp::Recip(op) => op.out.precision().into_elem(),\n            FuseOp::Sqrt(op) => op.out.precision().into_elem(),\n            FuseOp::Assign(op) => op.out.precision().into_elem(),\n            FuseOp::Equal(op) => op.lhs.precision().into_elem(),\n            FuseOp::Lower(op) => op.lhs.precision().into_elem(),\n            FuseOp::Greater(op) => op.lhs.precision().into_elem(),\n            FuseOp::LowerEqual(op) => op.lhs.precision().into_elem(),\n            FuseOp::GreaterEqual(op) => op.lhs.precision().into_elem(),\n            FuseOp::ConditionalAssign { out, .. } => out.precision().into_elem(),\n            FuseOp::Gather { output, .. } => output.precision().into_elem(),\n            FuseOp::Select { output, .. } => output.precision().into_elem(),\n            FuseOp::Dequantize { output, .. } => output.precision().into_elem(),\n            FuseOp::Rem(op) => op.out.precision().into_elem(),\n            FuseOp::Clamp { out, .. } => out.precision().into_elem(),\n        }\n    }\n\n    pub(crate) fn cmp_storage_ty(&self) -> StorageType {\n        self.cmp_elem().into()\n    }\n}\n\n#[derive(CubeType, CubeLaunch, Default, Clone)]\n/// Global arguments that are used for fusing [element wise operations](ElemTypewiseOp).\npub struct GlobalArgs {\n    /// Tensors that are stored in global memory.\n    pub tensors: Sequence<GlobalTensor>,\n    /// Scalars that are stored in global memory.\n    pub scalars: Sequence<InputScalar>,\n    /// To be used to perform reshape inside a fused kernel.\n    pub reshapes: Sequence<usize>,\n    /// When there are no metadata as a reference layout, we provide runtime shape/strides in this\n    /// sequence instead.\n    pub runtime_layouts: Sequence<usize>,\n    /// Variables shared between blocks.\n    pub variables: MultiBlockVariables,\n}\n\nimpl<R: Runtime> GlobalArgsLaunch<R> {\n    pub fn required_address_type(&self) -> AddressType {\n        self.tensors\n            .values\n            .iter()\n            .map(|it| it.address_type)\n            .max()\n            .unwrap_or_default()\n    }\n}\n\n/// Variables shared between blocks.\n#[derive(CubeType, Default, Clone)]\npub struct MultiBlockVariables {\n    variables: Registry<usize, Registry<usize, RuntimeCell<Vector<DynElem, DynSize>>>>,\n}\n\n#[cube]\nimpl MultiBlockVariables {\n    /// Initializes the variable with the given key and vector size.\n    ///\n    /// # Notes\n    ///\n    /// The type of [`NumericExpand<DYN_ELEM_ID>`] must be set before calling this function.\n    pub fn init(&mut self, #[comptime] key: MultiBlockPos) {\n        let mut registers = Registry::<\n            usize,\n            Registry<usize, RuntimeCell<Vector<DynElem, DynSize>>>,\n        >::find_or_default::<usize>(&mut self.variables, key.block_pos);\n        let cell = RuntimeCell::new(Vector::empty());\n        registers.insert(key.block_local_pos, cell);\n    }\n\n    /// Read the variable using the provided key.\n    ///\n    /// # Notes\n    ///\n    /// The variable must be initialized.\n    pub fn read(&self, #[comptime] key: MultiBlockPos) -> Vector<DynElem, DynSize> {\n        let registers = self.variables.find(key.block_pos);\n        let cell = registers.find(key.block_local_pos);\n        cell.read()\n    }\n\n    /// Write to the variable using the provided key and value.\n    ///\n    /// # Notes\n    ///\n    /// The variable must be initialized.\n    pub fn write(&mut self, #[comptime] key: MultiBlockPos, value: Vector<DynElem, DynSize>) {\n        let registers = self.variables.find(key.block_pos);\n        // Try find for local(visibility) registers.\n        let cell = registers.find(key.block_local_pos);\n        cell.store(value);\n    }\n}\n\n// Because we only create it DURING compilation, not as a real launch arg.\nunsafe impl Send for MultiBlockVariables {}\nunsafe impl Sync for MultiBlockVariables {}\n\nimpl LaunchArg for MultiBlockVariables {\n    type RuntimeArg<R: Runtime> = ();\n    type CompilationArg = ();\n\n    fn compilation_arg<R: Runtime>(_runtime_arg: &Self::RuntimeArg<R>) -> Self::CompilationArg {}\n\n    fn register<R: Runtime>(_arg: Self::RuntimeArg<R>, _launcher: &mut KernelLauncher<R>) {}\n\n    fn expand(\n        _arg: &Self::CompilationArg,\n        _builder: &mut KernelBuilder,\n    ) -> <Self as CubeType>::ExpandType {\n        MultiBlockVariablesExpand {\n            variables: Default::default(),\n        }\n    }\n}\n\nimpl<R: Runtime> Default for GlobalArgsLaunch<R> {\n    fn default() -> Self {\n        Self {\n            tensors: Default::default(),\n            scalars: Default::default(),\n            reshapes: Default::default(),\n            variables: Default::default(),\n            runtime_layouts: Default::default(),\n            _phantom_runtime: std::marker::PhantomData,\n        }\n    }\n}\n\nimpl<R: Runtime> core::fmt::Debug for GlobalArgsLaunch<R> {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        write!(f, \"({:?})\", self.tensors.values)\n    }\n}\n\nimpl<R: Runtime> GlobalArgsLaunch<R> {\n    /// Get the shape of the given [argument](Arg).\n    ///\n    /// # Panics\n    ///\n    /// If the argument doesn't have an handle.\n    pub fn shape(&self, arg: &FuseArg) -> Shape {\n        match self.resolve_arg(arg) {\n            TensorArg::Handle { handle, .. } => handle.shape.clone(),\n            TensorArg::Alias { .. } => panic!(\"Unsupported yet\"),\n        }\n    }\n\n    /// Shape used by the reference tensor.\n    pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape {\n        match ref_layout {\n            RefLayout::Concrete(arg) => self.shape(arg),\n            RefLayout::Virtual(layout) => match layout {\n                VirtualLayout::SwapDims(original, dims) => {\n                    let mut shape = self.shape(original);\n                    shape.swap(dims.0, dims.1);\n                    shape\n                }\n                VirtualLayout::Reshaped { reshape_pos, .. } => {\n                    let start = *reshape_pos * rank;\n                    let end = start + rank;\n                    self.reshapes.values[start..end].iter().copied().collect()\n                }\n                VirtualLayout::Shape(original, _) => self.shape(original),\n                VirtualLayout::Runtime { pos } => {\n                    let start = (*pos * 2) * rank;\n                    let end = start + rank;\n                    self.runtime_layouts.values[start..end]\n                        .iter()\n                        .copied()\n                        .collect()\n                }\n            },\n        }\n    }\n\n    /// Get the strides of the given [argument](Arg).\n    ///\n    /// # Panics\n    ///\n    /// If the argument doesn't have an handle.\n    pub fn strides(&self, arg: &FuseArg) -> Strides {\n        match self.resolve_arg(arg) {\n            TensorArg::Handle { handle, .. } => handle.strides.clone(),\n            TensorArg::Alias { .. } => panic!(\"Unsupported yet\"),\n        }\n    }\n\n    pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides {\n        match ref_layout {\n            RefLayout::Concrete(arg) => self.strides(arg),\n            // When not concrete, we operate on the contiguous layout.\n            _ => {\n                let shape = self.shape_ref(ref_layout, rank);\n                let mut strides = strides![0; shape.len()];\n\n                let mut current = 1;\n                shape.iter().enumerate().rev().for_each(|(index, val)| {\n                    strides[index] = current;\n                    current *= val;\n                });\n\n                strides\n            }\n        }\n    }\n\n    /// Get the vector size of the given [argument](Arg).\n    ///\n    /// # Panics\n    ///\n    /// If the argument doesn't have an handle.\n    pub fn vector_size(&self, arg: &FuseArg) -> VectorSize {\n        match arg {\n            FuseArg::Input(pos, _, _) => self.tensors.values[*pos].ty.vector_size(),\n            FuseArg::Output(pos, _, _) => self.tensors.values[*pos].ty.vector_size(),\n            other => panic!(\"Arg not found: {other:?}\"),\n        }\n    }\n\n    /// Resolve the [argument](Arg) to a [tensor argument](TensorArg).\n    ///\n    /// # Panics\n    ///\n    /// If the argument isn't a global input or output tensor.\n    pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg<R> {\n        match arg {\n            FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor,\n            FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor,\n            other => panic!(\"Arg not found: {other:?}\"),\n        }\n    }\n}\n\n#[derive(CubeType, Clone)]\n/// Keep track of all local variables that are used as argument in fused\n/// [element wise operations](ElemwiseOp).\npub struct LocalArgs {\n    pub l_f64: Registry<usize, Vector<f64, DynSize>>,\n    pub l_f32: Registry<usize, Vector<f32, DynSize>>,\n    pub l_f16: Registry<usize, Vector<f16, DynSize>>,\n    pub l_bf16: Registry<usize, Vector<bf16, DynSize>>,\n    pub l_i64: Registry<usize, Vector<i64, DynSize>>,\n    pub l_i32: Registry<usize, Vector<i32, DynSize>>,\n    pub l_i16: Registry<usize, Vector<i16, DynSize>>,\n    pub l_i8: Registry<usize, Vector<i8, DynSize>>,\n    pub l_u64: Registry<usize, Vector<u64, DynSize>>,\n    pub l_u32: Registry<usize, Vector<u32, DynSize>>,\n    pub l_u16: Registry<usize, Vector<u16, DynSize>>,\n    pub l_u8: Registry<usize, Vector<u8, DynSize>>,\n    pub ref_shape: Slice<usize>,\n    pub ref_strides: Slice<usize>,\n    #[cube(comptime)]\n    pub ref_vector_size: VectorSize,\n}\n\n#[cube]\nimpl LocalArgs {\n    /// Creates a new [LocalArgs] container.\n    pub fn new(\n        ref_shape: Slice<usize>,\n        ref_strides: Slice<usize>,\n        #[comptime] ref_vector_size: VectorSize,\n    ) -> LocalArgs {\n        LocalArgs {\n            l_f64: Registry::<usize, Vector<f64, DynSize>>::new(),\n            l_f32: Registry::<usize, Vector<f32, DynSize>>::new(),\n            l_f16: Registry::<usize, Vector<f16, DynSize>>::new(),\n            l_bf16: Registry::<usize, Vector<bf16, DynSize>>::new(),\n            l_i64: Registry::<usize, Vector<i64, DynSize>>::new(),\n            l_i32: Registry::<usize, Vector<i32, DynSize>>::new(),\n            l_i16: Registry::<usize, Vector<i16, DynSize>>::new(),\n            l_i8: Registry::<usize, Vector<i8, DynSize>>::new(),\n            l_u64: Registry::<usize, Vector<u64, DynSize>>::new(),\n            l_u32: Registry::<usize, Vector<u32, DynSize>>::new(),\n            l_u16: Registry::<usize, Vector<u16, DynSize>>::new(),\n            l_u8: Registry::<usize, Vector<u8, DynSize>>::new(),\n            ref_shape,\n            ref_strides,\n            ref_vector_size,\n        }\n    }\n}\n\n#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// Unary [element wise operation](ElemwiseOp) arguments.\npub struct UnaryFuseArgs {\n    pub input: FuseArg,\n    pub out: FuseArg,\n}\n\n#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// Binary [element wise operation](ElemwiseOp) arguments.\npub struct BinaryFuseArgs {\n    pub lhs: FuseArg,\n    pub rhs: FuseArg,\n    pub out: FuseArg,\n}\n\n#[derive(\n    CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,\n)]\n/// Precisions supported by [element wise operations](ElemwiseOp).\n///\n/// This is a custom type instead of [ElemType] so it can implement [CubeType]\n/// and restricts the supported types for fusion.\npub enum FuseType {\n    F64,\n    F32,\n    Flex32,\n    F16,\n    BF16,\n    I64,\n    I32,\n    I16,\n    I8,\n    U64,\n    U32,\n    U16,\n    U8,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// Configuration that encapsulates all comptime information necessary for element wise fusion.\npub struct FuseBlockConfig {\n    pub rank: usize,\n    pub ref_layout: RefLayout,\n    pub ops: Vec<FuseOp>,\n    pub width: VectorSize,\n}\n\nimpl FuseBlockConfig {\n    pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {\n        for op in self.ops.iter() {\n            op.multi_block_variables(registers);\n        }\n    }\n}\n\nimpl FuseArg {\n    pub fn multi_block_variable(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {\n        match self {\n            FuseArg::MultiBlockGlobal(arg, fuse_type)\n                // TODO: we need to init the multi-block local, but at some point we could avoid\n                // that for performance (easier for the underlying compiler).\n            | FuseArg::MultiBlockLocal(arg, fuse_type) => {\n                registers.push((arg.clone(), fuse_type.into_storage_type()))\n            }\n            _ => {}\n        };\n    }\n}\n\nimpl FuseOp {\n    pub fn multi_block_variables(&self, registers: &mut Vec<(MultiBlockPos, StorageType)>) {\n        match self {\n            FuseOp::Add(binary_fuse_args)\n            | FuseOp::Sub(binary_fuse_args)\n            | FuseOp::Mul(binary_fuse_args)\n            | FuseOp::Div(binary_fuse_args)\n            | FuseOp::Powf(binary_fuse_args)\n            | FuseOp::Equal(binary_fuse_args)\n            | FuseOp::Lower(binary_fuse_args)\n            | FuseOp::Greater(binary_fuse_args)\n            | FuseOp::LowerEqual(binary_fuse_args)\n            | FuseOp::Rem(binary_fuse_args)\n            | FuseOp::GreaterEqual(binary_fuse_args) => {\n                binary_fuse_args.lhs.multi_block_variable(registers);\n                binary_fuse_args.rhs.multi_block_variable(registers);\n                binary_fuse_args.out.multi_block_variable(registers);\n            }\n            FuseOp::Abs(unary_fuse_args)\n            | FuseOp::Exp(unary_fuse_args)\n            | FuseOp::Log(unary_fuse_args)\n            | FuseOp::Log1p(unary_fuse_args)\n            | FuseOp::Cos(unary_fuse_args)\n            | FuseOp::Sin(unary_fuse_args)\n            | FuseOp::Tanh(unary_fuse_args)\n            | FuseOp::Erf(unary_fuse_args)\n            | FuseOp::Sqrt(unary_fuse_args)\n            | FuseOp::Recip(unary_fuse_args)\n            | FuseOp::Assign(unary_fuse_args) => {\n                unary_fuse_args.input.multi_block_variable(registers);\n                unary_fuse_args.out.multi_block_variable(registers);\n            }\n            FuseOp::Clamp {\n                input,\n                min,\n                max,\n                out,\n            } => {\n                input.multi_block_variable(registers);\n                min.multi_block_variable(registers);\n                max.multi_block_variable(registers);\n                out.multi_block_variable(registers);\n            }\n            FuseOp::ConditionalAssign {\n                cond,\n                lhs,\n                rhs,\n                out,\n            } => {\n                cond.multi_block_variable(registers);\n                lhs.multi_block_variable(registers);\n                rhs.multi_block_variable(registers);\n                out.multi_block_variable(registers);\n            }\n            FuseOp::Gather {\n                input,\n                indices,\n                output,\n                dim: _,\n            } => {\n                input.multi_block_variable(registers);\n                indices.multi_block_variable(registers);\n                output.multi_block_variable(registers);\n            }\n            FuseOp::Select {\n                input,\n                indices,\n                output,\n                dim: _,\n            } => {\n                input.multi_block_variable(registers);\n                indices.multi_block_variable(registers);\n                output.multi_block_variable(registers);\n            }\n            FuseOp::Dequantize {\n                values,\n                params,\n                output,\n                scheme: _,\n            } => {\n                values.multi_block_variable(registers);\n                params.multi_block_variable(registers);\n                output.multi_block_variable(registers);\n            }\n        }\n    }\n}\n\n#[cube]\n/// Initializes block variables, both globals and locals.\npub fn multi_block_variables_init(\n    #[comptime] block: &FuseBlockConfig,\n    variables: &mut MultiBlockVariables,\n) {\n    let output = comptime! {\n        let mut output = Vec::<(MultiBlockPos, StorageType)>::new();\n        block.multi_block_variables(&mut output);\n        output\n    };\n\n    #[unroll]\n    for i in 0..comptime!(output.len()) {\n        let (key, dtype) = comptime!(output.get(i).unwrap().clone());\n        set_polyfill::<DynElem, DynSize>(comptime![Type::new(dtype).with_vector_size(block.width)]);\n        variables.init(key);\n    }\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// A reference layout determines how a fuse execution will access elements in tensors.\n///\n/// It can either follow the same layout as a concrete tensor, or follow a virtual layout.\npub enum RefLayout {\n    Concrete(FuseArg),\n    Virtual(VirtualLayout),\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// A virtual layout is always contiguous and retrieves its shape from either a reshaped tensor or a\n/// tensor with swap dimensions.\npub enum VirtualLayout {\n    /// Virtual tensor with the provided shape id and contiguous strides.\n    Reshaped {\n        reshape_pos: usize,\n        vector_size: VectorSize,\n    },\n    /// Virtual tensor with the same shape as the given input, but with swap dims and contiguous\n    /// strides.\n    SwapDims(FuseArg, (usize, usize)),\n    /// Virtual tensor with the same shape as the given input, but with contiguous strides.\n    Shape(FuseArg, usize),\n    /// We don't have access to global metadata, they are passed as runtime values.\n    Runtime { pos: usize },\n}\n\nimpl FuseArg {\n    /// Adds layout information.\n    ///\n    /// It's going to impact how the input or output is read and written to.\n    pub fn add_layout_info(&mut self, layout: LayoutInfo) {\n        match self {\n            FuseArg::Input(_, _, old) => {\n                *old = layout;\n            }\n            FuseArg::Output(_, _, old) => {\n                *old = layout;\n            }\n            _ => {}\n        }\n    }\n}\n\nimpl RegistryQuery<Self> for FuseArg {}\n\nimpl From<ElemType> for FuseType {\n    fn from(value: ElemType) -> Self {\n        match value {\n            ElemType::Float(kind) => match kind {\n                FloatKind::F16 => Self::F16,\n                FloatKind::BF16 => Self::BF16,\n                FloatKind::F32 => Self::F32,\n                FloatKind::Flex32 => Self::Flex32,\n                _ => panic!(\"Unsupported precision for fusion: {value}\"),\n            },\n            ElemType::Int(kind) => match kind {\n                IntKind::I64 => Self::I64,\n                IntKind::I32 => Self::I32,\n                IntKind::I16 => Self::I16,\n                IntKind::I8 => Self::I8,\n            },\n            ElemType::UInt(kind) => match kind {\n                UIntKind::U64 => Self::U64,\n                UIntKind::U32 => Self::U32,\n                UIntKind::U16 => Self::U16,\n                UIntKind::U8 => Self::U8,\n            },\n            ElemType::Bool => panic!(\"Bool should be encoded as u8 or u32\"),\n        }\n    }\n}\n\nimpl From<StorageType> for FuseType {\n    fn from(value: StorageType) -> Self {\n        value.elem_type().into()\n    }\n}\n\nimpl FuseType {\n    /// Converts the [fused element type](FuseType) into the [cubecl element type](ElemType).\n    pub fn into_elem(self) -> ElemType {\n        match self {\n            FuseType::F32 => ElemType::Float(FloatKind::F32),\n            FuseType::Flex32 => ElemType::Float(FloatKind::Flex32),\n            FuseType::F16 => ElemType::Float(FloatKind::F16),\n            FuseType::BF16 => ElemType::Float(FloatKind::BF16),\n            FuseType::I64 => ElemType::Int(IntKind::I64),\n            FuseType::I32 => ElemType::Int(IntKind::I32),\n            FuseType::I16 => ElemType::Int(IntKind::I16),\n            FuseType::I8 => ElemType::Int(IntKind::I8),\n            FuseType::U64 => ElemType::UInt(UIntKind::U64),\n            FuseType::U32 => ElemType::UInt(UIntKind::U32),\n            FuseType::U16 => ElemType::UInt(UIntKind::U16),\n            FuseType::U8 => ElemType::UInt(UIntKind::U8),\n            FuseType::F64 => ElemType::Float(FloatKind::F64),\n        }\n    }\n\n    /// Convert the [fused element type](FuseType) into the [cubecl storage type](StorageType).\n    pub fn into_storage_type(self) -> StorageType {\n        self.into_elem().into()\n    }\n\n    /// Convert the [fused element type](FuseType) into the [cubecl type](Type)\n    pub fn into_type(self, vector_size: VectorSize) -> Type {\n        Type::new(self.into_storage_type()).with_vector_size(vector_size)\n    }\n}\n\nimpl From<DType> for FuseType {\n    fn from(value: DType) -> Self {\n        match value {\n            DType::F32 => Self::F32,\n            DType::Flex32 => Self::Flex32,\n            DType::F16 => Self::F16,\n            DType::BF16 => Self::BF16,\n            DType::I64 => Self::I64,\n            DType::I32 => Self::I32,\n            DType::I16 => Self::I16,\n            DType::I8 => Self::I8,\n            DType::U64 => Self::U64,\n            DType::U32 => Self::U32,\n            DType::U16 => Self::U16,\n            DType::U8 => Self::U8,\n            DType::Bool(BoolStore::Native) => unimplemented!(\"Bool should be U8 or U32\"),\n            DType::Bool(BoolStore::U8) => Self::U8,\n            DType::Bool(BoolStore::U32) => Self::U32,\n            DType::F64 => Self::F64,\n            DType::QFloat(scheme) => match scheme.store {\n                QuantStore::Native => match scheme.value {\n                    QuantValue::Q8F | QuantValue::Q8S => Self::I8,\n                    QuantValue::E4M3 | QuantValue::E5M2 => {\n                        unimplemented!(\"Unsupported precision for fusion\")\n                    }\n                    QuantValue::Q4F\n                    | QuantValue::Q4S\n                    | QuantValue::Q2F\n                    | QuantValue::Q2S\n                    | QuantValue::E2M1 => {\n                        panic!(\"Can't store native sub-byte values\")\n                    }\n                },\n                QuantStore::PackedU32(_) => Self::U32,\n                QuantStore::PackedNative(_) => match scheme.value {\n                    QuantValue::E2M1 => unimplemented!(\"Unsupported precision for fusion\"),\n                    other => panic!(\"{other:?} doesn't support native packing\"),\n                },\n            },\n        }\n    }\n}\n\nimpl Display for FuseArg {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            FuseArg::Input(pos, ..) => write!(f, \"input({pos})\"),\n            FuseArg::Output(pos, ..) => write!(f, \"output({pos})\"),\n            FuseArg::BlockLocal { pos, ty } => write!(f, \"local({pos}, {ty:?})\"),\n            FuseArg::MultiBlockLocal(mbp, ..) => write!(f, \"{mbp}\"),\n            FuseArg::MultiBlockGlobal(mbp, ..) => write!(f, \"global_{mbp}\"),\n            FuseArg::Scalar(pos, ..) => write!(f, \"scalar({pos})\"),\n            FuseArg::ScalarShape(pos) => write!(f, \"scalar_shape({pos})\"),\n            FuseArg::Literal(val, ..) => write!(f, \"literal_{val}\"),\n            FuseArg::InputReshaped { original, .. } => write!(f, \"input_reshaped_{original}\"),\n            FuseArg::InputSwapDims { original, .. } => write!(f, \"input_swap_dims_{original}\"),\n        }\n    }\n}\n\nimpl Display for MultiBlockPos {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        write!(\n            f,\n            \"block_local({}-{})\",\n            self.block_pos, self.block_local_pos\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs",
    "content": "use super::{io::*, ir::*};\nuse burn_std::quantization::{QuantScheme, QuantStore, QuantValue};\nuse cubecl::{\n    ir::{ElemType, FloatKind, StorageType, UIntKind},\n    prelude::*,\n};\nuse cubek::quantization::{dequantize::dequantize_symmetric_packed_value_at, scheme::QuantMode};\n\n#[cube]\n/// Fuse element-wise operations at the given write position.\n///\n/// # Arguments\n///\n/// - `inputs`: Contains all readonly global kernel arguments.\n/// - `outputs`: Contains all readwrite global kernel arguments.\n/// - `locals`: Contains all local variables defined during kernel expansion.\n/// - `write_pos`: The logical position the values are written to.\n/// - `write_values`: The explicit values to write at the given position.\n/// - `write_args`: The arguments associated to the `writes_values`.\n/// - `config`: The current [fuse block configuration](FuseBlockConfig).\n///\n/// # Notes\n///\n/// The function will start by writing `write_values`.\npub fn fuse_on_write<E: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    write_values: Registry<FuseArg, Vector<E, N>>,\n    #[comptime] write_args: Vec<FuseArg>,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    comment!(\"Fuse on write begin\");\n    // Write the values given as arguments.\n    #[unroll]\n    for i in 0..write_args.len() {\n        let arg = comptime![write_args.get(i).unwrap().clone()];\n        let val = write_values.find(arg.clone());\n\n        write::<E, N>(inputs, outputs, locals, write_pos, val, arg, config);\n    }\n\n    fuse(inputs, outputs, locals, write_pos, config);\n    comment!(\"Fuse on write end\");\n}\n\n#[cube]\n/// Fuse element-wise operations at the given read position.\n///\n/// # Arguments\n///\n/// - `inputs`: Contains all readonly global kernel arguments.\n/// - `outputs`: Contains all readwrite global kernel arguments.\n/// - `locals`: Contains all local variables defined during kernel expansion.\n/// - `read_pos`: The logical position the values are read from.\n/// - `read_args`: The arguments associated to the `read_pos`.\n/// - `config`: The current [fuse block configuration](FuseBlockConfig).\n///\n/// # Returns\n///\n/// - A sequence of values associated to the given `read_args`.  \npub fn fuse_on_read<E: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    read_pos: usize,\n    #[comptime] read_args: Sequence<FuseArg>,\n    #[comptime] config: &FuseBlockConfig,\n) -> Sequence<Vector<E, N>> {\n    comment!(\"Fuse on read begin\");\n    fuse(inputs, outputs, locals, read_pos, config);\n\n    let mut output = Sequence::new();\n\n    #[unroll]\n    for i in 0..read_args.len() {\n        let arg = comptime![read_args.index(i).clone()];\n        let value = read::<E, N>(inputs, outputs, locals, read_pos, arg, config);\n\n        output.push(value);\n    }\n\n    comment!(\"Fuse on read end\");\n    output\n}\n\n#[cube]\n/// Initializes [LocalArgs] given the input and output [arguments](GlobalArgs) with the [FuseBlockConfig].\n///\n/// # Notes\n///\n/// The goal is to resolve and cache the reference shape and strides, as it is used in many\n/// different function during kernel expansion.\npub fn init_locals(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    #[comptime] config: &FuseBlockConfig,\n) -> LocalArgs {\n    comment!(\"Init locals begin\");\n    let mut ref_shape = Array::new(config.rank);\n    let mut ref_strides = Array::new(config.rank);\n\n    let locals = match config.ref_layout.clone() {\n        RefLayout::Concrete(arg) => match comptime![arg] {\n            FuseArg::Input(index, ..) => {\n                let layout = inputs.tensors.index(index);\n\n                #[unroll]\n                for i in 0..config.rank {\n                    ref_shape[i] = layout.tensor.shape(i);\n                    ref_strides[i] = layout.tensor.stride(i);\n                }\n\n                LocalArgs::new(\n                    ref_shape.to_slice(),\n                    ref_strides.to_slice(),\n                    layout.tensor.vector_size(),\n                )\n            }\n            FuseArg::Output(index, ..) => {\n                let layout = outputs.tensors.index(index);\n\n                #[unroll]\n                for i in 0..config.rank {\n                    ref_shape[i] = layout.tensor.shape(i);\n                    ref_strides[i] = layout.tensor.stride(i);\n                }\n\n                LocalArgs::new(\n                    ref_shape.to_slice(),\n                    ref_strides.to_slice(),\n                    layout.tensor.vector_size(),\n                )\n            }\n            _ => comptime![panic!(\"Invalid concrete ref layout.\")],\n        },\n        RefLayout::Virtual(layout) => match layout {\n            VirtualLayout::SwapDims(original, dims) => {\n                let layout = match original.clone() {\n                    FuseArg::Input(pos, ..) => inputs.tensors.index(pos),\n                    FuseArg::Output(pos, ..) => outputs.tensors.index(pos),\n                    _ => comptime![panic!(\"Unsupported\")],\n                };\n\n                let mut stride_curr = 1;\n\n                #[unroll]\n                #[allow(clippy::clone_on_copy)]\n                for i in 0..config.rank {\n                    let reverse = reverse_index(config.rank, i);\n                    let swap = comptime![swap_dims_transform(reverse, dims)];\n                    let shape = layout.tensor.shape(swap.clone());\n\n                    ref_shape[reverse] = shape;\n                    ref_strides[reverse] = stride_curr;\n\n                    stride_curr *= ref_shape[comptime![reverse]];\n                }\n\n                LocalArgs::new(\n                    ref_shape.to_slice(),\n                    ref_strides.to_slice(),\n                    layout.tensor.vector_size(),\n                )\n            }\n            VirtualLayout::Reshaped {\n                reshape_pos,\n                vector_size,\n            } => {\n                let mut stride_curr = 1;\n                let start = reshape_pos * config.rank;\n\n                #[unroll]\n                #[allow(clippy::clone_on_copy)]\n                for i in 0..config.rank {\n                    let reverse = reverse_index(config.rank, i);\n                    let arg = comptime![FuseArg::ScalarShape(start + reverse)];\n                    let shape = read_scalar_shape(inputs, arg.clone());\n\n                    ref_shape[comptime![reverse]] = shape;\n                    ref_strides[comptime![reverse]] = stride_curr;\n\n                    stride_curr *= ref_shape[comptime![reverse]];\n                }\n\n                LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), vector_size)\n            }\n            VirtualLayout::Runtime { pos } => {\n                let start_shape = (pos * 2) * config.rank;\n                let start_strides = start_shape + config.rank;\n\n                #[unroll]\n                for i in 0..config.rank {\n                    let shape_index = start_shape + i;\n                    let strides_index = start_strides + i;\n\n                    ref_shape[i] = *inputs.runtime_layouts.index(shape_index);\n                    ref_strides[i] = *inputs.runtime_layouts.index(strides_index);\n                }\n\n                LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), config.width)\n            }\n            VirtualLayout::Shape(original, vector_size) => {\n                let layout = match original.clone() {\n                    FuseArg::Input(pos, ..) => inputs.tensors.index(pos),\n                    FuseArg::Output(pos, ..) => outputs.tensors.index(pos),\n                    _ => comptime![panic!(\"Unsupported\")],\n                };\n                let mut stride_curr = 1;\n\n                #[unroll]\n                #[allow(clippy::clone_on_copy)]\n                for i in 0..config.rank {\n                    let reverse = reverse_index(config.rank, i);\n                    let shape = layout.tensor.shape(reverse);\n\n                    ref_shape[comptime![reverse]] = shape;\n                    ref_strides[comptime![reverse]] = stride_curr;\n\n                    stride_curr *= ref_shape[comptime![reverse]];\n                }\n\n                LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), vector_size)\n            }\n        },\n    };\n    comment!(\"Init locals end\");\n    locals\n}\n\n#[cube]\n/// Expands all [operations](FuseOp) registered in the [block config](FuseBlockConfig].\nfn fuse(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    pos: usize,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    #[unroll]\n    for index in 0..config.ops.len() {\n        let op = config.ops[index].clone();\n        let define!(E) = op.cmp_storage_ty();\n        let size!(N) = config.width;\n\n        match op {\n            FuseOp::Add(op) => add::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Div(op) => div::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Sub(op) => sub::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Mul(op) => mul::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Powf(op) => powf::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Erf(op) => erf::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Sqrt(op) => sqrt::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Abs(op) => abs::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Log(op) => log::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Log1p(op) => log1p::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Recip(op) => recip::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Assign(op) => assign::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Exp(op) => exp::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Cos(op) => cos::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Sin(op) => sin::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Tanh(op) => tanh::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Equal(op) => equal::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Greater(op) => greater::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::GreaterEqual(op) => {\n                greater_equal::<E, N>(inputs, outputs, locals, pos, op, config)\n            }\n            FuseOp::Lower(op) => lower::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::LowerEqual(op) => lower_equal::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::ConditionalAssign {\n                cond,\n                lhs,\n                rhs,\n                out,\n            } => conditional_assign::<E, N>(\n                inputs, outputs, locals, pos, cond, lhs, rhs, out, config,\n            ),\n            FuseOp::Gather {\n                input,\n                indices,\n                output,\n                dim,\n            } => gather::<E, N>(\n                inputs, outputs, locals, pos, dim, input, indices, output, config,\n            ),\n            FuseOp::Select {\n                input,\n                indices,\n                output,\n                dim,\n            } => select_indices::<E, N>(\n                inputs, outputs, locals, pos, dim, input, indices, output, config,\n            ),\n            FuseOp::Dequantize {\n                values,\n                params,\n                output,\n                scheme,\n            } => dequantize::<E, N>(\n                inputs,\n                outputs,\n                locals,\n                pos,\n                values,\n                params,\n                output,\n                scheme.scheme,\n                config,\n            ),\n            FuseOp::Rem(op) => rem::<E, N>(inputs, outputs, locals, pos, op, config),\n            FuseOp::Clamp {\n                input,\n                min,\n                max,\n                out,\n            } => clamp::<E, N>(inputs, outputs, locals, pos, input, min, max, out, config),\n        }\n    }\n}\n\nmacro_rules! binary_op {\n    ($ident:ident, $op:tt) => {\n        #[cube]\n        fn $ident<C: Numeric, N: Size>(\n            inputs: &GlobalArgs,\n            outputs: &mut GlobalArgs,\n            locals: &mut LocalArgs,\n            write_pos: usize,\n            #[comptime] op: BinaryFuseArgs,\n            #[comptime] config: &FuseBlockConfig,\n        ) {\n            let lhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.lhs, config);\n            let rhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.rhs, config);\n            let result = lhs $op rhs;\n\n            write::<C, N>(inputs, outputs, locals, write_pos, result, op.out, config);\n        }\n    };\n}\n\nmacro_rules! binary_func {\n    ($ident:ident, $func:expr, $c:tt) => {\n        #[cube]\n        fn $ident<C: $c, N: Size>(\n            inputs: &GlobalArgs,\n            outputs: &mut GlobalArgs,\n            locals: &mut LocalArgs,\n            write_pos: usize,\n            #[comptime] op: BinaryFuseArgs,\n            #[comptime] config: &FuseBlockConfig,\n        ) {\n            let lhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.lhs, config);\n            let rhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.rhs, config);\n            let result = $func(lhs, rhs);\n\n            write::<C, N>(inputs, outputs, locals, write_pos, result, op.out, config);\n        }\n    };\n}\n\nmacro_rules! comparison_op {\n    ($ident:ident, $op:tt) => {\n        #[cube]\n        fn $ident<C: Scalar + core::cmp::PartialOrd, N: Size>(\n            inputs: &GlobalArgs,\n            outputs: &mut GlobalArgs,\n            locals: &mut LocalArgs,\n            write_pos: usize,\n            #[comptime] op: BinaryFuseArgs,\n            #[comptime] config: &FuseBlockConfig,\n        ) {\n            let lhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.lhs, config);\n            let rhs = read::<C, N>(inputs, outputs, &locals, write_pos, op.rhs, config);\n            let result = Vector::new(lhs $op rhs);\n\n            write::<bool, N>(inputs, outputs, locals, write_pos, result, op.out, config);\n        }\n    };\n}\n\nmacro_rules! unary_func {\n    ($ident:ident, $func:expr, $c:tt) => {\n        #[cube]\n        fn $ident<C: $c, N: Size>(\n            inputs: &GlobalArgs,\n            outputs: &mut GlobalArgs,\n            locals: &mut LocalArgs,\n            write_pos: usize,\n            #[comptime] op: UnaryFuseArgs,\n            #[comptime] config: &FuseBlockConfig,\n        ) {\n            let input = read::<C, N>(inputs, outputs, &locals, write_pos, op.input, config);\n            let result = $func(input);\n\n            write::<C, N>(inputs, outputs, locals, write_pos, result, op.out, config);\n        }\n    };\n}\n\n#[cube]\nfn assign<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] op: UnaryFuseArgs,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    let input = read::<C, N>(inputs, outputs, locals, write_pos, op.input, config);\n\n    write::<C, N>(inputs, outputs, locals, write_pos, input, op.out, config);\n}\n\n#[cube]\nfn gather<C: Numeric, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] dim: usize,\n    #[comptime] input: FuseArg,\n    #[comptime] indices: FuseArg,\n    #[comptime] output: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    let vector_size = locals.ref_vector_size;\n\n    let pos_input = comptime! {\n        match input {\n            FuseArg::Input(pos, ..) => pos,\n            _ => panic!(\"Input tensor isn't an input\"),\n        }\n    };\n    let pos_indices = comptime! {\n        match indices {\n            FuseArg::Input(pos, ..) => pos,\n            _ => panic!(\"Indices tensor isn't an input\"),\n        }\n    };\n\n    let stride_input_dim = global_stride(inputs, dim, pos_input);\n\n    let mut index = 0;\n    let mut result = Vector::<C, N>::empty();\n\n    if comptime![dim > 0] {\n        let index_before = global_offset(\n            inputs,\n            outputs,\n            locals,\n            write_pos,\n            input.clone(),\n            comptime![Some((0, dim))],\n            config,\n        );\n        index += index_before;\n    }\n\n    if comptime![dim + 1 < config.rank] {\n        let index_after = global_offset(\n            inputs,\n            outputs,\n            locals,\n            write_pos,\n            input,\n            comptime![Some((dim + 1, config.rank))],\n            config,\n        );\n        index += index_after;\n    }\n\n    let index_offset = global_offset(\n        inputs,\n        outputs,\n        locals,\n        write_pos,\n        indices,\n        comptime![Some((0, config.rank))],\n        config,\n    );\n\n    if comptime![dim == config.rank - 1] {\n        // Per-element indexing (along the dimension)\n        #[unroll]\n        for i in 0..vector_size {\n            let offset = read_input::<u32, Const<1>>(\n                inputs,\n                locals,\n                pos_indices,\n                index_offset + i,\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n\n            let input = read_input::<C, Const<1>>(\n                inputs,\n                locals,\n                pos_input,\n                index + (offset[0] as usize * stride_input_dim),\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n\n            result[i] = input[0];\n        }\n    } else {\n        // Shared index for whole vector\n        let stride_input_vector = global_stride(inputs, config.rank - 1, pos_input);\n\n        let offset = read_input::<u32, Const<1>>(\n            inputs,\n            locals,\n            pos_indices,\n            index_offset,\n            LayoutInfo::IsRef,\n            config,\n            None,\n        );\n\n        index += offset[0] as usize * stride_input_dim;\n\n        #[unroll]\n        for i in 0..vector_size {\n            let input = read_input::<C, Const<1>>(\n                inputs,\n                locals,\n                pos_input,\n                index + i * stride_input_vector,\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n\n            result[i] = input[0];\n        }\n    }\n\n    write::<C, N>(inputs, outputs, locals, write_pos, result, output, config);\n}\n\n#[cube]\nfn select_indices<C: Numeric, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] dim: usize,\n    #[comptime] input: FuseArg,\n    #[comptime] indices: FuseArg,\n    #[comptime] output: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    let (vector_size_ref, stride_dim_ref, shape_dim_ref) = (\n        locals.ref_vector_size,\n        locals.ref_strides[dim],\n        locals.ref_shape[dim],\n    );\n\n    let pos_input = comptime! {\n        match input {\n            FuseArg::Input(pos, ..) => pos,\n            _ => panic!(\"Input tensor isn't an input\"),\n        }\n    };\n    let pos_indices = match indices {\n        FuseArg::Input(pos, ..) => pos,\n        _ => panic!(\"Indices tensor isn't an input\"),\n    };\n\n    let stride_input_dim = global_stride(inputs, dim, pos_input);\n\n    let mut index = 0;\n    let mut result = Vector::empty();\n\n    if comptime![dim != config.rank - 1] {\n        // In this scenario the select is actually broadcasted along the axis we're working on.\n        //\n        // Therefore the same indices are used to fetch multiple entries in the input tensor.\n\n        if comptime![dim > 0] {\n            let index_before = global_offset(\n                inputs,\n                outputs,\n                locals,\n                write_pos,\n                input.clone(),\n                comptime![Some((0, dim))],\n                config,\n            );\n            index += index_before;\n        }\n\n        if comptime![dim + 1 < config.rank] {\n            let index_after = global_offset(\n                inputs,\n                outputs,\n                locals,\n                write_pos,\n                input.clone(),\n                comptime![Some((dim + 1, config.rank))],\n                config,\n            );\n            index += index_after;\n        }\n\n        let stride_input_vector = global_stride(inputs, comptime![config.rank - 1], pos_input);\n        let write_pos_input = write_pos * vector_size_ref;\n        let coordinate_dim = write_pos_input / stride_dim_ref % shape_dim_ref;\n        let offset_dim = read_input::<u32, Const<1>>(\n            inputs,\n            locals,\n            pos_indices,\n            coordinate_dim,\n            LayoutInfo::IsRef,\n            config,\n            None,\n        );\n\n        index += offset_dim[0] as usize * stride_input_dim;\n\n        #[unroll]\n        for i in 0..vector_size_ref {\n            let input = read_input::<C, Const<1>>(\n                inputs,\n                locals,\n                pos_input,\n                index + i * stride_input_vector,\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n            result[i] = input[0];\n        }\n    } else {\n        // In this scenario the select is actually performed on the last dimension we're working on.\n        //\n        // Therefore we need to fetch multiple indices that correspond to different entries in the\n        // input tensor.\n\n        if comptime![dim > 0] {\n            let index_before = global_offset(\n                inputs,\n                outputs,\n                locals,\n                write_pos,\n                input.clone(),\n                comptime![Some((0, dim))],\n                config,\n            );\n            index += index_before;\n        }\n\n        if comptime![dim + 1 < config.rank] {\n            let index_after = global_offset(\n                inputs,\n                outputs,\n                locals,\n                write_pos,\n                input,\n                comptime![Some((dim + 1, config.rank))],\n                config,\n            );\n            index += index_after;\n        }\n\n        let write_pos_indices = write_pos * vector_size_ref;\n\n        #[unroll]\n        for i in 0..vector_size_ref {\n            let coordinate_dim = (write_pos_indices + i) / stride_dim_ref % shape_dim_ref;\n            let offset_dim = read_input::<u32, Const<1>>(\n                inputs,\n                locals,\n                pos_indices,\n                coordinate_dim,\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n\n            let input = read_input::<C, Const<1>>(\n                inputs,\n                locals,\n                pos_input,\n                index + (offset_dim[0] as usize * stride_input_dim),\n                LayoutInfo::IsRef,\n                config,\n                None,\n            );\n            result[i] = input[0];\n        }\n    }\n\n    write::<C, N>(inputs, outputs, locals, write_pos, result, output, config);\n}\n\n#[cube]\nfn conditional_assign<C: Scalar, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] cond: FuseArg,\n    #[comptime] lhs: FuseArg,\n    #[comptime] rhs: FuseArg,\n    #[comptime] out: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    let cond = read::<bool, N>(inputs, outputs, locals, write_pos, cond, config);\n    let lhs = read::<C, N>(inputs, outputs, locals, write_pos, lhs, config);\n    let rhs = read::<C, N>(inputs, outputs, locals, write_pos, rhs, config);\n    let result = select_many(cond, lhs, rhs);\n\n    write::<C, N>(inputs, outputs, locals, write_pos, result, out, config);\n}\n\n#[cube]\nfn clamp<C: Numeric, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] input: FuseArg,\n    #[comptime] min: FuseArg,\n    #[comptime] max: FuseArg,\n    #[comptime] out: FuseArg,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    let input = read::<C, N>(inputs, outputs, locals, write_pos, input, config);\n    let min = read::<C, N>(inputs, outputs, locals, write_pos, min, config);\n    let max = read::<C, N>(inputs, outputs, locals, write_pos, max, config);\n    let result = cubecl::prelude::clamp(input, min, max);\n\n    write::<C, N>(inputs, outputs, locals, write_pos, result, out, config);\n}\n\n#[cube]\n#[allow(clippy::explicit_counter_loop)]\nfn dequantize<C: Float, N: Size>(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    locals: &mut LocalArgs,\n    write_pos: usize,\n    #[comptime] input: FuseArg,\n    #[comptime] scales: FuseArg,\n    #[comptime] output: FuseArg,\n    #[comptime] scheme: QuantScheme,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    comptime!(assert_eq!(\n        scheme.mode,\n        QuantMode::Symmetric,\n        \"Only symmetric quantization mode is supported.\"\n    ));\n\n    let quant_ty = comptime![match scheme.store {\n        QuantStore::Native => match scheme.value {\n            QuantValue::Q8F | QuantValue::Q8S => StorageType::Scalar(ElemType::UInt(UIntKind::U8)),\n            QuantValue::E4M3 => StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),\n            QuantValue::E5M2 => StorageType::Scalar(ElemType::Float(FloatKind::E5M2)),\n            QuantValue::Q4F\n            | QuantValue::Q4S\n            | QuantValue::Q2F\n            | QuantValue::Q2S\n            | QuantValue::E2M1 => unreachable!(\"Can't store native sub-byte values\"),\n        },\n        QuantStore::PackedU32(_) => ElemType::UInt(UIntKind::U32).into(),\n        QuantStore::PackedNative(_) => match scheme.value {\n            QuantValue::E2M1 => StorageType::Packed(ElemType::Float(FloatKind::E4M3), 2),\n            other => panic!(\"{other:?} doesn't support native packing\"),\n        },\n    }];\n    let param_ty = comptime![match scheme.param {\n        cubecl::quant::scheme::QuantParam::F32 =>\n            StorageType::Scalar(ElemType::Float(FloatKind::F32)),\n        cubecl::quant::scheme::QuantParam::F16 =>\n            StorageType::Scalar(ElemType::Float(FloatKind::F16)),\n        cubecl::quant::scheme::QuantParam::BF16 =>\n            StorageType::Scalar(ElemType::Float(FloatKind::BF16)),\n        cubecl::quant::scheme::QuantParam::UE8M0 =>\n            StorageType::Scalar(ElemType::Float(FloatKind::UE8M0)),\n        cubecl::quant::scheme::QuantParam::UE4M3 =>\n            StorageType::Scalar(ElemType::Float(FloatKind::E4M3)),\n    }];\n    let q_vector_size = N::value().comptime() / scheme.num_quants();\n\n    let define!(QStoreType) = quant_ty;\n    let size!(QStoreSize) = q_vector_size;\n\n    let define!(QParamType) = param_ty;\n\n    let tensor_pos = comptime!(match input {\n        FuseArg::Input(pos, _, _) => pos,\n        _ => panic!(\"Not supported\"),\n    });\n    let pos = comptime!(match scales {\n        FuseArg::Input(pos, ..) => pos,\n        _ => unreachable!(\"\"),\n    });\n    let input =\n        read_quantized::<QStoreType, QStoreSize>(inputs, locals, write_pos, input, config, scheme);\n\n    let num_quants = scheme.num_quants();\n\n    let scales =\n        input_as_scales_view::<QParamType, Const<1>>(inputs, pos, tensor_pos, scheme.level, config);\n    let result = dequantize_symmetric_packed_value_at::<C, N, QParamType, QStoreType, QStoreSize>(\n        write_pos * num_quants,\n        input,\n        &scales,\n        scheme,\n    );\n\n    let vector = if comptime!(q_vector_size == 1) {\n        result[0]\n    } else {\n        let mut vector = Vector::empty();\n\n        #[unroll]\n        for i in 0..q_vector_size {\n            let value = result[i];\n\n            #[unroll]\n            for j in 0..num_quants {\n                let index = i * num_quants + j;\n                vector[index] = value[j];\n            }\n        }\n\n        vector\n    };\n\n    write::<C, N>(inputs, outputs, locals, write_pos, vector, output, config);\n}\n\nbinary_op!(add, +);\nbinary_op!(mul, *);\nbinary_op!(div, /);\nbinary_op!(sub, -);\n\ncomparison_op!(equal, ==);\ncomparison_op!(greater, >);\ncomparison_op!(greater_equal, >=);\ncomparison_op!(lower, <);\ncomparison_op!(lower_equal, <=);\n\nbinary_func!(powf, Vector::<C, N>::powf, Float);\nbinary_func!(rem, Vector::<C, N>::rem, Float);\n\nunary_func!(exp, Vector::<C, N>::exp, Float);\nunary_func!(log, Vector::<C, N>::ln, Float);\nunary_func!(log1p, Vector::<C, N>::log1p, Float);\nunary_func!(sqrt, Vector::<C, N>::sqrt, Float);\nunary_func!(cos, Vector::<C, N>::cos, Float);\nunary_func!(sin, Vector::<C, N>::sin, Float);\nunary_func!(tanh, Vector::<C, N>::tanh, Float);\nunary_func!(erf, Vector::<C, N>::erf, Float);\nunary_func!(recip, Vector::<C, N>::recip, Float);\nunary_func!(abs, Vector::<C, N>::abs, Numeric);\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/mod.rs",
    "content": "pub(crate) mod io;\npub(crate) mod ir;\npub(crate) mod kernel;\npub(crate) mod tensor;\npub(crate) mod view;\n\nmod base;\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/tensor.rs",
    "content": "use crate::engine::codegen::{DynElem, DynSize};\n\nuse cubecl::{ir::Type, prelude::*};\nuse serde::{Deserialize, Serialize};\nuse std::hash::Hash;\n\n/// Represents a global tensor with the given [element type](ElemType).\n///\n/// # Warning\n///\n/// The `tensor` field type [Vector<NumericExpand<DYN_ELEM_ID>>] must be set using polyfill before\n/// use.\n#[derive(CubeType, Clone)]\npub struct GlobalTensor {\n    /// The global tensor type.\n    pub tensor: Tensor<Vector<DynElem, DynSize>>,\n    /// The element type of the tensor.\n    #[cube(comptime)]\n    pub ty: Type,\n    /// Whether the current tensor is logically broadcasted.\n    #[cube(comptime)]\n    pub broadcasted: bool,\n}\n\n// Everything below is to implement [LaunchArg].\n\n#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]\npub struct GlobalTensorCompilationArg {\n    tensor: TensorCompilationArg,\n    ty: Type,\n    broadcasted: bool,\n}\n\n#[derive(new, Debug)]\npub struct GlobalTensorArg<R: Runtime> {\n    pub tensor: <Tensor<Vector<DynElem, DynSize>> as LaunchArg>::RuntimeArg<R>,\n    pub ty: Type,\n    pub broadcasted: bool,\n    pub address_type: AddressType,\n}\n\nimpl LaunchArg for GlobalTensor {\n    type RuntimeArg<R: Runtime> = GlobalTensorArg<R>;\n    type CompilationArg = GlobalTensorCompilationArg;\n\n    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<R>) -> Self::CompilationArg {\n        let tensor =\n            <Tensor<Vector<DynElem, DynSize>> as LaunchArg>::compilation_arg(&runtime_arg.tensor);\n        GlobalTensorCompilationArg {\n            tensor,\n            ty: runtime_arg.ty,\n            broadcasted: runtime_arg.broadcasted,\n        }\n    }\n\n    fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {\n        launcher.register_tensor(arg.tensor, arg.ty);\n    }\n\n    fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> GlobalTensorExpand {\n        let tensor = builder.input_tensor(arg.ty);\n\n        GlobalTensorExpand {\n            tensor: tensor.into(),\n            ty: arg.ty,\n            broadcasted: arg.broadcasted,\n        }\n    }\n    fn expand_output(\n        arg: &Self::CompilationArg,\n        builder: &mut KernelBuilder,\n    ) -> GlobalTensorExpand {\n        let tensor = match arg.tensor.inplace {\n            Some(id) => builder.inplace_output(id),\n            None => builder.output_tensor(arg.ty),\n        };\n        GlobalTensorExpand {\n            tensor: tensor.into(),\n            ty: arg.ty,\n            broadcasted: arg.broadcasted,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/codegen/view.rs",
    "content": "use crate::engine::codegen::{DynElem, DynSize, io::set_polyfill_typed};\n\nuse super::{\n    io::{\n        Transform, global_buffer_len, global_vector_size, input_as_slice, read_input,\n        read_input_window, ref_buffer_len, ref_len,\n    },\n    ir::{FuseArg, FuseBlockConfig, GlobalArgs, LayoutInfo, LocalArgs},\n    kernel::fuse_on_write,\n};\nuse cubecl::{\n    CubeType,\n    io::read_masked,\n    ir::StorageType,\n    prelude::{barrier::BarrierExpand, *},\n    std::tensor::{\n        ViewOperations, ViewOperationsExpand, ViewOperationsMut, ViewOperationsMutExpand,\n        layout::Coords1d,\n    },\n};\n\n#[allow(dead_code, reason = \"only used in expand\")]\n#[derive(CubeType)]\npub struct GlobalInput {\n    inputs: GlobalArgs,\n    locals: LocalArgs,\n    #[cube(comptime)]\n    pos: usize,\n    #[cube(comptime)]\n    ty: StorageType,\n    #[cube(comptime)]\n    layout: LayoutInfo,\n    #[cube(comptime)]\n    config: FuseBlockConfig,\n    #[cube(comptime)]\n    transform: Option<Transform>,\n}\n\n#[cube]\nimpl GlobalInput {\n    pub fn new(\n        inputs: &GlobalArgs,\n        locals: &LocalArgs,\n        #[comptime] arg: FuseArg,\n        #[comptime] config: FuseBlockConfig,\n        #[comptime] transform: Option<Transform>,\n    ) -> GlobalInput {\n        let (pos, ty, layout) = comptime![match arg {\n            FuseArg::Input(pos, prec, layout) => (pos, prec.into_storage_type(), layout),\n            _ => unreachable!(\"Must be concrete input\"),\n        }];\n\n        GlobalInput {\n            inputs: inputs.clone(),\n            locals: locals.clone(),\n            pos,\n            ty,\n            layout,\n            config,\n            transform,\n        }\n    }\n}\n\nimpl<E: CubePrimitive> ViewOperations<E, Coords1d> for GlobalInput {}\nimpl<E: CubePrimitive> ViewOperationsExpand<E, Coords1d> for GlobalInputExpand {\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        ViewOperationsExpand::<E, Coords1d>::__expand_read_unchecked_method(self, scope, pos)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_checked_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        let zero = E::__expand_cast_from(scope, 0.into());\n        ViewOperationsExpand::<E, Coords1d>::__expand_read_masked_method(self, scope, pos, zero)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_masked_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n        value: <E as CubeType>::ExpandType,\n    ) -> <E as CubeType>::ExpandType {\n        let in_bounds = ViewOperationsExpand::<E, Coords1d>::__expand_is_in_bounds_method(\n            self,\n            scope,\n            pos.clone(),\n        );\n        set_polyfill_typed::expand::<E, DynElem, DynSize>(scope);\n        let slice = input_as_slice::expand(scope, self.inputs.clone(), self.pos);\n        read_masked::expand::<E>(scope, in_bounds, slice, pos, value)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_unchecked_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        set_polyfill_typed::expand::<E, DynElem, DynSize>(scope);\n        let value = read_input::expand::<E::Scalar, E::Size>(\n            scope,\n            self.inputs.clone(),\n            self.locals.clone(),\n            self.pos,\n            pos,\n            self.layout,\n            self.config.clone(),\n            self.transform.clone(),\n        );\n        E::__expand_cast_from(scope, value)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_to_linear_slice_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n        end: NativeExpand<usize>,\n    ) -> SliceExpand<E, ReadOnly> {\n        set_polyfill_typed::expand::<E, DynElem, DynSize>(scope);\n        let end = add::expand(scope, end.clone(), 1.into());\n        read_input_window::expand(scope, self.inputs.clone(), self.pos, pos, end)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_tensor_map_load_method(\n        &self,\n        _scope: &mut Scope,\n        _barrier: BarrierExpand,\n        _shared_memory: SliceExpand<E, ReadWrite>,\n        _pos: NativeExpand<usize>,\n    ) {\n        panic!(\"Not a tensor map\")\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_shape_method(&self, scope: &mut Scope) -> NativeExpand<usize> {\n        global_buffer_len::expand(scope, self.inputs.clone(), self.pos)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_is_in_bounds_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n    ) -> NativeExpand<bool> {\n        let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos);\n        lt::expand(scope, pos, buffer_len)\n    }\n}\n\nimpl Vectorized for GlobalInput {}\nimpl VectorizedExpand for GlobalInputExpand {\n    fn vector_size(&self) -> VectorSize {\n        let mut temp_scope = Scope::root(false);\n        global_vector_size::expand(&mut temp_scope, self.inputs.clone(), self.pos)\n    }\n}\n\n#[allow(dead_code, reason = \"only used in expand\")]\n#[derive(CubeType)]\npub struct FusedOutput {\n    inputs: GlobalArgs,\n    outputs: GlobalArgs,\n    locals: LocalArgs,\n    arg: FuseArg,\n    #[cube(comptime)]\n    config: FuseBlockConfig,\n}\n\n#[cube]\nimpl FusedOutput {\n    pub fn new(\n        inputs: &GlobalArgs,\n        outputs: &mut GlobalArgs,\n        locals: &mut LocalArgs,\n        arg: FuseArg,\n        #[comptime] config: FuseBlockConfig,\n    ) -> Self {\n        FusedOutput {\n            inputs: inputs.clone(),\n            outputs: outputs.clone(),\n            locals: locals.clone(),\n            arg,\n            config,\n        }\n    }\n}\n\nimpl<E: CubePrimitive> ViewOperations<E, Coords1d> for FusedOutput {}\nimpl<E: CubePrimitive> ViewOperationsExpand<E, Coords1d> for FusedOutputExpand {\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        todo!()\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_checked_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        todo!()\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_masked_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n        _value: <E as CubeType>::ExpandType,\n    ) -> <E as CubeType>::ExpandType {\n        todo!()\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_read_unchecked_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n    ) -> <E as CubeType>::ExpandType {\n        todo!()\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_to_linear_slice_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n        _size: NativeExpand<usize>,\n    ) -> SliceExpand<E, ReadOnly> {\n        todo!()\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_tensor_map_load_method(\n        &self,\n        _scope: &mut Scope,\n        _barrier: BarrierExpand,\n        _shared_memory: SliceExpand<E, ReadWrite>,\n        _pos: NativeExpand<usize>,\n    ) {\n        panic!(\"Not a tensor map\")\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_shape_method(&self, scope: &mut Scope) -> NativeExpand<usize> {\n        ref_len::expand(\n            scope,\n            self.inputs.clone(),\n            self.outputs.clone(),\n            self.locals.clone(),\n            self.config.clone(),\n        )\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_is_in_bounds_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n    ) -> NativeExpand<bool> {\n        let buffer_len = ref_buffer_len::expand(\n            scope,\n            self.inputs.clone(),\n            self.outputs.clone(),\n            self.locals.clone(),\n            self.config.clone(),\n        );\n        lt::expand(scope, pos, buffer_len)\n    }\n}\n\nimpl<E: CubePrimitive> ViewOperationsMut<E, Coords1d> for FusedOutput {}\nimpl<E: CubePrimitive> ViewOperationsMutExpand<E, Coords1d> for FusedOutputExpand {\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_write_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n        value: <E as CubeType>::ExpandType,\n    ) {\n        let values = Registry::<FuseArg, Vector<E::Scalar, E::Size>>::__expand_new(scope);\n        let mut args = comptime![Vec::<FuseArg>::new()];\n\n        let value = Vector::__expand_cast_from(scope, value);\n        values\n            .clone()\n            .__expand_insert_method(scope, comptime![self.arg.clone()], value);\n        comptime![args.push(self.arg.clone())];\n\n        fuse_on_write::expand(\n            scope,\n            self.inputs.clone(),\n            self.outputs.clone(),\n            self.locals.clone(),\n            pos,\n            values,\n            args,\n            self.config.clone(),\n        );\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_write_checked_method(\n        &self,\n        scope: &mut Scope,\n        pos: NativeExpand<usize>,\n        value: <E as CubeType>::ExpandType,\n    ) {\n        let in_bounds = ViewOperationsExpand::<E, Coords1d>::__expand_is_in_bounds_method(\n            self,\n            scope,\n            pos.clone(),\n        );\n        if_expand(scope, in_bounds, |scope| {\n            ViewOperationsMutExpand::<E, Coords1d>::__expand_write_method(self, scope, pos, value);\n        })\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_to_linear_slice_mut_method(\n        &self,\n        _scope: &mut Scope,\n        _pos: NativeExpand<usize>,\n        _size: NativeExpand<usize>,\n    ) -> SliceExpand<E, ReadWrite> {\n        todo!(\"Not yet supported\")\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn __expand_tensor_map_store_method(\n        &self,\n        _scope: &mut Scope,\n        _shared_memory: SliceExpand<E, ReadOnly>,\n        _pos: NativeExpand<usize>,\n    ) {\n        panic!(\"Not a tensor map\")\n    }\n}\n\nimpl Vectorized for FusedOutput {}\nimpl VectorizedExpand for FusedOutputExpand {\n    fn vector_size(&self) -> VectorSize {\n        self.locals.ref_vector_size\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/fuser.rs",
    "content": "use super::{\n    codegen::ir::{BinaryFuseArgs, FuseArg, FuseOp, UnaryFuseArgs},\n    settings::FuseSettings,\n    trace::{FuseTrace, TraceFuser, block::QuantInput},\n};\nuse crate::engine::{codegen::ir::QuantSchemeFuse, scoring::Scoring};\nuse burn_fusion::{FuserProperties, FuserStatus, OperationFuser};\nuse burn_ir::{\n    BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr,\n    TensorIr, UnaryOpIr,\n};\nuse burn_std::{DType, Shape};\nuse cubecl::ir::ElemType;\n\n/// The base operation fuser that can be used to fuse [all supported fuse operations](FuseOp).\n///\n///\n/// This fuser doesn't create a ready-to-execute kernel, but rather generates a\n/// [trace](FuseTrace) that be used with a [runner](super::trace::TraceRunner).\n///\n/// Since this fuser supports fusing multiple blocks, you can fuse any compute-bound operations\n/// with the combination of fuse-on-read and fuse-on-write strategy.\n///\n/// # Notes\n///\n/// It is responsible to translate [OperationIr] into [FuseOp] and it uses the [TraceFuser]\n/// to actually fuse the [FuseOp] when possible.\n#[derive(Debug, Clone)]\npub(crate) struct TraceOperationFuser {\n    fuser: TryTraceFuser,\n    scoring: Scoring,\n    pub(crate) settings: FuseSettings,\n    pub(crate) current_output_shape: Shape,\n    status: FuserStatus,\n    pub(crate) num_ops: usize,\n    pub(crate) num_views: usize,\n    pub(crate) max_bindings: u32,\n}\n\nimpl TraceOperationFuser {\n    /// Checks if the [operation](OperationIr) can be fused with the current fuser.\n    pub(crate) fn can_fuse(&self, op: &OperationIr) -> bool {\n        let len_previous = self.len();\n        let mut fuser_cloned = self.clone();\n\n        fuser_cloned.fuse(op);\n        let len_after = fuser_cloned.len();\n\n        len_after > len_previous\n    }\n}\n\nimpl OperationFuser<FuseTrace> for TraceOperationFuser {\n    fn fuse(&mut self, op: &OperationIr) {\n        if let FuserStatus::Closed = self.status {\n            return;\n        }\n\n        match op {\n            OperationIr::Drop(tensor) => {\n                if self.num_ops == 0 {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n\n                self.fuser.fuser.fuse_dropped(tensor);\n            }\n            OperationIr::BaseFloat(ops) => {\n                if !self.fuse_base(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            OperationIr::BaseInt(ops) => {\n                if !self.fuse_base(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            OperationIr::Float(_dtype, ops) => {\n                if !self.fuse_float(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            OperationIr::NumericFloat(_dtype, ops) => {\n                if !self.fuse_numeric(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            OperationIr::NumericInt(_dtype, ops) => {\n                if !self.fuse_numeric(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            OperationIr::BaseBool(ops) => {\n                if !self.fuse_base(ops) {\n                    self.status = FuserStatus::Closed;\n                    return;\n                }\n            }\n            _ => {\n                self.status = FuserStatus::Closed;\n                return;\n            }\n        };\n\n        self.status = FuserStatus::Open;\n        self.scoring.register(op);\n        self.num_ops += 1;\n    }\n\n    fn finish(&mut self) -> FuseTrace {\n        self.fuser.finish(self.current_output_shape.clone())\n    }\n\n    fn len(&self) -> usize {\n        self.num_ops\n    }\n\n    fn reset(&mut self) {\n        self.num_ops = 0;\n        self.scoring.reset();\n        self.num_views = 0;\n        self.status = FuserStatus::Open;\n        self.fuser = TryTraceFuser::new(self.max_bindings, self.settings);\n        self.current_output_shape = Shape::new([]);\n    }\n\n    fn status(&self) -> FuserStatus {\n        self.status\n    }\n\n    fn properties(&self) -> FuserProperties {\n        let ready = self.num_ops > 0;\n        let score = self\n            .scoring\n            .evaluate(&self.fuser.clone().finish(self.current_output_shape.clone()));\n\n        FuserProperties { ready, score }\n    }\n\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<FuseTrace>> {\n        Box::new(self.clone())\n    }\n}\n\nimpl TraceOperationFuser {\n    /// Creates a new fuser.\n    pub fn new(max_bindings: u32, settings: FuseSettings) -> Self {\n        Self {\n            fuser: TryTraceFuser::new(max_bindings, settings),\n            settings,\n            scoring: Scoring::default(),\n            num_ops: 0,\n            num_views: 0,\n            max_bindings,\n            current_output_shape: Shape::new([]),\n            status: FuserStatus::Open,\n        }\n    }\n\n    /// Closes the fuser.\n    pub fn close(&mut self) {\n        self.status = FuserStatus::Closed;\n    }\n\n    /// Declares an input tensor argument where the kernel is responsible to load.\n    ///\n    /// # Returns\n    ///\n    /// - The argument that maps to the tensor to be used during kernel expansion.\n    pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {\n        self.fuser.fuser.input_unhandled(tensor)\n    }\n\n    /// Declares an input quantized tensor argument where the kernel is responsible to load.\n    ///\n    /// # Returns\n    ///\n    /// None if it's not possible to fuse a quantized tensor. Otherwise:\n    ///\n    /// - The argument that maps to the tensor values to be used during kernel expansion.\n    /// - The argument that maps to the tensor params to be used during kernel expansion.\n    pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {\n        self.fuser.fuser.input_quantized_unhandled(tensor)\n    }\n\n    /// Declares an output tensor argument where the kernel is responsible to write values.\n    ///\n    /// # Notes\n    ///\n    /// Normally you don't have to declare outputs explicitly before they are going to be\n    /// fused based on the operations [fused](Self::fuse).\n    ///\n    /// # Returns\n    ///\n    /// - The argument that maps to the tensor to be used during kernel expansion.\n    pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {\n        if self.current_output_shape.is_empty() {\n            self.current_output_shape = tensor.shape.clone();\n        } else if self.current_output_shape.iter().sum::<usize>() < tensor.shape.iter().sum() {\n            // The larguest shape win.\n            self.current_output_shape = tensor.shape.clone();\n        }\n\n        self.fuser.fuser.output_unhandled(tensor)\n    }\n\n    /// Closes the previous block and declares a new one.\n    ///\n    /// # Arguments\n    ///\n    /// - arguments: Tensors that are logical outputs of the current block and inputs of the following blocks.\n    /// - settings: [FuseSettings] to be used by the next block.\n    ///\n    /// # Returns\n    ///\n    /// None if it's impossible to create a next block with the given arguments. Otherwise, the\n    /// corresponding [arguments](Arg) to the given tensors are returned.\n    pub fn next_block<const N: usize>(\n        &mut self,\n        arguments: [&TensorIr; N],\n        settings: FuseSettings,\n        global: bool,\n    ) -> [FuseArg; N] {\n        let block_pos = self.fuser.fuser.num_previous_blocks();\n        let current_output_shape =\n            core::mem::replace(&mut self.current_output_shape, Shape::new([]));\n\n        self.fuser.fuser.next_block(current_output_shape, settings);\n\n        self.settings = settings;\n        self.status = FuserStatus::Open;\n\n        arguments.map(|arg| self.fuser.fuser.block_local_input(arg, block_pos, global))\n    }\n\n    /// Tag the [tensor](TensorIr) as received from a previous block.\n    ///\n    /// This will avoid reading the input again and instead use le local version when possible.\n    pub fn block_local_input(&mut self, tensor: &TensorIr, block_pos: usize, global: bool) {\n        self.fuser\n            .fuser\n            .block_local_input(tensor, block_pos, global);\n    }\n\n    fn fuse_base(&mut self, ops: &BaseOperationIr) -> bool {\n        match ops {\n            BaseOperationIr::Equal(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            BaseOperationIr::EqualElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Equal(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            BaseOperationIr::Cast(desc) => {\n                self.fuse_unary_op(&desc.input, &desc.out, |input, out| {\n                    FuseOp::Assign(UnaryFuseArgs { input, out })\n                })\n            }\n            BaseOperationIr::SwapDims(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                if self.fuser.fuse(|fuser| {\n                    fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?;\n\n                    Some(())\n                }) {\n                    self.num_views += 1;\n                    true\n                } else {\n                    false\n                }\n            }\n            BaseOperationIr::Reshape(desc) => {\n                if desc.input.shape == desc.out.shape {\n                    return self.fuse_unary_op(&desc.input, &desc.out, |input, out| {\n                        FuseOp::Assign(UnaryFuseArgs { input, out })\n                    });\n                }\n\n                if desc.input.shape.rank() > desc.out.shape.rank() {\n                    // Not yet supported.\n                    return false;\n                }\n\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                if self.fuser.fuse(|fuser| {\n                    fuser.input_reshaped(&desc.input, &desc.out)?;\n                    Some(())\n                }) {\n                    self.num_views += 1;\n                    true\n                } else {\n                    false\n                }\n            }\n            BaseOperationIr::Ones(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                let elem: ElemType = desc.out.dtype.into();\n                let precision = elem.into();\n                let input = FuseArg::Literal(1, precision);\n\n                self.fuser.fuse(|fuser| {\n                    let out = fuser.output(&desc.out)?;\n\n                    fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));\n\n                    Some(())\n                })\n            }\n            BaseOperationIr::Zeros(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                let elem: ElemType = desc.out.dtype.into();\n                let precision = elem.into();\n                let input = FuseArg::Literal(0, precision);\n\n                self.fuser.fuse(|fuser| {\n                    let out = fuser.output(&desc.out)?;\n\n                    fuser.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));\n\n                    Some(())\n                })\n            }\n            BaseOperationIr::Gather(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let input = build.input_indexed(&desc.tensor)?;\n                    let indices = build.input_indexed(&desc.indices)?;\n                    let output = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::Gather {\n                        input,\n                        indices,\n                        output,\n                        dim: desc.dim,\n                    });\n\n                    Some(())\n                })\n            }\n            BaseOperationIr::Select(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let input = build.input_indexed(&desc.tensor)?;\n                    let indices = build.input_indexed(&desc.indices)?;\n                    let output = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::Select {\n                        input,\n                        indices,\n                        output,\n                        dim: desc.dim,\n                    });\n\n                    Some(())\n                })\n            }\n            BaseOperationIr::MaskWhere(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let cond = build.input(&desc.mask)?;\n                    let rhs = build.input(&desc.tensor)?;\n                    let lhs = build.input(&desc.value)?;\n                    let out = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::ConditionalAssign {\n                        cond,\n                        lhs,\n                        rhs,\n                        out,\n                    });\n\n                    Some(())\n                })\n            }\n            BaseOperationIr::MaskFill(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let cond = build.input(&desc.mask)?;\n                    let lhs = build.scalar(&desc.value, desc.out.dtype);\n                    let rhs = build.input(&desc.tensor)?;\n                    let out = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::ConditionalAssign {\n                        cond,\n                        lhs,\n                        rhs,\n                        out,\n                    });\n\n                    Some(())\n                })\n            }\n            _ => false,\n        }\n    }\n\n    fn fuse_float(&mut self, ops: &FloatOperationIr) -> bool {\n        match ops {\n            FloatOperationIr::Exp(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Exp(UnaryFuseArgs { input, out }))\n            }\n            FloatOperationIr::Log(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Log(UnaryFuseArgs { input, out }))\n            }\n            FloatOperationIr::Powf(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            FloatOperationIr::Log1p(desc) => self.fuse_unary_ops(desc, |input, out| {\n                FuseOp::Log1p(UnaryFuseArgs { input, out })\n            }),\n            FloatOperationIr::Cos(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Cos(UnaryFuseArgs { input, out }))\n            }\n            FloatOperationIr::Sin(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Sin(UnaryFuseArgs { input, out }))\n            }\n            FloatOperationIr::PowfScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Powf(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            FloatOperationIr::Tanh(desc) => self.fuse_unary_ops(desc, |input, out| {\n                FuseOp::Tanh(UnaryFuseArgs { input, out })\n            }),\n            FloatOperationIr::Erf(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Erf(UnaryFuseArgs { input, out }))\n            }\n            FloatOperationIr::Sqrt(desc) => self.fuse_unary_ops(desc, |input, out| {\n                FuseOp::Sqrt(UnaryFuseArgs { input, out })\n            }),\n            FloatOperationIr::Recip(desc) => self.fuse_unary_ops(desc, |input, out| {\n                FuseOp::Recip(UnaryFuseArgs { input, out })\n            }),\n            FloatOperationIr::Dequantize(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let qinput = build.input_quantized(&desc.input)?;\n                    let out = build.output(&desc.out)?;\n\n                    match qinput {\n                        QuantInput::AlreadyDequantized { local } => {\n                            build.fuse_operation(FuseOp::Assign(UnaryFuseArgs {\n                                input: local,\n                                out,\n                            }));\n                        }\n                        QuantInput::Quantized { values, params } => {\n                            build.fuse_operation(FuseOp::Dequantize {\n                                values,\n                                params,\n                                output: out,\n                                scheme: match desc.input.dtype {\n                                    DType::QFloat(scheme) => QuantSchemeFuse { scheme },\n                                    _ => unreachable!(\"Should be a quant tensor.\"),\n                                },\n                            });\n                        }\n                    }\n\n                    Some(())\n                })\n            }\n            _ => false,\n        }\n    }\n\n    fn fuse_numeric(&mut self, op: &NumericOperationIr) -> bool {\n        match op {\n            NumericOperationIr::Add(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::AddScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Add(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Sub(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::SubScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Sub(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Mul(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::MulScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Mul(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Div(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::DivScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Div(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Abs(desc) => {\n                self.fuse_unary_ops(desc, |input, out| FuseOp::Abs(UnaryFuseArgs { input, out }))\n            }\n            NumericOperationIr::Lower(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::LowerElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Lower(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Greater(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::GreaterElem(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Greater(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::LowerEqual(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::LowerEqualElem(desc) => self\n                .fuse_scalar_ops(desc, |lhs, rhs, out| {\n                    FuseOp::LowerEqual(BinaryFuseArgs { lhs, rhs, out })\n                }),\n            NumericOperationIr::GreaterEqual(desc) => self\n                .fuse_binary_ops(desc, |lhs, rhs, out| {\n                    FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })\n                }),\n            NumericOperationIr::GreaterEqualElem(desc) => self\n                .fuse_scalar_ops(desc, |lhs, rhs, out| {\n                    FuseOp::GreaterEqual(BinaryFuseArgs { lhs, rhs, out })\n                }),\n            NumericOperationIr::Full(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let input = build.scalar(&desc.value, desc.out.dtype);\n                    let out = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::Assign(UnaryFuseArgs { input, out }));\n\n                    Some(())\n                })\n            }\n            NumericOperationIr::Rem(desc) => self.fuse_binary_ops(desc, |lhs, rhs, out| {\n                FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::RemScalar(desc) => self.fuse_scalar_ops(desc, |lhs, rhs, out| {\n                FuseOp::Rem(BinaryFuseArgs { lhs, rhs, out })\n            }),\n            NumericOperationIr::Clamp(desc) => {\n                if !self.output_is_compatible(&desc.out) {\n                    return false;\n                }\n\n                self.fuser.fuse(|build| {\n                    let input = build.input(&desc.tensor)?;\n                    let min = build.scalar(&desc.min, desc.out.dtype);\n                    let max = build.scalar(&desc.max, desc.out.dtype);\n                    let out = build.output(&desc.out)?;\n\n                    build.fuse_operation(FuseOp::Clamp {\n                        input,\n                        min,\n                        max,\n                        out,\n                    });\n\n                    Some(())\n                })\n            }\n            _ => false,\n        }\n    }\n\n    fn fuse_binary_ops<Func>(&mut self, desc: &BinaryOpIr, func: Func) -> bool\n    where\n        Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,\n    {\n        if !self.output_is_compatible(&desc.out) {\n            return false;\n        }\n\n        self.fuser.fuse(|build| {\n            let lhs = build.input(&desc.lhs)?;\n            let rhs = build.input(&desc.rhs)?;\n            let out = build.output(&desc.out)?;\n\n            build.fuse_operation(func(lhs, rhs, out));\n\n            Some(())\n        })\n    }\n\n    fn fuse_unary_ops<Func>(&mut self, desc: &UnaryOpIr, func: Func) -> bool\n    where\n        Func: Fn(FuseArg, FuseArg) -> FuseOp,\n    {\n        self.fuse_unary_op(&desc.input, &desc.out, func)\n    }\n\n    fn fuse_unary_op<Func>(&mut self, input: &TensorIr, out: &TensorIr, func: Func) -> bool\n    where\n        Func: Fn(FuseArg, FuseArg) -> FuseOp,\n    {\n        if !self.output_is_compatible(out) {\n            return false;\n        }\n\n        self.fuser.fuse(|build| {\n            let input = build.input(input)?;\n            let out = build.output(out)?;\n            build.fuse_operation(func(input, out));\n            Some(())\n        })\n    }\n\n    fn fuse_scalar_ops<Func>(&mut self, desc: &ScalarOpIr, func: Func) -> bool\n    where\n        Func: Fn(FuseArg, FuseArg, FuseArg) -> FuseOp,\n    {\n        if !self.output_is_compatible(&desc.out) {\n            return false;\n        }\n\n        self.fuser.fuse(|build| {\n            let elem = desc.lhs.dtype;\n            let lhs = build.input(&desc.lhs)?;\n            let rhs = build.scalar(&desc.rhs, elem);\n            let out = build.output(&desc.out)?;\n\n            build.fuse_operation(func(lhs, rhs, out));\n\n            Some(())\n        })\n    }\n\n    fn output_is_compatible(&mut self, out: &TensorIr) -> bool {\n        if self.current_output_shape.is_empty() {\n            self.current_output_shape.clone_from(&out.shape);\n            return true;\n        }\n\n        let rank = self.current_output_shape.len();\n\n        // Rank should be equal.\n        if rank != out.shape.num_dims() {\n            return false;\n        }\n\n        let mut updated = self.current_output_shape.clone();\n        let mut should_update = false;\n\n        #[allow(clippy::needless_range_loop)]\n        for i in 0..rank {\n            let curr = self.current_output_shape[i];\n            let new = out.shape[i];\n\n            if curr == new {\n                continue;\n            }\n\n            // Broadcast not enabled.\n            if !self.settings.broadcast {\n                return false;\n            }\n\n            // Broadcasted on new dim.\n            if new == 0 {\n                continue;\n            }\n\n            // Broadcasted on curr dim - update reference output shape.\n            if curr == 0 && self.settings.output_shape_updates {\n                should_update = true;\n                updated[i] = new;\n                continue;\n            }\n\n            return false;\n        }\n\n        if should_update {\n            // For now forced to have exact shape.\n            if updated != out.shape {\n                return false;\n            }\n\n            self.current_output_shape.clone_from_slice(&out.shape);\n        }\n\n        true\n    }\n}\n\n#[derive(Debug, Clone)]\n/// Builder wrapper to limit the number of bindings in generated kernels.\nstruct TryTraceFuser {\n    fuser: TraceFuser,\n    max_bindings: u32,\n    max_ops: u32,\n    added_ops: bool,\n}\n\nimpl TryTraceFuser {\n    fn new(max_bindings: u32, settings: FuseSettings) -> Self {\n        Self {\n            fuser: TraceFuser::new(settings),\n            max_bindings,\n            // A good default, avoid errors with for loops over only memory\n            // bound operations.\n            max_ops: 64,\n            added_ops: false,\n        }\n    }\n\n    fn fuse(&mut self, add_ops: impl FnOnce(&mut TraceFuser) -> Option<()>) -> bool {\n        if self.fuser.num_ops_fused() > self.max_ops {\n            return false;\n        }\n\n        // Always allow the first operation to be added.\n        if !self.added_ops {\n            self.added_ops = true;\n\n            if add_ops(&mut self.fuser).is_none() {\n                return false;\n            }\n            return true;\n        }\n\n        let mut cloned = self.fuser.clone();\n        if add_ops(&mut cloned).is_none() {\n            return false;\n        }\n\n        if cloned.estimate_bindings() > self.max_bindings {\n            return false;\n        }\n\n        self.fuser = cloned;\n        true\n    }\n\n    fn finish(&mut self, shape: Shape) -> FuseTrace {\n        self.fuser.finish(shape)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/base.rs",
    "content": "use crate::{\n    CubeFusionHandle,\n    engine::{\n        launch::{\n            HandleInput, HandleOutput, LaunchPlan, executor::LaunchPlanExecutor,\n            input::InputPlanner, output::OutputPlanner, runner::TraceRunner,\n            vectorization::VectorizationPlanner,\n        },\n        trace::{FuseTrace, TraceError, TuneOutput},\n    },\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{Runtime, client::ComputeClient};\nuse std::marker::PhantomData;\n\n/// The launcher is responsible to launch a fused kernel using the [TraceRunner] and a [FuseTrace].\n///\n/// TODO: We can reuse the same launcher between runs and avoid a lot of allocation, by simply\n/// resetting the state.\npub struct FuseTraceLauncher<'a, R: Runtime, Runner: TraceRunner<R>> {\n    trace: &'a FuseTrace,\n    runner: &'a Runner,\n    _runtime: PhantomData<R>,\n}\n\nimpl<'a, R: Runtime, Runner: TraceRunner<R>> FuseTraceLauncher<'a, R, Runner> {\n    /// Creates a new launcher.\n    pub fn new(trace: &'a FuseTrace, runner: &'a Runner) -> Self {\n        Self {\n            trace,\n            runner,\n            _runtime: PhantomData,\n        }\n    }\n    /// Launches the fuse kernel on the given device modifying the context.\n    pub fn launch(\n        &self,\n        client: &ComputeClient<R>,\n        device: &R::Device,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n    ) -> Result<TuneOutput<R>, TraceError<Runner::Error>> {\n        let mut plan = LaunchPlan::new(&self.trace.blocks);\n\n        InputPlanner::new(&self.trace.resources, &self.trace.blocks).run(context, &mut plan);\n\n        OutputPlanner::new(&self.trace.resources, &self.trace.blocks)\n            .run(client, device, context, &mut plan);\n\n        VectorizationPlanner::new(&self.trace.resources, &self.trace.blocks).run(\n            client,\n            self.runner,\n            context,\n            &mut plan,\n        );\n\n        match LaunchPlanExecutor::new(&self.trace.resources, &self.trace.blocks).execute::<_>(\n            client,\n            self.runner,\n            context,\n            plan,\n        ) {\n            Err(err) => {\n                self.rollback(context, err.handles_input, err.handles_output);\n                Err(err.error)\n            }\n            Ok(val) => Ok(val),\n        }\n    }\n\n    fn rollback(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        handle_inputs: Vec<HandleInput<R>>,\n        handle_outputs: Vec<HandleOutput<R>>,\n    ) {\n        for input in handle_inputs {\n            match input {\n                HandleInput::Normal(input) => {\n                    context\n                        .handles\n                        .register_handle(input.global_ir.id, input.handle_rollback());\n                }\n                HandleInput::QuantValues(input) => {\n                    context\n                        .handles\n                        .register_handle(input.global_ir.id, input.handle);\n                }\n                HandleInput::QuantParams(_) => {\n                    // The scales are part of the quant data handle.\n                }\n            };\n        }\n        for output in handle_outputs {\n            if let HandleOutput::Owned {\n                global_id, handle, ..\n            } = output\n            {\n                context.handles.register_handle(global_id, handle);\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/executor.rs",
    "content": "use super::{HandleInput, HandleOutput, LaunchPlan, ReferenceSelection};\nuse crate::engine::launch::runner::TraceRunner;\nuse crate::engine::trace::{FuseResources, TensorView, TraceError, TuneOutput, block::FuseBlock};\nuse crate::{\n    CubeFusionHandle,\n    engine::{\n        codegen::ir::{\n            FuseBlockConfig, FuseOp, FuseType, GlobalArgsLaunch, RefLayout, VirtualLayout,\n        },\n        codegen::tensor::GlobalTensorArg,\n    },\n};\nuse burn_fusion::stream::{Context, ScalarId};\nuse burn_ir::ScalarIr;\nuse cubecl::{\n    Runtime,\n    client::ComputeClient,\n    ir::{AddressType, Type},\n    prelude::{InputScalar, TensorArg},\n};\nuse std::marker::PhantomData;\n\n/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context).\npub struct LaunchPlanExecutor<'a, R: Runtime> {\n    resources: &'a FuseResources,\n    blocks: &'a Vec<FuseBlock>,\n    _r: PhantomData<R>,\n}\n\n#[derive(new, Debug)]\npub struct ExecutionError<R: Runtime, Runner: TraceRunner<R>> {\n    pub error: TraceError<Runner::Error>,\n    pub handles_input: Vec<HandleInput<R>>,\n    pub handles_output: Vec<HandleOutput<R>>,\n}\n\nimpl<'a, R: Runtime> LaunchPlanExecutor<'a, R> {\n    pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {\n        Self {\n            resources,\n            blocks,\n            _r: PhantomData,\n        }\n    }\n\n    pub fn execute<Runner: TraceRunner<R>>(\n        self,\n        client: &ComputeClient<R>,\n        runner: &Runner,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: LaunchPlan<'a, R>,\n    ) -> Result<TuneOutput<R>, ExecutionError<R, Runner>> {\n        let mut num_writes = 0;\n        for b in plan.blocks.iter() {\n            for writes in b.writes.values() {\n                num_writes += writes.len();\n            }\n        }\n\n        #[cfg(feature = \"autotune-checks\")]\n        let mut tune_output = TuneOutput::Checked {\n            handles: std::collections::HashMap::new(),\n        };\n\n        #[cfg(not(feature = \"autotune-checks\"))]\n        let mut tune_output = TuneOutput::UnChecked(PhantomData);\n\n        if num_writes == 0 {\n            // Nothing to write, can skip execution.\n            return Ok(tune_output);\n        }\n\n        let mut inputs = GlobalArgsLaunch::default();\n        let mut outputs = GlobalArgsLaunch::default();\n\n        register_inputs(plan.handle_inputs.clone(), &mut inputs);\n        register_scalars(\n            self.resources.scalars.iter(),\n            self.resources.views.iter(),\n            context,\n            &mut inputs,\n        );\n        register_outputs::<R>(plan.handle_outputs.clone(), &mut outputs, &mut tune_output);\n\n        for layout in plan.runtime_layouts {\n            for s in layout.shape.iter() {\n                inputs.runtime_layouts.push(*s);\n            }\n            for s in layout.strides.iter() {\n                inputs.runtime_layouts.push(*s);\n            }\n        }\n\n        let mut configs = Vec::with_capacity(plan.blocks.len());\n\n        for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) {\n            let reference = match block_plan.reference {\n                ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout),\n                ReferenceSelection::VirtualShape { original, .. } => {\n                    RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width))\n                }\n                ReferenceSelection::SwapDims { original, dims } => {\n                    RefLayout::Virtual(VirtualLayout::SwapDims(original, dims))\n                }\n                ReferenceSelection::Reshaped { reshape_pos } => {\n                    RefLayout::Virtual(VirtualLayout::Reshaped {\n                        reshape_pos,\n                        vector_size: block_plan.width,\n                    })\n                }\n                ReferenceSelection::Runtime { pos } => {\n                    RefLayout::Virtual(VirtualLayout::Runtime { pos })\n                }\n                ReferenceSelection::Searching => {\n                    return Err(ExecutionError::new(\n                        TraceError::ReferenceNotFound,\n                        plan.handle_inputs,\n                        plan.handle_outputs,\n                    ));\n                }\n            };\n\n            let mut ops = Vec::<FuseOp>::new();\n\n            for read_ops in block_plan.reads.into_values() {\n                for op in read_ops {\n                    ops.push(op);\n                }\n            }\n\n            for op in block.ops.iter() {\n                ops.push(op.clone());\n            }\n\n            for opsw in block_plan.writes.into_values() {\n                for op in opsw {\n                    ops.push(op);\n                }\n            }\n\n            let config = FuseBlockConfig {\n                rank: plan.rank,\n                ref_layout: reference,\n                ops,\n                width: block_plan.width,\n            };\n            configs.push(config);\n        }\n\n        Runner::run(runner, client, inputs, outputs, &configs).map_err(|err| {\n            ExecutionError::new(\n                TraceError::RunnerError(err),\n                plan.handle_inputs,\n                plan.handle_outputs,\n            )\n        })?;\n\n        Ok(tune_output)\n    }\n}\n\nfn register_inputs<R: Runtime>(\n    handle_inputs: Vec<HandleInput<R>>,\n    inputs: &mut GlobalArgsLaunch<R>,\n) {\n    for hi in handle_inputs {\n        match hi {\n            HandleInput::Normal(hi) => {\n                let at = hi.handle.required_address_type();\n                let arg = hi.handle.into_tensor_arg(hi.global_ir.shape.clone());\n                inputs.tensors.push(GlobalTensorArg::new(\n                    arg,\n                    hi.precision.into_type(hi.vector_size),\n                    hi.broadcated,\n                    at,\n                ));\n            }\n            HandleInput::QuantValues(hi) => {\n                let at = hi.handle.required_address_type();\n                let arg = hi.handle.into_tensor_arg(hi.global_ir.shape.clone());\n                inputs.tensors.push(GlobalTensorArg::new(\n                    arg,\n                    hi.precision.into_type(hi.vector_size),\n                    false,\n                    at,\n                ));\n            }\n            HandleInput::QuantParams(hi) => {\n                let at = hi.handle.required_address_type();\n                let arg = hi.handle.into_tensor_arg(hi.shape.clone());\n                inputs.tensors.push(GlobalTensorArg::new(\n                    arg,\n                    hi.precision.into_type(1),\n                    false,\n                    at,\n                ));\n            }\n        }\n    }\n}\n\nfn register_outputs<R: Runtime>(\n    handle_outputs: Vec<HandleOutput<R>>,\n    outputs: &mut GlobalArgsLaunch<R>,\n    #[allow(unused_variables)] tune_output: &mut TuneOutput<R>,\n) {\n    for item in handle_outputs {\n        match item {\n            HandleOutput::Alias {\n                input_pos,\n                precision,\n                global_shape,\n                strides,\n                #[cfg(feature = \"autotune-checks\")]\n                debug_info,\n            } => {\n                outputs.tensors.push(GlobalTensorArg::new(\n                    TensorArg::Alias {\n                        input_pos,\n                        strides,\n                        shape: global_shape,\n                    },\n                    precision.into_type(1),\n                    false,\n                    AddressType::default(),\n                ));\n\n                #[cfg(feature = \"autotune-checks\")]\n                if let TuneOutput::Checked { handles, .. } = tune_output {\n                    handles.insert(\n                        debug_info.relative_id,\n                        (debug_info.global_shape.clone(), debug_info.handle.clone()),\n                    );\n                }\n            }\n            HandleOutput::Owned {\n                precision,\n                handle,\n                global_shape,\n                vectorization: vector_size,\n                #[cfg(feature = \"autotune-checks\")]\n                relative_id,\n                ..\n            } => {\n                let at = handle.required_address_type();\n                let arg = handle.into_tensor_arg(global_shape.clone());\n\n                let elem = precision.into_elem();\n                let ty = Type::new(elem.into()).with_vector_size(vector_size);\n\n                #[cfg(feature = \"autotune-checks\")]\n                if let TuneOutput::Checked { handles, .. } = tune_output {\n                    handles.insert(*relative_id, (global_shape.clone(), handle.clone()));\n                }\n\n                outputs\n                    .tensors\n                    .push(GlobalTensorArg::new(arg, ty, false, at));\n            }\n        }\n    }\n}\n\nfn register_scalars<'h, R: Runtime>(\n    scalars: impl Iterator<Item = &'h (FuseType, u64)>,\n    views: impl DoubleEndedIterator<Item = &'h TensorView>,\n    context: &mut Context<'_, CubeFusionHandle<R>>,\n    inputs: &mut GlobalArgsLaunch<R>,\n) {\n    for (precision, id) in scalars {\n        let dtype = precision.into_storage_type();\n        match context.scalars.get(&ScalarId { value: *id }) {\n            Some(scalar) => match scalar {\n                ScalarIr::Float(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),\n                ScalarIr::Int(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),\n                ScalarIr::UInt(val) => inputs.scalars.push(InputScalar::new(*val, dtype)),\n                ScalarIr::Bool(val) => inputs.scalars.push(InputScalar::new(*val as u8, dtype)),\n            },\n            None => panic!(\"Scalar ID not found\"),\n        }\n    }\n\n    for relative in views {\n        if let TensorView::Reshape { reshaped, .. } = relative {\n            let global = context.tensors.get(reshaped).unwrap();\n\n            for shape in global.shape.iter() {\n                inputs.reshapes.push(*shape);\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/input.rs",
    "content": "use super::{BlockPlan, HandleInput, InputReference};\nuse super::{LaunchPlan, NormalHandleInput, PotentialInplace};\nuse crate::CubeFusionHandle;\nuse crate::engine::launch::{QuantParamsHandleInput, QuantValuesHandleInput};\nuse crate::engine::trace::block::FuseBlock;\nuse crate::engine::trace::{FuseResources, RegisterTensor, TensorView};\nuse burn_fusion::stream::Context;\nuse burn_ir::{TensorIr, TensorStatus};\nuse burn_std::quantization::params_shape;\nuse cubecl::Runtime;\nuse std::marker::PhantomData;\n\n/// Fetch and register [input handles](HandleInput). Also identifies potential inputs that\n/// can be used inplace and/or as the [reference layout](super::super::ir::RefLayout).\npub struct InputPlanner<'a, R: Runtime> {\n    resources: &'a FuseResources,\n    blocks: &'a Vec<FuseBlock>,\n    _r: PhantomData<R>,\n}\n\nimpl<'a, R: Runtime> InputPlanner<'a, R> {\n    pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {\n        Self {\n            resources,\n            blocks,\n            _r: PhantomData,\n        }\n    }\n\n    pub fn run(self, context: &mut Context<'_, CubeFusionHandle<R>>, plan: &mut LaunchPlan<'a, R>) {\n        for (pos, input) in self.resources.inputs.iter().enumerate() {\n            match input {\n                RegisterTensor::Normal(tensor_relative, precision) => {\n                    let mut tensor_global =\n                        context.tensors.get(&tensor_relative.id).unwrap().clone();\n                    let handle = context\n                        .handles\n                        .get_handle(&tensor_global.id, &TensorStatus::ReadOnly);\n\n                    if let TensorStatus::ReadWrite = tensor_relative.status {\n                        plan.cleared.push(tensor_global.id);\n                    }\n\n                    let mut new_strides = handle.strides.clone();\n\n                    self.analyze(plan, pos, tensor_relative, &handle);\n\n                    if tensor_global.shape.rank() < plan.rank {\n                        let num_elem: usize = tensor_global.shape.iter().product();\n                        for _ in 0..(plan.rank - tensor_global.shape.rank()) {\n                            tensor_global.shape.insert(0, 1);\n                            new_strides.insert(0, num_elem);\n                        }\n                    }\n\n                    plan.handle_inputs\n                        .push(HandleInput::Normal(NormalHandleInput::new(\n                            tensor_global,\n                            tensor_relative,\n                            *precision,\n                            handle,\n                            new_strides,\n                        )));\n                }\n                RegisterTensor::QuantValues(tensor_relative) => {\n                    let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone();\n                    let handle = context\n                        .handles\n                        .get_handle(&tensor_global.id, &TensorStatus::ReadOnly);\n\n                    let scheme = match tensor_relative.dtype {\n                        burn_std::DType::QFloat(scheme) => scheme,\n                        _ => unreachable!(\"Can't have quant data without QFloat\"),\n                    };\n                    let params = handle.params(scheme).unwrap();\n                    let precision = tensor_relative.dtype.into();\n                    let precision_scales = params.dtype.into();\n\n                    let global_shape = tensor_global.shape.clone();\n                    let shape_params = params_shape(&global_shape, scheme.level);\n                    plan.handle_inputs\n                        .push(HandleInput::QuantValues(QuantValuesHandleInput {\n                            relative_id: tensor_relative.id,\n                            global_ir: tensor_global,\n                            precision,\n                            handle,\n                            vector_size: 1,\n                        }));\n\n                    plan.handle_inputs\n                        .push(HandleInput::QuantParams(QuantParamsHandleInput {\n                            precision: precision_scales,\n                            handle: params,\n                            shape: shape_params,\n                        }));\n                }\n                RegisterTensor::QuantParams(_) => {\n                    // It is registered at the same time as quant data.\n                    // The order is important and the index in the vector as well, so that's why we\n                    // have QuantParams.\n                }\n            }\n        }\n    }\n\n    fn analyze(\n        &self,\n        plan: &mut LaunchPlan<'a, R>,\n        pos: usize,\n        tensor_relative: &'a TensorIr,\n        handle: &CubeFusionHandle<R>,\n    ) {\n        if !self\n            .resources\n            .inputs_unhandled\n            .contains(&tensor_relative.id)\n        {\n            let mut is_a_view = false;\n            // For each view we try to see if it's not possible to set it as a reference input.\n            for view in self.resources.views.iter() {\n                for (block_plan, block) in plan.blocks.iter_mut().zip(self.blocks) {\n                    is_a_view = is_a_view\n                        || Self::analyze_view(pos, tensor_relative, block, block_plan, view);\n                }\n            }\n\n            if !is_a_view {\n                self.analyze_normal(plan, pos, tensor_relative, handle);\n            }\n        }\n    }\n\n    /// Analyzes if the given tensor can be used inplace in one of the block.\n    fn analyze_normal(\n        &self,\n        plan: &mut LaunchPlan<'a, R>,\n        pos: usize,\n        tensor_relative: &'a TensorIr,\n        handle: &CubeFusionHandle<R>,\n    ) {\n        enum BlockInplaceSelection {\n            Notinit,\n            /// The block reads the input, and therefore can use it for inplace.\n            Selected(usize),\n            /// The same input is used in multiple blocks.\n            Unavailable,\n        }\n\n        let mut block_inplace_selection = BlockInplaceSelection::Notinit;\n\n        for (idx, block) in plan.blocks.iter().enumerate() {\n            if block.reads.contains_key(&tensor_relative.id) {\n                match block_inplace_selection {\n                    BlockInplaceSelection::Notinit => {\n                        block_inplace_selection = BlockInplaceSelection::Selected(idx);\n                    }\n                    BlockInplaceSelection::Selected(_) => {\n                        block_inplace_selection = BlockInplaceSelection::Unavailable;\n                    }\n                    BlockInplaceSelection::Unavailable => {}\n                }\n            }\n        }\n\n        if let BlockInplaceSelection::Selected(idx) = block_inplace_selection {\n            if self.blocks[idx].shape_ref != tensor_relative.shape {\n                return;\n            }\n\n            let block_plan = &mut plan.blocks[idx];\n            if tensor_relative.status == TensorStatus::ReadWrite {\n                if self.blocks[idx].settings.inplace && handle.handle.can_mut() {\n                    block_plan.potential_inplaces.push(PotentialInplace {\n                        input_pos: pos,\n                        tensor_relative,\n                        strides: handle.strides.clone(),\n                    });\n                }\n                // Inplace tensors are normally really good as the reference layout, since\n                // it's normally better to be based on writes rather than on reads.\n                block_plan.potential_reference_input =\n                    Some(InputReference::Normal { input_pos: pos });\n            } else {\n                block_plan.potential_reference_input =\n                    Some(InputReference::Normal { input_pos: pos });\n            }\n        }\n    }\n\n    /// Analyzes if the given tensor is also the view provided, and check if it can be used as the reference layout\n    /// for the given block.\n    fn analyze_view(\n        pos: usize,\n        tensor_relative: &'a TensorIr,\n        block: &FuseBlock,\n        block_plan: &mut BlockPlan<'a>,\n        view: &TensorView,\n    ) -> bool {\n        match view {\n            TensorView::Reshape {\n                reshaped,\n                original,\n                reshape_pos,\n                shape_relative,\n            } => {\n                if original == &tensor_relative.id || reshaped == &tensor_relative.id {\n                    if block_plan.potential_reference_input.is_none()\n                        && shape_relative == &block.shape_ref\n                    {\n                        block_plan.potential_reference_input = Some(InputReference::Reshaped {\n                            reshape_pos: *reshape_pos,\n                        });\n                    }\n                    return true;\n                }\n            }\n            TensorView::SwapDims {\n                swapped,\n                original,\n                dims,\n                ..\n            } => {\n                if swapped == &tensor_relative.id {\n                    return true;\n                }\n\n                if original == &tensor_relative.id {\n                    let shape = tensor_relative\n                        .shape\n                        .clone()\n                        .swapped(dims.0, dims.1)\n                        .unwrap();\n\n                    if block_plan.potential_reference_input.is_none() && shape == block.shape_ref {\n                        block_plan.potential_reference_input = Some(InputReference::SwapDims {\n                            original_pos: pos,\n                            dims: *dims,\n                        });\n                    }\n                    return true;\n                }\n            }\n        };\n\n        false\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/mod.rs",
    "content": "pub(crate) mod executor;\npub(crate) mod input;\npub(crate) mod output;\npub(crate) mod runner;\npub(crate) mod vectorization;\n\npub(crate) mod plan;\npub use plan::*;\n\nmod base;\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/output.rs",
    "content": "use super::{\n    super::codegen::ir::FuseType, BlockPlan, HandleOutput, InputReference, LaunchPlan,\n    NormalHandleInput, ReferenceSelection,\n};\nuse crate::{\n    CubeFusionHandle,\n    engine::{\n        codegen::ir::{FuseArg, FuseOp, LayoutInfo},\n        launch::HandleInput,\n        settings::RefLayoutSetting,\n        trace::{FuseResources, RegisterTensor, RuntimeLayout, TensorView, block::FuseBlock},\n    },\n    strides_dyn_rank,\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::{TensorId, TensorIr};\nuse burn_std::Shape;\nuse burn_std::{\n    Strides,\n    tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action},\n};\nuse cubecl::{Runtime, client::ComputeClient, ir::StorageType};\n\n/// Create or reuse handles for the outputs.\n///\n/// It is also responsible to select the reference tensor.\npub struct OutputPlanner<'a, R: Runtime> {\n    resources: &'a FuseResources,\n    outputs_sorted: Vec<OutputSorted<'a>>,\n    handles: Vec<Option<HandleOutput<R>>>,\n    globals: Vec<Option<TensorIr>>,\n    blocks: &'a Vec<FuseBlock>,\n}\n\n#[derive(Debug)]\nstruct OutputSorted<'a> {\n    pos_original: usize,\n    precision: FuseType,\n    tensor_relative: &'a TensorIr,\n}\n\n#[derive(Debug)]\nenum OutputKind {\n    Normal,\n    Inplace {\n        /// The position in the potential inplace vector\n        input_pos: usize,\n    },\n    Transform(TensorView),\n}\n\nimpl<'a, R: Runtime> OutputPlanner<'a, R> {\n    pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {\n        let mut outputs_sorted: Vec<_> = resources\n            .outputs\n            .iter()\n            .enumerate()\n            .filter_map(|(pos, entry)| match entry {\n                RegisterTensor::Normal(ir, p) => Some((pos, ir, p)),\n                RegisterTensor::QuantValues(_) => None,\n                RegisterTensor::QuantParams(_) => None,\n            })\n            .map(|(pos, tensor, precision)| OutputSorted {\n                pos_original: pos,\n                precision: *precision,\n                tensor_relative: tensor,\n            })\n            .collect();\n\n        outputs_sorted.sort_by(|a, b| {\n            let a_val: usize = a.tensor_relative.shape.iter().sum();\n            let b_val: usize = b.tensor_relative.shape.iter().sum();\n\n            b_val.cmp(&a_val)\n        });\n\n        let mut handles = Vec::with_capacity(resources.outputs.len());\n        let mut globals = Vec::with_capacity(resources.outputs.len());\n\n        for _ in 0..resources.outputs.len() {\n            handles.push(None);\n            globals.push(None);\n        }\n\n        Self {\n            resources,\n            outputs_sorted,\n            handles,\n            globals,\n            blocks,\n        }\n    }\n\n    pub fn run(\n        mut self,\n        client: &ComputeClient<R>,\n        device: &R::Device,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n    ) {\n        // So that we can borrow self during the iteration.\n        let mut outputs = Vec::new();\n        core::mem::swap(&mut outputs, &mut self.outputs_sorted);\n\n        for output in outputs.into_iter() {\n            let tensor_global = context\n                .tensors\n                .get(&output.tensor_relative.id)\n                .unwrap()\n                .clone();\n            let strides = strides_dyn_rank(&tensor_global.shape);\n            let (kind, block_idx) = self.output_kind(plan, &tensor_global, &output, &strides);\n\n            match kind {\n                OutputKind::Inplace { input_pos } => {\n                    self.inplace_output(\n                        context,\n                        plan,\n                        output,\n                        tensor_global,\n                        strides,\n                        input_pos,\n                        block_idx,\n                    );\n                }\n                OutputKind::Normal => {\n                    self.normal_output(\n                        client,\n                        device,\n                        context,\n                        plan,\n                        output,\n                        tensor_global,\n                        strides,\n                        block_idx,\n                    );\n                }\n                OutputKind::Transform(TensorView::Reshape { original, .. }) => {\n                    self.reshaped_output(\n                        client,\n                        device,\n                        context,\n                        plan,\n                        output,\n                        tensor_global,\n                        strides,\n                        original,\n                        block_idx,\n                    );\n                }\n                OutputKind::Transform(TensorView::SwapDims { original, dims, .. }) => {\n                    self.swapped_dims_output(\n                        client,\n                        device,\n                        context,\n                        plan,\n                        output,\n                        tensor_global,\n                        original,\n                        dims,\n                        block_idx,\n                    );\n                }\n            }\n        }\n\n        for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) {\n            plan.handle_outputs.push(handle.unwrap());\n            plan.global_outputs.push(global.unwrap());\n        }\n\n        for i in 0..plan.blocks.len() {\n            if !plan.blocks[i].reference.is_found() {\n                match self.blocks[i].settings.ref_layout {\n                    RefLayoutSetting::SameAsBlock { block_pos } => {\n                        plan.blocks[i].reference =\n                            plan.blocks[block_pos as usize].reference.clone();\n                    }\n                    _ => {\n                        let new_runtime = Self::select_reference_from_inputs(\n                            &self.blocks[i],\n                            &mut plan.blocks[i],\n                            &plan.handle_inputs,\n                        );\n\n                        if let Some(shape) = new_runtime {\n                            let pos = plan.runtime_layouts.len();\n                            let mut shape_global = shape.clone();\n                            for (i, s) in shape.iter().enumerate() {\n                                shape_global[i] = *context.shapes_relative2global.get(s).unwrap();\n                            }\n\n                            let strides = strides_dyn_rank(&shape_global);\n\n                            plan.blocks[i].reference = ReferenceSelection::Runtime { pos };\n                            plan.runtime_layouts.push(RuntimeLayout {\n                                shape: shape_global,\n                                strides,\n                            });\n                        }\n                    }\n                };\n            } else {\n                Self::add_layout_info_inputs(&mut plan.blocks[i], &plan.handle_inputs);\n            }\n        }\n\n        // Make sure dropped are correctly executed.\n        for id in self.resources.dropped.iter() {\n            if let Some(tensor_global) = context.tensors.get(id) {\n                context.handles.remove_handle(tensor_global.id);\n            }\n        }\n        for id in plan.cleared.drain(..) {\n            context.handles.remove_handle(id);\n        }\n    }\n\n    fn select_reference_from_inputs(\n        block: &FuseBlock,\n        block_plan: &mut BlockPlan<'_>,\n        handle_inputs: &[HandleInput<R>],\n    ) -> Option<Shape> {\n        if let Some(input_ref) = block_plan.potential_reference_input.take() {\n            match input_ref {\n                InputReference::Normal { input_pos } => {\n                    let reference = handle_inputs\n                        .get(input_pos)\n                        .unwrap()\n                        .as_normal()\n                        .expect(\"Quant can't be used as inplace\");\n\n                    let set_ref_as_concrete = |block: &mut BlockPlan<'_>| {\n                        block.reference = ReferenceSelection::Concrete {\n                            layout: FuseArg::Input(\n                                input_pos,\n                                reference.precision,\n                                LayoutInfo::IsRef,\n                            ),\n                            shape: reference.global_ir.shape.clone(),\n                            strides: reference.handle.strides.clone(),\n                        };\n                    };\n\n                    let set_ref_as_virtual = |block: &mut BlockPlan<'_>| {\n                        block.reference = ReferenceSelection::VirtualShape {\n                            original: FuseArg::Input(\n                                input_pos,\n                                reference.precision,\n                                LayoutInfo::Unknown,\n                            ),\n                            shape: reference.global_ir.shape.clone(),\n                            strides: contiguous_strides(&reference.global_ir.shape),\n                        };\n                    };\n\n                    match block.settings.ref_layout {\n                        RefLayoutSetting::Any => set_ref_as_concrete(block_plan),\n                        RefLayoutSetting::SameAsBlock { .. } => {\n                            // Skip set ref.\n                        }\n                        RefLayoutSetting::OnlyContiguous => {\n                            if is_contiguous(&reference.global_ir.shape, &reference.handle.strides)\n                            {\n                                set_ref_as_concrete(block_plan)\n                            } else {\n                                set_ref_as_virtual(block_plan)\n                            }\n                        }\n                    }\n\n                    Self::add_layout_info_inputs(block_plan, handle_inputs);\n                }\n                InputReference::SwapDims { original_pos, dims } => {\n                    let reference = handle_inputs\n                        .get(original_pos)\n                        .unwrap()\n                        .as_normal()\n                        .expect(\"Quant can't be used in swap dims operation\");\n                    block_plan.reference = ReferenceSelection::SwapDims {\n                        original: FuseArg::Input(\n                            original_pos,\n                            reference.precision,\n                            LayoutInfo::Unknown,\n                        ),\n                        dims,\n                    };\n                }\n                InputReference::Reshaped { reshape_pos } => {\n                    block_plan.reference = ReferenceSelection::Reshaped { reshape_pos };\n                }\n            };\n            None\n        } else {\n            Some(block.shape_ref.clone())\n        }\n    }\n\n    fn add_layout_info_inputs(block: &mut BlockPlan<'_>, handle_inputs: &[HandleInput<R>]) {\n        for hi in handle_inputs.iter().filter_map(|h| match h {\n            HandleInput::Normal(input) => Some(input),\n            _ => None,\n        }) {\n            let (strides, shape) = match &block.reference {\n                ReferenceSelection::Concrete { strides, shape, .. }\n                | ReferenceSelection::VirtualShape { strides, shape, .. } => (strides, shape),\n                _ => continue,\n            };\n\n            if strides == &hi.handle.strides\n                && shape == &hi.global_ir.shape\n                && let Some(ops) = block.reads.get_mut(&hi.relative_id)\n            {\n                for op in ops.iter_mut() {\n                    if let FuseOp::Assign(op) = op {\n                        op.input.add_layout_info(LayoutInfo::SameAsRef);\n                    }\n                }\n            }\n        }\n    }\n\n    fn output_kind(\n        &self,\n        plan: &mut LaunchPlan<'a, R>,\n        tensor_global: &TensorIr,\n        output: &OutputSorted,\n        strides: &[usize],\n    ) -> (OutputKind, usize) {\n        let mut block_idx = None;\n        for (i, block) in plan.blocks.iter().enumerate() {\n            if block.writes.contains_key(&output.tensor_relative.id) {\n                block_idx = Some(i);\n                break;\n            }\n        }\n        let block_idx = block_idx.unwrap();\n\n        if let Some(transform) = self.resources.views.iter().find(|v| match v {\n            TensorView::Reshape { reshaped, .. } => reshaped == &output.tensor_relative.id,\n            TensorView::SwapDims { swapped, .. } => swapped == &output.tensor_relative.id,\n        }) {\n            return (OutputKind::Transform(transform.clone()), block_idx);\n        }\n\n        let block = &plan.blocks[block_idx];\n        let kind = block\n            .potential_inplaces\n            .iter()\n            .enumerate()\n            .find(|(_pos, pi)| {\n                pi.tensor_relative.dtype == tensor_global.dtype\n                    && pi.tensor_relative.shape == output.tensor_relative.shape\n                    && &*pi.strides == strides\n                    && block.reference.compatible_strides_for_inplace(strides)\n            })\n            .map(|(pos, _)| OutputKind::Inplace { input_pos: pos })\n            .unwrap_or(OutputKind::Normal);\n\n        (kind, block_idx)\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn inplace_output(\n        &mut self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n        output: OutputSorted,\n        tensor_global: TensorIr,\n        strides: Strides,\n        input_index: usize,\n        block_idx: usize,\n    ) {\n        let block = &mut plan.blocks[block_idx];\n        let potential_inplace = block.potential_inplaces.remove(input_index);\n        let handle_input = match plan.handle_inputs.get(potential_inplace.input_pos).unwrap() {\n            HandleInput::Normal(handle) => handle,\n            _ => {\n                unreachable!(\"Quant tensor handle can't be used inplace yet.\")\n            }\n        };\n\n        if !block.reference.is_found()\n            && !matches!(\n                self.blocks[block_idx].settings.ref_layout,\n                RefLayoutSetting::SameAsBlock { .. }\n            )\n        {\n            let index_input = self\n                .resources\n                .inputs\n                .get_index(potential_inplace.tensor_relative.id)\n                .unwrap();\n\n            block.reference = ReferenceSelection::Concrete {\n                layout: FuseArg::Input(index_input, output.precision, LayoutInfo::IsRef),\n                shape: tensor_global.shape.clone(),\n                strides: handle_input.handle.strides.clone(),\n            };\n\n            if let Some(ops) = block.reads.get_mut(&handle_input.relative_id) {\n                for op in ops.iter_mut() {\n                    if let FuseOp::Assign(op) = op {\n                        op.input.add_layout_info(LayoutInfo::IsRef);\n                        break;\n                    };\n                }\n            }\n\n            if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {\n                for op in ops {\n                    if let FuseOp::Assign(op) = op {\n                        op.out.add_layout_info(LayoutInfo::IsRef);\n                        break;\n                    }\n                }\n            };\n        } else {\n            // Already validated, necessary for correctness.\n            if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {\n                for op in ops {\n                    if let FuseOp::Assign(op) = op {\n                        op.out.add_layout_info(LayoutInfo::SameAsRef);\n                        break;\n                    }\n                }\n            };\n        }\n\n        context\n            .handles\n            .register_handle(tensor_global.id, handle_input.handle.clone());\n\n        self.handles[output.pos_original] = Some(HandleOutput::Alias {\n            input_pos: potential_inplace.input_pos,\n            precision: output.precision,\n            global_shape: tensor_global.shape.clone(),\n            strides,\n            #[cfg(feature = \"autotune-checks\")]\n            debug_info: super::HandleOutputAliasDebugInfo {\n                relative_id: output.tensor_relative.id,\n                handle: handle_input.handle.clone(),\n                global_shape: tensor_global.shape.dims.clone(),\n            },\n        });\n        self.globals[output.pos_original] = Some(tensor_global);\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn normal_output(\n        &mut self,\n        client: &ComputeClient<R>,\n        device: &R::Device,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n        output: OutputSorted,\n        tensor_global: TensorIr,\n        strides: Strides,\n        block_idx: usize,\n    ) {\n        let block = &mut plan.blocks[block_idx];\n\n        if !block.reference.is_found()\n            && self.blocks[block_idx].shape_ref == output.tensor_relative.shape\n            && !matches!(\n                self.blocks[block_idx].settings.ref_layout,\n                RefLayoutSetting::SameAsBlock { .. }\n            )\n        {\n            block.reference = ReferenceSelection::Concrete {\n                layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef),\n                shape: tensor_global.shape.clone(),\n                strides: strides.clone(),\n            };\n\n            // Sometimes outputs that are manually handled don't have any write registered.\n            if let Some(ops) = block.writes.get_mut(&output.tensor_relative.id) {\n                for op in ops {\n                    if let FuseOp::Assign(op) = op {\n                        op.out.add_layout_info(LayoutInfo::IsRef);\n                        break;\n                    }\n                }\n            };\n        } else if let ReferenceSelection::Concrete {\n            shape: ref_shape,\n            strides: ref_strides,\n            ..\n        } = &block.reference\n            && ref_strides == &strides\n            && ref_shape == &tensor_global.shape\n            && let Some(ops) = block.writes.get_mut(&output.tensor_relative.id)\n        {\n            for op in ops {\n                if let FuseOp::Assign(op) = op {\n                    op.out.add_layout_info(LayoutInfo::SameAsRef);\n                    break;\n                }\n            }\n        };\n\n        let dtype = tensor_global.dtype;\n        let size = tensor_global.shape.iter().product::<usize>() * StorageType::from(dtype).size();\n\n        let handle = CubeFusionHandle {\n            client: client.clone(),\n            handle: client.empty(size),\n            device: device.clone(),\n            strides,\n            dtype,\n            qparams: None,\n        };\n\n        plan.rank = usize::max(tensor_global.shape.rank(), plan.rank);\n        context\n            .handles\n            .register_handle(tensor_global.id, handle.clone());\n\n        self.handles[output.pos_original] = Some(HandleOutput::Owned {\n            precision: output.precision,\n            handle,\n            global_shape: tensor_global.shape.clone(),\n            global_id: tensor_global.id,\n            relative_id: output.tensor_relative.id,\n            vectorization: 1,\n        });\n        self.globals[output.pos_original] = Some(tensor_global);\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn reshaped_output(\n        &mut self,\n        client: &ComputeClient<R>,\n        device: &R::Device,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n        output: OutputSorted,\n        tensor_global: TensorIr,\n        strides: Strides,\n        original: TensorId,\n        block_idx: usize,\n    ) {\n        let block = &mut plan.blocks[block_idx];\n\n        let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);\n\n        let dtype = tensor_global.dtype;\n\n        let action = reshape_action(\n            &original_handle.global_ir.shape,\n            &original_handle.handle.strides,\n            &tensor_global.shape,\n        );\n\n        let update = match action {\n            ReshapeAction::UpdateStrides { strides } => Some(strides),\n            ReshapeAction::NoChange => Some(original_handle.handle.strides.clone()),\n            ReshapeAction::Recompute => None,\n        };\n\n        match update {\n            Some(strides) => {\n                // We modify the metadata instead.\n                remove_concrete_write(block, output.tensor_relative.id, output.pos_original);\n\n                let handle = CubeFusionHandle {\n                    client: client.clone(),\n                    handle: original_handle.handle.handle.clone(),\n                    device: device.clone(),\n                    strides,\n                    dtype,\n                    qparams: original_handle.handle.qparams.clone(),\n                };\n                context\n                    .handles\n                    .register_handle(tensor_global.id, handle.clone());\n\n                // IT will never be access, just a way to keep the original position working.\n                self.handles[output.pos_original] = Some(HandleOutput::Alias {\n                    input_pos: pos_input,\n                    precision: output.precision,\n                    global_shape: tensor_global.shape.clone(),\n                    strides: handle.strides.clone(),\n                    #[cfg(feature = \"autotune-checks\")]\n                    debug_info: super::HandleOutputAliasDebugInfo {\n                        relative_id: output.tensor_relative.id,\n                        handle: handle.clone(),\n                        global_shape: tensor_global.shape.dims.clone(),\n                    },\n                });\n                self.globals[output.pos_original] = Some(tensor_global);\n            }\n            None => {\n                self.normal_output(\n                    client,\n                    device,\n                    context,\n                    plan,\n                    output,\n                    tensor_global,\n                    strides,\n                    block_idx,\n                );\n            }\n        }\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    fn swapped_dims_output(\n        &mut self,\n        client: &ComputeClient<R>,\n        device: &R::Device,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n        output: OutputSorted,\n        tensor_global: TensorIr,\n        original: TensorId,\n        dims: (usize, usize),\n        block_idx: usize,\n    ) {\n        let block = &mut plan.blocks[block_idx];\n        let (pos_input, original_handle) = Self::find_child_input(&plan.handle_inputs, original);\n\n        let dtype = tensor_global.dtype;\n\n        // TODO: Check if we can also remove the read, if we have a dead partial graph.\n        //\n        // We modify the metadata instead.\n        remove_concrete_write(block, output.tensor_relative.id, output.pos_original);\n\n        let strides = original_handle.handle.strides.clone();\n\n        let mut handle = CubeFusionHandle {\n            client: client.clone(),\n            handle: original_handle.handle.handle.clone(),\n            device: device.clone(),\n            strides,\n            dtype,\n            qparams: original_handle.handle.qparams.clone(),\n        };\n        handle.strides.swap(dims.0, dims.1);\n\n        context\n            .handles\n            .register_handle(tensor_global.id, handle.clone());\n\n        // IT will never be access, just a way to keep the original position working.\n        self.handles[output.pos_original] = Some(HandleOutput::Alias {\n            input_pos: pos_input,\n            precision: output.precision,\n            global_shape: tensor_global.shape.clone(),\n            strides: handle.strides.clone(),\n            #[cfg(feature = \"autotune-checks\")]\n            debug_info: super::HandleOutputAliasDebugInfo {\n                relative_id: output.tensor_relative.id,\n                handle: handle.clone(),\n                global_shape: tensor_global.shape.dims.clone(),\n            },\n        });\n        self.globals[output.pos_original] = Some(tensor_global);\n    }\n\n    fn find_child_input(\n        handle_inputs: &[HandleInput<R>],\n        original: TensorId,\n    ) -> (usize, &NormalHandleInput<R>) {\n        handle_inputs\n            .iter()\n            .enumerate()\n            .find_map(|(pi, handle)| match handle {\n                HandleInput::Normal(handle) => match handle.relative_id == original {\n                    true => Some((pi, handle)),\n                    false => None,\n                },\n                _ => None, // Quant tensor can't be reshaped.\n            })\n            .unwrap()\n    }\n}\n\nfn remove_concrete_write(block: &mut BlockPlan, id: TensorId, output_pos: usize) {\n    let ops = block.writes.remove(&id);\n\n    if let Some(ops) = ops {\n        let mut keep = Vec::with_capacity(ops.len());\n\n        for op in ops {\n            if let FuseOp::Assign(args) = &op {\n                if let FuseArg::Output(pos, ..) = args.out {\n                    if pos != output_pos {\n                        keep.push(op);\n                    }\n                } else {\n                    keep.push(op);\n                }\n            }\n        }\n        block.writes.insert(id, keep);\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/plan.rs",
    "content": "use crate::{\n    CubeFusionHandle,\n    engine::{\n        codegen::ir::{FuseArg, FuseOp, FuseType},\n        launch::vectorization::Vect,\n        trace::{RuntimeLayout, block::FuseBlock},\n    },\n};\nuse burn_ir::{TensorId, TensorIr};\nuse burn_std::{Shape, Strides};\nuse cubecl::{Runtime, ir::VectorSize};\nuse std::collections::BTreeMap;\n\n/// The `LaunchPlan` is responsible for aggregating all runtime information required\n/// to dispatch a fused kernel.\n///\n/// It maps abstract IR tensors to memory handles, manages vectorization\n/// strategies, and tracks layout transformations.\n#[derive(Debug)]\npub struct LaunchPlan<'a, R: Runtime> {\n    /// The IR representation of tensors that are results of the fusion.\n    pub global_outputs: Vec<TensorIr>,\n    /// Memory handles and metadata for all input tensors.\n    pub handle_inputs: Vec<HandleInput<R>>,\n    /// Memory handles and metadata for all output tensors, including aliased inputs.\n    pub handle_outputs: Vec<HandleOutput<R>>,\n    /// The rank across all tensors in the plan.\n    ///\n    /// Smaller tensors are unsqueezed during launch.\n    pub rank: usize,\n    /// Detailed planning for each individual computation block within the fusion.\n    pub blocks: Vec<BlockPlan<'a>>,\n    /// Mapping of tensor IDs to their specific vectorization factors.\n    pub vectorizations: BTreeMap<TensorId, Vect>,\n    /// Tensors that can be cleared or deallocated after this plan executes.\n    pub cleared: Vec<TensorId>,\n    /// Metadata for shapes and strides passed from the host when they cannot be\n    /// inferred from input tensors (e.g., complex deep fusions).\n    pub runtime_layouts: Vec<RuntimeLayout>,\n}\n\n/// Information regarding the execution of a specific block of operations within a fusion.\n#[derive(Debug)]\npub struct BlockPlan<'a> {\n    /// List of inputs that are candidates for in-place memory reuse within this block.\n    pub potential_inplaces: Vec<PotentialInplace<'a>>,\n    /// The input tensor chosen to define the iteration space, if any.\n    pub potential_reference_input: Option<InputReference>,\n    /// How the master layout is determined for this block.\n    pub reference: ReferenceSelection,\n    /// Mapping of tensor IDs to the read operations performed on them.\n    pub reads: BTreeMap<TensorId, Vec<FuseOp>>,\n    /// Mapping of tensor IDs to the write operations performed on them.\n    pub writes: BTreeMap<TensorId, Vec<FuseOp>>,\n    /// The width for the operations in this block.\n    pub width: VectorSize,\n}\n\n/// Metadata for an input tensor being used as a reference for a block's layout.\n#[derive(Debug)]\npub enum InputReference {\n    /// Standard input at the specified position.\n    Normal { input_pos: usize },\n    /// Input that has an axis swapped.\n    SwapDims {\n        original_pos: usize,\n        dims: (usize, usize),\n    },\n    /// Input that has been reshaped.\n    Reshaped { reshape_pos: usize },\n}\n\n/// Strategies for selecting the reference layout of a fused block.\n///\n/// The reference layout determines how global indices are mapped to tensor coordinates.\n#[derive(Clone, Debug)]\npub enum ReferenceSelection {\n    /// The engine is still calculating the optimal reference.\n    Searching,\n    /// Layout from a normal tensor.\n    Concrete {\n        layout: FuseArg,\n        shape: Shape,\n        strides: Strides,\n    },\n    /// Layout from a swapped dim tensor.\n    SwapDims {\n        original: FuseArg,\n        dims: (usize, usize),\n    },\n    /// Layout from a reshaped tensor.\n    Reshaped { reshape_pos: usize },\n    /// Layout that has the shape of an input, but not its strides.\n    VirtualShape {\n        original: FuseArg,\n        shape: Shape,\n        strides: Strides,\n    },\n    /// The layout is provided dynamically by the host at runtime.\n    Runtime { pos: usize },\n}\n\nimpl<R: Runtime> LaunchPlan<'_, R> {\n    /// Creates a new `LaunchPlan` from a slice of fusion blocks.\n    ///\n    /// Initializes blocks with default \"Searching\" references and calculates\n    /// the initial max rank.\n    pub fn new(fuse_blocks: &[FuseBlock]) -> Self {\n        let mut rank = 0;\n        let mut blocks = Vec::with_capacity(fuse_blocks.len());\n\n        for b in fuse_blocks.iter() {\n            rank = usize::max(b.shape_ref.len(), rank);\n            let block = BlockPlan {\n                reference: ReferenceSelection::Searching,\n                reads: b.reads.clone(),\n                writes: b.writes.clone(),\n                width: 0,\n                potential_inplaces: Vec::new(),\n                potential_reference_input: None,\n            };\n            blocks.push(block);\n        }\n\n        LaunchPlan {\n            global_outputs: Vec::new(),\n            handle_inputs: Vec::new(),\n            handle_outputs: Vec::new(),\n            rank,\n            blocks,\n            vectorizations: Default::default(),\n            cleared: Default::default(),\n            runtime_layouts: Default::default(),\n        }\n    }\n}\n\n/// Debugging information for aliased handles when `autotune-checks` is enabled.\n#[cfg(feature = \"autotune-checks\")]\n#[derive(Debug)]\npub struct HandleOutputAliasDebugInfo<R: Runtime> {\n    pub handle: CubeFusionHandle<R>,\n    pub relative_id: TensorId,\n    pub global_shape: Shape,\n}\n\n/// Represents the output of a fused kernel execution.\n#[derive(Debug, Clone)]\n#[allow(clippy::large_enum_variant)]\npub enum HandleOutput<R: Runtime> {\n    /// An output that reuses the memory of an input tensor (In-place).\n    Alias {\n        /// Index of the input handle being aliased.\n        input_pos: usize,\n        /// Data type precision.\n        precision: FuseType,\n        global_shape: Shape,\n        strides: Strides,\n        #[cfg(feature = \"autotune-checks\")]\n        debug_info: HandleOutputAliasDebugInfo<R>,\n    },\n    /// An output that requires a newly allocated memory buffer.\n    Owned {\n        global_id: TensorId,\n        relative_id: TensorId,\n        precision: FuseType,\n        handle: CubeFusionHandle<R>,\n        global_shape: Shape,\n        vectorization: VectorSize,\n    },\n}\n\n/// A standard input handle with associated layout and vectorization metadata.\n#[derive(Debug, Clone)]\npub struct NormalHandleInput<R: Runtime> {\n    pub relative_id: TensorId,\n    pub global_ir: TensorIr,\n    pub precision: FuseType,\n    pub handle: CubeFusionHandle<R>,\n    pub vector_size: VectorSize,\n    pub broadcated: bool,\n    /// Stores the original strides of the handle for restoration during plan rollback.\n    pub orig_strides: Strides,\n}\n\n/// An input handle containing values for a quantized tensor.\n#[derive(Debug, Clone)]\npub struct QuantValuesHandleInput<R: Runtime> {\n    pub relative_id: TensorId,\n    pub global_ir: TensorIr,\n    pub precision: FuseType,\n    pub handle: CubeFusionHandle<R>,\n    pub vector_size: VectorSize,\n}\n\n/// An input handle containing parameters (scales/offsets) for quantization.\n#[derive(Debug, Clone)]\npub struct QuantParamsHandleInput<R: Runtime> {\n    pub precision: FuseType,\n    pub handle: CubeFusionHandle<R>,\n    pub shape: Shape,\n}\n\n/// Different types of inputs that can be passed to a fused kernel.\n#[derive(Debug, Clone)]\npub enum HandleInput<R: Runtime> {\n    Normal(NormalHandleInput<R>),\n    QuantValues(QuantValuesHandleInput<R>),\n    QuantParams(QuantParamsHandleInput<R>),\n}\n\nimpl<R: Runtime> HandleInput<R> {\n    /// Returns a reference to the inner `NormalHandleInput` if the variant matches.\n    pub fn as_normal(&self) -> Option<&NormalHandleInput<R>> {\n        match self {\n            HandleInput::Normal(normal) => Some(normal),\n            _ => None,\n        }\n    }\n}\n\nimpl<R: Runtime> NormalHandleInput<R> {\n    /// Creates a new `NormalHandleInput` tracking original strides.\n    pub fn new(\n        tensor_global: TensorIr,\n        tensor_relative: &TensorIr,\n        precision: FuseType,\n        mut handle: CubeFusionHandle<R>,\n        mut strides: Strides,\n    ) -> Self {\n        // Swap current handle strides with provided strides to track the original state for rollback.\n        core::mem::swap(&mut handle.strides, &mut strides);\n        Self {\n            precision,\n            handle,\n            relative_id: tensor_relative.id,\n            global_ir: tensor_global,\n            vector_size: 1,\n            broadcated: false,\n            orig_strides: strides,\n        }\n    }\n\n    /// Restores the handle's original strides and returns the handle.\n    ///\n    /// Used when a plan is invalidated or needs to be rolled back.\n    pub fn handle_rollback(mut self) -> CubeFusionHandle<R> {\n        core::mem::swap(&mut self.handle.strides, &mut self.orig_strides);\n        self.handle\n    }\n}\n\n/// A candidate for in-place optimization.\n#[derive(Debug)]\npub struct PotentialInplace<'a> {\n    /// Position of the input handle in the `handle_inputs` vector.\n    pub input_pos: usize,\n    /// Reference to the IR of the relative tensor.\n    pub tensor_relative: &'a TensorIr,\n    /// Current strides of the potential in-place candidate.\n    pub strides: Strides,\n}\n\nimpl ReferenceSelection {\n    pub fn is_found(&self) -> bool {\n        !matches!(self, Self::Searching)\n    }\n\n    pub fn compatible_strides_for_inplace(&self, strides_inplace: &[usize]) -> bool {\n        match self {\n            ReferenceSelection::Concrete { strides, .. } => &**strides == strides_inplace,\n            _ => false,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/runner.rs",
    "content": "use super::super::codegen::ir::{FuseBlockConfig, GlobalArgsLaunch};\nuse crate::{\n    CubeFusionHandle,\n    engine::launch::{\n        LaunchPlan,\n        vectorization::{Vect, vectorization_default},\n    },\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::{TensorId, TensorIr};\nuse cubecl::prelude::*;\nuse std::collections::{BTreeMap, HashMap};\n\n/// A trace runner is responsible for determining the vectorization factor as well as launching\n/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch)\n/// with provided [fuse block configs](FuseBlockConfig).\npub trait TraceRunner<R: Runtime>: Vectorization<R> {\n    /// The error that might happen while running the trace.\n    type Error;\n\n    /// Run the trace with the given inputs and outputs.\n    ///\n    /// There is one [fuse config](FuseBlockConfig) for each [block](super::block::FuseBlock) registered\n    /// in the [optimization builder](burn_fusion::OptimizationBuilder).\n    fn run<'a>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        configs: &'a [FuseBlockConfig],\n    ) -> Result<(), Self::Error>;\n}\n\npub enum VectorizationHandle<'a, R: Runtime> {\n    NormalInput(&'a CubeFusionHandle<R>, &'a TensorIr),\n    QuantValues(&'a CubeFusionHandle<R>, &'a TensorIr),\n    QuantParams,\n}\n\nimpl<'a, R: Runtime> VectorizationHandle<'a, R> {\n    /// Returns if the current vectorization handle is from the given tensor id.\n    pub fn is_from_tensor(&self, id: TensorId) -> bool {\n        match self {\n            VectorizationHandle::NormalInput(_, tensor_ir) => tensor_ir.id == id,\n            VectorizationHandle::QuantValues(_, tensor_ir) => tensor_ir.id == id,\n            VectorizationHandle::QuantParams => false,\n        }\n    }\n}\n\n#[derive(Default)]\npub struct VectorizationAxis {\n    axis: HashMap<TensorId, usize>,\n}\n\nimpl VectorizationAxis {\n    pub fn get<F: FnOnce() -> usize>(&self, id: TensorId, default: F) -> usize {\n        self.axis.get(&id).copied().unwrap_or_else(default)\n    }\n    pub fn insert(&mut self, id: TensorId, axis: usize) {\n        self.axis.insert(id, axis);\n    }\n}\n\npub trait Vectorization<R: Runtime> {\n    /// Returns the vectorization options.\n    fn axis(&self, _plan: &LaunchPlan<'_, R>) -> VectorizationAxis {\n        VectorizationAxis::default()\n    }\n    /// The vectorization factor for all inputs and outputs.\n    #[allow(clippy::too_many_arguments)]\n    fn vectorization<'a>(\n        &self,\n        _context: &Context<'_, CubeFusionHandle<R>>,\n        vectorizations: &mut BTreeMap<TensorId, Vect>,\n        inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,\n        outputs: impl Iterator<Item = &'a TensorIr>,\n        reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,\n        swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,\n        vector_sizes: &[VectorSize],\n        max: VectorSize,\n        axis: VectorizationAxis,\n    ) {\n        vectorization_default(\n            vectorizations,\n            inputs,\n            outputs,\n            reshaped,\n            swapped,\n            vector_sizes,\n            &Default::default(),\n            max,\n            &axis,\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs",
    "content": "use crate::{\n    CubeFusionHandle,\n    engine::launch::runner::{VectorizationAxis, VectorizationHandle},\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::{TensorId, TensorIr};\nuse cubecl::{Runtime, ir::VectorSize};\nuse serde::{Deserialize, Serialize};\nuse std::collections::BTreeMap;\n\n#[derive(Debug, Clone, Copy)]\npub enum Vect {\n    Broadcasted,\n    Aligned(VectorSize),\n}\n\nimpl Vect {\n    pub fn vector_size(&self) -> VectorSize {\n        match self {\n            Vect::Broadcasted => 1,\n            Vect::Aligned(val) => *val,\n        }\n    }\n\n    pub fn is_broadcast(&self) -> bool {\n        matches!(self, Vect::Broadcasted)\n    }\n}\n\n#[derive(Default, Clone, Serialize, Deserialize, Debug)]\npub struct VectorSizeOverrides {\n    state: Option<BTreeMap<TensorId, Vec<VectorSize>>>,\n    default: Option<Vec<VectorSize>>,\n}\n\n#[allow(unused)]\nimpl VectorSizeOverrides {\n    pub fn overrides(&mut self, tensor_id: &TensorId, vector_sizes: Vec<VectorSize>) {\n        let map = match &mut self.state {\n            Some(val) => val,\n            None => {\n                self.state = Some(BTreeMap::new());\n                self.state.as_mut().unwrap()\n            }\n        };\n\n        map.insert(*tensor_id, vector_sizes);\n    }\n    pub fn overrides_default(&mut self, vector_sizes: Vec<VectorSize>) {\n        self.default = Some(vector_sizes);\n    }\n\n    pub fn mapping<R: Runtime>(&self, context: &Context<'_, CubeFusionHandle<R>>) -> Self {\n        match &self.state {\n            Some(state) => {\n                let mut state_new = BTreeMap::new();\n\n                for (k, v) in state.iter() {\n                    let global = context.tensors.get(k).unwrap();\n                    state_new.insert(global.id, v.clone());\n                }\n\n                Self {\n                    state: Some(state_new),\n                    default: self.default.clone(),\n                }\n            }\n            None => Self {\n                state: None,\n                default: self.default.clone(),\n            },\n        }\n    }\n\n    pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec<VectorSize>> {\n        let map = match &self.state {\n            Some(val) => val,\n            None => match &self.default {\n                Some(val) => return Some(val),\n                None => return None,\n            },\n        };\n\n        match map.get(tensor_id) {\n            Some(val) => Some(val),\n            None => match &self.default {\n                Some(val) => Some(val),\n                None => None,\n            },\n        }\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn vectorization_default<'a, R: Runtime>(\n    vectorizations: &mut BTreeMap<TensorId, Vect>,\n    inputs: impl Iterator<Item = VectorizationHandle<'a, R>>,\n    outputs: impl Iterator<Item = &'a TensorIr>,\n    reshaped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool)>,\n    swapped: impl Iterator<Item = (&'a TensorIr, &'a TensorIr, bool, &'a (usize, usize))>,\n    vector_sizes: &[VectorSize],\n    overrides: &VectorSizeOverrides,\n    max: VectorSize,\n    axis: &VectorizationAxis,\n) {\n    let swapped: Vec<_> = swapped.collect();\n\n    for input in inputs {\n        if let Some((s, o, mr, dims)) = swapped\n            .iter()\n            .find(|(_s, o, _mr, _dims)| input.is_from_tensor(o.id))\n        {\n            let (handle, id) = match input {\n                VectorizationHandle::NormalInput(handle, tensor_ir) => (handle, &tensor_ir.id),\n                VectorizationHandle::QuantValues(..) => panic!(\"Can't be swapped\"),\n                VectorizationHandle::QuantParams => panic!(\"Can't be swapped\"),\n            };\n            let val = vectorization_swapped(\n                handle,\n                s,\n                o,\n                *mr,\n                dims,\n                max,\n                axis,\n                vector_sizes,\n                overrides.tensor(id),\n            );\n            multi_reads_vectorization_update(vectorizations, o.id, val);\n        } else {\n            match input {\n                VectorizationHandle::NormalInput(handle, tensor_ir) => {\n                    let val = vectorization_input(\n                        handle,\n                        tensor_ir,\n                        axis,\n                        vector_sizes,\n                        overrides.tensor(&tensor_ir.id),\n                    );\n                    vectorizations.insert(tensor_ir.id, val);\n                }\n                VectorizationHandle::QuantValues(handle, tensor_ir) => {\n                    let val = vectorization_input(\n                        handle,\n                        tensor_ir,\n                        axis,\n                        vector_sizes,\n                        overrides.tensor(&tensor_ir.id),\n                    );\n                    let num_quants = match tensor_ir.dtype {\n                        burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(),\n                        _ => panic!(\"\"),\n                    };\n                    let val = match val {\n                        Vect::Broadcasted => Vect::Aligned(1),\n                        Vect::Aligned(val) => Vect::Aligned(val.div_ceil(num_quants)),\n                    };\n                    vectorizations.insert(tensor_ir.id, val);\n                }\n                VectorizationHandle::QuantParams => {\n                    // Doesn't have vectorization for now.\n                }\n            };\n        }\n    }\n\n    for (reshaped, original, multi_reads) in reshaped {\n        let val = vectorization_reshape(\n            reshaped,\n            original,\n            multi_reads,\n            axis,\n            vector_sizes,\n            max,\n            overrides.tensor(&original.id),\n        );\n        multi_reads_vectorization_update(vectorizations, original.id, val);\n    }\n\n    for tensor in outputs {\n        let val = vectorization_output(\n            tensor,\n            axis,\n            vector_sizes,\n            max,\n            overrides.tensor(&tensor.id),\n        );\n        vectorizations.insert(tensor.id, val);\n    }\n}\n\nfn multi_reads_vectorization_update(\n    vectorizations: &mut BTreeMap<TensorId, Vect>,\n    original: TensorId,\n    vect: Vect,\n) {\n    if let Some(ori_vect) = vectorizations.get(&original).cloned() {\n        match ori_vect {\n            Vect::Broadcasted => {\n                // keep the original as is.\n            }\n            Vect::Aligned(ori) => match vect {\n                Vect::Broadcasted => {\n                    vectorizations.insert(original, Vect::Aligned(1));\n                }\n                Vect::Aligned(new) => {\n                    let val = if new != ori { 1 } else { new };\n                    vectorizations.insert(original, Vect::Aligned(val));\n                }\n            },\n        };\n    } else {\n        vectorizations.insert(original, vect);\n    }\n}\n\n// The default version uses the last dimension as vectorization axis and assumes a\n// perpendicular contiguous vector.\nfn vectorization_input<R: Runtime>(\n    handle: &CubeFusionHandle<R>,\n    desc: &TensorIr,\n    axis: &VectorizationAxis,\n    vector_sizes: &[VectorSize],\n    overrides: Option<&Vec<VectorSize>>,\n) -> Vect {\n    let axis = axis.get(desc.id, || handle.strides.len() - 1);\n    let shape_axis = desc.shape[axis];\n\n    if shape_axis == 1 {\n        return Vect::Broadcasted;\n    }\n\n    // Last dimension strides should be 1, otherwise vecX won't be contiguous.\n    if handle.strides[axis] != 1 {\n        return Vect::Aligned(1);\n    }\n\n    let inner = |s: VectorSize| {\n        // The last dimension should be a multiple of the vector size or broadcated.\n        if shape_axis.is_multiple_of(s) {\n            return Some(Vect::Aligned(s));\n        }\n        None\n    };\n\n    match overrides {\n        Some(vals) => {\n            for s in vals {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n        None => {\n            for s in vector_sizes {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n    }\n\n    Vect::Aligned(1)\n}\n\nfn vectorization_output(\n    desc: &TensorIr,\n    axis: &VectorizationAxis,\n    vector_sizes: &[VectorSize],\n    max: VectorSize,\n    overrides: Option<&Vec<VectorSize>>,\n) -> Vect {\n    let axis = axis.get(desc.id, || desc.shape.rank() - 1);\n\n    let inner = |s: VectorSize| {\n        // The dimension should be a multiple of the vector size.\n        if desc.shape[axis].is_multiple_of(s) && s <= max {\n            return Some(Vect::Aligned(s));\n        }\n\n        None\n    };\n    match overrides {\n        Some(val) => {\n            for s in val {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n        None => {\n            for s in vector_sizes {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n    }\n\n    Vect::Aligned(1)\n}\n\nfn vectorization_reshape(\n    reshaped: &TensorIr,\n    original: &TensorIr,\n    multi_reads: bool,\n    axis: &VectorizationAxis,\n    vector_sizes: &[VectorSize],\n    max: VectorSize,\n    overrides: Option<&Vec<VectorSize>>,\n) -> Vect {\n    let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1);\n    let reshape_shape_axis = reshaped.shape[axis];\n\n    if !multi_reads && reshape_shape_axis == 1 {\n        return Vect::Broadcasted;\n    }\n\n    // If the axis is not the last dim, didn't think of it, return Aligned(1) to be sure.\n    if axis != reshaped.shape.rank() - 1 {\n        return Vect::Aligned(1);\n    }\n\n    let original_shape_axis = original.shape[original.shape.rank() - 1];\n\n    if original_shape_axis != reshape_shape_axis {\n        return Vect::Aligned(1);\n    }\n\n    let inner = |s: VectorSize| {\n        if !multi_reads {\n            // The last dimension should be a multiple of the vector size or broadcated.\n            if reshape_shape_axis.is_multiple_of(s) && s <= max {\n                Some(Vect::Aligned(s))\n            } else {\n                None\n            }\n        } else {\n            // Since the original tensor must share the same vectorization factor as the\n            // reshaped tensor, they must have compatible shapes when both are access\n            // independently.\n            if reshape_shape_axis.is_multiple_of(s)\n                && original_shape_axis.is_multiple_of(s)\n                && s <= max\n            {\n                Some(Vect::Aligned(s))\n            } else {\n                None\n            }\n        }\n    };\n\n    match overrides {\n        Some(val) => {\n            for i in val {\n                if let Some(vect) = inner(*i) {\n                    return vect;\n                }\n            }\n        }\n        None => {\n            for s in vector_sizes {\n                if let Some(vect) = inner(*s) {\n                    return vect;\n                }\n            }\n        }\n    }\n\n    Vect::Aligned(1)\n}\n\n#[allow(clippy::too_many_arguments)]\nfn vectorization_swapped<R: Runtime>(\n    handle: &CubeFusionHandle<R>,\n    swapped: &TensorIr,\n    original: &TensorIr,\n    multi_reads: bool,\n    dims: &(usize, usize),\n    max: VectorSize,\n    axis: &VectorizationAxis,\n    vector_sizes: &[VectorSize],\n    overrides: Option<&Vec<VectorSize>>,\n) -> Vect {\n    let axis = axis.get(swapped.id, || swapped.shape.rank() - 1);\n\n    let swapped_axis = swapped.shape[axis];\n    let shape_axis = original.shape[axis];\n\n    let axis_index = axis;\n    let dim_index = if dims.0 == axis_index {\n        dims.1\n    } else if dims.1 == axis_index {\n        dims.0\n    } else {\n        axis_index\n    };\n\n    // Last dimension strides should be 1, otherwise vecX won't be contiguous.\n    if multi_reads {\n        if handle.strides[axis_index] != 1 {\n            return Vect::Aligned(1);\n        }\n        if handle.strides[dim_index] != 1 {\n            return Vect::Aligned(1);\n        }\n    } else if handle.strides[dim_index] != 1 {\n        return Vect::Aligned(1);\n    }\n\n    if !multi_reads && swapped_axis == 1 {\n        return Vect::Broadcasted;\n    }\n\n    let inner = |s: VectorSize| {\n        // The last dimension should be a multiple of the vector size or broadcated.\n        if multi_reads {\n            if swapped_axis.is_multiple_of(s) && s <= max {\n                return Some(Vect::Aligned(s));\n            }\n        } else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max {\n            return Some(Vect::Aligned(s));\n        }\n        None\n    };\n\n    match overrides {\n        Some(val) => {\n            for s in val {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n        None => {\n            for s in vector_sizes {\n                if let Some(val) = inner(*s) {\n                    return val;\n                }\n            }\n        }\n    }\n\n    Vect::Aligned(1)\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/vectorization/mod.rs",
    "content": "mod base;\nmod planner;\n\npub use base::*;\npub use planner::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs",
    "content": "use super::{\n    super::{BlockPlan, HandleOutput, LaunchPlan},\n    Vect,\n};\nuse crate::{\n    CubeFusionHandle,\n    engine::{\n        launch::{\n            HandleInput,\n            runner::{Vectorization, VectorizationHandle},\n        },\n        settings::VectorizationSetting,\n        trace::{FuseResources, TensorView, block::FuseBlock},\n    },\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::TensorId;\nuse cubecl::{\n    Runtime,\n    client::ComputeClient,\n    ir::{ElemType, StorageType, UIntKind},\n};\nuse cubecl::{\n    ir::VectorSize,\n    quant::scheme::{QuantScheme, QuantStore, QuantValue},\n};\nuse std::marker::PhantomData;\n\n/// Select the best vectorization factor for each tensor handle.\npub struct VectorizationPlanner<'a, R: Runtime> {\n    resources: &'a FuseResources,\n    blocks: &'a Vec<FuseBlock>,\n    _r: PhantomData<R>,\n}\n\nimpl<'a, R: Runtime> VectorizationPlanner<'a, R> {\n    pub fn new(resources: &'a FuseResources, blocks: &'a Vec<FuseBlock>) -> Self {\n        Self {\n            resources,\n            blocks,\n            _r: PhantomData,\n        }\n    }\n    pub fn run<Runner: Vectorization<R>>(\n        self,\n        client: &ComputeClient<R>,\n        runner: &Runner,\n        context: &Context<'_, CubeFusionHandle<R>>,\n        plan: &mut LaunchPlan<'a, R>,\n    ) {\n        let has_multiple_read = |tensor: &TensorId| {\n            let mut read_count = 0;\n            for block in plan.blocks.iter() {\n                read_count += block.reads.get(tensor).map(|a| a.len()).unwrap_or(0);\n            }\n            read_count > 1\n        };\n        let tensors_reshaped = self.resources.views.iter().filter_map(|view| match view {\n            TensorView::Reshape {\n                reshaped, original, ..\n            } => Some((\n                context.tensors.get(reshaped).unwrap(),\n                context.tensors.get(original).unwrap(),\n                has_multiple_read(original),\n            )),\n            TensorView::SwapDims { .. } => None,\n        });\n        let tensors_swapped = self.resources.views.iter().filter_map(|view| match view {\n            TensorView::SwapDims {\n                swapped,\n                original,\n                dims,\n                ..\n            } => Some((\n                context.tensors.get(swapped).unwrap(),\n                context.tensors.get(original).unwrap(),\n                has_multiple_read(original),\n                dims,\n            )),\n            TensorView::Reshape { .. } => None,\n        });\n\n        let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8);\n        let mut quants_vector_sizes: Option<Vec<VectorSize>> = None;\n\n        for input in plan.handle_inputs.iter() {\n            let elem: StorageType = match input {\n                HandleInput::Normal(h) => h.global_ir.dtype.into(),\n                HandleInput::QuantValues(handle) => match handle.global_ir.dtype {\n                    burn_std::DType::QFloat(scheme) => {\n                        vector_sizes_quants(client, &mut quants_vector_sizes, scheme);\n                        continue;\n                    }\n                    _ => panic!(\"Unable to retrieve the scheme for quantized values.\"),\n                },\n                HandleInput::QuantParams(..) => continue,\n            };\n            let elem_size = elem.size();\n\n            if ref_elem.1 >= elem_size {\n                ref_elem = (elem, elem_size);\n            }\n        }\n        for r in plan.global_outputs.iter() {\n            let elem: StorageType = r.dtype.into();\n            let elem_size = elem.size();\n\n            if ref_elem.1 >= elem_size {\n                ref_elem = (elem, elem_size);\n            }\n        }\n\n        let filtered = plan\n            .handle_inputs\n            .iter()\n            .map(|item| {\n                item.as_normal()\n                    // Filter out indexed resources.\n                    .map(|item| !self.resources.indexed.contains_key(&item.relative_id))\n                    .unwrap_or(true)\n            })\n            .collect::<Vec<_>>();\n\n        let vector_sizes = match quants_vector_sizes {\n            // Quantization normally triggers higher vectorization than anything else, no need to\n            // compare to ref elem.\n            Some(vector_sizes) => vector_sizes,\n            None => client\n                .io_optimized_vector_sizes(ref_elem.0.size())\n                .collect::<Vec<_>>(),\n        };\n\n        let vectorization_axis = runner.axis(plan);\n\n        runner.vectorization(\n            context,\n            &mut plan.vectorizations,\n            plan.handle_inputs\n                .iter()\n                .enumerate()\n                .filter_map(|(i, item)| {\n                    if filtered[i] {\n                        Some(match item {\n                            HandleInput::Normal(h) => {\n                                VectorizationHandle::NormalInput(&h.handle, &h.global_ir)\n                            }\n                            HandleInput::QuantValues(h) => {\n                                VectorizationHandle::QuantValues(&h.handle, &h.global_ir)\n                            }\n                            HandleInput::QuantParams(_) => VectorizationHandle::QuantParams,\n                        })\n                    } else {\n                        None\n                    }\n                }),\n            plan.global_outputs.iter(),\n            tensors_reshaped,\n            tensors_swapped,\n            &vector_sizes,\n            u8::MAX as usize,\n            vectorization_axis,\n        );\n\n        for tensor in self.resources.indexed.keys() {\n            let global = context.tensors.get(tensor).unwrap();\n            plan.vectorizations.insert(global.id, Vect::Aligned(1));\n        }\n\n        let mut block_vectorization = Vec::with_capacity(self.blocks.len());\n        for _ in 0..self.blocks.len() {\n            block_vectorization.push(Vec::new());\n        }\n\n        for (input_pos, handle) in plan.handle_inputs.iter_mut().enumerate() {\n            let (global_ir, relative_id) = match handle {\n                HandleInput::Normal(h) => (&h.global_ir, &h.relative_id),\n                HandleInput::QuantValues(h) => (&h.global_ir, &h.relative_id),\n                HandleInput::QuantParams(_) => continue,\n            };\n            let (vect, br) = match plan.vectorizations.get(&global_ir.id) {\n                Some(v) => (v.vector_size(), v.is_broadcast()),\n                None => panic!(\"No vectorization factor found for {:?}\", global_ir.id),\n            };\n\n            for (block_pos, block_plan) in plan.blocks.iter().enumerate() {\n                if block_plan.reads.contains_key(relative_id) {\n                    block_vectorization[block_pos].push(BlockVectorization {\n                        action: VectorizationAction::Input(input_pos),\n                        potential: vect,\n                        broadcasted: br,\n                    });\n                }\n            }\n        }\n\n        for (output_pos, handle) in plan.handle_outputs.iter().enumerate() {\n            if let HandleOutput::Owned {\n                global_id,\n                relative_id,\n                ..\n            } = handle\n            {\n                for (block_pos, block_plan) in plan.blocks.iter().enumerate() {\n                    if block_plan.writes.contains_key(relative_id) {\n                        let vectorization =\n                            plan.vectorizations.get(global_id).unwrap().vector_size();\n                        block_vectorization[block_pos].push(BlockVectorization {\n                            action: VectorizationAction::Output(output_pos),\n                            potential: vectorization,\n                            broadcasted: false,\n                        });\n                    }\n                }\n            }\n        }\n\n        let mut previous_widths = Vec::with_capacity(block_vectorization.len());\n\n        // Unhandled inputs might not get included in any fused blocks for now.\n        //\n        // So we ensure they are vectorized by setting their vectorization before we set the\n        // vectorizations in blocks.\n        //\n        // Unhandled Outputs are correctly vectorized, so this is only necessary for inputs.\n        for input in self.resources.inputs_unhandled.iter() {\n            let pos = self\n                .resources\n                .inputs\n                .get_index(*input)\n                .unwrap_or_else(|| self.resources.inputs.get_index_quant(*input).unwrap());\n            let input_global = context.tensors.get(input).unwrap();\n\n            match plan.vectorizations.get(&input_global.id).unwrap() {\n                Vect::Aligned(vect) => {\n                    let handle = &mut plan.handle_inputs[pos];\n                    match handle {\n                        HandleInput::Normal(handle) => {\n                            handle.vector_size = *vect;\n                        }\n                        HandleInput::QuantValues(handle) => {\n                            handle.vector_size = *vect;\n                        }\n                        HandleInput::QuantParams(_) => {}\n                    }\n                }\n                Vect::Broadcasted => {}\n            }\n        }\n\n        for ((tmp, block_plan), block) in block_vectorization\n            .into_iter()\n            .zip(plan.blocks.iter_mut())\n            .zip(self.blocks)\n        {\n            match block.settings.vectorization {\n                VectorizationSetting::Activated => {\n                    apply_vectorization_block(\n                        tmp,\n                        &mut plan.handle_inputs,\n                        &mut plan.handle_outputs,\n                        block_plan,\n                        u8::MAX as usize,\n                    );\n                }\n                VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos } => {\n                    apply_vectorization_block(\n                        tmp,\n                        &mut plan.handle_inputs,\n                        &mut plan.handle_outputs,\n                        block_plan,\n                        previous_widths[block_pos],\n                    );\n                    if block_plan.width == 0 {\n                        block_plan.width = previous_widths[block_pos];\n                    }\n                }\n                VectorizationSetting::EqualThanPreviousBlock { block_pos } => {\n                    apply_vectorization_block(\n                        tmp,\n                        &mut plan.handle_inputs,\n                        &mut plan.handle_outputs,\n                        block_plan,\n                        previous_widths[block_pos],\n                    );\n                    // Enforces the width.\n                    block_plan.width = previous_widths[block_pos];\n                }\n                VectorizationSetting::Deactivated => {\n                    apply_vectorization_block(\n                        tmp,\n                        &mut plan.handle_inputs,\n                        &mut plan.handle_outputs,\n                        block_plan,\n                        1,\n                    );\n                    block_plan.width = 1;\n                }\n            }\n\n            // When only virtual inputs/outputs are present for a block, we need to set a width.\n            if block_plan.width == 0 {\n                if let Some(w) = previous_widths.last() {\n                    block_plan.width = *w;\n                } else {\n                    block_plan.width = 1;\n                }\n            }\n\n            previous_widths.push(block_plan.width);\n        }\n    }\n}\n\n#[derive(Debug)]\nenum VectorizationAction {\n    Input(usize),\n    Output(usize),\n}\n\n#[derive(Debug)]\nstruct BlockVectorization {\n    action: VectorizationAction,\n    potential: VectorSize,\n    broadcasted: bool,\n}\n\nfn apply_vectorization_block<R: Runtime>(\n    block_vectorization: Vec<BlockVectorization>,\n    inputs: &mut [HandleInput<R>],\n    outputs: &mut [HandleOutput<R>],\n    block_plan: &mut BlockPlan,\n    max: VectorSize,\n) {\n    for item in block_vectorization {\n        match item.action {\n            VectorizationAction::Input(pos) => {\n                let (vect, br) = if item.potential <= max {\n                    (item.potential, item.broadcasted)\n                } else {\n                    (1, false)\n                };\n\n                match &mut inputs[pos] {\n                    HandleInput::Normal(input) => {\n                        input.vector_size = vect;\n                        input.broadcated = br;\n                    }\n                    HandleInput::QuantValues(input) => {\n                        input.vector_size = vect;\n                    }\n                    HandleInput::QuantParams(_) => {\n                        // Not vectorized\n                    }\n                }\n\n                if block_plan.width < vect {\n                    block_plan.width = vect;\n                }\n            }\n            VectorizationAction::Output(pos) => {\n                if let HandleOutput::Owned { vectorization, .. } = &mut outputs[pos] {\n                    let vect = if item.potential <= max {\n                        item.potential\n                    } else {\n                        1\n                    };\n                    *vectorization = vect;\n\n                    if block_plan.width < vect {\n                        block_plan.width = vect;\n                    }\n                }\n            }\n        }\n    }\n}\n\nfn vector_sizes_quants<R: Runtime>(\n    client: &ComputeClient<R>,\n    quants_vector_sizes: &mut Option<Vec<VectorSize>>,\n    scheme: QuantScheme,\n) {\n    match scheme.store {\n        QuantStore::Native => match scheme.value {\n            // Type sizes are the same so just treat fp8/fp4x2 as i8\n            QuantValue::Q8F\n            | QuantValue::Q8S\n            | QuantValue::E4M3\n            | QuantValue::E5M2\n            | QuantValue::E2M1 => {\n                let vector_sizes = client\n                    .io_optimized_vector_sizes(size_of::<i8>())\n                    .collect::<Vec<_>>();\n\n                match &quants_vector_sizes {\n                    Some(sizes) => {\n                        if sizes[0] < vector_sizes[0] {\n                            *quants_vector_sizes = Some(vector_sizes);\n                        }\n                    }\n                    None => {\n                        *quants_vector_sizes = Some(vector_sizes);\n                    }\n                }\n            }\n            QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {\n                unreachable!(\"Can't store native sub-byte values\")\n            }\n        },\n        QuantStore::PackedU32(_) => {\n            let mut vector_sizes = client\n                .io_optimized_vector_sizes(size_of::<u32>())\n                .collect::<Vec<_>>();\n            for val in vector_sizes.iter_mut() {\n                *val *= scheme.num_quants();\n            }\n\n            match &quants_vector_sizes {\n                Some(sizes) => {\n                    if sizes[0] < vector_sizes[0] {\n                        let mut min = *vector_sizes.last().unwrap();\n\n                        while min > 1 {\n                            min /= 2;\n                            vector_sizes.push(min);\n                        }\n                        *quants_vector_sizes = Some(vector_sizes);\n                    }\n                }\n                None => {\n                    *quants_vector_sizes = Some(vector_sizes);\n                }\n            }\n        }\n        QuantStore::PackedNative(_) => {\n            panic!(\"Not yet supported\")\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/mod.rs",
    "content": "pub(crate) mod codegen;\npub(crate) mod fuser;\npub(crate) mod launch;\npub(crate) mod scoring;\npub(crate) mod settings;\n\npub mod trace;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/scoring.rs",
    "content": "use crate::engine::{\n    codegen::ir::{FuseArg, FuseOp, UnaryFuseArgs},\n    trace::FuseTrace,\n};\nuse burn_ir::OperationIr;\n\n#[derive(Debug, Clone, Default)]\n/// Tracks and evaluates the efficiency of operation fusion.\npub struct Scoring {\n    num_writes: usize,\n    num_reads: usize,\n    num_ops: usize,\n}\n\nimpl Scoring {\n    /// Resets the internal O counters.\n    pub fn reset(&mut self) {\n        self.num_writes = 0;\n        self.num_reads = 0;\n        self.num_ops = 0;\n    }\n\n    /// Registers an unfused operation to the score, counting its total potential I/O.\n    pub fn register(&mut self, op: &OperationIr) {\n        self.num_writes += op.outputs().count();\n        self.num_reads += op.inputs().count();\n        self.num_ops += 1;\n    }\n\n    /// Evaluates the efficiency of a fused trace by comparing its actual I/O\n    /// against the registered unfused I/O. Returns the number of saved I/O operations.\n    pub fn evaluate(&self, trace: &FuseTrace) -> u64 {\n        let mut num_reads_fused = 0;\n        let mut num_writes_fused = 0;\n        let mut num_penalty = 0;\n\n        for b in trace.blocks.iter() {\n            // Count reads in block\n            for (_, ops) in b.reads.iter() {\n                let result = self.count_fused_io(ops, |args| &args.input);\n                num_reads_fused += result.0;\n                num_penalty += result.1;\n            }\n            // Count writes in block\n            for (_, ops) in b.writes.iter() {\n                let result = self.count_fused_io(ops, |args| &args.out);\n                num_writes_fused += result.0;\n                num_penalty += result.1;\n            }\n        }\n\n        self.calculate_score(num_reads_fused, num_writes_fused, num_penalty)\n    }\n\n    fn calculate_score(&self, reads_fused: usize, writes_fused: usize, num_penalty: usize) -> u64 {\n        // Those could be tweaked eventually.\n\n        const FACTOR_IO: u64 = 100;\n        const FACTOR_LAUNCH: u64 = 10;\n        const FACTOR_PENALTY: u64 = 50;\n\n        let num_fused = reads_fused + writes_fused;\n        let num_unfused = self.num_reads + self.num_writes;\n\n        let score_io = match num_fused >= num_unfused {\n            true => 0,\n            false => (num_unfused - num_fused) as u64 * FACTOR_IO,\n        };\n\n        // We minus 1 since at least one kernel launch is necessary.\n        let score_launch = self.num_ops.saturating_sub(1) as u64 * FACTOR_LAUNCH;\n\n        let score_penalty = num_penalty as u64 * FACTOR_PENALTY;\n\n        (score_io + score_launch).saturating_sub(score_penalty)\n    }\n\n    fn count_fused_io<F>(&self, ops: &[FuseOp], arg_extractor: F) -> (usize, usize)\n    where\n        F: Fn(&UnaryFuseArgs) -> &FuseArg,\n    {\n        let mut num_io = 0;\n        let mut penalty = 0;\n\n        for op in ops.iter() {\n            let FuseOp::Assign(args) = op else {\n                unreachable!()\n            };\n            let count_normal = matches!(\n                arg_extractor(args),\n                FuseArg::Input(..) | FuseArg::Output(..)\n            ) as usize;\n            let count_view = matches!(\n                arg_extractor(args),\n                FuseArg::InputReshaped { .. } | FuseArg::InputSwapDims { .. }\n            ) as usize;\n            num_io += count_normal + count_view;\n            penalty += count_view;\n        }\n\n        (num_io, penalty)\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::field_reassign_with_default)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_scoring_io_savings() {\n        let mut scoring = Scoring::default();\n        scoring.num_reads = 2;\n        scoring.num_writes = 2;\n        scoring.num_ops = 2;\n\n        let score = scoring.calculate_score(1, 1, 0);\n        assert_eq!(score, 210);\n    }\n\n    #[test]\n    fn test_scoring_with_penalties() {\n        let mut scoring = Scoring::default();\n        scoring.num_reads = 2;\n        scoring.num_writes = 2;\n        scoring.num_ops = 2;\n\n        let score = scoring.calculate_score(1, 1, 1);\n        assert_eq!(score, 160);\n    }\n\n    #[test]\n    fn test_penalty_outweighs_benefit() {\n        let mut scoring = Scoring::default();\n        scoring.num_reads = 1;\n        scoring.num_writes = 1;\n        scoring.num_ops = 2;\n\n        let score = scoring.calculate_score(1, 1, 1);\n        assert_eq!(score, 0);\n    }\n\n    #[test]\n    fn test_scoring_no_ops() {\n        let scoring = Scoring::default();\n        let score = scoring.calculate_score(0, 0, 0);\n        assert_eq!(score, 0);\n    }\n\n    #[test]\n    fn test_reset() {\n        let mut scoring = Scoring {\n            num_writes: 10,\n            num_reads: 10,\n            num_ops: 10,\n        };\n        scoring.reset();\n        assert_eq!(scoring.num_writes, 0);\n        assert_eq!(scoring.num_reads, 0);\n        assert_eq!(scoring.num_ops, 0);\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/settings.rs",
    "content": "use serde::{Deserialize, Serialize};\n\n/// Controls which operations can be fused.\n#[derive(Clone, Copy, Debug, Serialize, Deserialize)]\npub struct FuseSettings {\n    /// Enables broadcasting of shapes.\n    pub broadcast: bool,\n    /// Enables output shape updates.\n    ///\n    /// When broadcast is enabled, the output shape can become bigger after a fusion,\n    /// therefore an update is needed.\n    pub output_shape_updates: bool,\n    /// Enables the reuse of input buffers.\n    pub inplace: bool,\n    /// Whether vectorization is enabled.\n    pub vectorization: VectorizationSetting,\n    /// How [reference layout](super::ir::RefLayout) selection is done.\n    pub ref_layout: RefLayoutSetting,\n}\n\nimpl Default for FuseSettings {\n    fn default() -> Self {\n        Self {\n            broadcast: true,\n            output_shape_updates: true,\n            inplace: true,\n            vectorization: VectorizationSetting::Activated,\n            ref_layout: RefLayoutSetting::Any,\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, Serialize, Deserialize)]\n/// How vectorization is handled during fusion.\npub enum VectorizationSetting {\n    /// The biggest vector_size possible will be used.\n    Activated,\n    /// Equivalent to using vector_size of one.\n    Deactivated,\n    /// This is a good setting when a block processes values calculated from a previous block.\n    SmallerOrEqualThanPreviousBlock { block_pos: usize },\n    /// This is a good setting when a block processes values calculated from a previous block.\n    EqualThanPreviousBlock { block_pos: usize },\n}\n\n#[derive(Clone, Copy, Debug, Serialize, Deserialize)]\n/// Influence how the [reference layout](super::ir::RefLayout) selection is done.\npub enum RefLayoutSetting {\n    /// Any reference layout is allowed.\n    Any,\n    /// Only contiguous reference layout is allowed.\n    ///\n    /// Note that forcing a contiguous reference layout might reduce the opportunity of inplace\n    /// fusion.\n    OnlyContiguous,\n    SameAsBlock {\n        block_pos: u32,\n    },\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/trace/base.rs",
    "content": "use crate::engine::{\n    codegen::ir::{FuseArg, FuseType},\n    trace::block::FuseBlock,\n};\nuse burn_ir::{TensorId, TensorIr};\nuse burn_std::{Shape, Strides};\nuse cubecl::prelude::*;\nuse serde::{Deserialize, Serialize};\nuse std::{\n    collections::{BTreeMap, HashSet},\n    marker::PhantomData,\n};\n\n#[cfg(feature = \"autotune-checks\")]\nuse crate::CubeFusionHandle;\n#[cfg(feature = \"autotune-checks\")]\nuse burn_backend::TensorData;\n#[cfg(feature = \"autotune-checks\")]\nuse std::collections::HashMap;\n\n#[derive(Clone, Serialize, Deserialize, Debug)]\n/// A trace contains all [blocks](FuseBlock) and the [resources](FuseResources) used by the\n/// kernel.\npub struct FuseTrace {\n    pub blocks: Vec<FuseBlock>,\n    pub resources: FuseResources,\n}\n\nimpl core::fmt::Display for FuseTrace {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        writeln!(f, \"FuseTrace\")?;\n        for b in self.blocks.iter() {\n            writeln!(f, \" - Block shape={:?}\", b.shape_ref)?;\n            for (tensor, ops) in b.reads.iter() {\n                for op in ops.iter() {\n                    writeln!(f, \"   - {op} <== {tensor}\")?;\n                }\n            }\n            for op in b.ops.iter() {\n                writeln!(f, \"   - {op}\")?;\n            }\n            for (tensor, ops) in b.writes.iter() {\n                for op in ops.iter() {\n                    writeln!(f, \"   - {op} <== {tensor}\")?;\n                }\n            }\n        }\n\n        Ok(())\n    }\n}\n\npub enum TuneOutput<R: Runtime> {\n    UnChecked(PhantomData<R>),\n    #[cfg(feature = \"autotune-checks\")]\n    Checked {\n        handles: HashMap<TensorId, (Vec<usize>, CubeFusionHandle<R>)>,\n    },\n}\n\nimpl<R: Runtime> TuneOutput<R> {\n    #[allow(unused_variables)]\n    pub fn merge(self, other: Self) -> Self {\n        let mut result = self;\n\n        match &mut result {\n            TuneOutput::UnChecked(..) => {}\n            #[cfg(feature = \"autotune-checks\")]\n            TuneOutput::Checked { handles } => match other {\n                TuneOutput::UnChecked(..) => {}\n                TuneOutput::Checked { handles: o } => {\n                    for (k, v) in o.into_iter() {\n                        handles.insert(k, v);\n                    }\n                }\n            },\n        }\n\n        result\n    }\n}\n\nimpl<R: Runtime> cubecl::tune::AutotuneOutput for TuneOutput<R> {\n    #[cfg(feature = \"autotune-checks\")]\n    fn check_equivalence(&self, other: Self) {\n        use burn_backend::Tolerance;\n        use burn_std::DType;\n\n        if let (\n            TuneOutput::Checked {\n                handles: handles_ref,\n            },\n            TuneOutput::Checked { handles },\n        ) = (self, &other)\n        {\n            let mut num_checked = 0;\n            let mut num_handles = 0;\n            for (id, (shape, handle)) in handles_ref.iter() {\n                num_handles += 1;\n                if let Some((shape_other, other)) = handles.get(id) {\n                    use burn_std::is_contiguous;\n                    use cubecl::std::tensor::into_contiguous_ref;\n\n                    let current_handle = if !is_contiguous(&shape, &handle.strides) {\n                        into_contiguous_ref::<R>(\n                            &handle.client,\n                            &handle.as_handle_ref(&shape),\n                            handle.dtype.into(),\n                        )\n                        .unwrap()\n                        .handle\n                    } else {\n                        handle.handle.clone()\n                    };\n                    let other_handle = if !is_contiguous(&shape, &other.strides) {\n                        into_contiguous_ref::<R>(\n                            &other.client,\n                            &other.as_handle_ref(&shape),\n                            other.dtype.into(),\n                        )\n                        .unwrap()\n                        .handle\n                    } else {\n                        other.handle.clone()\n                    };\n\n                    let data_ref = handle.client.read_one(current_handle);\n                    let data_other = other.client.read_one(other_handle);\n                    let data_ref = TensorData::from_bytes(data_ref, shape.clone(), handle.dtype);\n                    let data_other =\n                        TensorData::from_bytes(data_other, shape_other.clone(), handle.dtype);\n\n                    match handle.dtype {\n                        DType::F64 => {\n                            data_ref.assert_approx_eq::<f64>(&data_other, Tolerance::permissive())\n                        }\n                        DType::F32 => {\n                            data_ref.assert_approx_eq::<f32>(&data_other, Tolerance::permissive())\n                        }\n                        DType::F16 => data_ref\n                            .assert_approx_eq::<half::f16>(&data_other, Tolerance::permissive()),\n                        DType::BF16 => data_ref\n                            .assert_approx_eq::<half::bf16>(&data_other, Tolerance::permissive()),\n                        _ => data_ref.assert_eq(&data_other, true),\n                    }\n                    num_checked += 1;\n                } else {\n                    // Debug info for the tests.\n                    println!(\"No tensor found for {id:?}=>{shape:?}\");\n                }\n            }\n\n            // At least one check is needed per output when there is an output.\n            //\n            // Some optimizations might write more outputs than needed, so it might be fined if\n            // the number of handles is different, but at least one is required.\n            //\n            // An optimization might not create outputs if its dead code detection is triggered,\n            // therefore avoiding useless computation.\n            if num_handles > 0 {\n                assert!(num_checked >= 1);\n            }\n        }\n    }\n}\n\n#[derive(Clone, Serialize, Deserialize, Debug, Default)]\n/// Declare all resources used by the kernel, and potentially multiple [blocks](FuseBlock).\n///\n/// # Notes\n///\n/// Each block can't contain their own resources, since they are shared between blocks. The\n/// vectorization factor of one input tensor must be the same for all blocks.\npub struct FuseResources {\n    pub outputs: RegisteredTensors,\n    pub inputs: RegisteredTensors,\n    pub scalars: Vec<(FuseType, u64)>,\n    // TODO: Making put a map of global registers.\n    pub views: Vec<TensorView>,\n    pub indexed: BTreeMap<TensorId, FuseArg>,\n    pub inputs_unhandled: Vec<TensorId>,\n    pub outputs_unhandled: Vec<FuseArg>,\n    pub num_reshaped: usize,\n    /// Necessary to remove some entries from the context.\n    pub dropped: HashSet<TensorId>,\n    /// We know during fusion that we have to have those buffers has global.\n    /// The pos here can be interpreted as GLOBAL pos where the output pos are locals.\n    pub buffers: RegisteredTensors,\n    /// Global registers available everywhere.\n    ///\n    /// TODO: Not all registers should be globals.\n    pub registers: BTreeMap<TensorId, FuseArg>,\n}\n\n#[derive(Clone, Serialize, Deserialize, Debug)]\npub struct RuntimeLayout {\n    pub shape: Shape,\n    pub strides: Strides,\n}\n\nimpl Default for RuntimeLayout {\n    fn default() -> Self {\n        Self {\n            shape: Shape::new([]),\n            strides: Strides::new(&[]),\n        }\n    }\n}\n\n#[derive(Debug)]\npub enum TraceError<Err> {\n    ReferenceNotFound,\n    RunnerError(Err),\n}\n\n#[derive(Clone, Serialize, Deserialize, Debug)]\npub enum TensorView {\n    Reshape {\n        reshaped: TensorId,\n        original: TensorId,\n        reshape_pos: usize,\n        shape_relative: Shape,\n    },\n    SwapDims {\n        swapped: TensorId,\n        original: TensorId,\n        dims: (usize, usize),\n    },\n}\n\n#[derive(Default, Clone, Serialize, Deserialize, Debug)]\npub struct RegisteredTensors {\n    tensors: Vec<RegisterTensor>,\n}\n\n#[derive(Clone, Serialize, Deserialize, Debug)]\npub enum RegisterTensor {\n    Normal(TensorIr, FuseType),\n    QuantValues(TensorIr),\n    QuantParams(TensorId),\n}\n\nimpl RegisterTensor {\n    pub fn as_normal_tensor(&self) -> Option<(&TensorIr, &FuseType)> {\n        match self {\n            RegisterTensor::Normal(tensor_ir, precision) => Some((tensor_ir, precision)),\n            RegisterTensor::QuantValues(_) => None,\n            RegisterTensor::QuantParams(_) => None,\n        }\n    }\n}\n\nimpl RegisteredTensors {\n    /// Iterate over all the registered tensors.\n    pub fn iter(&self) -> impl Iterator<Item = &RegisterTensor> {\n        self.tensors.iter()\n    }\n\n    /// Consumes and iterate over all the registered tensors.\n    pub fn into_iter(self) -> impl Iterator<Item = RegisterTensor> {\n        self.tensors.into_iter()\n    }\n\n    /// Returns the number of tensors registered.\n    pub fn len(&self) -> usize {\n        self.tensors.len()\n    }\n\n    /// Retrieve the [tensor id](TensorId) at the given index.\n    pub fn get_id(&self, index: usize) -> Option<TensorId> {\n        self.tensors.get(index).map(|entry| match entry {\n            RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id,\n            RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id,\n            RegisterTensor::QuantParams(tensor_id) => *tensor_id,\n        })\n    }\n\n    /// Doesn't return quantized tensor.\n    pub fn get_index(&self, tensor_id: TensorId) -> Option<usize> {\n        self.tensors\n            .iter()\n            .enumerate()\n            .find(|(_pos, entry)| match entry {\n                RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,\n                RegisterTensor::QuantValues(_) => false,\n                RegisterTensor::QuantParams(_) => false,\n            })\n            .map(|(pos, _)| pos)\n    }\n\n    /// Get the index of a quantized tensor.\n    pub fn get_index_quant(&self, tensor_id: TensorId) -> Option<usize> {\n        self.tensors\n            .iter()\n            .enumerate()\n            .find(|(_pos, entry)| match entry {\n                RegisterTensor::Normal(..) => false,\n                RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id,\n                RegisterTensor::QuantParams(_) => false,\n            })\n            .map(|(pos, _)| pos)\n    }\n\n    /// Doesn't return quantized tensor.\n    pub fn get(&self, tensor_id: TensorId) -> Option<(&TensorIr, &FuseType)> {\n        self.tensors\n            .iter()\n            .find(|entry| match entry {\n                RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor_id,\n                RegisterTensor::QuantValues(_) => false,\n                RegisterTensor::QuantParams(_) => false,\n            })\n            .and_then(|entry| match entry {\n                RegisterTensor::Normal(tensor_ir, fuse_precision) => {\n                    Some((tensor_ir, fuse_precision))\n                }\n                RegisterTensor::QuantValues(_) => None,\n                RegisterTensor::QuantParams(_) => None,\n            })\n    }\n\n    /// Insert a quantized tensor.\n    ///\n    /// It will return the positions for both the value tensor and param tensor.\n    pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) {\n        if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {\n            RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor,\n            _ => false,\n        }) {\n            let values = old.0;\n            let params = values + 1;\n            return (values, params);\n        }\n\n        let params = RegisterTensor::QuantParams(tensor.id);\n        let values = RegisterTensor::QuantValues(tensor);\n        let pos_values = self.len();\n        self.tensors.push(values);\n\n        let pos_params = self.len();\n        self.tensors.push(params);\n\n        (pos_values, pos_params)\n    }\n\n    /// Insert a normal tensor with the given [precision](FusePrecision) in the current block.\n    pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize {\n        if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val {\n            RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,\n            _ => false,\n        }) {\n            return old.0;\n        }\n\n        let value = RegisterTensor::Normal(tensor, precision);\n        let pos = self.len();\n\n        self.tensors.push(value);\n\n        pos\n    }\n\n    /// Update the already registered tensor with the given [tensor ir](TensorIr).\n    ///\n    /// # Notes\n    ///\n    /// This function only works with normal tensors, not quantized tensors.\n    pub fn update(&mut self, tensor: &TensorIr) {\n        if let Some(entry) = self.tensors.iter_mut().find(|entry| match entry {\n            RegisterTensor::Normal(tensor_ir, _) => tensor_ir.id == tensor.id,\n            _ => false,\n        }) && let RegisterTensor::Normal(tensor_ir, _) = entry\n        {\n            tensor_ir.status = tensor.status\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/trace/block.rs",
    "content": "use super::{FuseResources, RegisteredTensors, TensorView};\nuse crate::engine::{\n    codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo, MultiBlockPos, UnaryFuseArgs},\n    settings::FuseSettings,\n};\nuse burn_ir::{TensorId, TensorIr, TensorStatus};\nuse burn_std::{DType, Shape, quantization::QuantParam};\nuse serde::{Deserialize, Serialize};\nuse std::collections::{BTreeMap, btree_map::Entry};\n\n#[derive(Clone, Serialize, Deserialize, Debug)]\n/// A block containing all [operations](FuseOp) as well as reads and writes for each tensor along\n/// with the [fusion settings](FuseSettings).\npub struct FuseBlock {\n    /// Contains the [fusion settings](FuseSettings) associated to the current block.\n    pub settings: FuseSettings,\n    /// Contains all the [operations](FuseOp) registered in the current block.\n    pub ops: Vec<FuseOp>,\n    /// The reference shape of the current block.\n    pub shape_ref: Shape,\n    /// Contains all tensor inputs of the current block except for manually handled tensors.\n    ///\n    /// # Notes\n    ///\n    /// Some reads might not have read operations registered, such as dequantization, but it's\n    /// important to be registered here for vectorization. Input tensors that are not\n    /// registered here must be vectorized manually.\n    pub reads: BTreeMap<TensorId, Vec<FuseOp>>,\n    /// Contains all tensor outputs of the current block except for manually handled tensors.\n    /// We can have multiple writes when the same variable is reused after in another block.\n    pub writes: BTreeMap<TensorId, Vec<FuseOp>>,\n}\n\n#[derive(Clone, Debug)]\n/// It is responsible to build a [trace](FuseBlock).\npub struct FuseBlockBuilder {\n    pub settings: FuseSettings,\n    locals: LocalVariablePool,\n    pub ops: Vec<FuseOp>,\n    reads: BTreeMap<TensorId, Vec<FuseOp>>,\n    // Only for global registers.\n    writes: BTreeMap<TensorId, Vec<FuseOp>>,\n    // Output declared in this block alone.\n    outputs: RegisteredTensors,\n    pub outputs_unhandled: Vec<FuseArg>,\n    pub local_inputs: BTreeMap<TensorId, FuseArg>,\n    /// The reference shape used by this block.\n    pub shape_ref: Shape,\n}\n\n#[derive(Debug)]\n/// How a quantized input can be read.\npub enum QuantInput {\n    /// If already dequantized, we cache the dequantization and returns the local variable\n    /// corresponding to the float value.\n    AlreadyDequantized { local: FuseArg },\n    /// Otherwise we return the information necessary to dequantize the tensor.\n    Quantized { values: FuseArg, params: FuseArg },\n}\n\nimpl FuseBlockBuilder {\n    pub fn new(settings: FuseSettings) -> Self {\n        Self {\n            settings,\n            locals: Default::default(),\n            ops: Default::default(),\n            reads: Default::default(),\n            writes: Default::default(),\n            outputs: Default::default(),\n            outputs_unhandled: Default::default(),\n            local_inputs: Default::default(),\n            shape_ref: Shape::new([]),\n        }\n    }\n\n    /// Register an output tensor.\n    pub fn output(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {\n        if resources.indexed.contains_key(&tensor.id) {\n            return None;\n        }\n        if matches!(tensor.dtype, DType::QFloat(..)) {\n            return None;\n        }\n        let precision = tensor.dtype.into();\n\n        let out = match self.locals.get(precision, tensor.id) {\n            Some(local) => local,\n            None => {\n                let out = self.locals.create(precision, tensor.id);\n\n                self.outputs.insert(precision, tensor.clone());\n                resources.outputs.insert(precision, tensor.clone());\n\n                out\n            }\n        };\n\n        Some(out)\n    }\n\n    /// Register an input tensor.\n    pub fn multi_block_variable(\n        &mut self,\n        block_pos: usize,\n        tensor: &TensorIr,\n        global: bool,\n    ) -> Option<FuseArg> {\n        let precision = tensor.dtype.into();\n\n        if let Some(val) = self.local_inputs.get(&tensor.id) {\n            return Some(val.clone());\n        }\n\n        let val = match self.locals.get(precision, tensor.id) {\n            Some(val) => val,\n            None => {\n                return None;\n            }\n        };\n\n        let arg = if global {\n            FuseArg::MultiBlockGlobal(\n                MultiBlockPos {\n                    block_pos,\n                    block_local_pos: self.writes.len(),\n                },\n                val.precision(),\n            )\n        } else {\n            FuseArg::MultiBlockLocal(\n                MultiBlockPos {\n                    block_pos,\n                    block_local_pos: self.writes.len(),\n                },\n                val.precision(),\n            )\n        };\n\n        let ops = match self.writes.get_mut(&tensor.id) {\n            Some(ops) => ops,\n            None => {\n                self.writes.insert(tensor.id, Vec::new());\n                self.writes.get_mut(&tensor.id).unwrap()\n            }\n        };\n        ops.push(FuseOp::Assign(UnaryFuseArgs {\n            input: val,\n            out: arg.clone(),\n        }));\n\n        Some(arg)\n    }\n\n    /// Register an input tensor.\n    pub fn input(&mut self, tensor: &TensorIr, resources: &mut FuseResources) -> Option<FuseArg> {\n        if resources.indexed.contains_key(&tensor.id) {\n            return None;\n        }\n\n        if matches!(tensor.dtype, DType::QFloat(..)) {\n            return None;\n        }\n        let precision = tensor.dtype.into();\n\n        if let Some(val) = self.local_inputs.get(&tensor.id) {\n            return Some(val.clone());\n        }\n\n        let arg = match self.locals.get(precision, tensor.id) {\n            Some(local) => {\n                resources.inputs.update(tensor);\n\n                local\n            }\n            None => {\n                let input = if resources.outputs.get_index(tensor.id).is_some() {\n                    if let Some(val) = resources.registers.get(&tensor.id) {\n                        return Some(val.clone());\n                    };\n\n                    let pos = resources.buffers.insert(precision, tensor.clone());\n                    FuseArg::Output(pos, precision, LayoutInfo::Unknown)\n                } else {\n                    let pos = resources.inputs.insert(precision, tensor.clone());\n                    FuseArg::Input(pos, precision, LayoutInfo::Unknown)\n                };\n\n                let out = self.locals.create(precision, tensor.id);\n\n                let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {\n                    e.insert(Vec::with_capacity(1));\n                    self.reads.get_mut(&tensor.id).unwrap()\n                } else {\n                    self.reads.get_mut(&tensor.id).unwrap()\n                };\n\n                reads.push(FuseOp::Assign(UnaryFuseArgs {\n                    input,\n                    out: out.clone(),\n                }));\n\n                out\n            }\n        };\n\n        Some(arg)\n    }\n\n    /// Register an input quantized tensor.\n    pub fn input_quant(\n        &mut self,\n        tensor: &TensorIr,\n        resources: &mut FuseResources,\n    ) -> Option<QuantInput> {\n        if resources.indexed.contains_key(&tensor.id) {\n            return None;\n        }\n\n        let precision = tensor.dtype.into();\n        let precision_scales = match tensor.dtype {\n            DType::QFloat(scheme) => match scheme.param {\n                QuantParam::F32 => FuseType::F32,\n                QuantParam::F16 => FuseType::F16,\n                QuantParam::BF16 => FuseType::BF16,\n                QuantParam::UE8M0 | QuantParam::UE4M3 => {\n                    unimplemented!(\"Unsupported fuse precision\");\n                }\n            },\n            _ => return None,\n        };\n\n        let arg = match self.locals.get(precision, tensor.id) {\n            Some(local) => {\n                resources.inputs.update(tensor);\n                QuantInput::AlreadyDequantized { local }\n            }\n            None => {\n                let (new_input, q_index) = resources.inputs.insert_quant(tensor.clone());\n                let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);\n                let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);\n\n                // Important to flag that there is a read, even if no operation is registered.\n                if let Entry::Vacant(e) = self.reads.entry(tensor.id) {\n                    e.insert(Vec::new());\n                };\n\n                QuantInput::Quantized {\n                    values: input,\n                    params: scales,\n                }\n            }\n        };\n\n        Some(arg)\n    }\n\n    /// Register an input with swapped dims.\n    pub fn input_swap_dims(\n        &mut self,\n        tensor: &TensorIr,\n        output: &TensorIr,\n        dims: (usize, usize),\n        resources: &mut FuseResources,\n    ) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(..)) {\n            return None;\n        }\n        let precision = tensor.dtype.into();\n\n        let input_index = match self.locals.get(precision, tensor.id) {\n            Some(_) => {\n                // Can't fused an already fused input.\n                if resources.outputs.get(tensor.id).is_some() {\n                    return None;\n                }\n\n                match resources.inputs.get_index(tensor.id) {\n                    Some(index) => {\n                        resources.inputs.update(tensor);\n                        index\n                    }\n                    None => {\n                        return None;\n                    }\n                }\n            }\n            None => resources.inputs.insert(precision, tensor.clone()),\n        };\n\n        let out = self.output(output, resources)?;\n        let original = FuseArg::Input(input_index, precision, LayoutInfo::Unknown);\n\n        let broadcasted = output.shape[output.shape.rank() - 1] == 0;\n\n        resources.views.push(TensorView::SwapDims {\n            swapped: output.id,\n            original: tensor.id,\n            dims,\n        });\n\n        let input = FuseArg::InputSwapDims {\n            original: Box::new(original),\n            dims,\n            broadcasted,\n        };\n\n        let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {\n            e.insert(Vec::with_capacity(1));\n            self.reads.get_mut(&tensor.id).unwrap()\n        } else {\n            self.reads.get_mut(&tensor.id).unwrap()\n        };\n\n        reads.push(FuseOp::Assign(UnaryFuseArgs {\n            input,\n            out: out.clone(),\n        }));\n\n        Some(out)\n    }\n\n    /// Register an input that is reshaped.\n    pub fn input_reshaped(\n        &mut self,\n        tensor: &TensorIr,\n        output: &TensorIr,\n        resources: &mut FuseResources,\n    ) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(..)) {\n            return None;\n        }\n        let precision = tensor.dtype.into();\n\n        let input_index = match self.locals.get(precision, tensor.id) {\n            Some(_) => {\n                // Can't fused an already fused input.\n                if resources.outputs.get(tensor.id).is_some() {\n                    return None;\n                }\n\n                match resources.inputs.get_index(tensor.id) {\n                    Some(index) => {\n                        resources.inputs.update(tensor);\n                        index\n                    }\n                    None => {\n                        return None;\n                    }\n                }\n            }\n            None => resources.inputs.insert(precision, tensor.clone()),\n        };\n\n        let out = self.output(output, resources)?;\n        let original = FuseArg::Input(input_index, precision, LayoutInfo::Unknown);\n\n        let mut shape = Vec::new();\n\n        let index = resources.num_reshaped;\n        resources.num_reshaped += 1;\n\n        let rank = output.shape.rank();\n\n        for i in 0..output.shape.rank() {\n            let id = index * rank + i;\n            shape.push(FuseArg::ScalarShape(id));\n        }\n\n        resources.views.push(TensorView::Reshape {\n            reshaped: output.id,\n            original: tensor.id,\n            reshape_pos: index,\n            shape_relative: output.shape.clone(),\n        });\n\n        let input = FuseArg::InputReshaped {\n            original: Box::new(original),\n            shape,\n            broadcasted: output.shape[rank - 1] == 0,\n        };\n\n        let reads = if let Entry::Vacant(e) = self.reads.entry(tensor.id) {\n            e.insert(Vec::with_capacity(1));\n            self.reads.get_mut(&tensor.id).unwrap()\n        } else {\n            self.reads.get_mut(&tensor.id).unwrap()\n        };\n\n        reads.push(FuseOp::Assign(UnaryFuseArgs {\n            input,\n            out: out.clone(),\n        }));\n\n        Some(out)\n    }\n\n    /// Build into a fuse block.\n    pub fn build(\n        &self,\n        resources: &FuseResources,\n        outputs: &mut RegisteredTensors,\n        buffers: &mut Vec<TensorId>,\n    ) -> FuseBlock {\n        let ops = self.ops.clone();\n        let reads = self.reads.clone();\n        let tensor_writes = self.tensor_writes(resources, buffers);\n\n        let mut writes = self.writes.clone();\n\n        for (tensor, precision) in tensor_writes\n            .iter()\n            .filter_map(|entry| entry.as_normal_tensor())\n        {\n            if let Some(local) = self.locals.get_any_precision(tensor.id) {\n                let out_index = outputs.insert(*precision, tensor.clone());\n\n                let ops = match writes.get_mut(&tensor.id) {\n                    Some(ops) => ops,\n                    None => {\n                        writes.insert(tensor.id, Vec::new());\n                        writes.get_mut(&tensor.id).unwrap()\n                    }\n                };\n\n                ops.push(FuseOp::Assign(UnaryFuseArgs {\n                    input: local,\n                    out: FuseArg::Output(out_index, *precision, LayoutInfo::Unknown),\n                }));\n            }\n        }\n\n        FuseBlock {\n            settings: self.settings,\n            ops,\n            shape_ref: self.shape_ref.clone(),\n            reads,\n            writes,\n        }\n    }\n\n    /// Return the tensor that needs to be written to.\n    ///\n    /// # Notes\n    ///\n    /// The buffers vector passed as input is only to track the intermediary buffer writes needed\n    /// during execution.\n    pub fn tensor_writes(\n        &self,\n        resources: &FuseResources,\n        buffers: &mut Vec<TensorId>,\n    ) -> RegisteredTensors {\n        let mut result = RegisteredTensors::default();\n\n        // All tensors where their latest representation is not read write should be written to since they\n        // are going to be used after the fused kernel by other operations.\n        for output in self.outputs.iter() {\n            if let Some((tensor, _precision)) = output.as_normal_tensor() {\n                // We get the latest representation from the resources, not just this block.\n                if let Some((tensor, precision)) = resources.outputs.get(tensor.id) {\n                    if !matches!(tensor.status, TensorStatus::ReadWrite) {\n                        result.insert(*precision, tensor.clone());\n                    } else if resources.buffers.get(tensor.id).is_some()\n                        && !buffers.contains(&tensor.id)\n                    {\n                        result.insert(*precision, tensor.clone());\n                        // We make sure we don't write multiple time in the same buffer, only the\n                        // earliest possible.\n                        buffers.push(tensor.id);\n                    }\n                }\n            }\n        }\n\n        result\n    }\n}\n\n#[derive(Default, Clone, Debug)]\npub struct LocalVariablePool {\n    values: BTreeMap<FuseType, BTreeMap<TensorId, usize>>,\n}\n\nimpl LocalVariablePool {\n    fn get(&self, precision: FuseType, tensor_id: TensorId) -> Option<FuseArg> {\n        if let Some(indexes) = self.values.get(&precision)\n            && let Some(index) = indexes.get(&tensor_id)\n        {\n            return Some(FuseArg::BlockLocal {\n                pos: *index,\n                ty: precision,\n            });\n        }\n\n        None\n    }\n\n    fn get_any_precision(&self, tensor_id: TensorId) -> Option<FuseArg> {\n        for (precision, indexes) in self.values.iter() {\n            if let Some(index) = indexes.get(&tensor_id) {\n                return Some(FuseArg::BlockLocal {\n                    pos: *index,\n                    ty: *precision,\n                });\n            }\n        }\n\n        None\n    }\n\n    fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg {\n        if let Some(indexes) = self.values.get_mut(&precision) {\n            let new_index = indexes.len();\n            indexes.insert(tensor_id, new_index);\n            return FuseArg::BlockLocal {\n                pos: new_index,\n                ty: precision,\n            };\n        }\n\n        let new_index = 0;\n        self.values\n            .insert(precision, BTreeMap::from_iter([(tensor_id, new_index)]));\n\n        FuseArg::BlockLocal {\n            pos: new_index,\n            ty: precision,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/trace/fuser.rs",
    "content": "use super::{\n    super::{\n        codegen::ir::{FuseArg, FuseOp, FuseType, LayoutInfo},\n        settings::FuseSettings,\n    },\n    FuseResources,\n    block::FuseBlockBuilder,\n};\nuse super::{FuseTrace, RegisteredTensors};\nuse crate::engine::trace::block::QuantInput;\nuse burn_fusion::stream::ScalarId;\nuse burn_ir::{ScalarIr, TensorIr};\nuse burn_std::{DType, Shape};\nuse cubecl::quant::scheme::QuantParam;\n\n#[derive(Clone, Debug)]\n/// It is responsible to create a [trace](FuseTrace) composed of multiple [blocks](super::block::FuseBlock).\n///\n/// It mostly handles the [resources](KernelResources) needed by the generated fused kernel, and\n/// delegates most of the work to the [block builder](FuseBlockBuilder).\npub struct TraceFuser {\n    settings: FuseSettings,\n    // The tensors returned by the block that don't need to be written to global memory.\n    block_current: FuseBlockBuilder,\n    blocks_previous: Vec<FuseBlockBuilder>,\n    resources: FuseResources,\n}\n\nimpl TraceFuser {\n    /// Create a new trace builder with the given bool precision and [fuse settings](FuseSettings).\n    pub fn new(settings: FuseSettings) -> Self {\n        Self {\n            settings,\n            block_current: FuseBlockBuilder::new(settings),\n            blocks_previous: Default::default(),\n            resources: Default::default(),\n        }\n    }\n\n    /// Get the number of blocks that are closed.\n    pub fn num_previous_blocks(&self) -> usize {\n        self.blocks_previous.len()\n    }\n\n    /// Tag a tensor as dropped.\n    pub fn fuse_dropped(&mut self, tensor: &TensorIr) {\n        self.resources.outputs.update(tensor);\n        self.resources.inputs.update(tensor);\n        self.resources.dropped.insert(tensor.id);\n    }\n\n    /// Register an operation.\n    pub fn fuse_operation(&mut self, op: FuseOp) {\n        self.block_current.ops.push(op);\n    }\n\n    /// The number of operations fused.\n    pub fn num_ops_fused(&self) -> u32 {\n        let mut num_ops_fused = 0;\n\n        for block in self.blocks_previous.iter() {\n            num_ops_fused += block.ops.len();\n        }\n\n        num_ops_fused += self.block_current.ops.len();\n        num_ops_fused as u32\n    }\n\n    /// Close the current block with the given reference shape and creates a new one with new [fusion settings](FuseSettings).\n    pub fn next_block(&mut self, shape_ref: Shape, settings: FuseSettings) {\n        let mut block_new = FuseBlockBuilder::new(settings);\n        core::mem::swap(&mut self.block_current, &mut block_new);\n        block_new.shape_ref = shape_ref;\n        self.blocks_previous.push(block_new);\n        self.settings = settings;\n    }\n\n    // Estimate how many bindings are in use right now. This can return more than the actual number\n    // but should never return less.\n    pub fn estimate_bindings(&self) -> u32 {\n        let mut buffers = Vec::new();\n        let mut estimation = 1; // Metadata takes one.\n\n        // We assume we are not going to write multiple times in the same output buffer.\n        for b in self.blocks_previous.iter() {\n            estimation += b.tensor_writes(&self.resources, &mut buffers).len() as u32;\n        }\n\n        estimation += self\n            .block_current\n            .tensor_writes(&self.resources, &mut buffers)\n            .len() as u32;\n        estimation += self.resources.inputs.len() as u32;\n        // One buffer per scalar type for now.\n        estimation += self.resources.scalars.len() as u32;\n\n        estimation\n    }\n\n    /// Tag the [tensor](TensorIr) as received from a previous block.\n    ///\n    /// This will avoid reading the input again and instead use le local version when possible.\n    pub fn block_local_input(\n        &mut self,\n        tensor: &TensorIr,\n        block_pos: usize,\n        global: bool,\n    ) -> FuseArg {\n        let block = &mut self.blocks_previous[block_pos];\n\n        let src_arg = match block.multi_block_variable(block_pos, tensor, global) {\n            Some(val) => val,\n            None => {\n                // We try to read the input if not present.\n                block.input(tensor, &mut self.resources);\n                block\n                    .multi_block_variable(block_pos, tensor, global)\n                    .unwrap()\n            }\n        };\n\n        self.resources.outputs.update(tensor);\n\n        if global {\n            self.resources.registers.insert(tensor.id, src_arg.clone());\n        }\n\n        self.block_current\n            .local_inputs\n            .insert(tensor.id, src_arg.clone());\n        src_arg\n    }\n\n    /// Register an output tensor that won't be automatically synced into global memory.\n    ///\n    /// It is therefore the responsibility of the operation to write the result to given tensor.\n    pub fn output_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {\n        let arg = self\n            .output(tensor)\n            .expect(\"Can't add a new output that is already used in an index operation\");\n\n        self.resources.outputs_unhandled.push(arg.clone());\n        self.block_current.outputs_unhandled.push(arg.clone());\n        arg\n    }\n\n    /// Register an input tensor that won't be automatically read into a local variable.\n    ///\n    /// It is therefore the responsibility of the operation to read the given tensor.\n    pub fn input_unhandled(&mut self, tensor: &TensorIr) -> FuseArg {\n        if self.resources.indexed.contains_key(&tensor.id) {\n            panic!(\"Can't add a new input that is already used in an index operation\");\n        }\n\n        self.resources.outputs.update(tensor);\n\n        let precision = tensor.dtype.into();\n        let new_input = self.resources.inputs.insert(precision, tensor.clone());\n        let arg = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);\n\n        self.resources.inputs_unhandled.push(tensor.id);\n        arg\n    }\n\n    /// Register an input tensor.\n    pub fn input_quantized_unhandled(&mut self, tensor: &TensorIr) -> Option<(FuseArg, FuseArg)> {\n        if self.resources.indexed.contains_key(&tensor.id) {\n            panic!(\"Can't add a new input that is already used in an index operation\");\n        }\n        self.resources.outputs.update(tensor);\n\n        let precision = tensor.dtype.into();\n        let precision_scales = match tensor.dtype {\n            DType::QFloat(scheme) => match scheme.param {\n                QuantParam::F32 => FuseType::F32,\n                QuantParam::F16 => FuseType::F16,\n                QuantParam::BF16 => FuseType::BF16,\n                QuantParam::UE8M0 | QuantParam::UE4M3 => {\n                    unimplemented!(\"Unsupported fuse precision\");\n                }\n            },\n            _ => return None,\n        };\n\n        let (new_input, q_index) = self.resources.inputs.insert_quant(tensor.clone());\n        let input = FuseArg::Input(new_input, precision, LayoutInfo::Unknown);\n        let scales = FuseArg::Input(q_index, precision_scales, LayoutInfo::Unknown);\n\n        self.resources.inputs_unhandled.push(tensor.id);\n        Some((input, scales))\n    }\n\n    /// Register an input tensor.\n    pub fn input(&mut self, tensor: &TensorIr) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(_)) {\n            return None;\n        }\n\n        self.resources.outputs.update(tensor);\n\n        self.block_current.input(tensor, &mut self.resources)\n    }\n\n    /// Register an input tensor.\n    pub fn input_quantized(&mut self, tensor: &TensorIr) -> Option<QuantInput> {\n        self.resources.outputs.update(tensor);\n        self.block_current.input_quant(tensor, &mut self.resources)\n    }\n\n    /// Register an output tensor.\n    pub fn output(&mut self, tensor: &TensorIr) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(_)) {\n            return None;\n        }\n        self.block_current.output(tensor, &mut self.resources)\n    }\n\n    /// Register an input that will be accessed using custom indexing with no vectorization.\n    pub fn input_indexed(&mut self, tensor: &TensorIr) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(_)) {\n            return None;\n        }\n\n        if let Some(val) = self.resources.indexed.get(&tensor.id) {\n            self.resources.outputs.update(tensor);\n            return Some(val.clone());\n        };\n\n        if self.resources.inputs.get(tensor.id).is_some() {\n            return None;\n        }\n\n        if self.resources.outputs.get(tensor.id).is_some() {\n            return None;\n        }\n\n        let input = self.input_unhandled(tensor);\n        self.resources.indexed.insert(tensor.id, input.clone());\n\n        Some(input)\n    }\n\n    /// Register an input with swapped dims.\n    pub fn input_swap_dims(\n        &mut self,\n        tensor: &TensorIr,\n        output: &TensorIr,\n        dims: (usize, usize),\n    ) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(_)) {\n            return None;\n        }\n\n        self.resources.outputs.update(tensor);\n        self.block_current\n            .input_swap_dims(tensor, output, dims, &mut self.resources)\n    }\n\n    /// Register an input that is reshaped.\n    pub fn input_reshaped(&mut self, tensor: &TensorIr, output: &TensorIr) -> Option<FuseArg> {\n        if matches!(tensor.dtype, DType::QFloat(_)) {\n            return None;\n        }\n\n        self.resources.outputs.update(tensor);\n        self.block_current\n            .input_reshaped(tensor, output, &mut self.resources)\n    }\n\n    /// Register a scalar value.\n    pub fn scalar(&mut self, elem: &ScalarIr, dtype: DType) -> FuseArg {\n        let precision = dtype.into();\n        let id = if let ScalarIr::UInt(value) = elem {\n            ScalarId { value: *value }\n        } else {\n            unreachable!() // should always be u64\n        };\n\n        let new_index = self.resources.scalars.len();\n\n        self.resources.scalars.push((precision, id.value));\n        FuseArg::Scalar(new_index, precision)\n    }\n\n    /// Finish fusing and returns the created trace.\n    pub fn finish(&mut self, shape_ref: Shape) -> FuseTrace {\n        let mut resources = self.resources.clone();\n        let mut outputs = RegisteredTensors::default();\n        let mut buffers = Vec::new();\n\n        for tensor in resources.buffers.iter() {\n            let (tensor, ty) = tensor.as_normal_tensor().unwrap();\n            outputs.insert(*ty, tensor.clone());\n        }\n\n        let mut blocks = Vec::new();\n\n        let mut register_block = |block: &FuseBlockBuilder| {\n            let block = block.build(&self.resources, &mut outputs, &mut buffers);\n            blocks.push(block);\n        };\n\n        for block in self.blocks_previous.iter() {\n            register_block(block);\n        }\n        self.block_current.shape_ref = shape_ref;\n        register_block(&self.block_current);\n\n        // We update the output tensors registered to be the ones that are written to in global\n        // memory.\n        resources.outputs = outputs;\n\n        FuseTrace { blocks, resources }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/engine/trace/mod.rs",
    "content": "pub(crate) mod block;\n\nmod base;\nmod fuser;\n\npub use base::*;\npub use fuser::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/lib.rs",
    "content": "#[macro_use]\nextern crate derive_new;\n\npub mod optim;\n\nmod base;\n\npub(crate) mod engine;\npub(crate) mod tune;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/base.rs",
    "content": "use crate::optim::{\n    elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},\n    matmul::{MatmulOptimization, MatmulOptimizationState},\n    reduce::{ReduceOptimization, ReduceOptimizationState},\n    reduce_broadcasted::{ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationState},\n};\nuse cubecl::Runtime;\nuse serde::{Deserialize, Serialize};\n\n/// Fusion optimization type for cubecl.\n///\n/// More optimization variants should be added here.\n#[allow(clippy::large_enum_variant)]\npub enum CubeOptimization<R: Runtime> {\n    ElementWise(ElemwiseOptimization<R>),\n    Matmul(MatmulOptimization<R>),\n    Reduce(ReduceOptimization<R>),\n    ReduceBroadcasted(ReduceBroadcastedOptimization<R>),\n}\n\nimpl<R: Runtime> core::fmt::Debug for CubeOptimization<R> {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        let value = self.to_opt_state();\n        f.write_fmt(format_args!(\"{value:?}\"))\n    }\n}\n\nimpl<R: Runtime> CubeOptimization<R> {\n    /// Serializes the current optimization to its state.\n    pub fn to_opt_state(&self) -> CubeOptimizationState {\n        match self {\n            Self::ElementWise(value) => CubeOptimizationState::ElementWise(value.to_state()),\n            Self::Matmul(value) => CubeOptimizationState::Matmul(value.to_state()),\n            Self::Reduce(value) => CubeOptimizationState::Reduce(value.to_state()),\n            Self::ReduceBroadcasted(value) => {\n                CubeOptimizationState::ReduceBroadcasted(value.to_state())\n            }\n        }\n    }\n}\n\nimpl<R: Runtime> burn_fusion::NumOperations for CubeOptimization<R> {\n    fn len(&self) -> usize {\n        match self {\n            Self::ElementWise(op) => op.num_ops_fused(),\n            Self::Matmul(op) => op.num_ops_fused(),\n            Self::Reduce(op) => op.num_ops_fused(),\n            Self::ReduceBroadcasted(op) => op.num_ops_fused(),\n        }\n    }\n}\n\n/// Fusion optimization state type for cubecl.\n///\n/// More optimization variants should be added here.\n#[allow(clippy::large_enum_variant)]\n#[derive(Serialize, Deserialize, Debug)]\npub enum CubeOptimizationState {\n    ElementWise(ElemwiseOptimizationState),\n    Matmul(MatmulOptimizationState),\n    Reduce(ReduceOptimizationState),\n    ReduceBroadcasted(ReduceBroadcastedOptimizationState),\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/elemwise/fuser.rs",
    "content": "use super::optimization::ElemwiseOptimization;\nuse crate::{\n    engine::{\n        fuser::TraceOperationFuser,\n        settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},\n    },\n    optim::CubeOptimization,\n};\nuse burn_fusion::OperationFuser;\nuse burn_std::Shape;\nuse cubecl::Runtime;\n\n/// Fuses element wise operations.\npub struct ElementWiseFuser<R: Runtime> {\n    fuser: TraceOperationFuser,\n    device: R::Device,\n}\n\nimpl<R: Runtime> Clone for ElementWiseFuser<R> {\n    fn clone(&self) -> Self {\n        Self {\n            fuser: self.fuser.clone(),\n            device: self.device.clone(),\n        }\n    }\n}\n\nimpl<R: Runtime> ElementWiseFuser<R> {\n    pub fn shape_id(&self) -> Shape {\n        self.fuser.current_output_shape.clone()\n    }\n    pub fn new(device: R::Device) -> Self {\n        let client = R::client(&device);\n        let props = client.properties();\n        let max_bindings = props.hardware.max_bindings;\n\n        Self {\n            fuser: TraceOperationFuser::new(\n                max_bindings,\n                FuseSettings {\n                    broadcast: true,\n                    output_shape_updates: true,\n                    inplace: true,\n                    vectorization: VectorizationSetting::Activated,\n                    ref_layout: RefLayoutSetting::Any,\n                },\n            ),\n            device,\n        }\n    }\n}\n\nimpl<R: Runtime> OperationFuser<CubeOptimization<R>> for ElementWiseFuser<R> {\n    fn fuse(&mut self, operation: &burn_ir::OperationIr) {\n        self.fuser.fuse(operation);\n    }\n\n    fn finish(&mut self) -> CubeOptimization<R> {\n        let client = R::client(&self.device);\n        let trace = self.fuser.finish();\n        let elementwise = ElemwiseOptimization::new(trace, client, self.device.clone(), self.len());\n\n        CubeOptimization::ElementWise(elementwise)\n    }\n\n    fn reset(&mut self) {\n        self.fuser.reset()\n    }\n\n    fn status(&self) -> burn_fusion::FuserStatus {\n        self.fuser.status()\n    }\n\n    fn properties(&self) -> burn_fusion::FuserProperties {\n        self.fuser.properties()\n    }\n\n    fn len(&self) -> usize {\n        self.fuser.len()\n    }\n\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {\n        Box::new(self.clone())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/elemwise/mod.rs",
    "content": "mod fuser;\nmod optimization;\n\npub use fuser::*;\npub use optimization::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs",
    "content": "use crate::{\n    CubeFusionHandle,\n    engine::{\n        codegen::{\n            DynSize,\n            io::ref_len,\n            ir::{\n                FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsLaunch, RefLayout,\n                multi_block_variables_init,\n            },\n            kernel::{fuse_on_write, init_locals},\n        },\n        launch::{\n            FuseTraceLauncher,\n            runner::{TraceRunner, Vectorization},\n        },\n        trace::FuseTrace,\n    },\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{CubeDim, calculate_cube_count_elemwise, client::ComputeClient, prelude::*};\nuse serde::{Deserialize, Serialize};\n\n#[derive(new)]\n/// Fuse element wise operations into a single kernel.\npub struct ElemwiseOptimization<R: Runtime> {\n    pub(crate) trace: FuseTrace,\n    client: ComputeClient<R>,\n    device: R::Device,\n    len: usize,\n}\n\n#[derive(Serialize, Deserialize, Debug)]\n/// State for the [elemwise optimization](ElemwiseOptimization).\npub struct ElemwiseOptimizationState {\n    trace: FuseTrace,\n    len: usize,\n}\n\nimpl<R: Runtime> ElemwiseOptimization<R> {\n    /// Execute the optimization.\n    pub fn execute(&self, context: &mut Context<'_, CubeFusionHandle<R>>) {\n        let launcher = FuseTraceLauncher::new(&self.trace, &ElemwiseRunner);\n\n        match launcher.launch(&self.client, &self.device, context) {\n            Ok(_) => (),\n            Err(err) => {\n                panic!(\"{err:?} - {:?}\", self.trace);\n            }\n        }\n    }\n\n    /// Number of element wise operations fused.\n    pub fn num_ops_fused(&self) -> usize {\n        self.len\n    }\n\n    /// Create an optimization from its [state](ElemwiseOptimizationState).\n    pub fn from_state(device: &R::Device, state: ElemwiseOptimizationState) -> Self {\n        Self {\n            trace: state.trace,\n            len: state.len,\n            client: R::client(device),\n            device: device.clone(),\n        }\n    }\n\n    /// Convert the optimization to its [state](ElemwiseOptimizationState).\n    pub fn to_state(&self) -> ElemwiseOptimizationState {\n        ElemwiseOptimizationState {\n            trace: self.trace.clone(),\n            len: self.len,\n        }\n    }\n}\n\npub struct ElemwiseRunner;\n\nimpl<R: Runtime> Vectorization<R> for ElemwiseRunner {}\nimpl<R: Runtime> TraceRunner<R> for ElemwiseRunner {\n    type Error = LaunchError; // No error possible\n\n    fn run<'a>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        configs: &[FuseBlockConfig],\n    ) -> Result<(), Self::Error> {\n        let config = &configs[0];\n        let shape = match &config.ref_layout {\n            RefLayout::Concrete(arg) => match arg {\n                FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank),\n                FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank),\n                _ => panic!(\"Invalid concreate ref layout\"),\n            },\n            RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank),\n        };\n        let working_units = shape.iter().product::<usize>() / config.width;\n        let cube_dim = CubeDim::new(client, working_units);\n        let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);\n        let address_type = inputs\n            .required_address_type()\n            .max(outputs.required_address_type());\n\n        unsafe {\n            elemwise_fuse::launch_unchecked(\n                client,\n                cube_count,\n                cube_dim,\n                address_type,\n                inputs,\n                outputs,\n                config.clone(),\n            );\n        };\n\n        Ok(())\n    }\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\nfn elemwise_fuse(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    #[comptime] config: &FuseBlockConfig,\n) {\n    // We write no values for this fusion.\n    let values = Registry::<FuseArg, Vector<f32, DynSize>>::new();\n    let args = comptime![Vec::<FuseArg>::new()];\n    let pos = ABSOLUTE_POS;\n\n    multi_block_variables_init(config, &mut outputs.variables);\n\n    let mut locals = init_locals(inputs, outputs, config);\n    let length = ref_len(inputs, outputs, &locals, config);\n\n    if pos < length {\n        fuse_on_write::<f32, DynSize>(inputs, outputs, &mut locals, pos, values, args, config)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/matmul/args.rs",
    "content": "use crate::engine::codegen::{\n    io::ref_vector_size,\n    ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, LocalArgs, multi_block_variables_init},\n    kernel::init_locals,\n    view::{FusedOutput, GlobalInput, GlobalInputExpand},\n};\nuse cubecl::{\n    intrinsic,\n    prelude::*,\n    quant::scheme::{QuantLevel, QuantScheme},\n    std::{\n        FastDivmod,\n        quant::{\n            RunWithQuantType,\n            view::{QuantizedView, run_with_quant_type},\n        },\n        tensor::{\n            View, ViewExpand,\n            layout::{Coords1d, Coords2d, VirtualLayout},\n        },\n    },\n};\nuse cubek::{\n    matmul::{\n        components::global::memory::{\n            BatchLayout, BlockScaledLayout, GlobalLayout, GlobalLayoutConfig, GlobalLayoutExpand,\n            GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout,\n        },\n        launch::{BatchedCoords, MatmulArgs},\n    },\n    std::MatrixLayout,\n};\nuse serde::{Deserialize, Serialize};\nuse std::marker::PhantomData;\n\n#[derive(Clone)]\npub struct FusedMatmulArgs;\n\n#[derive(CubeLaunch, CubeType)]\npub struct FusedMatmulInput {\n    global: GlobalArgs,\n    #[cube(comptime)]\n    config: FuseBlockConfig,\n    #[cube(comptime)]\n    a: MatmulArg,\n    #[cube(comptime)]\n    b: MatmulArg,\n    #[cube(comptime)]\n    c: Option<MatmulArg>,\n    #[cube(comptime)]\n    out: FuseArg,\n}\n\n#[cube]\nimpl MatmulArgs for FusedMatmulArgs {\n    type Output<EO: CubePrimitive> = GlobalArgs;\n    type Input<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> = FusedMatmulInput;\n    type State<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive> = FusedMatmulState;\n    type Config = ();\n\n    fn init_state<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        inputs: &Self::Input<Lhs, Rhs, EO>,\n        outputs: &mut Self::Output<EO>,\n        _config: (),\n        #[comptime] lhs_layout_config: GlobalLayoutConfig,\n        #[comptime] rhs_layout_config: GlobalLayoutConfig,\n        #[comptime] out_layout_config: GlobalLayoutConfig,\n    ) -> Self::State<Lhs, Rhs, EO> {\n        multi_block_variables_init(&inputs.config, &mut outputs.variables);\n\n        let mut locals = init_locals(&inputs.global, outputs, &inputs.config);\n        let rank = comptime![inputs.config.rank];\n\n        let mut batch_shape = Sequence::new();\n        let mut batch_strides_out = Sequence::new();\n\n        #[unroll]\n        for i in 0..rank - 2 {\n            batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32));\n            batch_strides_out.push(locals.ref_strides[i]);\n        }\n\n        let batch_lhs = input_batch_layout(\n            &inputs.global,\n            &batch_shape,\n            comptime![inputs.a.clone()],\n            comptime![inputs.config.clone()],\n        );\n        let batch_rhs = input_batch_layout(\n            &inputs.global,\n            &batch_shape,\n            comptime![inputs.b.clone()],\n            comptime![inputs.config.clone()],\n        );\n        let batch_acc = match comptime![inputs.c.clone()] {\n            Some(c) => ComptimeOption::Some(input_batch_layout(\n                &inputs.global,\n                &batch_shape,\n                comptime![c],\n                comptime![inputs.config.clone()],\n            )),\n            None => ComptimeOption::new_None(),\n        };\n        let batch_out = BatchLayout::new(batch_strides_out, batch_shape.clone());\n\n        FusedMatmulState::new(\n            inputs,\n            outputs,\n            &mut locals,\n            batch_lhs,\n            batch_rhs,\n            batch_acc,\n            VirtualLayout::new::<BatchLayout>(batch_out),\n            batch_shape,\n            &inputs.config,\n            lhs_layout_config,\n            rhs_layout_config,\n            out_layout_config,\n        )\n    }\n\n    fn view_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n    ) -> View<Lhs, BatchedCoords> {\n        global_view(\n            &state.inputs,\n            &state.locals,\n            &state.batch_shape,\n            comptime![state.a.clone()],\n            comptime![state.config.clone()],\n            state.lhs_layout_config,\n        )\n    }\n\n    fn batch_lhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n        batch: usize,\n    ) -> usize {\n        state.a_batch.to_source_pos(batch)\n    }\n\n    fn view_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n    ) -> View<Rhs, BatchedCoords> {\n        global_view(\n            &state.inputs,\n            &state.locals,\n            &state.batch_shape,\n            comptime![state.b.clone()],\n            comptime![state.config.clone()],\n            comptime![state.rhs_layout_config],\n        )\n    }\n\n    fn batch_rhs<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n        batch: usize,\n    ) -> usize {\n        state.b_batch.to_source_pos(batch)\n    }\n\n    fn view_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n    ) -> ComptimeOption<View<EO, BatchedCoords>> {\n        match comptime![state.c.clone()] {\n            Some(c) => {\n                let view = global_view(\n                    &state.inputs,\n                    &state.locals,\n                    &state.batch_shape,\n                    c,\n                    comptime![state.config.clone()],\n                    comptime![state.out_layout_config],\n                );\n                ComptimeOption::Some(view)\n            }\n            None => ComptimeOption::new_None(),\n        }\n    }\n\n    fn batch_acc<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n        batch: usize,\n    ) -> usize {\n        #[comptime]\n        match state.c_batch {\n            ComptimeOption::Some(c_batch) => c_batch.to_source_pos(batch),\n            ComptimeOption::None => batch,\n        }\n    }\n\n    fn view_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &mut Self::State<Lhs, Rhs, EO>,\n    ) -> View<EO, BatchedCoords, ReadWrite> {\n        let rank = comptime![state.config.rank];\n\n        let shape_row = state.locals.ref_shape[rank - 2] as u32;\n        let shape_col = state.locals.ref_shape[rank - 1] as u32;\n\n        let stride_row = state.locals.ref_strides[rank - 2];\n        let stride_col = state.locals.ref_strides[rank - 1];\n\n        let layout = GlobalLayout::new(\n            VirtualLayout::new::<NoopLayout>(NoopLayout::new()),\n            shape_row,\n            shape_col,\n            stride_row,\n            stride_col,\n            ref_vector_size(&state.locals),\n            1u32,\n            state.out_layout_config,\n        );\n        let mut buffer = FusedOutput::new(\n            &state.inputs,\n            &mut state.outputs,\n            &mut state.locals,\n            comptime![state.out.clone()],\n            comptime![state.config.clone()],\n        );\n        View::new_mut::<FusedOutput, Coords1d>(&mut buffer, layout)\n    }\n\n    fn batch_out<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        state: &Self::State<Lhs, Rhs, EO>,\n        batch: usize,\n    ) -> usize {\n        state.out_batch.to_source_pos(batch)\n    }\n\n    fn runtime_config<Lhs: CubePrimitive, Rhs: CubePrimitive, EO: CubePrimitive>(\n        _state: &Self::State<Lhs, Rhs, EO>,\n    ) {\n    }\n}\n\n#[cube]\n#[allow(clippy::missing_transmute_annotations)]\nfn global_view<E: CubePrimitive>(\n    inputs: &GlobalArgs,\n    locals: &LocalArgs,\n    batch_shape: &Sequence<FastDivmod<u32>>,\n    #[comptime] arg: MatmulArg,\n    #[comptime] config: FuseBlockConfig,\n    #[comptime] layout_config: GlobalLayoutConfig,\n) -> View<E, BatchedCoords> {\n    let rank = comptime![config.rank];\n    let data = comptime![arg.data().clone()];\n    let data_tensor = match comptime![data.clone()] {\n        FuseArg::Input(pos, ..) => inputs.tensors.index(pos),\n        _ => panic!(\"Input must be concrete\"),\n    };\n\n    let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32;\n    let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32;\n    let mut packing = comptime![1];\n\n    if arg.scheme().is_some() {\n        let scheme = arg.scheme().unwrap();\n        let num_quants = scheme.num_quants() as u32;\n        comptime![packing = num_quants];\n        match comptime![layout_config.matrix_layout] {\n            MatrixLayout::RowMajor => shape_col *= num_quants,\n            MatrixLayout::ColMajor => shape_row *= num_quants,\n        };\n    }\n\n    let shape = (shape_row, shape_col);\n\n    // Noop for normal inputs because batch offset is cached, quantized uses logical batches\n    let batch_layout = match comptime![arg.clone()] {\n        MatmulArg::Normal(_) => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),\n        MatmulArg::Quantized { data, .. } => {\n            let data_arg = comptime![MatmulArg::Normal(data)];\n            input_batch_layout(inputs, batch_shape, data_arg, comptime![config.clone()])\n        }\n    };\n\n    let data_layout = global_layout(\n        inputs,\n        shape,\n        batch_layout,\n        arg.data().clone(),\n        config.clone(),\n        data_tensor.tensor.vector_size(),\n        layout_config,\n        packing,\n    );\n    let data_buf = GlobalInput::new(inputs, locals, data, comptime![config.clone()], None);\n\n    match comptime![arg.clone()] {\n        MatmulArg::Normal(_) => View::new::<GlobalInput, Coords1d>(&data_buf, data_layout),\n        MatmulArg::Quantized { scales, scheme, .. } => {\n            let scales_layout = match comptime![scheme.level] {\n                QuantLevel::Tensor => GlobalScaleLayout::new_PerTensor(shape),\n                QuantLevel::Block(block_size) => {\n                    let block_size = comptime![block_size.as_dim::<2>()];\n\n                    let scales_arg = comptime![MatmulArg::Normal(scales.clone())];\n                    let batch_layout = input_batch_layout(\n                        inputs,\n                        batch_shape,\n                        scales_arg,\n                        comptime![config.clone()],\n                    );\n\n                    let scales_layout = global_layout(\n                        inputs,\n                        shape,\n                        batch_layout,\n                        comptime![scales.clone()],\n                        comptime![config.clone()],\n                        1usize,\n                        layout_config,\n                        1u32,\n                    );\n                    GlobalScaleLayout::new_BlockScaled(BlockScaledLayout::new(\n                        shape,\n                        scales_layout,\n                        comptime![(block_size[0] as u32, block_size[1] as u32)],\n                    ))\n                }\n            };\n            let scales_buf = GlobalInput::new(inputs, locals, scales, config, None);\n\n            // Redefine because of `Numeric` bound, kinda hacky but I can't figure out a way to\n            // assert `Vector<T: Numeric>::Scalar: Numeric`\n            let define!(T) = storage_type_of::<E::Scalar>();\n            let view = create_quant_view_dynamic::<T, E::Size>(\n                data_buf,\n                data_layout,\n                scales_buf,\n                scales_layout,\n                scheme,\n            );\n            // Safety: should be fine since `Vector<E::Scalar, N>` is guaranteed equal to `E`\n            comptime![unsafe { core::mem::transmute(view) }]\n        }\n    }\n}\n\n#[cube]\nfn input_batch_layout(\n    inputs: &GlobalArgs,\n    batch_shape: &Sequence<FastDivmod<u32>>,\n    #[comptime] arg: MatmulArg,\n    #[comptime] config: FuseBlockConfig,\n) -> VirtualLayout<usize, usize> {\n    let rank = comptime![config.rank];\n    match comptime![arg.clone()] {\n        MatmulArg::Normal(arg) => {\n            let data_tensor = match comptime![arg.clone()] {\n                FuseArg::Input(pos, ..) => inputs.tensors.index(pos),\n                _ => panic!(\"Input must be concrete\"),\n            };\n\n            let mut batch_strides = Sequence::new();\n            #[unroll]\n            for i in 0..rank - 2 {\n                let shape = data_tensor.tensor.shape(i);\n                let stride = select(shape == 1, 0, data_tensor.tensor.stride(i));\n                batch_strides.push(stride);\n            }\n\n            VirtualLayout::new::<BatchLayout>(BatchLayout::new(batch_strides, batch_shape.clone()))\n        }\n        MatmulArg::Quantized { .. } => VirtualLayout::new::<NoopLayout>(NoopLayout::new()),\n    }\n}\n\n#[cube]\nfn global_layout(\n    inputs: &GlobalArgs,\n    shape: Coords2d,\n    batch_layout: VirtualLayout<usize, usize>,\n    #[comptime] arg: FuseArg,\n    #[comptime] config: FuseBlockConfig,\n    #[comptime] vector_size: VectorSize,\n    #[comptime] layout_config: GlobalLayoutConfig,\n    #[comptime] packing: u32,\n) -> GlobalLayout {\n    let rank = comptime![config.rank];\n    let data_tensor = match comptime![arg.clone()] {\n        FuseArg::Input(pos, ..) => inputs.tensors.index(pos),\n        _ => panic!(\"Input must be concrete\"),\n    };\n\n    let (shape_row, shape_col) = shape;\n\n    let stride_row = data_tensor.tensor.stride(rank - 2);\n    let stride_col = data_tensor.tensor.stride(rank - 1);\n\n    GlobalLayout::new(\n        batch_layout,\n        shape_row,\n        shape_col,\n        stride_row,\n        stride_col,\n        vector_size,\n        packing,\n        layout_config,\n    )\n}\n\nstruct CreateQuantView<'a, E: Numeric, N: Size> {\n    scope: &'a mut Scope,\n    data_buf: GlobalInputExpand,\n    data_layout: GlobalLayoutExpand,\n    scales_buf: GlobalInputExpand,\n    scales_layout: GlobalScaleLayoutExpand,\n    scheme: QuantScheme,\n    _ty: PhantomData<(E, N)>,\n}\n\nimpl<'a, E: Numeric, N: Size> RunWithQuantType for CreateQuantView<'a, E, N> {\n    type Output = ViewExpand<Vector<E, N>, BatchedCoords>;\n\n    fn execute<Q: Scalar, S: Scalar>(self) -> Self::Output {\n        create_quant_view::expand::<E, N, Q, S>(\n            self.scope,\n            self.data_buf,\n            self.data_layout,\n            self.scales_buf,\n            self.scales_layout,\n            self.scheme,\n        )\n    }\n}\n\n#[cube]\n#[allow(unused)]\nfn create_quant_view_dynamic<E: Numeric, N: Size>(\n    data_buf: GlobalInput,\n    data_layout: GlobalLayout,\n    scales_buf: GlobalInput,\n    scales_layout: GlobalScaleLayout,\n    #[comptime] scheme: QuantScheme,\n) -> View<Vector<E, N>, BatchedCoords> {\n    intrinsic!(|scope| {\n        let func = CreateQuantView {\n            scope,\n            data_buf,\n            data_layout,\n            scales_buf,\n            scales_layout,\n            scheme,\n            _ty: PhantomData,\n        };\n        run_with_quant_type(func, scheme)\n    })\n}\n\n#[cube]\nfn create_quant_view<E: Numeric, N: Size, Q: Scalar, S: Scalar>(\n    data_buf: GlobalInput,\n    data_layout: GlobalLayout,\n    scales_buf: GlobalInput,\n    scales_layout: GlobalScaleLayout,\n    #[comptime] scheme: QuantScheme,\n) -> View<Vector<E, N>, BatchedCoords> {\n    let size!(NQ) = N::value().comptime() / scheme.num_quants();\n\n    let data_view: View<Vector<Q, NQ>, BatchedCoords> =\n        View::new::<GlobalInput, Coords1d>(&data_buf, data_layout);\n    let scales_view: View<S, BatchedCoords> =\n        View::new::<GlobalInput, Coords1d>(&scales_buf, scales_layout);\n    QuantizedView::new(data_view, scales_view, scheme).view()\n}\n\n#[derive(CubeType)]\npub struct FusedMatmulState {\n    inputs: GlobalArgs,\n    outputs: GlobalArgs,\n    locals: LocalArgs,\n    a_batch: VirtualLayout<Coords1d, Coords1d>,\n    b_batch: VirtualLayout<Coords1d, Coords1d>,\n    c_batch: ComptimeOption<VirtualLayout<Coords1d, Coords1d>>,\n    out_batch: VirtualLayout<Coords1d, Coords1d>,\n    #[cube(comptime)]\n    config: FuseBlockConfig,\n    #[cube(comptime)]\n    a: MatmulArg,\n    #[cube(comptime)]\n    b: MatmulArg,\n    #[cube(comptime)]\n    c: Option<MatmulArg>,\n    #[cube(comptime)]\n    out: FuseArg,\n    #[cube(comptime)]\n    lhs_layout_config: GlobalLayoutConfig,\n    #[cube(comptime)]\n    rhs_layout_config: GlobalLayoutConfig,\n    #[cube(comptime)]\n    out_layout_config: GlobalLayoutConfig,\n    batch_shape: Sequence<FastDivmod<u32>>,\n}\n\n#[cube]\nimpl FusedMatmulState {\n    #[allow(clippy::too_many_arguments)]\n    pub fn new(\n        inputs: &FusedMatmulInput,\n        outputs: &mut GlobalArgs,\n        locals: &mut LocalArgs,\n        a_batch: VirtualLayout<usize, usize>,\n        b_batch: VirtualLayout<usize, usize>,\n        c_batch: ComptimeOption<VirtualLayout<usize, usize>>,\n        out_batch: VirtualLayout<usize, usize>,\n        batch_shape: Sequence<FastDivmod<u32>>,\n        #[comptime] config: &FuseBlockConfig,\n        #[comptime] lhs_layout_config: GlobalLayoutConfig,\n        #[comptime] rhs_layout_config: GlobalLayoutConfig,\n        #[comptime] out_layout_config: GlobalLayoutConfig,\n    ) -> FusedMatmulState {\n        FusedMatmulState {\n            inputs: inputs.global.clone(),\n            outputs: outputs.clone(),\n            config: comptime![config.clone()],\n            locals: locals.clone(),\n            a_batch,\n            b_batch,\n            c_batch,\n            out_batch,\n            a: comptime![inputs.a.clone()],\n            b: comptime![inputs.b.clone()],\n            c: comptime![inputs.c.clone()],\n            out: comptime![inputs.out.clone()],\n            lhs_layout_config,\n            rhs_layout_config,\n            out_layout_config,\n            batch_shape,\n        }\n    }\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]\n/// Argument to a matmul operation.\npub enum MatmulArg {\n    Normal(FuseArg),\n    Quantized {\n        data: FuseArg,\n        scales: FuseArg,\n        precision: FuseType,\n        scheme: QuantScheme,\n    },\n}\n\nimpl MatmulArg {\n    pub fn data(&self) -> &FuseArg {\n        match self {\n            MatmulArg::Normal(arg) => arg,\n            MatmulArg::Quantized { data, .. } => data,\n        }\n    }\n\n    pub fn scheme(&self) -> Option<&QuantScheme> {\n        match self {\n            MatmulArg::Normal(_) => None,\n            MatmulArg::Quantized { scheme, .. } => Some(scheme),\n        }\n    }\n\n    pub fn precision(&self) -> FuseType {\n        match self {\n            MatmulArg::Normal(arg) => arg.precision(),\n            MatmulArg::Quantized { precision, .. } => *precision,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/matmul/fuser.rs",
    "content": "use super::optimization::{FusedMatmul, MatmulOptimization};\nuse crate::{\n    engine::{fuser::TraceOperationFuser, settings::FuseSettings},\n    optim::CubeOptimization,\n    optim::matmul::args::MatmulArg,\n};\nuse burn_fusion::{FuserStatus, OperationFuser};\nuse burn_ir::{FloatOperationIr, OperationIr};\nuse burn_std::DType;\nuse cubecl::Runtime;\n\n/// Fused element wise operations that are normally memory bound.\npub struct MatmulFuser<R: Runtime> {\n    fuser: TraceOperationFuser,\n    fuser_fallback: TraceOperationFuser,\n    device: R::Device,\n    matmul: Option<FusedMatmul>,\n}\n\nimpl<R: Runtime> Clone for MatmulFuser<R> {\n    fn clone(&self) -> Self {\n        Self {\n            fuser: self.fuser.clone(),\n            fuser_fallback: self.fuser_fallback.clone(),\n            device: self.device.clone(),\n            matmul: self.matmul.clone(),\n        }\n    }\n}\n\nimpl<R: Runtime> MatmulFuser<R> {\n    pub fn new(device: R::Device) -> Self {\n        let client = R::client(&device);\n        let props = client.properties();\n        let max_bindings = props.hardware.max_bindings;\n        let settings_matmul = FuseSettings {\n            output_shape_updates: false,\n            ..Default::default()\n        };\n        let settings_fallback = FuseSettings::default();\n\n        Self {\n            fuser: TraceOperationFuser::new(max_bindings, settings_matmul),\n            fuser_fallback: TraceOperationFuser::new(max_bindings, settings_fallback),\n            device,\n            matmul: None,\n        }\n    }\n}\n\nimpl<R: Runtime> OperationFuser<CubeOptimization<R>> for MatmulFuser<R> {\n    fn fuse(&mut self, operation: &OperationIr) {\n        if let FuserStatus::Closed = self.fuser.status() {\n            return;\n        }\n\n        if self.matmul.is_none() {\n            if let OperationIr::Float(_, FloatOperationIr::Matmul(op)) = operation {\n                // Precision shouldn't be hardcoded but I don't know how to get float precision of the backend\n                let lhs = match op.lhs.dtype {\n                    DType::QFloat(scheme) => {\n                        let (data, scales) = self.fuser.input_quantized_unhandled(&op.lhs).unwrap();\n                        MatmulArg::Quantized {\n                            data,\n                            scales,\n                            precision: op.out.dtype.into(),\n                            scheme,\n                        }\n                    }\n                    _ => MatmulArg::Normal(self.fuser.input_unhandled(&op.lhs)),\n                };\n                let rhs = match op.rhs.dtype {\n                    DType::QFloat(scheme) => {\n                        let (data, scales) = self.fuser.input_quantized_unhandled(&op.rhs).unwrap();\n                        MatmulArg::Quantized {\n                            data,\n                            scales,\n                            precision: op.out.dtype.into(),\n                            scheme,\n                        }\n                    }\n                    _ => MatmulArg::Normal(self.fuser.input_unhandled(&op.rhs)),\n                };\n\n                let out = self.fuser.output_unhandled(&op.out);\n\n                self.matmul = Some(FusedMatmul::new(\n                    lhs,\n                    rhs,\n                    out,\n                    op.clone().into(),\n                    Default::default(),\n                ));\n            } else {\n                self.fuser.close();\n                self.fuser_fallback.close();\n            }\n        } else {\n            let can_register =\n                self.fuser.can_fuse(operation) && self.fuser_fallback.can_fuse(operation);\n\n            match can_register {\n                true => {\n                    self.fuser.fuse(operation);\n                    self.fuser_fallback.fuse(operation);\n                }\n                false => {\n                    self.fuser.close();\n                    self.fuser_fallback.close();\n                }\n            };\n        }\n    }\n\n    fn finish(&mut self) -> CubeOptimization<R> {\n        let client = R::client(&self.device);\n        let trace = self.fuser.finish();\n        let trace_fallback = self.fuser_fallback.finish();\n\n        let matmul = MatmulOptimization::new(\n            trace,\n            trace_fallback,\n            client,\n            self.device.clone(),\n            self.len(),\n            self.matmul.as_ref().unwrap().clone(),\n        );\n\n        CubeOptimization::Matmul(matmul)\n    }\n\n    fn reset(&mut self) {\n        self.fuser.reset();\n        self.fuser_fallback.reset();\n        self.matmul = None;\n    }\n\n    fn status(&self) -> burn_fusion::FuserStatus {\n        self.fuser.status()\n    }\n\n    fn properties(&self) -> burn_fusion::FuserProperties {\n        self.fuser.properties()\n    }\n\n    fn len(&self) -> usize {\n        // Matmul operation isn't registered in the fuser\n        self.fuser.len() + 1\n    }\n\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {\n        Box::new(self.clone())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/matmul/mod.rs",
    "content": "mod fuser;\nmod optimization;\n\npub(crate) mod args;\npub(crate) mod tune;\n\npub use fuser::*;\npub use optimization::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs",
    "content": "use super::args::FusedMatmulInputLaunch;\n#[cfg(feature = \"autotune\")]\nuse super::tune::fused_matmul_autotune;\nuse crate::{\n    CubeFusionHandle, FallbackOperation,\n    engine::{\n        codegen::ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout},\n        launch::{\n            FuseTraceLauncher, HandleInput, LaunchPlan,\n            runner::{TraceRunner, Vectorization, VectorizationAxis},\n        },\n        trace::{FuseTrace, TraceError, TuneOutput},\n    },\n    optim::{\n        elemwise::ElemwiseRunner,\n        matmul::args::{FusedMatmulArgs, MatmulArg},\n    },\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::BinaryOpIr;\nuse cubecl::{\n    client::ComputeClient,\n    prelude::*,\n    std::tensor::{MatrixBatchLayout, matrix_batch_layout},\n};\nuse cubek::{\n    matmul::{\n        components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},\n        definition::{\n            MatmulElems, MatmulGlobalElems, MatmulProblem, MatmulSetupError, MatmulVectorSizes,\n        },\n        launch::launch_kernel_virtual,\n        routines::{\n            BlueprintStrategy, Routine,\n            double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs},\n            double_unit::DoubleUnitAlgorithm,\n            ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs},\n            simple::{SimpleAlgorithm, SimpleArgs},\n            simple_unit::SimpleUnitAlgorithm,\n            vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},\n        },\n    },\n    std::MatrixLayout,\n};\nuse serde::{Deserialize, Serialize};\nuse std::sync::Arc;\n\n/// Fuse matmul operation followed by elemwise operations into a single kernel.\npub struct MatmulOptimization<R: Runtime> {\n    pub(crate) info: Arc<MatmulOptimizationInfo<R>>,\n}\n\npub struct MatmulOptimizationTuneArg<R: Runtime> {\n    pub(crate) info: Arc<MatmulOptimizationInfo<R>>,\n    pub(crate) fallback: Box<dyn FallbackOperation<R>>,\n}\n\npub(crate) struct MatmulOptimizationInfo<R: Runtime> {\n    trace: FuseTrace,\n    trace_fallback: FuseTrace,\n    pub(crate) client: ComputeClient<R>,\n    pub(crate) device: R::Device,\n    pub(crate) len: usize,\n    pub(crate) matmul: FusedMatmul,\n}\n\n#[derive(Serialize, Deserialize, Debug)]\n/// State for the [matrix optimization](MatmulOptimizationState).\npub struct MatmulOptimizationState {\n    trace: FuseTrace,\n    trace_fallback: FuseTrace,\n    matmul: FusedMatmul,\n    len: usize,\n}\n\nimpl<R: Runtime> MatmulOptimizationInfo<R> {\n    /// Returns the number of output buffers added by fusion.\n    pub fn num_output_buffers(&self) -> usize {\n        self.trace_fallback.resources.outputs.len()\n    }\n\n    /// Number of operations fused.\n    pub fn num_ops_fused(&self) -> usize {\n        self.len\n    }\n}\n\nimpl<R: Runtime> MatmulOptimizationTuneArg<R> {\n    pub(crate) fn execute_fused(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        selector: FusedMatmulSelector,\n    ) -> Result<TuneOutput<R>, TraceError<FusedMatmulError>> {\n        let launch = FusedMatmulLaunch::new(&self.info.matmul, selector);\n        let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);\n\n        launcher.launch(&self.info.client, &self.info.device, context)\n    }\n\n    pub fn execute_fallback(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n    ) -> TuneOutput<R> {\n        self.fallback.run(context);\n\n        #[cfg(feature = \"autotune-checks\")]\n        let mut output = TuneOutput::Checked {\n            handles: Default::default(),\n        };\n        #[cfg(not(feature = \"autotune-checks\"))]\n        let output = TuneOutput::UnChecked(core::marker::PhantomData);\n\n        #[cfg(feature = \"autotune-checks\")]\n        if let TuneOutput::Checked { handles } = &mut output {\n            let out_desc = context.tensors.get(&self.info.matmul.op.out.id).unwrap();\n            let handle_out = context\n                .handles\n                .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);\n\n            handles.insert(\n                self.info.matmul.op.out.id,\n                (out_desc.shape.dims.clone(), handle_out.clone()),\n            );\n        }\n\n        let launcher = FuseTraceLauncher::new(&self.info.trace_fallback, &ElemwiseRunner);\n        let output_write = launcher\n            .launch(&self.info.client, &self.info.device, context)\n            .unwrap();\n\n        output.merge(output_write)\n    }\n}\n\nimpl<R: Runtime> MatmulOptimization<R> {\n    pub fn new(\n        trace: FuseTrace,\n        trace_fallback: FuseTrace,\n        client: ComputeClient<R>,\n        device: R::Device,\n        len: usize,\n        matmul: FusedMatmul,\n    ) -> Self {\n        let info = MatmulOptimizationInfo {\n            trace,\n            trace_fallback,\n            client,\n            device,\n            len,\n            matmul,\n        };\n\n        Self {\n            info: Arc::new(info),\n        }\n    }\n    /// Execute the optimization.\n    pub fn execute(\n        &mut self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,\n    ) {\n        // The index of the fallback matmul is always 0.\n        let fallback = fallback(0);\n        let arg = MatmulOptimizationTuneArg {\n            info: self.info.clone(),\n            fallback,\n        };\n\n        #[cfg(feature = \"autotune\")]\n        fused_matmul_autotune::<R>(arg, context);\n\n        #[cfg(not(feature = \"autotune\"))]\n        if arg\n            .execute_fused(context, FusedMatmulSelector::default())\n            .is_err()\n        {\n            arg.execute_fallback(context);\n        }\n    }\n\n    /// Number of operations fused.\n    pub fn num_ops_fused(&self) -> usize {\n        self.info.num_ops_fused()\n    }\n\n    /// Create an optimization from its [state](MatmulOptimizationState).\n    pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self {\n        let info = MatmulOptimizationInfo {\n            trace: state.trace,\n            trace_fallback: state.trace_fallback,\n            len: state.len,\n            client: R::client(device),\n            device: device.clone(),\n            matmul: state.matmul.clone(),\n        };\n\n        Self {\n            info: Arc::new(info),\n        }\n    }\n\n    /// Convert the optimization to its [state](MatmulOptimizationState).\n    pub fn to_state(&self) -> MatmulOptimizationState {\n        MatmulOptimizationState {\n            trace: self.info.trace.clone(),\n            trace_fallback: self.info.trace_fallback.clone(),\n            matmul: self.info.matmul.clone(),\n            len: self.info.len,\n        }\n    }\n}\n\n#[derive(Clone, Copy, Serialize, Deserialize, Debug)]\npub enum FusedMatmulSelector {\n    Simple {\n        multi_rows: bool,\n        tile_matmul: AcceleratedTileKind,\n    },\n    DoubleBuffering {\n        specialized: bool,\n        tile_matmul: AcceleratedTileKind,\n    },\n    OrderedDoubleBuffering {\n        tile_matmul: AcceleratedTileKind,\n    },\n    SimpleVecMat,\n    DoubleVecMat,\n    SimpleUnit,\n    DoubleUnit,\n}\n\nimpl FusedMatmulSelector {\n    /// Not efficient, but only called once when initializing the tunables.\n    pub fn name(&self) -> String {\n        let name = match self {\n            FusedMatmulSelector::Simple {\n                multi_rows,\n                tile_matmul,\n            } => match multi_rows {\n                false => format!(\"simple_{tile_matmul:?}\"),\n                true => format!(\"simple_multirows_{tile_matmul:?}\"),\n            },\n            FusedMatmulSelector::DoubleBuffering {\n                specialized,\n                tile_matmul,\n            } => match specialized {\n                false => format!(\"double_buffering_{tile_matmul:?}\"),\n                true => format!(\"double_buffering_specialized_{tile_matmul:?}\"),\n            },\n            FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {\n                format!(\"double_buffering_ordered_{tile_matmul:?}\").to_lowercase()\n            }\n            FusedMatmulSelector::SimpleVecMat => \"simple_vec_mat\".into(),\n            FusedMatmulSelector::DoubleVecMat => \"double_buffering_vec_mat\".into(),\n            FusedMatmulSelector::SimpleUnit => \"simple_unit\".into(),\n            FusedMatmulSelector::DoubleUnit => \"double_buffering_unit\".into(),\n        };\n\n        format!(\"fused_{name}\")\n    }\n}\n\nimpl Default for FusedMatmulSelector {\n    fn default() -> Self {\n        FusedMatmulSelector::Simple {\n            multi_rows: false,\n            tile_matmul: AcceleratedTileKind::Cmma,\n        }\n    }\n}\n\n#[derive(new, Clone, Serialize, Deserialize, Debug)]\npub struct FusedMatmul {\n    pub(crate) lhs: MatmulArg,\n    pub(crate) rhs: MatmulArg,\n    out: FuseArg,\n    pub(crate) op: BinaryOpIr,\n    pub(crate) selector: FusedMatmulSelector,\n}\n\n#[derive(new)]\npub struct FusedMatmulLaunch<'a> {\n    pub(crate) matmul: &'a FusedMatmul,\n    pub(crate) selector: FusedMatmulSelector,\n}\n\n#[derive(Debug)]\npub enum FusedMatmulError {\n    LaunchError(MatmulSetupError),\n    InvalidInput(&'static str),\n}\n\nimpl From<MatmulSetupError> for FusedMatmulError {\n    fn from(value: MatmulSetupError) -> Self {\n        Self::LaunchError(value)\n    }\n}\n\nimpl<'a, R: Runtime> Vectorization<R> for FusedMatmulLaunch<'a> {\n    fn axis(&self, plan: &LaunchPlan<'_, R>) -> VectorizationAxis {\n        let lhs_id = self.matmul.op.lhs.id;\n        let rhs_id = self.matmul.op.rhs.id;\n\n        let mut tensor_lhs = None;\n        let mut tensor_rhs = None;\n\n        for input in plan.handle_inputs.iter() {\n            match input {\n                HandleInput::Normal(input) => {\n                    if input.relative_id == lhs_id {\n                        tensor_lhs = Some((input.global_ir.id, &input.handle.strides));\n                    }\n                    if input.relative_id == rhs_id {\n                        tensor_rhs = Some((input.global_ir.id, &input.handle.strides));\n                    }\n                }\n                HandleInput::QuantValues(input) => {\n                    if input.relative_id == lhs_id {\n                        tensor_lhs = Some((input.global_ir.id, &input.handle.strides));\n                    }\n                    if input.relative_id == rhs_id {\n                        tensor_rhs = Some((input.global_ir.id, &input.handle.strides));\n                    }\n                }\n                HandleInput::QuantParams(_) => {}\n            }\n        }\n\n        let (lhs_id_global, lhs_strides) = tensor_lhs.unwrap();\n        let (rhs_id_global, rhs_strides) = tensor_rhs.unwrap();\n\n        let mut axis = VectorizationAxis::default();\n\n        if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =\n            matrix_batch_layout(lhs_strides, self.matmul.lhs.scheme())\n            && transposed\n        {\n            axis.insert(lhs_id_global, lhs_strides.len() - 2);\n        }\n\n        if let MatrixBatchLayout::MildlyPermuted { transposed, .. } =\n            matrix_batch_layout(rhs_strides, self.matmul.rhs.scheme())\n            && transposed\n        {\n            axis.insert(rhs_id_global, rhs_strides.len() - 2);\n        }\n\n        axis\n    }\n}\n\nimpl<R: Runtime> TraceRunner<R> for FusedMatmulLaunch<'_> {\n    type Error = FusedMatmulError;\n\n    fn run<'a>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        configs: &'a [FuseBlockConfig],\n    ) -> Result<(), FusedMatmulError> {\n        let global_elems = MatmulGlobalElems {\n            lhs: self.matmul.lhs.precision().into_storage_type(),\n            rhs: self.matmul.rhs.precision().into_storage_type(),\n            out: self.matmul.out.precision().into_storage_type(),\n        };\n        let dtypes = MatmulElems::from_globals(&global_elems);\n        self.matmul_fused(client, inputs, outputs, &configs[0], dtypes)\n    }\n}\n\n#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]\n/// Which tile matmul to use for accelerated algorithms\npub enum AcceleratedTileKind {\n    #[default]\n    Cmma,\n    Mma,\n}\n\nmacro_rules! with_tile_kind {\n    ($kind: expr, $T: ident, $launch: expr) => {\n        match $kind {\n            AcceleratedTileKind::Cmma => {\n                type $T = CmmaMatmul;\n                ($launch)()\n            }\n            AcceleratedTileKind::Mma => {\n                type $T = MmaMatmul;\n                ($launch)()\n            }\n        }\n    };\n}\n\nimpl FusedMatmulLaunch<'_> {\n    fn matmul_fused<'a, R: Runtime>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        config: &'a FuseBlockConfig,\n        dtypes: MatmulElems,\n    ) -> Result<(), FusedMatmulError> {\n        let lhs_shape = inputs.shape(self.matmul.lhs.data());\n        let rhs_shape = inputs.shape(self.matmul.rhs.data());\n        let out_shape = outputs.shape_ref(&config.ref_layout, config.rank);\n\n        let lhs_strides = inputs.strides(self.matmul.lhs.data());\n        let lhs_scheme = self.matmul.lhs.scheme();\n        let rhs_strides = inputs.strides(self.matmul.rhs.data());\n        let rhs_scheme = self.matmul.rhs.scheme();\n\n        if matrix_batch_layout(&lhs_strides, lhs_scheme) == MatrixBatchLayout::HighlyPermuted {\n            return Err(FusedMatmulError::InvalidInput(\n                \"Lhs needs to be contiguous, but can't when fusing.\",\n            ));\n        }\n        if matrix_batch_layout(&rhs_strides, rhs_scheme) == MatrixBatchLayout::HighlyPermuted {\n            return Err(FusedMatmulError::InvalidInput(\n                \"Rhs needs to be contiguous, but can't when fusing.\",\n            ));\n        }\n\n        let mut vector_sizes = MatmulVectorSizes {\n            lhs: inputs.vector_size(self.matmul.lhs.data()),\n            rhs: inputs.vector_size(self.matmul.rhs.data()),\n            out: match &config.ref_layout {\n                RefLayout::Concrete(arg) => match arg {\n                    FuseArg::Input(..) => inputs.vector_size(arg),\n                    FuseArg::Output(..) => outputs.vector_size(arg),\n                    _ => panic!(\"Invalid ref layout\"),\n                },\n                RefLayout::Virtual(_) => 1,\n            },\n        };\n\n        let address_type = inputs\n            .required_address_type()\n            .max(outputs.required_address_type());\n\n        if vector_sizes.out == 1 && (vector_sizes.lhs > 1 || vector_sizes.rhs > 1) {\n            return Err(FusedMatmulError::InvalidInput(\n                \"Output vector size of 1 removes the gain from fusion\",\n            ));\n        }\n\n        if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs {\n            vector_sizes.lhs *= scheme.num_quants();\n        }\n        if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs {\n            vector_sizes.rhs *= scheme.num_quants();\n        }\n\n        let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape);\n        let problem = MatmulProblem::from_shapes_and_strides(\n            lhs_shape,\n            rhs_shape,\n            out_shape,\n            lhs_strides,\n            rhs_strides,\n            out_strides,\n            dtypes.as_global_elems(),\n            address_type,\n            self.matmul.lhs.scheme(),\n            self.matmul.rhs.scheme(),\n        )?;\n\n        match self.selector {\n            FusedMatmulSelector::Simple {\n                multi_rows,\n                tile_matmul,\n            } => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<\n                R,\n                SimpleAlgorithm<Accelerated>,\n            >(\n                client,\n                FusedMatmulInputLaunch::new(\n                    inputs,\n                    config.clone(),\n                    self.matmul.lhs.clone(),\n                    self.matmul.rhs.clone(),\n                    None,\n                    self.matmul.out.clone(),\n                ),\n                outputs,\n                problem,\n                vector_sizes,\n                &BlueprintStrategy::Inferred(SimpleArgs { multi_rows }),\n            ) {\n                Ok(_) => Ok(()),\n                Err(err) => Err(FusedMatmulError::LaunchError(err)),\n            }),\n            FusedMatmulSelector::DoubleBuffering {\n                specialized,\n                tile_matmul,\n            } => with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<\n                R,\n                CyclicDoubleBufferingAlgorithm<Accelerated>,\n            >(\n                client,\n                FusedMatmulInputLaunch::new(\n                    inputs,\n                    config.clone(),\n                    self.matmul.lhs.clone(),\n                    self.matmul.rhs.clone(),\n                    None,\n                    self.matmul.out.clone(),\n                ),\n                outputs,\n                problem,\n                vector_sizes,\n                &BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized }),\n            ) {\n                Ok(_) => Ok(()),\n                Err(err) => Err(FusedMatmulError::LaunchError(err)),\n            }),\n            FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul } => {\n                let row_count = match self.matmul.lhs.precision() {\n                    FuseType::F16 | FuseType::BF16 => 8,\n                    _ => 4,\n                };\n\n                with_tile_kind!(tile_matmul, Accelerated, || match launch_inner_fix_dtype::<\n                    R,\n                    OrderedDoubleBufferingAlgorithm<Accelerated>,\n                >(\n                    client,\n                    FusedMatmulInputLaunch::new(\n                        inputs,\n                        config.clone(),\n                        self.matmul.lhs.clone(),\n                        self.matmul.rhs.clone(),\n                        None,\n                        self.matmul.out.clone(),\n                    ),\n                    outputs,\n                    problem,\n                    vector_sizes,\n                    &BlueprintStrategy::Inferred(OrderedSelectionArgs {\n                        row_count: Some(row_count),\n                        rows_per_plane: Some(2),\n                        partition_k: Some(2),\n                    }),\n                ) {\n                    Ok(_) => Ok(()),\n                    Err(err) => Err(FusedMatmulError::LaunchError(err)),\n                })\n            }\n            FusedMatmulSelector::SimpleUnit => {\n                match launch_inner_fix_dtype::<R, SimpleUnitAlgorithm>(\n                    client,\n                    FusedMatmulInputLaunch::new(\n                        inputs,\n                        config.clone(),\n                        self.matmul.lhs.clone(),\n                        self.matmul.rhs.clone(),\n                        None,\n                        self.matmul.out.clone(),\n                    ),\n                    outputs,\n                    problem,\n                    vector_sizes,\n                    &Default::default(),\n                ) {\n                    Ok(_) => Ok(()),\n                    Err(err) => Err(FusedMatmulError::LaunchError(err)),\n                }\n            }\n            FusedMatmulSelector::DoubleUnit => {\n                match launch_inner_fix_dtype::<R, DoubleUnitAlgorithm>(\n                    client,\n                    FusedMatmulInputLaunch::new(\n                        inputs,\n                        config.clone(),\n                        self.matmul.lhs.clone(),\n                        self.matmul.rhs.clone(),\n                        None,\n                        self.matmul.out.clone(),\n                    ),\n                    outputs,\n                    problem,\n                    vector_sizes,\n                    &Default::default(),\n                ) {\n                    Ok(_) => Ok(()),\n                    Err(err) => Err(FusedMatmulError::LaunchError(err)),\n                }\n            }\n            FusedMatmulSelector::SimpleVecMat => {\n                match launch_inner_fix_dtype::<R, SimpleVecMatAlgorithm>(\n                    client,\n                    FusedMatmulInputLaunch::new(\n                        inputs,\n                        config.clone(),\n                        self.matmul.lhs.clone(),\n                        self.matmul.rhs.clone(),\n                        None,\n                        self.matmul.out.clone(),\n                    ),\n                    outputs,\n                    problem,\n                    vector_sizes,\n                    &Default::default(),\n                ) {\n                    Ok(_) => Ok(()),\n                    Err(err) => Err(FusedMatmulError::LaunchError(err)),\n                }\n            }\n            FusedMatmulSelector::DoubleVecMat => {\n                match launch_inner_fix_dtype::<R, DoubleVecMatAlgorithm>(\n                    client,\n                    FusedMatmulInputLaunch::new(\n                        inputs,\n                        config.clone(),\n                        self.matmul.lhs.clone(),\n                        self.matmul.rhs.clone(),\n                        None,\n                        self.matmul.out.clone(),\n                    ),\n                    outputs,\n                    problem,\n                    vector_sizes,\n                    &Default::default(),\n                ) {\n                    Ok(_) => Ok(()),\n                    Err(err) => Err(FusedMatmulError::LaunchError(err)),\n                }\n            }\n        }\n    }\n}\n\nfn launch_inner_fix_dtype<R: Runtime, A: Routine<()>>(\n    client: &ComputeClient<R>,\n    input: FusedMatmulInputLaunch<R>,\n    output: GlobalArgsLaunch<R>,\n    problem: MatmulProblem,\n    vector_sizes: MatmulVectorSizes,\n    blueprint_strategy: &BlueprintStrategy<(), A>,\n) -> Result<(), MatmulSetupError> {\n    launch_kernel_virtual::<FusedMatmulArgs, R, A>(\n        client,\n        input,\n        output,\n        (),\n        problem,\n        vector_sizes,\n        blueprint_strategy,\n    )\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/matmul/tune.rs",
    "content": "use super::optimization::MatmulOptimizationTuneArg;\nuse crate::{\n    CubeFusionHandle,\n    engine::trace::TuneOutput,\n    optim::matmul::{AcceleratedTileKind, FusedMatmulSelector},\n    tune::{TuneContext, TuneInput},\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{\n    AutotuneKey, CubeTuneId, Runtime,\n    tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},\n};\nuse cubek::matmul::{\n    definition::MatmulKind,\n    launch::{MatmulAutotuneKey, MatmulGlobalScale, should_tune_double_buffering},\n};\nuse serde::{Deserialize, Serialize};\n\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\npub struct FusedMatmulAutotuneKey {\n    matmul_key: MatmulAutotuneKey,\n    #[autotune(anchor)]\n    num_out_buffers: usize,\n    #[autotune(anchor)]\n    num_ops: usize,\n}\n\n/// Executes autotune on matmul operations\npub fn fused_matmul_autotune<R: Runtime>(\n    optimization: MatmulOptimizationTuneArg<R>,\n    context: &mut Context<CubeFusionHandle<R>>,\n) {\n    static TUNER: LocalTuner<FusedMatmulAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 3;\n        const PRIORITY_HIGH: i8 = 2;\n        const PRIORITY_MEDIUM: i8 = 1;\n        const PRIORITY_MIN: i8 = 0;\n\n        let cmma = TuneGroup::<FusedMatmulAutotuneKey>::new(\"cmma\", |key| {\n            if matches!(\n                key.matmul_key.analysis.kind,\n                MatmulKind::General\n                // Those variants are just because the unit alternatives aren't very good yet.\n                | MatmulKind::VecMat | MatmulKind::MatVec\n            ) {\n                PRIORITY_MAX\n            } else {\n                PRIORITY_MEDIUM\n            }\n        });\n\n        let mma = TuneGroup::<FusedMatmulAutotuneKey>::new(\"mma\", |key| {\n            if matches!(\n                key.matmul_key.analysis.kind,\n                // General is usually bad, but I think shapes like 16x8196 would be classed as\n                // general and are very good with MMA\n                // Should highly degenerated matrices that aren't VecMat have their own class?\n                MatmulKind::General | MatmulKind::VecMat | MatmulKind::MatVec\n            ) {\n                PRIORITY_MAX\n            } else {\n                PRIORITY_MEDIUM\n            }\n        });\n\n        let odd = TuneGroup::<FusedMatmulAutotuneKey>::new(\"odd\", |key| {\n            if key.matmul_key.definition.lhs_pow2_factor == 0\n                || key.matmul_key.definition.rhs_pow2_factor == 0\n            {\n                PRIORITY_MAX\n            } else {\n                PRIORITY_MIN\n            }\n        });\n\n        let unit = TuneGroup::<FusedMatmulAutotuneKey>::new(\"unit\", |key| {\n            if !matches!(key.matmul_key.analysis.kind, MatmulKind::General)\n                || matches!(\n                    key.matmul_key.analysis.scale_global,\n                    MatmulGlobalScale::Small\n                )\n            {\n                PRIORITY_MAX\n            } else {\n                PRIORITY_MIN\n            }\n        });\n\n        fn double_buffering_priority(key: &FusedMatmulAutotuneKey, max: i8, min: i8) -> i8 {\n            if should_tune_double_buffering(key.num_out_buffers > 1, &key.matmul_key) {\n                max\n            } else {\n                min\n            }\n        }\n\n        let mut set = TunableSet::new(create_key::<R>, input_gen::<R>)\n            .with(Tunable::new(\"fused_matmul_fallback\", tune_fallback::<R>)); // First one should always work.\n\n        // Unit matmuls\n        for (selector, double_buf) in [\n            (FusedMatmulSelector::SimpleUnit, false),\n            (FusedMatmulSelector::DoubleUnit, true),\n            (FusedMatmulSelector::SimpleVecMat, false),\n            (FusedMatmulSelector::DoubleVecMat, true),\n        ] {\n            set = set.with(\n                Tunable::new(selector.name(), move |input| {\n                    tune_fused::<R>(input, selector)\n                })\n                .group(&unit, move |key| match double_buf {\n                    true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),\n                    false => PRIORITY_MAX,\n                }),\n            );\n        }\n\n        // Accelerated matmuls\n        for (tile_matmul, group) in [\n            (AcceleratedTileKind::Cmma, &cmma),\n            (AcceleratedTileKind::Mma, &mma),\n        ] {\n            for (selector, double_buf, extra_group) in [\n                (\n                    FusedMatmulSelector::Simple {\n                        multi_rows: false,\n                        tile_matmul,\n                    },\n                    false,\n                    None,\n                ),\n                (\n                    FusedMatmulSelector::Simple {\n                        multi_rows: true,\n                        tile_matmul,\n                    },\n                    false,\n                    None,\n                ),\n                (\n                    FusedMatmulSelector::OrderedDoubleBuffering { tile_matmul },\n                    true,\n                    None,\n                ),\n                (\n                    FusedMatmulSelector::DoubleBuffering {\n                        specialized: false,\n                        tile_matmul,\n                    },\n                    true,\n                    None,\n                ),\n                (\n                    FusedMatmulSelector::DoubleBuffering {\n                        specialized: true,\n                        tile_matmul,\n                    },\n                    true,\n                    Some(&odd),\n                ),\n            ] {\n                let mut tunable = Tunable::new(selector.name(), move |input| {\n                    tune_fused::<R>(input, selector)\n                })\n                .group(group, move |key| match double_buf {\n                    true => double_buffering_priority(key, PRIORITY_MAX, PRIORITY_HIGH),\n                    false => PRIORITY_MAX,\n                });\n                if let Some(group) = extra_group {\n                    tunable = tunable.group(group, |_| PRIORITY_MAX);\n                }\n                set = set.with(tunable);\n            }\n        }\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&optimization.info.client, &optimization.info.device),\n        &optimization.info.client.clone(),\n        tunables,\n        TuneInput::new(context, optimization),\n    );\n}\n\npub(crate) fn create_key<R: Runtime>(\n    input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,\n) -> FusedMatmulAutotuneKey {\n    let opt = input.optimization();\n    let context = match input.context() {\n        TuneContext::Original(context) => context,\n        TuneContext::Fork(_) => panic!(\"Not supported when generating key\"),\n    };\n\n    let lhs = context.tensors.get(&opt.info.matmul.op.lhs.id).unwrap();\n    let rhs = context.tensors.get(&opt.info.matmul.op.rhs.id).unwrap();\n    let out = context.tensors.get(&opt.info.matmul.op.out.id).unwrap();\n\n    let lhs_strides = context\n        .handles\n        .get_handle(&lhs.id, &burn_ir::TensorStatus::ReadOnly)\n        .strides\n        .clone();\n    let rhs_strides = context\n        .handles\n        .get_handle(&rhs.id, &burn_ir::TensorStatus::ReadOnly)\n        .strides\n        .clone();\n\n    let key = MatmulAutotuneKey::generate(\n        &opt.info.client,\n        &lhs.shape,\n        &rhs.shape,\n        &lhs_strides,\n        &rhs_strides,\n        lhs.dtype.into(),\n        rhs.dtype.into(),\n        out.dtype.into(),\n        opt.info.matmul.lhs.scheme(),\n        opt.info.matmul.rhs.scheme(),\n    );\n    FusedMatmulAutotuneKey::new(key, opt.info.num_output_buffers(), opt.info.num_ops_fused())\n}\n\nfn input_gen<R: Runtime>(\n    _key: &FusedMatmulAutotuneKey,\n    input: &TuneInput<R, MatmulOptimizationTuneArg<R>>,\n) -> TuneInput<R, MatmulOptimizationTuneArg<R>> {\n    input.clone()\n}\n\nfn tune_fused<R: Runtime>(\n    input: TuneInput<R, MatmulOptimizationTuneArg<R>>,\n    selector: FusedMatmulSelector,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n    let context = input.context();\n\n    match context {\n        TuneContext::Original(context) => match optimization.execute_fused(context, selector) {\n            Ok(out) => Ok(out),\n            Err(_) => {\n                return tune_fallback::<R>(input);\n            }\n        },\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fused(&mut context_owned.as_context(), selector)\n        }\n    }\n    .map_err(|e| format!(\"{e:?}\"))\n}\n\nfn tune_fallback<R: Runtime>(\n    input: TuneInput<R, MatmulOptimizationTuneArg<R>>,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n    let context = input.context();\n\n    Ok(match context {\n        TuneContext::Original(context) => optimization.execute_fallback(context),\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fallback(&mut context_owned.as_context())\n        }\n    })\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/mod.rs",
    "content": "pub mod elemwise;\npub mod matmul;\npub mod reduce;\npub mod reduce_broadcasted;\n\nmod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce/args.rs",
    "content": "use crate::engine::codegen::{\n    io::{ref_buffer_len, ref_len, ref_shape, ref_stride, ref_vector_size},\n    ir::{FuseArg, FuseBlockConfig, GlobalArgs, GlobalArgsExpand, LocalArgs, LocalArgsExpand},\n    kernel::{fuse_on_read, fuse_on_write, init_locals},\n};\nuse cubecl::prelude::*;\nuse cubek::reduce::components::args::{ReduceArgs, ReduceDType};\n\n#[derive(Clone)]\npub struct FusedReduceArgs;\n\n#[derive(CubeType, CubeLaunch)]\npub struct FusedReduceInput {\n    pub global: GlobalArgs,\n    #[cube(comptime)]\n    pub config: FuseBlockConfig,\n    #[cube(comptime)]\n    pub arg: FuseArg,\n}\n\n#[derive(CubeType, CubeLaunch)]\npub struct FusedReduceOutput {\n    pub global: GlobalArgs,\n    #[cube(comptime)]\n    pub config: FuseBlockConfig,\n    #[cube(comptime)]\n    pub arg: FuseArg,\n}\n\npub struct FusedReduceState {\n    inputs: *const GlobalArgs,\n    outputs: *mut GlobalArgs,\n    locals_on_read: *mut LocalArgs,\n    locals_on_write: *mut LocalArgs,\n    config_on_read: FuseBlockConfig,\n    config_on_write: FuseBlockConfig,\n    // TODO: Should be a list when multiple blocks are there.\n    input: FuseArg,\n    out: FuseArg,\n}\n\n#[derive(Clone)]\npub struct FusedReduceStateExpand {\n    inputs: GlobalArgsExpand,\n    outputs: GlobalArgsExpand,\n    locals_on_read: LocalArgsExpand,\n    locals_on_write: LocalArgsExpand,\n    config_on_read: FuseBlockConfig,\n    config_on_write: FuseBlockConfig,\n    input: FuseArg,\n    out: FuseArg,\n}\n\n#[cube]\nimpl ReduceArgs for FusedReduceArgs {\n    type Input<E: Numeric, S: Size> = FusedReduceInput;\n    type Output<E: Numeric, S: Size> = FusedReduceOutput;\n    type State<P: ReduceDType> = FusedReduceState;\n\n    fn init_state<P: ReduceDType>(\n        input: &Self::Input<P::In, P::SizeIn>,\n        output: &mut Self::Output<P::Out, P::SizeOut>,\n    ) -> Self::State<P> {\n        let mut locals_read = init_locals(&input.global, &mut output.global, &input.config);\n        let mut locals_write = init_locals(&input.global, &mut output.global, &output.config);\n        // TODO Add stuff from previous blocks to the local of each block.\n        FusedReduceState::new(input, output, &mut locals_read, &mut locals_write)\n    }\n\n    fn read_input<P: ReduceDType>(\n        state: &Self::State<P>,\n        index: usize,\n    ) -> Vector<P::In, P::SizeIn> {\n        let value = fuse_on_read::<P::In, P::SizeIn>(\n            unsafe { &(*state.inputs) },\n            unsafe { &mut (*state.outputs) },\n            unsafe { &mut (*state.locals_on_read) },\n            index,\n            comptime! {\n                let mut sequence = Sequence::new();\n                // TODO: Register local arguments from previous blocks.\n                sequence.push(state.input.clone());\n                sequence\n            },\n            &state.config_on_read,\n        )[0];\n        value\n    }\n\n    fn read_output<P: ReduceDType>(\n        _state: &Self::State<P>,\n        _index: usize,\n    ) -> Vector<P::Out, P::SizeOut> {\n        Vector::empty()\n    }\n\n    fn write_output<P: ReduceDType>(\n        state: &mut Self::State<P>,\n        index: usize,\n        value: Vector<P::Out, P::SizeOut>,\n    ) {\n        let mut values = Registry::<FuseArg, Vector<P::Out, P::SizeOut>>::new();\n        let mut args = comptime![Vec::<FuseArg>::new()];\n\n        values.insert(comptime![state.out.clone()], value);\n        comptime![args.push(state.out.clone())];\n        fuse_on_write(\n            unsafe { &(*state.inputs) },\n            unsafe { &mut (*state.outputs) },\n            unsafe { &mut (*state.locals_on_write) },\n            index,\n            values,\n            args,\n            &state.config_on_write,\n        );\n    }\n\n    fn len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        ref_len(\n            unsafe { &(*state.inputs) },\n            unsafe { &(*state.outputs) },\n            unsafe { &(*state.locals_on_read) },\n            &state.config_on_read,\n        )\n    }\n\n    fn len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        ref_len(\n            unsafe { &(*state.inputs) },\n            unsafe { &(*state.outputs) },\n            unsafe { &(*state.locals_on_write) },\n            &state.config_on_write,\n        )\n    }\n\n    fn buffer_len_input<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        ref_buffer_len(\n            unsafe { &(*state.inputs) },\n            unsafe { &(*state.outputs) },\n            unsafe { &(*state.locals_on_read) },\n            &state.config_on_read,\n        )\n    }\n\n    fn buffer_len_output<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        ref_buffer_len(\n            unsafe { &(*state.inputs) },\n            unsafe { &(*state.outputs) },\n            unsafe { &(*state.locals_on_write) },\n            &state.config_on_write,\n        )\n    }\n\n    fn rank_input<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        state.config_on_read.rank.runtime()\n    }\n\n    fn rank_output<P: ReduceDType>(state: &Self::State<P>) -> usize {\n        state.config_on_write.rank.runtime()\n    }\n\n    fn shape_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {\n        ref_shape(unsafe { &(*state.locals_on_read) }, dim)\n    }\n\n    fn shape_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {\n        ref_shape(unsafe { &(*state.locals_on_write) }, dim)\n    }\n\n    fn stride_input<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {\n        ref_stride(unsafe { &(*state.locals_on_read) }, dim)\n    }\n\n    fn stride_output<P: ReduceDType>(state: &Self::State<P>, dim: usize) -> usize {\n        ref_stride(unsafe { &(*state.locals_on_write) }, dim)\n    }\n\n    fn vector_size_input<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(VectorSize) {\n        ref_vector_size(unsafe { &(*state.locals_on_read) })\n    }\n\n    fn vector_size_output<P: ReduceDType>(state: &Self::State<P>) -> comptime_type!(VectorSize) {\n        ref_vector_size(unsafe { &(*state.locals_on_write) })\n    }\n}\n\n#[cube]\nimpl FusedReduceState {\n    pub fn new(\n        inputs: &FusedReduceInput,\n        outputs: &mut FusedReduceOutput,\n        locals_on_read: &mut LocalArgs,\n        locals_on_write: &mut LocalArgs,\n    ) -> FusedReduceState {\n        FusedReduceState {\n            inputs: &inputs.global,\n            outputs: &mut outputs.global,\n            locals_on_read,\n            locals_on_write,\n            config_on_read: comptime![inputs.config.clone()],\n            config_on_write: comptime![outputs.config.clone()],\n            input: comptime![inputs.arg.clone()],\n            out: comptime![outputs.arg.clone()],\n        }\n    }\n}\n\nimpl CubeType for FusedReduceState {\n    type ExpandType = FusedReduceStateExpand;\n}\n\nimpl IntoMut for FusedReduceStateExpand {\n    fn into_mut(self, _context: &mut Scope) -> Self {\n        self\n    }\n}\n\nimpl CubeDebug for FusedReduceStateExpand {}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce/fuser.rs",
    "content": "use super::{\n    ReduceSettings,\n    optimization::{FusedReduce, ReduceInstruction, ReduceOptimization},\n};\nuse crate::{\n    engine::{\n        codegen::ir::FuseType,\n        fuser::TraceOperationFuser,\n        settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},\n    },\n    optim::CubeOptimization,\n};\nuse burn_fusion::{FuserStatus, OperationFuser};\nuse burn_ir::{NumericOperationIr, OperationIr, ReduceDimOpIr};\nuse burn_std::Shape;\nuse cubecl::Runtime;\n\n/// Fuses element wise operations around a reduce operation.\npub struct ReduceFuser<R: Runtime> {\n    pub(crate) fuser: TraceOperationFuser,\n    pub(crate) fuser_read_fallback: TraceOperationFuser,\n    fuser_write_fallback: TraceOperationFuser,\n    settings_write: FuseSettings,\n    pub(crate) device: R::Device,\n    pub(crate) reduce: Option<FusedReduce>,\n    settings: ReduceSettings,\n}\n\nimpl<R: Runtime> Clone for ReduceFuser<R> {\n    fn clone(&self) -> Self {\n        Self {\n            fuser: self.fuser.clone(),\n            fuser_read_fallback: self.fuser_read_fallback.clone(),\n            fuser_write_fallback: self.fuser_write_fallback.clone(),\n            settings_write: self.settings_write,\n            device: self.device.clone(),\n            reduce: self.reduce.clone(),\n            settings: self.settings,\n        }\n    }\n}\n\n#[derive(Debug)]\npub enum ReduceFuserInfo {\n    FusedReduce { shape_input_id: Shape, axis: usize },\n    FusedElemwise { shape_id: Shape },\n}\n\nimpl<R: Runtime> ReduceFuser<R> {\n    pub fn new(device: R::Device, settings: ReduceSettings) -> Self {\n        let client = R::client(&device);\n        let props = client.properties();\n        let max_bindings = props.hardware.max_bindings;\n        let settings_read = FuseSettings {\n            // Inplace would work, but not when we have a concrete output to write too.\n            inplace: true,\n            ref_layout: RefLayoutSetting::OnlyContiguous,\n            broadcast: false,\n            output_shape_updates: true,\n            vectorization: VectorizationSetting::Activated,\n        };\n        let settings_write = FuseSettings {\n            inplace: false,\n            output_shape_updates: false,\n            vectorization: VectorizationSetting::SmallerOrEqualThanPreviousBlock { block_pos: 0 },\n            broadcast: false,\n            ref_layout: RefLayoutSetting::OnlyContiguous,\n        };\n        let settings_fallback = FuseSettings::default();\n\n        Self {\n            fuser: TraceOperationFuser::new(max_bindings, settings_read),\n            fuser_read_fallback: TraceOperationFuser::new(max_bindings, settings_fallback),\n            fuser_write_fallback: TraceOperationFuser::new(max_bindings, settings_fallback),\n            settings_write,\n            device,\n            reduce: None,\n            settings,\n        }\n    }\n\n    pub fn reduce_info(&self) -> ReduceFuserInfo {\n        match &self.reduce {\n            Some(reduce) => {\n                let shape_input_id = reduce.op.input.shape.clone();\n                let axis = reduce.axis;\n\n                ReduceFuserInfo::FusedReduce {\n                    shape_input_id,\n                    axis,\n                }\n            }\n            None => {\n                let shape_id = self.fuser_read_fallback.current_output_shape.clone();\n                ReduceFuserInfo::FusedElemwise { shape_id }\n            }\n        }\n    }\n    fn on_reduce(&mut self, op: &ReduceDimOpIr, inst: ReduceInstruction) {\n        // TODO: Fix: we need to have fuse-on-read with an identity block.\n        //\n        // if self.fuser.num_ops == 0 && false {\n        //     self.fuser.current_output_shape = op.input.shape.dims.clone();\n        // } else if self.fuser.current_output_shape != op.input.shape.dims {\n\n        if self.fuser.current_output_shape != op.input.shape {\n            self.fuser.close();\n            self.fuser_read_fallback.close();\n            return;\n        }\n\n        let [input] = self\n            .fuser\n            .next_block([&op.input], self.settings_write, false);\n\n        let output = self.fuser.output_unhandled(&op.out);\n        let axis = op.axis;\n\n        let fuse_on_write_activated = match self.settings {\n            ReduceSettings::Always => true,\n            // We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise\n            // vectorization is impossible. Only [VectorizationMode::Perpendicular] supports vectorization.\n            //\n            // We could still fuse some output operations, but it would probably lead to worse performance.\n            ReduceSettings::OnlyParallel => axis != op.input.shape.rank() - 1,\n            ReduceSettings::Never => false,\n        };\n\n        if !fuse_on_write_activated {\n            self.fuser.close();\n        }\n\n        let acc = match inst {\n            ReduceInstruction::Mean | ReduceInstruction::Prod | ReduceInstruction::Sum => {\n                match input.precision() {\n                    FuseType::F16 | FuseType::BF16 => FuseType::F32,\n                    FuseType::I16 | FuseType::I8 => FuseType::I32,\n                    FuseType::U16 | FuseType::U8 => FuseType::U32,\n                    _ => input.precision(),\n                }\n            }\n            _ => input.precision(),\n        };\n\n        self.reduce = Some(FusedReduce {\n            input,\n            output,\n            acc,\n            axis,\n            op: op.clone(),\n            use_planes: false,\n            shared: false,\n            inst,\n        });\n\n        self.fuser_read_fallback.close();\n    }\n\n    fn on_elemwise_read(&mut self, operation: &OperationIr) {\n        let can_register =\n            self.fuser.can_fuse(operation) && self.fuser_read_fallback.can_fuse(operation);\n\n        match can_register {\n            true => {\n                self.fuser.fuse(operation);\n                self.fuser_read_fallback.fuse(operation);\n            }\n            false => {\n                self.fuser.close();\n                self.fuser_read_fallback.close();\n            }\n        };\n    }\n\n    fn on_elemwise_write(&mut self, operation: &OperationIr) {\n        let can_register =\n            self.fuser.can_fuse(operation) && self.fuser_write_fallback.can_fuse(operation);\n\n        match can_register {\n            true => {\n                self.fuser.fuse(operation);\n                self.fuser_write_fallback.fuse(operation);\n            }\n            false => {\n                self.fuser.close();\n                self.fuser_write_fallback.close();\n            }\n        };\n    }\n}\n\nimpl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceFuser<R> {\n    fn fuse(&mut self, operation: &OperationIr) {\n        if let FuserStatus::Closed = self.fuser.status() {\n            return;\n        }\n\n        if self.reduce.is_none() {\n            if let OperationIr::NumericFloat(_, op) = operation {\n                match op {\n                    NumericOperationIr::SumDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Sum);\n                    }\n                    NumericOperationIr::MeanDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Mean);\n                    }\n                    NumericOperationIr::ProdDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Prod);\n                    }\n                    NumericOperationIr::ArgMax(op) => {\n                        self.on_reduce(op, ReduceInstruction::ArgMax);\n                    }\n                    NumericOperationIr::ArgMin(op) => {\n                        self.on_reduce(op, ReduceInstruction::ArgMin);\n                    }\n                    NumericOperationIr::MinDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Min);\n                    }\n                    NumericOperationIr::MaxDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Max);\n                    }\n                    NumericOperationIr::MaxAbsDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::MaxAbs);\n                    }\n                    _ => {\n                        self.on_elemwise_read(operation);\n                    }\n                };\n            } else if let OperationIr::NumericInt(_, op) = operation {\n                match op {\n                    NumericOperationIr::SumDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Sum);\n                    }\n                    NumericOperationIr::MeanDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Mean);\n                    }\n                    NumericOperationIr::ProdDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Prod);\n                    }\n                    NumericOperationIr::ArgMax(op) => {\n                        self.on_reduce(op, ReduceInstruction::ArgMax);\n                    }\n                    NumericOperationIr::ArgMin(op) => {\n                        self.on_reduce(op, ReduceInstruction::ArgMin);\n                    }\n                    NumericOperationIr::MinDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Min);\n                    }\n                    NumericOperationIr::MaxDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::Max);\n                    }\n                    NumericOperationIr::MaxAbsDim(op) => {\n                        self.on_reduce(op, ReduceInstruction::MaxAbs);\n                    }\n                    _ => {\n                        self.on_elemwise_read(operation);\n                    }\n                };\n            } else {\n                self.on_elemwise_read(operation);\n            }\n        } else {\n            self.on_elemwise_write(operation);\n        }\n    }\n\n    fn finish(&mut self) -> CubeOptimization<R> {\n        let client = R::client(&self.device);\n        let trace = self.fuser.finish();\n        let trace_read_fallback = self.fuser_read_fallback.finish();\n        let trace_write_fallback = self.fuser_write_fallback.finish();\n        let fuse_reduce = self.reduce.as_ref().unwrap();\n\n        let reduce = ReduceOptimization::new(\n            trace,\n            trace_read_fallback,\n            trace_write_fallback,\n            client,\n            self.device.clone(),\n            self.len(),\n            self.fuser_read_fallback.len(),\n            fuse_reduce.clone(),\n            self.settings,\n        );\n\n        CubeOptimization::Reduce(reduce)\n    }\n\n    fn reset(&mut self) {\n        self.fuser.reset();\n        self.fuser_read_fallback.reset();\n        self.fuser_write_fallback.reset();\n        self.reduce = None;\n    }\n\n    fn status(&self) -> burn_fusion::FuserStatus {\n        self.fuser.status()\n    }\n\n    fn properties(&self) -> burn_fusion::FuserProperties {\n        let mut properties = self.fuser.properties();\n        properties.ready = self.reduce.is_some();\n        properties\n    }\n\n    fn len(&self) -> usize {\n        self.fuser.len() + if self.reduce.is_some() { 1 } else { 0 }\n    }\n\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {\n        Box::new(self.clone())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce/mod.rs",
    "content": "mod fuser;\nmod optimization;\n\npub(crate) mod args;\npub(crate) mod tune;\n\npub use fuser::*;\npub use optimization::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs",
    "content": "use super::args::{\n    FusedReduceInput, FusedReduceInputLaunch, FusedReduceOutput, FusedReduceOutputLaunch,\n};\n#[cfg(feature = \"autotune\")]\nuse super::tune::fused_reduce_autotune;\nuse crate::{\n    CubeFusionHandle, FallbackOperation,\n    engine::{\n        codegen::ir::{\n            FuseArg, FuseBlockConfig, FuseType, GlobalArgsLaunch, RefLayout,\n            multi_block_variables_init,\n        },\n        launch::{\n            FuseTraceLauncher,\n            runner::{TraceRunner, Vectorization},\n        },\n        trace::{FuseTrace, TraceError, TuneOutput},\n    },\n    optim::{elemwise::ElemwiseRunner, reduce::args::FusedReduceArgs},\n};\nuse burn_fusion::stream::Context;\nuse burn_ir::ReduceDimOpIr;\nuse burn_std::DType;\nuse cubecl::{Runtime, client::ComputeClient, ir::StorageType, prelude::*};\nuse cubek::reduce::{\n    ReduceDtypes, ReduceError, VectorizationMode,\n    components::instructions::ReduceOperationConfig,\n    init_tensors,\n    launch::{RoutineStrategy, reduce_kernel_virtual},\n    routines::{\n        ReduceBlueprint, ReduceLaunchSettings, ReduceProblem, ReduceVectorSettings, Routine,\n        cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,\n    },\n};\nuse serde::{Deserialize, Serialize};\nuse std::sync::Arc;\n\n#[cfg(not(feature = \"autotune\"))]\nuse cubek::reduce::routines::{BlueprintStrategy, unit::UnitStrategy};\n\npub struct ReduceOptimization<R: Runtime> {\n    pub(crate) info: Arc<ReduceOptimizationInfo<R>>,\n}\n\npub(crate) struct ReduceOptimizationInfo<R: Runtime> {\n    pub(crate) trace: FuseTrace,\n    trace_read_fallback: FuseTrace,\n    trace_write_fallback: FuseTrace,\n    pub(crate) client: ComputeClient<R>,\n    pub(crate) device: R::Device,\n    pub(crate) len: usize,\n    pub(crate) len_read: usize,\n    pub(crate) reduce: FusedReduce,\n    settings: ReduceSettings,\n}\n\nimpl<R: Runtime> ReduceOptimizationInfo<R> {\n    pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {\n        let client = R::client(device);\n\n        Self {\n            trace: state.trace,\n            trace_read_fallback: state.trace_read_fallback,\n            trace_write_fallback: state.trace_write_fallback,\n            client,\n            device: device.clone(),\n            len: state.len,\n            len_read: state.len_read,\n            reduce: state.reduce,\n            settings: state.settings,\n        }\n    }\n    pub fn to_state(&self) -> ReduceOptimizationState {\n        ReduceOptimizationState {\n            trace: self.trace.clone(),\n            trace_read_fallback: self.trace_read_fallback.clone(),\n            trace_write_fallback: self.trace_write_fallback.clone(),\n            len: self.len,\n            len_read: self.len_read,\n            reduce: self.reduce.clone(),\n            settings: self.settings,\n        }\n    }\n}\n\n#[derive(Serialize, Deserialize, Copy, Clone)]\npub enum ReduceSettings {\n    Always,\n    /// We only activate fuse-on-write when the reduction isn't on the last dimension, otherwise\n    /// vectorization is impossible. Only [VectorizationMode::Perpendicular] supports vectorization.\n    ///\n    /// We could still fuse some output operations, but it would probably lead to worse performance.\n    OnlyParallel,\n    Never,\n}\n\npub(crate) struct ReduceOptimizationTuneArg<R: Runtime> {\n    pub(crate) info: Arc<ReduceOptimizationInfo<R>>,\n    pub(crate) fallback: Arc<Box<dyn FallbackOperation<R>>>,\n}\n\nimpl<R: Runtime> Clone for ReduceOptimizationTuneArg<R> {\n    fn clone(&self) -> Self {\n        Self {\n            info: self.info.clone(),\n            fallback: self.fallback.clone(),\n        }\n    }\n}\n\n#[derive(Clone, Copy, Serialize, Deserialize, Debug)]\npub enum ReduceInstruction {\n    ArgMax,\n    ArgMin,\n    Mean,\n    Prod,\n    Sum,\n    Max,\n    Min,\n    MaxAbs,\n}\n\npub trait ReduceFallbackFn<R: Runtime>: Send + Sync {\n    fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);\n}\n\n#[derive(Serialize, Deserialize)]\npub struct ReduceOptimizationState {\n    pub(crate) trace: FuseTrace,\n    pub(crate) trace_read_fallback: FuseTrace,\n    pub(crate) trace_write_fallback: FuseTrace,\n    pub(crate) reduce: FusedReduce,\n    pub(crate) len: usize,\n    pub(crate) len_read: usize,\n    pub(crate) settings: ReduceSettings,\n}\n\nimpl core::fmt::Debug for ReduceOptimizationState {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_fmt(format_args!(\n            \"{{ len_read: {}, len_total: {} }}\",\n            self.len_read, self.len\n        ))\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct FusedReduce {\n    pub(crate) input: FuseArg,\n    pub(crate) output: FuseArg,\n    pub(crate) acc: FuseType,\n    pub(crate) axis: usize,\n    pub(crate) op: ReduceDimOpIr,\n    pub(crate) use_planes: bool,\n    pub(crate) shared: bool,\n    pub(crate) inst: ReduceInstruction,\n}\n\n#[derive(new)]\npub struct FusedReduceLaunch<'a> {\n    reduce: &'a FusedReduce,\n    strategy: RoutineStrategy,\n}\n\n#[derive(Debug)]\npub enum FusedReduceError {\n    Reduce(ReduceError),\n    InvalidSelection(Box<&'static str>),\n    InvalidInput,\n}\n\nimpl From<ReduceError> for FusedReduceError {\n    fn from(value: ReduceError) -> Self {\n        Self::Reduce(value)\n    }\n}\n\nimpl<R: Runtime> ReduceOptimizationTuneArg<R> {\n    pub fn execute_fused(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        strategy: RoutineStrategy,\n    ) -> Result<TuneOutput<R>, TraceError<FusedReduceError>> {\n        let launch = FusedReduceLaunch::new(&self.info.reduce, strategy);\n        let launcher = FuseTraceLauncher::new(&self.info.trace, &launch);\n        launcher.launch(&self.info.client, &self.info.device, context)\n    }\n\n    pub fn execute_fallback(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n    ) -> TuneOutput<R> {\n        let launcher = FuseTraceLauncher::new(&self.info.trace_read_fallback, &ElemwiseRunner);\n\n        #[allow(unused_mut)] // It is used when `autotune-checks` is activated.\n        let mut output_read = launcher\n            .launch(&self.info.client, &self.info.device, context)\n            .unwrap();\n\n        self.fallback.run(context);\n\n        #[cfg(feature = \"autotune-checks\")]\n        if let TuneOutput::Checked { handles } = &mut output_read {\n            let out_desc = context.tensors.get(&self.info.reduce.op.out.id).unwrap();\n            let handle_out = context\n                .handles\n                .get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);\n\n            handles.insert(\n                self.info.reduce.op.out.id,\n                (out_desc.shape.dims.clone(), handle_out.clone()),\n            );\n        }\n\n        let launcher = FuseTraceLauncher::new(&self.info.trace_write_fallback, &ElemwiseRunner);\n\n        let output_write = launcher\n            .launch(&self.info.client, &self.info.device, context)\n            .unwrap();\n\n        output_read.merge(output_write)\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\nimpl<R: Runtime> ReduceOptimization<R> {\n    pub fn new(\n        trace: FuseTrace,\n        trace_read_fallback: FuseTrace,\n        trace_write_fallback: FuseTrace,\n        client: ComputeClient<R>,\n        device: R::Device,\n        len: usize,\n        len_read: usize,\n        reduce: FusedReduce,\n        settings: ReduceSettings,\n    ) -> Self {\n        let info = ReduceOptimizationInfo {\n            trace,\n            trace_read_fallback,\n            trace_write_fallback,\n            client,\n            device,\n            len,\n            len_read,\n            reduce,\n            settings,\n        };\n\n        Self {\n            info: Arc::new(info),\n        }\n    }\n    /// Execute the optimization.\n    pub fn execute(\n        &mut self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,\n    ) {\n        // The index of the fallback reduce is the number of ops fused as read.\n        let fallback = fallback(self.info.len_read);\n        let arg = ReduceOptimizationTuneArg {\n            info: self.info.clone(),\n            fallback: Arc::new(fallback),\n        };\n\n        #[cfg(feature = \"autotune\")]\n        fused_reduce_autotune::<R>(arg, context);\n\n        #[cfg(not(feature = \"autotune\"))]\n        if arg\n            .execute_fused(\n                context,\n                RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),\n            )\n            .is_err()\n        {\n            arg.execute_fallback(context);\n        }\n    }\n\n    pub fn num_output_buffers(&self) -> usize {\n        self.info.trace_read_fallback.resources.outputs.len()\n    }\n\n    pub fn to_state(&self) -> ReduceOptimizationState {\n        ReduceOptimizationState {\n            trace: self.info.trace.clone(),\n            trace_read_fallback: self.info.trace_read_fallback.clone(),\n            trace_write_fallback: self.info.trace_write_fallback.clone(),\n            reduce: self.info.reduce.clone(),\n            len: self.info.len,\n            len_read: self.info.len_read,\n            settings: self.info.settings,\n        }\n    }\n\n    pub fn from_state(device: &R::Device, state: ReduceOptimizationState) -> Self {\n        let client = R::client(device);\n\n        let info = ReduceOptimizationInfo {\n            trace: state.trace,\n            trace_read_fallback: state.trace_read_fallback,\n            trace_write_fallback: state.trace_write_fallback,\n            reduce: state.reduce,\n            len: state.len,\n            len_read: state.len_read,\n            client,\n            device: device.clone(),\n            settings: state.settings,\n        };\n\n        Self {\n            info: Arc::new(info),\n        }\n    }\n\n    /// Returns the number of output buffers added by fusion.\n    pub fn num_ops_fused(&self) -> usize {\n        self.info.len\n    }\n}\n\n// TODO: Implement better vectorization here.\nimpl<R: Runtime> Vectorization<R> for FusedReduceLaunch<'_> {}\n\nimpl<R: Runtime> TraceRunner<R> for FusedReduceLaunch<'_> {\n    type Error = FusedReduceError;\n\n    fn run<'a>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        configs: &'a [FuseBlockConfig],\n    ) -> Result<(), FusedReduceError> {\n        let [config_read, config_write] = [&configs[0], &configs[1]];\n        let shape = match &config_read.ref_layout {\n            RefLayout::Concrete(FuseArg::Output(..)) => {\n                outputs.shape_ref(&config_read.ref_layout, config_read.rank)\n            }\n            _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank),\n        };\n        let reduce_count: usize = shape\n            .iter()\n            .enumerate()\n            .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s })\n            .product();\n\n        let vectorization_mode = match self.reduce.axis == config_read.rank - 1 {\n            true => VectorizationMode::Parallel,\n            false => VectorizationMode::Perpendicular,\n        };\n        let address_type = inputs\n            .required_address_type()\n            .max(outputs.required_address_type());\n\n        let settings = ReduceVectorSettings {\n            vectorization_mode,\n            vector_size_input: config_read.width,\n            vector_size_output: config_write.width,\n        };\n        let problem = ReduceProblem {\n            vector_size: shape[self.reduce.axis],\n            vector_count: reduce_count,\n            axis: self.reduce.axis,\n            dtypes: ReduceDtypes {\n                input: self.reduce.op.input.dtype.into(),\n                output: self.reduce.op.out.dtype.into(),\n                accumulation: self.reduce.acc.into_elem().into(),\n            },\n            address_type,\n        };\n\n        let (blueprint, settings) = match self.strategy.clone() {\n            RoutineStrategy::Unit(strategy) => {\n                let routine = UnitRoutine;\n                routine.prepare(client, problem, settings, strategy)?\n            }\n            RoutineStrategy::Plane(strategy) => {\n                let routine = PlaneRoutine;\n                routine.prepare(client, problem, settings, strategy)?\n            }\n            RoutineStrategy::Cube(strategy) => {\n                let routine = CubeRoutine;\n                routine.prepare(client, problem, settings, strategy)?\n            }\n        };\n\n        let kwargs = ReduceKwArgs {\n            client,\n            inputs,\n            outputs,\n            axis: self.reduce.axis,\n            config_fuse_read: config_read.clone(),\n            config_fuse_write: config_write.clone(),\n            input: self.reduce.input.clone(),\n            output: self.reduce.output.clone(),\n            blueprint,\n            settings,\n        };\n        let result = launch_reduce_mixed_precision(\n            kwargs,\n            self.reduce.inst,\n            self.reduce.op.input.dtype,\n            self.reduce.op.out.dtype,\n            DType::from(self.reduce.acc.into_elem()),\n        );\n\n        match result {\n            Ok(_) => Ok(()),\n            Err(err) => Err(FusedReduceError::Reduce(ReduceError::Launch(err))),\n        }\n    }\n}\n\nstruct ReduceKwArgs<'b, Run: Runtime> {\n    client: &'b ComputeClient<Run>,\n    inputs: GlobalArgsLaunch<Run>,\n    outputs: GlobalArgsLaunch<Run>,\n    axis: usize,\n    blueprint: ReduceBlueprint,\n    settings: ReduceLaunchSettings,\n    config_fuse_read: FuseBlockConfig,\n    config_fuse_write: FuseBlockConfig,\n    input: FuseArg,\n    output: FuseArg,\n}\n\nfn launch_reduce_mixed_precision<Run: Runtime>(\n    kwargs: ReduceKwArgs<'_, Run>,\n    instruction: ReduceInstruction,\n    dtype_input: DType,\n    dtype_output: DType,\n    dtype_acc: DType,\n) -> Result<(), LaunchError> {\n    let config = match instruction {\n        ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,\n        ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,\n        ReduceInstruction::Prod => ReduceOperationConfig::Prod,\n        ReduceInstruction::Mean => ReduceOperationConfig::Mean,\n        ReduceInstruction::Sum => ReduceOperationConfig::Sum,\n        ReduceInstruction::Max => ReduceOperationConfig::Max,\n        ReduceInstruction::Min => ReduceOperationConfig::Min,\n        ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,\n    };\n    launch_reduce::<Run>(kwargs, config, dtype_input, dtype_output, dtype_acc)\n}\n\nfn launch_reduce<Run: Runtime>(\n    kwargs: ReduceKwArgs<'_, Run>,\n    inst: ReduceOperationConfig,\n    dtype_input: DType,\n    dtype_output: DType,\n    dtype_acc: DType,\n) -> Result<(), LaunchError> {\n    unsafe {\n        reduce_kernel_fused::launch_unchecked::<Run>(\n            kwargs.client,\n            kwargs.settings.cube_count,\n            kwargs.settings.cube_dim,\n            kwargs.settings.address_type,\n            kwargs.config_fuse_read.width,\n            kwargs.config_fuse_write.width,\n            FusedReduceInputLaunch::new(kwargs.inputs, kwargs.config_fuse_read, kwargs.input),\n            FusedReduceOutputLaunch::new(kwargs.outputs, kwargs.config_fuse_write, kwargs.output),\n            kwargs.axis,\n            kwargs.blueprint,\n            inst,\n            dtype_input.into(),\n            dtype_output.into(),\n            dtype_acc.into(),\n        )\n    };\n\n    Ok(())\n}\n\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub fn reduce_kernel_fused<In: Numeric, SizeIn: Size, Out: Numeric, SizeOut: Size, Acc: Numeric>(\n    input: &FusedReduceInput,\n    output: &mut FusedReduceOutput,\n    axis_reduce: usize,\n    #[comptime] blueprint: ReduceBlueprint,\n    #[comptime] config: ReduceOperationConfig,\n    #[define(In)] _input_dtype: StorageType,\n    #[define(Out)] _output_dtype: StorageType,\n    #[define(Acc)] _acc_dtype: StorageType,\n) {\n    multi_block_variables_init(&input.config, &mut output.global.variables);\n    multi_block_variables_init(&output.config, &mut output.global.variables);\n\n    let (input, mut output) =\n        init_tensors::<FusedReduceArgs, In, SizeIn, Out, SizeOut>(input, output);\n\n    reduce_kernel_virtual::<In, SizeIn, Out, SizeOut, Acc>(\n        &input,\n        &mut output,\n        axis_reduce,\n        blueprint,\n        config,\n    );\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce/tune.rs",
    "content": "use super::optimization::ReduceOptimizationTuneArg;\nuse crate::{\n    CubeFusionHandle,\n    engine::trace::TuneOutput,\n    tune::{TuneContext, TuneInput},\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{\n    AutotuneKey, CubeTuneId, Runtime,\n    tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},\n};\nuse cubek::reduce::{\n    launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},\n    routines::{BlueprintStrategy, cube::CubeStrategy, plane::PlaneStrategy, unit::UnitStrategy},\n};\nuse serde::{Deserialize, Serialize};\n\n/// Autotune key for standard fused reduction operations.\n///\n/// Records metadata about the fusion graph (IO and ops) alongside\n/// the core reduction parameters to ensure stable kernel selection.\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\npub struct FusedReduceAutotuneKey {\n    reduce_key: ReduceAutotuneKey,\n    #[autotune(anchor)]\n    fuse_num_reads: usize,\n    #[autotune(anchor)]\n    fuse_num_writes: usize,\n    #[autotune(anchor)]\n    fuse_num_ops: usize,\n}\n\n/// Executes autotuning for fused reduction operations.\n///\n/// This tuner evaluates different hardware-specific strategies (Plane, Cube, Unit)\n/// and assigns priorities based on the `vector_count` of the reduction.\npub fn fused_reduce_autotune<R: Runtime>(\n    arg: ReduceOptimizationTuneArg<R>,\n    context: &mut Context<CubeFusionHandle<R>>,\n) {\n    static TUNER: LocalTuner<FusedReduceAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 2;\n        const PRIORITY_MIN: i8 = 1;\n\n        let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);\n        let group = TuneGroup::<FusedReduceAutotuneKey>::new(\"fused_reduce\", |_key| PRIORITY_MAX);\n\n        // Fallback implementation for robustness.\n        set = set.with(Tunable::new(\"fused_reduce_fallback\", tune_fallback::<R>));\n\n        // Define properties to categorize hardware strategies.\n        enum ReduceProps {\n            GreatWithLowReduceCount,\n            GreatWithHighReduceCount,\n            Balanced,\n        }\n\n        let strategies = [\n            (\n                \"fused_unit\",\n                RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),\n                ReduceProps::GreatWithHighReduceCount,\n            ),\n            (\n                \"fused_plane\",\n                RoutineStrategy::Plane(BlueprintStrategy::Inferred(PlaneStrategy {\n                    independent: true,\n                })),\n                ReduceProps::Balanced,\n            ),\n            (\n                \"fused_cube\",\n                RoutineStrategy::Cube(BlueprintStrategy::Inferred(CubeStrategy {\n                    // Two steps reduction doesn't work with fuse-on-write, we can't activate plane\n                    // when using the cube algo.\n                    use_planes: false,\n                })),\n                ReduceProps::GreatWithLowReduceCount,\n            ),\n        ];\n\n        for (name, strategy, props) in strategies {\n            let tunable = Tunable::new(name, move |input| tune_reduce::<R>(input, &strategy))\n                .group(&group, move |key| match props {\n                    ReduceProps::GreatWithLowReduceCount => {\n                        if key.reduce_key.vector_count < 128 {\n                            PRIORITY_MAX\n                        } else {\n                            PRIORITY_MIN\n                        }\n                    }\n                    ReduceProps::GreatWithHighReduceCount => {\n                        if key.reduce_key.vector_count > 64 {\n                            PRIORITY_MAX\n                        } else {\n                            PRIORITY_MIN\n                        }\n                    }\n                    ReduceProps::Balanced => PRIORITY_MAX,\n                });\n\n            set = set.with(tunable);\n        }\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&arg.info.client, &arg.info.device),\n        &arg.info.client.clone(),\n        tunables,\n        TuneInput::new(context, arg),\n    );\n}\n\n/// Creates the autotune key by extracting tensor metadata and fusion block statistics.\npub(crate) fn create_key<R: Runtime>(\n    input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,\n) -> FusedReduceAutotuneKey {\n    let opt = input.optimization();\n    let context = match input.context() {\n        TuneContext::Original(context) => context,\n        TuneContext::Fork(_) => panic!(\"Forked context not supported for key generation\"),\n    };\n\n    let input_tensor = context.tensors.get(&opt.info.reduce.op.input.id).unwrap();\n    let out_tensor = context.tensors.get(&opt.info.reduce.op.out.id).unwrap();\n    let acc = opt.info.reduce.acc.into_elem();\n\n    let key = ReduceAutotuneKey::generate(\n        input_tensor.dtype.into(),\n        out_tensor.dtype.into(),\n        acc,\n        &input_tensor.shape,\n        opt.info.reduce.axis == input_tensor.shape.rank() - 1,\n        opt.info.reduce.axis,\n    );\n\n    // Assume the fusion contains at least a read and a write block.\n    let read_block = &opt.info.trace.blocks[0];\n    let write_block = &opt.info.trace.blocks[1];\n\n    FusedReduceAutotuneKey::new(\n        key,\n        read_block.reads.len() + write_block.reads.len(),\n        read_block.writes.len() + write_block.writes.len(),\n        read_block.ops.len() + write_block.ops.len(),\n    )\n}\n\n/// Identity generator for tuning inputs.\nfn input_gen<R: Runtime>(\n    _key: &FusedReduceAutotuneKey,\n    input: &TuneInput<R, ReduceOptimizationTuneArg<R>>,\n) -> TuneInput<R, ReduceOptimizationTuneArg<R>> {\n    input.clone()\n}\n\n/// Executes a fused reduction optimization.\nfn tune_reduce<R: Runtime>(\n    input: TuneInput<R, ReduceOptimizationTuneArg<R>>,\n    strategy: &RoutineStrategy,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n\n    match input.context() {\n        TuneContext::Original(context) => optimization.execute_fused(context, strategy.clone()),\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fused(&mut context_owned.as_context(), strategy.clone())\n        }\n    }\n    .map_err(|e| format!(\"{e:?}\"))\n}\n\n/// Executes the fallback path for a reduction optimization.\nfn tune_fallback<R: Runtime>(\n    input: TuneInput<R, ReduceOptimizationTuneArg<R>>,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n\n    match input.context() {\n        TuneContext::Original(context) => optimization.execute_fallback(context),\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fallback(&mut context_owned.as_context())\n        }\n    };\n\n    Ok(TuneOutput::UnChecked(std::marker::PhantomData))\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/base.rs",
    "content": "use crate::optim::{\n    CubeOptimization,\n    reduce::{ReduceFuser, ReduceFuserInfo, ReduceSettings},\n    reduce_broadcasted::{\n        ReduceBroadcastedOptimization, ReduceBroadcastedOptimizationInfo,\n        fuser::{\n            block::{ReduceBlockFuser, ReduceBlockFusionAnalysis, ReduceBroadcastedStatus},\n            full::ReduceBroadcastedFullFuser,\n            full_analyzer::FullFuserAnalyzer,\n        },\n    },\n};\nuse burn_fusion::{FuserProperties, FuserStatus, OperationFuser};\nuse burn_ir::OperationIr;\nuse cubecl::Runtime;\nuse std::sync::Arc;\n\n/// Fuses element wise operations around a reduce operation.\npub struct ReduceBroadcastedFuser<R: Runtime> {\n    blocks: Vec<ReduceBlockFuser<R>>,\n    fuser_default: ReduceFuser<R>,\n    num_ops: usize,\n    state: ReduceBroadcastedStatus,\n    max_bindings: u32,\n}\n\nimpl<R: Runtime> Clone for ReduceBroadcastedFuser<R> {\n    fn clone(&self) -> Self {\n        Self {\n            blocks: self.blocks.clone(),\n            fuser_default: self.fuser_default.clone(),\n            num_ops: self.num_ops,\n            state: self.state.clone(),\n            max_bindings: self.max_bindings,\n        }\n    }\n}\n\nimpl<R: Runtime> ReduceBroadcastedFuser<R> {\n    pub fn new(device: R::Device) -> Self {\n        let fuser = ReduceFuser::new(device, ReduceSettings::Always);\n        let max_bindings = fuser.fuser.max_bindings;\n        let block = ReduceBlockFuser::new(fuser.clone());\n\n        Self {\n            blocks: vec![block],\n            fuser_default: fuser,\n            num_ops: 0,\n            state: ReduceBroadcastedStatus::Starting,\n            max_bindings,\n        }\n    }\n}\n\nimpl<R: Runtime> OperationFuser<CubeOptimization<R>> for ReduceBroadcastedFuser<R> {\n    fn fuse(&mut self, operation: &OperationIr) {\n        if matches!(\n            &self.state,\n            ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort\n        ) {\n            return;\n        }\n\n        let block = self.blocks.last_mut().unwrap();\n        let analyze = block.analyze(operation, &self.state, &self.fuser_default);\n\n        let info = match analyze {\n            ReduceBlockFusionAnalysis::Accept => {\n                block.fuse(operation);\n                self.num_ops += 1;\n                block.fuser.reduce_info()\n            }\n            ReduceBlockFusionAnalysis::Refuse => {\n                self.state = ReduceBroadcastedStatus::Closed;\n                return;\n            }\n            ReduceBlockFusionAnalysis::NewBlockRequired => {\n                let info = block.fuser.reduce_info();\n                let mut block = ReduceBlockFuser::new(self.fuser_default.clone());\n                block.fuse(operation);\n                self.num_ops += 1;\n                self.blocks.push(block);\n                info\n            }\n        };\n\n        match info {\n            ReduceFuserInfo::FusedReduce {\n                shape_input_id,\n                axis,\n            } => {\n                // Only support last axis for now.\n                if axis != shape_input_id.len() - 1 {\n                    self.state = ReduceBroadcastedStatus::Abort;\n                } else {\n                    self.state = ReduceBroadcastedStatus::Init {\n                        shape_id: shape_input_id,\n                        axis,\n                    };\n                }\n            }\n            ReduceFuserInfo::FusedElemwise { .. } => {}\n        }\n    }\n\n    fn finish(&mut self) -> CubeOptimization<R> {\n        let analyzer = FullFuserAnalyzer::new(&self.blocks);\n        let mut full = ReduceBroadcastedFullFuser::new(self.max_bindings, analyzer);\n        let mut num_ops = 0;\n        let fallbacks = self\n            .blocks\n            .iter_mut()\n            .map(|block| block.finish(&mut num_ops, &mut full))\n            .collect::<Vec<_>>();\n\n        let broadcasted = Arc::new(full.finish());\n        let info = Arc::new(ReduceBroadcastedOptimizationInfo {\n            fallbacks,\n            broadcasted,\n        });\n        CubeOptimization::ReduceBroadcasted(ReduceBroadcastedOptimization { info, num_ops })\n    }\n\n    fn reset(&mut self) {\n        let block = ReduceBlockFuser::new(self.fuser_default.clone());\n        self.blocks = vec![block];\n        self.num_ops = 0;\n        self.state = ReduceBroadcastedStatus::Starting;\n    }\n\n    fn status(&self) -> FuserStatus {\n        match self.state {\n            ReduceBroadcastedStatus::Closed | ReduceBroadcastedStatus::Abort => {\n                return FuserStatus::Closed;\n            }\n            _ => {}\n        };\n\n        let fuser = self.blocks.last().unwrap();\n        fuser.fuser.status()\n    }\n\n    fn properties(&self) -> FuserProperties {\n        let ready = match self.state {\n            ReduceBroadcastedStatus::Starting | ReduceBroadcastedStatus::Abort => false,\n            ReduceBroadcastedStatus::Closed => {\n                if self.blocks.len() == 1 {\n                    !self.blocks[0].is_elemwise()\n                } else {\n                    true\n                }\n            }\n            _ => true,\n        };\n        let mut props = FuserProperties { score: 0, ready };\n        for block in self.blocks.iter() {\n            let p = block.properties();\n            props.score += p.score;\n            props.ready = p.ready && props.ready;\n        }\n        props\n    }\n\n    fn len(&self) -> usize {\n        self.num_ops\n    }\n\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<CubeOptimization<R>>> {\n        Box::new(self.clone())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_ir::{\n        BaseOperationIr, BinaryOpIr, CreationOpIr, ReduceDimOpIr, TensorId, TensorIr, TensorStatus,\n    };\n    use burn_std::{DType, Shape};\n\n    use super::*;\n\n    type Run = cubecl::TestRuntime;\n\n    #[test]\n    fn reduce_broadcast_workflow_1() {\n        let device: <Run as Runtime>::Device = Default::default();\n        let mut fuser = ReduceBroadcastedFuser::<Run>::new(device);\n        let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);\n        let (tensor2_out, tensor2) = tensor(1, &[1, 0], TensorStatus::ReadWrite);\n\n        fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(\n            CreationOpIr { out: tensor1_out },\n        )));\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {\n                input: tensor1,\n                out: tensor2_out,\n                axis: 1,\n            }),\n        ));\n\n        let status = fuser.status();\n        assert_eq!(2, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        // An existing tensor\n        let (_tensor3_out, tensor3) = tensor(2, &[1, 0], TensorStatus::ReadWrite);\n        // A new tensor\n        let (tensor4_out, tensor4) = tensor(3, &[1, 0], TensorStatus::ReadWrite);\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::Add(BinaryOpIr {\n                lhs: tensor2,\n                rhs: tensor3,\n                out: tensor4_out,\n            }),\n        ));\n\n        let status = fuser.status();\n        assert_eq!(3, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        // An existing tensor\n        let (_tensor5_out, tensor5) = tensor(4, &[1, 2], TensorStatus::ReadWrite);\n        // A new tensor\n        let (tensor6_out, tensor6) = tensor(5, &[1, 2], TensorStatus::ReadWrite);\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::Add(BinaryOpIr {\n                lhs: tensor4,\n                rhs: tensor5,\n                out: tensor6_out,\n            }),\n        ));\n\n        let status = fuser.status();\n        assert_eq!(4, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        let (tensor7_out, _tensor7) = tensor(6, &[1, 0], TensorStatus::ReadWrite);\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {\n                input: tensor6,\n                out: tensor7_out,\n                axis: 1,\n            }),\n        ));\n        assert_eq!(5, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        let _optimization = fuser.finish();\n    }\n\n    #[test]\n    fn reduce_broadcast_workflow_2() {\n        let device: <Run as Runtime>::Device = Default::default();\n        let mut fuser = ReduceBroadcastedFuser::<Run>::new(device);\n        let (tensor1_out, tensor1) = tensor(0, &[1, 2], TensorStatus::ReadWrite);\n        // An existing tensor\n        let (_tensor2_out, mut tensor2) = tensor(2, &[1, 2], TensorStatus::ReadOnly);\n        let (tensor3_out, tensor3) = tensor(3, &[1, 2], TensorStatus::ReadWrite);\n\n        // First reduce output\n        let (tensor4_out, tensor4) = tensor(1, &[1, 0], TensorStatus::ReadWrite);\n\n        fuser.fuse(&OperationIr::BaseFloat(BaseOperationIr::Ones(\n            CreationOpIr { out: tensor1_out },\n        )));\n\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::Add(BinaryOpIr {\n                lhs: tensor1,\n                rhs: tensor2.clone(),\n                out: tensor3_out,\n            }),\n        ));\n\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::SumDim(ReduceDimOpIr {\n                input: tensor3,\n                out: tensor4_out,\n                axis: 1,\n            }),\n        ));\n\n        let status = fuser.status();\n        assert_eq!(3, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        // A new tensor\n        let (tensor5_out, _tensor5) = tensor(5, &[1, 2], TensorStatus::ReadWrite);\n        // Last time we use tensor2.\n        tensor2.status = TensorStatus::ReadWrite;\n        fuser.fuse(&OperationIr::NumericFloat(\n            DType::F32,\n            burn_ir::NumericOperationIr::Add(BinaryOpIr {\n                lhs: tensor4,\n                rhs: tensor2,\n                out: tensor5_out,\n            }),\n        ));\n\n        let status = fuser.status();\n        assert_eq!(4, fuser.len());\n        assert_eq!(status, FuserStatus::Open);\n        assert!(fuser.properties().ready,);\n\n        let _optimization = fuser.finish();\n    }\n\n    fn tensor(id: u64, shape: &[usize], status: TensorStatus) -> (TensorIr, TensorIr) {\n        let tensor = TensorIr {\n            id: TensorId::new(id),\n            shape: Shape::from(shape),\n            status: TensorStatus::NotInit,\n            dtype: DType::F32,\n        };\n        let mut tensor_init = tensor.clone();\n        tensor_init.status = status;\n\n        (tensor, tensor_init)\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/block.rs",
    "content": "use crate::optim::{\n    CubeOptimization,\n    elemwise::ElemwiseOptimization,\n    reduce::{FusedReduce, ReduceFuser, ReduceFuserInfo},\n    reduce_broadcasted::{ReduceBlockOptimInfo, fuser::full::ReduceBroadcastedFullFuser},\n};\nuse burn_fusion::{FuserProperties, OperationFuser};\nuse burn_ir::OperationIr;\nuse burn_std::Shape;\nuse cubecl::Runtime;\nuse std::sync::Arc;\n\n/// Responsible for fusing a single reduce block or elementwise block.\n///\n/// When the block kind is reduce, it supports fuse-on-read and fuse-on-write fusion.\n/// Broadcasting isn't supported; another block should handle it instead.\npub struct ReduceBlockFuser<R: Runtime> {\n    /// We use [ReduceFuser] for both elementwise and reduce blocks, keeping only the\n    /// fuse-on-read trace if the block is tagged as elementwise.\n    ///\n    /// # Notes\n    ///\n    /// A single elementwise block can only exist at the end of a full [ReduceBlockFuser],\n    /// otherwise the optimization will be included in the reduce fusion block.\n    pub fuser: ReduceFuser<R>,\n    pub(crate) ops: Vec<OperationIr>,\n    pub(crate) kind: ReduceBlockKind,\n}\n\n/// The current state of the fusion process.\n#[derive(Debug, Clone)]\npub enum ReduceBroadcastedStatus {\n    /// Fusion is starting; no reduction has been fused yet.\n    Starting,\n    /// Fusion is initialized with at least one reduce operation.\n    ///\n    /// # Notes\n    ///\n    /// Subsequent reduce operations must be compatible with the previous reduction to fuse.\n    Init { shape_id: Shape, axis: usize },\n    /// No more operations can be fused.\n    Closed,\n    /// Invalid axis.\n    Abort,\n}\n\n/// The [ReduceBlockFuser] capacity to accept an [OperationIr].\n#[derive(Clone, Copy, Debug)]\npub enum ReduceBlockFusionAnalysis {\n    /// The operation can be fused; call [ReduceBlockFuser::fuse()].\n    Accept,\n    /// The operation cannot be fused; the optimization should close.\n    Refuse,\n    /// The operation can be fused, but requires a new block.\n    NewBlockRequired,\n}\n\nimpl<R: Runtime> ReduceBlockFuser<R> {\n    /// Creates a new block.\n    pub fn new(fuser: ReduceFuser<R>) -> Self {\n        Self {\n            fuser: fuser.clone(),\n            ops: Vec::new(),\n            kind: ReduceBlockKind::Elemwise,\n        }\n    }\n\n    /// Returns true if this is an elementwise fuser.\n    pub fn is_elemwise(&self) -> bool {\n        matches!(self.kind, ReduceBlockKind::Elemwise)\n    }\n\n    /// Analyzes if fusion is possible within this block.\n    pub fn analyze(\n        &self,\n        op: &OperationIr,\n        status: &ReduceBroadcastedStatus,\n        default_node: &ReduceFuser<R>,\n    ) -> ReduceBlockFusionAnalysis {\n        let mut fuser_try = self.fuser.clone();\n        let before = fuser_try.len();\n        fuser_try.fuse(op);\n        let after = fuser_try.len();\n\n        if after > before {\n            return ReduceBlockFusionAnalysis::Accept;\n        }\n\n        // Can't create a new block if the previous one was not a reduction.\n        if self.fuser.reduce.is_none() {\n            return ReduceBlockFusionAnalysis::Refuse;\n        }\n\n        let mut fuser_try = default_node.clone();\n        let before = fuser_try.len();\n        fuser_try.fuse(op);\n        let after = fuser_try.len();\n\n        if after > before {\n            let info = fuser_try.reduce_info();\n\n            return match (info, status) {\n                (\n                    ReduceFuserInfo::FusedReduce {\n                        shape_input_id,\n                        axis,\n                    },\n                    ReduceBroadcastedStatus::Init {\n                        shape_id,\n                        axis: axis_init,\n                    },\n                ) => {\n                    if shape_id == &shape_input_id && axis_init == &axis {\n                        ReduceBlockFusionAnalysis::NewBlockRequired\n                    } else {\n                        ReduceBlockFusionAnalysis::Refuse\n                    }\n                }\n                (\n                    ReduceFuserInfo::FusedElemwise { shape_id },\n                    ReduceBroadcastedStatus::Init {\n                        shape_id: shape_init,\n                        ..\n                    },\n                ) => {\n                    if &shape_id == shape_init {\n                        ReduceBlockFusionAnalysis::NewBlockRequired\n                    } else {\n                        ReduceBlockFusionAnalysis::Refuse\n                    }\n                }\n                _ => ReduceBlockFusionAnalysis::Refuse,\n            };\n        }\n\n        ReduceBlockFusionAnalysis::Refuse\n    }\n\n    /// Fuses an operation within this block.\n    ///\n    /// # Warning\n    ///\n    /// Ensure [Self::analyze()] is called before this function to confirm the operation is accepted.\n    pub fn fuse(&mut self, op: &OperationIr) {\n        self.fuser.fuse(op);\n        self.ops.push(op.clone());\n\n        // Update the kind if a reduction is introduced to an elementwise block.\n        if let (Some(reduce), ReduceBlockKind::Elemwise) = (&self.fuser.reduce, &self.kind) {\n            self.kind = ReduceBlockKind::Reduce {\n                ops_index: self.ops.len() - 1,\n                reduce: Box::new(reduce.clone()),\n            };\n        }\n    }\n\n    /// Computes the fuser properties.\n    pub fn properties(&self) -> FuserProperties {\n        let mut properties = self.fuser.properties();\n        if let ReduceBlockKind::Elemwise = &self.kind {\n            // Elementwise traces are always ready to run.\n            properties.ready = true;\n        }\n        properties\n    }\n\n    pub fn finish(\n        &mut self,\n        num_ops: &mut usize,\n        full: &mut ReduceBroadcastedFullFuser,\n    ) -> ReduceBlockOptimInfo<R> {\n        full.register(self);\n\n        match &self.kind {\n            ReduceBlockKind::Elemwise => {\n                let len = self.fuser.fuser_read_fallback.len();\n                let device = self.fuser.device.clone();\n                *num_ops += len;\n                let trace = self.fuser.fuser_read_fallback.finish();\n                let client = R::client(&device);\n                let elementwise = ElemwiseOptimization::new(trace, client, device, len);\n                ReduceBlockOptimInfo::Elemwise(Arc::new(elementwise))\n            }\n            ReduceBlockKind::Reduce { .. } => {\n                *num_ops += self.fuser.len();\n                let optim = self.fuser.finish();\n                let info = match optim {\n                    CubeOptimization::Reduce(optim) => optim.info,\n                    _ => unreachable!(\"Expected Reduce optimization\"),\n                };\n                ReduceBlockOptimInfo::Reduce(info)\n            }\n        }\n    }\n}\n\n#[derive(Clone, Debug)]\npub enum ReduceBlockKind {\n    Elemwise,\n    Reduce {\n        ops_index: usize,\n        reduce: Box<FusedReduce>,\n    },\n}\n\nimpl<R: Runtime> Clone for ReduceBlockFuser<R> {\n    fn clone(&self) -> Self {\n        Self {\n            fuser: self.fuser.clone(),\n            ops: self.ops.clone(),\n            kind: self.kind.clone(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/full.rs",
    "content": "use crate::{\n    engine::{\n        fuser::TraceOperationFuser,\n        settings::{FuseSettings, RefLayoutSetting, VectorizationSetting},\n    },\n    optim::{\n        reduce::{FusedReduce, ReduceInstruction},\n        reduce_broadcasted::{\n            ReduceBroadcastedInfo,\n            fuser::{\n                block::{ReduceBlockFuser, ReduceBlockKind},\n                full_analyzer::FullFuserAnalyzer,\n            },\n            launch::ReduceBroadcastedFuseBlock,\n        },\n    },\n};\nuse burn_fusion::OperationFuser;\nuse cubecl::Runtime;\nuse cubek::reduce::components::instructions::ReduceOperationConfig;\n\n/// Responsible for fusing a single trace for all operations involved in this optimization.\npub struct ReduceBroadcastedFullFuser {\n    pub(crate) fuser: TraceOperationFuser,\n    analyzer: FullFuserAnalyzer,\n    blocks: Vec<ReduceBlockKind>,\n    settings_read: FuseSettings,\n    settings_write: FuseSettings,\n}\n\nimpl ReduceBroadcastedFullFuser {\n    /// Creates a new fuser with the given settings.\n    pub fn new(max_bindings: u32, analyzer: FullFuserAnalyzer) -> Self {\n        let settings_read = FuseSettings {\n            output_shape_updates: true,\n            broadcast: true,\n            inplace: false,\n            ref_layout: RefLayoutSetting::OnlyContiguous,\n            vectorization: VectorizationSetting::Activated,\n        };\n        let settings_write = FuseSettings {\n            output_shape_updates: false,\n            inplace: false,\n            broadcast: false,\n            ref_layout: RefLayoutSetting::OnlyContiguous,\n            // Deactivated for now, but would be cool to support vectorization of the output.\n            vectorization: VectorizationSetting::Deactivated,\n        };\n        let fuser = TraceOperationFuser::new(max_bindings, settings_read);\n\n        Self {\n            fuser,\n            blocks: Vec::new(),\n            settings_write,\n            settings_read,\n            analyzer,\n        }\n    }\n\n    /// Finishes fusing all blocks.\n    pub fn finish(mut self) -> ReduceBroadcastedInfo {\n        let mut reduce_axis = 0;\n        let mut blocks = Vec::new();\n\n        for block in self.blocks.iter() {\n            match block {\n                ReduceBlockKind::Elemwise => {}\n                ReduceBlockKind::Reduce { reduce, .. } => {\n                    let config = match reduce.inst {\n                        ReduceInstruction::ArgMax => ReduceOperationConfig::ArgMax,\n                        ReduceInstruction::ArgMin => ReduceOperationConfig::ArgMin,\n                        ReduceInstruction::Prod => ReduceOperationConfig::Prod,\n                        ReduceInstruction::Mean => ReduceOperationConfig::Mean,\n                        ReduceInstruction::Sum => ReduceOperationConfig::Sum,\n                        ReduceInstruction::Max => ReduceOperationConfig::Max,\n                        ReduceInstruction::Min => ReduceOperationConfig::Min,\n                        ReduceInstruction::MaxAbs => ReduceOperationConfig::MaxAbs,\n                    };\n\n                    let block = ReduceBroadcastedFuseBlock {\n                        op: config,\n                        input: reduce.input.clone(),\n                        output: reduce.output.clone(),\n                    };\n                    reduce_axis = reduce.axis;\n                    blocks.push(block);\n                }\n            }\n        }\n\n        let trace = self.fuser.finish();\n\n        ReduceBroadcastedInfo {\n            blocks,\n            trace,\n            reduce_axis,\n        }\n    }\n\n    /// Registers a [ReduceBlockFuser] to build the trace.\n    pub fn register<R: Runtime>(&mut self, block: &ReduceBlockFuser<R>) {\n        // Helper to close previous blocks if necessary\n        if !self.fuser.is_empty() {\n            let mut settings = self.settings_read;\n            settings.vectorization = VectorizationSetting::EqualThanPreviousBlock { block_pos: 0 };\n            settings.ref_layout = RefLayoutSetting::SameAsBlock { block_pos: 0 };\n            self.fuser.next_block([], settings, false);\n\n            let analysis = self.analyzer.retrieve_next();\n\n            for (tensor, block_pos) in analysis.inputs {\n                self.fuser.block_local_input(&tensor, block_pos, false);\n            }\n        }\n\n        match &block.kind {\n            ReduceBlockKind::Elemwise => {\n                for op in &block.ops {\n                    self.fuser.fuse(op);\n                }\n                self.blocks.push(ReduceBlockKind::Elemwise);\n            }\n            ReduceBlockKind::Reduce { ops_index, reduce } => {\n                for op in &block.ops[0..*ops_index] {\n                    self.fuser.fuse(op);\n                }\n\n                let [input] = self\n                    .fuser\n                    .next_block([&reduce.op.input], self.settings_write, false);\n\n                let output = self.fuser.output_unhandled(&reduce.op.out);\n                let analysis = self.analyzer.retrieve_next();\n\n                // Can be broadcasted so the generated buffer can be global.\n                for (tensor, block_pos) in analysis.inputs {\n                    self.fuser.block_local_input(&tensor, block_pos, false);\n                }\n\n                let fused_reduce = FusedReduce {\n                    input,\n                    output,\n                    acc: reduce.acc,\n                    axis: reduce.axis,\n                    op: reduce.op.clone(),\n                    use_planes: reduce.use_planes,\n                    shared: reduce.shared,\n                    inst: reduce.inst,\n                };\n\n                self.blocks.push(ReduceBlockKind::Reduce {\n                    ops_index: *ops_index,\n                    reduce: Box::new(fused_reduce),\n                });\n\n                for op in &block.ops[*ops_index + 1..block.ops.len()] {\n                    self.fuser.fuse(op);\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/full_analyzer.rs",
    "content": "use super::block::ReduceBlockKind;\nuse crate::optim::reduce_broadcasted::fuser::block::ReduceBlockFuser;\nuse burn_ir::{TensorId, TensorIr};\nuse cubecl::Runtime;\nuse std::collections::BTreeMap;\n\n#[derive(Debug)]\npub struct FullFuserAnalyzer {\n    // We need to know the block id of which we can reuse the read local input.\n    analyses: Vec<Vec<(TensorIr, usize)>>,\n}\n\nimpl FullFuserAnalyzer {\n    pub fn new<R: Runtime>(blocks: &[ReduceBlockFuser<R>]) -> Self {\n        let mut state = AnalysisState::default();\n\n        for block in blocks.iter() {\n            for (pos, op) in block.ops.iter().enumerate() {\n                let potential_from_previous_blocks = op.inputs();\n                let potential_to_next_blocks = op.outputs();\n\n                match &block.kind {\n                    ReduceBlockKind::Elemwise => {\n                        state.register(\n                            potential_from_previous_blocks,\n                            potential_to_next_blocks,\n                            BlockKind::Full,\n                        );\n                    }\n                    ReduceBlockKind::Reduce { ops_index, .. } => {\n                        if pos < *ops_index {\n                            state.register(\n                                potential_from_previous_blocks,\n                                potential_to_next_blocks,\n                                BlockKind::Full,\n                            );\n                        } else if pos > *ops_index {\n                            state.register(\n                                potential_from_previous_blocks,\n                                potential_to_next_blocks,\n                                BlockKind::Single,\n                            );\n                        } else {\n                            state.next_block();\n                        }\n                    }\n                }\n            }\n            state.next_block();\n        }\n\n        // First one is never called.\n        state.analyses.remove(0);\n\n        Self {\n            analyses: state.analyses,\n        }\n    }\n\n    pub fn retrieve_next(&mut self) -> FullFuserAnalysis {\n        let inputs = self.analyses.remove(0);\n        FullFuserAnalysis { inputs }\n    }\n}\n\n#[derive(Debug)]\npub struct FullFuserAnalysis {\n    /// The tensor received from a previous block.\n    pub inputs: Vec<(TensorIr, usize)>,\n}\n\n#[derive(Default)]\nstruct AnalysisState {\n    /// That pool contains tensors that are available in the fuse-on-write part of a reduce, not\n    /// broadcasted.\n    available_from_previous_single: BTreeMap<TensorId, usize>,\n    /// That pool contains tensors that are available in the fuse-on-read of a reduce and the\n    /// element-wise broadcasted part\n    available_from_previous_full: BTreeMap<TensorId, usize>,\n    block_data: Vec<(TensorIr, usize)>,\n    analyses: Vec<Vec<(TensorIr, usize)>>,\n    current_full: Vec<TensorIr>,\n    current_single: Vec<TensorIr>,\n}\n\nenum BlockKind {\n    Full,\n    Single,\n}\n\nimpl AnalysisState {\n    fn next_block(&mut self) {\n        let block_pos = self.analyses.len();\n        let data = core::mem::take(&mut self.block_data);\n        self.analyses.push(data);\n\n        // Makes the current tensor reads available for the next block.\n        for p in self.current_single.drain(..) {\n            // We need to keep the earliest block position.\n            self.available_from_previous_single\n                .entry(p.id)\n                .or_insert(block_pos);\n        }\n        for p in self.current_full.drain(..) {\n            // We need to keep the earliest block position.\n            self.available_from_previous_full\n                .entry(p.id)\n                .or_insert(block_pos);\n        }\n    }\n\n    fn register<'a>(\n        &mut self,\n        potential_from_previous_blocks: impl Iterator<Item = &'a TensorIr>,\n        potential_to_next_blocks: impl Iterator<Item = &'a TensorIr>,\n        kind: BlockKind,\n    ) {\n        match kind {\n            BlockKind::Full => {\n                for potential in potential_from_previous_blocks {\n                    // We can't since it's not in the same scope.\n                    //\n                    // TODO: Find a way to merge multiple reduce loops.\n                    //\n                    // if let Some(block_pos) = self.available_from_previous_full.get(&potential.id) {\n                    //     self.block_data.push((potential.clone(), *block_pos));\n                    // }\n\n                    // We can since it's a broadcast.\n                    if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)\n                    {\n                        self.block_data.push((potential.clone(), *block_pos));\n                    }\n\n                    // Can reuse the read.\n                    self.current_full.push(potential.clone());\n                }\n\n                for p in potential_to_next_blocks {\n                    self.current_full.push(p.clone());\n                }\n            }\n            BlockKind::Single => {\n                for potential in potential_from_previous_blocks {\n                    if let Some(block_pos) = self.available_from_previous_single.get(&potential.id)\n                    {\n                        self.block_data.push((potential.clone(), *block_pos));\n                    }\n                    // Can reuse the read.\n                    self.current_single.push(potential.clone());\n                }\n\n                for p in potential_to_next_blocks {\n                    self.current_single.push(p.clone());\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/fuser/mod.rs",
    "content": "mod base;\nmod block;\nmod full;\nmod full_analyzer;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/launch.rs",
    "content": "use crate::{\n    engine::{\n        codegen::ir::{FuseArg, FuseBlockConfig, GlobalArgsLaunch, RefLayout},\n        launch::runner::{TraceRunner, Vectorization},\n    },\n    optim::reduce_broadcasted::unit::{\n        ElemwiseFuseBlockLaunch, ReduceFuseBlockLaunch, reduce_kernel_broadcasted,\n    },\n};\nuse cubecl::{\n    Runtime,\n    ir::{ElemType, FloatKind, StorageType},\n    prelude::*,\n    server::LaunchError,\n};\nuse cubek::reduce::{\n    ReduceDtypes, VectorizationMode,\n    components::instructions::ReduceOperationConfig,\n    launch::RoutineStrategy,\n    routines::{\n        BlueprintStrategy, GlobalReduceBlueprint, ReduceProblem, ReduceVectorSettings, Routine,\n        unit::{UnitRoutine, UnitStrategy},\n    },\n};\nuse serde::{Deserialize, Serialize};\n\n#[derive(Serialize, Deserialize, Debug, Clone)]\npub struct ReduceBroadcastedFuseBlock {\n    pub(crate) op: ReduceOperationConfig,\n    pub(crate) input: FuseArg,\n    pub(crate) output: FuseArg,\n}\n\n#[derive(new)]\npub struct FusedReduceBroadcastedLaunch<'a> {\n    blocks: &'a Vec<ReduceBroadcastedFuseBlock>,\n    reduce_axis: usize,\n    // TODO: Support multiple strategies.\n    _strategy: RoutineStrategy,\n}\n\nimpl<R: Runtime> Vectorization<R> for FusedReduceBroadcastedLaunch<'_> {}\n\nimpl<R: Runtime> TraceRunner<R> for FusedReduceBroadcastedLaunch<'_> {\n    type Error = LaunchError;\n\n    fn run<'a>(\n        &'a self,\n        client: &'a ComputeClient<R>,\n        inputs: GlobalArgsLaunch<R>,\n        outputs: GlobalArgsLaunch<R>,\n        configs: &'a [FuseBlockConfig],\n    ) -> Result<(), Self::Error> {\n        let routine = UnitRoutine;\n        let first_config = &configs[0];\n\n        let shape = match &first_config.ref_layout {\n            RefLayout::Concrete(FuseArg::Output(..)) => {\n                outputs.shape_ref(&first_config.ref_layout, first_config.rank)\n            }\n            _ => inputs.shape_ref(&first_config.ref_layout, first_config.rank),\n        };\n\n        let vector_size = shape[self.reduce_axis];\n        let vector_count = shape.iter().product::<usize>() / vector_size;\n        let address_type = inputs\n            .required_address_type()\n            .max(outputs.required_address_type());\n\n        let (blueprint, settings) = routine\n            .prepare::<R>(\n                client,\n                ReduceProblem {\n                    vector_size,\n                    vector_count,\n                    axis: self.reduce_axis,\n                    dtypes: ReduceDtypes {\n                        input: StorageType::Scalar(ElemType::Float(FloatKind::F32)),\n                        output: StorageType::Scalar(ElemType::Float(FloatKind::F32)),\n                        accumulation: StorageType::Scalar(ElemType::Float(FloatKind::F32)),\n                    },\n                    address_type,\n                },\n                ReduceVectorSettings {\n                    vectorization_mode: VectorizationMode::Parallel,\n                    vector_size_input: first_config.width,\n                    vector_size_output: 1,\n                },\n                BlueprintStrategy::Inferred(UnitStrategy),\n            )\n            .unwrap();\n\n        assert_eq!(blueprint.vectorization_mode, VectorizationMode::Parallel);\n\n        let mut blocks = SequenceArg::new();\n        let mut index = 0;\n\n        for block in self.blocks {\n            let arg = ReduceFuseBlockLaunch::new(\n                block.op,\n                configs[index].clone(),\n                configs[index + 1].clone(),\n                block.input.clone(),\n                block.output.clone(),\n                match blueprint.global {\n                    GlobalReduceBlueprint::Unit(bpt) => bpt,\n                    _ => panic!(),\n                },\n            );\n            index += 2;\n            blocks.push(arg);\n        }\n\n        let block_end = match configs.len() > index {\n            true => ComptimeOptionArgs::Some(ElemwiseFuseBlockLaunch::new(\n                configs.last().cloned().unwrap(),\n            )),\n            false => ComptimeOptionArgs::None,\n        };\n\n        // TODO: Ensure parallel is selected.\n\n        unsafe {\n            reduce_kernel_broadcasted::launch_unchecked::<R>(\n                client,\n                settings.cube_count,\n                settings.cube_dim,\n                settings.address_type,\n                inputs,\n                outputs,\n                self.reduce_axis,\n                blocks,\n                block_end,\n            );\n        }\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/mod.rs",
    "content": "mod fuser;\nmod optimization;\n\npub(crate) mod launch;\npub(crate) mod tune;\npub(crate) mod unit;\n\npub use fuser::*;\npub use optimization::*;\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/optimization.rs",
    "content": "#[cfg(feature = \"autotune\")]\nuse crate::optim::reduce::tune::fused_reduce_autotune;\nuse crate::{\n    CubeFusionHandle, FallbackOperation,\n    engine::{\n        launch::FuseTraceLauncher,\n        trace::{FuseTrace, TraceError, TuneOutput},\n    },\n    optim::{\n        elemwise::{ElemwiseOptimization, ElemwiseOptimizationState},\n        reduce::{ReduceOptimizationInfo, ReduceOptimizationState, ReduceOptimizationTuneArg},\n        reduce_broadcasted::{\n            launch::{FusedReduceBroadcastedLaunch, ReduceBroadcastedFuseBlock},\n            tune::fused_broadcasted_reduce_autotune,\n        },\n    },\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{Runtime, prelude::*};\nuse cubek::reduce::launch::RoutineStrategy;\nuse serde::{Deserialize, Serialize};\nuse std::sync::Arc;\n\npub struct ReduceBroadcastedOptimization<R: Runtime> {\n    pub(crate) info: Arc<ReduceBroadcastedOptimizationInfo<R>>,\n    pub(crate) num_ops: usize,\n}\n\npub(crate) struct ReduceBroadcastedOptimizationInfo<R: Runtime> {\n    pub(crate) fallbacks: Vec<ReduceBlockOptimInfo<R>>,\n    pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,\n}\n\n#[derive(Serialize, Deserialize, Debug, Clone)]\npub(crate) struct ReduceBroadcastedInfo {\n    pub(crate) blocks: Vec<ReduceBroadcastedFuseBlock>,\n    pub(crate) trace: FuseTrace,\n    pub(crate) reduce_axis: usize,\n}\n\npub(crate) enum ReduceBlockOptimInfo<R: Runtime> {\n    Reduce(Arc<ReduceOptimizationInfo<R>>),\n    Elemwise(Arc<ElemwiseOptimization<R>>),\n}\n\nimpl<R: Runtime> ReduceBlockOptimInfo<R> {\n    pub fn from_state(device: &R::Device, state: ReduceBlockState) -> Self {\n        match state {\n            ReduceBlockState::Reduce(state) => {\n                Self::Reduce(Arc::new(ReduceOptimizationInfo::from_state(device, state)))\n            }\n            ReduceBlockState::Elemwise(state) => {\n                Self::Elemwise(Arc::new(ElemwiseOptimization::from_state(device, state)))\n            }\n        }\n    }\n    pub fn to_state(&self) -> ReduceBlockState {\n        match self {\n            Self::Reduce(info) => ReduceBlockState::Reduce(info.to_state()),\n            Self::Elemwise(info) => ReduceBlockState::Elemwise(info.to_state()),\n        }\n    }\n}\n\npub(crate) struct ReduceBroadcastedOptimizationTuneArg<R: Runtime> {\n    pub(crate) fallbacks: Vec<ReduceBlockOptimArg<R>>,\n    pub(crate) broadcasted: Arc<ReduceBroadcastedInfo>,\n    pub(crate) client: ComputeClient<R>,\n    pub(crate) device: R::Device,\n}\n\npub(crate) enum ReduceBlockOptimArg<R: Runtime> {\n    Reduce(ReduceOptimizationTuneArg<R>),\n    Elemwise(Arc<ElemwiseOptimization<R>>),\n}\n\nimpl<R: Runtime> ReduceBlockOptimArg<R> {\n    pub fn execute_fallback(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n    ) -> Option<TuneOutput<R>> {\n        match self {\n            ReduceBlockOptimArg::Reduce(reduce) => {\n                #[cfg(feature = \"autotune\")]\n                {\n                    fused_reduce_autotune::<R>(reduce.clone(), context);\n                    None\n                }\n                #[cfg(not(feature = \"autotune\"))]\n                Some(reduce.execute_fallback(context))\n            }\n            ReduceBlockOptimArg::Elemwise(elem) => {\n                elem.execute(context);\n                None\n            }\n        }\n    }\n}\n\n#[derive(Serialize, Deserialize, Debug)]\npub struct ReduceBroadcastedOptimizationState {\n    fallbacks: Vec<ReduceBlockState>,\n    broadcasted: ReduceBroadcastedInfo,\n    num_ops: usize,\n}\n\n#[derive(Serialize, Deserialize, Debug)]\n#[allow(clippy::large_enum_variant)] // Only for serialization.\npub enum ReduceBlockState {\n    Reduce(ReduceOptimizationState),\n    Elemwise(ElemwiseOptimizationState),\n}\n\nimpl<R: Runtime> ReduceBroadcastedOptimizationTuneArg<R> {\n    pub fn execute_fused(\n        &self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        strategy: RoutineStrategy,\n    ) -> Result<TuneOutput<R>, TraceError<String>> {\n        let launch = FusedReduceBroadcastedLaunch::new(\n            &self.broadcasted.blocks,\n            self.broadcasted.reduce_axis,\n            strategy,\n        );\n        let launcher = FuseTraceLauncher::new(&self.broadcasted.trace, &launch);\n\n        launcher\n            .launch(&self.client, &self.device, context)\n            .map_err(|err| TraceError::RunnerError(format!(\"{:?}\", err)))\n    }\n\n    pub fn execute_fallback(&self, context: &mut Context<'_, CubeFusionHandle<R>>) {\n        for fallback in self.fallbacks.iter() {\n            fallback.execute_fallback(context);\n        }\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\nimpl<R: Runtime> ReduceBroadcastedOptimization<R> {\n    /// Execute the optimization.\n    pub fn execute(\n        &mut self,\n        context: &mut Context<'_, CubeFusionHandle<R>>,\n        fallback: impl Fn(usize) -> Box<dyn FallbackOperation<R>>,\n    ) {\n        let mut current_index = 0;\n        let mut client = None;\n        let mut device = None;\n\n        let fallbacks = self\n            .info\n            .fallbacks\n            .iter()\n            .map(|info| {\n                match info {\n                    ReduceBlockOptimInfo::Reduce(info) => {\n                        // The index of the fallback reduce is the number of ops fused as read.\n                        let fallback = fallback(current_index + info.len_read);\n                        client = Some(info.client.clone());\n                        device = Some(info.device.clone());\n                        let arg = ReduceOptimizationTuneArg {\n                            info: info.clone(),\n                            fallback: Arc::new(fallback),\n                        };\n                        current_index += info.len;\n                        ReduceBlockOptimArg::Reduce(arg)\n                    }\n                    ReduceBlockOptimInfo::Elemwise(op) => ReduceBlockOptimArg::Elemwise(op.clone()),\n                }\n            })\n            .collect();\n\n        let arg = ReduceBroadcastedOptimizationTuneArg {\n            fallbacks,\n            client: client.unwrap(),\n            device: device.unwrap(),\n            broadcasted: self.info.broadcasted.clone(),\n        };\n\n        #[cfg(feature = \"autotune\")]\n        fused_broadcasted_reduce_autotune::<R>(arg, context);\n\n        #[cfg(not(feature = \"autotune\"))]\n        arg.execute_fallback(context);\n    }\n\n    pub fn to_state(&self) -> ReduceBroadcastedOptimizationState {\n        ReduceBroadcastedOptimizationState {\n            fallbacks: self\n                .info\n                .fallbacks\n                .iter()\n                .map(|info| info.to_state())\n                .collect(),\n            broadcasted: self.info.broadcasted.as_ref().clone(),\n            num_ops: self.num_ops,\n        }\n    }\n\n    pub fn from_state(device: &R::Device, state: ReduceBroadcastedOptimizationState) -> Self {\n        Self {\n            info: Arc::new(ReduceBroadcastedOptimizationInfo {\n                fallbacks: state\n                    .fallbacks\n                    .into_iter()\n                    .map(|state| ReduceBlockOptimInfo::from_state(device, state))\n                    .collect(),\n                broadcasted: Arc::new(state.broadcasted),\n            }),\n            num_ops: state.num_ops,\n        }\n    }\n\n    /// Returns the number of output buffers added by fusion.\n    pub fn num_ops_fused(&self) -> usize {\n        self.num_ops\n    }\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/tune.rs",
    "content": "use super::optimization::ReduceBroadcastedOptimizationTuneArg;\nuse crate::{\n    CubeFusionHandle,\n    engine::trace::TuneOutput,\n    optim::{reduce::ReduceOptimizationInfo, reduce_broadcasted::ReduceBlockOptimArg},\n    tune::{TuneContext, TuneInput},\n};\nuse burn_fusion::stream::Context;\nuse cubecl::{\n    AutotuneKey, CubeTuneId, Runtime,\n    tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner},\n};\nuse cubek::reduce::{\n    launch::{RoutineStrategy, tune_key::ReduceAutotuneKey},\n    routines::{BlueprintStrategy, unit::UnitStrategy},\n};\nuse serde::{Deserialize, Serialize};\n\n/// Autotune key for fused broadcasted reduction operations.\n///\n/// Captures the characteristics of the fusion (reads, writes, ops) to ensure\n/// the best kernel is selected for specific fused graph shapes.\n#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]\npub struct FusedBroadcastedReduceAutotuneKey {\n    reduce_key: ReduceAutotuneKey,\n    #[autotune(anchor)]\n    fuse_num_reads: usize,\n    #[autotune(anchor)]\n    fuse_num_writes: usize,\n    #[autotune(anchor)]\n    fuse_num_ops: usize,\n    fuse_num_blocks: usize,\n}\n\n/// Executes the autotuning process for fused reduction operations.\n///\n/// This function initializes a local tuner and attempts multiple strategies\n/// (fallback vs. unit strategy) to find the most efficient execution path.\npub fn fused_broadcasted_reduce_autotune<R: Runtime>(\n    arg: ReduceBroadcastedOptimizationTuneArg<R>,\n    context: &mut Context<CubeFusionHandle<R>>,\n) {\n    static TUNER: LocalTuner<FusedBroadcastedReduceAutotuneKey, CubeTuneId> = local_tuner!();\n\n    let tunables = TUNER.init(|| {\n        const PRIORITY_MAX: i8 = 2;\n        let mut set = TunableSet::new(create_key::<R>, input_gen::<R>);\n\n        let group = TuneGroup::<FusedBroadcastedReduceAutotuneKey>::new(\n            \"fused_reduce_broadcasted\",\n            |_key| PRIORITY_MAX,\n        );\n\n        // Standard fallback implementation - guaranteed to work.\n        set = set.with(Tunable::new(\n            \"fused_reduce_broadcasted_fallback\",\n            tune_fallback::<R>,\n        ));\n\n        // Specialized unit strategy for fused reductions.\n        set = set.with(\n            Tunable::new(\"fused_reduce_broadcasted_unit\", move |input| {\n                tune_reduce::<R>(\n                    input,\n                    &RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),\n                )\n            })\n            .group(&group, |_| PRIORITY_MAX),\n        );\n\n        set\n    });\n\n    TUNER.execute(\n        &CubeTuneId::new(&arg.client, &arg.device),\n        &arg.client.clone(),\n        tunables,\n        TuneInput::new(context, arg),\n    );\n}\n\n/// Generates the autotune key based on the current optimization context and trace blocks.\npub(crate) fn create_key<R: Runtime>(\n    input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,\n) -> FusedBroadcastedReduceAutotuneKey {\n    let opt = input.optimization();\n    let context = match input.context() {\n        TuneContext::Original(context) => context,\n        TuneContext::Fork(_) => unreachable!(\"Forked context not supported for key generation\"),\n    };\n\n    // The fusion must start with a reduction block to be valid here.\n    let info = match &opt.fallbacks[0] {\n        ReduceBlockOptimArg::Reduce(reduce) => &reduce.info,\n        ReduceBlockOptimArg::Elemwise(_) => {\n            unreachable!(\"Fusion must start with a reduction block\")\n        }\n    };\n\n    let key = generate_reduce_autotune_key(info, context);\n\n    // Sum up complexity metrics across all blocks in the fused trace.\n    let (mut num_reads, mut num_writes, mut num_ops) = (0, 0, 0);\n\n    for block in opt.broadcasted.trace.blocks.iter() {\n        num_reads += block.reads.len();\n        num_writes += block.writes.len();\n        num_ops += block.ops.len();\n    }\n\n    FusedBroadcastedReduceAutotuneKey::new(\n        key,\n        num_reads,\n        num_writes,\n        num_ops,\n        info.trace.blocks.len(),\n    )\n}\n\n/// Helper to generate the base reduction key (shapes, types, axes).\nfn generate_reduce_autotune_key<R: Runtime>(\n    info: &ReduceOptimizationInfo<R>,\n    context: &Context<CubeFusionHandle<R>>,\n) -> ReduceAutotuneKey {\n    let input = context.tensors.get(&info.reduce.op.input.id).unwrap();\n    let out = context.tensors.get(&info.reduce.op.out.id).unwrap();\n    let acc = info.reduce.acc.into_elem();\n\n    ReduceAutotuneKey::generate(\n        input.dtype.into(),\n        out.dtype.into(),\n        acc,\n        &input.shape,\n        info.reduce.axis == input.shape.rank() - 1, // Is it the last dimension?\n        info.reduce.axis,\n    )\n}\n\n/// Simple input generator that clones the input for the tuner.\nfn input_gen<R: Runtime>(\n    _key: &FusedBroadcastedReduceAutotuneKey,\n    input: &TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,\n) -> TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>> {\n    input.clone()\n}\n\n/// Executes a fused reduction using a specific routine strategy.\nfn tune_reduce<R: Runtime>(\n    input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,\n    strategy: &RoutineStrategy,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n\n    match input.context() {\n        TuneContext::Original(context) => optimization.execute_fused(context, strategy.clone()),\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fused(&mut context_owned.as_context(), strategy.clone())\n        }\n    }\n    .map_err(|e| format!(\"{e:?}\"))\n}\n\n/// Executes the fallback implementation for the reduction.\nfn tune_fallback<R: Runtime>(\n    input: TuneInput<R, ReduceBroadcastedOptimizationTuneArg<R>>,\n) -> Result<TuneOutput<R>, String> {\n    let optimization = input.optimization();\n\n    match input.context() {\n        TuneContext::Original(context) => optimization.execute_fallback(context),\n        TuneContext::Fork(mut context_owned) => {\n            optimization.execute_fallback(&mut context_owned.as_context())\n        }\n    };\n\n    // Fallback is often used as a baseline, returning unchecked output.\n    Ok(TuneOutput::UnChecked(std::marker::PhantomData))\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/optim/reduce_broadcasted/unit.rs",
    "content": "use crate::{\n    engine::codegen::{\n        ir::{FuseArg, FuseBlockConfig, FuseType, GlobalArgs, multi_block_variables_init},\n        kernel::{fuse_on_write, init_locals},\n    },\n    optim::reduce::args::{FusedReduceArgs, FusedReduceInput, FusedReduceOutput},\n};\nuse cubecl::{Runtime, define_size, prelude::*, std::tensor::r#virtual::VirtualTensor};\nuse cubek::reduce::{\n    ReduceInstruction, ReducePrecision, VectorizationMode,\n    components::{\n        args::NumericLine,\n        global::unit::GlobalFullUnitReduce,\n        instructions::{ReduceOperation, ReduceOperationConfig},\n    },\n    init_tensors,\n    routines::UnitReduceBlueprint,\n};\n\n/// A configuration block for a reduction operation within a fused kernel.\n///\n/// This struct holds all the compile-time information needed to perform a\n/// reduction, including the operation type (Sum, Max, etc.) and the layout\n/// configuration for both input and output.\n#[derive(CubeType, CubeLaunch, Clone)]\npub struct ReduceFuseBlock {\n    #[cube(comptime)]\n    op: ReduceOperationConfig,\n    #[cube(comptime)]\n    config_input: FuseBlockConfig,\n    #[cube(comptime)]\n    config_output: FuseBlockConfig,\n    #[cube(comptime)]\n    input: FuseArg,\n    #[cube(comptime)]\n    output: FuseArg,\n    #[cube(comptime)]\n    blueprint: UnitReduceBlueprint,\n}\n\n/// A configuration block for an elementwise operation that follows a reduction.\n#[derive(CubeType, CubeLaunch, Clone)]\npub struct ElemwiseFuseBlock {\n    #[cube(comptime)]\n    config: FuseBlockConfig,\n}\n\n/// The entry point for a broadcasted reduction kernel.\n///\n/// This kernel initializes local variables for multiple reduction blocks and then\n/// executes the reduction sequence.\n///\n/// # Arguments\n///\n/// * `inputs` - Global arguments containing input tensor handles.\n/// * `outputs` - Global arguments containing output tensor handles.\n/// * `reduce_axis` - The dimension along which the reduction is performed.\n/// * `blocks` - A sequence of reduction operations to execute.\n/// * `block_end` - An optional elementwise block to execute after reductions are complete.\n#[cube(launch_unchecked, address_type = \"dynamic\")]\npub fn reduce_kernel_broadcasted(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    reduce_axis: usize,\n    blocks: Sequence<ReduceFuseBlock>,\n    block_end: ComptimeOption<ElemwiseFuseBlock>,\n) {\n    #[unroll]\n    for i in 0..blocks.len() {\n        let block = blocks.index(i);\n        multi_block_variables_init(&block.config_input, &mut outputs.variables);\n        multi_block_variables_init(&block.config_output, &mut outputs.variables);\n    }\n\n    reduce_many(inputs, outputs, reduce_axis, blocks, block_end);\n}\n\ndefine_scalar!(In);\ndefine_scalar!(Acc);\ndefine_scalar!(Out);\n\ndefine_size!(InSize);\ndefine_size!(OutSize);\n\n/// Configures the precision polyfills for the reduction based on the block's `FuseType`.\n#[cube]\nfn set_polyfill_block(block: &ReduceFuseBlock) {\n    let input_precision = comptime!(block.input.precision());\n    let output_precision = comptime!(block.output.precision());\n    let acc_precision = comptime!(match input_precision {\n        FuseType::F64 => FuseType::F64,\n        FuseType::F32 => FuseType::F32,\n        FuseType::Flex32 => FuseType::F32,\n        FuseType::F16 => FuseType::F32,\n        FuseType::BF16 => FuseType::F32,\n        FuseType::I64 => FuseType::I64,\n        FuseType::I32 => FuseType::I32,\n        FuseType::I16 => FuseType::I32,\n        FuseType::I8 => FuseType::I32,\n        FuseType::U64 => FuseType::U64,\n        FuseType::U32 => FuseType::U32,\n        FuseType::U16 => FuseType::U32,\n        FuseType::U8 => FuseType::U32,\n    });\n\n    set_polyfill::<In, InSize>(comptime!(\n        input_precision.into_type(block.config_input.width)\n    ));\n    set_polyfill::<Out, OutSize>(comptime!(\n        output_precision.into_type(block.config_output.width)\n    ));\n    set_polyfill::<Acc, InSize>(comptime!(acc_precision.into_type(block.config_input.width)));\n}\n\n/// Internal logic for executing a sequence of reduction blocks followed by an optional\n/// trailing elementwise block.\n#[cube]\n#[allow(clippy::clone_on_copy)]\nfn reduce_many(\n    inputs: &GlobalArgs,\n    outputs: &mut GlobalArgs,\n    reduce_axis: usize,\n    blocks: Sequence<ReduceFuseBlock>,\n    block_end: ComptimeOption<ElemwiseFuseBlock>,\n) {\n    let mut axis_size = 0;\n\n    #[unroll]\n    for i in 0..blocks.len() {\n        let block = blocks.index(i);\n        let input = FusedReduceInput {\n            global: inputs.clone(),\n            config: comptime!(block.config_input.clone()),\n            arg: comptime!(block.input.clone()),\n        };\n        let global = outputs.clone();\n        let config = comptime!(block.config_output.clone());\n        let arg = comptime!(block.output.clone());\n        let mut output = FusedReduceOutput {\n            global,\n            config,\n            arg,\n        };\n\n        set_polyfill_block(block);\n        let (input, mut output) =\n            init_tensors::<FusedReduceArgs, In, InSize, Out, OutSize>(&input, &mut output);\n\n        axis_size = reduce_step::<(In, InSize, Acc), (Out, OutSize), ReduceOperation>(\n            &input,\n            &mut output,\n            reduce_axis,\n            block.op,\n            comptime!(block.blueprint.clone()),\n        );\n    }\n\n    #[comptime]\n    if let ComptimeOption::Some(block) = block_end {\n        let global_index = ABSOLUTE_POS;\n        let width = block.config.width;\n        let num_iter = axis_size / width;\n        let size!(N) = width;\n\n        for i in 0..num_iter {\n            // Register block local inputs.\n            let values = Registry::<FuseArg, Vector<f32, N>>::new();\n            let args = comptime![Vec::<FuseArg>::new()];\n            let index = global_index * num_iter + i;\n            let mut locals = init_locals(inputs, outputs, &block.config);\n\n            fuse_on_write::<f32, N>(\n                inputs,\n                outputs,\n                &mut locals,\n                index,\n                values,\n                args,\n                &block.config.clone(),\n            )\n        }\n    }\n}\n\n#[cube]\n/// Executes a single reduction step using a specified instruction and blueprint.\n///\n/// Returns the size of the axis that was reduced.\nfn reduce_step<P: ReducePrecision, Out: NumericLine, I: ReduceInstruction<P>>(\n    input: &VirtualTensor<P::EI, P::SI>,\n    output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,\n    reduce_axis: usize,\n    #[comptime] config: I::Config,\n    #[comptime] blueprint: UnitReduceBlueprint,\n) -> usize {\n    let inst = I::from_config(config);\n    let axis_size = input.shape(reduce_axis);\n\n    GlobalFullUnitReduce::execute::<P, Out, I>(\n        input,\n        output,\n        reduce_axis,\n        &inst,\n        VectorizationMode::Parallel,\n        comptime!(blueprint),\n    );\n    axis_size\n}\n"
  },
  {
    "path": "crates/burn-cubecl-fusion/src/tune.rs",
    "content": "use crate::CubeFusionHandle;\nuse burn_fusion::stream::{Context, ContextOwned};\nuse cubecl::Runtime;\nuse std::sync::Arc;\n\n/// Fusion context used when tuning kernels.\n///\n/// Either the original context is returned or a fork of the original.\n/// The fork is only given when performing autotuning, and not when actually performing the\n/// operation.\npub enum TuneContext<'a, R: Runtime> {\n    Original(&'a mut Context<'a, CubeFusionHandle<R>>),\n    Fork(Box<ContextOwned<CubeFusionHandle<R>>>),\n}\n\n/// Fusion input wrapper containing the context and the optimization.\n///\n/// # Safety\n///\n/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions\n/// are made based on its behavior.\npub struct TuneInput<R: Runtime, O> {\n    context: UnsafeTuneContext<R>,\n    optimization: Arc<O>,\n}\n\n/// Unsafe wrapper around the context.\n///\n/// # Safety\n///\n/// The wrapper removes the context lifetime.\n///\n/// For it to be correct, the context must not be used after the invocation of the\n/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are\n/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find\n/// the best kernel to use, which can be async.\nenum UnsafeTuneContext<R: Runtime> {\n    Original(*mut Context<'static, CubeFusionHandle<R>>),\n    Fork(Box<ContextOwned<CubeFusionHandle<R>>>),\n}\n\nunsafe impl<R: Runtime> Send for UnsafeTuneContext<R> {}\nunsafe impl<R: Runtime, O> Send for TuneInput<R, O> {}\n\nimpl<R: Runtime, O> TuneInput<R, O> {\n    /// Create a new autotune input from the [context](Context) and an optimization.\n    pub fn new(context: &mut Context<CubeFusionHandle<R>>, optimization: O) -> Self {\n        let context = UnsafeTuneContext::new(context);\n\n        Self {\n            context,\n            optimization: Arc::new(optimization),\n        }\n    }\n\n    /// Retrieve the [autotune context](TuneContext) for the current input.\n    pub fn context(&self) -> TuneContext<'static, R> {\n        self.context.get()\n    }\n\n    /// Retrieve the optimization for the current input.\n    pub fn optimization(&self) -> &O {\n        &self.optimization\n    }\n}\n\nimpl<R: Runtime> UnsafeTuneContext<R> {\n    fn new(context: &mut Context<'_, CubeFusionHandle<R>>) -> Self {\n        let ptr = core::ptr::from_mut(context);\n\n        // It is necessary for the lifetime.\n        #[allow(clippy::unnecessary_cast)]\n        Self::Original(ptr as *mut Context<'static, _>)\n    }\n\n    fn get(&self) -> TuneContext<'static, R> {\n        match self {\n            UnsafeTuneContext::Original(ptr) => {\n                TuneContext::Original(unsafe { ptr.as_mut().unwrap() })\n            }\n            UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())),\n        }\n    }\n}\n\nimpl<R: Runtime, O> Clone for TuneInput<R, O> {\n    fn clone(&self) -> Self {\n        Self {\n            context: self.context.clone(),\n            optimization: self.optimization.clone(),\n        }\n    }\n}\n\nimpl<R: Runtime> Clone for UnsafeTuneContext<R> {\n    fn clone(&self) -> Self {\n        let context = match self {\n            UnsafeTuneContext::Original(ptr) => {\n                let context: &mut Context<'static, CubeFusionHandle<R>> =\n                    unsafe { ptr.as_mut().unwrap() };\n                context.fork()\n            }\n            UnsafeTuneContext::Fork(context) => context.fork(),\n        };\n        UnsafeTuneContext::Fork(Box::new(context))\n    }\n}\n"
  },
  {
    "path": "crates/burn-cuda/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"CUDA backend for the Burn framework\"\ndocumentation = \"https://docs.rs/burn-cuda\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\", \"cuda\"]\nlicense.workspace = true\nname = \"burn-cuda\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\nautotune = [\"burn-cubecl/autotune\"]\nautotune-checks = [\"burn-cubecl/autotune-checks\"]\ndefault = [\"std\", \"fusion\", \"autotune\", \"burn-cubecl/default\", \"cubecl/default\"]\ndoc = [\"burn-cubecl/doc\"]\nfusion = [\"burn-fusion\", \"burn-cubecl/fusion\"]\nstd = [\"burn-cubecl/std\", \"cubecl/std\"]\ntracing = [\n    \"burn-backend/tracing\",\n    \"burn-cubecl/tracing\",\n    \"burn-fusion?/tracing\",\n    \"cubecl/tracing\",\n]\n\n[dependencies]\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"cubecl-cuda\",\n] }\ncubecl = { workspace = true, features = [\"cuda\"] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-cuda/README.md",
    "content": "# Burn CUDA Backend\n\n[Burn](https://github.com/tracel-ai/burn) CUDA backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-cuda.svg)](https://crates.io/crates/burn-cuda)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-cuda/blob/master/README.md)\n\nThis crate provides a CUDA backend for [Burn](https://github.com/tracel-ai/burn) using the\n[cubecl](https://github.com/tracel-ai/cubecl.git) and [cudarc](https://github.com/coreylowman/cudarc.git)\ncrates.\n\n## Usage Example\n\n```rust\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use burn_autodiff::Autodiff;\n    use burn_cuda::{Cuda, CudaDevice};\n    use mnist::training;\n\n    pub fn run() {\n        let device = CudaDevice::default();\n        training::run::<Autodiff<Cuda<f32, i32>>>(device);\n    }\n}\n```\n\n## Dependencies\n\nRequires CUDA 12.x to be installed and on the `PATH`."
  },
  {
    "path": "crates/burn-cuda/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n\nextern crate alloc;\n\nuse burn_cubecl::CubeBackend;\npub use cubecl::cuda::CudaDevice;\nuse cubecl::cuda::CudaRuntime;\n\n#[cfg(not(feature = \"fusion\"))]\npub type Cuda<F = f32, I = i32> = CubeBackend<CudaRuntime, F, I, u8>;\n\n#[cfg(feature = \"fusion\")]\npub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<CubeBackend<CudaRuntime, F, I, u8>>;\n\n#[cfg(all(test, not(target_os = \"macos\")))]\nmod tests {\n    use super::*;\n    use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive};\n    use burn_cubecl::tensor::CubeTensor;\n\n    #[test]\n    fn should_support_dtypes() {\n        type B = Cuda;\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::Flex32));\n        assert!(B::supports_dtype(&device, DType::F16));\n        assert!(B::supports_dtype(&device, DType::BF16));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::I16));\n        assert!(B::supports_dtype(&device, DType::I8));\n        assert!(B::supports_dtype(&device, DType::U64));\n        assert!(B::supports_dtype(&device, DType::U32));\n        assert!(B::supports_dtype(&device, DType::U16));\n        assert!(B::supports_dtype(&device, DType::U8));\n        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n        assert!(B::supports_dtype(\n            &device,\n            DType::QFloat(CubeTensor::<CudaRuntime>::default_scheme())\n        ));\n\n        // Currently not registered in supported types\n        assert!(!B::supports_dtype(&device, DType::F64));\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Library with simple dataset APIs for creating ML data pipelines\"\ndocumentation = \"https://docs.rs/burn-dataset\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-dataset\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-dataset\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"sqlite-bundled\"]\ndoc = [\"default\"]\ntracing = [\n    \"burn-std/tracing\",\n]\n\naudio = [\"hound\"]\nbuiltin-sources = [\"vision\", \"dep:tar\", \"nlp\"]\nfake = [\"dep:fake\"]\nnetwork = [\"dep:burn-std\"]\nsqlite = [\"__sqlite-shared\", \"dep:rusqlite\"]\nsqlite-bundled = [\"__sqlite-shared\", \"rusqlite/bundled\"]\nvision = [\"dep:flate2\", \"dep:globwalk\", \"dep:image\", \"network\"]\nnlp = [\"dep:zip\", \"dep:encoding_rs\"]\n# internal\n__sqlite-shared = [\n    \"dep:r2d2\",\n    \"dep:r2d2_sqlite\",\n    \"dep:serde_rusqlite\",\n    \"dep:image\",\n    \"dep:gix-tempfile\",\n]\ndataframe = [\"dep:polars\", \"dep:planus\"]\n\n[dependencies]\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", optional = true, features = [\n    \"network\",\n] }\ncsv = { workspace = true }\nderive-new = { workspace = true }\ndirs = { workspace = true }\nfake = { workspace = true, optional = true }\nflate2 = { workspace = true, optional = true }\ngix-tempfile = { workspace = true, optional = true }\nglobwalk = { workspace = true, optional = true }\nhound = { workspace = true, optional = true }\nimage = { workspace = true, optional = true }\nplanus = { workspace = true, optional = true }\nencoding_rs = { workspace = true, optional = true }\npolars = { workspace = true, optional = true }\nr2d2 = { workspace = true, optional = true }\nr2d2_sqlite = { workspace = true, optional = true }\nrand = { workspace = true, features = [\"std\", \"sys_rng\"] }\nzip = { workspace = true, optional = true }\nrmp-serde = { workspace = true }\nrusqlite = { workspace = true, optional = true }\nsanitize-filename = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\nserde_json = { workspace = true, features = [\"std\"] }\nserde_rusqlite = { workspace = true, optional = true }\nstrum = { workspace = true }\ntar = { workspace = true, optional = true }\ntempfile = { workspace = true }\nthiserror = { workspace = true }\n\n\n[dev-dependencies]\nfake = { workspace = true }\nrayon = { workspace = true }\nrstest = { workspace = true }\n\n[package.metadata.cargo-udeps.ignore]\nnormal = [\"strum\", \"strum_macros\"]\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-dataset/README.md",
    "content": "# Burn Dataset\n\n> [Burn](https://github.com/tracel-ai/burn) dataset library\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-dataset.svg)](https://crates.io/crates/burn-dataset)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-dataset/blob/master/README.md)\n\nThe Burn Dataset library is designed to streamline your machine learning (ML) data pipeline creation\nprocess. It offers a variety of dataset implementations, transformation functions, and data sources.\n\n## Feature Flags\n\n- `audio` - enables audio dataset (SpeechCommandsDataset). Run the following example to try it out:\n\n  ```shell\n  cargo run --example speech_commands --features audio\n  ```\n"
  },
  {
    "path": "crates/burn-dataset/examples/hf_dataset.rs",
    "content": "use burn_dataset::HuggingfaceDatasetLoader;\nuse burn_dataset::SqliteDataset;\nuse serde::Deserialize;\n\n#[derive(Deserialize, Debug, Clone)]\nstruct MnistItemRaw {\n    pub _image_bytes: Vec<u8>,\n    pub _label: usize,\n}\nfn main() {\n    // There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,\n    // In this cases you must enable trusting remote code execution if you want to use it.\n    let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new(\"mnist\")\n        .with_trust_remote_code(true)\n        .dataset(\"train\")\n        .unwrap();\n\n    // However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main\n    let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new(\"Anthropic/hh-rlhf\")\n        .dataset(\"train\")\n        .unwrap();\n}\n"
  },
  {
    "path": "crates/burn-dataset/examples/speech_commands.rs",
    "content": "#[cfg(feature = \"audio\")]\nuse burn_dataset::{Dataset, audio::SpeechCommandsDataset};\n\n#[cfg(feature = \"audio\")]\nfn speech_command() {\n    let index: usize = 4835;\n    let test = SpeechCommandsDataset::test();\n    let item = test.get(index).unwrap();\n\n    println!(\"Item: {:?}\", item);\n    println!(\"Item Length: {:?}\", item.audio_samples.len());\n    println!(\"Label: {}\", item.label);\n\n    assert_eq!(test.len(), 4890);\n    assert_eq!(item.label.to_string(), \"Yes\");\n    assert_eq!(item.sample_rate, 16000);\n    assert_eq!(item.audio_samples.len(), 16000);\n}\n\nfn main() {\n    #[cfg(feature = \"audio\")]\n    speech_command()\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/audio/mod.rs",
    "content": "mod speech_commands;\n\npub use speech_commands::*;\n"
  },
  {
    "path": "crates/burn-dataset/src/audio/speech_commands.rs",
    "content": "use crate::{\n    Dataset, HuggingfaceDatasetLoader, SqliteDataset,\n    transform::{Mapper, MapperDataset},\n};\n\nuse hound::WavReader;\nuse serde::{Deserialize, Serialize};\nuse strum::{Display, EnumCount, FromRepr};\n\ntype MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples, SpeechItemRaw>;\n\n/// Enum representing speech command classes in the Speech Commands dataset.\n/// Class names are based on the Speech Commands dataset from Huggingface.\n/// See [speech_commands](https://huggingface.co/datasets/speech_commands)\n/// for more information.\n#[allow(missing_docs)]\n#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]\npub enum SpeechCommandClass {\n    // Target command words\n    Yes = 0,\n    No = 1,\n    Up = 2,\n    Down = 3,\n    Left = 4,\n    Right = 5,\n    On = 6,\n    Off = 7,\n    Stop = 8,\n    Go = 9,\n    Zero = 10,\n    One = 11,\n    Two = 12,\n    Three = 13,\n    Four = 14,\n    Five = 15,\n    Six = 16,\n    Seven = 17,\n    Eight = 18,\n    Nine = 19,\n\n    // Non-target words that can be grouped into \"Other\"\n    Bed = 20,\n    Bird = 21,\n    Cat = 22,\n    Dog = 23,\n    Happy = 24,\n    House = 25,\n    Marvin = 26,\n    Sheila = 27,\n    Tree = 28,\n    Wow = 29,\n\n    // Commands from v2 dataset, that can be grouped into \"Other\"\n    Backward = 30,\n    Forward = 31,\n    Follow = 32,\n    Learn = 33,\n    Visual = 34,\n\n    // Background noise\n    Silence = 35,\n\n    // Other miscellaneous words\n    Other = 36,\n}\n\n/// Struct containing raw speech data returned from a database.\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct SpeechItemRaw {\n    /// Audio file bytes.\n    pub audio_bytes: Vec<u8>,\n\n    /// Label index.\n    pub label: usize,\n\n    /// Indicates if the label is unknown.\n    pub is_unknown: bool,\n}\n\n/// Speech item with audio samples and label.\n///\n/// The audio samples are floats in the range [-1.0, 1.0].\n/// The sample rate is in Hz.\n/// The label is the class index (see [SpeechCommandClass]).\n/// To convert to usize simply use `as usize`. To convert label to string use `.to_string()`.\n///\n/// The original label is also stored in the `label_original` field for debugging and remapping if needed.\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct SpeechItem {\n    /// Audio samples in the range [-1.0, 1.0].\n    pub audio_samples: Vec<f32>,\n\n    /// The sample rate of the audio.\n    pub sample_rate: usize,\n\n    /// The label of the audio.\n    pub label: SpeechCommandClass,\n}\n\n/// Speech Commands dataset from Huggingface v0.02.\n/// See [Speech Commands dataset](https://huggingface.co/datasets/speech_commands).\n///\n/// The data is downloaded from Huggingface and stored in a SQLite database (3.0 GB).\n/// The dataset contains 99,720 audio samples of 2,607 people saying 35 different words.\n///\n/// NOTE: The most samples are under 1 second long but there are some with pure background noise that\n/// need splitting into shorter segmants.\n///\n/// The labels are 20 target words, silence and other words.\n///\n/// The dataset is split into 3 parts:\n/// - train: 84,848 audio files\n/// - test: 4,890 audio files\n/// - validation: 9,982 audio files\npub struct SpeechCommandsDataset {\n    dataset: MappedDataset,\n}\n\nimpl SpeechCommandsDataset {\n    /// Create a new dataset with the given split.\n    pub fn new(split: &str) -> Self {\n        let dataset: SqliteDataset<SpeechItemRaw> =\n            HuggingfaceDatasetLoader::new(\"speech_commands\")\n                .with_subset(\"v0.02\")\n                .dataset(split)\n                .unwrap();\n        let dataset = MapperDataset::new(dataset, ConvertSamples);\n        Self { dataset }\n    }\n\n    /// Create a new dataset with the train split.\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    /// Create a new dataset with the test split.\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    /// Create a new dataset with the validation split.\n    pub fn validation() -> Self {\n        Self::new(\"validation\")\n    }\n\n    /// Returns the number of classes in the dataset\n    pub fn num_classes() -> usize {\n        SpeechCommandClass::COUNT\n    }\n}\n\nimpl Dataset<SpeechItem> for SpeechCommandsDataset {\n    fn get(&self, index: usize) -> Option<SpeechItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\n/// Mapper converting audio bytes into audio samples and the label to enum class.\nstruct ConvertSamples;\n\nimpl ConvertSamples {\n    /// Convert label to enum class.\n    fn to_speechcommandclass(label: usize) -> SpeechCommandClass {\n        SpeechCommandClass::from_repr(label).unwrap()\n    }\n\n    /// Convert audio bytes into samples of floats [-1.0, 1.0].\n    fn to_audiosamples(bytes: &Vec<u8>) -> (Vec<f32>, usize) {\n        let reader = WavReader::new(bytes.as_slice()).unwrap();\n        let spec = reader.spec();\n\n        // Maximum value of the audio samples (using bit shift to raise 2 to the power of bits per sample).\n        let max_value = (1 << (spec.bits_per_sample - 1)) as f32;\n\n        // The sample rate of the audio.\n        let sample_rate = spec.sample_rate as usize;\n\n        // Convert the audio samples to floats [-1.0, 1.0].\n        let audio_samples: Vec<f32> = reader\n            .into_samples::<i32>()\n            .filter_map(Result::ok)\n            .map(|sample| sample as f32 / max_value)\n            .collect();\n\n        (audio_samples, sample_rate)\n    }\n}\n\nimpl Mapper<SpeechItemRaw, SpeechItem> for ConvertSamples {\n    /// Convert audio bytes into samples of floats [-1.0, 1.0]\n    /// and the label to enum class with the target word, other and silence classes.\n    fn map(&self, item: &SpeechItemRaw) -> SpeechItem {\n        let (audio_samples, sample_rate) = Self::to_audiosamples(&item.audio_bytes);\n\n        // Convert the label to enum class, with the target words, other and silence classes.\n        let label = Self::to_speechcommandclass(item.label);\n\n        SpeechItem {\n            audio_samples,\n            sample_rate,\n            label,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/base.rs",
    "content": "use std::sync::Arc;\n\nuse crate::DatasetIterator;\n\n/// The dataset trait defines a basic collection of items with a predefined size.\npub trait Dataset<I>: Send + Sync {\n    /// Gets the item at the given index.\n    fn get(&self, index: usize) -> Option<I>;\n\n    /// Gets the number of items in the dataset.\n    fn len(&self) -> usize;\n\n    /// Checks if the dataset is empty.\n    fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n\n    /// Returns an iterator over the dataset.\n    fn iter(&self) -> DatasetIterator<'_, I>\n    where\n        Self: Sized,\n    {\n        DatasetIterator::new(self)\n    }\n}\n\nimpl<D, I> Dataset<I> for Arc<D>\nwhere\n    D: Dataset<I>,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        self.as_ref().get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.as_ref().len()\n    }\n}\n\nimpl<I> Dataset<I> for Arc<dyn Dataset<I>> {\n    fn get(&self, index: usize) -> Option<I> {\n        self.as_ref().get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.as_ref().len()\n    }\n}\n\nimpl<D, I> Dataset<I> for Box<D>\nwhere\n    D: Dataset<I>,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        self.as_ref().get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.as_ref().len()\n    }\n}\n\nimpl<I> Dataset<I> for Box<dyn Dataset<I>> {\n    fn get(&self, index: usize) -> Option<I> {\n        self.as_ref().get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.as_ref().len()\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/dataframe.rs",
    "content": "use std::marker::PhantomData;\n\nuse crate::Dataset;\n\nuse polars::frame::row::Row;\nuse polars::prelude::*;\nuse serde::de::DeserializeSeed;\nuse serde::{\n    Deserialize,\n    de::{self, DeserializeOwned, Deserializer, SeqAccess, Visitor},\n    forward_to_deserialize_any,\n};\n\n/// Error type for DataframeDataset\n#[derive(thiserror::Error, Debug)]\npub enum DataframeDatasetError {\n    /// Error occurred during deserialization or other operations\n    #[error(\"{0}\")]\n    Other(String),\n}\n\nimpl de::Error for DataframeDatasetError {\n    fn custom<T: std::fmt::Display>(msg: T) -> Self {\n        DataframeDatasetError::Other(msg.to_string())\n    }\n}\n\n/// Dataset implementation for Polars DataFrame\n///\n/// This struct provides a way to access data from a Polars DataFrame\n/// as if it were a Dataset of type I.\npub struct DataframeDataset<I> {\n    df: DataFrame,\n    len: usize,\n    column_name_mapping: Vec<usize>,\n    phantom: PhantomData<I>,\n}\n\nimpl<I> DataframeDataset<I>\nwhere\n    I: Clone + Send + Sync + DeserializeOwned,\n{\n    /// Create a new DataframeDataset from a Polars DataFrame\n    ///\n    /// # Arguments\n    ///\n    /// * `df` - A Polars DataFrame\n    ///\n    /// # Returns\n    ///\n    /// A Result containing the new DataframeDataset or a DataframeDatasetError\n    pub fn new(df: DataFrame) -> Result<Self, DataframeDatasetError> {\n        let len = df.height();\n        let field_names = extract_field_names::<I>();\n\n        let column_name_mapping = field_names\n            .iter()\n            .map(|name| {\n                df.schema()\n                    .try_get_full(name)\n                    .expect(\"Corresponding column should exist in the DataFrame\")\n                    .0\n            })\n            .collect::<Vec<_>>();\n\n        Ok(DataframeDataset {\n            df,\n            len,\n            column_name_mapping,\n            phantom: PhantomData,\n        })\n    }\n}\n\nimpl<I> Dataset<I> for DataframeDataset<I>\nwhere\n    I: Clone + Send + Sync + DeserializeOwned,\n{\n    /// Get an item from the dataset at the specified index\n    ///\n    /// # Arguments\n    ///\n    /// * `index` - The index of the item to retrieve\n    ///\n    /// # Returns\n    ///\n    /// An Option containing the item if it exists, or None if it doesn't\n    fn get(&self, index: usize) -> Option<I> {\n        let row = self.df.get_row(index).ok()?;\n\n        let mut deserializer = RowDeserializer::new(&row, &self.column_name_mapping);\n        I::deserialize(&mut deserializer).ok()\n    }\n\n    /// Get the length of the dataset\n    fn len(&self) -> usize {\n        self.len\n    }\n\n    /// Check if the dataset is empty\n    fn is_empty(&self) -> bool {\n        self.len == 0\n    }\n}\n\n/// A deserializer for Polars DataFrame rows\nstruct RowDeserializer<'a> {\n    row: &'a Row<'a>,\n    column_name_mapping: &'a Vec<usize>,\n    index: usize,\n}\n\nimpl<'a> RowDeserializer<'a> {\n    /// Create a new RowDeserializer\n    ///\n    /// # Arguments\n    ///\n    /// * `row` - A reference to a Polars DataFrame row\n    /// * `column_name_mapping` - A reference to a vector mapping field names to column indices\n    fn new(row: &'a Row, column_name_mapping: &'a Vec<usize>) -> RowDeserializer<'a> {\n        RowDeserializer {\n            row,\n            column_name_mapping,\n            index: 0,\n        }\n    }\n}\n\nimpl<'de, 'a> Deserializer<'de> for &'a mut RowDeserializer<'a> {\n    type Error = DataframeDatasetError;\n\n    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, DataframeDatasetError>\n    where\n        V: Visitor<'de>,\n    {\n        let i = self.column_name_mapping[self.index];\n\n        let value = &self.row.0[i];\n        match value {\n            AnyValue::Null => visitor.visit_none(),\n            AnyValue::Boolean(b) => visitor.visit_bool(*b),\n            AnyValue::Int8(i) => visitor.visit_i8(*i),\n            AnyValue::Int16(i) => visitor.visit_i16(*i),\n            AnyValue::Int32(i) => visitor.visit_i32(*i),\n            AnyValue::Int64(i) => visitor.visit_i64(*i),\n            AnyValue::UInt8(i) => visitor.visit_u8(*i),\n            AnyValue::UInt16(i) => visitor.visit_u16(*i),\n            AnyValue::UInt32(i) => visitor.visit_u32(*i),\n            AnyValue::UInt64(i) => visitor.visit_u64(*i),\n            AnyValue::Float32(f) => visitor.visit_f32(*f),\n            AnyValue::Float64(f) => visitor.visit_f64(*f),\n            AnyValue::Date(i) => visitor.visit_i32(*i),\n            AnyValue::String(s) => visitor.visit_string(s.to_string()),\n            AnyValue::Binary(b) => {\n                visitor.visit_seq(de::value::SeqDeserializer::new(b.iter().copied()))\n            }\n            AnyValue::Time(t) => visitor.visit_i64(*t),\n            ty => Err(DataframeDatasetError::Other(\n                format!(\"Unsupported type: {ty:?}\").to_string(),\n            )),\n        }\n    }\n\n    fn deserialize_struct<V>(\n        self,\n        _name: &'static str,\n        _fields: &'static [&'static str],\n        visitor: V,\n    ) -> Result<V::Value, DataframeDatasetError>\n    where\n        V: Visitor<'de>,\n    {\n        visitor.visit_seq(self)\n    }\n\n    forward_to_deserialize_any! {\n        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string\n        bytes byte_buf option unit unit_struct newtype_struct seq tuple\n        tuple_struct map enum identifier ignored_any\n    }\n}\n\nimpl<'de, 'a> SeqAccess<'de> for RowDeserializer<'a> {\n    type Error = DataframeDatasetError;\n\n    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DataframeDatasetError>\n    where\n        T: DeserializeSeed<'de>,\n    {\n        if self.index >= self.row.0.len() {\n            return Ok(None);\n        }\n        let mut deserializer = RowDeserializer {\n            row: self.row,\n            column_name_mapping: self.column_name_mapping,\n            index: self.index,\n        };\n        self.index += 1;\n        seed.deserialize(&mut deserializer).map(Some)\n    }\n}\n\nstruct FieldExtractor {\n    fields: Vec<&'static str>,\n}\n\nimpl<'de> Deserializer<'de> for &mut FieldExtractor {\n    type Error = de::value::Error;\n\n    fn deserialize_any<V>(self, _visitor: V) -> core::result::Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        Err(de::Error::custom(\"Field extractor\"))\n    }\n\n    fn deserialize_struct<V>(\n        self,\n        _name: &'static str,\n        fields: &'static [&'static str],\n        _visitor: V,\n    ) -> core::result::Result<V::Value, Self::Error>\n    where\n        V: Visitor<'de>,\n    {\n        self.fields.extend_from_slice(fields);\n        Err(de::Error::custom(\"Field extractor\"))\n    }\n\n    forward_to_deserialize_any! {\n        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes\n        byte_buf option unit unit_struct newtype_struct seq tuple\n        tuple_struct map enum identifier ignored_any\n    }\n}\n\n/// Extract field names from a type T that implements Deserialize\n///\n/// # Returns\n///\n/// A vector of field names as static string slices\nfn extract_field_names<'de, T>() -> Vec<&'static str>\nwhere\n    T: Deserialize<'de>,\n{\n    let mut extractor = FieldExtractor { fields: Vec::new() };\n    let _ = T::deserialize(&mut extractor);\n    extractor.fields\n}\n\n#[cfg(test)]\nmod tests {\n    use polars::prelude::*;\n    use serde::Deserialize;\n\n    use super::*;\n    #[derive(Clone, Debug, Deserialize, PartialEq)]\n    struct TestData {\n        int32: i32,\n        bool: bool,\n        float64: f64,\n        string: String,\n        int16: i16,\n        uint32: u32,\n        uint64: u64,\n        float32: f32,\n        int64: i64,\n        int8: i8,\n        binary: Vec<u8>,\n    }\n\n    fn create_test_dataframe() -> DataFrame {\n        let s0 = Column::new(\"int32\".into(), &[1i32, 2i32, 3i32]);\n        let s1 = Column::new(\"bool\".into(), &[true, false, true]);\n        let s2 = Column::new(\"float64\".into(), &[1.1f64, 2.2f64, 3.3f64]);\n        let s3 = Column::new(\"string\".into(), &[\"Boo\", \"Boo2\", \"Boo3\"]);\n        let s6 = Column::new(\"int16\".into(), &[1i16, 2i16, 3i16]);\n        let s8 = Column::new(\"uint32\".into(), &[1u32, 2u32, 3u32]);\n        let s9 = Column::new(\"uint64\".into(), &[1u64, 2u64, 3u64]);\n        let s10 = Column::new(\"float32\".into(), &[1.1f32, 2.2f32, 3.3f32]);\n        let s11 = Column::new(\"int64\".into(), &[1i64, 2i64, 3i64]);\n        let s12 = Column::new(\"int8\".into(), &[1i8, 2i8, 3i8]);\n\n        let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];\n\n        let s13 = Column::new(\"binary\".into(), binary_data);\n        DataFrame::new_infer_height(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()\n    }\n\n    #[test]\n    fn test_dataframe_dataset_creation() {\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<TestData>::new(df);\n        assert!(dataset.is_ok());\n    }\n\n    #[test]\n    fn test_dataframe_dataset_length() {\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<TestData>::new(df).unwrap();\n        assert_eq!(dataset.len(), 3);\n        assert!(!dataset.is_empty());\n    }\n\n    #[test]\n    fn test_dataframe_dataset_get() {\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<TestData>::new(df).unwrap();\n\n        let expected_items = vec![\n            TestData {\n                int32: 1,\n                bool: true,\n                float64: 1.1,\n                string: \"Boo\".to_string(),\n                int16: 1,\n                uint32: 1,\n                uint64: 1,\n                float32: 1.1,\n                int64: 1,\n                int8: 1,\n                binary: vec![1, 2, 3],\n            },\n            TestData {\n                int32: 2,\n                bool: false,\n                float64: 2.2,\n                string: \"Boo2\".to_string(),\n                int16: 2,\n                uint32: 2,\n                uint64: 2,\n                float32: 2.2,\n                int64: 2,\n                int8: 2,\n                binary: vec![4, 5, 6],\n            },\n            TestData {\n                int32: 3,\n                bool: true,\n                float64: 3.3,\n                string: \"Boo3\".to_string(),\n                int16: 3,\n                uint32: 3,\n                uint64: 3,\n                float32: 3.3,\n                int64: 3,\n                int8: 3,\n                binary: vec![7, 8, 9],\n            },\n        ];\n\n        for (index, expected_item) in expected_items.iter().enumerate() {\n            let item = dataset.get(index).unwrap();\n            assert_eq!(&item, expected_item);\n        }\n    }\n\n    #[test]\n    fn test_dataframe_dataset_out_of_bounds() {\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<TestData>::new(df).unwrap();\n        assert!(dataset.get(3).is_none());\n    }\n\n    #[test]\n    fn test_dataframe_dataset() {\n        let df = create_test_dataframe();\n        let dataset: DataframeDataset<TestData> = DataframeDataset::new(df).unwrap();\n\n        assert_eq!(dataset.len(), 3);\n        assert!(!dataset.is_empty());\n\n        let item = dataset.get(1).unwrap();\n        assert_eq!(\n            item,\n            TestData {\n                int32: 2,\n                bool: false,\n                float64: 2.2,\n                string: \"Boo2\".to_string(),\n                int16: 2,\n                uint32: 2,\n                uint64: 2,\n                float32: 2.2,\n                int64: 2,\n                int8: 2,\n                binary: vec![4, 5, 6],\n            }\n        );\n\n        let item = dataset.get(2).unwrap();\n\n        assert_eq!(\n            item,\n            TestData {\n                int32: 3,\n                bool: true,\n                float64: 3.3,\n                string: \"Boo3\".to_string(),\n                int16: 3,\n                uint32: 3,\n                uint64: 3,\n                float32: 3.3,\n                int64: 3,\n                int8: 3,\n                binary: vec![7, 8, 9],\n            }\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Corresponding column should exist in the DataFrame: SchemaFieldNotFound(ErrString(\\\"non_existent\\\"))\"]\n    fn test_non_existing_struct_fields() {\n        #[derive(Clone, Debug, Deserialize, PartialEq)]\n        struct PartialTestData {\n            int32: i32,\n            bool: bool,\n            non_existent: String,\n        }\n\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<PartialTestData>::new(df);\n\n        assert!(dataset.is_err());\n        if let Err(e) = dataset {\n            assert!(matches!(e, DataframeDatasetError::Other(_)));\n        }\n    }\n\n    #[test]\n    fn test_partial_table() {\n        #[derive(Clone, Debug, Deserialize, PartialEq)]\n        struct PartialTestData {\n            int32: i32,\n            bool: bool,\n            string: String,\n        }\n\n        let df = create_test_dataframe();\n        let dataset = DataframeDataset::<PartialTestData>::new(df).unwrap();\n\n        assert_eq!(dataset.len(), 3);\n        assert!(!dataset.is_empty());\n\n        let item = dataset.get(1).unwrap();\n        assert_eq!(\n            item,\n            PartialTestData {\n                int32: 2,\n                bool: false,\n                string: \"Boo2\".to_string(),\n            }\n        );\n\n        let item = dataset.get(2).unwrap();\n        assert_eq!(\n            item,\n            PartialTestData {\n                int32: 3,\n                bool: true,\n                string: \"Boo3\".to_string(),\n            }\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/fake.rs",
    "content": "use crate::{Dataset, DatasetIterator, InMemDataset};\nuse fake::{Dummy, Fake, Faker};\n\n/// Dataset filled with fake items generated from the [fake](fake) crate.\npub struct FakeDataset<I> {\n    dataset: InMemDataset<I>,\n}\n\nimpl<I: Dummy<Faker>> FakeDataset<I> {\n    /// Create a new fake dataset with the given size.\n    pub fn new(size: usize) -> Self {\n        let mut items = Vec::with_capacity(size);\n        for _ in 0..size {\n            items.push(Faker.fake());\n        }\n        let dataset = InMemDataset::new(items);\n\n        Self { dataset }\n    }\n}\n\nimpl<I: Send + Sync + Clone> Dataset<I> for FakeDataset<I> {\n    fn iter(&self) -> DatasetIterator<'_, I> {\n        DatasetIterator::new(self)\n    }\n\n    fn get(&self, index: usize) -> Option<I> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n\n    fn is_empty(&self) -> bool {\n        self.dataset.is_empty()\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/in_memory.rs",
    "content": "use std::{\n    fs::File,\n    io::{BufRead, BufReader},\n    path::Path,\n};\n\nuse serde::de::DeserializeOwned;\n\nuse crate::Dataset;\n\n/// Dataset where all items are stored in ram.\npub struct InMemDataset<I> {\n    items: Vec<I>,\n}\n\nimpl<I> InMemDataset<I> {\n    /// Creates a new in memory dataset from the given items.\n    pub fn new(items: Vec<I>) -> Self {\n        InMemDataset { items }\n    }\n}\n\nimpl<I> Dataset<I> for InMemDataset<I>\nwhere\n    I: Clone + Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        self.items.get(index).cloned()\n    }\n    fn len(&self) -> usize {\n        self.items.len()\n    }\n}\n\nimpl<I> InMemDataset<I>\nwhere\n    I: Clone + DeserializeOwned,\n{\n    /// Create from a dataset. All items are loaded in memory.\n    pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {\n        let items: Vec<I> = dataset.iter().collect();\n        Self::new(items)\n    }\n\n    /// Create from a json rows file (one json per line).\n    ///\n    /// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)\n    pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {\n        let file = File::open(path)?;\n        let reader = BufReader::new(file);\n        let mut items = Vec::new();\n\n        for line in reader.lines() {\n            let item = serde_json::from_str(line.unwrap().as_str()).unwrap();\n            items.push(item);\n        }\n\n        let dataset = Self::new(items);\n\n        Ok(dataset)\n    }\n\n    /// Create from a csv file.\n    ///\n    /// The provided `csv::ReaderBuilder` can be configured to fit your csv format.\n    ///\n    /// The supported field types are: String, integer, float, and bool.\n    ///\n    /// See:\n    /// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)\n    /// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)\n    pub fn from_csv<P: AsRef<Path>>(\n        path: P,\n        builder: &csv::ReaderBuilder,\n    ) -> Result<Self, std::io::Error> {\n        let mut rdr = builder.from_path(path)?;\n\n        let mut items = Vec::new();\n\n        for result in rdr.deserialize() {\n            let item: I = result?;\n            items.push(item);\n        }\n\n        let dataset = Self::new(items);\n\n        Ok(dataset)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n\n    use super::*;\n    use crate::{SqliteDataset, test_data};\n\n    use rstest::{fixture, rstest};\n    use serde::{Deserialize, Serialize};\n\n    const DB_FILE: &str = \"tests/data/sqlite-dataset.db\";\n    const JSON_FILE: &str = \"tests/data/dataset.json\";\n    const CSV_FILE: &str = \"tests/data/dataset.csv\";\n    const CSV_FMT_FILE: &str = \"tests/data/dataset-fmt.csv\";\n\n    type SqlDs = SqliteDataset<Sample>;\n\n    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\n    pub struct Sample {\n        column_str: String,\n        column_bytes: Vec<u8>,\n        column_int: i64,\n        column_bool: bool,\n        column_float: f64,\n    }\n\n    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\n    pub struct SampleCsv {\n        column_str: String,\n        column_int: i64,\n        column_bool: bool,\n        column_float: f64,\n    }\n\n    #[fixture]\n    fn train_dataset() -> SqlDs {\n        SqliteDataset::from_db_file(DB_FILE, \"train\").unwrap()\n    }\n\n    #[rstest]\n    pub fn from_dataset(train_dataset: SqlDs) {\n        let dataset = InMemDataset::from_dataset(&train_dataset);\n\n        let non_existing_record_index: usize = 10;\n        let record_index: usize = 0;\n\n        assert_eq!(train_dataset.get(non_existing_record_index), None);\n        assert_eq!(dataset.get(record_index).unwrap().column_str, \"HI1\");\n    }\n\n    #[test]\n    pub fn from_json_rows() {\n        let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();\n\n        let non_existing_record_index: usize = 10;\n        let record_index: usize = 1;\n\n        assert_eq!(dataset.get(non_existing_record_index), None);\n        assert_eq!(dataset.get(record_index).unwrap().column_str, \"HI2\");\n        assert!(!dataset.get(record_index).unwrap().column_bool);\n    }\n\n    #[test]\n    pub fn from_csv_rows() {\n        let rdr = csv::ReaderBuilder::new();\n        let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();\n\n        let non_existing_record_index: usize = 10;\n        let record_index: usize = 1;\n\n        assert_eq!(dataset.get(non_existing_record_index), None);\n        assert_eq!(dataset.get(record_index).unwrap().column_str, \"HI2\");\n        assert_eq!(dataset.get(record_index).unwrap().column_int, 1);\n        assert!(!dataset.get(record_index).unwrap().column_bool);\n        assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);\n    }\n\n    #[test]\n    pub fn from_csv_rows_fmt() {\n        let mut rdr = csv::ReaderBuilder::new();\n        let rdr = rdr.delimiter(b' ').has_headers(false);\n        let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();\n\n        let non_existing_record_index: usize = 10;\n        let record_index: usize = 1;\n\n        assert_eq!(dataset.get(non_existing_record_index), None);\n        assert_eq!(dataset.get(record_index).unwrap().column_str, \"HI2\");\n        assert_eq!(dataset.get(record_index).unwrap().column_int, 1);\n        assert!(!dataset.get(record_index).unwrap().column_bool);\n        assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);\n    }\n\n    #[test]\n    pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {\n        let items_original = test_data::string_items();\n        let dataset = InMemDataset::new(items_original.clone());\n\n        let items: Vec<String> = dataset.iter().collect();\n\n        assert_eq!(items_original, items);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/iterator.rs",
    "content": "use crate::dataset::Dataset;\nuse std::iter::Iterator;\n\n/// Dataset iterator.\npub struct DatasetIterator<'a, I> {\n    current: usize,\n    dataset: &'a dyn Dataset<I>,\n}\n\nimpl<'a, I> DatasetIterator<'a, I> {\n    /// Creates a new dataset iterator.\n    pub fn new<D>(dataset: &'a D) -> Self\n    where\n        D: Dataset<I>,\n    {\n        DatasetIterator {\n            current: 0,\n            dataset,\n        }\n    }\n}\n\nimpl<I> Iterator for DatasetIterator<'_, I> {\n    type Item = I;\n\n    fn next(&mut self) -> Option<I> {\n        let item = self.dataset.get(self.current);\n        self.current += 1;\n        item\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/mod.rs",
    "content": "mod base;\nmod in_memory;\nmod iterator;\n\npub use base::*;\npub use in_memory::*;\npub use iterator::*;\n\n#[cfg(any(test, feature = \"fake\"))]\nmod fake;\n\n#[cfg(any(test, feature = \"fake\"))]\npub use self::fake::*;\n\n#[cfg(feature = \"dataframe\")]\nmod dataframe;\n\n#[cfg(feature = \"dataframe\")]\npub use dataframe::*;\n\n#[cfg(any(feature = \"sqlite\", feature = \"sqlite-bundled\"))]\npub use sqlite::*;\n\n#[cfg(any(feature = \"sqlite\", feature = \"sqlite-bundled\"))]\nmod sqlite;\n"
  },
  {
    "path": "crates/burn-dataset/src/dataset/sqlite.rs",
    "content": "use std::{\n    collections::HashSet,\n    fs, io,\n    marker::PhantomData,\n    path::{Path, PathBuf},\n    sync::{Arc, RwLock},\n};\n\nuse crate::Dataset;\n\nuse gix_tempfile::{\n    AutoRemove, ContainingDirectory, Handle,\n    handle::{Writable, persist},\n};\nuse r2d2::{Pool, PooledConnection};\nuse r2d2_sqlite::{\n    SqliteConnectionManager,\n    rusqlite::{OpenFlags, OptionalExtension},\n};\nuse sanitize_filename::sanitize;\nuse serde::{Serialize, de::DeserializeOwned};\nuse serde_rusqlite::{columns_from_statement, from_row_with_columns};\n\n/// Result type for the sqlite dataset.\npub type Result<T> = core::result::Result<T, SqliteDatasetError>;\n\n/// Sqlite dataset error.\n#[derive(thiserror::Error, Debug)]\npub enum SqliteDatasetError {\n    /// IO related error.\n    #[error(\"IO error: {0}\")]\n    Io(#[from] io::Error),\n\n    /// Sql related error.\n    #[error(\"Sql error: {0}\")]\n    Sql(#[from] serde_rusqlite::rusqlite::Error),\n\n    /// Serde related error.\n    #[error(\"Serde error: {0}\")]\n    Serde(#[from] rmp_serde::encode::Error),\n\n    /// The database file already exists error.\n    #[error(\"Overwrite flag is set to false and the database file already exists: {0}\")]\n    FileExists(PathBuf),\n\n    /// Error when creating the connection pool.\n    #[error(\"Failed to create connection pool: {0}\")]\n    ConnectionPool(#[from] r2d2::Error),\n\n    /// Error when persisting the temporary database file.\n    #[error(\"Could not persist the temporary database file: {0}\")]\n    PersistDbFile(#[from] persist::Error<Writable>),\n\n    /// Any other error.\n    #[error(\"{0}\")]\n    Other(&'static str),\n}\n\nimpl From<&'static str> for SqliteDatasetError {\n    fn from(s: &'static str) -> Self {\n        SqliteDatasetError::Other(s)\n    }\n}\n\n/// This struct represents a dataset where all items are stored in an SQLite database.\n/// Each instance of this struct corresponds to a specific table within the SQLite database,\n/// and allows for interaction with the data stored in the table in a structured and typed manner.\n///\n/// The SQLite database must contain a table with the same name as the `split` field. This table should\n/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`\n/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.\n///\n/// Table columns can be represented in two ways:\n///\n/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table\n///    should match the field names of the `I` struct. The field names can be a subset of column names and\n///    can be in any order.\n///\n/// For the supported field types, refer to:\n/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)\n/// - [SQLite data types](https://www.sqlite.org/datatype3.html)\n///\n/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table\n///    should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields\n///    that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using\n///    [MessagePack](https://msgpack.org/).\n///\n/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate\n/// method to read the data from the table.\n#[derive(Debug)]\npub struct SqliteDataset<I> {\n    db_file: PathBuf,\n    split: String,\n    conn_pool: Pool<SqliteConnectionManager>,\n    columns: Vec<String>,\n    len: usize,\n    select_statement: String,\n    row_serialized: bool,\n    phantom: PhantomData<I>,\n}\n\nimpl<I> SqliteDataset<I> {\n    /// Initializes a `SqliteDataset` from a SQLite database file and a split name.\n    pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {\n        // Create a connection pool\n        let conn_pool = create_conn_pool(&db_file, false)?;\n\n        // Determine how the table is stored\n        let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;\n\n        // Create a select statement and save it\n        let select_statement = if row_serialized {\n            format!(\"select item from {split} where row_id = ?\")\n        } else {\n            format!(\"select * from {split} where row_id = ?\")\n        };\n\n        // Save the column names and the number of rows\n        let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;\n\n        Ok(SqliteDataset {\n            db_file: db_file.as_ref().to_path_buf(),\n            split: split.to_string(),\n            conn_pool,\n            columns,\n            len,\n            select_statement,\n            row_serialized,\n            phantom: PhantomData,\n        })\n    }\n\n    /// Returns true if table has two columns: row_id (integer) and item (blob).\n    ///\n    /// This is used to determine if the table is row serialized or not.\n    fn check_if_row_serialized(\n        conn_pool: &Pool<SqliteConnectionManager>,\n        split: &str,\n    ) -> Result<bool> {\n        // This struct is used to store the column name and type\n        struct Column {\n            name: String,\n            ty: String,\n        }\n\n        const COLUMN_NAME: usize = 1;\n        const COLUMN_TYPE: usize = 2;\n\n        let sql_statement = format!(\"PRAGMA table_info({split})\");\n\n        let conn = conn_pool.get()?;\n\n        let mut stmt = conn.prepare(sql_statement.as_str())?;\n        let column_iter = stmt.query_map([], |row| {\n            Ok(Column {\n                name: row\n                    .get::<usize, String>(COLUMN_NAME)\n                    .unwrap()\n                    .to_lowercase(),\n                ty: row\n                    .get::<usize, String>(COLUMN_TYPE)\n                    .unwrap()\n                    .to_lowercase(),\n            })\n        })?;\n\n        let mut columns: Vec<Column> = vec![];\n\n        for column in column_iter {\n            columns.push(column?);\n        }\n\n        if columns.len() != 2 {\n            Ok(false)\n        } else {\n            // Check if the column names and types match the expected values\n            Ok(columns[0].name == \"row_id\"\n                && columns[0].ty == \"integer\"\n                && columns[1].name == \"item\"\n                && columns[1].ty == \"blob\")\n        }\n    }\n\n    /// Get the database file name.\n    pub fn db_file(&self) -> PathBuf {\n        self.db_file.clone()\n    }\n\n    /// Get the split name.\n    pub fn split(&self) -> &str {\n        self.split.as_str()\n    }\n}\n\nimpl<I> Dataset<I> for SqliteDataset<I>\nwhere\n    I: Clone + Send + Sync + DeserializeOwned,\n{\n    /// Get an item from the dataset.\n    fn get(&self, index: usize) -> Option<I> {\n        // Row ids start with 1 (one) and index starts with 0 (zero)\n        let row_id = index + 1;\n\n        // Get a connection from the pool\n        let connection = self.conn_pool.get().unwrap();\n        let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();\n\n        if self.row_serialized {\n            // Fetch with a single column `item` and deserialize it with MessagePack\n            statement\n                .query_row([row_id], |row| {\n                    // Deserialize item (blob) with MessagePack (rmp-serde)\n                    Ok(\n                        rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())\n                            .unwrap(),\n                    )\n                })\n                .optional() //Converts Error (not found) to None\n                .unwrap()\n        } else {\n            // Fetch a row with multiple columns and deserialize it serde_rusqlite\n            statement\n                .query_row([row_id], |row| {\n                    // Deserialize the row with serde_rusqlite\n                    Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())\n                })\n                .optional() //Converts Error (not found) to None\n                .unwrap()\n        }\n    }\n\n    /// Return the number of rows in the dataset.\n    fn len(&self) -> usize {\n        self.len\n    }\n}\n\n/// Fetch the column names and the number of rows from the database.\nfn fetch_columns_and_len(\n    conn_pool: &Pool<SqliteConnectionManager>,\n    select_statement: &str,\n    split: &str,\n) -> Result<(Vec<String>, usize)> {\n    // Save the column names\n    let connection = conn_pool.get()?;\n    let statement = connection.prepare(select_statement)?;\n    let columns = columns_from_statement(&statement);\n\n    // Count the number of rows and save it as len\n    //\n    // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.\n    // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,\n    // which corresponds to the number of rows in the table.\n    // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.\n    // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.\n    let mut statement =\n        connection.prepare(format!(\"select coalesce(max(row_id), 0) from {split}\").as_str())?;\n\n    let len = statement.query_row([], |row| {\n        let len: usize = row.get(0)?;\n        Ok(len)\n    })?;\n    Ok((columns, len))\n}\n\n/// Helper function to create a connection pool\nfn create_conn_pool<P: AsRef<Path>>(\n    db_file: P,\n    write: bool,\n) -> Result<Pool<SqliteConnectionManager>> {\n    let sqlite_flags = if write {\n        OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE\n    } else {\n        OpenFlags::SQLITE_OPEN_READ_ONLY\n    };\n\n    let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);\n    Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)\n}\n\n/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.\n/// It consists of an optional name, a database file path, and a base directory for storage.\n#[derive(Clone, Debug)]\npub struct SqliteDatasetStorage {\n    name: Option<String>,\n    db_file: Option<PathBuf>,\n    base_dir: Option<PathBuf>,\n}\n\nimpl SqliteDatasetStorage {\n    /// Creates a new instance of `SqliteDatasetStorage` using a dataset name.\n    ///\n    /// # Arguments\n    ///\n    /// * `name` - A string slice that holds the name of the dataset.\n    pub fn from_name(name: &str) -> Self {\n        SqliteDatasetStorage {\n            name: Some(name.to_string()),\n            db_file: None,\n            base_dir: None,\n        }\n    }\n\n    /// Creates a new instance of `SqliteDatasetStorage` using a database file path.\n    ///\n    /// # Arguments\n    ///\n    /// * `db_file` - A reference to the Path that represents the database file path.\n    pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {\n        SqliteDatasetStorage {\n            name: None,\n            db_file: Some(db_file.as_ref().to_path_buf()),\n            base_dir: None,\n        }\n    }\n\n    /// Sets the base directory for storing the dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `base_dir` - A string slice that represents the base directory.\n    pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {\n        self.base_dir = Some(base_dir.as_ref().to_path_buf());\n        self\n    }\n\n    /// Checks if the database file exists in the given path.\n    ///\n    /// # Returns\n    ///\n    /// * A boolean value indicating whether the file exists or not.\n    pub fn exists(&self) -> bool {\n        self.db_file().exists()\n    }\n\n    /// Fetches the database file path.\n    ///\n    /// # Returns\n    ///\n    /// * A `PathBuf` instance representing the file path.\n    pub fn db_file(&self) -> PathBuf {\n        match &self.db_file {\n            Some(db_file) => db_file.clone(),\n            None => {\n                let name = sanitize(self.name.as_ref().expect(\"Name is not set\"));\n                Self::base_dir(self.base_dir.to_owned()).join(format!(\"{name}.db\"))\n            }\n        }\n    }\n\n    /// Determines the base directory for storing the dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.\n    ///\n    /// # Returns\n    ///\n    /// * A `PathBuf` instance representing the base directory.\n    pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {\n        match base_dir {\n            Some(base_dir) => base_dir,\n            None => dirs::cache_dir()\n                .expect(\"Could not get cache directory\")\n                .join(\"burn-dataset\"),\n        }\n    }\n\n    /// Provides a writer instance for the SQLite dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.\n    pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>\n    where\n        I: Clone + Send + Sync + Serialize + DeserializeOwned,\n    {\n        SqliteDatasetWriter::new(self.db_file(), overwrite)\n    }\n\n    /// Provides a reader instance for the SQLite dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `split` - A string slice that defines the data split for reading (e.g., \"train\", \"test\").\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.\n    pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>\n    where\n        I: Clone + Send + Sync + Serialize + DeserializeOwned,\n    {\n        if !self.exists() {\n            panic!(\"The database file does not exist\");\n        }\n\n        SqliteDataset::from_db_file(self.db_file(), split)\n    }\n}\n\n/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.\n/// It retains the current writer's state and its database connection.\n///\n/// Being thread-safe, this writer can be concurrently used across multiple threads.\n///\n/// Typical applications include:\n///\n/// - Generation of a new dataset\n/// - Storage of preprocessed data or metadata\n/// - Enlargement of a dataset's item count post preprocessing\n#[derive(Debug)]\npub struct SqliteDatasetWriter<I> {\n    db_file: PathBuf,\n    db_file_tmp: Option<Handle<Writable>>,\n    splits: Arc<RwLock<HashSet<String>>>,\n    overwrite: bool,\n    conn_pool: Option<Pool<SqliteConnectionManager>>,\n    is_completed: Arc<RwLock<bool>>,\n    phantom: PhantomData<I>,\n}\n\nimpl<I> SqliteDatasetWriter<I>\nwhere\n    I: Clone + Send + Sync + Serialize + DeserializeOwned,\n{\n    /// Creates a new instance of `SqliteDatasetWriter`.\n    ///\n    /// # Arguments\n    ///\n    /// * `db_file` - A reference to the Path that represents the database file path.\n    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.\n    pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {\n        let writer = Self {\n            db_file: db_file.as_ref().to_path_buf(),\n            db_file_tmp: None,\n            splits: Arc::new(RwLock::new(HashSet::new())),\n            overwrite,\n            conn_pool: None,\n            is_completed: Arc::new(RwLock::new(false)),\n            phantom: PhantomData,\n        };\n\n        writer.init()\n    }\n\n    /// Initializes the dataset writer by creating the database file, tables, and connection pool.\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.\n    fn init(mut self) -> Result<Self> {\n        // Remove the db file if it already exists\n        if self.db_file.exists() {\n            if self.overwrite {\n                fs::remove_file(&self.db_file)?;\n            } else {\n                return Err(SqliteDatasetError::FileExists(self.db_file));\n            }\n        }\n\n        // Create the database file directory if it does not exist\n        let db_file_dir = self\n            .db_file\n            .parent()\n            .ok_or(\"Unable to get parent directory\")?;\n\n        if !db_file_dir.exists() {\n            fs::create_dir_all(db_file_dir)?;\n        }\n\n        // Create a temp database file name as {base_dir}/{name}.db.tmp\n        let mut db_file_tmp = self.db_file.clone();\n        db_file_tmp.set_extension(\"db.tmp\");\n        if db_file_tmp.exists() {\n            fs::remove_file(&db_file_tmp)?;\n        }\n\n        // Create the temp database file and wrap it with a gix_tempfile::Handle\n        // This will ensure that the temp file is deleted when the writer is dropped\n        // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)\n        gix_tempfile::signal::setup(Default::default());\n        self.db_file_tmp = Some(gix_tempfile::writable_at(\n            &db_file_tmp,\n            ContainingDirectory::Exists,\n            AutoRemove::Tempfile,\n        )?);\n\n        let conn_pool = create_conn_pool(db_file_tmp, true)?;\n        self.conn_pool = Some(conn_pool);\n\n        Ok(self)\n    }\n\n    /// Serializes and writes an item to the database. The item is written to the table for the\n    /// specified split. If the table does not exist, it is created. If the table exists, the item\n    /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)\n    ///\n    /// # Arguments\n    ///\n    /// * `split` - A string slice that defines the data split for writing (e.g., \"train\", \"test\").\n    /// * `item` - A reference to the item to be written to the database.\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` containing the index of the inserted row if successful, an error otherwise.\n    pub fn write(&self, split: &str, item: &I) -> Result<usize> {\n        // Acquire the read lock (wont't block other reads)\n        let is_completed = self.is_completed.read().unwrap();\n\n        // If the writer is completed, return an error\n        if *is_completed {\n            return Err(SqliteDatasetError::Other(\n                \"Cannot save to a completed dataset writer\",\n            ));\n        }\n\n        // create the table for the split if it does not exist\n        if !self.splits.read().unwrap().contains(split) {\n            self.create_table(split)?;\n        }\n\n        // Get a connection from the pool\n        let conn_pool = self.conn_pool.as_ref().unwrap();\n        let conn = conn_pool.get()?;\n\n        // Serialize the item using MessagePack\n        let serialized_item = rmp_serde::to_vec(item)?;\n\n        // Turn off the synchronous and journal mode for speed up\n        // We are sacrificing durability for speed but it's okay because\n        // we always recreate the dataset if it is not completed.\n        pragma_update_with_error_handling(&conn, \"synchronous\", \"OFF\")?;\n        pragma_update_with_error_handling(&conn, \"journal_mode\", \"OFF\")?;\n\n        // Insert the serialized item into the database\n        let insert_statement = format!(\"insert into {split} (item) values (?)\");\n        conn.execute(insert_statement.as_str(), [serialized_item])?;\n\n        // Get the primary key of the last inserted row and convert to index (row_id-1)\n        let index = (conn.last_insert_rowid() - 1) as usize;\n\n        Ok(index)\n    }\n\n    /// Marks the dataset as completed and persists the temporary database file.\n    pub fn set_completed(&mut self) -> Result<()> {\n        let mut is_completed = self.is_completed.write().unwrap();\n\n        // Force close the connection pool\n        // This is required on Windows platform where the connection pool prevents\n        // from persisting the db by renaming the temp file.\n        if let Some(pool) = self.conn_pool.take() {\n            std::mem::drop(pool);\n        }\n\n        // Rename the database file from tmp to db\n        let _file_result = self\n            .db_file_tmp\n            .take() // take ownership of the temporary file and set to None\n            .unwrap() // unwrap the temporary file\n            .persist(&self.db_file)?\n            .ok_or(\"Unable to persist the database file\")?;\n\n        *is_completed = true;\n        Ok(())\n    }\n\n    /// Creates table for the data split.\n    ///\n    /// Note: call is idempotent and thread-safe.\n    ///\n    /// # Arguments\n    ///\n    /// * `split` - A string slice that defines the data split for the table (e.g., \"train\", \"test\").\n    ///\n    /// # Returns\n    ///\n    /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.\n    ///\n    /// TODO (@antimora): add support creating a table with columns corresponding to the item fields\n    fn create_table(&self, split: &str) -> Result<()> {\n        // Check if the split already exists\n        if self.splits.read().unwrap().contains(split) {\n            return Ok(());\n        }\n\n        let conn_pool = self.conn_pool.as_ref().unwrap();\n        let connection = conn_pool.get()?;\n        let create_table_statement = format!(\n            \"create table if not exists  {split} (row_id integer primary key autoincrement not \\\n             null, item blob not null)\"\n        );\n\n        connection.execute(create_table_statement.as_str(), [])?;\n\n        // Add the split to the splits\n        self.splits.write().unwrap().insert(split.to_string());\n\n        Ok(())\n    }\n}\n\n/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.\n///\n/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error\n/// and can be ignored. This function runs the pragma update and ignores the error if it is\n/// `ExecuteReturnedResults`.\nfn pragma_update_with_error_handling(\n    conn: &PooledConnection<SqliteConnectionManager>,\n    setting: &str,\n    value: &str,\n) -> Result<()> {\n    let result = conn.pragma_update(None, setting, value);\n    if let Err(error) = result\n        && error != rusqlite::Error::ExecuteReturnedResults\n    {\n        return Err(SqliteDatasetError::Sql(error));\n    }\n\n    Ok(())\n}\n\n#[cfg(test)]\nmod tests {\n    use rayon::prelude::*;\n    use rstest::{fixture, rstest};\n    use serde::{Deserialize, Serialize};\n    use tempfile::{NamedTempFile, TempDir, tempdir};\n\n    use super::*;\n\n    type SqlDs = SqliteDataset<Sample>;\n\n    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\n    pub struct Sample {\n        column_str: String,\n        column_bytes: Vec<u8>,\n        column_int: i64,\n        column_bool: bool,\n        column_float: f64,\n    }\n\n    #[fixture]\n    fn train_dataset() -> SqlDs {\n        SqliteDataset::<Sample>::from_db_file(\"tests/data/sqlite-dataset.db\", \"train\").unwrap()\n    }\n\n    #[rstest]\n    pub fn len(train_dataset: SqlDs) {\n        assert_eq!(train_dataset.len(), 2);\n    }\n\n    #[rstest]\n    pub fn get_some(train_dataset: SqlDs) {\n        let item = train_dataset.get(0).unwrap();\n        assert_eq!(item.column_str, \"HI1\");\n        assert_eq!(item.column_bytes, vec![55, 231, 159]);\n        assert_eq!(item.column_int, 1);\n        assert!(item.column_bool);\n        assert_eq!(item.column_float, 1.0);\n    }\n\n    #[rstest]\n    pub fn get_none(train_dataset: SqlDs) {\n        assert_eq!(train_dataset.get(10), None);\n    }\n\n    #[rstest]\n    pub fn multi_thread(train_dataset: SqlDs) {\n        let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];\n        let results: Vec<Option<Sample>> =\n            indices.par_iter().map(|&i| train_dataset.get(i)).collect();\n\n        let mut match_count = 0;\n        for (_index, result) in indices.iter().zip(results.iter()) {\n            if let Some(_val) = result {\n                match_count += 1\n            }\n        }\n\n        assert_eq!(match_count, 5);\n    }\n\n    #[test]\n    fn sqlite_dataset_storage() {\n        // Test with non-existing file\n        let storage = SqliteDatasetStorage::from_file(\"non-existing.db\");\n        assert!(!storage.exists());\n\n        // Test with non-existing name\n        let storage = SqliteDatasetStorage::from_name(\"non-existing.db\");\n        assert!(!storage.exists());\n\n        // Test with existing file\n        let storage = SqliteDatasetStorage::from_file(\"tests/data/sqlite-dataset.db\");\n        assert!(storage.exists());\n        let result = storage.reader::<Sample>(\"train\");\n        assert!(result.is_ok());\n        let train = result.unwrap();\n        assert_eq!(train.len(), 2);\n\n        // Test get writer\n        let temp_file = NamedTempFile::new().unwrap();\n        let storage = SqliteDatasetStorage::from_file(temp_file.path());\n        assert!(storage.exists());\n        let result = storage.writer::<Sample>(true);\n        assert!(result.is_ok());\n    }\n\n    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\n    pub struct Complex {\n        column_str: String,\n        column_bytes: Vec<u8>,\n        column_int: i64,\n        column_bool: bool,\n        column_float: f64,\n        column_complex: Vec<Vec<Vec<[u8; 3]>>>,\n    }\n\n    /// Create a temporary directory.\n    #[fixture]\n    fn tmp_dir() -> TempDir {\n        // Create a TempDir. This object will be automatically\n        // deleted when it goes out of scope.\n        tempdir().unwrap()\n    }\n    type Writer = SqliteDatasetWriter<Complex>;\n\n    /// Create a SqliteDatasetWriter with a temporary directory.\n    /// Make sure to return the temporary directory so that it is not deleted.\n    #[fixture]\n    fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {\n        let temp_dir_str = tmp_dir.path();\n        let storage = SqliteDatasetStorage::from_name(\"preprocessed\").with_base_dir(temp_dir_str);\n        let overwrite = true;\n        let result = storage.writer::<Complex>(overwrite);\n        assert!(result.is_ok());\n        let writer = result.unwrap();\n        (writer, tmp_dir)\n    }\n\n    #[test]\n    fn test_new() {\n        // Test that the constructor works with overwrite = true\n        let test_path = NamedTempFile::new().unwrap();\n        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();\n        assert!(!test_path.path().exists());\n\n        // Test that the constructor works with overwrite = false\n        let test_path = NamedTempFile::new().unwrap();\n        let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);\n        assert!(result.is_err());\n\n        // Test that the constructor works with no existing file\n        let temp = NamedTempFile::new().unwrap();\n        let test_path = temp.path().to_path_buf();\n        assert!(temp.close().is_ok());\n        assert!(!test_path.exists());\n        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();\n        assert!(!test_path.exists());\n    }\n\n    #[rstest]\n    pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {\n        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)\n        let (writer, _tmp_dir) = writer_fixture;\n\n        assert!(writer.overwrite);\n        assert!(!writer.db_file.exists());\n\n        let new_item = Complex {\n            column_str: \"HI1\".to_string(),\n            column_bytes: vec![1_u8, 2, 3],\n            column_int: 0,\n            column_bool: true,\n            column_float: 1.0,\n            column_complex: vec![vec![vec![[1, 23_u8, 3]]]],\n        };\n\n        let index = writer.write(\"train\", &new_item).unwrap();\n        assert_eq!(index, 0);\n\n        let mut writer = writer;\n\n        writer.set_completed().expect(\"Failed to set completed\");\n\n        assert!(writer.db_file.exists());\n        assert!(writer.db_file_tmp.is_none());\n\n        let result = writer.write(\"train\", &new_item);\n\n        // Should fail because the writer is completed\n        assert!(result.is_err());\n\n        let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, \"train\").unwrap();\n\n        let fetched_item = dataset.get(0).unwrap();\n        assert_eq!(fetched_item, new_item);\n        assert_eq!(dataset.len(), 1);\n    }\n\n    #[rstest]\n    pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {\n        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)\n        let (writer, _tmp_dir) = writer_fixture;\n\n        let writer = Arc::new(writer);\n        let record_count = 20;\n\n        let splits = [\"train\", \"test\"];\n\n        (0..record_count).into_par_iter().for_each(|index: i64| {\n            let thread_id: std::thread::ThreadId = std::thread::current().id();\n            let sample = Complex {\n                column_str: format!(\"test_{thread_id:?}_{index}\"),\n                column_bytes: vec![index as u8, 2, 3],\n                column_int: index,\n                column_bool: true,\n                column_float: 1.0,\n                column_complex: vec![vec![vec![[1, index as u8, 3]]]],\n            };\n\n            // half for train and half for test\n            let split = splits[index as usize % 2];\n\n            let _index = writer.write(split, &sample).unwrap();\n        });\n\n        let mut writer = Arc::try_unwrap(writer).unwrap();\n\n        writer\n            .set_completed()\n            .expect(\"Should set completed successfully\");\n\n        let train =\n            SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), \"train\").unwrap();\n        let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, \"test\").unwrap();\n\n        assert_eq!(train.len(), record_count as usize / 2);\n        assert_eq!(test.len(), record_count as usize / 2);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! # Burn Dataset\n//!\n//! Burn Dataset is a library for creating and loading datasets.\n\n#[macro_use]\nextern crate derive_new;\n\nextern crate alloc;\nextern crate dirs;\n\n/// Sources for datasets.\npub mod source;\n\npub mod transform;\n\n/// Audio datasets.\n#[cfg(feature = \"audio\")]\npub mod audio;\n\n/// Vision datasets.\n#[cfg(feature = \"vision\")]\npub mod vision;\n\n/// Natural language processing datasets.\n#[cfg(feature = \"nlp\")]\npub mod nlp;\n\n/// Network dataset utilities.\n#[cfg(feature = \"network\")]\npub mod network {\n    pub use burn_std::network::*;\n}\n\nmod dataset;\npub use dataset::*;\n#[cfg(any(feature = \"sqlite\", feature = \"sqlite-bundled\"))]\npub use source::huggingface::downloader::*;\n\n#[cfg(test)]\nmod test_data {\n    pub fn string_items() -> Vec<String> {\n        vec![\n            \"1 Item\".to_string(),\n            \"2 Items\".to_string(),\n            \"3 Items\".to_string(),\n            \"4 Items\".to_string(),\n        ]\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/nlp/ag_news.rs",
    "content": "//! AG NEWS Dataset Module\n//!\n//! This module provides functionality for loading the AG NEWS text classification dataset.\n//! AG NEWS is a collection of news articles categorized into different topics.\n//! The dataset is split into training (120,000 articles) and test (7,600 articles) sets.\n//!\n//! ## Dataset Details\n//! - **Classes**: 4 categories (World, Sports, Business, Sci/Tech)\n//! - **AG NEWS mirror**: [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83)\n//! - **License**: [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE)\n//!\n//! ## Usage Example\n//! ```rust\n//! use burn_dataset::nlp::AgNewsDataset;\n//!\n//! // Create an AG NEWS dataset accessor\n//! let dataset = AgNewsDataset::new();\n//!\n//! // Access training and test sets\n//! let train_dataset = dataset.train();\n//! let test_dataset = dataset.test();\n//! ```\n\nuse std::{path::PathBuf, sync::Mutex};\n\nuse flate2::read::GzDecoder;\nuse serde::{Deserialize, Serialize};\nuse tar::Archive;\n\nuse crate::InMemDataset;\nuse crate::network::downloader;\n\n/// AG NEWS mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L83).\n/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\nconst AG_NEWS_URL: &str = \"https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz\";\n\n/// Represents an item in the AG NEWS dataset.\n///\n/// Each item contains a label, title, and content of a news article.\n#[derive(Deserialize, Serialize, Debug, Clone)]\npub struct AgNewsItem {\n    /// The category label of the news article.\n    pub label: String,\n    /// The title of the news article.\n    pub title: String,\n    /// The content/body of the news article.\n    pub content: String,\n}\n\n/// AG NEWS dataset accessor.\n///\n/// This struct provides convenient access to the AG NEWS text classification dataset.\n/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.\n///\n/// The dataset is split into training (120,000 articles) and test (7,600 articles) sets.\npub struct AgNewsDataset {\n    agnews_dir: PathBuf,\n}\n\n/// AG NEWS dataset download lock.\n///\n/// This lock ensures that only one thread downloads the AG NEWS dataset at a time.\nstatic DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());\n\nimpl AgNewsDataset {\n    /// Creates a new AG NEWS dataset accessor.\n    ///\n    /// This will download and extract the dataset if it's not already present.\n    pub fn new() -> Self {\n        Self {\n            agnews_dir: Self::download(),\n        }\n    }\n\n    /// Downloads and extracts the AG NEWS dataset.\n    ///\n    /// # Returns\n    /// Path to the directory containing the extracted dataset.\n    fn download() -> PathBuf {\n        // Acquire the lock. This will block if another thread already holds the lock.\n        let _lock = DOWNLOAD_LOCK.lock().unwrap();\n\n        // Dataset files are stored in the burn-dataset cache directory\n        let cache_dir = dirs::cache_dir()\n            .expect(\"Could not get cache directory\")\n            .join(\"burn-dataset\");\n\n        // AG NEWS dataset directory\n        let agnews_dir = cache_dir.join(\"ag_news_csv\");\n\n        // AG NEWS dataset url\n        let url = AG_NEWS_URL;\n\n        // AG NEWS dataset archive filename\n        let filename = \"ag_news_csv.tgz\";\n\n        // Check for already downloaded content\n        if !agnews_dir.exists() {\n            // Download gzip file\n            let bytes = downloader::download_file_as_bytes(url, filename);\n\n            // Decode gzip file content and unpack archive\n            let gz_buffer = GzDecoder::new(&bytes[..]);\n            let mut archive = Archive::new(gz_buffer);\n            archive.unpack(cache_dir).unwrap();\n        }\n\n        agnews_dir\n    }\n\n    /// Parses a CSV file into an in-memory dataset.\n    ///\n    /// # Arguments\n    /// * `file_path` - Path to the CSV file to parse.\n    ///\n    /// # Returns\n    /// An `InMemDataset` containing the parsed data.\n    fn parse_csv(file_path: &str) -> InMemDataset<AgNewsItem> {\n        let mut rdr = csv::ReaderBuilder::new();\n        let rdr = rdr.has_headers(false);\n\n        InMemDataset::from_csv(file_path, &rdr).expect(\"Failed to parse CSV file\")\n    }\n\n    /// Gets the training dataset.\n    ///\n    /// # Returns\n    /// An `InMemDataset` instance containing 120,000 training articles.\n    pub fn train(&self) -> InMemDataset<AgNewsItem> {\n        let file_path = self.agnews_dir.join(\"train.csv\");\n        Self::parse_csv(file_path.to_str().unwrap())\n    }\n\n    /// Gets the test dataset.\n    ///\n    /// # Returns\n    /// An `InMemDataset` instance containing 7,600 test articles.\n    pub fn test(&self) -> InMemDataset<AgNewsItem> {\n        let file_path = self.agnews_dir.join(\"test.csv\");\n        Self::parse_csv(file_path.to_str().unwrap())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::Dataset;\n\n    // AG NEWS dataset train and test dataset lengths\n    const TRAIN_DATASET_LEN: usize = 120000;\n    const TEST_DATASET_LEN: usize = 7600;\n\n    #[test]\n    fn test_agnews_download() {\n        let agnews_dir = AgNewsDataset::download();\n        assert!(agnews_dir.exists());\n    }\n\n    #[test]\n    fn test_agnews_len() {\n        let agnews = AgNewsDataset::new();\n        let train_dataset = agnews.train();\n        let test_dataset = agnews.test();\n        assert_eq!(train_dataset.len(), TRAIN_DATASET_LEN);\n        assert_eq!(test_dataset.len(), TEST_DATASET_LEN);\n    }\n\n    #[test]\n    fn test_agnews_first_and_last_item() {\n        let agnews = AgNewsDataset::new();\n\n        // Test the first and the last item in training dataset\n        let train_dataset = agnews.train();\n        let first_item = train_dataset.get(0).unwrap();\n        let last_item = train_dataset.get(train_dataset.len() - 1).unwrap();\n        assert!(compare_item(&first_item, &(\"3\".to_string(), \"Wall St. Bears Claw Back Into the Black (Reuters)\".to_string(), \"Reuters - Short-sellers, Wall Street's dwindling\\\\band of ultra-cynics, are seeing green again.\".to_string())));\n        assert!(compare_item(\n            &last_item,\n            &(\n                \"2\".to_string(),\n                \"Nets get Carter from Raptors\".to_string(),\n                \"INDIANAPOLIS -- All-Star Vince Carter was traded by the Toronto Raptors to the New Jersey Nets for Alonzo Mourning, Eric Williams, Aaron Williams, and a pair of first-round draft picks yesterday.\".to_string()\n            )\n        ));\n\n        // Test the first and the last item in test dataset\n        let test_dataset = agnews.test();\n        let first_item = test_dataset.get(0).unwrap();\n        let last_item = test_dataset.get(test_dataset.len() - 1).unwrap();\n        assert!(compare_item(\n            &first_item,\n            &(\n                \"3\".to_string(),\n                \"Fears for T N pension after talks\".to_string(),\n                \"Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.\".to_string()\n            )\n        ));\n        assert!(compare_item(\n            &last_item,\n            &(\n                \"3\".to_string(),\n                \"EBay gets into rentals\".to_string(),\n                \"EBay plans to buy the apartment and home rental service Rent.com for \\\\$415 million, adding to its already exhaustive breadth of offerings.\".to_string()\n            )\n        ));\n    }\n\n    fn compare_item(item: &AgNewsItem, target: &(String, String, String)) -> bool {\n        item.label == target.0 && item.title == target.1 && item.content == target.2\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/nlp/mod.rs",
    "content": "#[cfg(feature = \"builtin-sources\")]\nmod ag_news;\nmod text_folder;\n\n#[cfg(feature = \"builtin-sources\")]\npub use ag_news::*;\npub use text_folder::*;\n"
  },
  {
    "path": "crates/burn-dataset/src/nlp/text_folder.rs",
    "content": "use crate::transform::{Mapper, MapperDataset};\nuse crate::{Dataset, InMemDataset};\n\nuse encoding_rs::{GB18030, GBK, UTF_8, UTF_16BE, UTF_16LE};\nuse globwalk::{self, DirEntry};\nuse std::collections::{HashMap, HashSet};\nuse std::fs;\nuse std::io::Read;\nuse std::path::{Path, PathBuf};\nuse thiserror::Error;\n\nconst SUPPORTED_FILES: [&str; 1] = [\"txt\"];\n\n/// Text data type.\n#[derive(Debug, Clone, PartialEq)]\npub struct TextData {\n    /// The text content.\n    pub text: String,\n\n    /// Original text source.\n    pub text_path: String,\n}\n\n/// Text dataset item.\n#[derive(Debug, Clone, PartialEq)]\npub struct TextDatasetItem {\n    /// Text content.\n    pub text: TextData,\n\n    /// Label for the text.\n    pub label: usize,\n}\n\n/// Raw text dataset item.\n#[derive(Debug, Clone)]\nstruct TextDatasetItemRaw {\n    /// Text path.\n    text_path: PathBuf,\n\n    /// Text label.\n    label: String,\n}\n\nimpl TextDatasetItemRaw {\n    fn new<P: AsRef<Path>>(text_path: P, label: String) -> TextDatasetItemRaw {\n        TextDatasetItemRaw {\n            text_path: text_path.as_ref().to_path_buf(),\n            label,\n        }\n    }\n}\n\nstruct PathToTextDatasetItem {\n    classes: HashMap<String, usize>,\n}\n\n/// Parse the text content from file with auto-detection of encoding.\nfn parse_text_content(text_path: &PathBuf) -> String {\n    // Read raw bytes from disk\n    let mut file = fs::File::open(text_path).unwrap();\n    let mut bytes = Vec::new();\n    file.read_to_end(&mut bytes).unwrap();\n\n    // Try to detect encoding and decode text\n    // First try UTF-8 with BOM\n    if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) && bytes.len() >= 3 {\n        let (result, _, had_errors) = UTF_8.decode(&bytes[3..]);\n        if !had_errors {\n            return result.into_owned();\n        }\n    }\n\n    // Try UTF-8 without BOM\n    let (result, _, had_errors) = UTF_8.decode(&bytes);\n    if !had_errors {\n        return result.into_owned();\n    }\n\n    // Try UTF-16LE with BOM\n    if bytes.starts_with(&[0xFF, 0xFE]) && bytes.len() >= 2 {\n        let (result, had_errors) = UTF_16LE.decode_with_bom_removal(&bytes[2..]);\n        if !had_errors {\n            return result.into_owned();\n        }\n    }\n\n    // Try UTF-16BE with BOM\n    if bytes.starts_with(&[0xFE, 0xFF]) && bytes.len() >= 2 {\n        let (result, had_errors) = UTF_16BE.decode_with_bom_removal(&bytes[2..]);\n        if !had_errors {\n            return result.into_owned();\n        }\n    }\n\n    // Try GB18030 encoding\n    let (result, _, had_errors) = GB18030.decode(&bytes);\n    if !had_errors {\n        return result.into_owned();\n    }\n\n    // Try GBK encoding\n    let (result, _, had_errors) = GBK.decode(&bytes);\n    if !had_errors {\n        return result.into_owned();\n    }\n\n    // Default fallback - use from_utf8_lossy for any remaining cases\n    String::from_utf8_lossy(&bytes).to_string()\n}\n\nimpl Mapper<TextDatasetItemRaw, TextDatasetItem> for PathToTextDatasetItem {\n    /// Convert a raw text dataset item (path-like) to text content with a target label.\n    fn map(&self, item: &TextDatasetItemRaw) -> TextDatasetItem {\n        let label = *self.classes.get(&item.label).unwrap();\n\n        // Load text from disk\n        let text_content = parse_text_content(&item.text_path);\n\n        let text_data = TextData {\n            text: text_content,\n            text_path: item.text_path.display().to_string(),\n        };\n\n        TextDatasetItem {\n            text: text_data,\n            label,\n        }\n    }\n}\n\n/// Error type for [TextFolderDataset](TextFolderDataset).\n#[derive(Error, Debug)]\npub enum TextLoaderError {\n    /// Unknown error.\n    #[error(\"unknown: `{0}`\")]\n    Unknown(String),\n\n    /// I/O operation error.\n    #[error(\"I/O error: `{0}`\")]\n    IOError(String),\n\n    /// Invalid file error.\n    #[error(\"Invalid file extension: `{0}`\")]\n    InvalidFileExtensionError(String),\n\n    /// Encoding error.\n    #[error(\"Encoding error: `{0}`\")]\n    EncodingError(String),\n}\n\ntype TextDatasetMapper =\n    MapperDataset<InMemDataset<TextDatasetItemRaw>, PathToTextDatasetItem, TextDatasetItemRaw>;\n\n/// A generic dataset to load texts from disk.\npub struct TextFolderDataset {\n    dataset: TextDatasetMapper,\n}\n\nimpl Dataset<TextDatasetItem> for TextFolderDataset {\n    fn get(&self, index: usize) -> Option<TextDatasetItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\nimpl TextFolderDataset {\n    /// Create a text classification dataset from the root folder.\n    ///\n    /// # Arguments\n    ///\n    /// * `root` - Dataset root folder.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification<P: AsRef<Path>>(root: P) -> Result<Self, TextLoaderError> {\n        // New dataset containing any of the supported file types\n        TextFolderDataset::new_classification_with(root, &SUPPORTED_FILES)\n    }\n\n    /// Create a text classification dataset from the root folder.\n    /// The included texts are filtered based on the provided extensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `root` - Dataset root folder.\n    /// * `extensions` - List of allowed extensions.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification_with<P, S>(root: P, extensions: &[S]) -> Result<Self, TextLoaderError>\n    where\n        P: AsRef<Path>,\n        S: AsRef<str>,\n    {\n        // Glob all texts with extensions\n        let walker = globwalk::GlobWalkerBuilder::from_patterns(\n            root.as_ref(),\n            &[format!(\n                \"*.{{{}}}\", // \"*.{ext1,ext2,ext3}\n                extensions\n                    .iter()\n                    .map(Self::check_extension)\n                    .collect::<Result<Vec<_>, _>>()?\n                    .join(\",\")\n            )],\n        )\n        .follow_links(true)\n        .sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path\n        .build()\n        .map_err(|err| TextLoaderError::Unknown(format!(\"{err:?}\")))?\n        .filter_map(Result::ok);\n\n        // Get all dataset items\n        let mut items = Vec::new();\n        let mut classes = HashSet::new();\n        for text in walker {\n            let text_path = text.path();\n\n            // Label name is represented by the parent folder name\n            let label = text_path\n                .parent()\n                .ok_or_else(|| {\n                    TextLoaderError::IOError(\"Could not resolve text parent folder\".to_string())\n                })?\n                .file_name()\n                .ok_or_else(|| {\n                    TextLoaderError::IOError(\n                        \"Could not resolve text parent folder name\".to_string(),\n                    )\n                })?\n                .to_string_lossy()\n                .into_owned();\n\n            classes.insert(label.clone());\n\n            items.push(TextDatasetItemRaw::new(text_path, label))\n        }\n\n        // Sort class names\n        let mut classes = classes.into_iter().collect::<Vec<_>>();\n        classes.sort();\n\n        Self::with_items(items, &classes)\n    }\n\n    /// Create a text classification dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - List of dataset items, each item represented by a tuple `(text path, label)`.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(\n        items: Vec<(P, String)>,\n        classes: &[S],\n    ) -> Result<Self, TextLoaderError> {\n        // Parse items and check valid text extension types\n        let items = items\n            .into_iter()\n            .map(|(path, label)| {\n                // Map text path and label\n                let path = path.as_ref();\n                let label = label;\n\n                Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;\n\n                Ok(TextDatasetItemRaw::new(path, label))\n            })\n            .collect::<Result<Vec<_>, _>>()?;\n\n        Self::with_items(items, classes)\n    }\n\n    /// Create a text dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - Raw dataset items.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    fn with_items<S: AsRef<str>>(\n        items: Vec<TextDatasetItemRaw>,\n        classes: &[S],\n    ) -> Result<Self, TextLoaderError> {\n        // NOTE: right now we don't need to validate the supported text files since\n        // the method is private. We assume it's already validated.\n        let dataset = InMemDataset::new(items);\n\n        // Class names to index map\n        let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();\n        let classes_map: HashMap<_, _> = classes\n            .into_iter()\n            .enumerate()\n            .map(|(idx, cls)| (cls.to_string(), idx))\n            .collect();\n\n        let mapper = PathToTextDatasetItem {\n            classes: classes_map,\n        };\n        let dataset = MapperDataset::new(dataset, mapper);\n\n        Ok(Self { dataset })\n    }\n\n    /// Check if extension is supported.\n    fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, TextLoaderError> {\n        let extension = extension.as_ref();\n        if !SUPPORTED_FILES.contains(&extension) {\n            Err(TextLoaderError::InvalidFileExtensionError(\n                extension.to_string(),\n            ))\n        } else {\n            Ok(extension.to_string())\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use std::path::Path;\n\n    const TEXT_ROOT: &str = \"tests/data/text_folder\";\n\n    #[test]\n    fn test_text_folder_dataset() {\n        let dataset = TextFolderDataset::new_classification(TEXT_ROOT).unwrap();\n\n        // Dataset should have 4 elements (2 positive + 2 negative)\n        assert_eq!(dataset.len(), 4);\n        assert_eq!(dataset.get(4), None);\n\n        // Check that we have items from both classes\n        let mut found_positive = false;\n        let mut found_negative = false;\n\n        for i in 0..dataset.len() {\n            let item = dataset.get(i).unwrap();\n            if item.label == 0 {\n                found_negative = true;\n                // Check that the text content is loaded correctly\n                assert!(!item.text.text.is_empty());\n                assert!(item.text.text_path.contains(\"negative\"));\n            } else if item.label == 1 {\n                found_positive = true;\n                // Check that the text content is loaded correctly\n                assert!(!item.text.text.is_empty());\n                assert!(item.text.text_path.contains(\"positive\"));\n            }\n        }\n\n        // Verify we found items from both classes\n        assert!(found_positive);\n        assert!(found_negative);\n    }\n\n    #[test]\n    fn test_text_folder_dataset_with_invalid_extension() {\n        // Try to create a dataset with an unsupported extension\n        let result = TextFolderDataset::new_classification_with(TEXT_ROOT, &[\"invalid\"]);\n        assert!(result.is_err());\n    }\n\n    #[test]\n    fn test_text_folder_dataset_with_items() {\n        // Create the dataset\n        let root = Path::new(TEXT_ROOT);\n        let items = vec![\n            (\n                root.join(\"positive\").join(\"sample1.txt\"),\n                \"positive\".to_string(),\n            ),\n            (\n                root.join(\"negative\").join(\"sample2.txt\"),\n                \"negative\".to_string(),\n            ),\n        ];\n        let classes = vec![\"positive\", \"negative\"];\n        let dataset = TextFolderDataset::new_classification_with_items(items, &classes).unwrap();\n\n        // Dataset should have 2 elements\n        assert_eq!(dataset.len(), 2);\n        assert_eq!(dataset.get(2), None);\n\n        // Get items\n        let item0 = dataset.get(0).unwrap();\n        let item1 = dataset.get(1).unwrap();\n\n        // Check item0\n        assert!(compare_item(\n            &item0,\n            &(\n                \"This is a positive text sample for testing the text folder dataset functionality.\"\n                    .to_string(),\n                0\n            )\n        ));\n\n        // Check item1\n        assert_eq!(item1.label, 1);\n        assert!(item1.text.text_path.contains(\"negative\"));\n        assert!(compare_item(\n            &item1,\n            &(\n                \"另一个负面文本样本，用以确保数据集能够处理同一类别中的多个文件。\".to_string(),\n                1\n            )\n        ));\n    }\n\n    fn compare_item(item: &TextDatasetItem, target: &(String, usize)) -> bool {\n        item.text.text == target.0 && item.label == target.1\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/source/huggingface/downloader.rs",
    "content": "use std::fs::{self, create_dir_all};\nuse std::path::{Path, PathBuf};\nuse std::process::Command;\n\nuse crate::{SqliteDataset, SqliteDatasetError, SqliteDatasetStorage};\n\nuse sanitize_filename::sanitize;\nuse serde::de::DeserializeOwned;\nuse thiserror::Error;\n\nconst PYTHON_SOURCE: &str = include_str!(\"importer.py\");\n#[cfg(not(target_os = \"windows\"))]\nconst VENV_BIN_PYTHON: &str = \"bin/python3\";\n#[cfg(target_os = \"windows\")]\nconst VENV_BIN_PYTHON: &str = \"Scripts\\\\python\";\n\n/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).\n#[derive(Error, Debug)]\npub enum ImporterError {\n    /// Unknown error.\n    #[error(\"unknown: `{0}`\")]\n    Unknown(String),\n\n    /// Fail to download python dependencies.\n    #[error(\"fail to download python dependencies: `{0}`\")]\n    FailToDownloadPythonDependencies(String),\n\n    /// Fail to create sqlite dataset.\n    #[error(\"sqlite dataset: `{0}`\")]\n    SqliteDataset(#[from] SqliteDatasetError),\n\n    /// python3 is not installed.\n    #[error(\"python3 is not installed\")]\n    PythonNotInstalled,\n\n    /// venv environment is not initialized.\n    #[error(\"venv environment is not initialized\")]\n    VenvNotInitialized,\n}\n\n/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).\n///\n/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).\n///\n/// # Example\n/// ```no_run\n///  use burn_dataset::HuggingfaceDatasetLoader;\n///  use burn_dataset::SqliteDataset;\n///  use serde::{Deserialize, Serialize};\n///\n/// #[derive(Deserialize, Debug, Clone)]\n/// struct MnistItemRaw {\n///     pub image_bytes: Vec<u8>,\n///     pub label: usize,\n/// }\n///\n///  let train_ds:SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new(\"mnist\")\n///       .dataset(\"train\")\n///       .unwrap();\n/// ```\n///\n/// # Note\n/// This loader relies on the [`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index)\n/// to download datasets. This is a Python library, so you must have an existing Python installation.\npub struct HuggingfaceDatasetLoader {\n    name: String,\n    subset: Option<String>,\n    base_dir: Option<PathBuf>,\n    huggingface_token: Option<String>,\n    huggingface_cache_dir: Option<String>,\n    huggingface_data_dir: Option<String>,\n    trust_remote_code: bool,\n    use_python_venv: bool,\n}\n\nimpl HuggingfaceDatasetLoader {\n    /// Create a huggingface dataset loader.\n    pub fn new(name: &str) -> Self {\n        Self {\n            name: name.to_string(),\n            subset: None,\n            base_dir: None,\n            huggingface_token: None,\n            huggingface_cache_dir: None,\n            huggingface_data_dir: None,\n            trust_remote_code: false,\n            use_python_venv: true,\n        }\n    }\n\n    /// Create a huggingface dataset loader for a subset of the dataset.\n    ///\n    /// The subset name must be one of the subsets listed in the dataset page.\n    ///\n    /// If no subset names are listed, then do not use this method.\n    pub fn with_subset(mut self, subset: &str) -> Self {\n        self.subset = Some(subset.to_string());\n        self\n    }\n\n    /// Specify a base directory to store the dataset.\n    ///\n    /// If not specified, the dataset will be stored in the system cache directory under `burn-dataset`.\n    pub fn with_base_dir(mut self, base_dir: &str) -> Self {\n        self.base_dir = Some(base_dir.into());\n        self\n    }\n\n    /// Specify a huggingface token to download datasets behind authentication.\n    ///\n    /// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)\n    pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {\n        self.huggingface_token = Some(huggingface_token.to_string());\n        self\n    }\n\n    /// Specify a huggingface cache directory to store the downloaded datasets.\n    ///\n    /// If not specified, the dataset will be stored in the system cache directory under `huggingface/datasets`.\n    pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {\n        self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());\n        self\n    }\n\n    /// Specify a relative path to a subset of a dataset. This is used in some datasets for the\n    /// manual steps of dataset download process.\n    ///\n    /// Unless you've encountered a ManualDownloadError\n    /// when loading your dataset you probably don't have to worry about this setting.\n    pub fn with_huggingface_data_dir(mut self, huggingface_data_dir: &str) -> Self {\n        self.huggingface_data_dir = Some(huggingface_data_dir.to_string());\n        self\n    }\n\n    /// Specify whether or not to trust remote code.\n    ///\n    /// If not specified, trust remote code is set to true.\n    pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {\n        self.trust_remote_code = trust_remote_code;\n        self\n    }\n\n    /// Specify whether or not to use the burn-dataset Python\n    /// virtualenv for running the importer script. If false, local\n    /// `python3`'s environment is used.\n    ///\n    /// If not specified, the virtualenv is used.\n    pub fn with_use_python_venv(mut self, use_python_venv: bool) -> Self {\n        self.use_python_venv = use_python_venv;\n        self\n    }\n\n    /// Load the dataset.\n    pub fn dataset<I: DeserializeOwned + Clone>(\n        self,\n        split: &str,\n    ) -> Result<SqliteDataset<I>, ImporterError> {\n        let db_file = self.db_file()?;\n        let dataset = SqliteDataset::from_db_file(db_file, split)?;\n        Ok(dataset)\n    }\n\n    /// Get the path to the sqlite database file.\n    ///\n    /// If the database file does not exist, it will be downloaded and imported.\n    pub fn db_file(self) -> Result<PathBuf, ImporterError> {\n        // determine (and create if needed) the base directory\n        let base_dir = SqliteDatasetStorage::base_dir(self.base_dir);\n\n        if !base_dir.exists() {\n            create_dir_all(&base_dir).expect(\"Failed to create base directory\");\n        }\n\n        //sanitize the name and subset\n        let name = sanitize(self.name.as_str());\n\n        // create the db file path\n        let db_file_name = if let Some(subset) = self.subset.clone() {\n            format!(\"{name}-{}.db\", sanitize(subset.as_str()))\n        } else {\n            format!(\"{name}.db\")\n        };\n\n        let db_file = base_dir.join(db_file_name);\n\n        // import the dataset if needed\n        if !Path::new(&db_file).exists() {\n            import(\n                self.name,\n                self.subset,\n                db_file.clone(),\n                base_dir,\n                self.huggingface_token,\n                self.huggingface_cache_dir,\n                self.huggingface_data_dir,\n                self.trust_remote_code,\n                self.use_python_venv,\n            )?;\n        }\n\n        Ok(db_file)\n    }\n}\n\n/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.\n#[allow(clippy::too_many_arguments)]\nfn import(\n    name: String,\n    subset: Option<String>,\n    base_file: PathBuf,\n    base_dir: PathBuf,\n    huggingface_token: Option<String>,\n    huggingface_cache_dir: Option<String>,\n    huggingface_data_dir: Option<String>,\n    trust_remote_code: bool,\n    use_python_venv: bool,\n) -> Result<(), ImporterError> {\n    let python_path = if use_python_venv {\n        install_python_deps(&base_dir)?\n    } else {\n        get_python_name()?.into()\n    };\n\n    let mut command = Command::new(python_path);\n\n    command.arg(importer_script_path(&base_dir));\n\n    command.arg(\"--name\");\n    command.arg(name);\n\n    command.arg(\"--file\");\n    command.arg(base_file);\n\n    if let Some(subset) = subset {\n        command.arg(\"--subset\");\n        command.arg(subset);\n    }\n\n    if let Some(huggingface_token) = huggingface_token {\n        command.arg(\"--token\");\n        command.arg(huggingface_token);\n    }\n\n    if let Some(huggingface_cache_dir) = huggingface_cache_dir {\n        command.arg(\"--cache_dir\");\n        command.arg(huggingface_cache_dir);\n    }\n    if let Some(huggingface_data_dir) = huggingface_data_dir {\n        command.arg(\"--data_dir\");\n        command.arg(huggingface_data_dir);\n    }\n    if trust_remote_code {\n        command.arg(\"--trust_remote_code\");\n        command.arg(\"True\");\n    }\n    let mut handle = command.spawn().unwrap();\n\n    let exit_status = handle\n        .wait()\n        .map_err(|err| ImporterError::Unknown(format!(\"{err:?}\")))?;\n\n    if !exit_status.success() {\n        return Err(ImporterError::Unknown(format!(\"{exit_status}\")));\n    }\n\n    Ok(())\n}\n\n/// check python --version output is `Python 3.x.x`\nfn check_python_version_is_3(python: &str) -> bool {\n    let output = Command::new(python).arg(\"--version\").output();\n    match output {\n        Ok(output) => {\n            if output.status.success() {\n                let version_string = String::from_utf8_lossy(&output.stdout);\n                if let Some(index) = version_string.find(' ') {\n                    let version = &version_string[index + 1..];\n                    version.starts_with(\"3.\")\n                } else {\n                    false\n                }\n            } else {\n                false\n            }\n        }\n        Err(_error) => false,\n    }\n}\n\n/// get python3 name `python` `python3` or `py`\nfn get_python_name() -> Result<&'static str, ImporterError> {\n    let python_name_list = [\"python3\", \"python\", \"py\"];\n    for python_name in python_name_list.iter() {\n        if check_python_version_is_3(python_name) {\n            return Ok(python_name);\n        }\n    }\n    Err(ImporterError::PythonNotInstalled)\n}\n\nfn importer_script_path(base_dir: &Path) -> PathBuf {\n    let path_file = base_dir.join(\"importer.py\");\n\n    fs::write(&path_file, PYTHON_SOURCE).expect(\"Write python dataset downloader\");\n    path_file\n}\n\nfn install_python_deps(base_dir: &Path) -> Result<PathBuf, ImporterError> {\n    let venv_dir = base_dir.join(\"venv\");\n    let venv_python_path = venv_dir.join(VENV_BIN_PYTHON);\n    // If the venv environment is already initialized, skip the initialization.\n    if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {\n        let python_name = get_python_name()?;\n        let mut command = Command::new(python_name);\n        command.args([\n            \"-m\",\n            \"venv\",\n            venv_dir\n                .as_os_str()\n                .to_str()\n                .expect(\"Path utf8 conversion should not fail\"),\n        ]);\n\n        // Spawn the venv creation process and wait for it to complete.\n        let mut handle = command.spawn().unwrap();\n\n        handle.wait().map_err(|err| {\n            ImporterError::FailToDownloadPythonDependencies(format!(\" error: {err}\"))\n        })?;\n        // Check if the venv environment can be used successfully.\"\n        if !check_python_version_is_3(venv_python_path.to_str().unwrap()) {\n            return Err(ImporterError::VenvNotInitialized);\n        }\n    }\n\n    let mut ensurepip_cmd = Command::new(&venv_python_path);\n    ensurepip_cmd.args([\"-m\", \"ensurepip\", \"--upgrade\"]);\n    let status = ensurepip_cmd.status().map_err(|err| {\n        ImporterError::FailToDownloadPythonDependencies(format!(\"failed to run ensurepip: {err}\"))\n    })?;\n    if !status.success() {\n        return Err(ImporterError::FailToDownloadPythonDependencies(\n            \"ensurepip failed to initialize pip\".to_string(),\n        ));\n    }\n\n    let mut command = Command::new(&venv_python_path);\n    command.args([\n        \"-m\",\n        \"pip\",\n        \"--quiet\",\n        \"install\",\n        \"pyarrow\",\n        \"sqlalchemy\",\n        \"Pillow\",\n        \"soundfile\",\n        \"datasets\",\n    ]);\n\n    // Spawn the pip install process and wait for it to complete.\n    let mut handle = command.spawn().unwrap();\n    handle\n        .wait()\n        .map_err(|err| ImporterError::FailToDownloadPythonDependencies(format!(\" error: {err}\")))?;\n\n    Ok(venv_python_path)\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/source/huggingface/importer.py",
    "content": "import argparse\n\nimport pyarrow as pa\nfrom datasets import Audio, Image, load_dataset\nfrom sqlalchemy import Column, Integer, Table, create_engine, event, inspect\nfrom sqlalchemy.types import LargeBinary\n\n\ndef download_and_export(\n    name: str,\n    subset: str,\n    db_file: str,\n    token: str,\n    cache_dir: str,\n    data_dir: str | None,\n    trust_remote_code: bool,\n):\n    \"\"\"\n    Download a dataset from using HuggingFace dataset and export it to a sqlite database.\n    \"\"\"\n\n    # TODO For media columns (Image and Audio) sometimes when decode=False,\n    # bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}\n    # We should handle this case, but unfortunately we did not come across this case yet to test it.\n\n    print(\"*\" * 80)\n    print(\"Starting huggingface dataset download and export\")\n    print(f\"Dataset Name: {name}\")\n    print(f\"Subset Name: {subset}\")\n    print(f\"Sqlite database file: {db_file}\")\n    print(f\"Trust remote code: {trust_remote_code}\")\n    if cache_dir is None:\n        print(f\"Custom cache dir: {cache_dir}\")\n    print(\"*\" * 80)\n\n    # Load the dataset\n    dataset_all = load_dataset(\n        name,\n        subset,\n        cache_dir=cache_dir,\n        data_dir=data_dir,\n        use_auth_token=token,\n        trust_remote_code=trust_remote_code,\n    )\n\n    print(f\"Dataset: {dataset_all}\")\n\n    # Create the database connection descriptor (sqlite)\n    engine = create_engine(f\"sqlite:///{db_file}\")\n\n    # Set some sqlite pragmas to speed up the database\n    event.listen(engine, \"connect\", set_sqlite_pragma)\n\n    # Add an row_id column to each table as primary key (datasets does not have API for this)\n    event.listen(Table, \"before_create\", add_pk_column)\n\n    # Export each split in the dataset\n    for key in dataset_all.keys():\n        dataset = dataset_all[key]\n\n        # Disable decoding for audio and image fields\n        dataset = disable_decoding(dataset)\n\n        # Flatten the dataset\n        dataset = dataset.flatten()\n\n        # Rename columns to remove dots from the names\n        dataset = rename_columns(dataset)\n\n        print(f\"Saving dataset: {name} - {key}\")\n        print(f\"Dataset features: {dataset.features}\")\n\n        # Save the dataset to a sqlite database\n        dataset.to_sql(\n            key,  # table name\n            engine,\n            # don't save the index, use row_id instead (index is not unique)\n            index=False,\n            dtype=blob_columns(dataset),  # save binary columns as blob\n        )\n\n    # Print the schema of the database so we can reference the columns in the rust code\n    print_table_info(engine)\n\n\ndef disable_decoding(dataset):\n    \"\"\"\n    Disable decoding for audio and image fields. The fields will be saved as raw file bytes.\n    \"\"\"\n    for k, v in dataset.features.items():\n        if isinstance(v, Audio):\n            dataset = dataset.cast_column(k, Audio(decode=False))\n        elif isinstance(v, Image):\n            dataset = dataset.cast_column(k, Image(decode=False))\n\n    return dataset\n\n\ndef rename_columns(dataset):\n    \"\"\"\n    Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.\n    Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.\n    This way there is an easy name mapping between the rust and sql columns.\n    \"\"\"\n\n    for name in dataset.features.keys():\n        if \".\" in name:\n            dataset = dataset.rename_column(name, name.replace(\".\", \"_\"))\n\n    return dataset\n\n\ndef blob_columns(dataset):\n    \"\"\"\n    Make sure all binary columns are blob columns in the database because\n    `to_sql` exports binary values as TEXT instead of BLOB.\n    \"\"\"\n    type_mapping = {}\n    for name, value in dataset.features.items():\n        if value.pa_type is not None and pa.types.is_binary(value.pa_type):\n            type_mapping[name] = LargeBinary\n    return type_mapping\n\n\ndef set_sqlite_pragma(dbapi_connection, connection_record):\n    \"\"\"\n    Set some sqlite pragmas to speed up the database\n    \"\"\"\n    cursor = dbapi_connection.cursor()\n    cursor.execute(\"PRAGMA synchronous = OFF\")\n    cursor.execute(\"PRAGMA journal_mode = OFF\")\n    cursor.close()\n\n\ndef add_pk_column(target, connection, **kw):\n    \"\"\"\n    Add an id column to each table.\n    \"\"\"\n    target.append_column(Column(\"row_id\", Integer, primary_key=True))\n\n\ndef print_table_info(engine):\n    \"\"\"\n    Print the schema of the database so we can reference the columns in the rust code\n    \"\"\"\n    print(f\"Printing table schema for sqlite3 db ({engine})\")\n    inspector = inspect(engine)\n    for table_name in inspector.get_table_names():\n        print(f\"Table: {table_name}\")\n        for column in inspector.get_columns(table_name):\n            print(f\"Column: {column['name']} - {column['type']}\")\n        print(\"\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Huggingface datasets downloader to use with burn-dataset\"\n    )\n    parser.add_argument(\n        \"--name\", type=str, help=\"Name of the dataset to download\", required=True\n    )\n    parser.add_argument(\n        \"--file\", type=str, help=\"Base file name where the data is saved\", required=True\n    )\n    parser.add_argument(\n        \"--subset\", type=str, help=\"Subset name\", required=False, default=None\n    )\n    parser.add_argument(\n        \"--token\",\n        type=str,\n        help=\"HuggingFace authentication token\",\n        required=False,\n        default=None,\n    )\n    parser.add_argument(\n        \"--cache_dir\", type=str, help=\"Cache directory\", required=False, default=None\n    )\n    parser.add_argument(\n            \"--data_dir\", type=str, help=\"Relative path to a specific subset of your dataset\", required=False, default=None\n    )\n    parser.add_argument(\n        \"--trust_remote_code\",\n        type=bool,\n        help=\"Trust remote code\",\n        required=False,\n        default=None,\n    )\n\n    return parser.parse_args()\n\n\ndef run():\n    args = parse_args()\n\n    download_and_export(\n        args.name,\n        args.subset,\n        args.file,\n        args.token,\n        args.data_dir,\n        args.cache_dir,\n        args.trust_remote_code,\n    )\n\n\nif __name__ == \"__main__\":\n    run()\n"
  },
  {
    "path": "crates/burn-dataset/src/source/huggingface/mod.rs",
    "content": "pub(crate) mod downloader;\n\npub use downloader::*;\n"
  },
  {
    "path": "crates/burn-dataset/src/source/mod.rs",
    "content": "/// Huggingface source\n#[cfg(any(feature = \"sqlite\", feature = \"sqlite-bundled\"))]\npub mod huggingface;\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/composed.rs",
    "content": "use crate::Dataset;\n\n/// Compose multiple datasets together to create a bigger one.\n#[derive(new)]\npub struct ComposedDataset<D> {\n    datasets: Vec<D>,\n}\n\nimpl<D, I> Dataset<I> for ComposedDataset<D>\nwhere\n    D: Dataset<I>,\n    I: Clone,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        let mut current_index = 0;\n        for dataset in self.datasets.iter() {\n            if index < dataset.len() + current_index {\n                return dataset.get(index - current_index);\n            }\n            current_index += dataset.len();\n        }\n        None\n    }\n    fn len(&self) -> usize {\n        let mut total = 0;\n        for dataset in self.datasets.iter() {\n            total += dataset.len();\n        }\n        total\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::FakeDataset;\n\n    #[test]\n    fn test_composed_dataset() {\n        let dataset1 = FakeDataset::<String>::new(10);\n        let dataset2 = FakeDataset::<String>::new(5);\n\n        let items1 = dataset1.iter().collect::<Vec<_>>();\n        let items2 = dataset2.iter().collect::<Vec<_>>();\n\n        let composed = ComposedDataset::new(vec![dataset1, dataset2]);\n\n        assert_eq!(composed.len(), 15);\n\n        let expected_items: Vec<String> = items1.iter().chain(items2.iter()).cloned().collect();\n\n        let items = composed.iter().collect::<Vec<_>>();\n\n        assert_eq!(items, expected_items);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/mapper.rs",
    "content": "use crate::Dataset;\nuse std::marker::PhantomData;\n\n/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).\npub trait Mapper<I, O>: Send + Sync {\n    /// Maps an item of type I to an item of type O.\n    fn map(&self, item: &I) -> O;\n}\n\n/// Dataset mapping each element in an inner dataset to another element type lazily.\n#[derive(new)]\npub struct MapperDataset<D, M, I> {\n    dataset: D,\n    mapper: M,\n    input: PhantomData<I>,\n}\n\nimpl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>\nwhere\n    D: Dataset<I>,\n    M: Mapper<I, O> + Send + Sync,\n    I: Send + Sync,\n    O: Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<O> {\n        let item = self.dataset.get(index);\n        item.map(|item| self.mapper.map(&item))\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{InMemDataset, test_data};\n\n    #[test]\n    pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {\n        struct StringToFirstChar;\n\n        impl Mapper<String, String> for StringToFirstChar {\n            fn map(&self, item: &String) -> String {\n                let mut item = item.clone();\n                item.truncate(1);\n                item\n            }\n        }\n\n        let items_original = test_data::string_items();\n        let dataset = InMemDataset::new(items_original);\n        let dataset = MapperDataset::new(dataset, StringToFirstChar);\n\n        let items: Vec<String> = dataset.iter().collect();\n\n        assert_eq!(vec![\"1\", \"2\", \"3\", \"4\"], items);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/mod.rs",
    "content": "//! # Dataset Transformations\n//!\n//! This module provides a collection of [`crate::Dataset`] composition wrappers;\n//! providing composition, subset selection, sampling, random shuffling, and windowing.\n//!\n//! * [`ComposedDataset`] - composes a list of datasets.\n//! * [`PartialDataset`] - selects a contiguous index range subset of a dataset.\n//! * [`ShuffledDataset`] - a randomly shuffled / mutably shuffle-able dataset;\n//!   a thin wrapper around [`SelectionDataset`].\n//! * [`SamplerDataset`] - samples a dataset; support for with/without replacement,\n//!   and under/oversampling.\n//! * [`SelectionDataset`] - selects a subset of a dataset via indices; support for shuffling.\n//! * [`WindowsDataset`] - creates a sliding window over a dataset.\nmod composed;\nmod mapper;\nmod options;\nmod partial;\nmod sampler;\nmod selection;\nmod shuffle;\nmod window;\n\npub use composed::*;\npub use mapper::*;\npub use options::*;\npub use partial::*;\npub use sampler::*;\npub use selection::*;\npub use shuffle::*;\npub use window::*;\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/options.rs",
    "content": "use rand::SeedableRng;\nuse rand::prelude::StdRng;\nuse rand::rngs::SysRng;\n\n/// Defines a source for a `StdRng`.\n///\n/// # Examples\n///\n/// ```rust,no_run\n/// use rand::rngs::StdRng;\n/// use rand::SeedableRng;\n/// use burn_dataset::transform::RngSource;\n///\n/// // Default via `StdRng::from_os_rng()` (`RngSource::Default`)\n/// let system: RngSource = RngSource::default();\n///\n/// // From a fixed seed (`RngSource::Seed`)\n/// let seeded: RngSource = 42.into();\n///\n/// // From an existing rng (`RngSource::Rng`)\n/// let rng = StdRng::seed_from_u64(123);\n/// let with_rng: RngSource = rng.into();\n///\n/// // Forks the parent RNG to derive an independent, deterministic child RNG.\n/// // The original `rng` is modified, and the resulting `RngSource` contains\n/// // a new RNG starting from a unique state.\n/// let mut rng = StdRng::seed_from_u64(123);\n/// let forked: RngSource = (&mut rng).into();\n/// ```\n#[derive(Debug, Default, PartialEq, Eq)]\n#[allow(clippy::large_enum_variant)]\npub enum RngSource {\n    /// Build a new rng from the system.\n    #[default]\n    Default,\n\n    /// The rng is passed as a seed.\n    Seed(u64),\n\n    /// The rng is passed as an option.\n    Rng(StdRng),\n}\n\nimpl From<RngSource> for StdRng {\n    fn from(source: RngSource) -> Self {\n        match source {\n            RngSource::Default => StdRng::try_from_rng(&mut SysRng).unwrap(),\n            RngSource::Rng(rng) => rng,\n            RngSource::Seed(seed) => StdRng::seed_from_u64(seed),\n        }\n    }\n}\n\nimpl From<u64> for RngSource {\n    fn from(seed: u64) -> Self {\n        Self::Seed(seed)\n    }\n}\n\nimpl From<StdRng> for RngSource {\n    fn from(rng: StdRng) -> Self {\n        Self::Rng(rng)\n    }\n}\n\n/// Derive an independent RNG from a mutable parent RNG.\n///\n/// This advances the parent RNG and creates a new RNG seeded from its output.\n/// The derived RNG is *not* a clone of the parent's state, but an independent\n/// stream (equivalent to `SeedableRng::fork`).\nimpl From<&mut StdRng> for RngSource {\n    fn from(rng: &mut StdRng) -> Self {\n        Self::Rng(rng.fork())\n    }\n}\n\n/// Helper option to describe the size of a wrapper, relative to a wrapped object.\n#[derive(Debug, Clone, Copy, Default, PartialEq)]\npub enum SizeConfig {\n    /// Use the size of the source dataset.\n    #[default]\n    Default,\n\n    /// Use the size as a ratio of the source dataset size.\n    ///\n    /// Must be >= 0.\n    Ratio(f64),\n\n    /// Use a fixed size.\n    Fixed(usize),\n}\n\nimpl SizeConfig {\n    /// Construct a source which will have the same size as the source dataset.\n    pub fn source() -> Self {\n        Self::Default\n    }\n\n    /// Resolve the effective size.\n    ///\n    /// ## Arguments\n    ///\n    /// - `source_size`: the size of the source dataset.\n    ///\n    /// ## Returns\n    ///\n    /// The resolved size of the wrapper dataset.\n    pub fn resolve(self, source_size: usize) -> usize {\n        match self {\n            SizeConfig::Default => source_size,\n            SizeConfig::Ratio(ratio) => {\n                assert!(ratio >= 0.0, \"Ratio must be positive: {ratio}\");\n                ((source_size as f64) * ratio) as usize\n            }\n            SizeConfig::Fixed(size) => size,\n        }\n    }\n}\n\nimpl From<usize> for SizeConfig {\n    fn from(size: usize) -> Self {\n        Self::Fixed(size)\n    }\n}\n\nimpl From<f64> for SizeConfig {\n    fn from(ratio: f64) -> Self {\n        Self::Ratio(ratio)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use rand::SeedableRng;\n\n    #[test]\n    fn test_rng_source_default() {\n        let rng_source: RngSource = Default::default();\n        assert_eq!(&rng_source, &RngSource::Default);\n        assert_eq!(&rng_source, &RngSource::default());\n\n        // Exercise the from_os_rng() call; but we don't know its seed;\n        let _rng: StdRng = rng_source.into();\n    }\n\n    #[test]\n    fn test_rng_source_seed() {\n        let rng_source = RngSource::from(42);\n        assert_eq!(&rng_source, &RngSource::Seed(42));\n\n        let rng: StdRng = rng_source.into();\n        let expected = StdRng::seed_from_u64(42);\n\n        assert_eq!(rng, expected);\n    }\n\n    #[test]\n    fn test_rng_source_rng() {\n        // From StdRng (owned).\n        {\n            let original = StdRng::seed_from_u64(42);\n\n            let rng_source = RngSource::from(original);\n            let rng: StdRng = rng_source.into();\n            // No longer clone, but from <> into should not have advanced the state\n            let original = StdRng::seed_from_u64(42);\n            assert_eq!(rng, original);\n        }\n\n        // From &mut StdRng (forks parent)\n        {\n            let mut original = StdRng::seed_from_u64(42);\n            let mut rng = StdRng::seed_from_u64(42);\n            let rng_forked = rng.fork();\n\n            let rng_source = RngSource::from(&mut original);\n\n            // Ensure the original was advanced\n            assert_eq!(original, rng);\n\n            // Ensure the sourced RNG matches the fork\n            let rng: StdRng = rng_source.into();\n            assert_eq!(rng, rng_forked);\n        }\n    }\n\n    #[test]\n    fn test_size_config() {\n        assert_eq!(SizeConfig::default(), SizeConfig::Default);\n\n        assert_eq!(SizeConfig::from(42), SizeConfig::Fixed(42));\n\n        assert_eq!(SizeConfig::from(1.5), SizeConfig::Ratio(1.5));\n\n        assert_eq!(SizeConfig::source(), SizeConfig::Default);\n        assert_eq!(SizeConfig::source().resolve(50), 50);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/partial.rs",
    "content": "use crate::Dataset;\nuse std::{marker::PhantomData, sync::Arc};\n\n/// Only use a fraction of an existing dataset lazily.\n#[derive(new, Clone)]\npub struct PartialDataset<D, I> {\n    dataset: D,\n    start_index: usize,\n    end_index: usize,\n    input: PhantomData<I>,\n}\n\nimpl<D, I> PartialDataset<D, I>\nwhere\n    D: Dataset<I>,\n{\n    /// Splits a dataset into multiple partial datasets.\n    pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {\n        let dataset = Arc::new(dataset); // cheap cloning.\n\n        let mut current = 0;\n        let mut datasets = Vec::with_capacity(num);\n\n        let batch_size = dataset.len() / num;\n\n        for i in 0..num {\n            let start = current;\n            let mut end = current + batch_size;\n\n            if i == (num - 1) {\n                end = dataset.len();\n            }\n\n            let dataset = PartialDataset::new(dataset.clone(), start, end);\n\n            current += batch_size;\n            datasets.push(dataset);\n        }\n\n        datasets\n    }\n\n    /// Splits a dataset by distributing complete chunks/batches across multiple partial datasets.\n    pub fn split_chunks(\n        dataset: D,\n        num: usize,\n        batch_size: usize,\n    ) -> Vec<PartialDataset<Arc<D>, I>> {\n        let dataset = Arc::new(dataset); // cheap cloning.\n        let total_items = dataset.len();\n\n        // Total number of complete batches\n        let total_batches = total_items.div_ceil(batch_size);\n        let batches_per_split = total_batches / num;\n        let extra_batches = total_batches % num;\n\n        let mut datasets = Vec::with_capacity(num);\n        let mut current_batch = 0;\n\n        for i in 0..num {\n            // Extra batches distributed across first splits\n            let split_batches = if i < extra_batches {\n                batches_per_split + 1\n            } else {\n                batches_per_split\n            };\n\n            let start_batch = current_batch;\n            let end_batch = start_batch + split_batches;\n\n            let start_index = start_batch * batch_size;\n            let end_index = core::cmp::min(end_batch * batch_size, total_items);\n\n            if start_index < total_items {\n                datasets.push(PartialDataset::new(dataset.clone(), start_index, end_index));\n            }\n\n            current_batch = end_batch;\n        }\n\n        datasets\n    }\n}\n\nimpl<D, I> Dataset<I> for PartialDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        let index = index + self.start_index;\n        if index < self.start_index || index >= self.end_index {\n            return None;\n        }\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        usize::min(self.end_index - self.start_index, self.dataset.len())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::FakeDataset;\n    use std::collections::HashSet;\n\n    #[test]\n    fn test_start_from_beginning() {\n        let dataset_original = FakeDataset::<String>::new(27);\n        let mut items_original_1 = HashSet::new();\n        let mut items_original_2 = HashSet::new();\n        let mut items_partial = HashSet::new();\n        dataset_original.iter().enumerate().for_each(|(i, item)| {\n            match i >= 10 {\n                true => items_original_2.insert(item),\n                false => items_original_1.insert(item),\n            };\n        });\n\n        let dataset_partial = PartialDataset::new(dataset_original, 0, 10);\n\n        for item in dataset_partial.iter() {\n            items_partial.insert(item);\n        }\n\n        assert_eq!(dataset_partial.len(), 10);\n        assert_eq!(items_original_1, items_partial);\n        for item in items_original_2 {\n            assert!(!items_partial.contains(&item));\n        }\n    }\n\n    #[test]\n    fn test_start_inside() {\n        let dataset_original = FakeDataset::<String>::new(27);\n        let mut items_original_1 = HashSet::new();\n        let mut items_original_2 = HashSet::new();\n        let mut items_partial = HashSet::new();\n\n        dataset_original.iter().enumerate().for_each(|(i, item)| {\n            match !(10..20).contains(&i) {\n                true => items_original_2.insert(item),\n                false => items_original_1.insert(item),\n            };\n        });\n\n        let dataset_partial = PartialDataset::new(dataset_original, 10, 20);\n        for item in dataset_partial.iter() {\n            items_partial.insert(item);\n        }\n\n        assert_eq!(dataset_partial.len(), 10);\n        assert_eq!(items_original_1, items_partial);\n        for item in items_original_2 {\n            assert!(!items_partial.contains(&item));\n        }\n    }\n\n    #[test]\n    fn test_split_contains_all_items_without_duplicates() {\n        let dataset_original = FakeDataset::<String>::new(27);\n        let mut items_original = Vec::new();\n        let mut items_partial = Vec::new();\n        for item in dataset_original.iter() {\n            items_original.push(item);\n        }\n\n        let dataset_partials = PartialDataset::split(dataset_original, 4);\n        let expected_len = [6, 6, 6, 9];\n\n        for (i, dataset) in dataset_partials.iter().enumerate() {\n            assert_eq!(dataset.len(), expected_len[i]);\n            for item in dataset.iter() {\n                items_partial.push(item);\n            }\n        }\n\n        assert_eq!(items_original, items_partial);\n    }\n\n    #[test]\n    fn test_split_chunks_contains_all_items_without_duplicates() {\n        let dataset_original = FakeDataset::<String>::new(27);\n        let mut items_original = Vec::new();\n        let mut items_partial = Vec::new();\n        for item in dataset_original.iter() {\n            items_original.push(item);\n        }\n\n        let dataset_partials = PartialDataset::split_chunks(dataset_original, 4, 5);\n        // [(2 * 5), (2 * 5), 5, 2] -> 5 complete chunks + 1 incomplete with 2 remaining items\n        // OTOH, `split(dataset, 4)` would yield [6, 6, 6, 9] -> 4 incomplete chunks + 4 incomplete with [1, 1, 1, 4]\n        let expected_len = [10, 10, 5, 2];\n\n        for (i, dataset) in dataset_partials.iter().enumerate() {\n            assert_eq!(dataset.len(), expected_len[i]);\n            for item in dataset.iter() {\n                items_partial.push(item);\n            }\n        }\n\n        assert_eq!(items_original, items_partial);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/sampler.rs",
    "content": "use crate::Dataset;\nuse crate::transform::{RngSource, SizeConfig};\nuse rand::prelude::SliceRandom;\nuse rand::{RngExt, distr::Uniform, rngs::StdRng, seq::IteratorRandom};\nuse std::{marker::PhantomData, ops::DerefMut, sync::Mutex};\n\n/// Options to configure a [SamplerDataset].\n#[derive(Debug, PartialEq)]\npub struct SamplerDatasetOptions {\n    /// The sampling mode.\n    pub replace_samples: bool,\n\n    /// The size source of the wrapper relative to the dataset.\n    pub size_config: SizeConfig,\n\n    /// The source of the random number generator.\n    pub rng_source: RngSource,\n}\n\nimpl Default for SamplerDatasetOptions {\n    fn default() -> Self {\n        Self {\n            replace_samples: true,\n            size_config: SizeConfig::Default,\n            rng_source: RngSource::Default,\n        }\n    }\n}\n\nimpl<T> From<Option<T>> for SamplerDatasetOptions\nwhere\n    T: Into<SamplerDatasetOptions>,\n{\n    fn from(option: Option<T>) -> Self {\n        match option {\n            Some(option) => option.into(),\n            None => Self::default(),\n        }\n    }\n}\n\nimpl From<usize> for SamplerDatasetOptions {\n    fn from(size: usize) -> Self {\n        Self::default().with_replacement().with_fixed_size(size)\n    }\n}\n\nimpl SamplerDatasetOptions {\n    /// Set the replacement mode.\n    pub fn with_replace_samples(self, replace_samples: bool) -> Self {\n        Self {\n            replace_samples,\n            ..self\n        }\n    }\n\n    /// Set the replacement mode to WithReplacement.\n    pub fn with_replacement(self) -> Self {\n        self.with_replace_samples(true)\n    }\n\n    /// Set the replacement mode to WithoutReplacement.\n    pub fn without_replacement(self) -> Self {\n        self.with_replace_samples(false)\n    }\n\n    /// Set the size source.\n    pub fn with_size<S>(self, source: S) -> Self\n    where\n        S: Into<SizeConfig>,\n    {\n        Self {\n            size_config: source.into(),\n            ..self\n        }\n    }\n\n    /// Set the size to the size of the source.\n    pub fn with_source_size(self) -> Self {\n        self.with_size(SizeConfig::Default)\n    }\n\n    /// Set the size to a fixed size.\n    pub fn with_fixed_size(self, size: usize) -> Self {\n        self.with_size(size)\n    }\n\n    /// Set the size to be a multiple of the ration and the source size.\n    pub fn with_size_ratio(self, size_ratio: f64) -> Self {\n        self.with_size(size_ratio)\n    }\n\n    /// Set the `RngSource`.\n    pub fn with_rng<R>(self, rng: R) -> Self\n    where\n        R: Into<RngSource>,\n    {\n        Self {\n            rng_source: rng.into(),\n            ..self\n        }\n    }\n\n    /// Use the system rng.\n    pub fn with_system_rng(self) -> Self {\n        self.with_rng(RngSource::Default)\n    }\n\n    /// Use a rng, built from a seed.\n    pub fn with_seed(self, seed: u64) -> Self {\n        self.with_rng(seed)\n    }\n}\n\n/// Sample items from a dataset.\n///\n/// This is a convenient way of modeling a dataset as a probability distribution of a fixed size.\n/// You have multiple options to instantiate the dataset sampler.\n///\n/// * With replacement (Default): This is the most efficient way of using the sampler because no state is\n///   required to keep indices that have been selected.\n///\n/// * Without replacement: This has a similar effect to using a\n///   [shuffled dataset](crate::transform::ShuffledDataset), but with more flexibility since you can\n///   set the dataset to an arbitrary size. Once every item has been used, a new cycle is\n///   created with a new random suffle.\npub struct SamplerDataset<D, I> {\n    dataset: D,\n    size: usize,\n    state: Mutex<SamplerState>,\n    input: PhantomData<I>,\n}\nenum SamplerState {\n    WithReplacement(StdRng),\n    WithoutReplacement(StdRng, Vec<usize>),\n}\n\nimpl<D, I> SamplerDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Send + Sync,\n{\n    /// Creates a new sampler dataset with replacement.\n    ///\n    /// When the sample size is less than or equal to the source dataset size,\n    /// data will be sampled without replacement from the source dataset in\n    /// a uniformly shuffled order.\n    ///\n    /// When the sample size is greater than the source dataset size,\n    /// the entire source dataset will be sampled once for every multiple\n    /// of the size ratios; with the remaining samples taken without replacement\n    /// uniformly from the source. All samples will be returned uniformly shuffled.\n    ///\n    /// ## Arguments\n    ///\n    /// * `dataset`: the dataset to wrap.\n    /// * `options`: the options to configure the sampler dataset.\n    ///\n    /// ## Examples\n    /// ```rust,ignore\n    /// use burn_dataset::transform::{\n    ///   SamplerDataset,\n    ///   SamplerDatasetOptions,\n    /// };\n    ///\n    /// // Examples below assuming `dataset.len()` = `10`.\n    ///\n    /// // sample size: 5\n    /// // WithReplacement\n    /// // rng: StdRng::from_os_rng()\n    /// SamplerDataset::new(dataset, 5);\n    ///\n    /// // sample size: 10 (source)\n    /// // WithReplacement\n    /// // rng: StdRng::from_os_rng()\n    /// SamplerDataset::new(dataset, SamplerDatasetOptions::default());\n    ///\n    /// // sample size: 15\n    /// // WithoutReplacement\n    /// // rng: StdRng::seed_from_u64(42)\n    /// SamplerDataset::new(\n    ///   dataset,\n    ///   SamplerDatasetOptions::default()\n    ///     .with_size(1.5)\n    ///     .without_replacement()\n    ///     .with_rng(42),\n    /// );\n    /// ```\n    pub fn new<O>(dataset: D, options: O) -> Self\n    where\n        O: Into<SamplerDatasetOptions>,\n    {\n        let options = options.into();\n        let size = options.size_config.resolve(dataset.len());\n        let rng = options.rng_source.into();\n        Self {\n            dataset,\n            size,\n            state: Mutex::new(match options.replace_samples {\n                true => SamplerState::WithReplacement(rng),\n                false => SamplerState::WithoutReplacement(rng, Vec::with_capacity(size)),\n            }),\n            input: PhantomData,\n        }\n    }\n\n    /// Creates a new sampler dataset with replacement.\n    ///\n    /// # Arguments\n    ///\n    /// - `dataset`: the dataset to wrap.\n    /// - `size`: the effective size of the sampled dataset.\n    pub fn with_replacement(dataset: D, size: usize) -> Self {\n        Self::new(\n            dataset,\n            SamplerDatasetOptions::default()\n                .with_replacement()\n                .with_fixed_size(size),\n        )\n    }\n\n    /// Creates a new sampler dataset without replacement.\n    ///\n    /// When the sample size is less than or equal to the source dataset size,\n    /// data will be sampled without replacement from the source dataset in\n    /// a uniformly shuffled order.\n    ///\n    /// When the sample size is greater than the source dataset size,\n    /// the entire source dataset will be sampled once for every multiple\n    /// of the size ratios; with the remaining samples taken without replacement\n    /// uniformly from the source. All samples will be returned uniformly shuffled.\n    ///\n    /// # Arguments\n    /// - `dataset`: the dataset to wrap.\n    /// - `size`: the effective size of the sampled dataset.\n    pub fn without_replacement(dataset: D, size: usize) -> Self {\n        Self::new(\n            dataset,\n            SamplerDatasetOptions::default()\n                .without_replacement()\n                .with_fixed_size(size),\n        )\n    }\n\n    /// Determines if the sampler is using the \"with replacement\" strategy.\n    ///\n    /// # Returns\n    /// - `true`: If the sampler is configured to sample with replacement.\n    /// - `false`: If the sampler is configured to sample without replacement.\n    pub fn is_with_replacement(&self) -> bool {\n        match self.state.lock().unwrap().deref_mut() {\n            SamplerState::WithReplacement(_) => true,\n            SamplerState::WithoutReplacement(_, _) => false,\n        }\n    }\n\n    fn index(&self) -> usize {\n        match self.state.lock().unwrap().deref_mut() {\n            SamplerState::WithReplacement(rng) => {\n                rng.sample(Uniform::new(0, self.dataset.len()).unwrap())\n            }\n            SamplerState::WithoutReplacement(rng, indices) => {\n                if indices.is_empty() {\n                    // Refill the state.\n                    let idx_range = 0..self.dataset.len();\n                    for _ in 0..(self.size / self.dataset.len()) {\n                        // No need to `.choose_multiple` here because we're using\n                        // the entire source range; and `.choose_multiple` will\n                        // not return a random sample anyway.\n                        indices.extend(idx_range.clone())\n                    }\n\n                    // From `choose_multiple` documentation:\n                    // > Although the elements are selected randomly, the order of elements in\n                    // > the buffer is neither stable nor fully random. If random ordering is\n                    // > desired, shuffle the result.\n                    indices.extend(idx_range.sample(rng, self.size - indices.len()));\n\n                    // The real shuffling is done here.\n                    indices.shuffle(rng);\n                }\n\n                indices.pop().expect(\"Indices are refilled when empty.\")\n            }\n        }\n    }\n}\n\nimpl<D, I> Dataset<I> for SamplerDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        if index >= self.size {\n            return None;\n        }\n\n        self.dataset.get(self.index())\n    }\n\n    fn len(&self) -> usize {\n        self.size\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    #![allow(clippy::bool_assert_comparison)]\n\n    use super::*;\n    use crate::FakeDataset;\n    use rand::SeedableRng;\n    use std::collections::HashMap;\n\n    #[test]\n    fn test_samplerdataset_options() {\n        let options = SamplerDatasetOptions::default();\n        assert_eq!(options.replace_samples, true);\n        assert_eq!(options.size_config, SizeConfig::Default);\n        assert_eq!(options.rng_source, RngSource::Default);\n\n        // ReplacementMode\n        let options = options.with_replace_samples(false);\n        assert_eq!(options.replace_samples, false);\n        let options = options.with_replacement();\n        assert_eq!(options.replace_samples, true);\n        let options = options.without_replacement();\n        assert_eq!(options.replace_samples, false);\n\n        // SourceSize\n        let options = options.with_size(SizeConfig::Default);\n        assert_eq!(options.size_config, SizeConfig::Default);\n        let options = options.with_source_size();\n        assert_eq!(options.size_config, SizeConfig::Default);\n        let options = options.with_fixed_size(10);\n        assert_eq!(options.size_config, SizeConfig::Fixed(10));\n        let options = options.with_size_ratio(1.5);\n        assert_eq!(options.size_config, SizeConfig::Ratio(1.5));\n\n        // RngSource\n        let options = options.with_system_rng();\n        assert_eq!(options.rng_source, RngSource::Default);\n        let options = options.with_seed(42);\n        assert_eq!(options.rng_source, RngSource::Seed(42));\n        let rng = StdRng::seed_from_u64(9);\n        let options = options.with_rng(rng);\n        assert!(matches!(options.rng_source, RngSource::Rng(_)));\n    }\n\n    #[test]\n    fn sampler_dataset_constructors_test() {\n        let ds = SamplerDataset::new(FakeDataset::<u32>::new(10), 15);\n        assert_eq!(ds.len(), 15);\n        assert_eq!(ds.dataset.len(), 10);\n        assert!(ds.is_with_replacement());\n\n        let ds = SamplerDataset::with_replacement(FakeDataset::<u32>::new(10), 15);\n        assert_eq!(ds.len(), 15);\n        assert_eq!(ds.dataset.len(), 10);\n        assert!(ds.is_with_replacement());\n\n        let ds = SamplerDataset::without_replacement(FakeDataset::<u32>::new(10), 15);\n        assert_eq!(ds.len(), 15);\n        assert_eq!(ds.dataset.len(), 10);\n        assert!(!ds.is_with_replacement());\n    }\n\n    #[test]\n    fn sampler_dataset_with_replacement_iter() {\n        let factor = 3;\n        let len_original = 10;\n        let dataset_sampler = SamplerDataset::with_replacement(\n            FakeDataset::<String>::new(len_original),\n            len_original * factor,\n        );\n        let mut total = 0;\n\n        for _item in dataset_sampler.iter() {\n            total += 1;\n        }\n\n        assert_eq!(total, factor * len_original);\n    }\n\n    #[test]\n    fn sampler_dataset_without_replacement_bucket_test() {\n        let factor = 3;\n        let len_original = 10;\n\n        let dataset_sampler = SamplerDataset::new(\n            FakeDataset::<String>::new(len_original),\n            SamplerDatasetOptions::default()\n                .without_replacement()\n                .with_size_ratio(factor as f64),\n        );\n\n        let mut buckets = HashMap::new();\n\n        for item in dataset_sampler.iter() {\n            let count = match buckets.get(&item) {\n                Some(count) => count + 1,\n                None => 1,\n            };\n\n            buckets.insert(item, count);\n        }\n\n        let mut total = 0;\n        for count in buckets.into_values() {\n            assert_eq!(count, factor);\n            total += count;\n        }\n        assert_eq!(total, factor * len_original);\n    }\n\n    #[test]\n    fn sampler_dataset_without_replacement_uniform_order_test() {\n        // This is a reversion test on the indices.shuffle(rng) call in SamplerDataset::index().\n        let size = 1000;\n        let dataset_sampler =\n            SamplerDataset::without_replacement(FakeDataset::<i32>::new(size), size);\n\n        let indices: Vec<_> = (0..size).map(|_| dataset_sampler.index()).collect();\n        let mean_delta = indices\n            .windows(2)\n            .map(|pair| pair[1].abs_diff(pair[0]))\n            .sum::<usize>() as f64\n            / (size - 1) as f64;\n\n        let expected = (size + 2) as f64 / 3.0;\n\n        assert!(\n            (mean_delta - expected).abs() <= 0.25 * expected,\n            \"Sampled indices are not uniformly distributed: mean_delta: {mean_delta}, expected: {expected}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/selection.rs",
    "content": "use crate::Dataset;\nuse crate::transform::RngSource;\nuse rand::prelude::SliceRandom;\nuse rand::rngs::StdRng;\nuse std::marker::PhantomData;\nuse std::sync::Arc;\n\n/// Generates a vector of indices from 0 to size - 1.\n///\n/// # Arguments\n///\n/// * `size` - The size of the dataset.\n///\n/// # Returns\n///\n/// A vector containing indices from 0 to size - 1.\n#[inline(always)]\npub fn iota(size: usize) -> Vec<usize> {\n    (0..size).collect()\n}\n\n/// Generates a shuffled vector of indices up to a size.\n///\n/// # Arguments\n///\n/// * `size` - The size of the dataset to shuffle.\n///\n/// # Returns\n///\n/// A vector of shuffled indices.\n#[inline(always)]\npub fn shuffled_indices(size: usize, rng: &mut StdRng) -> Vec<usize> {\n    let mut indices = iota(size);\n    indices.shuffle(rng);\n    indices\n}\n\n/// A dataset that selects a subset of indices from an existing dataset.\n///\n/// Indices may appear multiple times, but they must be within the bounds of the original dataset.\n#[derive(Clone)]\npub struct SelectionDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    /// The wrapped dataset from which to select indices.\n    pub wrapped: Arc<D>,\n\n    /// The indices to select from the wrapped dataset.\n    pub indices: Vec<usize>,\n\n    input: PhantomData<I>,\n}\n\nimpl<D, I> SelectionDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    /// Creates a new selection dataset with the given dataset and indices.\n    ///\n    /// Checks that all indices are within the bounds of the dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    /// * `indices` - A slice of indices to select from the dataset.\n    ///   These indices must be within the bounds of the dataset.\n    ///\n    /// # Panics\n    ///\n    /// Panics if any index is out of bounds for the dataset.\n    pub fn from_indices_checked<S>(dataset: S, indices: Vec<usize>) -> Self\n    where\n        S: Into<Arc<D>>,\n    {\n        let dataset = dataset.into();\n\n        let size = dataset.len();\n        if let Some(idx) = indices.iter().find(|&i| *i >= size) {\n            panic!(\"Index out of bounds for wrapped dataset size: {idx} >= {size}\");\n        }\n\n        Self::from_indices_unchecked(dataset, indices)\n    }\n\n    /// Creates a new selection dataset with the given dataset and indices without checking bounds.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    /// * `indices` - A vector of indices to select from the dataset.\n    ///\n    /// # Safety\n    ///\n    /// This function does not check if the indices are within the bounds of the dataset.\n    pub fn from_indices_unchecked<S>(dataset: S, indices: Vec<usize>) -> Self\n    where\n        S: Into<Arc<D>>,\n    {\n        Self {\n            wrapped: dataset.into(),\n            indices,\n            input: PhantomData,\n        }\n    }\n\n    /// Creates a new selection dataset that selects all indices from the dataset.\n    ///\n    /// This allocates a 1-to-1 mapping of indices to the dataset size,\n    /// essentially functioning as a no-op selection. This is only useful\n    /// when the dataset will later be shuffled or transformed in place.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    ///\n    /// # Returns\n    ///\n    /// A new `SelectionDataset` that selects all indices from the dataset.\n    pub fn new_select_all<S>(dataset: S) -> Self\n    where\n        S: Into<Arc<D>>,\n    {\n        let dataset = dataset.into();\n        let size = dataset.len();\n        Self::from_indices_unchecked(dataset, iota(size))\n    }\n\n    /// Creates a new selection dataset with shuffled indices.\n    ///\n    /// Selects every index of the dataset and shuffles them\n    /// with randomness from the provided random number generator.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    /// * `rng` - A mutable reference to a random number generator.\n    ///\n    /// # Returns\n    ///\n    /// A new `SelectionDataset` with shuffled indices.\n    pub fn new_shuffled<S, R>(dataset: S, rng_source: R) -> Self\n    where\n        S: Into<Arc<D>>,\n        R: Into<RngSource>,\n    {\n        let mut this = Self::new_select_all(dataset);\n        this.shuffle(rng_source);\n        this\n    }\n\n    /// Shuffles the indices of the dataset using a mutable random number generator.\n    ///\n    /// This method modifies the dataset in place, shuffling the indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `rng` - A mutable reference to a random number generator.\n    pub fn shuffle<R>(&mut self, rng_source: R)\n    where\n        R: Into<RngSource>,\n    {\n        let mut rng: StdRng = rng_source.into().into();\n        self.indices.shuffle(&mut rng)\n    }\n\n    /// Creates a new dataset that is a slice of the current selection dataset.\n    ///\n    /// Slices the *selection indices* from ``[start..end]``.\n    ///\n    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.\n    ///\n    ///\n    /// # Arguments\n    ///\n    /// * `start` - The start of the range.\n    /// * `end` - The end of the range (exclusive).\n    // TODO: SliceArg in burn-tensor should be lifted to burn-std; this should use SliceArg.\n    pub fn slice(&self, start: usize, end: usize) -> Self {\n        Self::from_indices_unchecked(self.wrapped.clone(), self.indices[start..end].to_vec())\n    }\n\n    /// Split into `num` datasets by slicing the selection indices evenly.\n    ///\n    /// Split is done via `slice`, so the datasets share the same wrapped dataset.\n    ///\n    /// Independent of future shuffles on the parent, but shares the same wrapped dataset.\n    ///\n    /// # Arguments\n    ///\n    /// * `num` - The number of datasets to split into.\n    ///\n    /// # Returns\n    ///\n    /// A vector of `SelectionDataset` instances, each containing a subset of the indices.\n    pub fn split(&self, num: usize) -> Vec<Self> {\n        let n = self.indices.len();\n\n        let mut current = 0;\n        let mut datasets = Vec::with_capacity(num);\n\n        let batch_size = n / num;\n        for i in 0..num {\n            let start = current;\n            let mut end = current + batch_size;\n\n            if i == (num - 1) {\n                end = n;\n            }\n\n            let dataset = self.slice(start, end);\n\n            current += batch_size;\n            datasets.push(dataset);\n        }\n\n        datasets\n    }\n}\n\nimpl<D, I> Dataset<I> for SelectionDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        let index = self.indices.get(index)?;\n        self.wrapped.get(*index)\n    }\n\n    fn len(&self) -> usize {\n        self.indices.len()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::FakeDataset;\n    use rand::SeedableRng;\n\n    #[test]\n    fn test_iota() {\n        let size = 10;\n        let indices = iota(size);\n        assert_eq!(indices.len(), size);\n        assert_eq!(indices, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);\n    }\n\n    #[test]\n    fn test_shuffled_indices_same_seed_is_deterministic() {\n        let size = 10;\n\n        let mut rng1 = StdRng::seed_from_u64(10);\n        // `StdRng` is no longer `Clone`, so its internal state cannot be duplicated.\n        // To test determinism, we must explicitly create a second RNG from the same seed.\n        let mut rng2 = StdRng::seed_from_u64(10);\n\n        let mut expected = iota(size);\n        expected.shuffle(&mut rng1);\n\n        let indices = shuffled_indices(size, &mut rng2);\n\n        assert_eq!(indices, expected);\n    }\n\n    #[test]\n    fn test_shuffled_indices_forked_rngs_differ() {\n        let size = 10;\n\n        let mut rng1 = StdRng::seed_from_u64(10);\n        let mut rng2 = rng1.fork();\n\n        let mut a = iota(size);\n        let mut b = iota(size);\n\n        a.shuffle(&mut rng1);\n        b.shuffle(&mut rng2);\n\n        assert_ne!(a, b);\n    }\n\n    #[should_panic(expected = \"Index out of bounds for wrapped dataset size: 300 >= 27\")]\n    #[test]\n    fn test_from_indices_checked_panics() {\n        let source_dataset = FakeDataset::<String>::new(27);\n        let indices: Vec<usize> = vec![15, 1, 12, 300];\n        SelectionDataset::from_indices_checked(source_dataset, indices);\n    }\n\n    #[test]\n    fn test_checked_selection_dataset() {\n        let source_dataset = FakeDataset::<String>::new(27);\n\n        let indices: Vec<usize> = vec![15, 1, 12, 12];\n        let expected: Vec<String> = indices\n            .iter()\n            .map(|i| source_dataset.get(*i).unwrap())\n            .collect();\n\n        let selection = SelectionDataset::from_indices_checked(source_dataset, indices.clone());\n\n        assert_eq!(&selection.indices, &indices);\n\n        let items = selection.iter().collect::<Vec<_>>();\n\n        assert_eq!(items, expected);\n    }\n\n    #[test]\n    fn test_shuffled_dataset() {\n        let dataset = FakeDataset::<String>::new(27);\n        let source_items = dataset.iter().collect::<Vec<_>>();\n\n        let selection = SelectionDataset::new_shuffled(dataset, 42);\n\n        let indices = shuffled_indices(source_items.len(), &mut StdRng::seed_from_u64(42));\n\n        assert_eq!(&selection.indices, &indices);\n        assert_eq!(selection.len(), source_items.len());\n\n        let expected_items: Vec<_> = indices\n            .iter()\n            .map(|&i| source_items[i].to_string())\n            .collect();\n        assert_eq!(&selection.iter().collect::<Vec<_>>(), &expected_items);\n    }\n\n    #[test]\n    fn test_slice() {\n        let dataset = FakeDataset::<String>::new(27);\n        let source_items = dataset.iter().collect::<Vec<_>>();\n\n        let selection = SelectionDataset::new_select_all(dataset);\n\n        let start = 5;\n        let end = 15;\n        let sliced_selection = selection.slice(start, end);\n\n        assert_eq!(sliced_selection.len(), end - start);\n\n        #[allow(clippy::needless_range_loop)]\n        for i in start..end {\n            assert_eq!(\n                sliced_selection.get(i - start),\n                Some(source_items[i].to_string())\n            );\n        }\n    }\n\n    #[test]\n    fn test_split() {\n        let dataset = FakeDataset::<String>::new(28);\n        let source_items = dataset.iter().collect::<Vec<_>>();\n\n        let selection = SelectionDataset::new_select_all(dataset);\n\n        let split_contents: Vec<Vec<_>> = selection\n            .split(3)\n            .iter()\n            .map(|d| d.iter().collect::<Vec<_>>())\n            .collect();\n        assert_eq!(\n            split_contents,\n            vec![\n                source_items[0..9].to_vec(),\n                source_items[9..18].to_vec(),\n                source_items[18..28].to_vec(),\n            ]\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/shuffle.rs",
    "content": "use crate::Dataset;\nuse crate::transform::{RngSource, SelectionDataset};\n\n/// A Shuffled a dataset.\n///\n/// This is a thin wrapper around a [SelectionDataset] which selects and shuffles\n/// the full indices of the original dataset.\n///\n/// Consider using [SelectionDataset] if you are only interested in\n/// shuffling mechanisms.\n///\n/// Consider using [sampler dataset](crate::transform::SamplerDataset) if you\n/// want a probability distribution which is computed lazily.\npub struct ShuffledDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    wrapped: SelectionDataset<D, I>,\n}\n\nimpl<D, I> ShuffledDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    /// Creates a new selection dataset with shuffled indices.\n    ///\n    /// This is a thin wrapper around `SelectionDataset::new_shuffled`.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    /// * `rng_source` - The source of the random number generator.\n    ///\n    /// # Returns\n    ///\n    /// A new `ShuffledDataset`.\n    pub fn new<R>(dataset: D, rng_source: R) -> Self\n    where\n        R: Into<RngSource>,\n    {\n        Self {\n            wrapped: SelectionDataset::new_shuffled(dataset, rng_source),\n        }\n    }\n\n    /// Creates a new selection dataset with shuffled indices using a fixed seed.\n    ///\n    /// This is a thin wrapper around `SelectionDataset::new_shuffled_with_seed`.\n    ///\n    /// # Arguments\n    ///\n    /// * `dataset` - The original dataset to select from.\n    /// * `seed` - A fixed seed for the random number generator.\n    ///\n    /// # Returns\n    ///\n    /// A new `ShuffledDataset`.\n    #[deprecated(since = \"0.19.0\", note = \"Use `new(dataset, seed)` instead`\")]\n    pub fn with_seed(dataset: D, seed: u64) -> Self {\n        Self::new(dataset, seed)\n    }\n}\n\nimpl<D, I> Dataset<I> for ShuffledDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Clone + Send + Sync,\n{\n    fn get(&self, index: usize) -> Option<I> {\n        self.wrapped.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.wrapped.len()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::FakeDataset;\n    use crate::transform::selection::shuffled_indices;\n    use rand::SeedableRng;\n    use rand::prelude::StdRng;\n\n    #[test]\n    fn test_shuffled_dataset() {\n        let dataset = FakeDataset::<String>::new(27);\n        let source_items = dataset.iter().collect::<Vec<_>>();\n\n        let seed = 42;\n\n        #[allow(deprecated)]\n        let shuffled = ShuffledDataset::with_seed(dataset, seed);\n\n        let mut rng = StdRng::seed_from_u64(seed);\n        let indices = shuffled_indices(source_items.len(), &mut rng);\n\n        assert_eq!(shuffled.len(), source_items.len());\n\n        let expected_items: Vec<_> = indices\n            .iter()\n            .map(|&i| source_items[i].to_string())\n            .collect();\n        assert_eq!(&shuffled.iter().collect::<Vec<_>>(), &expected_items);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/transform/window.rs",
    "content": "use std::{cmp::max, marker::PhantomData, num::NonZeroUsize};\n\nuse crate::Dataset;\n\n/// Functionality to create a window.\npub trait Window<I> {\n    /// Creates a window of a collection.\n    ///\n    /// # Returns\n    ///\n    /// A `Vec<I>` representing the window.\n    fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>>;\n}\n\nimpl<I, T: Dataset<I> + ?Sized> Window<I> for T {\n    fn window(&self, current: usize, size: NonZeroUsize) -> Option<Vec<I>> {\n        (current..current + size.get())\n            .map(|x| self.get(x))\n            .collect()\n    }\n}\n\n/// Functionality to create a `WindowsIterator`.\npub trait Windows<I> {\n    /// Creates and returns an iterator over all the windows of length `size`.\n    fn windows(&self, size: usize) -> WindowsIterator<'_, I>;\n}\n\nimpl<I, T: Dataset<I>> Windows<I> for T {\n    /// Is empty if the `Dataset` is shorter than `size`.\n    ///\n    /// # Panics\n    ///\n    /// Panics if `size` is 0.    \n    ///\n    /// # Examples\n    ///\n    /// ```\n    /// use crate::burn_dataset::{\n    ///    transform::{Windows, WindowsDataset},\n    ///    Dataset, InMemDataset,\n    /// };\n    ///\n    /// let items = [1, 2, 3, 4].to_vec();\n    /// let dataset = InMemDataset::new(items.clone());\n    ///\n    /// for window in dataset.windows(2) {\n    ///  // do sth with window\n    /// }\n    /// ```\n    fn windows(&self, size: usize) -> WindowsIterator<'_, I> {\n        let size = NonZeroUsize::new(size).expect(\"window size must be non-zero\");\n        WindowsIterator::new(self, size)\n    }\n}\n\n/// Overlapping windows iterator.\npub struct WindowsIterator<'a, I> {\n    /// The size of the windows.\n    pub size: NonZeroUsize,\n    current: usize,\n    dataset: &'a dyn Dataset<I>,\n}\n\nimpl<'a, I> WindowsIterator<'a, I> {\n    /// Creates a new `WindowsIterator` instance. The windows overlap.\n    /// Is empty if the input `Dataset` is shorter than `size`.\n    ///\n    /// # Parameters\n    ///\n    /// - `dataset`: The dataset over which windows will be created.\n    /// - `size`: The size of the windows.\n    pub fn new(dataset: &'a dyn Dataset<I>, size: NonZeroUsize) -> Self {\n        WindowsIterator {\n            current: 0,\n            dataset,\n            size,\n        }\n    }\n}\n\nimpl<I> Iterator for WindowsIterator<'_, I> {\n    type Item = Vec<I>;\n\n    fn next(&mut self) -> Option<Vec<I>> {\n        self.current += 1;\n        self.dataset.window(self.current - 1, self.size)\n    }\n}\n\nimpl<I> Clone for WindowsIterator<'_, I> {\n    fn clone(&self) -> Self {\n        WindowsIterator {\n            size: self.size,\n            dataset: self.dataset,\n            current: self.current,\n        }\n    }\n}\n\n/// Dataset designed to work with overlapping windows of data.\npub struct WindowsDataset<D, I> {\n    /// The size of the windows.\n    pub size: NonZeroUsize,\n    dataset: D,\n    input: PhantomData<I>,\n}\n\nimpl<D, I> WindowsDataset<D, I>\nwhere\n    D: Dataset<I>,\n{\n    /// Creates a new `WindowsDataset` instance. The windows overlap.\n    /// Is empty if the input `Dataset` is shorter than `size`.\n    ///\n    /// # Parameters\n    ///\n    /// - `dataset`: The dataset over which windows will be created.\n    /// - `size`: The size of the windows.\n    pub fn new(dataset: D, size: usize) -> Self\n    where\n        D:,\n    {\n        let size = NonZeroUsize::new(size).expect(\"window size must be non-zero\");\n        WindowsDataset::<D, I> {\n            size,\n            dataset,\n            input: PhantomData,\n        }\n    }\n}\n\nimpl<D, I> Dataset<Vec<I>> for WindowsDataset<D, I>\nwhere\n    D: Dataset<I>,\n    I: Send + Sync,\n{\n    /// Retrieves a window of items from the dataset.\n    ///\n    /// # Parameters\n    ///\n    /// - `index`: The index of the window.\n    ///\n    /// # Returns\n    ///\n    /// A vector representing the window.\n    fn get(&self, index: usize) -> Option<Vec<I>> {\n        self.dataset.window(index, self.size)\n    }\n\n    /// Retrieves the number of windows in the dataset.\n    ///\n    /// # Returns\n    ///\n    /// A size representing the number of windows.\n    fn len(&self) -> usize {\n        let len = self.dataset.len() as isize - self.size.get() as isize + 1;\n        max(len, 0) as usize\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use rstest::rstest;\n\n    use crate::{\n        Dataset, InMemDataset,\n        transform::{Windows, WindowsDataset},\n    };\n\n    #[rstest]\n    pub fn windows_should_be_equal_to_vec_windows() {\n        let items = [1, 2, 3, 4, 5].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n        let expected = items\n            .windows(3)\n            .map(|x| x.to_vec())\n            .collect::<Vec<Vec<i32>>>();\n\n        let result = dataset.windows(3).collect::<Vec<Vec<i32>>>();\n\n        assert_eq!(result, expected);\n    }\n\n    #[rstest]\n    pub fn windows_dataset_should_be_equal_to_vec_windows() {\n        let items = [1, 2, 3, 4, 5].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n        let expected = items\n            .windows(3)\n            .map(|x| x.to_vec())\n            .collect::<Vec<Vec<i32>>>();\n\n        let result = WindowsDataset::new(dataset, 3)\n            .iter()\n            .collect::<Vec<Vec<i32>>>();\n\n        assert_eq!(result, expected);\n    }\n\n    #[rstest]\n    pub fn cloned_iterator_should_be_equal() {\n        let items = [1, 2, 3, 4, 5].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n        let original = dataset.windows(4);\n\n        let cloned = original.clone();\n\n        assert!(std::ptr::eq(cloned.dataset, original.dataset));\n        assert_eq!(cloned.size, original.size);\n        assert_eq!(cloned.current, original.current);\n    }\n\n    #[rstest]\n    pub fn cloned_iterator_should_be_unaffected() {\n        let items = [1, 2, 3, 4, 5].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n        let mut original = dataset.windows(4);\n\n        let cloned = original.clone();\n        original.current = 2;\n\n        assert_ne!(cloned.current, original.current);\n    }\n\n    #[rstest]\n    #[should_panic(expected = \"window size must be non-zero\")]\n    pub fn windows_should_panic() {\n        let items = [1, 2].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n\n        dataset.windows(0);\n    }\n\n    #[rstest]\n    #[should_panic(expected = \"window size must be non-zero\")]\n    pub fn new_window_dataset_should_panic() {\n        let items = [1, 2].to_vec();\n        let dataset = InMemDataset::new(items.clone());\n\n        WindowsDataset::new(dataset, 0);\n    }\n\n    #[rstest]\n    pub fn window_dataset_len_should_be_equal() {\n        let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());\n\n        let result = WindowsDataset::new(dataset, 2).len();\n\n        assert_eq!(result, 3);\n    }\n\n    #[rstest]\n    pub fn window_iterator_should_be_empty() {\n        let dataset = InMemDataset::new([1, 2].to_vec());\n        let mut peekable = dataset.windows(4).peekable();\n\n        let result = peekable.peek();\n\n        assert_eq!(result, None);\n    }\n\n    #[rstest]\n    pub fn window_dataset_len_should_be_zero() {\n        let dataset = InMemDataset::new([1, 2].to_vec());\n\n        let result = WindowsDataset::new(dataset, 4).len();\n\n        assert_eq!(result, 0);\n    }\n\n    #[rstest]\n    pub fn window_dataset_get_should_be_equal() {\n        let dataset = InMemDataset::new([1, 2, 3, 4].to_vec());\n        let expected = Some([1, 2, 3].to_vec());\n\n        let result = WindowsDataset::new(dataset, 3).get(0);\n\n        assert_eq!(result, expected);\n    }\n\n    #[rstest]\n    pub fn window_dataset_get_should_be_none() {\n        let dataset = InMemDataset::new([1, 2].to_vec());\n\n        let result = WindowsDataset::new(dataset, 4).get(0);\n\n        assert_eq!(result, None);\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/vision/cifar.rs",
    "content": "//! CIFAR Dataset Module\n//!\n//! This module provides functionality for loading the CIFAR-10 and CIFAR-100 image classification datasets.\n//! CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks in computer vision,\n//! consisting of 32×32 pixel color images split into training (50,000 images) and test (10,000 images) sets.\n//!\n//! ## Dataset Variants\n//! - **CIFAR-10**: Contains 10 distinct classes (e.g., airplane, automobile, bird, cat)\n//!     - CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).\n//!     - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\n//! - **CIFAR-100**: Contains 100 fine-grained classes (e.g., beaver, dolphin, oak tree)\n//!     - CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).\n//!     - Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\n//!\n//! ## Usage Example\n//! ```rust\n//! use burn_dataset::vision::CifarDataset;\n//! use burn_dataset::vision::CifarType;\n//!\n//! // Create a CIFAR-10 dataset accessor\n//! let dataset = CifarDataset::new(CifarType::Cifar10);\n//!\n//! // Access training and test sets\n//! let train_dataset = dataset.train();\n//! let test_dataset = dataset.test();\n//! ```\n//! ```rust\n//! use burn_dataset::vision::CifarDataset;\n//! use burn_dataset::vision::CifarType;\n//!\n//! // Create a CIFAR-100 dataset accessor\n//! let dataset = CifarDataset::new(CifarType::Cifar100);\n//!\n//! // Access training and test sets\n//! let train_dataset = dataset.train();\n//! let test_dataset = dataset.test();\n//! ```\n\nuse std::{path::PathBuf, sync::Mutex};\n\nuse flate2::read::GzDecoder;\nuse tar::Archive;\n\nuse crate::network::downloader;\nuse crate::vision::ImageFolderDataset;\n\n/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).\n/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\nconst CIFAR10_URL: &str = \"https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz\";\n\n/// CIFAR-100 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L75).\n/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\nconst CIFAR100_URL: &str = \"https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz\";\n\n/// Enum representing the types of CIFAR datasets available.\n///\n/// CIFAR (Canadian Institute For Advanced Research) datasets are widely used benchmarks for image classification.\n/// This enum provides support for the two main CIFAR datasets.\n#[derive(Debug, Clone, Copy)]\n#[allow(dead_code)]\npub enum CifarType {\n    /// CIFAR-10 dataset containing 10 classes with 60,000 images in total.\n    Cifar10,\n    /// CIFAR-100 dataset containing 100 classes with 60,000 images in total.\n    Cifar100,\n}\n\n/// CIFAR dataset accessor.\n///\n/// This struct provides convenient access to the CIFAR-10 and CIFAR-100 image classification datasets.\n/// It automatically downloads (if not already downloaded), extracts, and loads the datasets.\n///\n/// All images in CIFAR datasets are 32×32 pixel color images, with 50,000 images in the training set\n/// and 10,000 images in the test set.\n///\n/// ## Differences between datasets\n/// - **CIFAR-10**: Contains 10 mutually exclusive classes such as airplane, automobile, bird, cat, etc.\n/// - **CIFAR-100**: Contains 100 fine-grained classes such as beaver, dolphin, etc.\npub struct CifarDataset {\n    cifar_dir: PathBuf,\n}\n\nimpl CifarDataset {\n    /// Creates a new CIFAR dataset accessor.\n    ///\n    /// # Arguments\n    /// * `cifar_type` - Specifies whether to use CIFAR-10 or CIFAR-100 dataset\n    pub fn new(cifar_type: CifarType) -> Self {\n        Self {\n            cifar_dir: download(&cifar_type),\n        }\n    }\n\n    /// Gets the training dataset.\n    ///\n    /// # Returns\n    /// An `ImageFolderDataset` instance containing 50,000 training images\n    pub fn train(&self) -> ImageFolderDataset {\n        ImageFolderDataset::new_classification(self.cifar_dir.join(\"train\")).unwrap()\n    }\n\n    /// Gets the test dataset.\n    ///\n    /// # Returns\n    /// An `ImageFolderDataset` instance containing 10,000 test images\n    pub fn test(&self) -> ImageFolderDataset {\n        ImageFolderDataset::new_classification(self.cifar_dir.join(\"test\")).unwrap()\n    }\n}\n\n/// CIFAR dataset download lock.\n///\n/// This lock ensures that only one thread downloads the CIFAR dataset at a time.\nstatic DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());\n\nfn download(cifar_type: &CifarType) -> PathBuf {\n    // Acquire the lock. This will block if another thread already holds the lock.\n    let _lock = DOWNLOAD_LOCK.lock().unwrap();\n\n    // Dataset files are stored in the burn-dataset cache directory\n    let cache_dir = dirs::cache_dir()\n        .expect(\"Could not get cache directory\")\n        .join(\"burn-dataset\");\n\n    // Cifar store directory\n    let cifar_dir = match cifar_type {\n        CifarType::Cifar10 => cache_dir.join(\"cifar10\"),\n        CifarType::Cifar100 => cache_dir.join(\"cifar100\"),\n    };\n\n    // Cifar dataset url\n    let url = match cifar_type {\n        CifarType::Cifar10 => CIFAR10_URL,\n        CifarType::Cifar100 => CIFAR100_URL,\n    };\n\n    // Cifar dataset archive filename\n    let filename = match cifar_type {\n        CifarType::Cifar10 => \"cifar10.tgz\",\n        CifarType::Cifar100 => \"cifar100.tgz\",\n    };\n\n    // Check for already downloaded content\n    if !cifar_dir.exists() {\n        // Download gzip file\n        let bytes = downloader::download_file_as_bytes(url, filename);\n\n        // Decode gzip file content and unpack archive\n        let gz_buffer = GzDecoder::new(&bytes[..]);\n        let mut archive = Archive::new(gz_buffer);\n        archive.unpack(cache_dir).unwrap();\n    }\n\n    cifar_dir\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{Dataset, vision::Annotation};\n\n    /// CIFAR dataset length\n    const TRAINDATASET_LEN: usize = 50000;\n    const TESTDATASET_LEN: usize = 10000;\n\n    /// CIFAR-10 label range\n    const CIFAR10_LABEL_MIN: usize = 0;\n    const CIFAR10_LABEL_MAX: usize = 9;\n\n    /// CIFAR-100 label range\n    const CIFAR100_LABEL_MIN: usize = 0;\n    const CIFAR100_LABEL_MAX: usize = 99;\n\n    #[test]\n    fn test_cifar10_download() {\n        let cifar_dir = download(&CifarType::Cifar10);\n        assert!(cifar_dir.exists());\n    }\n\n    #[test]\n    fn test_cifar100_download() {\n        let cifar_dir = download(&CifarType::Cifar100);\n        assert!(cifar_dir.exists());\n    }\n\n    #[test]\n    fn test_cifar10_len() {\n        let dataset = CifarDataset::new(CifarType::Cifar10);\n        let train_dataset = dataset.train();\n        let test_dataset = dataset.test();\n        assert_eq!(train_dataset.len(), TRAINDATASET_LEN);\n        assert_eq!(test_dataset.len(), TESTDATASET_LEN);\n    }\n\n    #[test]\n    fn test_cifar100_len() {\n        let dataset = CifarDataset::new(CifarType::Cifar100);\n        let train_dataset = dataset.train();\n        let test_dataset = dataset.test();\n        assert_eq!(train_dataset.len(), TRAINDATASET_LEN);\n        assert_eq!(test_dataset.len(), TESTDATASET_LEN);\n    }\n\n    #[test]\n    fn test_cifar10_label_range() {\n        let dataset = CifarDataset::new(CifarType::Cifar10);\n        let test_dataset = dataset.test();\n        let (min, max) = get_label_range(&test_dataset);\n        assert_eq!(min, CIFAR10_LABEL_MIN);\n        assert_eq!(max, CIFAR10_LABEL_MAX);\n    }\n\n    #[test]\n    fn test_cifar100_label_range() {\n        let dataset = CifarDataset::new(CifarType::Cifar100);\n        let test_dataset = dataset.test();\n        let (min, max) = get_label_range(&test_dataset);\n        assert_eq!(min, CIFAR100_LABEL_MIN);\n        assert_eq!(max, CIFAR100_LABEL_MAX);\n    }\n\n    fn get_label_range(dataset: &ImageFolderDataset) -> (usize, usize) {\n        let labels: Vec<_> = dataset.iter().map(|item| item.annotation).collect();\n        let mut min = 128;\n        let mut max = 0;\n        for label in labels {\n            let index = match label {\n                Annotation::Label(index) => index,\n                _ => 0,\n            };\n            if index < min {\n                min = index;\n            }\n            if index > max {\n                max = index;\n            }\n        }\n\n        (min, max)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/vision/image_folder.rs",
    "content": "use crate::transform::{Mapper, MapperDataset};\nuse crate::{Dataset, InMemDataset};\n\nuse globwalk::{self, DirEntry};\nuse image::{self, ColorType};\nuse serde::{Deserialize, Serialize};\nuse serde_json::Value;\nuse std::collections::{HashMap, HashSet};\nuse std::fs;\nuse std::path::{Path, PathBuf};\nuse thiserror::Error;\n\nconst SUPPORTED_FILES: [&str; 4] = [\"bmp\", \"jpg\", \"jpeg\", \"png\"];\nconst BBOX_MIN_NUM_VALUES: usize = 4;\n\n/// Image data type.\n#[derive(Debug, Copy, Clone, PartialEq)]\npub enum PixelDepth {\n    /// 8-bit unsigned.\n    U8(u8),\n    /// 16-bit unsigned.\n    U16(u16),\n    /// 32-bit floating point.\n    F32(f32),\n}\n\nimpl TryFrom<PixelDepth> for u8 {\n    type Error = &'static str;\n\n    fn try_from(value: PixelDepth) -> Result<Self, Self::Error> {\n        if let PixelDepth::U8(v) = value {\n            Ok(v)\n        } else {\n            Err(\"Value is not u8\")\n        }\n    }\n}\n\nimpl TryFrom<PixelDepth> for u16 {\n    type Error = &'static str;\n\n    fn try_from(value: PixelDepth) -> Result<Self, Self::Error> {\n        if let PixelDepth::U16(v) = value {\n            Ok(v)\n        } else {\n            Err(\"Value is not u16\")\n        }\n    }\n}\n\nimpl TryFrom<PixelDepth> for f32 {\n    type Error = &'static str;\n\n    fn try_from(value: PixelDepth) -> Result<Self, Self::Error> {\n        if let PixelDepth::F32(v) = value {\n            Ok(v)\n        } else {\n            Err(\"Value is not f32\")\n        }\n    }\n}\n\n/// Annotation type for different tasks.\n#[derive(Debug, Clone, PartialEq)]\npub enum Annotation {\n    /// Image-level label.\n    Label(usize),\n    /// Multiple image-level labels.\n    MultiLabel(Vec<usize>),\n    /// Object bounding boxes.\n    BoundingBoxes(Vec<BoundingBox>),\n    /// Segmentation mask.\n    SegmentationMask(SegmentationMask),\n}\n\n/// Segmentation mask annotation.\n/// For semantic segmentation, a mask has a single channel (C = 1).\n/// For instance segmentation, there may be multiple masks per image (C >= 1).\n#[derive(Debug, Clone, PartialEq)]\npub struct SegmentationMask {\n    /// Segmentation mask.\n    pub mask: Vec<usize>,\n}\n\n/// Object detection bounding box annotation.\n#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]\npub struct BoundingBox {\n    /// Coordinates in [x_min, y_min, width, height] format.\n    pub coords: [f32; 4],\n\n    /// Box class label.\n    pub label: usize,\n}\n\n/// Image dataset item.\n#[derive(Debug, Clone, PartialEq)]\npub struct ImageDatasetItem {\n    /// Image as a vector with a valid image type.\n    pub image: Vec<PixelDepth>,\n\n    /// Original source image width.\n    pub image_width: usize,\n\n    /// Original source image height.\n    pub image_height: usize,\n\n    /// Annotation for the image.\n    pub annotation: Annotation,\n\n    /// Original image source.\n    pub image_path: String,\n}\n\n/// Raw annotation types.\n#[derive(Deserialize, Serialize, Debug, Clone)]\nenum AnnotationRaw {\n    Label(String),\n    MultiLabel(Vec<String>),\n    BoundingBoxes(Vec<BoundingBox>),\n    SegmentationMask(PathBuf),\n}\n\n#[derive(Deserialize, Serialize, Debug, Clone)]\nstruct ImageDatasetItemRaw {\n    /// Image path.\n    image_path: PathBuf,\n\n    /// Image annotation.\n    annotation: AnnotationRaw,\n}\n\nimpl ImageDatasetItemRaw {\n    fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {\n        ImageDatasetItemRaw {\n            image_path: image_path.as_ref().to_path_buf(),\n            annotation,\n        }\n    }\n}\n\nstruct PathToImageDatasetItem {\n    classes: HashMap<String, usize>,\n}\n\nfn segmentation_mask_to_vec_usize(mask_path: &PathBuf) -> Vec<usize> {\n    // Load image from disk\n    let image = image::open(mask_path).unwrap();\n\n    // Image as Vec<PixelDepth>\n    // if rgb8 or rgb16, keep only the first channel assuming all channels are the same\n\n    match image.color() {\n        ColorType::L8 => image.into_luma8().iter().map(|&x| x as usize).collect(),\n        ColorType::L16 => image.into_luma16().iter().map(|&x| x as usize).collect(),\n        ColorType::Rgb8 => image\n            .into_rgb8()\n            .iter()\n            .step_by(3)\n            .map(|&x| x as usize)\n            .collect(),\n        ColorType::Rgb16 => image\n            .into_rgb16()\n            .iter()\n            .step_by(3)\n            .map(|&x| x as usize)\n            .collect(),\n        _ => panic!(\"Unrecognized image color type\"),\n    }\n}\n\n/// Parse the image annotation to the corresponding type.\nfn parse_image_annotation(\n    annotation: &AnnotationRaw,\n    classes: &HashMap<String, usize>,\n) -> Annotation {\n    // TODO: add support for other annotations\n    // - [ ] Object bounding boxes\n    // - [x] Segmentation mask\n    // For now, only image classification labels and segmentation are supported.\n\n    // Map class string to label id\n    match annotation {\n        AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()),\n        AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(\n            names\n                .iter()\n                .map(|name| *classes.get(name).unwrap())\n                .collect(),\n        ),\n        AnnotationRaw::SegmentationMask(mask_path) => {\n            Annotation::SegmentationMask(SegmentationMask {\n                mask: segmentation_mask_to_vec_usize(mask_path),\n            })\n        }\n        AnnotationRaw::BoundingBoxes(v) => Annotation::BoundingBoxes(v.clone()),\n    }\n}\n\n/// Retrieve all available classes from the COCO JSON\nfn parse_coco_classes(\n    json: &serde_json::Value,\n) -> Result<HashMap<String, usize>, ImageLoaderError> {\n    let mut classes = HashMap::new();\n\n    if let Some(json_classes) = json[\"categories\"].as_array() {\n        for class in json_classes {\n            let id = class[\"id\"]\n                .as_u64()\n                .ok_or_else(|| ImageLoaderError::ParsingError(\"Invalid class ID\".to_string()))\n                .and_then(|v| {\n                    usize::try_from(v).map_err(|_| {\n                        ImageLoaderError::ParsingError(\"Class ID out of usize range\".to_string())\n                    })\n                })?;\n\n            let name = class[\"name\"]\n                .as_str()\n                .filter(|&s| !s.is_empty())\n                .ok_or_else(|| ImageLoaderError::ParsingError(\"Invalid class name\".to_string()))?\n                .to_string();\n\n            classes.insert(name, id);\n        }\n    }\n\n    if classes.is_empty() {\n        return Err(ImageLoaderError::ParsingError(\n            \"No classes found in annotations\".to_string(),\n        ));\n    }\n\n    Ok(classes)\n}\n\n/// Retrieve annotations from COCO JSON\nfn parse_coco_bbox_annotations(\n    json: &serde_json::Value,\n) -> Result<HashMap<u64, AnnotationRaw>, ImageLoaderError> {\n    let mut annotations = HashMap::new();\n\n    if let Some(json_annotations) = json[\"annotations\"].as_array() {\n        for annotation in json_annotations {\n            let image_id = annotation[\"image_id\"].as_u64().ok_or_else(|| {\n                ImageLoaderError::ParsingError(\"Invalid image ID in annotation\".into())\n            })?;\n\n            let class_id = annotation[\"category_id\"]\n                .as_u64()\n                .ok_or_else(|| {\n                    ImageLoaderError::ParsingError(\"Invalid class ID in annotations\".to_string())\n                })\n                .and_then(|v| {\n                    usize::try_from(v).map_err(|_| {\n                        ImageLoaderError::ParsingError(\n                            \"Class ID in annotations out of usize range\".to_string(),\n                        )\n                    })\n                })?;\n\n            let bbox_coords = annotation[\"bbox\"]\n                .as_array()\n                .ok_or_else(|| ImageLoaderError::ParsingError(\"missing bbox array\".to_string()))?\n                .iter()\n                .map(|v| {\n                    v.as_f64()\n                        .ok_or_else(|| {\n                            ImageLoaderError::ParsingError(\"invalid bbox value\".to_string())\n                        })\n                        .map(|val| val as f32)\n                })\n                .collect::<Result<Vec<f32>, _>>()?;\n\n            if bbox_coords.len() < BBOX_MIN_NUM_VALUES {\n                return Err(ImageLoaderError::ParsingError(format!(\n                    \"not enough bounding box coordinates in annotation for image {image_id}\",\n                )));\n            }\n\n            let bbox = BoundingBox {\n                coords: [\n                    bbox_coords[0],\n                    bbox_coords[1],\n                    bbox_coords[2],\n                    bbox_coords[3],\n                ],\n                label: class_id,\n            };\n\n            annotations\n                .entry(image_id)\n                .and_modify(|entry| {\n                    if let AnnotationRaw::BoundingBoxes(bboxes) = entry {\n                        bboxes.push(bbox.clone());\n                    }\n                })\n                .or_insert_with(|| AnnotationRaw::BoundingBoxes(vec![bbox]));\n        }\n    }\n\n    if annotations.is_empty() {\n        return Err(ImageLoaderError::ParsingError(\n            \"no annotations found\".to_string(),\n        ));\n    }\n\n    Ok(annotations)\n}\n\n/// Retrieve all available images from the COCO JSON\nfn parse_coco_images<P: AsRef<Path>>(\n    images_path: &P,\n    mut annotations: HashMap<u64, AnnotationRaw>,\n    json: &serde_json::Value,\n) -> Result<Vec<ImageDatasetItemRaw>, ImageLoaderError> {\n    let mut images = Vec::new();\n    if let Some(json_images) = json[\"images\"].as_array() {\n        for image in json_images {\n            let image_id = image[\"id\"].as_u64().ok_or_else(|| {\n                ImageLoaderError::ParsingError(\"Invalid image ID in image list\".to_string())\n            })?;\n\n            let file_name = image[\"file_name\"]\n                .as_str()\n                .ok_or_else(|| ImageLoaderError::ParsingError(\"Invalid image ID\".to_string()))?\n                .to_string();\n\n            let mut image_path = images_path.as_ref().to_path_buf();\n            image_path.push(file_name);\n\n            if !image_path.exists() {\n                return Err(ImageLoaderError::IOError(format!(\n                    \"Image {} not found\",\n                    image_path.display()\n                )));\n            }\n\n            let annotation = annotations\n                .remove(&image_id)\n                .unwrap_or_else(|| AnnotationRaw::BoundingBoxes(Vec::new()));\n\n            images.push(ImageDatasetItemRaw {\n                annotation,\n                image_path,\n            });\n        }\n    }\n\n    if images.is_empty() {\n        return Err(ImageLoaderError::ParsingError(\n            \"No images found in annotations\".to_string(),\n        ));\n    }\n\n    Ok(images)\n}\n\nimpl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {\n    /// Convert a raw image dataset item (path-like) to a 3D image array with a target label.\n    fn map(&self, item: &ImageDatasetItemRaw) -> ImageDatasetItem {\n        let annotation = parse_image_annotation(&item.annotation, &self.classes);\n\n        // Load image from disk\n        let image = image::open(&item.image_path).unwrap();\n\n        // Save image dimensions for manipulation\n        let img_width = image.width() as usize;\n        let img_height = image.height() as usize;\n\n        // Image as Vec<PixelDepth>\n        let img_vec = match image.color() {\n            ColorType::L8 => image\n                .into_luma8()\n                .iter()\n                .map(|&x| PixelDepth::U8(x))\n                .collect(),\n            ColorType::La8 => image\n                .into_luma_alpha8()\n                .iter()\n                .map(|&x| PixelDepth::U8(x))\n                .collect(),\n            ColorType::L16 => image\n                .into_luma16()\n                .iter()\n                .map(|&x| PixelDepth::U16(x))\n                .collect(),\n            ColorType::La16 => image\n                .into_luma_alpha16()\n                .iter()\n                .map(|&x| PixelDepth::U16(x))\n                .collect(),\n            ColorType::Rgb8 => image\n                .into_rgb8()\n                .iter()\n                .map(|&x| PixelDepth::U8(x))\n                .collect(),\n            ColorType::Rgba8 => image\n                .into_rgba8()\n                .iter()\n                .map(|&x| PixelDepth::U8(x))\n                .collect(),\n            ColorType::Rgb16 => image\n                .into_rgb16()\n                .iter()\n                .map(|&x| PixelDepth::U16(x))\n                .collect(),\n            ColorType::Rgba16 => image\n                .into_rgba16()\n                .iter()\n                .map(|&x| PixelDepth::U16(x))\n                .collect(),\n            ColorType::Rgb32F => image\n                .into_rgb32f()\n                .iter()\n                .map(|&x| PixelDepth::F32(x))\n                .collect(),\n            ColorType::Rgba32F => image\n                .into_rgba32f()\n                .iter()\n                .map(|&x| PixelDepth::F32(x))\n                .collect(),\n            _ => panic!(\"Unrecognized image color type\"),\n        };\n\n        ImageDatasetItem {\n            image: img_vec,\n            image_width: img_width,\n            image_height: img_height,\n            annotation,\n            image_path: item.image_path.display().to_string(),\n        }\n    }\n}\n\n/// Error type for [ImageFolderDataset](ImageFolderDataset).\n#[derive(Error, Debug)]\npub enum ImageLoaderError {\n    /// Unknown error.\n    #[error(\"unknown: `{0}`\")]\n    Unknown(String),\n\n    /// I/O operation error.\n    #[error(\"I/O error: `{0}`\")]\n    IOError(String),\n\n    /// Invalid file error.\n    #[error(\"Invalid file extension: `{0}`\")]\n    InvalidFileExtensionError(String),\n\n    /// Parsing error.\n    #[error(\"Parsing error: `{0}`\")]\n    ParsingError(String),\n}\n\ntype ImageDatasetMapper =\n    MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;\n\n/// A generic dataset to load images from disk.\npub struct ImageFolderDataset {\n    dataset: ImageDatasetMapper,\n}\n\nimpl Dataset<ImageDatasetItem> for ImageFolderDataset {\n    fn get(&self, index: usize) -> Option<ImageDatasetItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\nimpl ImageFolderDataset {\n    /// Create an image classification dataset from the root folder.\n    ///\n    /// # Arguments\n    ///\n    /// * `root` - Dataset root folder.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification<P: AsRef<Path>>(root: P) -> Result<Self, ImageLoaderError> {\n        // New dataset containing any of the supported file types\n        ImageFolderDataset::new_classification_with(root, &SUPPORTED_FILES)\n    }\n\n    /// Create an image classification dataset from the root folder.\n    /// The included images are filtered based on the provided extensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `root` - Dataset root folder.\n    /// * `extensions` - List of allowed extensions.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification_with<P, S>(\n        root: P,\n        extensions: &[S],\n    ) -> Result<Self, ImageLoaderError>\n    where\n        P: AsRef<Path>,\n        S: AsRef<str>,\n    {\n        // Glob all images with extensions\n        let walker = globwalk::GlobWalkerBuilder::from_patterns(\n            root.as_ref(),\n            &[format!(\n                \"*.{{{}}}\", // \"*.{ext1,ext2,ext3}\n                extensions\n                    .iter()\n                    .map(Self::check_extension)\n                    .collect::<Result<Vec<_>, _>>()?\n                    .join(\",\")\n            )],\n        )\n        .follow_links(true)\n        .sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) // order by path\n        .build()\n        .map_err(|err| ImageLoaderError::Unknown(format!(\"{err:?}\")))?\n        .filter_map(Result::ok);\n\n        // Get all dataset items\n        let mut items = Vec::new();\n        let mut classes = HashSet::new();\n        for img in walker {\n            let image_path = img.path();\n\n            // Label name is represented by the parent folder name\n            let label = image_path\n                .parent()\n                .ok_or_else(|| {\n                    ImageLoaderError::IOError(\"Could not resolve image parent folder\".to_string())\n                })?\n                .file_name()\n                .ok_or_else(|| {\n                    ImageLoaderError::IOError(\n                        \"Could not resolve image parent folder name\".to_string(),\n                    )\n                })?\n                .to_string_lossy()\n                .into_owned();\n\n            classes.insert(label.clone());\n\n            items.push(ImageDatasetItemRaw::new(\n                image_path,\n                AnnotationRaw::Label(label),\n            ))\n        }\n\n        // Sort class names\n        let mut classes = classes.into_iter().collect::<Vec<_>>();\n        classes.sort();\n\n        Self::with_items(items, &classes)\n    }\n\n    /// Create an image classification dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(\n        items: Vec<(P, String)>,\n        classes: &[S],\n    ) -> Result<Self, ImageLoaderError> {\n        // Parse items and check valid image extension types\n        let items = items\n            .into_iter()\n            .map(|(path, label)| {\n                // Map image path and label\n                let path = path.as_ref();\n                let label = AnnotationRaw::Label(label);\n\n                Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;\n\n                Ok(ImageDatasetItemRaw::new(path, label))\n            })\n            .collect::<Result<Vec<_>, _>>()?;\n\n        Self::with_items(items, classes)\n    }\n\n    /// Create a multi-label image classification dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(\n        items: Vec<(P, Vec<String>)>,\n        classes: &[S],\n    ) -> Result<Self, ImageLoaderError> {\n        // Parse items and check valid image extension types\n        let items = items\n            .into_iter()\n            .map(|(path, labels)| {\n                // Map image path and multi-label\n                let path = path.as_ref();\n                let labels = AnnotationRaw::MultiLabel(labels);\n\n                Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;\n\n                Ok(ImageDatasetItemRaw::new(path, labels))\n            })\n            .collect::<Result<Vec<_>, _>>()?;\n\n        Self::with_items(items, classes)\n    }\n\n    /// Create an image segmentation dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - List of dataset items, each item represented by a tuple `(image path, annotation path)`.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_segmentation_with_items<P: AsRef<Path>, S: AsRef<str>>(\n        items: Vec<(P, P)>,\n        classes: &[S],\n    ) -> Result<Self, ImageLoaderError> {\n        // Parse items and check valid image extension types\n        let items = items\n            .into_iter()\n            .map(|(image_path, mask_path)| {\n                // Map image path and segmentation mask path\n                let image_path = image_path.as_ref();\n                let annotation = AnnotationRaw::SegmentationMask(mask_path.as_ref().to_path_buf());\n\n                Self::check_extension(&image_path.extension().unwrap().to_str().unwrap())?;\n\n                Ok(ImageDatasetItemRaw::new(image_path, annotation))\n            })\n            .collect::<Result<Vec<_>, _>>()?;\n\n        Self::with_items(items, classes)\n    }\n\n    /// Create a COCO detection dataset based on the annotations JSON and image directory.\n    ///\n    /// # Arguments\n    ///\n    /// * `annotations_json` - Path to the JSON file containing annotations in COCO format (for\n    ///   example instances_train2017.json).\n    ///\n    /// * `images_path` - Path containing the images matching the annotations JSON.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    pub fn new_coco_detection<A: AsRef<Path>, I: AsRef<Path>>(\n        annotations_json: A,\n        images_path: I,\n    ) -> Result<Self, ImageLoaderError> {\n        let file = fs::File::open(annotations_json)\n            .map_err(|e| ImageLoaderError::IOError(format!(\"Failed to open annotations: {e}\")))?;\n        let json: Value = serde_json::from_reader(file).map_err(|e| {\n            ImageLoaderError::ParsingError(format!(\"Failed to parse annotations: {e}\"))\n        })?;\n\n        let classes = parse_coco_classes(&json)?;\n        let annotations = parse_coco_bbox_annotations(&json)?;\n        let items = parse_coco_images(&images_path, annotations, &json)?;\n        let dataset = InMemDataset::new(items);\n        let mapper = PathToImageDatasetItem { classes };\n        let dataset = MapperDataset::new(dataset, mapper);\n\n        Ok(Self { dataset })\n    }\n\n    /// Create an image dataset with the specified items.\n    ///\n    /// # Arguments\n    ///\n    /// * `items` - Raw dataset items.\n    /// * `classes` - Dataset class names.\n    ///\n    /// # Returns\n    /// A new dataset instance.\n    fn with_items<S: AsRef<str>>(\n        items: Vec<ImageDatasetItemRaw>,\n        classes: &[S],\n    ) -> Result<Self, ImageLoaderError> {\n        // NOTE: right now we don't need to validate the supported image files since\n        // the method is private. We assume it's already validated.\n        let dataset = InMemDataset::new(items);\n\n        // Class names to index map\n        let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();\n        let classes_map: HashMap<_, _> = classes\n            .into_iter()\n            .enumerate()\n            .map(|(idx, cls)| (cls.to_string(), idx))\n            .collect();\n\n        let mapper = PathToImageDatasetItem {\n            classes: classes_map,\n        };\n        let dataset = MapperDataset::new(dataset, mapper);\n\n        Ok(Self { dataset })\n    }\n\n    /// Check if extension is supported.\n    fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {\n        let extension = extension.as_ref();\n        if !SUPPORTED_FILES.contains(&extension) {\n            Err(ImageLoaderError::InvalidFileExtensionError(\n                extension.to_string(),\n            ))\n        } else {\n            Ok(extension.to_string())\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    const DATASET_ROOT: &str = \"tests/data/image_folder\";\n    const SEGMASK_ROOT: &str = \"tests/data/segmask_folder\";\n    const COCO_JSON: &str = \"tests/data/dataset_coco.json\";\n    const COCO_IMAGES: &str = \"tests/data/image_folder_coco\";\n\n    #[test]\n    pub fn image_folder_dataset() {\n        let dataset = ImageFolderDataset::new_classification(DATASET_ROOT).unwrap();\n\n        // Dataset has 3 elements\n        assert_eq!(dataset.len(), 3);\n        assert_eq!(dataset.get(3), None);\n\n        // Dataset elements should be: orange (0), red (1), red (1)\n        assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));\n        assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));\n        assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));\n    }\n\n    #[test]\n    pub fn image_folder_dataset_filtered() {\n        let dataset = ImageFolderDataset::new_classification_with(DATASET_ROOT, &[\"jpg\"]).unwrap();\n\n        // Filtered dataset has 2 elements\n        assert_eq!(dataset.len(), 2);\n        assert_eq!(dataset.get(2), None);\n\n        // Dataset elements should be: orange (0), red (1)\n        assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));\n        assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));\n    }\n\n    #[test]\n    pub fn image_folder_dataset_with_items_sizes() {\n        let root = Path::new(DATASET_ROOT);\n        let items = vec![\n            (root.join(\"orange\").join(\"dot.jpg\"), \"orange\".to_string()),\n            (root.join(\"red\").join(\"dot.jpg\"), \"red\".to_string()),\n            (root.join(\"red\").join(\"dot.png\"), \"red\".to_string()),\n        ];\n        let dataset =\n            ImageFolderDataset::new_classification_with_items(items, &[\"orange\", \"red\"]).unwrap();\n\n        // Dataset has 3 elements\n        assert_eq!(dataset.len(), 3);\n        assert_eq!(dataset.get(3), None);\n\n        // Test item sizes\n\n        assert_eq!(\n            (\n                dataset.get(0).unwrap().image_width,\n                dataset.get(0).unwrap().image_height\n            ),\n            (1, 1)\n        );\n        assert_eq!(\n            (\n                dataset.get(1).unwrap().image_width,\n                dataset.get(1).unwrap().image_height\n            ),\n            (1, 1)\n        );\n        assert_eq!(\n            (\n                dataset.get(2).unwrap().image_width,\n                dataset.get(2).unwrap().image_height\n            ),\n            (1, 1)\n        );\n    }\n\n    #[test]\n    pub fn image_folder_dataset_with_items() {\n        let root = Path::new(DATASET_ROOT);\n        let items = vec![\n            (root.join(\"orange\").join(\"dot.jpg\"), \"orange\".to_string()),\n            (root.join(\"red\").join(\"dot.jpg\"), \"red\".to_string()),\n            (root.join(\"red\").join(\"dot.png\"), \"red\".to_string()),\n        ];\n        let dataset =\n            ImageFolderDataset::new_classification_with_items(items, &[\"orange\", \"red\"]).unwrap();\n\n        // Dataset has 3 elements\n        assert_eq!(dataset.len(), 3);\n        assert_eq!(dataset.get(3), None);\n\n        // Dataset elements should be: orange (0), red (1), red (1)\n        assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));\n        assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));\n        assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));\n    }\n\n    #[test]\n    pub fn image_folder_dataset_multilabel() {\n        let root = Path::new(DATASET_ROOT);\n        let items = vec![\n            (\n                root.join(\"orange\").join(\"dot.jpg\"),\n                vec![\"dot\".to_string(), \"orange\".to_string()],\n            ),\n            (\n                root.join(\"red\").join(\"dot.jpg\"),\n                vec![\"dot\".to_string(), \"red\".to_string()],\n            ),\n            (\n                root.join(\"red\").join(\"dot.png\"),\n                vec![\"dot\".to_string(), \"red\".to_string()],\n            ),\n        ];\n        let dataset = ImageFolderDataset::new_multilabel_classification_with_items(\n            items,\n            &[\"dot\", \"orange\", \"red\"],\n        )\n        .unwrap();\n\n        // Dataset has 3 elements\n        assert_eq!(dataset.len(), 3);\n        assert_eq!(dataset.get(3), None);\n\n        // Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)\n        assert_eq!(\n            dataset.get(0).unwrap().annotation,\n            Annotation::MultiLabel(vec![0, 1])\n        );\n        assert_eq!(\n            dataset.get(1).unwrap().annotation,\n            Annotation::MultiLabel(vec![0, 2])\n        );\n        assert_eq!(\n            dataset.get(2).unwrap().annotation,\n            Annotation::MultiLabel(vec![0, 2])\n        );\n    }\n\n    #[test]\n    #[should_panic]\n    pub fn image_folder_dataset_invalid_extension() {\n        // Some invalid file extension\n        let _ = ImageFolderDataset::new_classification_with(DATASET_ROOT, &[\"ico\"]).unwrap();\n    }\n\n    #[test]\n    pub fn pixel_depth_try_into_u8() {\n        let val = u8::MAX;\n        let pix: u8 = PixelDepth::U8(val).try_into().unwrap();\n        assert_eq!(pix, val);\n    }\n\n    #[test]\n    #[should_panic]\n    pub fn pixel_depth_try_into_u8_invalid() {\n        let _: u8 = PixelDepth::U16(u8::MAX as u16 + 1).try_into().unwrap();\n    }\n\n    #[test]\n    pub fn pixel_depth_try_into_u16() {\n        let val = u16::MAX;\n        let pix: u16 = PixelDepth::U16(val).try_into().unwrap();\n        assert_eq!(pix, val);\n    }\n\n    #[test]\n    #[should_panic]\n    pub fn pixel_depth_try_into_u16_invalid() {\n        let _: u16 = PixelDepth::F32(u16::MAX as f32).try_into().unwrap();\n    }\n\n    #[test]\n    pub fn pixel_depth_try_into_f32() {\n        let val = f32::MAX;\n        let pix: f32 = PixelDepth::F32(val).try_into().unwrap();\n        assert_eq!(pix, val);\n    }\n\n    #[test]\n    #[should_panic]\n    pub fn pixel_depth_try_into_f32_invalid() {\n        let _: f32 = PixelDepth::U16(u16::MAX).try_into().unwrap();\n    }\n\n    #[test]\n    pub fn parse_image_annotation_label_string() {\n        let classes = HashMap::from([(\"0\".to_string(), 0_usize), (\"1\".to_string(), 1_usize)]);\n        let anno = AnnotationRaw::Label(\"0\".to_string());\n        assert_eq!(\n            parse_image_annotation(&anno, &classes),\n            Annotation::Label(0)\n        );\n    }\n\n    #[test]\n    pub fn parse_image_annotation_multilabel_string() {\n        let classes = HashMap::from([\n            (\"0\".to_string(), 0_usize),\n            (\"1\".to_string(), 1_usize),\n            (\"2\".to_string(), 2_usize),\n        ]);\n        let anno = AnnotationRaw::MultiLabel(vec![\"0\".to_string(), \"2\".to_string()]);\n        assert_eq!(\n            parse_image_annotation(&anno, &classes),\n            Annotation::MultiLabel(vec![0, 2])\n        );\n    }\n\n    #[test]\n    pub fn segmask_image_path_to_vec_usize() {\n        let root = Path::new(SEGMASK_ROOT);\n\n        // checkerboard mask\n        const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [\n            1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2,\n            1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1,\n            2, 1, 2, 1, 2, 1,\n        ];\n        assert_eq!(\n            TEST_CHECKERBOARD_MASK_PATTERN\n                .iter()\n                .map(|&x| x as usize)\n                .collect::<Vec<usize>>(),\n            segmentation_mask_to_vec_usize(&root.join(\"annotations\").join(\"mask_checkerboard.png\")),\n        );\n\n        // random 2 colors mask\n        const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [\n            1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2,\n            2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1,\n            1, 1, 1, 1, 1, 1,\n        ];\n        assert_eq!(\n            TEST_RANDOM2COLORS_MASK_PATTERN\n                .iter()\n                .map(|&x| x as usize)\n                .collect::<Vec<usize>>(),\n            segmentation_mask_to_vec_usize(\n                &root.join(\"annotations\").join(\"mask_random_2colors.png\")\n            ),\n        );\n        // random 3 colors mask\n        const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [\n            3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3,\n            3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1,\n            3, 3, 2, 1, 2, 2,\n        ];\n        assert_eq!(\n            TEST_RANDOM3COLORS_MASK_PATTERN\n                .iter()\n                .map(|&x| x as usize)\n                .collect::<Vec<usize>>(),\n            segmentation_mask_to_vec_usize(\n                &root.join(\"annotations\").join(\"mask_random_3colors.png\")\n            ),\n        );\n    }\n\n    #[test]\n    pub fn segmask_folder_dataset() {\n        let root = Path::new(SEGMASK_ROOT);\n\n        let items = vec![\n            (\n                root.join(\"images\").join(\"image_checkerboard.png\"),\n                root.join(\"annotations\").join(\"mask_checkerboard.png\"),\n            ),\n            (\n                root.join(\"images\").join(\"image_random_2colors.png\"),\n                root.join(\"annotations\").join(\"mask_random_2colors.png\"),\n            ),\n            (\n                root.join(\"images\").join(\"image_random_3colors.png\"),\n                root.join(\"annotations\").join(\"mask_random_3colors.png\"),\n            ),\n        ];\n        let dataset = ImageFolderDataset::new_segmentation_with_items(\n            items,\n            &[\n                \"foo\", // 0\n                \"bar\", // 1\n                \"baz\", // 2\n                \"qux\", // 3\n            ],\n        )\n        .unwrap();\n\n        // Dataset has 3 elements; each (image, annotation) is a single item\n        assert_eq!(dataset.len(), 3);\n        assert_eq!(dataset.get(3), None);\n\n        // checkerboard mask\n        const TEST_CHECKERBOARD_MASK_PATTERN: [u8; 64] = [\n            1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2,\n            1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1,\n            2, 1, 2, 1, 2, 1,\n        ];\n        assert_eq!(\n            dataset.get(0).unwrap().annotation,\n            Annotation::SegmentationMask(SegmentationMask {\n                mask: TEST_CHECKERBOARD_MASK_PATTERN\n                    .iter()\n                    .map(|&x| x as usize)\n                    .collect()\n            })\n        );\n        // random 2 colors mask\n        const TEST_RANDOM2COLORS_MASK_PATTERN: [u8; 64] = [\n            1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2,\n            2, 1, 1, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1,\n            1, 1, 1, 1, 1, 1,\n        ];\n        assert_eq!(\n            dataset.get(1).unwrap().annotation,\n            Annotation::SegmentationMask(SegmentationMask {\n                mask: TEST_RANDOM2COLORS_MASK_PATTERN\n                    .iter()\n                    .map(|&x| x as usize)\n                    .collect()\n            })\n        );\n        // random 3 colors mask\n        const TEST_RANDOM3COLORS_MASK_PATTERN: [u8; 64] = [\n            3, 1, 3, 3, 1, 1, 3, 2, 3, 3, 3, 3, 1, 3, 2, 1, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 1, 3, 3,\n            3, 2, 3, 2, 2, 3, 2, 3, 3, 1, 3, 1, 3, 3, 1, 1, 3, 2, 1, 2, 2, 2, 1, 2, 1, 2, 3, 3, 1,\n            3, 3, 2, 1, 2, 2,\n        ];\n        assert_eq!(\n            dataset.get(2).unwrap().annotation,\n            Annotation::SegmentationMask(SegmentationMask {\n                mask: TEST_RANDOM3COLORS_MASK_PATTERN\n                    .iter()\n                    .map(|&x| x as usize)\n                    .collect()\n            })\n        );\n    }\n\n    #[test]\n    pub fn coco_detection_dataset() {\n        let dataset = ImageFolderDataset::new_coco_detection(COCO_JSON, COCO_IMAGES).unwrap();\n        assert_eq!(dataset.len(), 3); // we have only three images defined\n        assert_eq!(dataset.get(3), None);\n\n        const TWO_DOTS_AND_TRIANGLE_B1: BoundingBox = BoundingBox {\n            coords: [3.125_172, 18.090_784, 10.960_11, 10.740_027],\n            label: 0,\n        };\n\n        const TWO_DOTS_AND_TRIANGLE_B2: BoundingBox = BoundingBox {\n            coords: [3.257_221_5, 3.037_139, 10.563_961, 10.828_06],\n            label: 0,\n        };\n\n        const TWO_DOTS_AND_TRIANGLE_B3: BoundingBox = BoundingBox {\n            coords: [15.097_662, 3.389_271, 12.632_737, 11.180_193],\n            label: 1,\n        };\n\n        const DOTS_TRIANGLE_B1: BoundingBox = BoundingBox {\n            coords: [3.125_172, 17.914_719, 10.828_06, 11.004_127],\n            label: 0,\n        };\n\n        const DOTS_TRIANGLE_B2: BoundingBox = BoundingBox {\n            coords: [15.273_727, 3.301_238, 12.192_573, 11.708_39],\n            label: 1,\n        };\n\n        const ONE_DOT_B1: BoundingBox = BoundingBox {\n            coords: [10.079_78, 9.595_598, 10.960_11, 11.356_258],\n            label: 0,\n        };\n\n        for item in dataset.iter() {\n            let file_name = Path::new(&item.image_path).file_name().unwrap();\n            match item.annotation {\n                // check if the number of bounding boxes is correct\n                Annotation::BoundingBoxes(v) => {\n                    if file_name == \"two_dots_and_triangle.jpg\" {\n                        assert_eq!(v.len(), 3);\n                        assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B1));\n                        assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B2));\n                        assert!(v.contains(&TWO_DOTS_AND_TRIANGLE_B3));\n                    } else if file_name == \"dot_triangle.jpg\" {\n                        assert_eq!(v.len(), 2);\n                        assert!(v.contains(&DOTS_TRIANGLE_B1));\n                        assert!(v.contains(&DOTS_TRIANGLE_B2));\n                    } else if file_name == \"one_dot.jpg\" {\n                        assert_eq!(v.len(), 1);\n                        assert!(v.contains(&ONE_DOT_B1));\n                    } else {\n                        panic!(\"{}\", format!(\"unexpected image name: {}\", item.image_path));\n                    }\n                }\n                _ => panic!(\"unexpected annotation\"),\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/vision/mnist.rs",
    "content": "use std::fs::{File, create_dir_all};\nuse std::io::{Read, Seek, SeekFrom};\nuse std::path::{Path, PathBuf};\n\nuse flate2::read::GzDecoder;\nuse serde::{Deserialize, Serialize};\n\nuse crate::{\n    Dataset, InMemDataset,\n    transform::{Mapper, MapperDataset},\n};\n\nuse crate::network::downloader::download_file_as_bytes;\n\n// CVDF mirror of http://yann.lecun.com/exdb/mnist/\nconst URL: &str = \"https://storage.googleapis.com/cvdf-datasets/mnist/\";\nconst TRAIN_IMAGES: &str = \"train-images-idx3-ubyte\";\nconst TRAIN_LABELS: &str = \"train-labels-idx1-ubyte\";\nconst TEST_IMAGES: &str = \"t10k-images-idx3-ubyte\";\nconst TEST_LABELS: &str = \"t10k-labels-idx1-ubyte\";\n\nconst WIDTH: usize = 28;\nconst HEIGHT: usize = 28;\n\n/// MNIST item.\n#[derive(Deserialize, Serialize, Debug, Clone)]\npub struct MnistItem {\n    /// Image as a 2D array of floats.\n    pub image: [[f32; WIDTH]; HEIGHT],\n\n    /// Label of the image.\n    pub label: u8,\n}\n\n#[derive(Deserialize, Debug, Clone)]\nstruct MnistItemRaw {\n    pub image_bytes: Vec<u8>,\n    pub label: u8,\n}\n\nstruct BytesToImage;\n\nimpl Mapper<MnistItemRaw, MnistItem> for BytesToImage {\n    /// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).\n    fn map(&self, item: &MnistItemRaw) -> MnistItem {\n        // Ensure the image dimensions are correct.\n        debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);\n\n        // Convert the image to a 2D array of floats.\n        let mut image_array = [[0f32; WIDTH]; HEIGHT];\n        for (i, pixel) in item.image_bytes.iter().enumerate() {\n            let x = i % WIDTH;\n            let y = i / HEIGHT;\n            image_array[y][x] = *pixel as f32;\n        }\n\n        MnistItem {\n            image: image_array,\n            label: item.label,\n        }\n    }\n}\n\ntype MappedDataset = MapperDataset<InMemDataset<MnistItemRaw>, BytesToImage, MnistItemRaw>;\n\n/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000\n/// images per class. There are 60,000 training images and 10,000 test images.\n///\n/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).\npub struct MnistDataset {\n    dataset: MappedDataset,\n}\n\nimpl Dataset<MnistItem> for MnistDataset {\n    fn get(&self, index: usize) -> Option<MnistItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\nimpl MnistDataset {\n    /// Creates a new train dataset.\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    /// Creates a new test dataset.\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    fn new(split: &str) -> Self {\n        // Download dataset\n        let root = MnistDataset::download(split);\n\n        // MNIST is tiny so we can load it in-memory\n        // Train images (u8): 28 * 28 * 60000 = 47.04Mb\n        // Test images (u8): 28 * 28 * 10000 = 7.84Mb\n        let images = MnistDataset::read_images(&root, split);\n        let labels = MnistDataset::read_labels(&root, split);\n\n        // Collect as vector of MnistItemRaw\n        let items: Vec<_> = images\n            .into_iter()\n            .zip(labels)\n            .map(|(image_bytes, label)| MnistItemRaw { image_bytes, label })\n            .collect();\n\n        let dataset = InMemDataset::new(items);\n        let dataset = MapperDataset::new(dataset, BytesToImage);\n\n        Self { dataset }\n    }\n\n    /// Download the MNIST dataset files from the web.\n    /// Panics if the download cannot be completed or the content of the file cannot be written to disk.\n    fn download(split: &str) -> PathBuf {\n        // Dataset files are stored in the burn-dataset cache directory\n        let cache_dir = dirs::cache_dir()\n            .expect(\"Could not get cache directory\")\n            .join(\"burn-dataset\");\n        let split_dir = cache_dir.join(\"mnist\").join(split);\n\n        if !split_dir.exists() {\n            create_dir_all(&split_dir).expect(\"Failed to create base directory\");\n        }\n\n        // Download split files\n        match split {\n            \"train\" => {\n                MnistDataset::download_file(TRAIN_IMAGES, &split_dir);\n                MnistDataset::download_file(TRAIN_LABELS, &split_dir);\n            }\n            \"test\" => {\n                MnistDataset::download_file(TEST_IMAGES, &split_dir);\n                MnistDataset::download_file(TEST_LABELS, &split_dir);\n            }\n            _ => panic!(\"Invalid split specified {split}\"),\n        };\n\n        split_dir\n    }\n\n    /// Download a file from the MNIST dataset URL to the destination directory.\n    /// File download progress is reported with the help of a [progress bar](indicatif).\n    fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {\n        // Output file name\n        let file_name = dest_dir.as_ref().join(name);\n\n        if !file_name.exists() {\n            // Download gzip file\n            let bytes = download_file_as_bytes(&format!(\"{URL}{name}.gz\"), name);\n\n            // Create file to write the downloaded content to\n            let mut output_file = File::create(&file_name).unwrap();\n\n            // Decode gzip file content and write to disk\n            let mut gz_buffer = GzDecoder::new(&bytes[..]);\n            std::io::copy(&mut gz_buffer, &mut output_file).unwrap();\n        }\n\n        file_name\n    }\n\n    /// Read images at the provided path for the specified split.\n    /// Each image is a vector of bytes.\n    fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {\n        let file_name = if split == \"train\" {\n            TRAIN_IMAGES\n        } else {\n            TEST_IMAGES\n        };\n        let file_name = root.as_ref().join(file_name);\n\n        // Read number of images from 16-byte header metadata\n        let mut f = File::open(file_name).unwrap();\n        let mut buf = [0u8; 4];\n        let _ = f.seek(SeekFrom::Start(4)).unwrap();\n        f.read_exact(&mut buf)\n            .expect(\"Should be able to read image file header\");\n        let size = u32::from_be_bytes(buf);\n\n        let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];\n        let _ = f.seek(SeekFrom::Start(16)).unwrap();\n        f.read_exact(&mut buf_images)\n            .expect(\"Should be able to read image file header\");\n\n        buf_images\n            .chunks(WIDTH * HEIGHT)\n            .map(|chunk| chunk.to_vec())\n            .collect()\n    }\n\n    /// Read labels at the provided path for the specified split.\n    fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {\n        let file_name = if split == \"train\" {\n            TRAIN_LABELS\n        } else {\n            TEST_LABELS\n        };\n        let file_name = root.as_ref().join(file_name);\n\n        // Read number of labels from 8-byte header metadata\n        let mut f = File::open(file_name).unwrap();\n        let mut buf = [0u8; 4];\n        let _ = f.seek(SeekFrom::Start(4)).unwrap();\n        f.read_exact(&mut buf)\n            .expect(\"Should be able to read label file header\");\n        let size = u32::from_be_bytes(buf);\n\n        let mut buf_labels: Vec<u8> = vec![0u8; size as usize];\n        let _ = f.seek(SeekFrom::Start(8)).unwrap();\n        f.read_exact(&mut buf_labels)\n            .expect(\"Should be able to read labels from file\");\n\n        buf_labels\n    }\n}\n"
  },
  {
    "path": "crates/burn-dataset/src/vision/mod.rs",
    "content": "#[cfg(feature = \"builtin-sources\")]\nmod cifar;\nmod image_folder;\nmod mnist;\n\n#[cfg(feature = \"builtin-sources\")]\npub use cifar::*;\npub use image_folder::*;\npub use mnist::*;\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/dataset-fmt.csv",
    "content": "HI1 1 true 1.0\r\nHI2 1 false 1.0\r\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/dataset.csv",
    "content": "column_str,column_int,column_bool,column_float\r\nHI1,1,true,1.0\r\nHI2,1,false,1.0\r\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/dataset.json",
    "content": "{\"column_str\":\"HI1\",\"column_bytes\":[1,2,3,3],\"column_int\":1,\"column_bool\":true,\"column_float\":1.0}\n{\"column_str\":\"HI2\",\"column_bytes\":[1,2,3,3],\"column_int\":1,\"column_bool\":false,\"column_float\":1.0}"
  },
  {
    "path": "crates/burn-dataset/tests/data/dataset_coco.json",
    "content": "{\n  \"images\": [\n    {\n      \"width\": 32,\n      \"height\": 32,\n      \"id\": 0,\n      \"file_name\": \"two_dots_and_triangle.jpg\"\n    },\n    {\n      \"width\": 32,\n      \"height\": 32,\n      \"id\": 1,\n      \"file_name\": \"dot_triangle.jpg\"\n    },\n    {\n      \"width\": 32,\n      \"height\": 32,\n      \"id\": 2,\n      \"file_name\": \"one_dot.jpg\"\n    }\n  ],\n  \"categories\": [\n    {\n      \"id\": 0,\n      \"name\": \"dot\"\n    },\n    {\n      \"id\": 1,\n      \"name\": \"triangle\"\n    }\n  ],\n  \"annotations\": [\n    {\n      \"id\": 0,\n      \"image_id\": 0,\n      \"category_id\": 0,\n      \"segmentation\": [],\n      \"bbox\": [\n        3.1251719394773056,\n        18.0907840440165,\n        10.96011004126548,\n        10.740027510316379\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 117.71188335928603\n    },\n    {\n      \"id\": 1,\n      \"image_id\": 0,\n      \"category_id\": 0,\n      \"segmentation\": [],\n      \"bbox\": [\n        3.2572214580467658,\n        3.0371389270976605,\n        10.563961485557085,\n        10.828060522696012\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 114.38721432504178\n    },\n    {\n      \"id\": 2,\n      \"image_id\": 0,\n      \"category_id\": 1,\n      \"segmentation\": [],\n      \"bbox\": [\n        15.097661623108666,\n        3.3892709766162312,\n        12.632737276478679,\n        11.18019257221458\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 141.23643546522516\n    },\n    {\n      \"id\": 3,\n      \"image_id\": 1,\n      \"category_id\": 0,\n      \"segmentation\": [],\n      \"bbox\": [\n        3.125171939477304,\n        17.914718019257222,\n        10.82806052269601,\n        11.004126547455297\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 119.15334825525184\n    },\n    {\n      \"id\": 4,\n      \"image_id\": 1,\n      \"category_id\": 1,\n      \"segmentation\": [],\n      \"bbox\": [\n        15.27372764786794,\n        3.301237964236589,\n        12.192572214580478,\n        11.708390646492433\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 142.7553984738776\n    },\n    {\n      \"id\": 5,\n      \"image_id\": 2,\n      \"category_id\": 0,\n      \"segmentation\": [],\n      \"bbox\": [\n        10.07977991746905,\n        9.59559834938102,\n        10.960110041265464,\n        11.356258596973863\n      ],\n      \"ignore\": 0,\n      \"iscrowd\": 0,\n      \"area\": 124.46584387990049\n    }\n  ],\n  \"info\": {\n    \"year\": 2024,\n    \"version\": \"1.0\",\n    \"description\": \"\",\n    \"contributor\": \"\",\n    \"url\": \"\",\n    \"date_created\": \"2024-12-11 22:16:31.823494\"\n  }\n}\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/segmask_folder/annotations/mask_checkerboard.txt",
    "content": "1 2 1 2 1 2 1 2\n2 1 2 1 2 1 2 1\n1 2 1 2 1 2 1 2\n2 1 2 1 2 1 2 1\n1 2 1 2 1 2 1 2\n2 1 2 1 2 1 2 1\n1 2 1 2 1 2 1 2\n2 1 2 1 2 1 2 1\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_2colors.txt",
    "content": "1 2 1 1 1 2 1 1\n1 2 1 1 1 1 2 1\n2 2 2 1 2 1 2 2\n2 2 2 2 2 2 1 1\n2 2 2 1 2 1 1 1\n1 1 2 2 2 2 2 1\n2 2 1 2 1 2 1 2\n2 1 1 1 1 1 1 1\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/segmask_folder/annotations/mask_random_3colors.txt",
    "content": "3 1 3 3 1 1 3 2\n3 3 3 3 1 3 2 1\n2 2 2 2 1 1 2 2\n1 1 1 3 3 3 2 3\n2 2 3 2 3 3 1 3\n1 3 3 1 1 3 2 1\n2 2 2 1 2 1 2 3\n3 1 3 3 2 1 2 2\n"
  },
  {
    "path": "crates/burn-dataset/tests/data/text_folder/negative/sample1.txt",
    "content": "This is a negative text sample for testing the text folder dataset functionality."
  },
  {
    "path": "crates/burn-dataset/tests/data/text_folder/negative/sample2.txt",
    "content": "另一个负面文本样本，用以确保数据集能够处理同一类别中的多个文件。"
  },
  {
    "path": "crates/burn-dataset/tests/data/text_folder/positive/sample1.txt",
    "content": "This is a positive text sample for testing the text folder dataset functionality."
  },
  {
    "path": "crates/burn-dataset/tests/data/text_folder/positive/sample2.txt",
    "content": "另一个正面文本样本，以确保数据集能够处理同一类别中的多个文件。"
  },
  {
    "path": "crates/burn-derive/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Derive crate for the Burn framework\"\nedition.workspace = true\nkeywords = []\nlicense.workspace = true\nname = \"burn-derive\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-derive\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[lib]\nproc-macro = true\n\n[dependencies]\nproc-macro2 = { workspace = true }\nquote = { workspace = true }\nsyn = { workspace = true }\nderive-new = { workspace = true }\n"
  },
  {
    "path": "crates/burn-derive/README.md",
    "content": "# Burn Derive\n\nThis crate should only be used with [burn](https://github.com/tracel-ai/burn).\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-derive.svg)](https://crates.io/crates/burn-derive)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-derive/blob/master/README.md)\n"
  },
  {
    "path": "crates/burn-derive/src/config/analyzer.rs",
    "content": "use super::ConfigEnumAnalyzer;\nuse crate::config::ConfigStructAnalyzer;\nuse crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};\nuse proc_macro2::TokenStream;\nuse quote::quote;\nuse syn::{Field, Ident};\n\npub struct ConfigAnalyzerFactory {}\n\npub trait ConfigAnalyzer {\n    fn gen_new_fn(&self) -> TokenStream {\n        quote! {}\n    }\n    fn gen_builder_fns(&self) -> TokenStream {\n        quote! {}\n    }\n    fn gen_serde_impl(&self) -> TokenStream;\n    fn gen_clone_impl(&self) -> TokenStream;\n    fn gen_display_impl(&self) -> TokenStream;\n    fn gen_config_impl(&self) -> TokenStream;\n}\n\nimpl ConfigAnalyzerFactory {\n    pub fn new() -> Self {\n        Self {}\n    }\n\n    pub fn create_analyzer(&self, item: &syn::DeriveInput) -> Box<dyn ConfigAnalyzer> {\n        let name = item.ident.clone();\n        let config_type = parse_asm(item);\n\n        match config_type {\n            ConfigType::Struct(data) => Box::new(self.create_struct_analyzer(name, data)),\n            ConfigType::Enum(data) => Box::new(self.create_enum_analyzer(name, data)),\n        }\n    }\n\n    fn create_struct_analyzer(&self, name: Ident, fields: Vec<Field>) -> ConfigStructAnalyzer {\n        let fields = fields.into_iter().map(FieldTypeAnalyzer::new);\n\n        let mut fields_required = Vec::new();\n        let mut fields_option = Vec::new();\n        let mut fields_default = Vec::new();\n\n        for field in fields {\n            let attributes: Vec<AttributeItem> = field\n                .attributes()\n                .filter(|attr| attr.has_name(\"config\"))\n                .map(|attr| attr.item())\n                .collect();\n\n            if !attributes.is_empty() {\n                let item = attributes.first().unwrap().clone();\n                fields_default.push((field.clone(), item));\n                continue;\n            }\n\n            if field.is_of_type(&[\"Option\"]) {\n                fields_option.push(field.clone());\n                continue;\n            }\n\n            fields_required.push(field.clone());\n        }\n\n        ConfigStructAnalyzer::new(name, fields_required, fields_option, fields_default)\n    }\n\n    fn create_enum_analyzer(&self, name: Ident, data: syn::DataEnum) -> ConfigEnumAnalyzer {\n        ConfigEnumAnalyzer::new(name, data)\n    }\n}\n\nenum ConfigType {\n    Struct(Vec<Field>),\n    Enum(syn::DataEnum),\n}\n\nfn parse_asm(ast: &syn::DeriveInput) -> ConfigType {\n    match &ast.data {\n        syn::Data::Struct(struct_data) => {\n            ConfigType::Struct(struct_data.fields.clone().into_iter().collect())\n        }\n        syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),\n        syn::Data::Union(_) => panic!(\"Only struct and enum can be derived\"),\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/config/analyzer_enum.rs",
    "content": "use crate::shared::enum_variant::map_enum_variant;\n\nuse super::ConfigAnalyzer;\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\n\npub struct ConfigEnumAnalyzer {\n    name: Ident,\n    data: syn::DataEnum,\n}\n\nimpl ConfigEnumAnalyzer {\n    pub fn new(name: Ident, data: syn::DataEnum) -> Self {\n        Self { name, data }\n    }\n\n    fn serde_enum_ident(&self) -> Ident {\n        Ident::new(&format!(\"{}Serde\", self.name), self.name.span())\n    }\n\n    fn gen_serde_enum(&self) -> TokenStream {\n        let enum_name = self.serde_enum_ident();\n        let data = &self.data.variants;\n\n        quote! {\n            #[derive(burn::serde::Serialize, burn::serde::Deserialize)]\n            #[serde(crate = \"burn::serde\")]\n            enum #enum_name {\n                #data\n            }\n\n        }\n    }\n\n    fn gen_serialize_fn(&self) -> TokenStream {\n        let enum_name = self.serde_enum_ident();\n        let variants = self.data.variants.iter().map(|variant| {\n            let variant_name = &variant.ident;\n            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });\n\n            quote! { Self::#variant_name #inputs => #enum_name::#variant_name #outputs }\n        });\n\n        let name = &self.name;\n\n        quote! {\n            impl burn::serde::Serialize for #name {\n                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n                where\n                    S: burn::serde::Serializer {\n                    let serde_state = match self {\n                        #(#variants),*\n                    };\n                    serde_state.serialize(serializer)\n                }\n            }\n\n        }\n    }\n\n    fn gen_deserialize_fn(&self) -> TokenStream {\n        let enum_name = self.serde_enum_ident();\n        let variants = self.data.variants.iter().map(|variant| {\n            let variant_name = &variant.ident;\n            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });\n\n            quote! { #enum_name::#variant_name #inputs => Self::#variant_name #outputs }\n        });\n        let name = &self.name;\n\n        quote! {\n            impl<'de> burn::serde::Deserialize<'de> for #name {\n                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>\n                where\n                    D: burn::serde::Deserializer<'de> {\n                    let serde_state = #enum_name::deserialize(deserializer)?;\n                    Ok(match serde_state {\n                        #(#variants),*\n                    })\n                }\n            }\n\n        }\n    }\n}\n\nimpl ConfigAnalyzer for ConfigEnumAnalyzer {\n    fn gen_serde_impl(&self) -> TokenStream {\n        let struct_gen = self.gen_serde_enum();\n        let serialize_gen = self.gen_serialize_fn();\n        let deserialize_gen = self.gen_deserialize_fn();\n\n        quote! {\n            #struct_gen\n            #serialize_gen\n            #deserialize_gen\n        }\n    }\n\n    fn gen_clone_impl(&self) -> TokenStream {\n        let variants = self.data.variants.iter().map(|variant| {\n            let variant_name = &variant.ident;\n            let (inputs, outputs) = map_enum_variant(variant, |ident| quote! { #ident.clone() });\n\n            quote! { Self::#variant_name #inputs => Self::#variant_name #outputs }\n        });\n        let name = &self.name;\n\n        quote! {\n            impl Clone for #name {\n                fn clone(&self) -> Self {\n                    match self {\n                        #(#variants),*\n                    }\n                }\n            }\n\n        }\n    }\n\n    fn gen_display_impl(&self) -> TokenStream {\n        let name = &self.name;\n\n        quote! {\n            impl core::fmt::Display for #name {\n                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n                    f.write_str(&burn::config::config_to_json(self))\n                }\n            }\n        }\n    }\n\n    fn gen_config_impl(&self) -> TokenStream {\n        let name = &self.name;\n\n        quote! {\n            impl burn::config::Config for #name {\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/config/analyzer_struct.rs",
    "content": "use super::ConfigAnalyzer;\nuse crate::shared::{attribute::AttributeItem, field::FieldTypeAnalyzer};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\n\npub struct ConfigStructAnalyzer {\n    name: Ident,\n    fields_required: Vec<FieldTypeAnalyzer>,\n    fields_option: Vec<FieldTypeAnalyzer>,\n    fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,\n}\n\nimpl ConfigStructAnalyzer {\n    pub fn new(\n        name: Ident,\n        fields_required: Vec<FieldTypeAnalyzer>,\n        fields_option: Vec<FieldTypeAnalyzer>,\n        fields_default: Vec<(FieldTypeAnalyzer, AttributeItem)>,\n    ) -> Self {\n        Self {\n            name,\n            fields_required,\n            fields_option,\n            fields_default,\n        }\n    }\n\n    fn wrap_impl_block(&self, tokens: TokenStream) -> TokenStream {\n        let name = &self.name;\n\n        quote! {\n            impl #name {\n                #tokens\n            }\n        }\n    }\n\n    fn names(&self) -> Vec<FieldTypeAnalyzer> {\n        let mut names = Vec::new();\n\n        for field in self.fields_required.iter() {\n            names.push(field.clone());\n        }\n\n        for field in self.fields_option.iter() {\n            names.push(field.clone());\n        }\n\n        for (field, _) in self.fields_default.iter() {\n            names.push(field.clone());\n        }\n\n        names\n    }\n\n    fn name_types(&self, names: &[FieldTypeAnalyzer]) -> Vec<TokenStream> {\n        let mut name_types = Vec::new();\n\n        for field in names.iter() {\n            let name = field.ident();\n            let ty = &field.field.ty;\n\n            name_types.push(quote! {\n                #name: #ty\n            });\n        }\n\n        name_types\n    }\n\n    fn serde_struct_ident(&self) -> Ident {\n        Ident::new(&format!(\"{}Serde\", self.name), self.name.span())\n    }\n\n    fn gen_serialize_fn(\n        &self,\n        struct_name: &Ident,\n        struct_gen: &TokenStream,\n        names: &[FieldTypeAnalyzer],\n    ) -> TokenStream {\n        let name = &self.name;\n        let names = names.iter().map(|name| {\n            let name = name.ident();\n            quote! { #name: self.#name.clone() }\n        });\n\n        quote! {\n            impl burn::serde::Serialize for #name {\n\n                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>\n                where\n                    S: burn::serde::Serializer {\n                    #[derive(burn::serde::Serialize)]\n                    #[serde(crate = \"burn::serde\")]\n                    #struct_gen\n\n                    let serde_state = #struct_name {\n                        #(#names),*\n                    };\n                    serde_state.serialize(serializer)\n                }\n            }\n\n        }\n    }\n\n    fn gen_deserialize_fn(\n        &self,\n        struct_name: &Ident,\n        struct_gen: &TokenStream,\n        names: &[FieldTypeAnalyzer],\n    ) -> TokenStream {\n        let name = &self.name;\n        let names = names.iter().map(|name| {\n            let name = name.ident();\n            quote! { #name: serde_state.#name }\n        });\n\n        quote! {\n            impl<'de> burn::serde::Deserialize<'de> for #name {\n                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>\n                where\n                    D: burn::serde::Deserializer<'de> {\n                    #[derive(burn::serde::Deserialize)]\n                    #[serde(crate = \"burn::serde\")]\n                    #struct_gen\n\n                    let serde_state = #struct_name::deserialize(deserializer)?;\n                    Ok(#name {\n                        #(#names),*\n                    })\n                }\n            }\n\n        }\n    }\n\n    fn gen_serde_struct(&self, names: &[TokenStream]) -> TokenStream {\n        let struct_name = self.serde_struct_ident();\n\n        quote! {\n            struct #struct_name {\n                #(#names),*\n            }\n\n        }\n    }\n}\n\nimpl ConfigAnalyzer for ConfigStructAnalyzer {\n    fn gen_new_fn(&self) -> TokenStream {\n        let mut body = quote! {};\n        let mut args = Vec::new();\n\n        let mut fn_docs = quote! {};\n        let mut has_field_docs = false;\n        let mut has_required_docs = false;\n        let mut has_option_docs = false;\n        let mut has_default_docs = false;\n        let mut docs_header = |fn_docs: &mut TokenStream,\n                               required_docs: bool,\n                               option_docs: bool,\n                               default_docs: bool| {\n            if !has_field_docs {\n                has_field_docs = true;\n                fn_docs.extend(quote! {\n                    #[doc = \"# Arguments\"]\n                });\n            }\n            if !has_required_docs && required_docs {\n                fn_docs.extend(quote! {\n                    #[doc = \"###### Required Arguments\"]\n                });\n                has_required_docs = true;\n            }\n            if !has_option_docs && option_docs {\n                fn_docs.extend(quote! {\n                    #[doc = \"###### Optional Arguments\"]\n                });\n                has_option_docs = true;\n            }\n            if !has_default_docs && default_docs {\n                fn_docs.extend(quote! {\n                    #[doc = \"###### Default Arguments\"]\n                });\n                has_default_docs = true;\n            }\n        };\n\n        for field in self.fields_required.iter() {\n            let name = field.ident();\n            let ty = &field.field.ty;\n            let docs = field.docs();\n\n            body.extend(quote! {\n                #name: #name,\n            });\n            args.push(quote! {\n                #name: #ty\n            });\n            docs_header(&mut fn_docs, true, false, false);\n            let doc_str = format!(\"###### `{}`\\n\\n\", quote!(#name));\n            fn_docs.extend(quote! {\n                #[doc = #doc_str]\n                #(#docs)*\n            });\n        }\n\n        for field in self.fields_option.iter() {\n            let name = field.ident();\n            let docs = field.docs();\n\n            body.extend(quote! {\n                #name: None,\n            });\n            docs_header(&mut fn_docs, false, true, false);\n            let default_doc = \"- Defaults to `None`\";\n            let doc_str = format!(\"###### `{}`\\n\", quote!(#name));\n            fn_docs.extend(quote! {\n                #[doc = #doc_str]\n                #(#docs)*\n                #[doc = #default_doc]\n            });\n        }\n\n        for (field, attribute) in self.fields_default.iter() {\n            let name = field.ident();\n            let value = &attribute.value;\n            let docs = field.docs();\n\n            match value {\n                syn::Lit::Str(value) => {\n                    let stream: proc_macro2::TokenStream = value.value().parse().unwrap();\n\n                    body.extend(quote! {\n                        #name: #stream,\n                    });\n                }\n                _ => {\n                    body.extend(quote! {\n                        #name: #value,\n                    });\n                }\n            };\n            docs_header(&mut fn_docs, false, false, true);\n            let default_doc = format!(\"- Defaults to `{}`\", quote!(#value));\n            let doc_str = format!(\"###### `{}`\\n\", quote!(#name));\n            fn_docs.extend(quote! {\n                #[doc = #doc_str]\n                #(#docs)*\n                #[doc = #default_doc]\n            });\n        }\n\n        let body = quote! {\n            #[doc = \"Create a new instance of the config.\"]\n            #fn_docs\n            #[allow(clippy::too_many_arguments)]\n            pub fn new(\n                #(#args),*\n            ) -> Self {\n                Self { #body }\n            }\n        };\n        self.wrap_impl_block(body)\n    }\n\n    fn gen_builder_fns(&self) -> TokenStream {\n        let mut body = quote! {};\n\n        for (field, attribute) in self.fields_default.iter() {\n            let name = field.ident();\n            let ty = &field.field.ty;\n            let value = &attribute.value;\n            let docs = field.docs();\n            let default_doc = format!(\"- Defaults to `{}`\", quote!(#value));\n            let doc_str = format!(\n                \"Sets the value for the field [`{}`](Self::{0}).\\n\\n\",\n                quote!(#name)\n            );\n            let fn_docs = quote! {\n                #[doc = #doc_str]\n                #(#docs)*\n                #[doc = #default_doc]\n            };\n            let fn_name = Ident::new(&format!(\"with_{name}\"), name.span());\n\n            body.extend(quote! {\n                #fn_docs\n                pub fn #fn_name(mut self, #name: #ty) -> Self {\n                    self.#name = #name;\n                    self\n                }\n            });\n        }\n\n        for field in self.fields_option.iter() {\n            let name = field.ident();\n            let ty = &field.field.ty;\n            let docs = field.docs();\n            let default_doc = \"- Defaults to `None`\";\n            let doc_str = format!(\n                \"Sets the value for the field [`{}`](Self::{0}).\\n\\n\",\n                quote!(#name)\n            );\n            let fn_docs = quote! {\n                #[doc = #doc_str]\n                #(#docs)*\n                #[doc = #default_doc]\n            };\n            let fn_name = Ident::new(&format!(\"with_{name}\"), name.span());\n\n            body.extend(quote! {\n                #fn_docs\n                pub fn #fn_name(mut self, #name: #ty) -> Self {\n                    self.#name = #name;\n                    self\n                }\n            });\n        }\n\n        self.wrap_impl_block(body)\n    }\n\n    fn gen_serde_impl(&self) -> TokenStream {\n        let names = self.names();\n\n        let struct_name = self.serde_struct_ident();\n        let name_types = self.name_types(&names);\n        let struct_gen = self.gen_serde_struct(&name_types);\n\n        let serialize_gen = self.gen_serialize_fn(&struct_name, &struct_gen, &names);\n        let deserialize_gen = self.gen_deserialize_fn(&struct_name, &struct_gen, &names);\n\n        quote! {\n            #serialize_gen\n            #deserialize_gen\n        }\n    }\n\n    fn gen_clone_impl(&self) -> TokenStream {\n        let name = &self.name;\n        let names = self.names().into_iter().map(|name| {\n            let name = name.ident();\n            quote! { #name: self.#name.clone() }\n        });\n\n        quote! {\n            impl Clone for #name {\n                fn clone(&self) -> Self {\n                    Self {\n                        #(#names),*\n                    }\n                }\n            }\n\n        }\n    }\n\n    fn gen_display_impl(&self) -> TokenStream {\n        let name = &self.name;\n\n        quote! {\n            impl core::fmt::Display for #name {\n                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n                    f.write_str(&burn::config::config_to_json(self))\n                }\n            }\n        }\n    }\n\n    fn gen_config_impl(&self) -> TokenStream {\n        let name = &self.name;\n\n        quote! {\n            impl burn::config::Config for #name {\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/config/base.rs",
    "content": "use super::ConfigAnalyzerFactory;\nuse quote::quote;\n\npub(crate) fn derive_impl(item: &syn::DeriveInput) -> proc_macro::TokenStream {\n    let factory = ConfigAnalyzerFactory::new();\n    let analyzer = factory.create_analyzer(item);\n\n    let constructor = analyzer.gen_new_fn();\n    let builders = analyzer.gen_builder_fns();\n    let serde = analyzer.gen_serde_impl();\n    let clone = analyzer.gen_clone_impl();\n    let display = analyzer.gen_display_impl();\n    let config_impl = analyzer.gen_config_impl();\n\n    quote! {\n        #config_impl\n        #constructor\n        #builders\n        #serde\n        #clone\n        #display\n    }\n    .into()\n}\n"
  },
  {
    "path": "crates/burn-derive/src/config/mod.rs",
    "content": "mod analyzer;\nmod analyzer_enum;\nmod analyzer_struct;\nmod base;\n\npub(crate) use analyzer::*;\npub(crate) use analyzer_enum::*;\npub(crate) use analyzer_struct::*;\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-derive/src/lib.rs",
    "content": "#![warn(missing_docs)]\n\n//! The derive crate of Burn.\n\n#[macro_use]\nextern crate derive_new;\n\nuse proc_macro::TokenStream;\n\npub(crate) mod config;\npub(crate) mod module;\npub(crate) mod record;\npub(crate) mod shared;\n\n/// Derive macro for the `Module` trait.\n///\n/// # Sub-modules\n///\n/// By default, the macro automatically detects sub-modules and parameters as module types.\n///\n/// Any field not recognized as a module type is assumed to be a non-module\n/// and is skipped by the module system (not persistent, not visited).\n///\n/// ## Generics\n///\n/// Generic type parameters (e.g., `field: M`) are assumed to be sub-modules by default.\n/// If a generic field represents some other runtime state or configuration, you can use\n/// the `#[module(skip)]` attribute to provide a hint.\n///\n/// # Field Attributes\n///\n/// ## `#[module(skip)]`\n///\n/// Explicitly marks a field to be ignored by the module derive.\n///\n/// Skipped fields are not parameters, not modules, and are not persistent.\n/// This is equivalent to the deprecated `Ignored<T>` wrapper.\n///\n/// ### Requirements\n///\n/// The field must implement: `Debug + Clone + Send`.\n///\n/// # Example\n///\n/// ```ignore\n/// #[derive(Module, Debug)]\n/// pub struct MyModule<B: Backend, M, N: NonModuleTrait> {\n///     /// A normal parameter.\n///     weights: Param<Tensor<B, 2>>,\n///     /// A field configured at runtime.\n///     dropout_prob: f64,\n///     /// A field that is recomputed at runtime.\n///     cached_mask: Option<Tensor<B, 2>>,\n///     /// A field that contains some debug state.\n///     debug_state: String,\n///     /// Treated as a module (default for generics).\n///     inner: M,\n///     /// Hint required: this generic is NOT a module.\n///     #[module(skip)]\n///     other: N,\n/// }\n/// ```\n#[proc_macro_derive(Module, attributes(module))]\npub fn module_derive(input: TokenStream) -> TokenStream {\n    let input = syn::parse(input).unwrap();\n    module::derive_impl(&input)\n}\n\n/// Derive macro for the record.\n#[proc_macro_derive(Record)]\npub fn record_derive(input: TokenStream) -> TokenStream {\n    let input = syn::parse(input).unwrap();\n    record::derive_impl(&input)\n}\n\n/// Derive macro for the config.\n#[proc_macro_derive(Config, attributes(config))]\npub fn config_derive(input: TokenStream) -> TokenStream {\n    let item = syn::parse(input).unwrap();\n    config::derive_impl(&item)\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/base.rs",
    "content": "use super::{\n    codegen::{generate_module_const, generate_module_standard},\n    codegen_enum::EnumModuleCodegen,\n    codegen_struct::StructModuleCodegen,\n};\nuse proc_macro::TokenStream;\n\npub(crate) fn derive_impl(ast: &syn::DeriveInput) -> TokenStream {\n    let has_backend = ast\n        .generics\n        .type_params()\n        .map(|param| param.ident == \"B\")\n        .reduce(|accum, is_backend| is_backend || accum)\n        .unwrap_or(false);\n\n    match &ast.data {\n        syn::Data::Struct(_) => match StructModuleCodegen::from_ast(ast) {\n            Ok(struct_codegen) => {\n                if has_backend {\n                    generate_module_standard(ast, struct_codegen)\n                } else {\n                    generate_module_const(ast)\n                }\n            }\n            Err(err) => err.to_compile_error(),\n        },\n        syn::Data::Enum(_data) => match EnumModuleCodegen::from_ast(ast) {\n            Ok(enum_codegen) => {\n                if has_backend {\n                    generate_module_standard(ast, enum_codegen)\n                } else {\n                    generate_module_const(ast)\n                }\n            }\n            Err(err) => err.to_compile_error(),\n        },\n        syn::Data::Union(_) => {\n            syn::Error::new_spanned(ast, \"Union modules aren't supported\").to_compile_error()\n        }\n    }\n    .into()\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/codegen.rs",
    "content": "use super::{display, record::ModuleRecordCodegen};\nuse crate::{\n    module::generics::{GenericKind, ModuleGenerics},\n    shared::generics::GenericsHelper,\n};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Attribute, Generics, parse_quote};\n\n/// Basic trait to be implemented for Module generation.\npub(crate) trait ModuleCodegen {\n    type RecordCodegen: ModuleRecordCodegen;\n\n    fn gen_num_params(&self) -> TokenStream;\n    fn gen_visit(&self) -> TokenStream;\n    fn gen_collect_devices(&self) -> TokenStream;\n    fn gen_to_device(&self) -> TokenStream;\n    fn gen_fork(&self) -> TokenStream;\n    fn gen_map(&self) -> TokenStream;\n    fn gen_valid(&self) -> TokenStream;\n    fn gen_from_inner(&self) -> TokenStream;\n    fn gen_into_record(&self) -> TokenStream;\n    fn gen_load_record(&self) -> TokenStream;\n    fn gen_clone(&self) -> TokenStream;\n\n    fn record_codegen(self) -> Self::RecordCodegen;\n\n    fn gen_display(&self) -> TokenStream;\n\n    fn module_generics(&self) -> &ModuleGenerics;\n}\n\npub(crate) fn generate_module_standard<Codegen: ModuleCodegen>(\n    ast: &syn::DeriveInput,\n    codegen: Codegen,\n) -> TokenStream {\n    let name = &ast.ident;\n\n    let generics = GenericsParser::from_ast(&ast.generics, codegen.module_generics());\n\n    let display_fn = display::display_fn(ast);\n    let attributes_fn = codegen.gen_display();\n    let num_params_fn = codegen.gen_num_params();\n    let visit = codegen.gen_visit();\n    let map_mut = codegen.gen_map();\n    let collect_devices = codegen.gen_collect_devices();\n    let to_device = codegen.gen_to_device();\n    let fork = codegen.gen_fork();\n    let valid_fn = codegen.gen_valid();\n    let from_inner_fn = codegen.gen_from_inner();\n    let into_record_fn = codegen.gen_into_record();\n    let load_record_fn = codegen.gen_load_record();\n    let clone_fn = codegen.gen_clone();\n\n    let record = codegen.record_codegen();\n    let record_name = Ident::new(format!(\"{name}Record\").as_str(), name.span());\n    let (record_type, record_generics) = record.gen_record_type(&record_name, &generics.module);\n\n    let (generics_module, generics_ty_module, generics_where_module) =\n        generics.module.split_for_impl();\n    let (generics_module_autodiff, generics_ty_module_autodiff, generics_where_module_autodiff) =\n        generics.module_autodiff.split_for_impl();\n    let (generics_module_has_autodiff, _generics_ty, generics_where_module_has_autodiff) =\n        generics.module_has_autodiff.split_for_impl();\n    let (_, generics_ty_record, _) = record_generics.split_for_impl();\n\n    let generics_ty_inner_module = generics.inner_module_ty;\n    let generics_ty_train_module = generics.train_module_ty;\n    let generics_ty_train_inner_module = generics.train_inner_ty;\n\n    let mut codegen = quote! {\n        impl #generics_module burn::module::Module<B> for #name #generics_ty_module #generics_where_module {\n            type Record = #record_name #generics_ty_record;\n\n            #load_record_fn\n            #into_record_fn\n\n            #num_params_fn\n\n            #visit\n            #map_mut\n\n            #collect_devices\n            #to_device\n            #fork\n\n        }\n\n        impl #generics_module_autodiff burn::module::AutodiffModule<B> for #name #generics_ty_module_autodiff #generics_where_module_autodiff\n        {\n            type InnerModule=#name<B::InnerBackend, #generics_ty_inner_module>;\n\n            #valid_fn\n\n            #from_inner_fn\n        }\n\n        impl #generics_module_has_autodiff burn::module::HasAutodiffModule<B> for #name<B::InnerBackend, #generics_ty_train_module> #generics_where_module_has_autodiff\n        {\n            type TrainModule=#name<B, #generics_ty_train_inner_module>;\n        }\n\n        impl #generics_module core::fmt::Display for #name #generics_ty_module #generics_where_module {\n            #display_fn\n        }\n\n\n        impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module {\n            #attributes_fn\n\n            fn num_params(&self) -> usize {\n                burn::module::Module::num_params(self)\n            }\n        }\n\n        impl #generics_module Clone for #name #generics_ty_module #generics_where_module {\n            #clone_fn\n        }\n\n        #record_type\n    };\n\n    if !has_custom_display(&ast.attrs) {\n        codegen.extend(quote! {\n            impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module {\n\n            }\n        });\n    }\n\n    codegen\n}\n\n// TODO: wait that means nothing is persistent... (empty!)\n\n// When there is no backend in the generic parameter, the type is considered as a constant.\npub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream {\n    let name = &ast.ident;\n    let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();\n\n    let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};\n    let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::AutodiffBackend >};\n\n    let mut generics_module = ast.generics.clone();\n    let mut generics_module_autodiff = ast.generics.clone();\n\n    for param in backend.params.into_iter() {\n        generics_module.params.push(param);\n    }\n    for param in backend_ad.params.into_iter() {\n        generics_module_autodiff.params.push(param);\n    }\n    let (generics_module, _, _) = generics_module.split_for_impl();\n    let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl();\n\n    let display_fn = display::display_fn(ast);\n    let attributes_fn = display::attributes_fn(ast);\n\n    let mut codegen = quote! {\n        impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {\n            burn::empty!(module);\n        }\n\n        impl #generics_module_ad burn::module::AutodiffModule<B>\n            for #name #generics_ty #generics_where {\n            burn::empty!(ad_module, #name #generics_ty);\n        }\n\n        impl #generics core::fmt::Display for #name #generics_ty #generics_where {\n            #display_fn\n        }\n\n\n        impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where {\n            #attributes_fn\n        }\n\n    };\n\n    if !has_custom_display(&ast.attrs) {\n        codegen.extend(quote! {\n            impl  #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where {\n\n            }\n        });\n    }\n\n    codegen\n}\n\nstruct GenericsParser {\n    module: Generics,\n    module_autodiff: Generics,\n    module_has_autodiff: Generics,\n    inner_module_ty: TokenStream,\n    train_module_ty: TokenStream,\n    train_inner_ty: TokenStream,\n}\n\nimpl GenericsParser {\n    fn from_ast(generics: &Generics, module_generics: &ModuleGenerics) -> Self {\n        let mut module = GenericsHelper::new(generics.clone());\n        let mut module_autodiff = GenericsHelper::new(generics.clone());\n        let mut module_has_autodiff = GenericsHelper::new(generics.clone());\n\n        let backend_trait = module.fetch_backend_trait();\n\n        module_autodiff.add_predicate(parse_quote! {\n                B: burn::tensor::backend::AutodiffBackend\n        });\n\n        module_autodiff.add_predicate(parse_quote! {\n                <B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait\n        });\n\n        module_has_autodiff.add_predicate(parse_quote! {\n                B: burn::tensor::backend::AutodiffBackend\n        });\n\n        module_has_autodiff.add_predicate(parse_quote! {\n                <B as burn::tensor::backend::AutodiffBackend>::InnerBackend: #backend_trait\n        });\n\n        let mut generics_names_except_backend = quote! {};\n        let mut train_generics_names_except_backend = quote! {};\n        let mut train_inner_generics_names_except_backend = quote! {};\n\n        module\n        .types()\n        .into_iter()\n        .filter(|ident| ident != \"B\")\n        .for_each(|ident| {\n            // By default, require module bound\n            let mut requires_module_bound = true;\n            let mut generic_kind = None;\n            if !module_generics.is_empty() {\n                generic_kind = module_generics.get_generic_kind(&ident);\n                let has_module_bound = matches!(generic_kind, Some(GenericKind::Module));\n                let is_unbounded = matches!(generic_kind, Some(GenericKind::Plain));\n\n                requires_module_bound = has_module_bound || is_unbounded;\n            }\n\n            if requires_module_bound {\n                module.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::Module<B>\n                    }\n                );\n\n                module.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::ModuleDisplay\n                    }\n                );\n\n                module_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::AutodiffModule<B>\n                    }\n                );\n\n                module_autodiff.add_predicate(\n                    parse_quote! {\n                        <#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::Module<B::InnerBackend>\n                    }\n                );\n\n                module_autodiff.add_predicate(\n                    parse_quote! {\n                        <#ident as burn::module::AutodiffModule<B>>::InnerModule: burn::module::ModuleDisplay\n                    }\n                );\n\n                generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule<B>>::InnerModule, });\n\n                module_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::ModuleDisplay\n                    }\n                );\n\n                module_has_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::Module<B::InnerBackend>\n                    }\n                );\n\n                module_has_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::ModuleDisplay\n                    }\n                );\n\n                module_has_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident: burn::module::HasAutodiffModule<B>\n                    }\n                );\n\n                module_has_autodiff.add_predicate(\n                    parse_quote! {\n                        #ident::TrainModule: burn::module::ModuleDisplay\n                    }\n                );\n                train_generics_names_except_backend.extend(quote! { #ident, });\n                train_inner_generics_names_except_backend.extend(quote! { #ident::TrainModule, });\n            }\n            else {\n                // Add required bounds to impl\n                if let Some(GenericKind::Skip) = generic_kind {\n                    module.add_predicate(\n                        parse_quote! {\n                            #ident: Clone + core::fmt::Debug + Send\n                        }\n                    );\n                    module_autodiff.add_predicate(\n                        parse_quote! {\n                            #ident: Clone + core::fmt::Debug + Send\n                        }\n                    );\n                    module_has_autodiff.add_predicate(\n                        parse_quote! {\n                            #ident: Clone + core::fmt::Debug + Send\n                        }\n                    );\n                }\n\n                // Pass through\n                generics_names_except_backend.extend(quote! { #ident, });\n                train_generics_names_except_backend.extend(quote! { #ident, });\n                train_inner_generics_names_except_backend.extend(quote! { #ident, });\n            }\n\n        });\n\n        module.consts().into_iter().for_each(|ident| {\n            generics_names_except_backend.extend(quote! { #ident, });\n            train_generics_names_except_backend.extend(quote! { #ident, });\n            train_inner_generics_names_except_backend.extend(quote! { #ident, });\n        });\n\n        Self {\n            module: module.generics,\n            module_autodiff: module_autodiff.generics,\n            module_has_autodiff: module_has_autodiff.generics,\n            inner_module_ty: generics_names_except_backend,\n            train_module_ty: train_generics_names_except_backend,\n            train_inner_ty: train_inner_generics_names_except_backend,\n        }\n    }\n}\n\nfn has_custom_display(attrs: &[Attribute]) -> bool {\n    attrs.iter().any(|attr| {\n        attr.path().is_ident(\"module\")\n            && attr\n                .parse_nested_meta(|meta| {\n                    if meta.path.is_ident(\"custom_display\") {\n                        Ok(())\n                    } else {\n                        Err(meta.error(\"unsupported attribute\"))\n                    }\n                })\n                .is_ok()\n    })\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/codegen_enum.rs",
    "content": "use super::{codegen::ModuleCodegen, record_enum::EnumModuleRecordCodegen};\nuse crate::{\n    module::generics::{ModuleGenerics, parse_module_generics},\n    shared::enum_variant::{EnumVariant, parse_variants},\n};\nuse proc_macro2::{Ident, Span, TokenStream};\nuse quote::quote;\nuse syn::Visibility;\n\npub(crate) struct EnumModuleCodegen {\n    pub name: Ident,\n    pub variants: Vec<EnumVariant>,\n    pub vis: Visibility,\n    pub generics: ModuleGenerics,\n}\n\nimpl ModuleCodegen for EnumModuleCodegen {\n    type RecordCodegen = EnumModuleRecordCodegen;\n\n    fn gen_num_params(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|_| {\n            quote! {\n                burn::module::Module::<B>::num_params(module)\n            }\n        });\n\n        quote! {\n            fn num_params(&self) -> usize {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_visit(&self) -> TokenStream {\n        let enum_name = self.name.to_string();\n        let container_type = format!(\"Enum:{}\", enum_name);\n        let match_body = self.gen_variants_match_fn(|variant_name| {\n            let variant_str = variant_name.to_string();\n            quote! {\n                {\n                    visitor.enter_module(#variant_str, #container_type);\n                    burn::module::Module::visit(module, visitor);\n                    visitor.exit_module(#variant_str, #container_type);\n                }\n            }\n        });\n\n        quote! {\n            fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_collect_devices(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|_| {\n            quote! {\n                burn::module::Module::<B>::collect_devices(module, devices)\n            }\n        });\n\n        quote! {\n            fn collect_devices(\n                &self,\n                devices: burn::module::Devices<B>\n            ) -> burn::module::Devices<B> {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_to_device(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                Self::#variant(burn::module::Module::<B>::to_device(module, device))\n            }\n        });\n\n        quote! {\n            fn to_device(self, device: &B::Device) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_fork(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                Self::#variant(burn::module::Module::<B>::fork(module, device))\n            }\n        });\n\n        quote! {\n            fn fork(self, device: &B::Device) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_map(&self) -> TokenStream {\n        let enum_name = self.name.to_string();\n        let container_type = format!(\"Enum:{}\", enum_name);\n        let match_body = self.gen_variants_match_fn(|variant| {\n            let variant_str = variant.to_string();\n            quote! {\n                {\n                    mapper.enter_module(#variant_str, #container_type);\n                    let result = burn::module::Module::<B>::map(module, mapper);\n                    mapper.exit_module(#variant_str, #container_type);\n                    Self::#variant(result)\n                }\n            }\n        });\n\n        quote! {\n            fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_valid(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                Self::InnerModule::#variant(burn::module::AutodiffModule::<B>::valid(module))\n            }\n        });\n\n        quote! {\n            fn valid(&self) -> Self::InnerModule {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_from_inner(&self) -> TokenStream {\n        let match_body =\n            self.gen_variants_match_fn_param(\"module\", \"Self::InnerModule::\", |variant| {\n                quote! {\n                    Self::#variant(burn::module::AutodiffModule::<B>::from_inner(module))\n                }\n            });\n\n        quote! {\n            fn from_inner(module: Self::InnerModule) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_into_record(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                Self::Record::#variant(burn::module::Module::<B>::into_record(module))\n            }\n        });\n\n        quote! {\n            fn into_record(self) -> Self::Record {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_load_record(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                {\n                    let Self::Record::#variant(r) = record else {panic!(\"Can't parse record from a different variant\");};\n                    Self::#variant(burn::module::Module::<B>::load_record(module, r))\n                }\n            }\n        });\n\n        quote! {\n            fn load_record(self, record: Self::Record) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn gen_clone(&self) -> TokenStream {\n        let match_body = self.gen_variants_match_fn(|variant| {\n            quote! {\n                Self::#variant(module.clone())\n            }\n        });\n\n        quote! {\n            fn clone(&self) -> Self {\n                #match_body\n            }\n        }\n    }\n\n    fn record_codegen(self) -> Self::RecordCodegen {\n        EnumModuleRecordCodegen::new(self.variants, self.vis)\n    }\n\n    fn module_generics(&self) -> &ModuleGenerics {\n        &self.generics\n    }\n\n    fn gen_display(&self) -> TokenStream {\n        // Only tuple enum variants with exactly one field are supported\n        let variant_prints = self.variants.iter().map(|variant| {\n            let variant_name = &variant.ident;\n            let field_names =\n                (0..1).map(|i| syn::Ident::new(&format!(\"_{i}\"), proc_macro2::Span::call_site()));\n\n            let field_prints = field_names.clone().map(|field_name| {\n                quote! { .add(stringify!(#field_name), #field_name) }\n            });\n            quote! {\n                Self::#variant_name(#(#field_names),*) => {\n                    content.set_top_level_type(&stringify!(#variant_name))\n                    #(#field_prints)*\n                    .optional()\n                }\n            }\n        });\n        quote! {\n            fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {\n                match self {\n                    #(#variant_prints)*\n                }\n            }\n        }\n    }\n}\n\nimpl EnumModuleCodegen {\n    pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {\n        Ok(Self {\n            name: ast.ident.clone(),\n            variants: parse_variants(ast)?,\n            vis: ast.vis.clone(),\n            generics: parse_module_generics(&ast.generics),\n        })\n    }\n\n    /// Generate the enum variants' match arms with the provided function\n    fn gen_variants_match_fn<F>(&self, func: F) -> TokenStream\n    where\n        F: Fn(Ident) -> TokenStream,\n    {\n        self.gen_variants_match_fn_param(\"self\", \"Self::\", func)\n    }\n\n    /// Generate a match expression over the given argument (e.g., `self`)\n    /// and using the provided prefix for variants (e.g., `Self::` or `Self::InnerModule::`)\n    fn gen_variants_match_fn_param<F>(&self, arg: &str, prefix: &str, func: F) -> TokenStream\n    where\n        F: Fn(Ident) -> TokenStream,\n    {\n        let match_arms = self.variants.iter().map(|variant| {\n            let name = &variant.ident;\n            let full_variant = syn::parse_str::<syn::Path>(&format!(\"{prefix}{name}\")).unwrap();\n            let arm_pattern = quote! { #full_variant(module) };\n            let arm_code = func(name.clone());\n            quote! { #arm_pattern => #arm_code, }\n        });\n\n        let arg = Ident::new(arg, Span::call_site());\n\n        quote! {\n            match #arg {\n                #(#match_arms)*\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/codegen_struct.rs",
    "content": "use std::collections::HashSet;\n\nuse crate::module::generics::{\n    GenericKind, ModuleGenerics, parse_module_generics, parse_ty_generics,\n};\n\nuse super::{codegen::ModuleCodegen, record_struct::StructModuleRecordCodegen};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::{ToTokens, quote};\nuse syn::{Field, Visibility};\n\npub(crate) struct StructModuleCodegen {\n    pub name: Ident,\n    pub fields: Vec<ModuleField>,\n    pub vis: Visibility,\n    pub generics: ModuleGenerics,\n}\n\nimpl ModuleCodegen for StructModuleCodegen {\n    type RecordCodegen = StructModuleRecordCodegen;\n\n    fn gen_num_params(&self) -> TokenStream {\n        let body = self.gen_fields_fn(|name, field_type| {\n            if field_type.is_parameter_module() || field_type.maybe_generic_module() {\n                quote! {\n                    num_params += burn::module::Module::<B>::num_params(&self.#name);\n                }\n            } else {\n                quote! {} // other fields have 0 params\n            }\n        });\n\n        quote! {\n            fn num_params(&self) -> usize {\n                let mut num_params = 0;\n                #body\n                num_params\n            }\n        }\n    }\n\n    fn gen_visit(&self) -> TokenStream {\n        let struct_name = self.name.to_string();\n        let container_type = format!(\"Struct:{}\", struct_name);\n        let body = self.gen_fields_fn(|name, field_type| {\n            if field_type.is_parameter_module() || field_type.maybe_generic_module() {\n                let name_str = name.to_string();\n                quote! {\n                    visitor.enter_module(#name_str, #container_type);\n                    burn::module::Module::visit(&self.#name, visitor);\n                    visitor.exit_module(#name_str, #container_type);\n                }\n            } else {\n                quote! {}\n            }\n        });\n\n        quote! {\n            fn visit<Visitor: burn::module::ModuleVisitor<B>>(&self, visitor: &mut Visitor) {\n                #body\n            }\n        }\n    }\n\n    fn gen_collect_devices(&self) -> TokenStream {\n        let body = self.gen_fields_fn(|name, field_type| {\n            if field_type.is_module || field_type.maybe_generic_module() {\n                quote! {\n                    let devices = burn::module::Module::<B>::collect_devices(&self.#name, devices);\n                }\n            } else {\n                quote! {}\n            }\n        });\n\n        quote! {\n            fn collect_devices(\n                &self,\n                devices: burn::module::Devices<B>\n            ) -> burn::module::Devices<B> {\n                #body\n                devices\n            }\n        }\n    }\n\n    fn gen_to_device(&self) -> TokenStream {\n        let (names, body) = self.gen_fields_fn_names(|name, field_type| {\n            if field_type.is_module || field_type.maybe_generic_module() {\n                quote! {\n                    let #name = burn::module::Module::<B>::to_device(self.#name, device);\n                }\n            } else {\n                quote! { let #name = self.#name; }\n            }\n        });\n\n        quote! {\n            fn to_device(self, device: &B::Device) -> Self {\n                #body\n                Self { #(#names),* }\n            }\n        }\n    }\n\n    fn gen_fork(&self) -> TokenStream {\n        let (names, body) = self.gen_fields_fn_names(|name, field_type| {\n            if field_type.is_module || field_type.maybe_generic_module() {\n                quote! {\n                    let #name = burn::module::Module::<B>::fork(self.#name, device);\n                }\n            } else {\n                quote! { let #name = self.#name; }\n            }\n        });\n\n        quote! {\n            fn fork(self, device: &B::Device) -> Self {\n                #body\n                Self { #(#names),* }\n            }\n        }\n    }\n\n    fn gen_map(&self) -> TokenStream {\n        let struct_name = self.name.to_string();\n        let container_type = format!(\"Struct:{}\", struct_name);\n        let (names, body) = self.gen_fields_fn_names(|name, field_type| {\n            if field_type.is_parameter_module() || field_type.maybe_generic_module() {\n                let name_str = name.to_string();\n                quote! {\n                    mapper.enter_module(#name_str, #container_type);\n                    let #name = burn::module::Module::<B>::map(self.#name, mapper);\n                    mapper.exit_module(#name_str, #container_type);\n                }\n            } else {\n                quote! { let #name = self.#name; }\n            }\n        });\n\n        quote! {\n            fn map<Mapper: burn::module::ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self {\n                #body\n                Self { #(#names),* }\n            }\n        }\n    }\n\n    fn gen_valid(&self) -> TokenStream {\n        let (names, body) = self.gen_fields_fn_names(|name, field_type| {\n            if field_type.is_module || field_type.maybe_generic_module() {\n                quote! {\n                    let #name = burn::module::AutodiffModule::<B>::valid(&self.#name);\n                }\n            } else {\n                quote! { let #name = self.#name.clone(); }\n            }\n        });\n\n        quote! {\n            fn valid(&self) -> Self::InnerModule {\n                #body\n                Self::InnerModule { #(#names),* }\n            }\n        }\n    }\n\n    fn gen_from_inner(&self) -> TokenStream {\n        let (names, body) = self.gen_fields_fn_names(|name, field_type| {\n            if field_type.is_module || field_type.maybe_generic_module() {\n                quote! {\n                    let #name = burn::module::AutodiffModule::<B>::from_inner(#name);\n                }\n            } else {\n                quote! { let #name = #name; }\n            }\n        });\n\n        let destructure = quote! {\n            let Self::InnerModule { #(#names),* } = module;\n        };\n\n        quote! {\n            fn from_inner(module: Self::InnerModule) -> Self {\n                #destructure\n                #body\n                Self { #(#names),* }\n            }\n        }\n    }\n\n    fn gen_into_record(&self) -> TokenStream {\n        let body = self.gen_fields_fn(|name, field_type| {\n            if field_type.is_persistent_module() || field_type.maybe_generic_module() {\n                quote! { #name: burn::module::Module::<B>::into_record(self.#name), }\n            } else {\n                match field_type.attr {\n                    // Default (None) gets skipped\n                    None | Some(ModuleFieldAttribute::Skip) => {\n                        quote! { #name: burn::module::EmptyRecord::new(), }\n                    }\n                }\n            }\n        });\n\n        quote! {\n            fn into_record(self) -> Self::Record {\n                Self::Record { #body }\n            }\n        }\n    }\n\n    fn gen_load_record(&self) -> TokenStream {\n        let body = self.gen_fields_fn(|name, field_type| {\n            if field_type.is_persistent_module() || field_type.maybe_generic_module() {\n                quote! { #name: burn::module::Module::<B>::load_record(self.#name, record.#name), }\n            } else {\n                match field_type.attr {\n                    // Default (None) gets skipped\n                    None | Some(ModuleFieldAttribute::Skip) => {\n                        quote! { #name: self.#name, }\n                    }\n                }\n            }\n        });\n\n        quote! {\n            fn load_record(self, record: Self::Record) -> Self {\n                Self { #body }\n            }\n        }\n    }\n\n    fn gen_clone(&self) -> TokenStream {\n        let (names, body) = self.gen_fields_fn_names(|name, _field_type| {\n            quote! {\n                let #name = self.#name.clone();\n            }\n        });\n\n        quote! {\n            fn clone(&self) -> Self {\n                #body\n                Self { #(#names),* }\n            }\n        }\n    }\n\n    fn record_codegen(self) -> Self::RecordCodegen {\n        StructModuleRecordCodegen::new(self.fields, self.vis)\n    }\n\n    fn module_generics(&self) -> &ModuleGenerics {\n        &self.generics\n    }\n\n    fn gen_display(&self) -> TokenStream {\n        let struct_name = self.name.to_string();\n        let field_prints = self.fields.iter().map(|field| {\n            let field_name = field.ident();\n            if field.field_type.is_module || field.field_type.maybe_generic_module() {\n                // Standard module type, use underlying `ModuleDisplay` impl\n                quote! { .add(stringify!(#field_name), &self.#field_name) }\n            } else {\n                // Not a module, use the debug implementation\n                quote! {\n                    .add_debug_attribute(stringify!(#field_name), &self.#field_name)\n                }\n            }\n        });\n        quote! {\n            fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {\n                content\n                    .set_top_level_type(&stringify!(#struct_name))\n                    #(#field_prints)*\n                    .optional()\n            }\n        }\n    }\n}\n\nimpl StructModuleCodegen {\n    pub fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {\n        let mut generics = parse_module_generics(&ast.generics);\n        Ok(Self {\n            name: ast.ident.clone(),\n            fields: parse_module_fields(ast, &mut generics)?,\n            vis: ast.vis.clone(),\n            generics,\n        })\n    }\n\n    fn gen_fields_fn_names<F>(&self, func: F) -> (Vec<Ident>, TokenStream)\n    where\n        F: Fn(Ident, &ModuleFieldType) -> TokenStream,\n    {\n        let mut body = quote! {};\n        let mut names = Vec::new();\n\n        for field in self.fields.iter() {\n            let name = field.ident();\n\n            names.push(name.clone());\n            body.extend(func(name, &field.field_type));\n        }\n\n        (names, body)\n    }\n\n    fn gen_fields_fn<F>(&self, func: F) -> TokenStream\n    where\n        F: Fn(Ident, &ModuleFieldType) -> TokenStream,\n    {\n        let mut body = quote! {};\n\n        for field in self.fields.iter() {\n            body.extend(func(field.ident(), &field.field_type));\n        }\n\n        body\n    }\n}\n\n#[derive(new)]\npub struct ModuleField {\n    pub field: Field,\n    pub field_type: ModuleFieldType,\n}\n\nimpl ModuleField {\n    pub fn ident(&self) -> Ident {\n        self.field.ident.clone().unwrap()\n    }\n}\n\n#[derive(Debug)]\npub enum ModuleFieldAttribute {\n    Skip,\n}\n\n#[derive(Default, Debug)]\npub struct ModuleFieldType {\n    pub is_module: bool,\n    pub attr: Option<ModuleFieldAttribute>,\n    pub generic_idents: HashSet<Ident>,\n}\n\nimpl ModuleFieldType {\n    /// Returns true if the field is a module with parameters\n    /// (i.e., a real module that is neither skipped nor constant).\n    pub fn is_parameter_module(&self) -> bool {\n        self.is_module && self.attr.is_none()\n    }\n\n    /// Returns true for modules that should be persisted, including constants.\n    pub fn is_persistent_module(&self) -> bool {\n        self.is_module && !matches!(self.attr, Some(ModuleFieldAttribute::Skip))\n    }\n\n    /// Returns true for generic fields that are assumed to be modules.\n    pub fn maybe_generic_module(&self) -> bool {\n        // We assumed it might be a module generic if the field is not marked\n        // by any attributes (skip or constant)\n        !self.generic_idents.is_empty() && self.attr.is_none()\n    }\n}\n\npub(crate) fn parse_module_fields(\n    ast: &syn::DeriveInput,\n    generics: &mut ModuleGenerics,\n) -> syn::Result<Vec<ModuleField>> {\n    let mut fields = Vec::new();\n\n    match &ast.data {\n        syn::Data::Struct(struct_data) => {\n            for field in struct_data.fields.iter() {\n                let field_type = parse_module_field_type(field, generics)?;\n                fields.push(ModuleField::new(field.clone(), field_type));\n            }\n        }\n        syn::Data::Enum(_) => panic!(\"Only struct can be derived\"),\n        syn::Data::Union(_) => panic!(\"Only struct can be derived\"),\n    };\n    Ok(fields)\n}\n\npub(crate) fn parse_module_field_type(\n    field: &Field,\n    generics: &mut ModuleGenerics,\n) -> syn::Result<ModuleFieldType> {\n    let mut field_type = ModuleFieldType::default();\n\n    // Check for generics\n    let mut has_backend = false;\n    let mut has_module_bound = false;\n    let field_generics = parse_ty_generics(&field.ty, generics)\n        .into_iter()\n        .filter_map(|ident| {\n            if ident == \"B\" {\n                has_backend = true;\n                None\n            } else {\n                has_module_bound = generics.is_bounded_module(&ident);\n                Some(ident)\n            }\n        })\n        .collect::<HashSet<_>>();\n\n    // Infer if a field is a module\n    let is_primitive = is_primitive_type(&field.ty);\n    let is_param = is_param_type(&field.ty);\n    let is_tensor = is_tensor_type(&field.ty);\n\n    let is_module = !is_primitive && (has_module_bound || is_param || is_tensor || has_backend);\n\n    for attr in &field.attrs {\n        if attr.path().is_ident(\"module\") {\n            attr.parse_nested_meta(|meta| {\n                if meta.path.is_ident(\"skip\") {\n                    // Mark field attribute and generic\n                    field_type.attr = Some(ModuleFieldAttribute::Skip);\n                    for ty in &field_generics {\n                        generics.update(ty, GenericKind::Skip);\n                    }\n                    Ok(())\n                } else {\n                    let path = meta.path.to_token_stream().to_string();\n                    Err(meta.error(format!(\"Unsupported module attribute: {}\", path)))\n                }?;\n\n                if is_param && field_type.attr.is_some() {\n                    Err(meta.error(\"Fields of type 'Param' should not be marked as 'skip'. Use a 'Tensor' instead.\"))\n                } else {\n                    Ok(())\n                }\n            })?;\n        }\n    }\n\n    field_type.is_module = is_module;\n    field_type.generic_idents = field_generics;\n\n    Ok(field_type)\n}\n\nfn type_matches_ident(ty: &syn::Type, idents: &[&str]) -> bool {\n    if let syn::Type::Path(type_path) = ty {\n        // Look at the last segment of the path (e.g., 'Param' in 'burn::module::Param')\n        if let Some(segment) = type_path.path.segments.last() {\n            return idents.contains(&segment.ident.to_string().as_str());\n        }\n    }\n    false\n}\n\nfn is_primitive_type(ty: &syn::Type) -> bool {\n    type_matches_ident(\n        ty,\n        &[\n            \"bool\", \"u8\", \"u16\", \"u32\", \"u64\", \"usize\", \"i8\", \"i16\", \"i32\", \"i64\", \"isize\", \"f32\",\n            \"f64\", \"String\",\n        ],\n    )\n}\n\nfn is_tensor_type(ty: &syn::Type) -> bool {\n    type_matches_ident(ty, &[\"Tensor\"])\n}\n\nfn is_param_type(ty: &syn::Type) -> bool {\n    type_matches_ident(ty, &[\"Param\"])\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/display.rs",
    "content": "use quote::quote;\n\nuse crate::module::{codegen_struct::parse_module_field_type, generics::parse_module_generics};\n\n// Only used for \"const\" modules\npub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {\n    let mut generics = parse_module_generics(&ast.generics);\n    match &ast.data {\n        syn::Data::Struct(data_struct) => {\n            let fields = match &data_struct.fields {\n                syn::Fields::Named(named_fields) => named_fields.named.iter().collect::<Vec<_>>(),\n                syn::Fields::Unit => Vec::new(),\n                _ => panic!(\"attributes_fn only supports structs with named or unit fields\"),\n            };\n            let field_prints = fields.iter().map(|field| {\n                let field_name = &field.ident;\n                let field_type = parse_module_field_type(field, &mut generics).unwrap();\n                if field_type.is_module || field_type.maybe_generic_module() {\n                    // Standard module type, use underlying `ModuleDisplay` impl\n                    quote! { .add(stringify!(#field_name), &self.#field_name) }\n                } else {\n                    // Not a module, use the debug implementation\n                    quote! {\n                        .add_debug_attribute(stringify!(#field_name), &self.#field_name)\n                    }\n                }\n            });\n            let struct_name = &ast.ident;\n            quote! {\n                fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {\n                    content\n                        .set_top_level_type(&stringify!(#struct_name))\n                        #(#field_prints)*\n                        .optional()\n                }\n            }\n        }\n        syn::Data::Enum(data_enum) => {\n            let variant_prints = data_enum.variants.iter().map(|variant| {\n                let variant_name = &variant.ident;\n                match &variant.fields {\n                    syn::Fields::Unit => {\n                        quote! {\n                            Self::#variant_name => {\n                                content.add_formatted(&stringify!(#variant_name).to_string())\n                                    .optional()\n\n                            }\n                        }\n                    }\n                    syn::Fields::Named(named_fields) => {\n                        let field_prints = named_fields.named.iter().map(|field| {\n                            let field_name = &field.ident;\n                            quote! { .add(stringify!(#field_name), &self.#field_name) }\n                        });\n\n                        let field_names = named_fields.named.iter().map(|field| {\n                            let field_name = &field.ident;\n                            quote! { #field_name }\n                        });\n\n                        quote! {\n                            Self::#variant_name { #(#field_names),* } => {\n                                content.set_top_level_type(&stringify!(#variant_name))\n                                #(#field_prints)*\n                                .optional()\n                            }\n                        }\n                    }\n                    syn::Fields::Unnamed(unnamed_fields) => {\n                        let field_names = (0..unnamed_fields.unnamed.len()).map(|i| {\n                            syn::Ident::new(&format!(\"_{i}\"), proc_macro2::Span::call_site())\n                        });\n\n                        let field_prints = field_names.clone().map(|field_name| {\n                            quote! { .add(stringify!(#field_name), #field_name) }\n                        });\n                        quote! {\n                            Self::#variant_name(#(#field_names),*) => {\n                                content.set_top_level_type(&stringify!(#variant_name))\n                                #(#field_prints)*\n                                .optional()\n                            }\n                        }\n                    }\n                }\n            });\n            quote! {\n                fn content(&self, mut content: burn::module::Content) -> Option<burn::module::Content> {\n                    match self {\n                        #(#variant_prints)*\n                    }\n                }\n            }\n        }\n        _ => panic!(\"attributes_fn only supports structs and enums\"),\n    }\n}\n\npub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream {\n    quote! {\n        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n            let formatted = burn::module::ModuleDisplay::format(self, Default::default());\n            write!(f, \"{}\", formatted)\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/generics.rs",
    "content": "use std::collections::{HashMap, HashSet};\n\nuse proc_macro2::Ident;\nuse syn::{GenericParam, Generics, Type, TypeParamBound, WherePredicate, visit::Visit};\n\n#[derive(Debug)]\npub enum GenericKind {\n    /// A generic with `Module<B>` bound.\n    Module,\n    /// A generic used in a field marked by `#[module(skip)]`.\n    Skip,\n    /// A plain generic that does not fit any of the above conditions.\n    Plain,\n}\n\n#[derive(Debug)]\npub struct ModuleGenerics {\n    kinds: HashMap<Ident, GenericKind>,\n}\n\nimpl ModuleGenerics {\n    pub fn is_empty(&self) -> bool {\n        self.kinds.is_empty()\n    }\n\n    pub fn get_generic_kind(&self, ident: &Ident) -> Option<&GenericKind> {\n        self.kinds.get(ident)\n    }\n\n    pub fn is_bounded_module(&self, ident: &Ident) -> bool {\n        self.kinds\n            .get(ident)\n            .map(|kind| matches!(kind, GenericKind::Module))\n            .unwrap_or(false)\n    }\n\n    pub fn update(&mut self, ident: &Ident, kind: GenericKind) {\n        self.kinds.insert(ident.clone(), kind);\n    }\n\n    pub fn contains(&self, ident: &Ident) -> bool {\n        self.kinds.contains_key(ident)\n    }\n}\n\npub fn parse_module_generics(generics: &Generics) -> ModuleGenerics {\n    let mut kinds = HashMap::new();\n\n    // Check inline bounds e.g. `M: Module<B>`\n    for param in &generics.params {\n        if let GenericParam::Type(type_param) = param {\n            let ident = &type_param.ident;\n            if ident != \"B\" {\n                if has_module_bound(&type_param.bounds) {\n                    kinds.insert(ident.clone(), GenericKind::Module);\n                } else {\n                    kinds.insert(ident.clone(), GenericKind::Plain);\n                }\n            }\n        }\n    }\n\n    // Check `where` clauses\n    if let Some(where_clause) = &generics.where_clause {\n        for predicate in &where_clause.predicates {\n            if let WherePredicate::Type(pt) = predicate {\n                // We only care if the bounded type is a simple identifier (like 'M')\n                if let Type::Path(p) = &pt.bounded_ty\n                    && let Some(ident) = p.path.get_ident()\n                    && ident != \"B\"\n                {\n                    if has_module_bound(&pt.bounds) {\n                        kinds.insert(ident.clone(), GenericKind::Module);\n                    } else {\n                        kinds.insert(ident.clone(), GenericKind::Plain);\n                    }\n                }\n            }\n        }\n    }\n\n    ModuleGenerics { kinds }\n}\n\n// TODO: remove special cases for `ident == \"B\"`, this could be used to check for `Backend` bound.\n\n/// Helper to check if a list of bounds contains \"Module\".\nfn has_module_bound(\n    bounds: &syn::punctuated::Punctuated<TypeParamBound, syn::token::Plus>,\n) -> bool {\n    has_bound(bounds, \"Module\")\n}\n\n/// Helper to check if a list of bounds contains the specified bound.\nfn has_bound(\n    bounds: &syn::punctuated::Punctuated<TypeParamBound, syn::token::Plus>,\n    ident: &str,\n) -> bool {\n    bounds.iter().any(|bound| {\n        if let TypeParamBound::Trait(trait_bound) = bound\n            && let Some(segment) = trait_bound.path.segments.last()\n        {\n            return segment.ident == ident;\n        }\n        false\n    })\n}\n\npub fn parse_ty_generics(ty: &Type, declared: &ModuleGenerics) -> HashSet<Ident> {\n    struct Collector<'a> {\n        generics: HashSet<Ident>,\n        declared: &'a ModuleGenerics,\n    }\n\n    impl<'ast, 'a> Visit<'ast> for Collector<'a> {\n        fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) {\n            if type_path.qself.is_none()\n                && let Some(ident) = type_path.path.get_ident()\n                && (self.declared.contains(ident) || ident == \"B\")\n            {\n                self.generics.insert(ident.clone());\n            }\n\n            syn::visit::visit_type_path(self, type_path);\n        }\n    }\n\n    let mut collector = Collector {\n        generics: HashSet::new(),\n        declared,\n    };\n    collector.visit_type(ty);\n    collector.generics\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/mod.rs",
    "content": "pub(crate) mod codegen;\npub(crate) mod codegen_enum;\npub(crate) mod codegen_struct;\npub(crate) mod display;\npub(crate) mod generics;\npub(crate) mod record;\npub(crate) mod record_enum;\npub(crate) mod record_struct;\n\nmod base;\n\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-derive/src/module/record.rs",
    "content": "use proc_macro2::{Ident, TokenStream};\nuse syn::Generics;\n\n/// Basic trait to generate a record type based on the Module struct.\npub(crate) trait ModuleRecordCodegen {\n    /// Generate the record type (i.e a struct)\n    fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics);\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/record_enum.rs",
    "content": "use crate::shared::enum_variant::EnumVariant;\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Generics, Visibility};\n\nuse super::record::ModuleRecordCodegen;\n\n#[derive(new)]\npub(crate) struct EnumModuleRecordCodegen {\n    variants: Vec<EnumVariant>,\n    vis: Visibility,\n}\n\nimpl ModuleRecordCodegen for EnumModuleRecordCodegen {\n    fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics) {\n        let mut variants = quote! {};\n        let vis = &self.vis;\n\n        // Capture the Record enum variant types\n        for variant in self.variants.iter() {\n            let ty = &variant.ty;\n            let name = &variant.ident;\n\n            variants.extend(quote! {\n                /// The module record associative type.\n                #name(<#ty as burn::module::Module<B>>::Record),\n            });\n        }\n\n        let (impl_generics, _generics_ty, generics_where) = generics.split_for_impl();\n\n        (\n            quote! {\n\n                /// The record type for the module.\n                #[derive(burn::record::Record)]\n                #vis enum #record_name #impl_generics #generics_where {\n                    #variants\n                }\n            },\n            generics.clone(),\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/module/record_struct.rs",
    "content": "use std::collections::HashSet;\n\nuse crate::module::codegen_struct::{ModuleField, ModuleFieldAttribute};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Generics, Visibility};\n\nuse super::record::ModuleRecordCodegen;\n\n#[derive(new)]\npub(crate) struct StructModuleRecordCodegen {\n    fields: Vec<ModuleField>,\n    vis: Visibility,\n}\n\nimpl ModuleRecordCodegen for StructModuleRecordCodegen {\n    fn gen_record_type(&self, record_name: &Ident, generics: &Generics) -> (TokenStream, Generics) {\n        let mut fields = quote! {};\n        let vis = &self.vis;\n\n        let mut used_generics = HashSet::new();\n\n        for field in self.fields.iter() {\n            let ty = &field.field.ty;\n            let name = &field.field.ident;\n\n            if field.field_type.is_persistent_module() || field.field_type.maybe_generic_module() {\n                fields.extend(quote! {\n                    /// The module record associative type.\n                    #vis #name: <#ty as burn::module::Module<B>>::Record,\n                });\n\n                used_generics.extend(&field.field_type.generic_idents);\n            } else {\n                match field.field_type.attr {\n                    // Default (None) gets skipped\n                    None | Some(ModuleFieldAttribute::Skip) => {\n                        fields.extend(quote! {\n                            #[allow(missing_docs)]\n                            #vis #name: burn::module::EmptyRecord,\n                        });\n\n                        // Do not capture generics from this field since it produces an empty record\n                    }\n                }\n            }\n        }\n\n        let mut filtered_generics = generics.clone();\n        filtered_generics.params = generics\n            .params\n            .iter()\n            .filter(|param| match param {\n                syn::GenericParam::Type(ty) if ty.ident == \"B\" => true,\n                syn::GenericParam::Type(ty) => used_generics.contains(&ty.ident),\n                _ => true,\n            })\n            .cloned()\n            .collect();\n\n        if let Some(where_clause) = &mut filtered_generics.where_clause {\n            where_clause.predicates = where_clause\n                .predicates\n                .iter()\n                .filter(|pred| {\n                    match pred {\n                        syn::WherePredicate::Type(ty) => {\n                            // Check if the bounded type is one of our remaining generics\n                            if let syn::Type::Path(p) = &ty.bounded_ty\n                                && let Some(ident) = p.path.get_ident()\n                            {\n                                return ident == \"B\" || used_generics.contains(ident);\n                            }\n                            true\n                        }\n                        _ => true,\n                    }\n                })\n                .cloned()\n                .collect();\n\n            // Remove the where clause entirely\n            if where_clause.predicates.is_empty() {\n                filtered_generics.where_clause = None;\n            }\n        }\n\n        let (impl_generics, _generics_ty, generics_where) = filtered_generics.split_for_impl();\n\n        (\n            quote! {\n\n                /// The record type for the module.\n                #[derive(burn::record::Record)]\n                #vis struct #record_name #impl_generics #generics_where {\n                    #fields\n                }\n            },\n            filtered_generics,\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/base.rs",
    "content": "use super::{\n    codegen::generate_record,\n    item::{codegen_enum::EnumRecordItemCodegen, codegen_struct::StructRecordItemCodegen},\n};\n\npub(crate) fn derive_impl(ast: &syn::DeriveInput) -> proc_macro::TokenStream {\n    match &ast.data {\n        syn::Data::Struct(_) => generate_record::<StructRecordItemCodegen>(ast),\n        syn::Data::Enum(_) => generate_record::<EnumRecordItemCodegen>(ast),\n        syn::Data::Union(_) => panic!(\"Union modules aren't supported yet.\"),\n    }\n    .into()\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/codegen.rs",
    "content": "use proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Generics, parse_quote};\n\nuse crate::record::item::codegen::RecordItemCodegen;\n\npub(crate) fn generate_record<G: RecordItemCodegen>(ast: &syn::DeriveInput) -> TokenStream {\n    let record_gen: syn::Result<RecordCodegen<G>> = RecordCodegen::from_ast(ast);\n    match record_gen {\n        Ok(record_gen) => {\n            let item_type = record_gen.gen_record_type();\n            let record_impl = record_gen.gen_impl_record();\n\n            quote! {\n                #item_type\n                #record_impl\n            }\n        }\n        Err(err) => err.to_compile_error(),\n    }\n}\n\npub(crate) struct RecordCodegen<G: RecordItemCodegen> {\n    /// Record type info.\n    ty: RecordType,\n    /// Record item code gen.\n    codegen: G,\n}\n\nimpl<G: RecordItemCodegen> RecordCodegen<G> {\n    /// Generate the record type with the correct generics.\n    pub(crate) fn gen_record_type(&self) -> TokenStream {\n        // Add precision settings type bound\n        let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};\n        let mut generics = self.ty.generics.clone();\n\n        for param in param.params.into_iter() {\n            generics.params.push(param);\n        }\n\n        // Generate the record item definition\n        self.codegen\n            .gen_item_type(&self.ty.item, &generics, self.ty.has_backend)\n    }\n\n    /// Generate the implementation for the Record trait.\n    pub(crate) fn gen_impl_record(&self) -> TokenStream {\n        // Capture the record type's generics and bounds in where clauses\n        let item_generics = self.record_item_generics();\n        let (_, ty_generics_item, _) = item_generics.split_for_impl();\n        let (impl_generics, ty_generics, where_clause) = self.ty.generics.split_for_impl();\n\n        let impl_generics = if let Some(impl_generic) = self.impl_generics() {\n            impl_generic\n        } else {\n            quote! { #impl_generics }\n        };\n\n        let name_item = &self.ty.item;\n        let into_item_fn = self.codegen.gen_into_item(name_item);\n        let from_item_fn = self.codegen.gen_from_item();\n\n        // Return the generated stream of token trees (i.e., code to be generated)\n        let name = &self.ty.name;\n        quote! {\n            impl #impl_generics burn::record::Record<B> for #name #ty_generics #where_clause {\n                type Item<S: burn::record::PrecisionSettings> = #name_item #ty_generics_item;\n\n                #into_item_fn\n                #from_item_fn\n\n            }\n        }\n    }\n\n    /// Add backend generic type to the implementation block.\n    fn impl_generics(&self) -> Option<TokenStream> {\n        if self.ty.has_backend {\n            return None;\n        }\n\n        let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };\n        let mut generics = self.ty.generics.clone();\n        generics.params.push(syn::GenericParam::Type(param));\n\n        let (impl_generics, _ty_generics, _where_clause) = generics.split_for_impl();\n\n        Some(quote! {#impl_generics})\n    }\n\n    /// Get the generics attached to the record item type.\n    fn record_item_generics(&self) -> Generics {\n        let param: syn::Generics = parse_quote! { <S: burn::record::PrecisionSettings >};\n        let mut generics = self.ty.generics.clone();\n        for param in param.params.into_iter() {\n            generics.params.push(param);\n        }\n\n        if !self.ty.has_backend {\n            let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };\n            generics.params.push(syn::GenericParam::Type(param));\n        }\n\n        generics\n    }\n\n    pub(crate) fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {\n        Ok(Self {\n            ty: RecordType::from_ast(ast),\n            codegen: G::from_ast(ast)?,\n        })\n    }\n}\n\n/// Information about a record type.\nstruct RecordType {\n    /// Record type name.\n    name: Ident,\n    /// Record item type name.\n    item: Ident,\n    /// Lifetimes and type parameters attached to the record type declaration.\n    generics: Generics,\n    /// Whether or not the record type should specify a backend generic.\n    has_backend: bool,\n}\n\nimpl RecordType {\n    fn from_ast(ast: &syn::DeriveInput) -> Self {\n        let name = ast.ident.clone();\n        let item = Ident::new(format!(\"{name}Item\").as_str(), name.span());\n        let has_backend = ast\n            .generics\n            .type_params()\n            .map(|param| param.ident == \"B\")\n            .reduce(|accum, is_backend| is_backend || accum)\n            .unwrap_or(false);\n\n        Self {\n            name,\n            item,\n            generics: ast.generics.clone(),\n            has_backend,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/item/codegen.rs",
    "content": "use proc_macro2::{Ident, TokenStream};\nuse syn::Generics;\n\n/// Basic trait to be implemented for record generation.\npub(crate) trait RecordItemCodegen {\n    /// Initialize the record item.\n    fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self>\n    where\n        Self: Sized;\n    /// Generate the record item type.\n    fn gen_item_type(\n        &self,\n        item_name: &Ident,\n        generics: &Generics,\n        has_backend: bool,\n    ) -> TokenStream;\n    /// Generate the into_item function.\n    fn gen_into_item(&self, item_name: &Ident) -> TokenStream;\n    /// Generate the from item function.\n    fn gen_from_item(&self) -> TokenStream;\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/item/codegen_enum.rs",
    "content": "use crate::shared::enum_variant::{EnumVariant, parse_variants};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Generics, Visibility, parse_quote};\n\nuse super::codegen::RecordItemCodegen;\n\npub(crate) struct EnumRecordItemCodegen {\n    /// Enum variants.\n    variants: Vec<EnumVariant>,\n    vis: Visibility,\n}\n\nimpl RecordItemCodegen for EnumRecordItemCodegen {\n    fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {\n        Ok(Self {\n            variants: parse_variants(ast)?,\n            vis: ast.vis.clone(),\n        })\n    }\n\n    fn gen_item_type(\n        &self,\n        item_name: &Ident,\n        generics: &Generics,\n        has_backend: bool,\n    ) -> TokenStream {\n        let mut variants = quote! {};\n        let mut serde_bounds = quote! {};\n        let mut clone_bounds = vec![];\n        let mut clone_match_arms = quote! {};\n        let vis = &self.vis;\n\n        // Capture the Record enum variant types and names to transpose them in RecordItem\n        for variant in self.variants.iter() {\n            let ty = &variant.ty;\n            let name = &variant.ident;\n\n            variants.extend(quote! {\n                /// Variant to be serialized.\n                #name(<#ty as burn::record::Record<B>>::Item<S>),\n            });\n\n            // Item types must implement serialization/deserialization\n            serde_bounds.extend(quote! {\n                <#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,\n            });\n            clone_bounds.push(parse_quote! {\n                <#ty as burn::record::Record<B>>::Item<S>: Clone\n            });\n\n            clone_match_arms.extend(quote! {\n                Self::#name(inner) => Self::#name(inner.clone()),\n            });\n        }\n        let serde_bound = serde_bounds.to_string();\n\n        // Capture the type's generics and bounds in where clauses\n        let mut generics = generics.clone();\n        if !has_backend {\n            let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };\n            generics.params.push(syn::GenericParam::Type(param));\n        }\n        let (generics, type_generics, generics_where) = generics.split_for_impl();\n\n        let clone_bounds = generics_where.cloned().map(|mut where_clause| {\n            for predicate in clone_bounds {\n                where_clause.predicates.push(predicate);\n            }\n            where_clause\n        });\n\n        let clone_impl = quote! {\n            impl #generics Clone for #item_name #type_generics #clone_bounds {\n                fn clone(&self) -> Self {\n                    match self {\n                        #clone_match_arms\n                    }\n                }\n            }\n        };\n\n        // Return the generated stream of token trees (i.e., code to be generated)\n        quote! {\n\n            /// The record item type for the module.\n            #[derive(burn::serde::Serialize, burn::serde::Deserialize)]\n            #[serde(crate = \"burn::serde\")]\n            #[serde(bound = #serde_bound)]\n            #vis enum #item_name #generics #generics_where {\n                #variants\n            }\n\n            #clone_impl\n        }\n    }\n\n    fn gen_into_item(&self, _item_name: &Ident) -> TokenStream {\n        let mut into_item_match_arms = quote! {};\n\n        for variant in self.variants.iter() {\n            let name = &variant.ident;\n\n            into_item_match_arms.extend(quote! {\n                Self::#name(record) => Self::Item::#name(burn::record::Record::<B>::into_item::<S>(record)),\n            });\n        }\n\n        quote! {\n            fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {\n                match self {\n                    #into_item_match_arms\n                }\n            }\n        }\n    }\n\n    fn gen_from_item(&self) -> TokenStream {\n        let mut from_item_match_arms = quote! {};\n\n        for variant in self.variants.iter() {\n            let name = &variant.ident;\n\n            from_item_match_arms.extend(quote! {\n                Self::Item::#name(item) => Self::#name(burn::record::Record::<B>::from_item::<S>(item, device)),\n            });\n        }\n\n        quote! {\n            fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n                match item {\n                    #from_item_match_arms\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/item/codegen_struct.rs",
    "content": "use crate::shared::field::{FieldTypeAnalyzer, parse_fields};\nuse proc_macro2::{Ident, TokenStream};\nuse quote::quote;\nuse syn::{Generics, Visibility, parse_quote};\n\nuse super::codegen::RecordItemCodegen;\n\npub(crate) struct StructRecordItemCodegen {\n    fields: Vec<FieldTypeAnalyzer>,\n    vis: Visibility,\n}\n\nimpl RecordItemCodegen for StructRecordItemCodegen {\n    fn from_ast(ast: &syn::DeriveInput) -> syn::Result<Self> {\n        Ok(Self {\n            fields: parse_fields(ast)\n                .into_iter()\n                .map(FieldTypeAnalyzer::new)\n                .collect(),\n            vis: ast.vis.clone(),\n        })\n    }\n\n    fn gen_item_type(\n        &self,\n        item_name: &Ident,\n        generics: &Generics,\n        has_backend: bool,\n    ) -> TokenStream {\n        let mut fields = quote! {};\n        let mut serde_bounds = quote! {};\n        let mut clone_bounds = vec![];\n        let mut clone_delegate = quote! {};\n        let vis = &self.vis;\n\n        for field in self.fields.iter() {\n            let ty = &field.field.ty;\n            let name = &field.field.ident;\n\n            fields.extend(quote! {\n                /// Field to be serialized.\n                pub #name: <#ty as burn::record::Record<B>>::Item<S>,\n            });\n\n            serde_bounds.extend(quote! {\n                <#ty as burn::record::Record<B>>::Item<S>: burn::serde::Serialize + burn::serde::de::DeserializeOwned,\n            });\n\n            clone_bounds.push(parse_quote! {\n                <#ty as burn::record::Record<B>>::Item<S>: Clone\n            });\n\n            clone_delegate.extend(quote! {\n                #name: self.#name.clone(),\n            });\n        }\n        let serde_bound = serde_bounds.to_string();\n\n        let mut generics = generics.clone();\n        if !has_backend {\n            let param: syn::TypeParam = parse_quote! { B: burn::tensor::backend::Backend };\n            generics.params.push(syn::GenericParam::Type(param));\n        }\n        let (generics, type_generics, generics_where) = generics.split_for_impl();\n\n        let clone_bounds = generics_where.cloned().map(|mut where_clause| {\n            for predicate in clone_bounds {\n                where_clause.predicates.push(predicate);\n            }\n            where_clause\n        });\n\n        let clone_impl = quote! {\n            impl #generics Clone for #item_name #type_generics #clone_bounds {\n                fn clone(&self) -> Self {\n                    Self {\n                        #clone_delegate\n                    }\n                }\n            }\n        };\n\n        quote! {\n\n            /// The record item type for the module.\n            #[derive(burn::serde::Serialize, burn::serde::Deserialize)]\n            #[serde(crate = \"burn::serde\")]\n            #[serde(bound = #serde_bound)]\n            #vis struct #item_name #generics #generics_where {\n                #fields\n            }\n\n            #clone_impl\n        }\n    }\n\n    fn gen_into_item(&self, item_name: &Ident) -> TokenStream {\n        let mut body_into_item = quote! {};\n\n        for field in self.fields.iter() {\n            let name = &field.field.ident;\n\n            body_into_item.extend(quote! {\n                #name: burn::record::Record::<B>::into_item::<S>(self.#name),\n            });\n        }\n\n        quote! {\n            fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {\n                #item_name {\n                    #body_into_item\n                }\n            }\n        }\n    }\n\n    fn gen_from_item(&self) -> TokenStream {\n        let mut body_from_item = quote! {};\n\n        for field in self.fields.iter() {\n            let name = &field.field.ident;\n\n            body_from_item.extend(quote! {\n                #name: burn::record::Record::<B>::from_item::<S>(item.#name, device),\n            });\n        }\n\n        quote! {\n            fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n                Self {\n                    #body_from_item\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/record/item/mod.rs",
    "content": "pub(crate) mod codegen;\npub(crate) mod codegen_enum;\npub(crate) mod codegen_struct;\n"
  },
  {
    "path": "crates/burn-derive/src/record/mod.rs",
    "content": "pub(crate) mod codegen;\npub(crate) mod item;\n\nmod base;\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-derive/src/shared/attribute.rs",
    "content": "use syn::{Attribute, Meta};\n\npub struct AttributeAnalyzer {\n    attr: Attribute,\n}\n\n#[derive(Clone)]\npub struct AttributeItem {\n    pub value: syn::Lit,\n}\n\nimpl AttributeAnalyzer {\n    pub fn new(attr: Attribute) -> Self {\n        Self { attr }\n    }\n\n    pub fn item(&self) -> AttributeItem {\n        let value = match &self.attr.meta {\n            Meta::List(val) => val.parse_args::<syn::MetaNameValue>().unwrap(),\n            Meta::NameValue(meta) => meta.clone(),\n            Meta::Path(_) => panic!(\"Path meta unsupported\"),\n        };\n\n        let lit = match value.value {\n            syn::Expr::Lit(lit) => lit.lit,\n            _ => panic!(\"Only literal is supported\"),\n        };\n\n        AttributeItem { value: lit }\n    }\n\n    pub fn has_name(&self, name: &str) -> bool {\n        Self::path_syn_name(self.attr.path()) == name\n    }\n\n    fn path_syn_name(path: &syn::Path) -> String {\n        let length = path.segments.len();\n        let mut name = String::new();\n        for (i, segment) in path.segments.iter().enumerate() {\n            if i == length - 1 {\n                name += segment.ident.to_string().as_str();\n            } else {\n                let tmp = segment.ident.to_string() + \"::\";\n                name += tmp.as_str();\n            }\n        }\n        name\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/shared/enum_variant.rs",
    "content": "use proc_macro2::{Ident, Span, TokenStream};\nuse quote::quote;\nuse syn::{FieldsNamed, Variant};\n\n/// Process a variant of an enum where the output is the result of the given mapper.\npub(crate) fn map_enum_variant<Mapper>(\n    variant: &Variant,\n    mapper: Mapper,\n) -> (TokenStream, TokenStream)\nwhere\n    Mapper: Fn(&Ident) -> TokenStream,\n{\n    let gen_fields_unnamed = |num: usize| {\n        let mut inputs = Vec::new();\n        let mut outputs = Vec::new();\n\n        for i in 0..num {\n            let arg_name = Ident::new(&format!(\"arg_{i}\"), Span::call_site());\n            let input = quote! { #arg_name };\n            let output = mapper(&arg_name);\n\n            inputs.push(input);\n            outputs.push(output);\n        }\n\n        (quote! (( #(#inputs),* )), quote! (( #(#outputs),* )))\n    };\n    let gen_fields_named = |fields: &FieldsNamed| {\n        let mut inputs = Vec::new();\n        let mut outputs = Vec::new();\n\n        fields.named.iter().for_each(|field| {\n            let ident = field.ident.as_ref().expect(\"Named field to have a name.\");\n            let input = quote! { #ident };\n            let output = mapper(ident);\n\n            inputs.push(input);\n            outputs.push(quote! {\n                #ident: #output\n            });\n        });\n\n        (quote! {{ #(#inputs),* }}, quote! {{ #(#outputs),* }})\n    };\n\n    match &variant.fields {\n        syn::Fields::Named(fields) => gen_fields_named(fields),\n        syn::Fields::Unnamed(_) => gen_fields_unnamed(variant.fields.len()),\n        syn::Fields::Unit => (quote! {}, quote! {}),\n    }\n}\n\n/// An enum variant (simplified).\npub(crate) struct EnumVariant {\n    pub ident: syn::Ident,\n    pub ty: syn::Type,\n}\npub(crate) fn parse_variants(ast: &syn::DeriveInput) -> syn::Result<Vec<EnumVariant>> {\n    let enum_data = match &ast.data {\n        syn::Data::Enum(data) => data,\n        _ => {\n            return Err(syn::Error::new_spanned(\n                ast,\n                \"Module can only be derived for enums.\",\n            ));\n        }\n    };\n\n    let mut variants = Vec::new();\n\n    for variant in enum_data.variants.iter() {\n        for attr in &variant.attrs {\n            if attr.path().is_ident(\"module\") {\n                Err(syn::Error::new_spanned(\n                    variant,\n                    \"Module attributes are not supported for enum variants.\",\n                ))?;\n            }\n        }\n\n        match &variant.fields {\n            syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {\n                let field = &fields.unnamed[0];\n\n                variants.push(EnumVariant {\n                    ident: variant.ident.clone(),\n                    ty: field.ty.clone(),\n                });\n            }\n            syn::Fields::Unnamed(_) => {\n                return Err(syn::Error::new_spanned(\n                    variant,\n                    \"Module derive only supports tuple enum variants with exactly one field.\",\n                ));\n            }\n            syn::Fields::Named(_) => {\n                return Err(syn::Error::new_spanned(\n                    variant,\n                    \"Module derive does not support struct enum variants.\",\n                ));\n            }\n            syn::Fields::Unit => {\n                return Err(syn::Error::new_spanned(\n                    variant,\n                    \"Module derive does not support unit enum variants.\",\n                ));\n            }\n        }\n    }\n\n    Ok(variants)\n}\n"
  },
  {
    "path": "crates/burn-derive/src/shared/field.rs",
    "content": "use super::attribute::AttributeAnalyzer;\nuse proc_macro2::Ident;\nuse syn::{Field, Type, TypePath};\n\n#[derive(Clone)]\npub struct FieldTypeAnalyzer {\n    pub field: Field,\n}\n\nimpl FieldTypeAnalyzer {\n    pub fn new(field: Field) -> Self {\n        FieldTypeAnalyzer { field }\n    }\n\n    pub fn ident(&self) -> Ident {\n        self.field.ident.clone().unwrap()\n    }\n\n    pub fn is_of_type(&self, paths: &[&str]) -> bool {\n        match &self.field.ty {\n            syn::Type::Path(path) => {\n                let name = Self::path_name(path);\n                paths.contains(&name.as_str())\n            }\n            _ => false,\n        }\n    }\n\n    #[allow(dead_code)]\n    pub fn first_generic_field(&self) -> TypePath {\n        let err = || panic!(\"Field {} as no generic\", self.field.ident.clone().unwrap());\n        match &self.field.ty {\n            syn::Type::Path(path) => Self::path_generic_argument(path),\n            _ => err(),\n        }\n    }\n    pub fn path_generic_argument(path: &TypePath) -> TypePath {\n        let segment = path.path.segments.last().unwrap();\n        let err = || panic!(\"Path segment {} has no generic\", segment.ident.clone(),);\n        match &segment.arguments {\n            syn::PathArguments::None => err(),\n            syn::PathArguments::AngleBracketed(param) => {\n                let first_param = param.args.first().unwrap();\n\n                if let syn::GenericArgument::Type(Type::Path(path)) = first_param {\n                    path.clone()\n                } else {\n                    err()\n                }\n            }\n            syn::PathArguments::Parenthesized(_) => err(),\n        }\n    }\n\n    fn path_name(path: &TypePath) -> String {\n        let length = path.path.segments.len();\n        let mut name = String::new();\n        for (i, segment) in path.path.segments.iter().enumerate() {\n            if i == length - 1 {\n                name += segment.ident.to_string().as_str();\n            } else {\n                let tmp = segment.ident.to_string() + \"::\";\n                name += tmp.as_str();\n            }\n        }\n        name\n    }\n\n    /// Returns the docs of the field.\n    pub fn docs(&self) -> impl Iterator<Item = &syn::Attribute> {\n        self.field\n            .attrs\n            .iter()\n            .filter(|attr| attr.path().is_ident(\"doc\"))\n    }\n\n    pub fn attributes(&self) -> impl Iterator<Item = AttributeAnalyzer> {\n        self.field\n            .attrs\n            .clone()\n            .into_iter()\n            .map(AttributeAnalyzer::new)\n    }\n}\n\npub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {\n    let mut fields = Vec::new();\n\n    match &ast.data {\n        syn::Data::Struct(struct_data) => {\n            for field in struct_data.fields.iter() {\n                fields.push(field.clone());\n            }\n        }\n        syn::Data::Enum(_) => panic!(\"Only struct can be derived\"),\n        syn::Data::Union(_) => panic!(\"Only struct can be derived\"),\n    };\n    fields\n}\n"
  },
  {
    "path": "crates/burn-derive/src/shared/generics.rs",
    "content": "use proc_macro2::Ident;\nuse quote::quote;\nuse syn::{Generics, WhereClause, WherePredicate, parse_quote};\n\n#[derive(new)]\npub struct GenericsHelper {\n    pub(crate) generics: Generics,\n}\n\nimpl GenericsHelper {\n    pub fn add_predicate(&mut self, predicate: WherePredicate) {\n        let where_clause: WhereClause = match &self.generics.where_clause {\n            Some(val) => parse_quote! {\n                #val\n                    #predicate,\n            },\n            None => parse_quote! {\n                where\n                    #predicate,\n            },\n        };\n        self.generics.where_clause = Some(where_clause);\n    }\n\n    pub fn consts(&self) -> Vec<Ident> {\n        self.generics\n            .const_params()\n            .map(|c| c.ident.clone())\n            .collect()\n    }\n\n    pub fn types(&self) -> Vec<Ident> {\n        self.generics\n            .type_params()\n            .map(|tp| tp.ident.clone())\n            .collect()\n    }\n\n    pub fn fetch_backend_trait(&self) -> proc_macro2::TokenStream {\n        static BACKEND_TRAIT_COMPILATION_ERROR_MSG: &str =\n            \"Modules should be generic over a backend.\n    - The generic argument named `B` should have its first trait bound being a backend trait.\n    - The default backend trait is `burn::tensor::backend::Backend`.\n    - Any backend trait is supported.\";\n\n        for param in self.generics.params.iter() {\n            if let syn::GenericParam::Type(ty) = &param\n                && ty.ident == \"B\"\n            {\n                let bound = ty\n                    .bounds\n                    .first()\n                    .expect(BACKEND_TRAIT_COMPILATION_ERROR_MSG);\n\n                return quote! {\n                    #bound\n                };\n            }\n        }\n\n        panic!(\"{BACKEND_TRAIT_COMPILATION_ERROR_MSG}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-derive/src/shared/mod.rs",
    "content": "pub(crate) mod attribute;\npub(crate) mod enum_variant;\npub(crate) mod field;\npub(crate) mod generics;\n"
  },
  {
    "path": "crates/burn-dispatch/Cargo.toml",
    "content": "[package]\nauthors = [\n    \"laggui <lagrange.guillaume.1@gmail.com>\",\n    \"nathanielsimard <nathaniel.simard.42@gmail.com>\",\n]\ncategories = [\"science\"]\ndescription = \"Backend dispatch for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-dispatch\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-dispatch\"\ndocumentation = \"https://docs.rs/burn-dispatch\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"std\",\n    \"ndarray\",\n    \"burn-autodiff?/default\",\n    \"burn-cpu?/default\",\n    \"burn-cuda?/default\",\n    \"burn-ndarray?/default\",\n    \"burn-rocm?/default\",\n    \"burn-tch?/default\",\n    \"burn-wgpu?/default\",\n]\ndoc = [\"default\"]\nstd = [\n    \"burn-backend/std\",\n    \"burn-std/std\",\n    \"burn-autodiff?/std\",\n    \"burn-cpu?/std\",\n    \"burn-cuda?/std\",\n    \"burn-ndarray?/std\",\n    \"burn-rocm?/std\",\n    \"burn-tch?/std\",\n    \"burn-wgpu?/std\",\n]\ntracing = [\n    \"burn-autodiff?/tracing\",\n    \"burn-cpu?/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-ndarray?/tracing\",\n    \"burn-rocm?/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-wgpu?/tracing\",\n]\n\n# Backends\ncuda = [\"burn-cuda\"]\nrocm = [\"burn-rocm\"]\nndarray = [\"burn-ndarray\"]\ntch = [\"burn-tch\"]\nvulkan = [\"wgpu\", \"burn-wgpu/vulkan\"]\nwebgpu = [\"wgpu\", \"burn-wgpu/webgpu\"]\nmetal = [\"wgpu\", \"burn-wgpu/metal\"]\nwgpu = [\"burn-wgpu\"]\ncpu = [\"burn-cpu\"]\nautodiff = [\"burn-autodiff\"]\n\n# Backend features\nautotune = [\n    \"burn-wgpu?/autotune\",\n    \"burn-cuda?/autotune\",\n    \"burn-rocm?/autotune\",\n    \"burn-cpu?/autotune\",\n]\nautotune-checks = [\n    \"burn-wgpu?/autotune-checks\",\n    \"burn-cuda?/autotune-checks\",\n    \"burn-rocm?/autotune-checks\",\n    \"burn-cpu?/autotune-checks\",\n]\nfusion = [\n    \"burn-wgpu?/fusion\",\n    \"burn-cuda?/fusion\",\n    \"burn-rocm?/fusion\",\n    \"burn-cpu?/fusion\",\n]\n\n[dependencies]\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\n\n# Backends\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-cpu = { path = \"../burn-cpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\n# Op macros with `.as_$inner_kind()`\npaste = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-dispatch/README.md",
    "content": "# Burn Backend Dispatch\n\nA multi-backend dispatch that forwards the tensor operations to the appropriate backend.\n"
  },
  {
    "path": "crates/burn-dispatch/build.rs",
    "content": "fn main() {\n    println!(\"cargo::rustc-check-cfg=cfg(wgpu_metal)\");\n    println!(\"cargo::rustc-check-cfg=cfg(wgpu_vulkan)\");\n    println!(\"cargo::rustc-check-cfg=cfg(wgpu_webgpu)\");\n\n    // Detect which single wgpu backend is enabled\n    let metal = cfg!(feature = \"metal\");\n    let vulkan = cfg!(feature = \"vulkan\");\n    let webgpu = cfg!(feature = \"webgpu\");\n    let enabled = [(metal, \"metal\"), (vulkan, \"vulkan\"), (webgpu, \"webgpu\")]\n        .iter()\n        .filter(|x| x.0)\n        .map(|x| x.1)\n        .collect::<Vec<_>>();\n\n    // WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error.\n    // In workspace builds with multiple features, we emit a warning and disable all WGPU backends.\n    if enabled.len() > 1 {\n        println!(\n            \"cargo:warning=Only one WGPU backend can be enabled at once. Detected: [{}]. No WGPU backend will be available in this build. This is expected in workspace builds. For production, enable only one of: metal, vulkan, or webgpu.\",\n            enabled.join(\", \")\n        );\n        return;\n    }\n\n    if metal {\n        println!(\"cargo:rustc-cfg=wgpu_metal\");\n    }\n    if vulkan {\n        println!(\"cargo:rustc-cfg=wgpu_vulkan\");\n    }\n    if webgpu {\n        println!(\"cargo:rustc-cfg=wgpu_webgpu\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/backend.rs",
    "content": "use alloc::format;\nuse alloc::string::String;\n\nuse burn_backend::Backend;\nuse burn_backend::ExecutionError;\nuse burn_std::DType;\n\n#[cfg(feature = \"autodiff\")]\nuse burn_autodiff::grads::Gradients;\n#[cfg(feature = \"autodiff\")]\nuse burn_backend::AutodiffBackend;\n\nuse crate::DispatchTensorKind;\nuse crate::backends::*;\nuse crate::{DispatchDevice, DispatchTensor};\n\n/// The main execution backend in Burn.\n///\n/// [`Dispatch`] acts as a global backend that can manage multiple underlying\n/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).  \n/// It is responsible for:\n/// - Dispatching tensor operations to the appropriate backend.\n/// - Managing cross-backend tensor transfers.\n///\n/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations\n/// in a backend-agnostic way. It allows Burn to provide a unified, global backend\n/// for users while still leveraging multiple specialized backends under the hood.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn::Dispatch;\n/// use burn::DispatchDevice;\n///\n/// // Select the device to execute operations on\n/// let device = DispatchDevice::Cuda(Default::default());\n///\n/// // Create a tensor using the global backend\n/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);\n/// ```\n#[derive(Debug, Default, Clone)]\npub struct Dispatch;\n\nimpl Backend for Dispatch {\n    type Device = DispatchDevice;\n\n    type FloatTensorPrimitive = DispatchTensor;\n\n    // TODO: either allow default dtype generic or remove associated types entirely?\n    type FloatElem = f32;\n\n    type IntTensorPrimitive = DispatchTensor;\n\n    type IntElem = i32;\n\n    type BoolTensorPrimitive = DispatchTensor;\n\n    type BoolElem = u8;\n\n    type QuantizedTensorPrimitive = DispatchTensor;\n\n    fn name(device: &Self::Device) -> String {\n        let inner = dispatch_device!(device, |device| B::name(device));\n        format!(\"dispatch<{inner}>\")\n    }\n\n    fn seed(device: &Self::Device, seed: u64) {\n        dispatch_device!(device, |device| B::seed(device, seed))\n    }\n\n    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {\n        dispatch_device!(device, |device| B::sync(device))\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {\n        dispatch_device!(device, |device| B::dtype_usage(device, dtype))\n    }\n\n    fn ad_enabled(device: &Self::Device) -> bool {\n        match device {\n            #[cfg(feature = \"autodiff\")]\n            DispatchDevice::Autodiff(_) => true,\n            _ => false,\n        }\n    }\n}\n\n#[cfg(feature = \"autodiff\")]\nimpl AutodiffBackend for Dispatch {\n    type InnerBackend = Dispatch;\n\n    type Gradients = Gradients;\n\n    fn backward(tensor: DispatchTensor) -> Self::Gradients {\n        let DispatchTensor { kind, .. } = tensor;\n        match kind {\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(tensor) => match *tensor {\n                #[cfg(feature = \"cpu\")]\n                DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),\n                #[cfg(feature = \"cuda\")]\n                DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),\n                #[cfg(wgpu_metal)]\n                DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),\n                #[cfg(feature = \"rocm\")]\n                DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),\n                #[cfg(wgpu_vulkan)]\n                DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),\n                #[cfg(wgpu_webgpu)]\n                DispatchTensorKind::WebGpu(tensor) => tensor.autodiff().backward(),\n                #[cfg(feature = \"ndarray\")]\n                DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(),\n                DispatchTensorKind::Autodiff(_) => {\n                    panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n            _ => panic!(\"Requires autodiff tensor.\"),\n        }\n    }\n\n    fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {\n        let DispatchTensor {\n            kind,\n            checkpointing,\n        } = tensor;\n        let grad = match &kind {\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {\n                #[cfg(feature = \"cpu\")]\n                DispatchTensorKind::Cpu(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"cuda\")]\n                DispatchTensorKind::Cuda(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_metal)]\n                DispatchTensorKind::Metal(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"rocm\")]\n                DispatchTensorKind::Rocm(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_vulkan)]\n                DispatchTensorKind::Vulkan(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_webgpu)]\n                DispatchTensorKind::WebGpu(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"ndarray\")]\n                DispatchTensorKind::NdArray(tensor) => tensor\n                    .as_autodiff()\n                    .grad(grads)\n                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),\n                DispatchTensorKind::Autodiff(_) => {\n                    panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n            _ => panic!(\"Requires autodiff tensor.\"),\n        };\n        grad.map(|kind| DispatchTensor {\n            kind,\n            checkpointing: *checkpointing,\n        })\n    }\n\n    fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {\n        let DispatchTensor {\n            kind,\n            checkpointing,\n        } = tensor;\n        let grad = match &kind {\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {\n                #[cfg(feature = \"cpu\")]\n                DispatchTensorKind::Cpu(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"cuda\")]\n                DispatchTensorKind::Cuda(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_metal)]\n                DispatchTensorKind::Metal(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"rocm\")]\n                DispatchTensorKind::Rocm(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_vulkan)]\n                DispatchTensorKind::Vulkan(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),\n                #[cfg(wgpu_webgpu)]\n                DispatchTensorKind::WebGpu(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),\n                #[cfg(feature = \"ndarray\")]\n                DispatchTensorKind::NdArray(tensor) => tensor\n                    .as_autodiff()\n                    .grad_remove(grads)\n                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),\n                DispatchTensorKind::Autodiff(_) => {\n                    panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n            _ => panic!(\"Requires autodiff tensor.\"),\n        };\n        grad.map(|kind| DispatchTensor {\n            kind,\n            checkpointing: *checkpointing,\n        })\n    }\n\n    fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {\n        let DispatchTensor {\n            kind,\n            checkpointing,\n        } = tensor;\n        let DispatchTensor {\n            kind: grad,\n            checkpointing: grad_ckp,\n        } = grad;\n        debug_assert_eq!(checkpointing, &grad_ckp);\n\n        match &kind {\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) {\n                #[cfg(feature = \"cpu\")]\n                (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(feature = \"cuda\")]\n                (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(wgpu_metal)]\n                (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(feature = \"rocm\")]\n                (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(wgpu_vulkan)]\n                (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(wgpu_webgpu)]\n                (DispatchTensorKind::WebGpu(tensor), DispatchTensorKind::WebGpu(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                #[cfg(feature = \"ndarray\")]\n                (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => {\n                    tensor.as_autodiff().grad_replace(grads, grad.float())\n                }\n                (DispatchTensorKind::Autodiff(_), _) => {\n                    panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n                (t, g) => panic!(\n                    \"The provided tensors are not on the same backend. Got backends {t:?} and {g:?}.\"\n                ),\n            },\n            _ => panic!(\"Requires autodiff tensor.\"),\n        }\n    }\n\n    fn inner(tensor: DispatchTensor) -> DispatchTensor {\n        let DispatchTensor {\n            kind,\n            checkpointing,\n        } = tensor;\n\n        let kind = match kind {\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind {\n                #[cfg(feature = \"cpu\")]\n                DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(feature = \"cuda\")]\n                DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(wgpu_metal)]\n                DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(feature = \"rocm\")]\n                DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(wgpu_vulkan)]\n                DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(wgpu_webgpu)]\n                DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::WebGpu(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                #[cfg(feature = \"ndarray\")]\n                DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray(\n                    crate::BackendTensor::Float(tensor.autodiff().primitive),\n                ),\n                DispatchTensorKind::Autodiff(_) => {\n                    panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n            _ => panic!(\"Requires autodiff tensor.\"),\n        };\n        DispatchTensor {\n            kind,\n            checkpointing,\n        }\n    }\n\n    fn int_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n\n    fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n\n    fn q_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n\n    fn from_inner(tensor: DispatchTensor) -> DispatchTensor {\n        let DispatchTensor {\n            kind,\n            checkpointing,\n        } = tensor;\n\n        let kind = match kind {\n            #[cfg(feature = \"cpu\")]\n            DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff(\n                    Autodiff::<Cpu<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(feature = \"cuda\")]\n            DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff(\n                    Autodiff::<Cuda<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(wgpu_metal)]\n            DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::Metal(crate::BackendTensor::Autodiff(\n                    Autodiff::<Metal<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(feature = \"rocm\")]\n            DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff(\n                    Autodiff::<Rocm<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(wgpu_vulkan)]\n            DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff(\n                    Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(wgpu_webgpu)]\n            DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::WebGpu(crate::BackendTensor::Autodiff(\n                    Autodiff::<WebGpu<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            #[cfg(feature = \"ndarray\")]\n            DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new(\n                DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff(\n                    Autodiff::<NdArray<f32>>::from_inner(tensor.float()),\n                )),\n            )),\n            DispatchTensorKind::Autodiff(_) => {\n                panic!(\"Autodiff should not wrap an autodiff tensor.\")\n            }\n        };\n        DispatchTensor {\n            kind,\n            checkpointing,\n        }\n    }\n\n    fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n\n    fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n\n    fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {\n        tensor\n    }\n}\n\nimpl DispatchTensorKind {\n    pub(crate) fn device(&self) -> DispatchDevice {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),\n            #[cfg(feature = \"cuda\")]\n            DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),\n            #[cfg(wgpu_metal)]\n            DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),\n            #[cfg(feature = \"rocm\")]\n            DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),\n            #[cfg(wgpu_vulkan)]\n            DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),\n            #[cfg(wgpu_webgpu)]\n            DispatchTensorKind::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),\n            #[cfg(feature = \"ndarray\")]\n            DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),\n            #[cfg(feature = \"tch\")]\n            DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),\n            #[cfg(feature = \"autodiff\")]\n            DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),\n        }\n    }\n}\n\nimpl DispatchTensor {\n    pub(crate) fn device(&self) -> DispatchDevice {\n        self.kind.device()\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/device.rs",
    "content": "use burn_backend::{DeviceId, DeviceOps};\n\nuse crate::backends::*;\n\n/// Represents a device for the [`Dispatch`](crate::Dispatch).\n///\n/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn::DispatchDevice;\n///\n/// #[cfg(feature = \"cpu\")]\n/// let cpu_device = DispatchDevice::Cpu(Default::default());\n///\n/// #[cfg(feature = \"cuda\")]\n/// let cuda_device = DispatchDevice::Cuda(Default::default());\n/// ```\n#[derive(Clone, Eq)]\npub enum DispatchDevice {\n    /// The [CPU backend](Cpu) device.\n    #[cfg(feature = \"cpu\")]\n    Cpu(CpuDevice),\n\n    /// The [CUDA backend](Cuda) device.\n    #[cfg(feature = \"cuda\")]\n    Cuda(CudaDevice),\n\n    /// The [Metal backend](Metal) device (via WGPU runtime).\n    #[cfg(wgpu_metal)]\n    Metal(WgpuDevice),\n\n    /// The [ROCm backend](Rocm) device.\n    #[cfg(feature = \"rocm\")]\n    Rocm(RocmDevice),\n\n    /// The [Vulkan backend](Vulkan) device.\n    #[cfg(wgpu_vulkan)]\n    Vulkan(WgpuDevice),\n\n    /// The [WebGPU backend](WebGpu) device (via WGPU runtime).\n    #[cfg(wgpu_webgpu)]\n    WebGpu(WgpuDevice),\n\n    /// The [NdArray backend](NdArray) device (CPU-only).\n    #[cfg(feature = \"ndarray\")]\n    NdArray(NdArrayDevice),\n\n    /// The [LibTorch backend](LibTorch) device.\n    #[cfg(feature = \"tch\")]\n    LibTorch(LibTorchDevice),\n\n    /// The [autodiff enabled backend](Autodiff) device.\n    #[cfg(feature = \"autodiff\")]\n    Autodiff(AutodiffDevice),\n}\n\n#[cfg(feature = \"autodiff\")]\n// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.\n/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].\n///\n/// Use [`DispatchDevice::autodiff`] to construct this type.\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct AutodiffDevice {\n    pub(crate) inner: Box<DispatchDevice>,\n    pub(crate) checkpointing: CheckpointingStrategy,\n}\n\n#[cfg(feature = \"autodiff\")]\nimpl AutodiffDevice {\n    pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {\n        Self {\n            inner: Box::new(device),\n            checkpointing,\n        }\n    }\n}\n\n#[cfg(feature = \"autodiff\")]\n// Useful for match in dispatch macros\nimpl core::ops::Deref for AutodiffDevice {\n    type Target = DispatchDevice;\n\n    fn deref(&self) -> &Self::Target {\n        &self.inner\n    }\n}\n\n#[cfg(feature = \"autodiff\")]\n#[allow(missing_docs)]\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]\n/// Checkpointing strategy for autodiff.\n#[repr(u8)]\npub enum CheckpointingStrategy {\n    Balanced,\n    #[default]\n    None,\n}\n\n#[cfg(feature = \"autodiff\")]\npub(crate) fn validate_checkpointing(\n    lhs: crate::CheckpointingStrategy,\n    rhs: crate::CheckpointingStrategy,\n) -> crate::CheckpointingStrategy {\n    assert_eq!(\n        lhs, rhs,\n        \"Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy.\"\n    );\n    lhs\n}\n\nimpl core::fmt::Debug for DispatchDevice {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(device) => f.debug_tuple(\"Cpu\").field(device).finish(),\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(device) => f.debug_tuple(\"Cuda\").field(device).finish(),\n            #[cfg(wgpu_metal)]\n            Self::Metal(device) => f.debug_tuple(\"Metal\").field(device).finish(),\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(device) => f.debug_tuple(\"Rocm\").field(device).finish(),\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(device) => f.debug_tuple(\"Vulkan\").field(device).finish(),\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(device) => f.debug_tuple(\"WebGpu\").field(device).finish(),\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(device) => f.debug_tuple(\"NdArray\").field(device).finish(),\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(device) => f.debug_tuple(\"LibTorch\").field(device).finish(),\n            #[cfg(feature = \"autodiff\")]\n            // Format without `AutodiffDevice` wrapper\n            Self::Autodiff(device) => f.debug_tuple(\"Autodiff\").field(&device.inner).finish(),\n        }\n    }\n}\n\nimpl Default for DispatchDevice {\n    #[allow(unreachable_code)]\n    fn default() -> Self {\n        // TODO: which priority?\n\n        #[cfg(feature = \"cpu\")]\n        return Self::Cpu(CpuDevice);\n\n        #[cfg(feature = \"cuda\")]\n        return Self::Cuda(CudaDevice::default());\n\n        #[cfg(wgpu_metal)]\n        return Self::Metal(burn_wgpu::WgpuDevice::default());\n\n        #[cfg(feature = \"rocm\")]\n        return Self::Rocm(RocmDevice::default());\n\n        #[cfg(wgpu_vulkan)]\n        return Self::Vulkan(burn_wgpu::WgpuDevice::default());\n\n        #[cfg(wgpu_webgpu)]\n        return Self::WebGpu(burn_wgpu::WgpuDevice::default());\n\n        #[cfg(feature = \"ndarray\")]\n        return Self::NdArray(NdArrayDevice::default());\n\n        #[cfg(feature = \"tch\")]\n        return Self::LibTorch(LibTorchDevice::default());\n    }\n}\n\nimpl PartialEq for DispatchDevice {\n    fn eq(&self, other: &Self) -> bool {\n        match (self, other) {\n            // If both are Autodiff, compare the inner devices\n            #[cfg(feature = \"autodiff\")]\n            (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,\n            // If one is Autodiff, compare it to the raw device\n            #[cfg(feature = \"autodiff\")]\n            (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,\n            #[cfg(feature = \"autodiff\")]\n            (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),\n            #[cfg(feature = \"cpu\")]\n            (Self::Cpu(a), Self::Cpu(b)) => a == b,\n            #[cfg(feature = \"cuda\")]\n            (Self::Cuda(a), Self::Cuda(b)) => a == b,\n            #[cfg(wgpu_metal)]\n            (Self::Metal(a), Self::Metal(b)) => a == b,\n            #[cfg(feature = \"rocm\")]\n            (Self::Rocm(a), Self::Rocm(b)) => a == b,\n            #[cfg(wgpu_vulkan)]\n            (Self::Vulkan(a), Self::Vulkan(b)) => a == b,\n            #[cfg(wgpu_webgpu)]\n            (Self::WebGpu(a), Self::WebGpu(b)) => a == b,\n            #[cfg(feature = \"ndarray\")]\n            (Self::NdArray(a), Self::NdArray(b)) => a == b,\n            #[cfg(feature = \"tch\")]\n            (Self::LibTorch(a), Self::LibTorch(b)) => a == b,\n            #[allow(unreachable_patterns)]\n            (_, _) => false,\n        }\n    }\n}\n\n/// Base multiplier to avoid type_id clashes between backends.\n/// Limits the number of device types per backend, but this is a sensible limit.\nconst TYPE_ID_BASE: u16 = 10;\n\nimpl DispatchDevice {\n    #[cfg(feature = \"autodiff\")]\n    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.\n    pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {\n        Self::autodiff_checkpointed(device, CheckpointingStrategy::None)\n    }\n    #[cfg(feature = \"autodiff\")]\n    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.\n    pub fn autodiff_checkpointed(\n        device: impl Into<DispatchDevice>,\n        checkpointing: CheckpointingStrategy,\n    ) -> DispatchDevice {\n        let device = device.into();\n        DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))\n    }\n\n    /// Returns a unique number per variant to encode into type_id.\n    fn backend_id(&self) -> BackendId {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(_) => BackendId::Cpu,\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(_) => BackendId::Cuda,\n            #[cfg(wgpu_metal)]\n            Self::Metal(_) => BackendId::Metal,\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(_) => BackendId::Rocm,\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(_) => BackendId::Vulkan,\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(_) => BackendId::WebGpu,\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(_) => BackendId::NdArray,\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(_) => BackendId::LibTorch,\n            #[cfg(feature = \"autodiff\")]\n            Self::Autodiff(device) => device.inner.backend_id(),\n        }\n    }\n\n    /// Encode variant ID and backend type ID into a unique `type_id`.\n    fn encode_type_id(&self, backend_type_id: u16) -> u16 {\n        u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id\n    }\n\n    /// Decode an encoded `type_id` into variant ID and backend type ID.\n    fn decode_type_id(type_id: u16) -> (BackendId, u16) {\n        let variant = type_id / TYPE_ID_BASE;\n        let backend_type_id = type_id % TYPE_ID_BASE;\n        (\n            BackendId::try_from(variant).expect(\"Unknown DispatchDevice variant\"),\n            backend_type_id,\n        )\n    }\n}\n\n#[derive(Debug, Clone, Copy, PartialEq, Eq)]\n#[repr(u16)]\nenum BackendId {\n    #[cfg(feature = \"cpu\")]\n    Cpu = 0,\n    #[cfg(feature = \"cuda\")]\n    Cuda = 1,\n    #[cfg(wgpu_metal)]\n    Metal = 2,\n    #[cfg(feature = \"rocm\")]\n    Rocm = 3,\n    #[cfg(wgpu_vulkan)]\n    Vulkan = 4,\n    #[cfg(wgpu_webgpu)]\n    WebGpu = 5,\n    #[cfg(feature = \"ndarray\")]\n    NdArray = 6,\n    #[cfg(feature = \"tch\")]\n    LibTorch = 7,\n}\n\nimpl From<BackendId> for u16 {\n    fn from(variant: BackendId) -> Self {\n        variant as u16\n    }\n}\n\nimpl TryFrom<u16> for BackendId {\n    type Error = ();\n\n    fn try_from(value: u16) -> Result<Self, Self::Error> {\n        match value {\n            #[cfg(feature = \"cpu\")]\n            0 => Ok(Self::Cpu),\n            #[cfg(feature = \"cuda\")]\n            1 => Ok(Self::Cuda),\n            #[cfg(wgpu_metal)]\n            2 => Ok(Self::Metal),\n            #[cfg(feature = \"rocm\")]\n            3 => Ok(Self::Rocm),\n            #[cfg(wgpu_vulkan)]\n            4 => Ok(Self::Vulkan),\n            #[cfg(wgpu_webgpu)]\n            5 => Ok(Self::WebGpu),\n            #[cfg(feature = \"ndarray\")]\n            6 => Ok(Self::NdArray),\n            #[cfg(feature = \"tch\")]\n            7 => Ok(Self::LibTorch),\n            _ => Err(()),\n        }\n    }\n}\n\nimpl DeviceOps for DispatchDevice {\n    fn inner(&self) -> &Self {\n        match self {\n            #[cfg(feature = \"autodiff\")]\n            DispatchDevice::Autodiff(device) => &device.inner,\n            device => device,\n        }\n    }\n}\n\nimpl burn_std::device::Device for DispatchDevice {\n    fn from_id(mut device_id: DeviceId) -> Self {\n        let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);\n        device_id.type_id = backend_type_id;\n\n        match dispatch_id {\n            #[cfg(feature = \"cpu\")]\n            BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),\n            #[cfg(feature = \"cuda\")]\n            BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),\n            #[cfg(wgpu_metal)]\n            BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),\n            #[cfg(feature = \"rocm\")]\n            BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),\n            #[cfg(wgpu_vulkan)]\n            BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),\n            #[cfg(wgpu_webgpu)]\n            BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),\n            #[cfg(feature = \"ndarray\")]\n            BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),\n            #[cfg(feature = \"tch\")]\n            BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),\n        }\n    }\n\n    fn to_id(&self) -> DeviceId {\n        let mut device_id = match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(device) => device.to_id(),\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(device) => device.to_id(),\n            #[cfg(wgpu_metal)]\n            Self::Metal(device) => device.to_id(),\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(device) => device.to_id(),\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(device) => device.to_id(),\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(device) => device.to_id(),\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(device) => device.to_id(),\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(device) => device.to_id(),\n            #[cfg(feature = \"autodiff\")]\n            Self::Autodiff(device) => device.inner.to_id(),\n        };\n        device_id.type_id = self.encode_type_id(device_id.type_id);\n        device_id\n    }\n\n    fn device_count(type_id: u16) -> usize {\n        let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id);\n        match dispatch_id {\n            #[cfg(feature = \"cpu\")]\n            BackendId::Cpu => CpuDevice::device_count(backend_type_id),\n            #[cfg(feature = \"cuda\")]\n            BackendId::Cuda => CudaDevice::device_count(backend_type_id),\n            #[cfg(wgpu_metal)]\n            BackendId::Metal => WgpuDevice::device_count(backend_type_id),\n            #[cfg(feature = \"rocm\")]\n            BackendId::Rocm => RocmDevice::device_count(backend_type_id),\n            #[cfg(wgpu_vulkan)]\n            BackendId::Vulkan => WgpuDevice::device_count(backend_type_id),\n            #[cfg(wgpu_webgpu)]\n            BackendId::WebGpu => WgpuDevice::device_count(backend_type_id),\n            #[cfg(feature = \"ndarray\")]\n            BackendId::NdArray => NdArrayDevice::device_count(backend_type_id),\n            #[cfg(feature = \"tch\")]\n            BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id),\n        }\n    }\n}\n\n#[cfg(feature = \"cpu\")]\nimpl From<CpuDevice> for DispatchDevice {\n    fn from(device: CpuDevice) -> Self {\n        DispatchDevice::Cpu(device)\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nimpl From<CudaDevice> for DispatchDevice {\n    fn from(device: CudaDevice) -> Self {\n        DispatchDevice::Cuda(device)\n    }\n}\n\n#[cfg(wgpu_metal)]\nimpl From<WgpuDevice> for DispatchDevice {\n    fn from(device: WgpuDevice) -> Self {\n        DispatchDevice::Metal(device)\n    }\n}\n\n#[cfg(feature = \"rocm\")]\nimpl From<RocmDevice> for DispatchDevice {\n    fn from(device: RocmDevice) -> Self {\n        DispatchDevice::Rocm(device)\n    }\n}\n\n#[cfg(wgpu_vulkan)]\nimpl From<WgpuDevice> for DispatchDevice {\n    fn from(device: WgpuDevice) -> Self {\n        DispatchDevice::Vulkan(device)\n    }\n}\n\n#[cfg(wgpu_webgpu)]\nimpl From<WgpuDevice> for DispatchDevice {\n    fn from(device: WgpuDevice) -> Self {\n        DispatchDevice::WebGpu(device)\n    }\n}\n\n#[cfg(feature = \"ndarray\")]\nimpl From<NdArrayDevice> for DispatchDevice {\n    fn from(device: NdArrayDevice) -> Self {\n        DispatchDevice::NdArray(device)\n    }\n}\n\n#[cfg(feature = \"tch\")]\nimpl From<LibTorchDevice> for DispatchDevice {\n    fn from(device: LibTorchDevice) -> Self {\n        DispatchDevice::LibTorch(device)\n    }\n}\n\n#[cfg(feature = \"tch\")]\nimpl From<LibTorchDevice> for DispatchDevice {\n    fn from(device: LibTorchDevice) -> Self {\n        DispatchDevice::LibTorch(device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![recursion_limit = \"138\"]\n\n//! Burn multi-backend dispatch.\n//!\n//! # Available Backends\n//!\n//! The dispatch backend supports the following variants, each enabled via cargo features:\n//!\n//! | Backend    | Feature    | Description |\n//! |------------|------------|-------------|\n//! | `Cpu`      | `cpu`      | Rust CPU backend (MLIR + LLVM) |\n//! | `Cuda`     | `cuda`     | NVIDIA CUDA backend |\n//! | `Metal`    | `metal`    | Apple Metal backend via `wgpu` (MSL) |\n//! | `Rocm`     | `rocm`     | AMD ROCm backend |\n//! | `Vulkan`   | `vulkan`   | Vulkan backend via `wgpu` (SPIR-V) |\n//! | `WebGpu`   | `webgpu`   | WebGPU backend via `wgpu` (WGSL) |\n//! | `NdArray`  | `ndarray`  | Pure Rust CPU backend using `ndarray` |\n//! | `LibTorch` | `tch`      | Libtorch backend via `tch` |\n//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |\n//!\n//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.\n//! All other backends can be combined freely.\n//!\n//! ## WGPU Backend Exclusivity\n//!\n//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to\n//! the current automatic compile, which can only select one target at a time.\n//!\n//! Enable only **one** of these features in your `Cargo.toml`:\n//! - `metal`\n//! - `vulkan`\n//! - `webgpu`\n//!\n//! If multiple WGPU features are enabled, the build script will emit a warning and **disable all WGPU\n//! backends** to prevent unintended behavior.\n\n#[cfg(not(any(\n    feature = \"cpu\",\n    feature = \"cuda\",\n    wgpu_metal,\n    feature = \"rocm\",\n    wgpu_vulkan,\n    wgpu_webgpu,\n    feature = \"ndarray\",\n    feature = \"tch\",\n)))]\ncompile_error!(\"At least one backend feature must be enabled.\");\n\n#[macro_use]\nmod macros;\n\nmod backend;\nmod device;\nmod ops;\nmod tensor;\n\npub use backend::*;\npub use device::*;\npub use tensor::*;\n\nextern crate alloc;\n\n/// Backends and devices used.\npub(crate) mod backends {\n    #[cfg(feature = \"autodiff\")]\n    pub use burn_autodiff::Autodiff;\n\n    #[cfg(feature = \"cpu\")]\n    pub use burn_cpu::{Cpu, CpuDevice};\n    #[cfg(feature = \"cuda\")]\n    pub use burn_cuda::{Cuda, CudaDevice};\n    #[cfg(feature = \"rocm\")]\n    pub use burn_rocm::{Rocm, RocmDevice};\n    #[cfg(wgpu_metal)]\n    pub use burn_wgpu::Metal;\n    #[cfg(wgpu_vulkan)]\n    pub use burn_wgpu::Vulkan;\n    #[cfg(wgpu_webgpu)]\n    pub use burn_wgpu::WebGpu;\n    #[cfg(any(wgpu_metal, wgpu_vulkan, wgpu_webgpu))]\n    pub use burn_wgpu::WgpuDevice;\n\n    #[cfg(feature = \"ndarray\")]\n    pub use burn_ndarray::{NdArray, NdArrayDevice};\n    #[cfg(feature = \"tch\")]\n    pub use burn_tch::{LibTorch, LibTorchDevice};\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/macros.rs",
    "content": "/// Supplies a list of all supported backends and their corresponding feature flags\n/// to a callback macro. This centralizes the backend registry.\nmacro_rules! backend_list {\n    ($callback:ident, $($extra:tt)*) => {\n        $callback! {\n            $($extra)*;\n            [Cpu, feature = \"cpu\"],\n            [Cuda, feature = \"cuda\"],\n            [Metal, wgpu_metal],\n            [Rocm, feature = \"rocm\"],\n            [Vulkan, wgpu_vulkan],\n            [WebGpu, wgpu_webgpu],\n            [NdArray, feature = \"ndarray\"],\n            [LibTorch, feature = \"tch\"]\n        }\n    };\n}\n\n/// Supplies a matrix of cross-backend combinations. Used for operations where the source and destination backends may differ.\nmacro_rules! backend_matrix {\n    ($callback:ident, $($extra:tt)*) => {\n        $callback! {\n            $($extra)*;\n            [Cpu, feature = \"cpu\"] => [[Cuda, feature = \"cuda\"], [Metal, wgpu_metal], [Rocm, feature = \"rocm\"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [Cuda, feature = \"cuda\"] => [[Cpu, feature = \"cpu\"], [Metal, wgpu_metal], [Rocm, feature = \"rocm\"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [Metal, wgpu_metal] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Rocm, feature = \"rocm\"], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [Rocm, feature = \"rocm\"] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Metal, wgpu_metal], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [Vulkan, wgpu_vulkan] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Rocm, feature = \"rocm\"], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [WebGpu, wgpu_webgpu] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Rocm, feature = \"rocm\"], [NdArray, feature = \"ndarray\"], [LibTorch, feature = \"tch\"]];\n            [NdArray, feature = \"ndarray\"] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Metal, wgpu_metal], [Rocm, feature = \"rocm\"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [LibTorch, feature = \"tch\"]];\n            [LibTorch, feature = \"tch\"] => [[Cpu, feature = \"cpu\"], [Cuda, feature = \"cuda\"], [Metal, wgpu_metal], [Rocm, feature = \"rocm\"], [Vulkan, wgpu_vulkan], [WebGpu, wgpu_webgpu], [NdArray, feature = \"ndarray\"]]\n        }\n    };\n}\n\n/// Helper to map the runtime strategy to the compile-time Autodiff generic.\nmacro_rules! with_autodiff_backend {\n    ($Backend:ident, $checkpointing:expr, |$B:ident| $body:expr) => {\n        match $checkpointing {\n            $crate::CheckpointingStrategy::Balanced => {\n                type $B = Autodiff<\n                    $Backend<f32>,\n                    burn_autodiff::checkpoint::strategy::BalancedCheckpointing,\n                >;\n                $body\n            }\n            $crate::CheckpointingStrategy::None => {\n                type $B =\n                    Autodiff<$Backend<f32>, burn_autodiff::checkpoint::strategy::NoCheckpointing>;\n                $body\n            }\n        }\n    };\n}\n\n/// Match arm generator for `dispatch_device`.\n/// Maps each backend variant to a block where the specific backend type is bound to `B`.\nmacro_rules! dispatch_device_arms {\n    (\n        $device:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {\n        match $device {\n            // Autodiff arm first\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchDevice::Autodiff(inner) => {\n                // Recursively dispatch on inner\n                dispatch_device_arms!(\n                    @autodiff\n                    &**inner,\n                    |$inner| $body;\n                    $([$Backend, $cfg]),*\n                )\n            },\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchDevice::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    $body\n                }\n            )*\n        }\n    };\n    (\n        @autodiff\n        $device:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {\n        match $device {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchDevice::$Backend($inner) => {\n                    type B = Autodiff<$Backend<f32>>;\n                    $body\n                }\n            )*\n            $crate::DispatchDevice::Autodiff(_) => panic!(\"Autodiff should not wrap an autodiff device.\")\n        }\n    };\n}\n\n/// Dispatches an operation body based on the provided device.\nmacro_rules! dispatch_device {\n    ($device:expr, |$inner:ident| $body:expr) => {\n        backend_list!(dispatch_device_arms, $device, |$inner| $body)\n    };\n}\n\n/// Match arm generator for `to_device`.\n/// Handles the logic for same-backend transfers (fast path) and cross-backend\n/// transfers by generating a grid of all device combinations provided via `backend_matrix`.\nmacro_rules! to_device_arms {\n    (\n        $kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr;\n        $( [$B1:ident, $src_cfg:meta] => [ $( [$B2:ident, $dst_cfg:meta] ),+ ] );*\n    ) => {\n        match ($tensor.kind, $device) {\n            // --- Same backend to_device ---\n            $(\n                #[cfg($src_cfg)]\n                ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B1(d)) => {\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$B1($crate::BackendTensor::$kind(\n                            $B1::<f32>::$to_device(tensor.$inner_fn(), d)\n                        )),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing: $tensor.checkpointing,\n                    }\n                }\n            )*\n\n            // --- Cross backend arms ---\n            // This loop generates the grid of combinations\n            $(\n                $(\n                    #[cfg(all($src_cfg, $dst_cfg))]\n                    ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B2($device_ident)) => {\n                        type B1 = $B1<f32>;\n                        type B2 = $B2<f32>;\n                        let $inner = tensor.$inner_fn();\n\n                        $crate::DispatchTensor {\n                            kind: $crate::DispatchTensorKind::$B2(\n                                $crate::BackendTensor::$kind($body)\n                            ),\n                            #[cfg(feature = \"autodiff\")]\n                            checkpointing: $tensor.checkpointing,\n                        }\n                    }\n                )+\n            )*\n            #[cfg(feature = \"autodiff\")]\n            (_, $crate::DispatchDevice::Autodiff(_)) | ($crate::DispatchTensorKind::Autodiff(..), _) => panic!(\"Operation not marked for autodiff.\")\n        }\n    };\n}\n\n/// Handles tensor movement between devices, supporting both same-backend transfers\n/// and cross-backend dispatches.\nmacro_rules! to_device {\n    ($kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr) => {\n        backend_matrix!(\n            to_device_arms,\n            $kind,\n            $inner_fn,\n            $tensor,\n            $device,\n            $to_device,\n            |$inner, $device_ident| $body\n        )\n    };\n}\n\n/// Match arm generator for `float_to_device`.\n///\n/// Similar to `to_device_arms`, but float tensors are checked for autodiff support.\nmacro_rules! float_to_device_arms {\n    (\n        $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr;\n        $( [$B1:ident, $src_cfg:meta] => [ $( [$B2:ident, $dst_cfg:meta] ),+ ] );*\n    ) => {\n        match ($tensor.kind, $device) {\n            #[cfg(feature = \"autodiff\")]\n            ($crate::DispatchTensorKind::Autodiff(kind), $crate::DispatchDevice::Autodiff(device)) => {\n                let ckp = $tensor.checkpointing;\n                float_to_device_arms!(\n                    @autodiff\n                    *kind, &**device, ckp, $to_device;\n                    $([$B1, $src_cfg]);*\n                )\n\n            }\n            // --- Same backend to_device ---\n            $(\n                #[cfg($src_cfg)]\n                ($crate::DispatchTensorKind::$B1(kind), $crate::DispatchDevice::$B1(d)) => {\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$B1($crate::BackendTensor::Float(\n                            $B1::<f32>::$to_device(kind.float(), d)\n                        )),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing: $tensor.checkpointing,\n                    }\n                }\n            )*\n\n            // --- Cross backend arms ---\n            // This loop generates the grid of combinations\n            $(\n                $(\n                    #[cfg(all($src_cfg, $dst_cfg))]\n                    ($crate::DispatchTensorKind::$B1(kind), $crate::DispatchDevice::$B2($device_ident)) => {\n                        type B1 = $B1<f32>;\n                        type B2 = $B2<f32>;\n                        let $inner = kind.float();\n\n                        $crate::DispatchTensor {\n                            kind: $crate::DispatchTensorKind::$B2($crate::BackendTensor::Float($body)),\n                            #[cfg(feature = \"autodiff\")]\n                            checkpointing: $tensor.checkpointing,\n                        }\n                    }\n                )+\n            )*\n            #[cfg(feature = \"autodiff\")]\n            ($crate::DispatchTensorKind::Autodiff(..), _) | (_, $crate::DispatchDevice::Autodiff(_)) => panic!(\"Cannot move between autodiff and non-autodiff instances.\")\n        }\n    };\n\n    // Autodiff(DispatchTensor)\n    (\n        @autodiff\n        $tensor:expr, $device:expr, $ckp:expr, $to_device:ident;\n        $( [$B1:ident, $src_cfg:meta] );*\n    ) => {{\n        match ($tensor, $device) {\n            // --- Same backend to_device ---\n            $(\n                #[cfg($src_cfg)]\n                ($crate::DispatchTensorKind::$B1(tensor), $crate::DispatchDevice::$B1(d)) => {\n                    let kind = $crate::DispatchTensorKind::Autodiff(Box::new($crate::DispatchTensorKind::$B1($crate::BackendTensor::Autodiff(\n                        with_autodiff_backend!($B1, $ckp, |B| {\n                            B::$to_device(tensor.autodiff(), d)\n                        })\n                    ))));\n                    $crate::DispatchTensor {kind, checkpointing: $ckp}\n                }\n            )*\n            (_, _) => unimplemented!(\"Autodiff tensor cannot be moved between backends.\")\n        }\n    }};\n}\n\n/// Handles float tensor movement between devices (that might support autodiff).\nmacro_rules! float_to_device {\n    ($kind:ident, $inner_fn:ident, $tensor:expr, $device:expr, $to_device:ident, |$inner:ident, $device_ident:ident| $body:expr) => {\n        backend_matrix!(\n            float_to_device_arms,\n            $tensor,\n            $device,\n            $to_device,\n            |$inner, $device_ident| $body\n        )\n    };\n}\n\n/// Dispatches a tensor creation operation (e.g., zeros, ones) to the correct backend\n/// based on the provided device.\nmacro_rules! creation_op {\n    ($kind:ident, $device:expr, |$inner:ident| $body:expr) => {\n        backend_list!(creation_op_arms, $kind, $device, |$inner| $body)\n    };\n}\n\n/// Match arm generator for `creation_float`.\n///\n/// Similar to `creation_op_arms`, but float tensors are checked for autodiff support.\nmacro_rules! creation_op_arms {\n    (\n        $kind:ident,\n        $device:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match $device {\n            // Autodiff arm first\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchDevice::Autodiff(inner) => {\n                // Recursively dispatch on inner\n                creation_op_arms!(\n                    @autodiff\n                    $kind,\n                    &**inner,\n                    inner.checkpointing,\n                    |$inner| $body;\n                    $([$Backend, $cfg]),*\n                )\n            },\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchDevice::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend(\n                            $crate::BackendTensor::$kind($body)\n                        ),\n                        // TODO: hmmm should devices also carry the checkpointing all the time?\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing: $crate::CheckpointingStrategy::None,\n                    }\n                }\n            )*\n        }\n    }};\n\n    (\n        @autodiff\n        $kind:ident,\n        $device:expr,\n        $ckp:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match $device {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchDevice::$Backend($inner) => {\n                    with_autodiff_backend!($Backend, $ckp, |B| {\n                        wrap_float!(@wrap_autodiff $kind, $Backend, $ckp, { $body })\n                    })\n                }\n            )*\n            $crate::DispatchDevice::Autodiff(_) => panic!(\"Autodiff should not wrap an autodiff device.\")\n        }\n    }};\n}\n\n/// Wrap the result in the backend tensor kind, handling float -> autodiff.\n#[cfg(feature = \"autodiff\")]\nmacro_rules! wrap_float {\n    (\n        @wrap_autodiff Float,\n        $Backend:ident,\n        $ckp:expr,\n        $expr:expr\n    ) => {\n        $crate::DispatchTensor {\n            kind: $crate::DispatchTensorKind::Autodiff(Box::new(\n                $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Autodiff($expr)),\n            )),\n            checkpointing: $ckp,\n        }\n    };\n\n    (\n        @wrap_autodiff $other:ident,\n        $Backend:ident,\n        $ckp:expr,\n        $expr:expr\n    ) => {\n        $crate::DispatchTensor {\n            kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$other($expr)),\n            checkpointing: $ckp,\n        }\n    };\n}\n\n/// Match arm generator for `unary_op`.\n/// Unwraps the inner tensor primitive (e.g., `inner.float()`) and provides the backend type `B`\n/// for the operation.\n///\n/// When the return kind is provided, the result is wrapped in the corresponding `DispatchTensor` variant.\nmacro_rules! unary_op_arms {\n    (\n        $kind:ident,\n        $inner_kind:ident,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = $tensor.checkpointing;\n\n        match $tensor.kind {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    let $inner = $inner.$inner_kind();\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Operation not marked for autodiff.\")\n        }\n    }};\n\n    // Operations that do not return a tensor kind\n    (\n        $inner_kind:ident,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match $tensor.kind {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    let $inner = $inner.$inner_kind();\n                    $body\n                }\n            )*\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Operation not marked for autodiff.\")\n        }\n    }};\n}\n\n/// Backend dispatch for unary operations.\n///\n/// When the return `=> Kind` is not provided, the operation output is not wrapped in a dispatch tensor (e.g., `into_data(..)`)\nmacro_rules! unary_op {\n    ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => {\n        backend_list!(unary_op_arms, $kind, $inner_kind, $tensor, |$inner| {\n            $body\n        })\n    };\n    ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => {\n        backend_list!(unary_op_arms, $inner_kind, $tensor, |$inner| { $body })\n    };\n}\n\n/// Match arm generator for `unary_float`.\n///\n/// Similar to `unary_op_arms`, but float tensors are checked for autodiff support.\nmacro_rules! unary_float_arms {\n    (\n        $mode:ident, // `owned` or `ref`\n        $kind:ident,\n        $inner_kind:ident,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = $tensor.checkpointing;\n\n        match $tensor.kind {\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(inner) => {\n                unary_float_arms!(\n                    @autodiff $mode,\n                    checkpointing,\n                    $kind,\n                    { if_mode!($mode, &**inner, *inner) },\n                    |$inner| $body;\n                    $([$Backend, $cfg]),*\n                )\n            },\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    let $inner = unary_float_arms!(@unwrap $mode, $inner, $inner_kind);\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend(\n                            $crate::BackendTensor::$kind($body)\n                        ),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n        }\n    }};\n\n    // --- Autodiff recursive arm ---\n    (\n        @autodiff $mode:ident,\n        $ckp:expr,\n        $kind:ident,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match $tensor {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    with_autodiff_backend!($Backend, $ckp, |B| {\n                        let $inner = unary_float_arms!(@unwrap_ad $mode, $inner);\n                        wrap_float!(@wrap_autodiff $kind, $Backend, $ckp, { $body })\n                    })\n                }\n            )*\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\")\n        }\n    }};\n\n    // --- Non-wrapping arms (operations not returning a tensor) ---\n    (\n        $mode:ident,\n        $inner_kind:ident,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = &$tensor.checkpointing;\n\n        match { if_mode!($mode, &$tensor.kind, $tensor.kind) } {\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(inner) => {\n                unary_float_arms!(\n                    @autodiff $mode,\n                    checkpointing,\n                    { if_mode!($mode, &**inner, *inner) },\n                    |$inner| $body;\n                    $([$Backend, $cfg]),*\n                )\n            },\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    type B = $Backend<f32>;\n                    let $inner = unary_float_arms!(@unwrap $mode, $inner, $inner_kind);\n                    $body\n                }\n            )*\n        }\n    }};\n    (\n        @autodiff $mode:ident,\n        $ckp:expr,\n        $tensor:expr,\n        |$inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match $tensor {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend($inner) => {\n                    with_autodiff_backend!($Backend, $ckp, |B| {\n                        let $inner = unary_float_arms!(@unwrap_ad $mode, $inner);\n                        $body\n                    })\n                }\n            )*\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\")\n        }\n    }};\n\n    // --- Helpers to unwarp the tensor based on owned/ref ---\n    (@unwrap owned, $inner:ident, $inner_kind:ident) => { $inner.$inner_kind() };\n    (@unwrap ref, $inner:ident, $inner_kind:ident) => {\n        paste::paste! { $inner.[< as_ $inner_kind >]() }\n    };\n\n    (@unwrap_ad owned, $inner:ident) => { $inner.autodiff() };\n    (@unwrap_ad ref, $inner:ident) => { $inner.as_autodiff() };\n\n}\n\n#[cfg(feature = \"autodiff\")]\n/// Utility to pick a token based on mode\nmacro_rules! if_mode {\n    (ref, $if_ref:expr, $if_owned:expr) => {\n        $if_ref\n    };\n    (owned, $if_ref:expr, $if_owned:expr) => {\n        $if_owned\n    };\n}\n\n/// Backend dispatch for float unary operations (that might support autodiff).\n///\n/// When the return `=> Kind` is not provided, the operation output is not wrapped in a dispatch tensor (e.g., `into_data(..)`)\nmacro_rules! unary_float {\n    // Owned with return kind\n    ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => {\n        backend_list!(\n            unary_float_arms,\n            owned,\n            $kind,\n            $inner_kind,\n            $tensor,\n            |$inner| { $body }\n        )\n    };\n    // Owned without return kind\n    ($tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => {\n        backend_list!(unary_float_arms, owned, $inner_kind, $tensor, |$inner| {\n            $body\n        })\n    };\n    // Reference without return kind\n    (ref $tensor:expr, $inner_kind:ident, |$inner:ident| $body:expr) => {\n        backend_list!(unary_float_arms, ref, $inner_kind, $tensor, |$inner| {\n            $body\n        })\n    };\n}\n\n/// Match arm generator for `binary_op`.\n/// Matches two tensors to ensure they share the same backend before unwrapping them for the operation.\nmacro_rules! binary_op_arms {\n    (\n        $kind:ident,\n        ($lhs:expr, $lhs_kind:ident),\n        ($rhs:expr, $rhs_kind:ident),\n        |$lhs_inner:ident, $rhs_inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing);\n\n        match ($lhs.kind, $rhs.kind) {\n            $(\n                #[cfg($cfg)]\n                ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    type B = $Backend<f32>;\n                    let $lhs_inner = $lhs_inner.$lhs_kind();\n                    let $rhs_inner = $rhs_inner.$rhs_kind();\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n            #[allow(unreachable_patterns)]\n            (lhs, rhs) => {\n                panic!(\n                    \"The provided tensors are not on the same backend. Got backends {:?} and {:?}.\", lhs, rhs\n                );\n            }\n        }\n    }};\n}\n\n/// Backend dispatch for binary operations.\n/// Automatically verifies that both tensors reside on the same backend.\nmacro_rules! binary_op {\n    (($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr => $kind:ident) => {\n        backend_list!(\n            binary_op_arms,\n            $kind,\n            ($lhs, $lhs_kind),\n            ($rhs, $rhs_kind),\n            |$lhs_inner, $rhs_inner| { $body }\n        )\n    };\n}\n\n/// Match arm generator for `binary_float`.\n/// Matches two tensors to ensure they share the same backend before unwrapping them for the operation.\nmacro_rules! binary_float_arms {\n    // (float, float) binary op\n    (\n        $kind:ident,\n        ($lhs:expr, float),\n        ($rhs:expr, float),\n        |$lhs_inner:ident, $rhs_inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing);\n\n        match ($lhs.kind, $rhs.kind) {\n            // Autodiff arms first\n            #[cfg(feature = \"autodiff\")]\n            ($crate::DispatchTensorKind::Autodiff(lhs_inner), $crate::DispatchTensorKind::Autodiff(rhs_inner)) => {\n                // Recursively dispatch on inner\n                binary_float_arms!(\n                    @autodiff\n                    $kind,\n                    (*lhs_inner, autodiff, checkpointing),\n                    (*rhs_inner, autodiff, checkpointing),\n                    |$lhs_inner, $rhs_inner| $body;\n                    $([$Backend, $cfg]),*\n                )\n            },\n            $(\n                #[cfg($cfg)]\n                ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    type B = $Backend<f32>;\n                    let $lhs_inner = $lhs_inner.float();\n                    let $rhs_inner = $rhs_inner.float();\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n            #[allow(unreachable_patterns)]\n            (lhs, rhs) => {\n                panic!(\n                    \"The provided tensors are not on the same backend. Got backends {:?} and {:?}.\", lhs, rhs\n                );\n            }\n        }\n    }};\n    // (float, any) binary op\n    (\n        $kind:ident,\n        ($lhs:expr, float),\n        ($rhs:expr, $rhs_kind:ident),\n        |$lhs_inner:ident, $rhs_inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = $crate::validate_checkpointing($lhs.checkpointing, $rhs.checkpointing);\n\n        match ($lhs.kind, $rhs.kind) {\n            $(\n                // Autodiff arms first\n                #[cfg(all(feature = \"autodiff\", $cfg))]\n                ($crate::DispatchTensorKind::Autodiff(lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    // Match on inner\n                    match *lhs_inner {\n                        $crate::DispatchTensorKind::$Backend($lhs_inner) => {\n                            with_autodiff_backend!($Backend, checkpointing, |B| {\n                                let $lhs_inner = $lhs_inner.autodiff();\n                                let $rhs_inner = $rhs_inner.$rhs_kind();\n                                wrap_float!(\n                                    @wrap_autodiff\n                                    $kind,\n                                    $Backend,\n                                    checkpointing,\n                                    { $body }\n                                )\n                            })\n                        }\n                        $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\"),\n                        #[allow(unreachable_patterns)]\n                        _ => panic!(\"The provided tensors are not on the same backend.\")\n                    }\n                },\n\n                #[cfg($cfg)]\n                ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    type B = $Backend<f32>;\n                    let $lhs_inner = $lhs_inner.float();\n                    let $rhs_inner = $rhs_inner.$rhs_kind();\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n            #[allow(unreachable_patterns)]\n            (lhs, rhs) => {\n                panic!(\n                    \"The provided tensors are not on the same backend. Got backends {:?} and {:?}.\", lhs, rhs\n                );\n            }\n        }\n    }};\n    (\n        $kind:ident,\n        ($lhs:expr, $lhs_kind:ident),\n        ($rhs:expr, $rhs_kind:ident),\n        |$lhs_inner:ident, $rhs_inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match ($lhs, $rhs) {\n            $(\n                #[cfg($cfg)]\n                ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    type B = $Backend<f32>;\n                    let $lhs_inner = $lhs_inner.$lhs_kind();\n                    let $rhs_inner = $rhs_inner.$rhs_kind();\n                    $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body))\n                }\n            )*\n            (lhs, rhs) => {\n                panic!(\n                    \"The provided tensors are not on the same backend. Got backends {:?} and {:?}.\", lhs, rhs\n                );\n            }\n        }\n    }};\n    // Autodiff (lhs, rhs) tensors\n    (\n        @autodiff\n        $kind:ident,\n        ($lhs:expr, $lhs_kind:ident, $ckp_lhs:expr),\n        ($rhs:expr, $rhs_kind:ident, $ckp_rhs:expr),\n        |$lhs_inner:ident, $rhs_inner:ident| $body:expr;\n        $([$Backend:ident, $cfg:meta]),*\n    ) => {{\n        match ($lhs, $rhs) {\n            $(\n                #[cfg($cfg)]\n                ($crate::DispatchTensorKind::$Backend($lhs_inner), $crate::DispatchTensorKind::$Backend($rhs_inner)) => {\n                    with_autodiff_backend!($Backend, $ckp_lhs, |B| {\n                        let $lhs_inner = $lhs_inner.$lhs_kind();\n                        let $rhs_inner = $rhs_inner.$rhs_kind();\n                        wrap_float!(\n                            @wrap_autodiff\n                            $kind,\n                            $Backend,\n                            $ckp_lhs,\n                            { $body }\n                        )\n                    })\n                }\n            )*\n            #[cfg(feature = \"autodiff\")]\n            ($crate::DispatchTensorKind::Autodiff(..), _) | (_, $crate::DispatchTensorKind::Autodiff(..))  => panic!(\"Autodiff should not wrap an autodiff tensor.\"),\n            #[allow(unreachable_patterns)]\n            (lhs, rhs) => {\n                panic!(\n                    \"The provided tensors are not on the same backend. Got backends {:?} and {:?}.\", lhs, rhs\n                );\n            }\n        }\n    }};\n\n}\n\n/// Backend dispatch for binary operations.\n/// Automatically verifies that both tensors reside on the same backend.\nmacro_rules! binary_float {\n    (($lhs:expr, $lhs_kind:ident), ($rhs:expr, $rhs_kind:ident), |$lhs_inner:ident, $rhs_inner:ident| $body:expr => $kind:ident) => {\n        backend_list!(\n            binary_float_arms,\n            $kind,\n            ($lhs, $lhs_kind),\n            ($rhs, $rhs_kind),\n            |$lhs_inner, $rhs_inner| { $body }\n        )\n    };\n}\n\n/// The core logic for a single backend in a `multi_op`.\n/// Handles the manual unwrapping of required/optional inputs and the\n/// re-wrapping of multiple required/optional output tensors.\nmacro_rules! multi_op_arm {\n    (\n        $Backend:ident,\n        $ckp:ident,\n        [ $( ($x:ident, $x_kind:ident) ),+ ],\n        [ $( ($opt_in:ident, $opt_kind:ident) ),* ],\n        [ $( ($out:ident, $out_kind:ident) ),+  ],\n        [ $( $opt_out:ident ),* ],\n        $body:expr\n    ) => {{\n        type B = $Backend<f32>;\n\n        // Required inputs\n        $(\n            let $x = match $x.kind {\n                $crate::DispatchTensorKind::$Backend(inner) => inner.$x_kind(),\n                #[allow(unreachable_patterns)]\n                _ => panic!(\"Input tensor {} is on the wrong device\", stringify!($x)),\n            };\n        )+\n\n        // Optional inputs\n        $(\n            let $opt_in = $opt_in.map(|o| match o.kind {\n                $crate::DispatchTensorKind::$Backend(inner) => inner.$opt_kind(),\n                #[allow(unreachable_patterns)]\n                _ => panic!(\"Optional tensor {} is on the wrong device\", stringify!($opt_in)),\n            });\n        )*\n\n        let ($($out),+, $($opt_out),*) = $body;\n\n        // Outputs and optional outputs\n        (\n            $(\n                $crate::DispatchTensor {\n                    kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$out_kind($out)),\n                    #[cfg(feature = \"autodiff\")]\n                    checkpointing: $ckp,\n                }\n            ),+,\n            $(\n                $opt_out.map(|t|\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Float(t)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing: $ckp,\n                    }\n                )\n            ),*\n        )\n    }};\n}\n\n#[cfg(feature = \"autodiff\")]\nmacro_rules! wrap_input_autodiff {\n    ($Backend:ident, $inner:expr, int) => {\n        $inner.int()\n    };\n    ($Backend:ident, $inner:expr, bool) => {\n        $inner.bool()\n    };\n    // Float tensors: wrap with autodiff\n    ($Backend:ident, $inner:expr, float) => {\n        $inner.autodiff()\n    };\n}\n\n#[cfg(feature = \"autodiff\")]\n// DispatchTensorKind::Autodiff(DispatchTensorKind::$Backend(BackendTensor::Autodiff()))\nmacro_rules! multi_op_arm_autodiff {\n    (\n        $Backend:ident,\n        $ckp:ident,\n        [ $( ($x:ident, $x_kind:ident) ),+ ],\n        [ $( ($opt_in:ident, $opt_kind:ident) ),* ],\n        [ $( ($out:ident, $out_kind:ident) ),+  ],\n        [ $( $opt_out:ident ),* ],\n        $body:expr\n    ) => {{\n        // type B = Autodiff<$Backend<f32>>;\n        with_autodiff_backend!($Backend, $ckp, |B| {\n            // Required inputs\n            $(\n                let $x = match $x.kind {\n                    $crate::DispatchTensorKind::Autodiff(inner) => {\n                        match *inner {\n                            $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $x_kind),\n                            _ => panic!(\"Input tensor {} is on the wrong device\", stringify!($x)),\n                        }\n                    },\n                    // Unreachable, except when input is int\n                    $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $x_kind),\n                    #[allow(unreachable_patterns)]\n                    _ => panic!(\"Input tensor {} is on the wrong device\", stringify!($x)),\n                };\n            )+\n\n            // Optional inputs (always assumed to be float / autodiff)\n            $(\n                let $opt_in = $opt_in.map(|o| match o.kind {\n                    $crate::DispatchTensorKind::Autodiff(inner) => {\n                        match *inner {\n                            $crate::DispatchTensorKind::$Backend(inner) => wrap_input_autodiff!($Backend, inner, $opt_kind),\n                            _ => panic!(\"Input tensor {} is on the wrong device\", stringify!($opt_in)),\n                        }\n                    },\n                    _ => panic!(\"Optional tensor {} is on the wrong device\", stringify!($opt_in)),\n                });\n            )*\n\n            let ($($out),+, $($opt_out),*) = $body;\n\n            // Outputs and optional outputs\n            (\n                $( wrap_float!(@wrap_autodiff $out_kind, $Backend, $ckp, $out) ),+,\n                $( $opt_out.map(|t| wrap_float!(@wrap_autodiff Float, $Backend, $ckp, t)) ),*\n            )\n        })\n    }};\n}\n\n/// Helper to extract the first identifier from an input list.\n/// Used to determine the device/backend for dispatching multi-tensor operations.\nmacro_rules! first_input {\n    ([ ($x:ident, $kind:ident) $(, $rest:tt)* ]) => {\n        $x\n    };\n}\n\n/// Match arm generator for `multi_op`.\n/// Determines the backend based on the first input and delegates to `multi_op_arm`\n/// to handle the repetition-heavy unwrapping and wrapping logic.\nmacro_rules! multi_op_arms_autodiff {\n    (\n        $inputs:tt,\n        $opt_inputs:tt,\n        $outputs:tt,\n        $opt_outputs:tt,\n        $body:expr;\n        $( [$Backend:ident, $cfg:meta] ),*\n    ) => {{\n        let first_input = &first_input!($inputs);\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = first_input.checkpointing;\n        match &first_input.kind {\n            // Autodiff first\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(inner) => {\n                match **inner {\n                    $(\n                        #[cfg($cfg)]\n                        $crate::DispatchTensorKind::$Backend(_) => {\n                            multi_op_arm_autodiff!(\n                                $Backend,\n                                checkpointing,\n                                $inputs,\n                                $opt_inputs,\n                                $outputs,\n                                $opt_outputs,\n                                $body\n                            )\n                        }\n                    )*\n                    $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend(_) => {\n                    multi_op_arm!(\n                        $Backend,\n                        checkpointing,\n                        $inputs,\n                        $opt_inputs,\n                        $outputs,\n                        $opt_outputs,\n                        $body\n                    )\n                }\n            )*\n        }\n    }};\n}\n\n/// Match arm generator for `multi_op`.\n///\n/// Similar to `multi_op_arms`, but skips autodiff checks.\nmacro_rules! multi_op_arms {\n    (\n        $inputs:tt,\n        $opt_inputs:tt,\n        $outputs:tt,\n        $opt_outputs:tt,\n        $body:expr;\n        $( [$Backend:ident, $cfg:meta] ),*\n    ) => {{\n        let first_input = &first_input!($inputs);\n        let checkpointing = if cfg!(feature = \"autodiff\") {\n            first_input.checkpointing\n        } else {\n            $crate::CheckpointingStrategy::None\n        };\n\n        match first_input.kind {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend(_) => {\n                    multi_op_arm!(\n                        $Backend,\n                        checkpointing,\n                        $inputs,\n                        $opt_inputs,\n                        $outputs,\n                        $opt_outputs,\n                        $body\n                    )\n                }\n            )*\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Operation not marked for autodiff.\")\n        }\n    }};\n}\n\n/// High-level macro for complex module operations (e.g., conv2d) and multi-tensor operations.\n/// Handles variable numbers of required/optional inputs and wraps multiple outputs.\n///\n/// Usage:\n/// ```ignore\n/// multi_op!(\n///     inputs[(x, float), (weight, float)],\n///     opt_inputs[(bias, float)],\n///     => Float,\n///     B::conv2d(x, weight, bias, options)\n/// )\n/// ```\nmacro_rules! multi_op {\n    // --- Single output shorthands ---\n    // Automatically wraps body in tuple and extracts .0\n    (\n        inputs[$( ($x:ident, $kind:ident) ),+],\n        => Float,\n        $body:expr\n    ) => {\n        multi_op!(\n            inputs[$( ($x, $kind) ),+],\n            opt_inputs[],\n            outputs[(out, Float)],\n            opt_outputs[],\n            { ($body,) }\n        )\n        .0\n    };\n    (\n        inputs[$( ($x:ident, $kind:ident) ),+],\n        opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ],\n        => $out_kind:ident,\n        $body:expr\n    ) => {\n        multi_op!(\n            inputs[$( ($x, $kind) ),+],\n            opt_inputs[ $(($opt_in, $opt_kind)),* ],\n            outputs[(out, $out_kind)],\n            opt_outputs[],\n            { ($body,) }\n        )\n        .0\n    };\n    // Int/Bool op specialization (not marked for autodiff)\n    (\n        inputs[$( ($x:ident, $kind:ident) ),+],\n        => $out_kind:ident,\n        $body:expr\n    ) => {\n        backend_list!(\n            multi_op_arms,\n            [ $(($x, $kind)),+ ],\n            [],\n            [ (out, $out_kind) ],\n            [],\n            { ($body,) }\n        ).0\n    };\n\n    // --- Required + optional for both inputs and outputs ---\n    (\n        inputs[ $(($x:ident, $kind:ident)),+ ],\n        opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ],\n        outputs[ $( ($out:ident, $out_kind:ident) ),+ ],\n        opt_outputs[ $($opt_out:ident),* ],\n        $body:expr\n    ) => {\n        backend_list!(\n            multi_op_arms_autodiff,\n            [ $(($x, $kind)),+ ],\n            [ $(($opt_in, $opt_kind)),* ],\n            [ $(($out, $out_kind)),+ ],\n            [ $($opt_out),* ],\n            $body\n        )\n    };\n\n    (\n        inputs[ $(($x:ident, $kind:ident)),+ ],\n        opt_inputs[ $(($opt_in:ident, $opt_kind:ident)),* ],\n        outputs[ $($out:ident),+ ],\n        $body:expr\n    ) => {\n        multi_op!(\n            inputs[ $(($x, $kind)),+ ],\n            opt_inputs[ $(($opt_in, $opt_kind)),* ],\n            outputs[ $(($out, Float)),+ ],\n            opt_outputs[],\n            $body\n        )\n    };\n\n    (\n        inputs[ $(($x:ident, $kind:ident)),+ ],\n        outputs[ $( ($out:ident, $out_kind:ident) ),+ ],\n        $body:expr\n    ) => {\n        multi_op!(\n            inputs[ $(($x, $kind)),+ ],\n            opt_inputs[],\n            outputs[ $(($out, $out_kind)),+ ],\n            opt_outputs[],\n            $body\n        )\n    };\n}\n\n/// Unwraps a `Vec<DispatchTensor>` for a known backend.\nmacro_rules! unwrap_vec {\n    ($Backend:ident, $vec:expr, $kind:ident) => {\n        $vec.into_iter()\n            .map(|t| match t.kind {\n                $crate::DispatchTensorKind::$Backend(inner) => inner.$kind(),\n                #[allow(unreachable_patterns)]\n                _ => panic!(\n                    \"Tensor is on the wrong backend (expected {}).\",\n                    stringify!($Backend)\n                ),\n            })\n            .collect::<Vec<_>>()\n    };\n\n    // Autodiff-wrapped backend\n    (@autodiff $Backend:ident, $vec:expr, $kind:ident) => {\n        $vec.into_iter()\n            .map(|t| match t.kind {\n                $crate::DispatchTensorKind::Autodiff(inner) => match *inner {\n                    $crate::DispatchTensorKind::$Backend(inner) => inner.$kind(),\n                    _ => panic!(\n                        \"Autodiff float tensor is on the wrong backend (expected {}).\",\n                        stringify!($Backend)\n                    ),\n                },\n                _ => panic!(\n                    \"Expected autodiff-wrapped float tensor for backend {}.\",\n                    stringify!($Backend)\n                ),\n            })\n            .collect::<Vec<_>>()\n    };\n}\n\n/// Match arm generator for `vec_op`.\nmacro_rules! vec_op_arms {\n    (Float, $inner_kind:ident, $tensors:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),*) => {{\n        let first = &$tensors[0];\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = first.checkpointing;\n\n        match &first.kind {\n            // Autodiff arm first\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(inner) => {\n                // Recursively dispatch on inner\n                match **inner {\n                    $(\n                    #[cfg($cfg)]\n                    $crate::DispatchTensorKind::$Backend(_) => {\n                        with_autodiff_backend!($Backend, checkpointing, |B| {\n                            let $inner = unwrap_vec!(@autodiff $Backend, $tensors, autodiff);\n                            wrap_float!( @wrap_autodiff Float, $Backend, checkpointing, { $body } )\n                        })\n                    }\n                )*\n                    $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend(_) => {\n                    type B = $Backend<f32>;\n\n                    let $inner = unwrap_vec!($Backend, $tensors, $inner_kind);\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::Float($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n        }\n    }};\n    ($kind:ident, $inner_kind:ident, $tensors:expr, |$inner:ident| $body:expr; $([$Backend:ident, $cfg:meta]),*) => {{\n        let first = &$tensors[0];\n        #[cfg(feature = \"autodiff\")]\n        let checkpointing = first.checkpointing;\n        match first.kind {\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend(_) => {\n                    type B = $Backend<f32>;\n\n                    let $inner = unwrap_vec!($Backend, $tensors, $inner_kind);\n                    $crate::DispatchTensor {\n                        kind: $crate::DispatchTensorKind::$Backend($crate::BackendTensor::$kind($body)),\n                        #[cfg(feature = \"autodiff\")]\n                        checkpointing,\n                    }\n                }\n            )*\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Operation not marked for autodiff.\")\n        }\n    }};\n}\n\n/// Backend dispatch for operations on multiple inputs (vec).\n/// Automatically verifies that tensors reside on the first backend.\nmacro_rules! vec_op {\n    ($tensors:expr, $inner_kind:ident, |$inner:ident| $body:expr => $kind:ident) => {\n        backend_list!(vec_op_arms, $kind, $inner_kind, $tensors, |$inner| {\n            $body\n        })\n    };\n}\n\n/// Match arm generator for `transaction_op`.\nmacro_rules! transaction_op_arms {\n    ($tx:ident, $first:expr; $([$Backend:ident, $cfg:meta]),*) => {{\n        match &$first.kind {\n            // Autodiff arm first\n            #[cfg(feature = \"autodiff\")]\n            $crate::DispatchTensorKind::Autodiff(inner) => {\n                // Recursively dispatch on inner\n                match **inner {\n                    $(\n                    #[cfg($cfg)]\n                    $crate::DispatchTensorKind::$Backend(_) => {\n                        type B = $Backend<f32>;\n\n                        // Unwrap vec\n                        let floats = unwrap_vec!(@autodiff $Backend, $tx.read_floats, autodiff_inner);\n                        let ints = unwrap_vec!($Backend, $tx.read_ints, int);\n                        let bools = unwrap_vec!($Backend, $tx.read_bools, bool);\n                        // Not supported\n                        let qfloats = $tx.read_qfloats.into_iter().map(|_t| todo!(\"Quantization not supported yet\")).collect();\n\n                        B::tr_execute(TransactionPrimitive::new(floats, qfloats, ints, bools)).await\n                    }\n                )*\n                    $crate::DispatchTensorKind::Autodiff(..) => panic!(\"Autodiff should not wrap an autodiff tensor.\")\n                }\n            },\n\n            $(\n                #[cfg($cfg)]\n                $crate::DispatchTensorKind::$Backend(_) => {\n                    type B = $Backend<f32>;\n\n                    // Unwrap vec\n                    let floats = unwrap_vec!($Backend, $tx.read_floats, float);\n                    let ints = unwrap_vec!($Backend, $tx.read_ints, int);\n                    let bools = unwrap_vec!($Backend, $tx.read_bools, bool);\n                    // Not supported\n                    let qfloats = $tx.read_qfloats.into_iter().map(|_t| todo!(\"Quantization not supported yet\")).collect();\n\n                    B::tr_execute(TransactionPrimitive::new(floats, qfloats, ints, bools)).await\n                }\n            )*\n        }\n    }};\n}\n\n/// Helper to dispatch a transaction based on the first available tensor.\nmacro_rules! transaction_op {\n    ($tx:ident, $first:expr) => {\n        backend_list!(transaction_op_arms, $tx, $first)\n    };\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/activation.rs",
    "content": "use burn_backend::{Scalar, ops::ActivationOps, tensor::FloatTensor};\n\nuse crate::Dispatch;\nuse crate::backends::*;\n\nimpl ActivationOps<Self> for Dispatch {\n    fn leaky_relu(tensor: FloatTensor<Self>, negative_slope: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::leaky_relu(tensor, negative_slope) => Float)\n    }\n\n    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::relu(tensor) => Float)\n    }\n\n    fn relu_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((output, float), (grad, float), |output, grad| B::relu_backward(output, grad) => Float)\n    }\n\n    fn gelu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::gelu(tensor) => Float)\n    }\n\n    fn prelu(tensor: FloatTensor<Self>, alpha: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (alpha, float), |tensor, alpha| B::prelu(tensor, alpha) => Float)\n    }\n\n    fn gelu_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((x, float), (grad, float), |x, grad| B::gelu_backward(x, grad) => Float)\n    }\n\n    fn sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::sigmoid(tensor) => Float)\n    }\n\n    fn sigmoid_backward(output: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((output, float), (grad, float), |output, grad| B::sigmoid_backward(output, grad) => Float)\n    }\n\n    fn hard_sigmoid(tensor: FloatTensor<Self>, alpha: Scalar, beta: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::hard_sigmoid(tensor, alpha, beta) => Float)\n    }\n\n    fn log_sigmoid(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::log_sigmoid(tensor) => Float)\n    }\n\n    fn log_sigmoid_backward(x: FloatTensor<Self>, grad: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((x, float), (grad, float), |x, grad| B::log_sigmoid_backward(x, grad) => Float)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/bool_tensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, Scalar, TensorData,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, FloatTensor, IntTensor},\n};\nuse burn_std::{Shape, Slice};\n\nuse crate::backends::*;\nuse crate::{Dispatch, DispatchDevice};\n\nimpl BoolTensorOps<Self> for Dispatch {\n    fn bool_empty(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {\n        creation_op!(Bool, device, |device| B::bool_empty(shape, device))\n    }\n\n    fn bool_zeros(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {\n        creation_op!(Bool, device, |device| B::bool_zeros(shape, device))\n    }\n\n    fn bool_ones(shape: Shape, device: &DispatchDevice) -> BoolTensor<Self> {\n        creation_op!(Bool, device, |device| B::bool_ones(shape, device))\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unary_op!(tensor, bool, |tensor| B::bool_into_data(tensor).await)\n    }\n\n    fn bool_from_data(data: TensorData, device: &DispatchDevice) -> BoolTensor<Self> {\n        creation_op!(Bool, device, |device| B::bool_from_data(data, device))\n    }\n\n    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_into_int(tensor) => Int)\n    }\n\n    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_into_float(tensor) => Float)\n    }\n\n    fn bool_device(tensor: &BoolTensor<Self>) -> DispatchDevice {\n        tensor.device()\n    }\n\n    fn bool_to_device(tensor: BoolTensor<Self>, device: &DispatchDevice) -> BoolTensor<Self> {\n        to_device!(\n            Bool,\n            bool,\n            tensor,\n            device,\n            bool_to_device,\n            |inner, device| {\n                let data =\n                    burn_backend::read_sync(B1::bool_into_data(inner)).expect(\"Should read data\");\n                B2::bool_from_data(data, device)\n            }\n        )\n    }\n\n    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_reshape(tensor, shape) => Bool)\n    }\n\n    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_slice(tensor, slices) => Bool)\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        slices: &[Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        binary_op!((tensor, bool), (value, bool), |tensor, value| B::bool_slice_assign(tensor, slices, value) => Bool)\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        multi_op!(\n            inputs[(tensor, bool), (mask, bool), (value, bool)], => Bool,\n            B::bool_mask_where(tensor, mask, value)\n        )\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        binary_op!((tensor, bool), (mask, bool), |tensor, mask| B::bool_mask_fill(tensor, mask, value) => Bool)\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_gather(dim, tensor, indices) => Bool)\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        multi_op!(\n            inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,\n            B::bool_scatter_or(dim, tensor, indices, value)\n        )\n    }\n\n    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_equal(lhs, rhs) => Bool)\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, bool, |lhs| B::bool_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_not(tensor) => Bool)\n    }\n\n    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_and(lhs, rhs) => Bool)\n    }\n\n    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_or(lhs, rhs) => Bool)\n    }\n\n    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_swap_dims(tensor, dim1, dim2) => Bool)\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_permute(tensor, axes) => Bool)\n    }\n\n    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_flip(tensor, axes) => Bool)\n    }\n\n    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_expand(tensor, shape) => Bool)\n    }\n\n    fn bool_unfold(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_unfold(tensor, dim, size, step) => Bool)\n    }\n\n    fn bool_select(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        binary_op!((tensor, bool), (indices, int), |tensor, indices| B::bool_select(tensor, dim, indices) => Bool)\n    }\n\n    fn bool_select_or(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        multi_op!(\n            inputs[(tensor, bool), (indices, int), (value, bool)], => Bool,\n            B::bool_select_or(tensor, dim, indices, value)\n        )\n    }\n\n    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_repeat_dim(tensor, dim, times) => Bool)\n    }\n\n    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {\n        vec_op!(tensors, bool, |tensors| B::bool_cat(tensors, dim) => Bool)\n    }\n\n    fn bool_not_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_not_equal(lhs, rhs) => Bool)\n    }\n\n    fn bool_not_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, bool, |lhs| B::bool_not_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn bool_xor(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, bool), (rhs, bool), |lhs, rhs| B::bool_xor(lhs, rhs) => Bool)\n    }\n\n    fn bool_transpose(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_transpose(tensor) => Bool)\n    }\n\n    fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_any(tensor) => Bool)\n    }\n\n    fn bool_any_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_any_dim(tensor, dim) => Bool)\n    }\n\n    fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_all(tensor) => Bool)\n    }\n\n    fn bool_all_dim(tensor: BoolTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_all_dim(tensor, dim) => Bool)\n    }\n\n    async fn bool_argwhere(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, bool, |tensor| B::bool_argwhere(tensor).await => Int)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/int_tensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, Scalar, TensorData,\n    ops::IntTensorOps,\n    tensor::{BoolTensor, FloatTensor, IntTensor},\n};\nuse burn_std::{IntDType, Shape, Slice};\n\nuse crate::backends::*;\nuse crate::{Dispatch, DispatchDevice};\n\nimpl IntTensorOps<Self> for Dispatch {\n    fn int_empty(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_empty(shape, device, dtype))\n    }\n\n    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unary_op!(tensor, int, |tensor| B::int_into_data(tensor).await)\n    }\n\n    fn int_from_data(data: TensorData, device: &DispatchDevice) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_from_data(data, device))\n    }\n\n    fn int_device(tensor: &IntTensor<Self>) -> DispatchDevice {\n        tensor.device()\n    }\n\n    fn int_to_device(tensor: IntTensor<Self>, device: &DispatchDevice) -> IntTensor<Self> {\n        to_device!(Int, int, tensor, device, int_to_device, |inner, device| {\n            let data = burn_backend::read_sync(B1::int_into_data(inner)).expect(\"Should read data\");\n            B2::int_from_data(data, device)\n        })\n    }\n\n    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_reshape(tensor, shape) => Int)\n    }\n\n    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_slice(tensor, slices) => Int)\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<Self>,\n        slices: &[Slice],\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        binary_op!((tensor, int), (value, int), |tensor, value| B::int_slice_assign(tensor, slices, value) => Int)\n    }\n\n    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_into_float(tensor) => Float)\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        multi_op!(\n            inputs[(tensor, int), (mask, bool), (value, int)], => Int,\n            B::int_mask_where(tensor, mask, value)\n        )\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> IntTensor<Self> {\n        binary_op!((tensor, int), (mask, bool), |tensor, mask| B::int_mask_fill(tensor, mask, value) => Int)\n    }\n\n    fn int_gather(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_gather(dim, tensor, indices) => Int)\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        multi_op!(\n            inputs[(tensor, int), (indices, int), (value, int)], => Int,\n            B::int_scatter_add(dim, tensor, indices, value)\n        )\n    }\n\n    fn int_select(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        binary_op!((tensor, int), (indices, int), |tensor, indices| B::int_select(tensor, dim, indices) => Int)\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        multi_op!(\n            inputs[(tensor, int), (indices, int), (value, int)], => Int,\n            B::int_select_add(tensor, dim, indices, value)\n        )\n    }\n\n    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_equal(lhs, rhs) => Bool)\n    }\n\n    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater(lhs, rhs) => Bool)\n    }\n\n    fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_greater_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_greater_equal(lhs, rhs) => Bool)\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_greater_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower(lhs, rhs) => Bool)\n    }\n\n    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_lower_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_lower_equal(lhs, rhs) => Bool)\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_lower_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_add(lhs, rhs) => Int)\n    }\n\n    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_add_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_sub(lhs, rhs) => Int)\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_sub_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_mul(lhs, rhs) => Int)\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_mul_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_div(lhs, rhs) => Int)\n    }\n\n    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_div_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_remainder(lhs, rhs) => Int)\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_remainder_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_matmul(lhs, rhs) => Int)\n    }\n\n    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_sum(tensor) => Int)\n    }\n\n    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_sum_dim(tensor, dim) => Int)\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_prod(tensor) => Int)\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_prod_dim(tensor, dim) => Int)\n    }\n\n    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_mean_dim(tensor, dim) => Int)\n    }\n\n    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_cumsum(tensor, dim) => Int)\n    }\n\n    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_cumprod(tensor, dim) => Int)\n    }\n\n    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_cummin(tensor, dim) => Int)\n    }\n\n    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_cummax(tensor, dim) => Int)\n    }\n\n    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_argmax(tensor, dim) => Int)\n    }\n\n    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_argmin(tensor, dim) => Int)\n    }\n\n    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_abs(tensor) => Int)\n    }\n\n    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_swap_dims(tensor, dim1, dim2) => Int)\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_permute(tensor, axes) => Int)\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_flip(tensor, axes) => Int)\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: burn_backend::Distribution,\n        device: &DispatchDevice,\n    ) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| {\n            B::int_random(shape, distribution, device)\n        })\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_expand(tensor, shape) => Int)\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_and(lhs, rhs) => Int)\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::bitwise_and_scalar(lhs, rhs) => Int)\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_or(lhs, rhs) => Int)\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::bitwise_or_scalar(lhs, rhs) => Int)\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_xor(lhs, rhs) => Int)\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::bitwise_xor_scalar(lhs, rhs) => Int)\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::bitwise_not(tensor) => Int)\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_left_shift(lhs, rhs) => Int)\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::bitwise_left_shift_scalar(lhs, rhs) => Int)\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::bitwise_right_shift(lhs, rhs) => Int)\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::bitwise_right_shift_scalar(lhs, rhs) => Int)\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_cast(tensor, dtype) => Int)\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_unfold(tensor, dim, size, step) => Int)\n    }\n\n    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_repeat_dim(tensor, dim, times) => Int)\n    }\n\n    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {\n        vec_op!(tensors, int, |tensors| B::int_cat(tensors, dim) => Int)\n    }\n\n    fn int_not_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_not_equal(lhs, rhs) => Bool)\n    }\n\n    fn int_not_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_not_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn int_powi(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_op!((lhs, int), (rhs, int), |lhs, rhs| B::int_powi(lhs, rhs) => Int)\n    }\n\n    fn int_powi_scalar_impl(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        unary_op!(lhs, int, |lhs| B::int_powi_scalar_impl(lhs, rhs) => Int)\n    }\n\n    fn int_clamp_min(tensor: IntTensor<Self>, min: Scalar) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_clamp_min(tensor, min) => Int)\n    }\n\n    fn int_clamp_max(tensor: IntTensor<Self>, max: Scalar) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_clamp_max(tensor, max) => Int)\n    }\n\n    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_clamp(tensor, min, max) => Int)\n    }\n\n    fn int_neg(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_neg(tensor) => Int)\n    }\n\n    fn int_zeros(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_zeros(shape, device, dtype))\n    }\n\n    fn int_ones(shape: Shape, device: &DispatchDevice, dtype: IntDType) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_ones(shape, device, dtype))\n    }\n\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &DispatchDevice,\n        dtype: IntDType,\n    ) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_full(\n            shape, fill_value, device, dtype\n        ))\n    }\n\n    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_mean(tensor) => Int)\n    }\n\n    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_max(tensor) => Int)\n    }\n\n    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_max_dim(tensor, dim) => Int)\n    }\n\n    fn int_max_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, int)],\n            outputs[(out, Int), (indices, Int)],\n            B::int_max_dim_with_indices(tensor, dim)\n        )\n    }\n\n    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_max_abs(tensor) => Int)\n    }\n\n    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_max_abs_dim(tensor, dim) => Int)\n    }\n\n    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_min(tensor) => Int)\n    }\n\n    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_min_dim(tensor, dim) => Int)\n    }\n\n    fn int_min_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, int)],\n            outputs[(out, Int), (indices, Int)],\n            B::int_min_dim_with_indices(tensor, dim)\n        )\n    }\n\n    fn int_transpose(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_transpose(tensor) => Int)\n    }\n\n    fn int_arange_step(\n        range: std::ops::Range<i64>,\n        step: usize,\n        device: &DispatchDevice,\n    ) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_arange_step(\n            range, step, device\n        ))\n    }\n\n    fn int_arange(range: std::ops::Range<i64>, device: &DispatchDevice) -> IntTensor<Self> {\n        creation_op!(Int, device, |device| B::int_arange(range, device))\n    }\n\n    fn int_any(tensor: IntTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_any(tensor) => Bool)\n    }\n\n    fn int_any_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_any_dim(tensor, dim) => Bool)\n    }\n\n    fn int_all(tensor: IntTensor<Self>) -> BoolTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_all(tensor) => Bool)\n    }\n\n    fn int_all_dim(tensor: IntTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_all_dim(tensor, dim) => Bool)\n    }\n\n    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_sign(tensor) => Int)\n    }\n\n    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_sort(tensor, dim, descending) => Int)\n    }\n\n    fn int_sort_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        descending: bool,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, int)],\n            outputs[(out, Int), (indices, Int)],\n            B::int_sort_with_indices(tensor, dim, descending)\n        )\n    }\n\n    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        unary_op!(tensor, int, |tensor| B::int_argsort(tensor, dim, descending) => Int)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/mod.rs",
    "content": "mod activation;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/module.rs",
    "content": "use burn_backend::{\n    ops::{\n        DeformConv2dBackward, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,\n        MaxPool2dWithIndices, ModuleOps,\n    },\n    tensor::{FloatTensor, IntTensor},\n};\n\nuse crate::Dispatch;\nuse crate::backends::*;\n\nimpl ModuleOps<Self> for Dispatch {\n    fn conv2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv2d(x, weight, bias, options)\n        )\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (offset, float), (weight, float)],\n            opt_inputs[(mask, float), (bias, float)],\n            => Float,\n            B::deform_conv2d(x, offset, weight, mask, bias, options)\n        )\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        let (x_grad, offset_grad, weight_grad, mask_grad, bias_grad) = multi_op!(\n            inputs[(x, float), (offset, float), (weight, float), (output_grad, float)],\n            opt_inputs[(mask, float), (bias, float)],\n            outputs[(x_grad, Float), (offset_grad, Float), (weight_grad, Float)],\n            opt_outputs[mask_grad, bias_grad],\n            {\n                let res = B::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options);\n                (res.x_grad, res.offset_grad, res.weight_grad, res.mask_grad, res.bias_grad)\n            }\n        );\n        DeformConv2dBackward::new(x_grad, offset_grad, weight_grad, mask_grad, bias_grad)\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv3d(x, weight, bias, options)\n        )\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv_transpose2d(x, weight, bias, options)\n        )\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv_transpose3d(x, weight, bias, options)\n        )\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(inputs[(x, float)],\n            => Float,\n            B::avg_pool2d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)\n        )\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (grad, float)],\n            => Float,\n            B::avg_pool2d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)\n        )\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float)],\n            => Float,\n            B::adaptive_avg_pool2d(x, output_size)\n        )\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (grad, float)],\n            => Float,\n            B::adaptive_avg_pool2d_backward(x, grad)\n        )\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float)],\n            => Float,\n            B::max_pool2d(x, kernel_size, stride, padding, dilation, ceil_mode)\n        )\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        let (out, indices) = multi_op!(\n            inputs[(x, float)],\n            outputs[(out, Float), (indices, Int)],\n            {\n                let res = B::max_pool2d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);\n                (res.output, res.indices)\n            }\n        );\n        MaxPool2dWithIndices::new(out, indices)\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool2dBackward<Self> {\n        let x_grad = multi_op!(\n            inputs[(x, float), (output_grad, float), (indices, int)],\n            => Float,\n            {\n                let res = B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);\n                res.x_grad\n            }\n        );\n        MaxPool2dBackward::new(x_grad)\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: burn_backend::ops::InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float)],\n            => Float,\n            B::interpolate(x, output_size, options)\n        )\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: burn_backend::ops::InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (grad, float)],\n            => Float,\n            B::interpolate_backward(x, grad, output_size, options)\n        )\n    }\n\n    fn embedding(weights: FloatTensor<Self>, indices: IntTensor<Self>) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(weights, float), (indices, int)],\n            => Float,\n            B::embedding(weights, indices)\n        )\n    }\n\n    fn embedding_backward(\n        weights: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(weights, float), (output_grad, float), (indices, int)],\n            => Float,\n            B::embedding_backward(weights, output_grad, indices)\n        )\n    }\n\n    fn conv1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv1d(x, weight, bias, options)\n        )\n    }\n\n    fn conv1d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv1d_x_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv1d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv1d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv1d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv1d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn conv2d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv2d_x_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv2d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv2d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv2d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv2d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn conv3d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv3d_x_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv3d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv3d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv3d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv3d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn conv_transpose1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float)],\n            opt_inputs[(bias, float)],\n            => Float,\n            B::conv_transpose1d(x, weight, bias, options)\n        )\n    }\n\n    fn conv_transpose1d_x_backward(\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose1d_x_backward(weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose1d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose1d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose1d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose1d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn conv_transpose2d_x_backward(\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose2d_x_backward(weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose2d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose2d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose2d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose2d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn conv_transpose3d_x_backward(\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose3d_x_backward(weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose3d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: burn_backend::ops::ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (weight, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose3d_weight_backward(x, weight, output_grad, options)\n        )\n    }\n\n    fn conv_transpose3d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (bias, float), (output_grad, float)],\n            => Float,\n            B::conv_transpose3d_bias_backward(x, bias, output_grad)\n        )\n    }\n\n    fn unfold4d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        options: burn_backend::ops::UnfoldOptions,\n    ) -> FloatTensor<Self> {\n        multi_op!(inputs[(x, float)], => Float, B::unfold4d(x, kernel_size, options))\n    }\n\n    fn avg_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(inputs[(x, float)], => Float,\n            B::avg_pool1d(x, kernel_size, stride, padding, count_include_pad, ceil_mode)\n        )\n    }\n\n    fn avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (grad, float)],\n            => Float,\n            B::avg_pool1d_backward(x, grad, kernel_size, stride, padding, count_include_pad, ceil_mode)\n        )\n    }\n\n    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {\n        multi_op!(inputs[(x, float)], => Float, B::adaptive_avg_pool1d(x, output_size))\n    }\n\n    fn adaptive_avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(x, float), (grad, float)],\n            => Float,\n            B::adaptive_avg_pool1d_backward(x, grad)\n        )\n    }\n\n    fn max_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        multi_op!(inputs[(x, float)], => Float,\n            B::max_pool1d(x, kernel_size, stride, padding, dilation, ceil_mode))\n    }\n\n    fn max_pool1d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<Self> {\n        let (out, indices) = multi_op!(\n            inputs[(x, float)],\n            outputs[(out, Float), (indices, Int)],\n            {\n                let res = B::max_pool1d_with_indices(x, kernel_size, stride, padding, dilation, ceil_mode);\n                (res.output, res.indices)\n            }\n        );\n        MaxPool1dWithIndices::new(out, indices)\n    }\n\n    fn max_pool1d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool1dBackward<Self> {\n        let x_grad = multi_op!(\n            inputs[(x, float), (output_grad, float), (indices, int)],\n            => Float,\n            {\n                let res = B::max_pool1d_with_indices_backward(x, kernel_size, stride, padding, dilation, ceil_mode, output_grad, indices);\n                res.x_grad\n            }\n        );\n        MaxPool1dBackward::new(x_grad)\n    }\n\n    fn attention(\n        query: FloatTensor<Self>,\n        key: FloatTensor<Self>,\n        value: FloatTensor<Self>,\n        mask: Option<burn_backend::tensor::BoolTensor<Self>>,\n        attn_bias: Option<FloatTensor<Self>>,\n        options: burn_backend::ops::AttentionModuleOptions,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(query, float), (key, float), (value, float)],\n            opt_inputs[(mask, bool), (attn_bias, float)],\n            => Float,\n            B::attention(query, key, value, mask, attn_bias, options)\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, QTensorPrimitive, TensorData, TensorPrimitive,\n    ops::QTensorOps,\n    quantization::QuantizationParametersPrimitive,\n    tensor::{FloatTensor, IntTensor, QuantizedTensor},\n};\nuse burn_std::{QuantPropagation, Shape, Slice};\n\nuse crate::backends::*;\nuse crate::{Dispatch, DispatchDevice};\n\nimpl QTensorOps<Self> for Dispatch {\n    fn q_from_data(data: TensorData, device: &DispatchDevice) -> QuantizedTensor<Self> {\n        creation_op!(Quantized, device, |device| B::q_from_data(data, device))\n    }\n\n    fn quantize(\n        tensor: FloatTensor<Self>,\n        scheme: &burn_std::QuantScheme,\n        qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        binary_op!(\n            (tensor, float),\n            (qparams.scales, float),\n            |tensor, scales| {\n                B::quantize(tensor, scheme, QuantizationParametersPrimitive { scales })\n            } => Quantized\n        )\n    }\n\n    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::dequantize(tensor) => Float)\n    }\n\n    fn q_device(tensor: &QuantizedTensor<Self>) -> DispatchDevice {\n        tensor.device()\n    }\n\n    fn q_to_device(\n        tensor: QuantizedTensor<Self>,\n        device: &DispatchDevice,\n    ) -> QuantizedTensor<Self> {\n        to_device!(\n            Quantized,\n            quantized,\n            tensor,\n            device,\n            q_to_device,\n            |inner, device| {\n                let data =\n                    burn_backend::read_sync(B1::q_into_data(inner)).expect(\"Should read data\");\n                B2::q_from_data(data, device)\n            }\n        )\n    }\n\n    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_reshape(tensor, shape) => Quantized)\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unary_op!(tensor, quantized, |tensor| B::q_into_data(tensor).await)\n    }\n\n    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_expand(tensor, shape) => Quantized)\n    }\n\n    fn q_swap_dims(\n        tensor: QuantizedTensor<Self>,\n        dim1: usize,\n        dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_swap_dims(tensor, dim1, dim2) => Quantized)\n    }\n\n    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_permute(tensor, axes) => Quantized)\n    }\n\n    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_flip(tensor, axes) => Quantized)\n    }\n\n    fn q_select(\n        tensor: QuantizedTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        binary_op!(\n            (tensor, quantized),\n            (indices, int),\n            |tensor, indices| B::q_select(tensor, dim, indices) => Quantized\n        )\n    }\n\n    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {\n        unary_op!(tensor, quantized, |tensor| B::q_slice(tensor, slices) => Quantized)\n    }\n\n    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {\n        // TODO: this would be much cleaner if we consolidated tensor primitive types\n        match (lhs, rhs) {\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {\n                if matches!(lhs.propagation(), QuantPropagation::Propagate) {\n                    let out = binary_op!(\n                        (lhs, quantized),\n                        (rhs, quantized),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::QFloat(out) = B::q_matmul(\n                                TensorPrimitive::QFloat(lhs),\n                                TensorPrimitive::QFloat(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Quantized\n                    );\n                    TensorPrimitive::QFloat(out)\n                } else {\n                    let out = binary_op!(\n                        (lhs, quantized),\n                        (rhs, quantized),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::Float(out) = B::q_matmul(\n                                TensorPrimitive::QFloat(lhs),\n                                TensorPrimitive::QFloat(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Float\n                    );\n                    TensorPrimitive::Float(out)\n                }\n            }\n            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {\n                if matches!(rhs.propagation(), QuantPropagation::Propagate) {\n                    let out = binary_op!(\n                        (lhs, float),\n                        (rhs, quantized),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::QFloat(out) = B::q_matmul(\n                                TensorPrimitive::Float(lhs),\n                                TensorPrimitive::QFloat(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Quantized\n                    );\n                    TensorPrimitive::QFloat(out)\n                } else {\n                    let out = binary_op!(\n                        (lhs, float),\n                        (rhs, quantized),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::Float(out) = B::q_matmul(\n                                TensorPrimitive::Float(lhs),\n                                TensorPrimitive::QFloat(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Float\n                    );\n                    TensorPrimitive::Float(out)\n                }\n            }\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {\n                if matches!(lhs.propagation(), QuantPropagation::Propagate) {\n                    let out = binary_op!(\n                        (lhs, quantized),\n                        (rhs, float),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::QFloat(out) = B::q_matmul(\n                                TensorPrimitive::QFloat(lhs),\n                                TensorPrimitive::Float(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Quantized\n                    );\n                    TensorPrimitive::QFloat(out)\n                } else {\n                    let out = binary_op!(\n                        (lhs, quantized),\n                        (rhs, float),\n                        |lhs, rhs| {\n                            if let TensorPrimitive::Float(out) = B::q_matmul(\n                                TensorPrimitive::QFloat(lhs),\n                                TensorPrimitive::Float(rhs),\n                            ) {\n                                out\n                            } else {\n                                unreachable!()\n                            }\n                        } => Float\n                    );\n                    TensorPrimitive::Float(out)\n                }\n            }\n            _ => unreachable!(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/tensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, Scalar, TensorData,\n    ops::FloatTensorOps,\n    tensor::{BoolTensor, FloatTensor, IntTensor},\n};\nuse burn_std::{FloatDType, Shape, Slice};\n\nuse crate::backends::*;\nuse crate::{Dispatch, DispatchDevice};\n\n// TODO: remove backend default elem type genericsnow that we have per-device defaults\n// https://github.com/tracel-ai/burn/issues/3642\n\nimpl FloatTensorOps<Self> for Dispatch {\n    fn float_from_data(\n        data: burn_backend::TensorData,\n        device: &DispatchDevice,\n    ) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| B::float_from_data(data, device))\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: burn_backend::Distribution,\n        device: &DispatchDevice,\n    ) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| {\n            B::float_random(shape, distribution, device)\n        })\n    }\n\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unary_float!(tensor, float, |tensor| B::float_into_data(tensor).await)\n    }\n\n    fn float_device(tensor: &FloatTensor<Self>) -> DispatchDevice {\n        tensor.device()\n    }\n\n    fn float_to_device(tensor: FloatTensor<Self>, device: &DispatchDevice) -> FloatTensor<Self> {\n        float_to_device!(\n            Float,\n            float,\n            tensor,\n            device,\n            float_to_device,\n            |inner, device| {\n                let data =\n                    burn_backend::read_sync(B1::float_into_data(inner)).expect(\"Should read data\");\n                B2::float_from_data(data, device)\n            }\n        )\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_into_int(tensor) => Int)\n    }\n\n    fn float_empty(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| B::float_empty(shape, device, dtype))\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_add(lhs, rhs) => Float)\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_add_scalar(lhs, rhs) => Float)\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_sub(lhs, rhs) => Float)\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_sub_scalar(lhs, rhs) => Float)\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_mul(lhs, rhs) => Float)\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_mul_scalar(lhs, rhs) => Float)\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_div(lhs, rhs) => Float)\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_div_scalar(lhs, rhs) => Float)\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_remainder(lhs, rhs) => Float)\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_remainder_scalar(lhs, rhs) => Float)\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_matmul(lhs, rhs) => Float)\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_cross(lhs, rhs, dim) => Float)\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_recip(tensor) => Float)\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_swap_dims(tensor, dim1, dim2) => Float)\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_permute(tensor, axes) => Float)\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_flip(tensor, axes) => Float)\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_reshape(tensor, shape) => Float)\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_gather(dim, tensor, indices) => Float)\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(tensor, float), (indices, int), (value, float)], => Float,\n            B::float_scatter_add(dim, tensor, indices, value)\n        )\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (indices, int), |tensor, indices| B::float_select(tensor, dim, indices) => Float)\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(tensor, float), (indices, int), (value, float)], => Float,\n            B::float_select_add(tensor, dim, indices, value)\n        )\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_slice(tensor, slices) => Float)\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (value, float), |tensor, value| B::float_slice_assign(tensor, slices, value) => Float)\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        multi_op!(\n            inputs[(tensor, float), (mask, bool), (value, float)], => Float,\n            B::float_mask_where(tensor, mask, value)\n        )\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (mask, bool), |tensor, mask| B::float_mask_fill(tensor, mask, value) => Float)\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_equal(lhs, rhs) => Bool)\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater(lhs, rhs) => Bool)\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_greater_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_greater_equal(lhs, rhs) => Bool)\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_greater_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower(lhs, rhs) => Bool)\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_lower_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_lower_equal(lhs, rhs) => Bool)\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_lower_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sum(tensor) => Float)\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sum_dim(tensor, dim) => Float)\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_mean_dim(tensor, dim) => Float)\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cumsum(tensor, dim) => Float)\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cumprod(tensor, dim) => Float)\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cummin(tensor, dim) => Float)\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cummax(tensor, dim) => Float)\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cast(tensor, dtype) => Float)\n    }\n\n    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_exp(tensor) => Float)\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_log(tensor) => Float)\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_log1p(tensor) => Float)\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_powf(lhs, rhs) => Float)\n    }\n\n    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_powf_scalar_impl(tensor, value) => Float)\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sqrt(tensor) => Float)\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_abs(tensor) => Float)\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cos(tensor) => Float)\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sin(tensor) => Float)\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_tan(tensor) => Float)\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_cosh(tensor) => Float)\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sinh(tensor) => Float)\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_tanh(tensor) => Float)\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_acos(tensor) => Float)\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_acosh(tensor) => Float)\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_asin(tensor) => Float)\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_asinh(tensor) => Float)\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_atan(tensor) => Float)\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_atanh(tensor) => Float)\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_atan2(lhs, rhs) => Float)\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_round(tensor) => Float)\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_floor(tensor) => Float)\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_ceil(tensor) => Float)\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_trunc(tensor) => Float)\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_erf(tensor) => Float)\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_argmax(tensor, dim) => Int)\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_argmin(tensor, dim) => Int)\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_expand(tensor, shape) => Float)\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| {\n            B::float_unfold(tensor, dim, size, step)\n        } => Float)\n    }\n\n    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_detach(tensor) => Float)\n    }\n\n    fn float_set_require_grad(tensor: FloatTensor<Self>, require_grad: bool) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_set_require_grad(tensor, require_grad) => Float)\n    }\n\n    fn float_is_require_grad(tensor: &FloatTensor<Self>) -> bool {\n        unary_float!(ref tensor, float, |tensor| B::float_is_require_grad(tensor))\n    }\n\n    // Default implementation\n    fn float_zeros(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| B::float_zeros(shape, device, dtype))\n    }\n\n    fn float_ones(shape: Shape, device: &DispatchDevice, dtype: FloatDType) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| B::float_ones(shape, device, dtype))\n    }\n\n    fn float_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &DispatchDevice,\n        dtype: FloatDType,\n    ) -> FloatTensor<Self> {\n        creation_op!(Float, device, |device| B::float_full(\n            shape, fill_value, device, dtype\n        ))\n    }\n\n    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_repeat_dim(tensor, dim, times) => Float)\n    }\n\n    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_clamp_min(tensor, min) => Float)\n    }\n\n    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_clamp_max(tensor, max) => Float)\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_clamp(tensor, min, max) => Float)\n    }\n\n    fn float_neg(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_neg(tensor) => Float)\n    }\n\n    fn float_transpose(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_transpose(tensor) => Float)\n    }\n\n    fn float_not_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float!((lhs, float), (rhs, float), |lhs, rhs| B::float_not_equal(lhs, rhs) => Bool)\n    }\n\n    fn float_not_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_not_equal_elem(lhs, rhs) => Bool)\n    }\n\n    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_prod(tensor) => Float)\n    }\n\n    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_prod_dim(tensor, dim) => Float)\n    }\n\n    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_mean(tensor) => Float)\n    }\n\n    fn float_powi(lhs: FloatTensor<Self>, rhs: IntTensor<Self>) -> FloatTensor<Self> {\n        binary_float!((lhs, float), (rhs, int), |lhs, rhs| B::float_powi(lhs, rhs) => Float)\n    }\n\n    fn float_powi_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        unary_float!(lhs, float, |lhs| B::float_powi_scalar_impl(lhs, rhs) => Float)\n    }\n\n    fn float_powf_scalar(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_powf_scalar(tensor, value) => Float)\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        vec_op!(tensors, float, |tensors| B::float_cat(tensors, dim) => Float)\n    }\n\n    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_max(tensor) => Float)\n    }\n\n    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_max_dim(tensor, dim) => Float)\n    }\n\n    fn float_max_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, float)],\n            outputs[(out, Float), (indices, Int)],\n            B::float_max_dim_with_indices(tensor, dim)\n        )\n    }\n\n    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_min(tensor) => Float)\n    }\n\n    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_min_dim(tensor, dim) => Float)\n    }\n\n    fn float_min_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, float)],\n            outputs[(out, Float), (indices, Int)],\n            B::float_min_dim_with_indices(tensor, dim)\n        )\n    }\n\n    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_max_abs(tensor) => Float)\n    }\n\n    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_max_abs_dim(tensor, dim) => Float)\n    }\n\n    fn float_any(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_any(tensor) => Bool)\n    }\n\n    fn float_any_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_any_dim(tensor, dim) => Bool)\n    }\n\n    fn float_all(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_all(tensor) => Bool)\n    }\n\n    fn float_all_dim(tensor: FloatTensor<Self>, dim: usize) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_all_dim(tensor, dim) => Bool)\n    }\n\n    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sign(tensor) => Float)\n    }\n\n    fn float_sort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> FloatTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_sort(tensor, dim, descending) => Float)\n    }\n\n    fn float_sort_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        descending: bool,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        multi_op!(\n            inputs[(tensor, float)],\n            outputs[(out, Float), (indices, Int)],\n            B::float_sort_with_indices(tensor, dim, descending)\n        )\n    }\n\n    fn float_argsort(tensor: FloatTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_argsort(tensor, dim, descending) => Int)\n    }\n\n    fn float_grid_sample_2d(\n        tensor: FloatTensor<Self>,\n        grid: FloatTensor<Self>,\n        options: burn_backend::ops::GridSampleOptions,\n    ) -> FloatTensor<Self> {\n        binary_float!((tensor, float), (grid, float), |tensor, grid| B::float_grid_sample_2d(tensor, grid, options) => Float)\n    }\n\n    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_is_nan(tensor) => Bool)\n    }\n\n    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        unary_float!(tensor, float, |tensor| B::float_is_inf(tensor) => Bool)\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/ops/transaction.rs",
    "content": "use burn_backend::{\n    ExecutionError,\n    ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData},\n};\n\nuse crate::Dispatch;\nuse crate::backends::*;\n\nimpl TransactionOps<Self> for Dispatch {\n    async fn tr_execute(\n        transaction: TransactionPrimitive<Self>,\n    ) -> Result<TransactionPrimitiveData, ExecutionError> {\n        let first_tensor = transaction\n            .read_floats\n            .first()\n            .or(transaction.read_ints.first())\n            .or(transaction.read_bools.first());\n\n        match first_tensor {\n            Some(tensor) => {\n                transaction_op!(transaction, tensor)\n            }\n            None => Ok(TransactionPrimitiveData::default()),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-dispatch/src/tensor.rs",
    "content": "use burn_backend::{Backend, QTensorPrimitive, TensorMetadata};\n\n#[cfg(feature = \"autodiff\")]\nuse crate::CheckpointingStrategy;\nuse crate::backends::*;\n\n#[cfg(feature = \"autodiff\")]\nuse burn_backend::tensor::FloatTensor;\n\n// TODO: if we reduce the different associated types for float/int/bool/quantized tensor primitives down to a single\n// `B::TensorPrimitive` we can simplify this.\n\n/// Tensor which points to a backend tensor primitive kind.\n#[derive(Clone, Debug)]\npub enum BackendTensor<B: Backend> {\n    /// Float tensor handle.\n    Float(B::FloatTensorPrimitive),\n    /// Int tensor handle.\n    Int(B::IntTensorPrimitive),\n    /// Bool tensor handle.\n    Bool(B::BoolTensorPrimitive),\n    /// Quantized tensor handle.\n    Quantized(B::QuantizedTensorPrimitive),\n    #[cfg(feature = \"autodiff\")]\n    /// Autodiff float tensor handle.\n    Autodiff(FloatTensor<Autodiff<B>>),\n}\n\nimpl<B: Backend> BackendTensor<B> {\n    /// Returns the inner float tensor primitive.\n    pub(crate) fn float(self) -> B::FloatTensorPrimitive {\n        match self {\n            BackendTensor::Float(tensor) => tensor,\n            BackendTensor::Int(_) => panic!(\"Should be float, got int\"),\n            BackendTensor::Bool(_) => panic!(\"Should be float, got bool\"),\n            BackendTensor::Quantized(_) => panic!(\"Should be float, got quantized\"),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(_) => panic!(\"Should be float, got autodiff\"),\n        }\n    }\n    /// Returns the inner float tensor primitive.\n    pub(crate) fn as_float(&self) -> &B::FloatTensorPrimitive {\n        match self {\n            BackendTensor::Float(tensor) => tensor,\n            BackendTensor::Int(_) => panic!(\"Should be float, got int\"),\n            BackendTensor::Bool(_) => panic!(\"Should be float, got bool\"),\n            BackendTensor::Quantized(_) => panic!(\"Should be float, got quantized\"),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(_) => panic!(\"Should be float, got autodiff\"),\n        }\n    }\n\n    /// Returns the inner int tensor primitive.\n    pub(crate) fn int(self) -> B::IntTensorPrimitive {\n        match self {\n            BackendTensor::Int(tensor) => tensor,\n            BackendTensor::Float(_) => panic!(\"Should be int, got float\"),\n            BackendTensor::Bool(_) => panic!(\"Should be int, got bool\"),\n            BackendTensor::Quantized(_) => panic!(\"Should be int, got quantized\"),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(_) => panic!(\"Should be int, got autodiff\"),\n        }\n    }\n\n    /// Returns the inner bool tensor primitive.\n    pub(crate) fn bool(self) -> B::BoolTensorPrimitive {\n        match self {\n            BackendTensor::Bool(tensor) => tensor,\n            BackendTensor::Float(_) => panic!(\"Should be bool, got float\"),\n            BackendTensor::Int(_) => panic!(\"Should be bool, got int\"),\n            BackendTensor::Quantized(_) => panic!(\"Should be bool, got quantized\"),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(_) => panic!(\"Should be bool, got autodiff\"),\n        }\n    }\n\n    /// Returns the inner quantized tensor primitive.\n    pub(crate) fn quantized(self) -> B::QuantizedTensorPrimitive {\n        match self {\n            BackendTensor::Quantized(tensor) => tensor,\n            _ => unreachable!(),\n        }\n    }\n\n    #[cfg(feature = \"autodiff\")]\n    /// Returns the inner autodiff tensor primitive.\n    pub(crate) fn autodiff(self) -> FloatTensor<Autodiff<B>> {\n        match self {\n            BackendTensor::Autodiff(tensor) => tensor,\n            // NOTE: this is the panicking code reached in tensor.rs:74:18:\n            _ => unreachable!(),\n        }\n    }\n\n    #[cfg(feature = \"autodiff\")]\n    /// Returns the inner autodiff tensor primitive.\n    pub(crate) fn as_autodiff(&self) -> &FloatTensor<Autodiff<B>> {\n        match self {\n            BackendTensor::Autodiff(tensor) => tensor,\n            _ => unreachable!(),\n        }\n    }\n\n    #[cfg(feature = \"autodiff\")]\n    /// Returns the inner autodiff tensor primitive.\n    pub(crate) fn autodiff_inner(self) -> B::FloatTensorPrimitive {\n        match self {\n            BackendTensor::Autodiff(tensor) => tensor.primitive,\n            _ => unreachable!(),\n        }\n    }\n\n    /// Returns the backend device.\n    pub(crate) fn device(&self) -> B::Device {\n        match self {\n            BackendTensor::Float(tensor) => B::float_device(tensor),\n            BackendTensor::Int(tensor) => B::int_device(tensor),\n            BackendTensor::Bool(tensor) => B::bool_device(tensor),\n            BackendTensor::Quantized(tensor) => B::q_device(tensor),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(tensor) => B::float_device(&tensor.primitive),\n        }\n    }\n}\n\nimpl<B: Backend> TensorMetadata for BackendTensor<B> {\n    fn dtype(&self) -> burn_std::DType {\n        match self {\n            BackendTensor::Float(tensor) => tensor.dtype(),\n            BackendTensor::Int(tensor) => tensor.dtype(),\n            BackendTensor::Bool(tensor) => tensor.dtype(),\n            BackendTensor::Quantized(tensor) => tensor.dtype(),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(tensor) => tensor.dtype(),\n        }\n    }\n\n    fn shape(&self) -> burn_std::Shape {\n        match self {\n            BackendTensor::Float(tensor) => tensor.shape(),\n            BackendTensor::Int(tensor) => tensor.shape(),\n            BackendTensor::Bool(tensor) => tensor.shape(),\n            BackendTensor::Quantized(tensor) => tensor.shape(),\n            #[cfg(feature = \"autodiff\")]\n            BackendTensor::Autodiff(tensor) => tensor.shape(),\n        }\n    }\n}\n\nimpl<B: Backend> QTensorPrimitive for BackendTensor<B> {\n    fn scheme(&self) -> &burn_std::QuantScheme {\n        match self {\n            BackendTensor::Quantized(tensor) => tensor.scheme(),\n            _ => panic!(\n                \"Quantization scheme is not valid for dtype {:?}\",\n                self.dtype(),\n            ),\n        }\n    }\n}\n\n/// A tensor that can dispatch operations to any enabled backend at runtime.\n///\n/// When the `autodiff` feature is enabled, tensors may carry a checkpointing\n/// strategy used to control gradient computation. This is derived from the\n/// device used to create the tensor.\n#[derive(Clone, Debug)]\npub struct DispatchTensor {\n    /// Tensor kind primitive.\n    pub(crate) kind: DispatchTensorKind,\n    // Technically more of a device property, but device is not a dispatch tensor field.\n    #[cfg(feature = \"autodiff\")]\n    pub(crate) checkpointing: CheckpointingStrategy,\n}\n\n/// Internal representation of a [`DispatchTensor`].\n///\n/// This enum contains the concrete backend tensor for each enabled backend.\n/// It is not intended to be used directly; instead, it is manipulated by\n/// the dispatch system to route operations to the correct backend.\n///\n/// Each variant corresponds to a specific backend implementation.\n#[derive(Clone, Debug)]\npub enum DispatchTensorKind {\n    /// The [CPU backend](Cpu) tensor.\n    #[cfg(feature = \"cpu\")]\n    Cpu(BackendTensor<Cpu>),\n\n    /// The [CUDA backend](Cuda) tensor.\n    #[cfg(feature = \"cuda\")]\n    Cuda(BackendTensor<Cuda>),\n\n    /// The [Metal backend](Metal) tensor.\n    #[cfg(wgpu_metal)]\n    Metal(BackendTensor<Metal>),\n\n    /// The [ROCm backend](Rocm) tensor.\n    #[cfg(feature = \"rocm\")]\n    Rocm(BackendTensor<Rocm>),\n\n    /// The [Vulkan backend](Vulkan) tensor.\n    #[cfg(wgpu_vulkan)]\n    Vulkan(BackendTensor<Vulkan>),\n\n    /// The [WebGPU backend](WebGpu) tensor.\n    #[cfg(wgpu_webgpu)]\n    WebGpu(BackendTensor<WebGpu>),\n\n    /// The [NdArray backend](NdArray) tensor.\n    #[cfg(feature = \"ndarray\")]\n    NdArray(BackendTensor<NdArray>),\n\n    /// The [LibTorch backend](LibTorch) tensor.\n    #[cfg(feature = \"tch\")]\n    LibTorch(BackendTensor<LibTorch>),\n\n    /// The [autodiff enabled backend](Autodiff) tensor.\n    #[cfg(feature = \"autodiff\")]\n    Autodiff(Box<DispatchTensorKind>),\n}\n\nimpl TensorMetadata for DispatchTensorKind {\n    fn dtype(&self) -> burn_std::DType {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(tensor) => tensor.dtype(),\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(tensor) => tensor.dtype(),\n            #[cfg(wgpu_metal)]\n            Self::Metal(tensor) => tensor.dtype(),\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(tensor) => tensor.dtype(),\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(tensor) => tensor.dtype(),\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(tensor) => tensor.dtype(),\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(tensor) => tensor.dtype(),\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(tensor) => tensor.dtype(),\n            #[cfg(feature = \"autodiff\")]\n            Self::Autodiff(tensor) => tensor.dtype(),\n        }\n    }\n\n    fn shape(&self) -> burn_std::Shape {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(tensor) => tensor.shape(),\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(tensor) => tensor.shape(),\n            #[cfg(wgpu_metal)]\n            Self::Metal(tensor) => tensor.shape(),\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(tensor) => tensor.shape(),\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(tensor) => tensor.shape(),\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(tensor) => tensor.shape(),\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(tensor) => tensor.shape(),\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(tensor) => tensor.shape(),\n            #[cfg(feature = \"autodiff\")]\n            Self::Autodiff(tensor) => tensor.shape(),\n        }\n    }\n}\n\nimpl QTensorPrimitive for DispatchTensorKind {\n    fn scheme(&self) -> &burn_std::QuantScheme {\n        match self {\n            #[cfg(feature = \"cpu\")]\n            Self::Cpu(tensor) => tensor.scheme(),\n            #[cfg(feature = \"cuda\")]\n            Self::Cuda(tensor) => tensor.scheme(),\n            #[cfg(wgpu_metal)]\n            Self::Metal(tensor) => tensor.scheme(),\n            #[cfg(feature = \"rocm\")]\n            Self::Rocm(tensor) => tensor.scheme(),\n            #[cfg(wgpu_vulkan)]\n            Self::Vulkan(tensor) => tensor.scheme(),\n            #[cfg(wgpu_webgpu)]\n            Self::WebGpu(tensor) => tensor.scheme(),\n            #[cfg(feature = \"ndarray\")]\n            Self::NdArray(tensor) => tensor.scheme(),\n            #[cfg(feature = \"tch\")]\n            Self::LibTorch(tensor) => tensor.scheme(),\n            #[cfg(feature = \"autodiff\")]\n            Self::Autodiff(tensor) => tensor.scheme(),\n        }\n    }\n}\n\nimpl TensorMetadata for DispatchTensor {\n    fn dtype(&self) -> burn_std::DType {\n        self.kind.dtype()\n    }\n\n    fn shape(&self) -> burn_std::Shape {\n        self.kind.shape()\n    }\n}\n\nimpl QTensorPrimitive for DispatchTensor {\n    fn scheme(&self) -> &burn_std::QuantScheme {\n        self.kind.scheme()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Kernel fusion backend decorator for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-fusion\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-fusion\"\ndocumentation = \"https://docs.rs/burn-fusion\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\", \"tracing\"]\nstd = [\"serde/std\", \"tracing?/std\"]\ndoc = [\"default\"]\nmemory-checks = [\"std\"]\n\ntracing = [\n    \"dep:tracing\",\n    \"burn-backend/tracing\",\n    \"burn-ir/tracing\",\n]\n\n[dependencies]\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\" }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\" }\ntracing = { workspace = true, optional = true, features = [\"attributes\"] }\n\nhashbrown = { workspace = true }\nderive-new = { workspace = true }\nspin = { workspace = true }\nlog = { workspace = true }\nserde = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-fusion/README.md",
    "content": "# Burn Fusion\n\nA kernel fusion backend decorator for Burn.\n"
  },
  {
    "path": "crates/burn-fusion/src/backend.rs",
    "content": "use crate::{\n    FusionTensor,\n    client::GlobalFusionClient,\n    stream::{Context, OrderedExecution},\n};\nuse burn_backend::{\n    Backend, DType, DeviceOps, ExecutionError,\n    tensor::{BoolTensor, Device, FloatTensor, IntTensor, QuantizedTensor},\n};\nuse burn_ir::{BackendIr, OperationIr, TensorHandle};\nuse serde::{Serialize, de::DeserializeOwned};\nuse std::marker::PhantomData;\n\n/// Get the client for the given device.\npub fn get_client<B: FusionBackend>(device: &Device<B>) -> Client<B::FusionRuntime> {\n    GlobalFusionClient::load(device)\n}\n\n/// Enable dynamic operation fusion on a backend that implements [fusion backend](crate::FusionBackend).\n#[derive(Clone, Debug, Default)]\npub struct Fusion<B: FusionBackend> {\n    _backend: PhantomData<B>,\n}\n\nimpl<B: FusionBackend> Backend for Fusion<B> {\n    type Device = B::Device;\n\n    type FloatTensorPrimitive = FusionTensor<B::FusionRuntime>;\n\n    type FloatElem = B::FloatElem;\n\n    type IntTensorPrimitive = FusionTensor<B::FusionRuntime>;\n\n    type IntElem = B::IntElem;\n\n    type BoolTensorPrimitive = FusionTensor<B::FusionRuntime>;\n\n    type BoolElem = B::BoolElem;\n\n    type QuantizedTensorPrimitive = FusionTensor<B::FusionRuntime>;\n\n    fn name(device: &Self::Device) -> String {\n        format!(\"fusion<{}>\", B::name(device))\n    }\n\n    fn seed(device: &B::Device, seed: u64) {\n        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);\n        client.drain();\n        B::seed(device, seed);\n    }\n\n    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {\n        let client = GlobalFusionClient::<B::FusionRuntime>::load(device);\n        client.drain();\n        B::sync(device)\n    }\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    fn memory_persistent_allocations<\n        Output: Send,\n        Input: Send,\n        Func: Fn(Input) -> Output + Send,\n    >(\n        device: &Self::Device,\n        input: Input,\n        func: Func,\n    ) -> Output {\n        B::memory_persistent_allocations(device, input, func)\n    }\n\n    fn memory_cleanup(device: &Self::Device) {\n        B::memory_cleanup(device)\n    }\n\n    fn staging<'a, Iter>(data: Iter, device: &Self::Device)\n    where\n        Iter: Iterator<Item = &'a mut burn_backend::TensorData>,\n    {\n        B::staging(data, device);\n    }\n\n    fn supports_dtype(device: &Self::Device, dtype: DType) -> bool {\n        B::supports_dtype(device, dtype)\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {\n        B::dtype_usage(device, dtype)\n    }\n}\n\n/// The status of a [fuser](OperationFuser).\n#[derive(Clone, Debug, Copy, PartialEq, Eq)]\npub enum FuserStatus {\n    /// No more operations can be fused.\n    Closed,\n    /// More operations can be fused.\n    Open,\n}\n\n/// The properties of a [fuser](OperationFuser).\n#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]\npub struct FuserProperties {\n    /// The score of the optimization, higher is better.\n    pub score: u64,\n    /// If the operation is ready to be executed.\n    pub ready: bool,\n}\n\n/// The fusion operation abstraction allows implementations to fuse many\n/// [tensor operations](OperationIr) into one, improving the performance of the backend.\n///\n///\n/// # Notes\n///\n/// The implementations are free to execute the registered operations the way they want to improve\n/// the speed and efficiency of the computational graph. It doesn't mean that all registered\n/// operations should be fused, but that another way of executing them is more efficient.\n///\n/// Also, it is important to return (FuserStatus::Closed) when no more registered operation can\n/// improve the performance.\npub trait OperationFuser<O>: Send {\n    /// Register a new [tensor operation](OperationIr).\n    fn fuse(&mut self, operation: &OperationIr);\n    /// Finish the optimization and create a fusion operation.\n    fn finish(&mut self) -> O;\n    /// Reset the state.\n    fn reset(&mut self);\n    /// Return the builder [status](FuserStatus).\n    fn status(&self) -> FuserStatus;\n    /// Return the builder [properties](FuserProperties).\n    fn properties(&self) -> FuserProperties;\n    /// The number of operation fused.\n    fn len(&self) -> usize;\n    /// If no operations are fused.\n    fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n    /// Clone the optimization builder.\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<O>>;\n}\n\n/// The number of operations contained in the data structure.\npub trait NumOperations: core::fmt::Debug {\n    /// The number of registered operations.\n    fn len(&self) -> usize;\n    /// If the current optimization is empty.\n    fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n}\n\n/// The optimization created from a [fuser](OperationFuser).\npub trait Optimization<R: FusionRuntime>: Send + NumOperations {\n    /// Execute the optimization.\n    fn execute(\n        &mut self,\n        context: &mut Context<'_, R::FusionHandle>,\n        execution: &OrderedExecution<R>,\n    );\n\n    /// Returns the state that can be serialized.\n    fn to_state(&self) -> R::OptimizationState;\n    /// Create the optimization from the state.\n    fn from_state(device: &R::FusionDevice, state: R::OptimizationState) -> Self;\n}\n\n/// Type alias for `<R as FusionRuntime>::FusionDevice`.\npub type FusionDevice<R> = <R as FusionRuntime>::FusionDevice;\n/// Type alias for `<R as FusionRuntime>::FusionHandle`.\npub type FusionHandle<R> = <R as FusionRuntime>::FusionHandle;\n/// Client alias.\npub type Client<R> = GlobalFusionClient<R>;\n\n/// Trait that defines a runtime that will benefits from fused operations.\npub trait FusionRuntime: Send + Sync + Sized + core::fmt::Debug + 'static {\n    /// The state that can be serialized for an optimization.\n    type OptimizationState: Serialize + DeserializeOwned;\n    /// Optimization type for the backend.\n    type Optimization: Optimization<Self>;\n    /// Handle used to store tensor dynamically.\n    type FusionHandle: Clone + Send;\n    /// Device used by the runtime.\n    type FusionDevice: DeviceOps;\n\n    /// The list of fusers that will be used to optimize the computational graph.\n    fn fusers(device: Self::FusionDevice) -> Vec<Box<dyn OperationFuser<Self::Optimization>>>;\n}\n\n/// Trait that allows an existing [backend](Backend) to specify graph optimizations using\n/// [operation fuser](crate::OperationFuser).\npub trait FusionBackend:\n    BackendIr<Handle = FusionHandle<Self::FusionRuntime>, Device = FusionDevice<Self::FusionRuntime>>\n{\n    /// The runtime used for this backend.\n    type FusionRuntime: FusionRuntime;\n\n    /// Cast a float tensor and returns the resulting handle.\n    fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle;\n\n    /// Pointer to the full precision fusion backend.\n    type FullPrecisionBackend: FusionBackend<FusionRuntime = Self::FusionRuntime>;\n}\n\n// Fusion implements `BackendIr` to enable router backend usage.\nimpl<B: FusionBackend> BackendIr for Fusion<B> {\n    type Handle = FusionTensor<B::FusionRuntime>;\n\n    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {\n        handle.handle\n    }\n\n    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {\n        handle.handle\n    }\n\n    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {\n        handle.handle\n    }\n\n    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {\n        handle.handle\n    }\n\n    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {\n        tensor\n    }\n\n    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {\n        tensor\n    }\n}\n\n// TODO: remove once backends no longer rely on generics for default elem types\n/// Returns the bool element dtype.\npub(crate) fn bool_dtype<BT: burn_backend::Element>() -> DType {\n    match BT::dtype() {\n        DType::U32 => DType::Bool(burn_backend::BoolStore::U32),\n        DType::U8 => DType::Bool(burn_backend::BoolStore::U8),\n        other => unimplemented!(\"Invalid bool dtye {other:?}\"),\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/client.rs",
    "content": "use crate::{\n    FusionBackend, FusionDevice, FusionHandle, FusionRuntime, FusionServer, FusionTensor,\n    stream::{OperationStreams, StreamId, execution::Operation},\n};\nuse burn_backend::{Device, DeviceHandle, DeviceId, DeviceService};\nuse burn_backend::{TensorData, backend::ExecutionError};\nuse burn_ir::{OperationIr, TensorId, TensorIr};\nuse std::sync::{\n    Arc,\n    atomic::{AtomicU64, Ordering},\n};\n\n/// Use a mutex to communicate with the fusion server.\npub struct GlobalFusionClient<R: FusionRuntime> {\n    server: DeviceHandle<FusionServer<R>>,\n    device: FusionDevice<R>,\n}\n\nimpl<R: FusionRuntime> DeviceService for FusionServer<R> {\n    fn init(device_id: DeviceId) -> Self {\n        let device = FusionDevice::<R>::from_id(device_id);\n        FusionServer::new(device)\n    }\n\n    fn utilities(&self) -> burn_backend::ServerUtilitiesHandle {\n        Arc::new(())\n    }\n}\n\nimpl<R> Clone for GlobalFusionClient<R>\nwhere\n    R: FusionRuntime,\n{\n    fn clone(&self) -> Self {\n        Self {\n            server: self.server.clone(),\n            device: self.device.clone(),\n        }\n    }\n}\nimpl<R> GlobalFusionClient<R>\nwhere\n    R: FusionRuntime + 'static,\n{\n    /// Loads the client from the given device.\n    pub fn load(device: &FusionDevice<R>) -> Self {\n        Self {\n            device: device.clone(),\n            server: DeviceHandle::new(device.to_id()),\n        }\n    }\n}\n\nstatic COUNTER: AtomicU64 = AtomicU64::new(0);\n\nimpl<R> GlobalFusionClient<R>\nwhere\n    R: FusionRuntime + 'static,\n{\n    /// Create a new client for the given [device](FusionRuntime::FusionDevice).\n    pub fn new(device: FusionDevice<R>) -> Self {\n        Self {\n            device: device.clone(),\n            server: DeviceHandle::new(device.to_id()),\n        }\n    }\n\n    /// Register a new [tensor operation intermediate representation](OperationIr).\n    ///\n    /// Returns the new (uninitialized) output tensor(s) generated by the registered operation.\n    pub fn register<O>(\n        &self,\n        streams: OperationStreams,\n        repr: OperationIr,\n        operation: O,\n    ) -> Vec<FusionTensor<R>>\n    where\n        O: Operation<R> + 'static,\n    {\n        // Create output tensors returned by this operation\n        let outputs = repr\n            .outputs()\n            .map(|output| {\n                FusionTensor::new(\n                    output.id,\n                    output.shape.clone(),\n                    output.dtype,\n                    self.clone(),\n                    StreamId::current(),\n                )\n            })\n            .collect();\n\n        self.server.submit(move |server| {\n            server.register(streams, repr, Arc::new(operation));\n        });\n\n        outputs\n    }\n\n    /// Register all lazy computation.\n    pub fn drain(&self) {\n        let id = StreamId::current();\n        self.server.submit(move |server| server.drain_stream(id));\n    }\n\n    /// Create a new (uninitialized) empty tensor handle and returns its corresponding [tensor id](TensorId).\n    pub fn create_empty_handle(&self) -> TensorId {\n        let value = COUNTER.fetch_add(1, Ordering::Relaxed);\n        TensorId::new(value)\n    }\n\n    /// Get the current device used by all operations handled by this client.\n    pub fn device(&self) -> &FusionDevice<R> {\n        &self.device\n    }\n\n    /// Create a tensor with the given handle and returns its corresponding [tensor id](TensorId).\n    pub fn register_tensor_handle(&self, handle: FusionHandle<R>) -> TensorId {\n        let id = self.create_empty_handle();\n\n        self.server\n            .submit(move |server| server.handles.register_handle(id, handle));\n\n        id\n    }\n\n    /// Read the values contained by a float tensor.\n    pub fn read_tensor_float<B>(\n        self,\n        tensor: TensorIr,\n        stream: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| server.read_float::<B>(tensor, stream))\n            .unwrap()\n    }\n\n    /// Read the values contained by an int tensor.\n    pub fn read_tensor_int<B>(\n        self,\n        tensor: TensorIr,\n        stream: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| server.read_int::<B>(tensor, stream))\n            .unwrap()\n    }\n\n    /// Read the values contained by a bool tensor.\n    pub fn read_tensor_bool<B>(\n        self,\n        tensor: TensorIr,\n        stream: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| server.read_bool::<B>(tensor, stream))\n            .unwrap()\n    }\n\n    /// Read the values contained by a quantized tensor.\n    pub fn read_tensor_quantized<B>(\n        self,\n        tensor: TensorIr,\n        stream: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| server.read_quantized::<B>(tensor, stream))\n            .unwrap()\n    }\n\n    /// Change the client of the given float tensor.\n    pub fn change_client_float<B>(\n        &self,\n        tensor: TensorIr,\n        client: Self,\n        stream: StreamId,\n    ) -> FusionTensor<R>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let dtype = tensor.dtype;\n        let client_cloned = client.clone();\n        let shape = tensor.shape.clone();\n        let id = self.create_empty_handle();\n\n        self.server.submit(move |server| {\n            server.drain_stream(stream);\n            // TODO: We could improve performance here by not requirering blocking.\n            client\n                .server\n                .clone()\n                .submit_blocking_scoped(move |server_other| {\n                    server_other.change_server_float::<B>(\n                        &tensor,\n                        id,\n                        stream,\n                        &client.device,\n                        server,\n                    )\n                })\n        });\n\n        FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current())\n    }\n\n    /// Change the client of the given int tensor.\n    pub fn change_client_int<B>(\n        &self,\n        tensor: TensorIr,\n        client: Self,\n        stream: StreamId,\n    ) -> FusionTensor<R>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let dtype = tensor.dtype;\n        let client_cloned = client.clone();\n        let shape = tensor.shape.clone();\n        let id = self.create_empty_handle();\n\n        self.server.submit(move |server| {\n            server.drain_stream(stream);\n            // TODO: We could improve performance here by not requirering blocking.\n            client\n                .server\n                .clone()\n                .submit_blocking_scoped(move |server_other| {\n                    server_other.change_server_int::<B>(&tensor, id, stream, &client.device, server)\n                })\n        });\n\n        FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current())\n    }\n\n    /// Change the client of the given bool tensor.\n    pub fn change_client_bool<B>(\n        &self,\n        tensor: TensorIr,\n        client: Self,\n        stream: StreamId,\n    ) -> FusionTensor<R>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let dtype = tensor.dtype;\n        let client_cloned = client.clone();\n        let shape = tensor.shape.clone();\n        let id = self.create_empty_handle();\n\n        self.server.submit(move |server| {\n            server.drain_stream(stream);\n            // TODO: We could improve performance here by not requirering blocking.\n            client\n                .server\n                .clone()\n                .submit_blocking_scoped(move |server_other| {\n                    server_other.change_server_bool::<B>(\n                        &tensor,\n                        id,\n                        stream,\n                        &client.device,\n                        server,\n                    )\n                })\n        });\n\n        FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current())\n    }\n\n    /// Change the client of the given quantized tensor.\n    pub fn change_client_quantized<B>(\n        &self,\n        tensor: TensorIr,\n        client: Self,\n        stream: StreamId,\n    ) -> FusionTensor<R>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let dtype = tensor.dtype;\n        let client_cloned = client.clone();\n        let shape = tensor.shape.clone();\n        let id = self.create_empty_handle();\n\n        self.server.submit(move |server| {\n            server.drain_stream(stream);\n            // TODO: We could improve performance here by not requirering blocking.\n            client\n                .server\n                .clone()\n                .submit_blocking_scoped(move |server_other| {\n                    server_other.change_server_quantized::<B>(&tensor, id, &client.device, server)\n                })\n        });\n\n        FusionTensor::new(id, shape, dtype, client_cloned, StreamId::current())\n    }\n\n    /// Resolve the given float tensor to a primitive tensor.\n    pub fn resolve_tensor_float<B>(&self, tensor: FusionTensor<R>) -> B::FloatTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| {\n                server.drain_stream(tensor.stream);\n                server.resolve_server_float::<B>(&tensor.into_ir())\n            })\n            .unwrap()\n    }\n\n    /// Resolve the given int tensor to a primitive tensor.\n    pub fn resolve_tensor_int<B>(&self, tensor: FusionTensor<R>) -> B::IntTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| {\n                server.drain_stream(tensor.stream);\n                server.resolve_server_int::<B>(&tensor.into_ir())\n            })\n            .unwrap()\n    }\n\n    /// Resolve the given bool tensor to a primitive tensor.\n    pub fn resolve_tensor_bool<B>(&self, tensor: FusionTensor<R>) -> B::BoolTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.server\n            .submit_blocking(move |server| {\n                server.drain_stream(tensor.stream);\n                server.resolve_server_bool::<B>(&tensor.into_ir())\n            })\n            .unwrap()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! # Burn Fusion\n//!\n//! This library is a part of the Burn project. It is a standalone crate that\n//! can be used to perform automatic operation fusion on backends that support it.\n\n#[macro_use]\nextern crate derive_new;\n\n/// Client module exposing types to communicate with the fusion server.\npub mod client;\n/// Stream module exposing all tensor operations that can be optimized.\npub mod stream;\n\n/// Search module for stream optimizations.\npub(crate) mod search;\n\nmod backend;\nmod ops;\nmod server;\nmod tensor;\n\npub(crate) use server::*;\n\npub use backend::*;\npub use ops::NoOp;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/activation.rs",
    "content": "use crate::{Fusion, FusionBackend};\nuse burn_backend::ops::ActivationOps;\n\nimpl<B: FusionBackend> ActivationOps<Self> for Fusion<B> {}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/base.rs",
    "content": "use crate::{FusionBackend, stream::Operation};\nuse burn_ir::HandleContainer;\nuse std::marker::PhantomData;\n\n/// A no-operation placeholder for the fusion backend.\n///\n/// `NoOp` is an implementation of [`Operation`] that doesn't execute anything.\n#[derive(new, Clone, Debug)]\npub struct NoOp<B: FusionBackend> {\n    _b: PhantomData<B>,\n}\n\nimpl<B: FusionBackend> Operation<B::FusionRuntime> for NoOp<B> {\n    fn execute(&self, _handles: &mut HandleContainer<B::Handle>) {}\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/binary.rs",
    "content": "#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_float_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(Debug)]\n        struct $name<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> $name<B> {\n            fn new(desc: BinaryOpIr) -> Self {\n                Self {\n                    desc,\n                    _b: PhantomData,\n                }\n            }\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);\n                let output = $ops(lhs, rhs);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_float_cmp_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);\n                let output = $ops(lhs, rhs);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_int_cmp_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(Debug)]\n        struct $name<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> $name<B> {\n            fn new(desc: BinaryOpIr) -> Self {\n                Self {\n                    desc,\n                    _b: PhantomData,\n                }\n            }\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);\n                let output = $ops(lhs, rhs);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_int_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_int_tensor::<B>(&self.desc.rhs);\n                let output = $ops(lhs, rhs);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/bool_tensor.rs",
    "content": "use crate::{\n    Fusion, FusionBackend, bool_dtype, get_client,\n    stream::{OperationStreams, execution::Operation},\n};\nuse burn_backend::{\n    Element, ExecutionError, Scalar, Shape, Slice, TensorData,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntTensor},\n};\nuse burn_ir::{\n    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,\n    GatherOpIr, HandleContainer, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr,\n    OperationOutput, PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr,\n    SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,\n};\nuse std::marker::PhantomData;\n\nuse super::NoOp;\n\nimpl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {\n    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct EmptyOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output = B::bool_empty(self.desc.shape.clone(), &self.device);\n                handles.register_bool_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, bool_dtype::<B::BoolElem>(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseBool(BaseOperationIr::Empty(desc.clone())),\n                EmptyOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct ZerosOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output = B::bool_zeros(self.desc.shape.clone(), &self.device);\n                handles.register_bool_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, bool_dtype::<B::BoolElem>(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseBool(BaseOperationIr::Zeros(desc.clone())),\n                ZerosOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct OnesOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output = B::bool_ones(self.desc.shape.clone(), &self.device);\n                handles.register_bool_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, bool_dtype::<B::BoolElem>(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseBool(BaseOperationIr::Ones(desc.clone())),\n                OnesOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {\n        tensor.bool_into_data::<B>().await\n    }\n\n    fn bool_from_data(data: burn_backend::TensorData, device: &Device<Self>) -> BoolTensor<Self> {\n        let client = get_client::<B>(device);\n        let tensor = B::bool_from_data(data, device);\n        let shape = burn_backend::TensorMetadata::shape(&tensor);\n\n        let handle = B::bool_tensor_handle(tensor);\n        let desc = InitOperationIr::create(shape, bool_dtype::<B::BoolElem>(), || {\n            client.register_tensor_handle(handle)\n        });\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::Init(desc),\n                NoOp::<B>::new(),\n            )\n            .output()\n    }\n\n    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct IntoIntOps<B: FusionBackend> {\n            desc: CastOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_into_int(input);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Bool(BoolOperationIr::IntoInt(desc.clone())),\n                IntoIntOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct IntoFloatOps<B: FusionBackend> {\n            desc: CastOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_into_float(input);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Bool(BoolOperationIr::IntoFloat(desc.clone())),\n                IntoFloatOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {\n        tensor.client.device().clone()\n    }\n\n    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {\n        let device_original: &B::Device = tensor.client.device();\n\n        if device_original == device {\n            return tensor;\n        }\n\n        let id = tensor.stream;\n        let client_target = get_client::<B>(device);\n        let client_original = tensor.client.clone();\n\n        client_original\n            .clone()\n            .change_client_bool::<B>(tensor.into_ir(), client_target, id)\n    }\n\n    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        if tensor.shape == shape {\n            return tensor;\n        }\n\n        #[derive(new, Debug)]\n        struct ReshapeDimsOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_reshape(input, self.desc.out.shape.clone());\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Reshape(desc.clone())),\n                ReshapeDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceOps<B: FusionBackend> {\n            desc: SliceOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n\n                let output = B::bool_slice(tensor, self.desc.ranges.as_slice());\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Slice(desc.clone())),\n                SliceOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        slices: &[Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceAssignOps<B: FusionBackend> {\n            desc: SliceAssignOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_bool_tensor::<B>(&self.desc.value);\n\n                let output = B::bool_slice_assign(tensor, self.desc.ranges.as_slice(), value);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &value]);\n\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc.clone())),\n                SliceAssignOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct CatOps<B: FusionBackend> {\n            desc: CatOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensors = self\n                    .desc\n                    .tensors\n                    .iter()\n                    .map(|tensor| handles.get_bool_tensor::<B>(tensor))\n                    .collect();\n\n                let output = B::bool_cat(tensors, self.desc.dim);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs(&tensors);\n\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Cat(desc.clone())),\n                CatOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct EqualOps<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);\n                let output = B::bool_equal(lhs, rhs);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Equal(desc.clone())),\n                EqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct NotOps<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for NotOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_not(input);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Bool(BoolOperationIr::Not(desc.clone())),\n                NotOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct AndOps<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for AndOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);\n                let output = B::bool_and(lhs, rhs);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Bool(BoolOperationIr::And(desc.clone())),\n                AndOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct OrOps<B: FusionBackend> {\n            desc: BinaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for OrOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_bool_tensor::<B>(&self.desc.rhs);\n                let output = B::bool_or(lhs, rhs);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n        client\n            .register(\n                streams,\n                OperationIr::Bool(BoolOperationIr::Or(desc.clone())),\n                OrOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct SwapDimsOps<B: FusionBackend> {\n            desc: SwapDimsOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::SwapDims(desc.clone())),\n                SwapDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct PermuteDimsOps<B: FusionBackend> {\n            desc: PermuteOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_permute(input, self.desc.axes.as_slice());\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),\n                PermuteDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct ExpandOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_expand(input, self.desc.out.shape.clone());\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Expand(desc.clone())),\n                ExpandOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct FlipOps<B: FusionBackend> {\n            desc: FlipOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_flip(input, self.desc.axes.as_slice());\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Flip(desc.clone())),\n                FlipOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct RepeatDimOps<B: FusionBackend> {\n            desc: RepeatDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n\n                let output = B::bool_repeat_dim(tensor, self.desc.dim, self.desc.times);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc.clone())),\n                RepeatDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_unfold(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct UnfoldOps<B: FusionBackend> {\n            desc: UnfoldOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_bool_tensor::<B>(&self.desc.input);\n                let output = B::bool_unfold(input, self.desc.dim, self.desc.size, self.desc.step);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),\n                UnfoldOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskWhereOps<B: FusionBackend> {\n            desc: MaskWhereOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_bool_tensor::<B>(&self.desc.value);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::bool_mask_where(tensor, mask, value);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);\n\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc.clone())),\n                MaskWhereOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskFillOps<B: FusionBackend> {\n            desc: MaskFillOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::bool_mask_fill(tensor, mask, self.desc.value.into());\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask]);\n\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::MaskFill(desc.clone())),\n                MaskFillOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct GatherOps<B: FusionBackend> {\n            desc: GatherOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::bool_gather(self.desc.dim, tensor, indices);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices]);\n\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Gather(desc.clone())),\n                GatherOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct ScatterOps<B: FusionBackend> {\n            desc: ScatterOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_bool_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n                let value = handles.get_bool_tensor::<B>(&self.desc.value);\n\n                let output = B::bool_scatter_or(self.desc.dim, tensor, indices, value);\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);\n\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::Scatter(desc.clone())),\n                ScatterOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct EqualElemOps<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for EqualElemOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_bool_tensor::<B>(&self.desc.lhs);\n                let output = B::bool_equal_elem(lhs, self.desc.rhs.into());\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseBool(BaseOperationIr::EqualElem(desc.clone())),\n                EqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/int_tensor.rs",
    "content": "use super::NoOp;\nuse crate::{\n    Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops, bool_dtype, get_client,\n    reduce_int_ops, scalar_int_cmp_ops, scalar_int_ops,\n    stream::{OperationStreams, execution::Operation},\n    unary_int_ops,\n};\nuse burn_backend::{\n    Distribution, Element, ExecutionError, IntDType, Scalar, Shape, Slice, TensorData,\n    ops::IntTensorOps,\n    tensor::{BoolTensor, Device, FloatTensor, IndexingUpdateOp, IntElem, IntTensor},\n};\nuse burn_ir::*;\nuse std::marker::PhantomData;\n\nimpl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {\n    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct EmptyOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output = B::int_empty(\n                    self.desc.shape.clone(),\n                    &self.device,\n                    self.desc.dtype.into(),\n                );\n                handles.register_int_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseInt(BaseOperationIr::Empty(desc.clone())),\n                EmptyOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {\n        tensor.int_into_data::<B>().await\n    }\n\n    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {\n        let client = get_client::<B>(device);\n        let dtype = data.dtype;\n        let tensor = B::int_from_data(data, device);\n        let shape = burn_backend::TensorMetadata::shape(&tensor);\n\n        let handle = B::int_tensor_handle(tensor);\n        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::Init(desc),\n                NoOp::<B>::new(),\n            )\n            .output()\n    }\n\n    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {\n        tensor.client.device().clone()\n    }\n\n    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {\n        let device_original: &B::Device = tensor.client.device();\n\n        if device_original == device {\n            return tensor;\n        }\n\n        let id = tensor.stream;\n        let client_target = get_client::<B>(device);\n        let client_original = tensor.client.clone();\n\n        client_original\n            .clone()\n            .change_client_int::<B>(tensor.into_ir(), client_target, id)\n    }\n\n    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        if tensor.shape == shape {\n            return tensor;\n        }\n\n        #[derive(new, Debug)]\n        struct ReshapeDimsOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_reshape(input, self.desc.out.shape.clone());\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Reshape(desc.clone())),\n                ReshapeDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceOps<B: FusionBackend> {\n            desc: SliceOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n\n                let output = B::int_slice(tensor, self.desc.ranges.as_slice());\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Slice(desc.clone())),\n                SliceOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceAssignOps<B: FusionBackend> {\n            desc: SliceAssignOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_int_tensor::<B>(&self.desc.value);\n\n                let output = B::int_slice_assign(tensor, self.desc.ranges.as_slice(), value);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &value]);\n\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc.clone())),\n                SliceAssignOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(MatmulOps, B::int_matmul);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),\n                MatmulOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskWhereOps<B: FusionBackend> {\n            desc: MaskWhereOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_int_tensor::<B>(&self.desc.value);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::int_mask_where(tensor, mask, value);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);\n\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc.clone())),\n                MaskWhereOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskFillOps<B: FusionBackend> {\n            desc: MaskFillOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::int_mask_fill(tensor, mask, self.desc.value.into());\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask]);\n\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::MaskFill(desc.clone())),\n                MaskFillOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_gather(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct GatherOps<B: FusionBackend> {\n            desc: GatherOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::int_gather(self.desc.dim, tensor, indices);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices]);\n\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Gather(desc.clone())),\n                GatherOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct ScatterOps<B: FusionBackend> {\n            desc: ScatterOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n                let value = handles.get_int_tensor::<B>(&self.desc.value);\n\n                let output = B::int_scatter_add(self.desc.dim, tensor, indices, value);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);\n\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Scatter(desc.clone())),\n                ScatterOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_select(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct SelectOps<B: FusionBackend> {\n            desc: SelectOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::int_select(tensor, self.desc.dim, indices);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices]);\n\n        let client = tensor.client.clone();\n        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Select(desc.clone())),\n                SelectOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct SelectAssignOps<B: FusionBackend> {\n            desc: SelectAssignOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n                let value = handles.get_int_tensor::<B>(&self.desc.value);\n\n                let output = B::int_select_add(tensor, self.desc.dim, indices, value);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);\n\n        let client = tensor.client.clone();\n        let desc = SelectAssignOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc.clone())),\n                SelectAssignOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CatOps<B: FusionBackend> {\n            desc: CatOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensors = self\n                    .desc\n                    .tensors\n                    .iter()\n                    .map(|tensor| handles.get_int_tensor::<B>(tensor))\n                    .collect();\n\n                let output = B::int_cat(tensors, self.desc.dim);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs(&tensors);\n\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Cat(desc.clone())),\n                CatOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_int_cmp_ops!(EqualOps, B::int_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Equal(desc.clone())),\n                EqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_int_cmp_ops!(EqualElemOps, B::int_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::EqualElem(desc.clone())),\n                EqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_int_cmp_ops!(GreaterOps, B::int_greater);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Greater(desc.clone())),\n                GreaterOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_int_cmp_ops!(GreaterElemOps, B::int_greater_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterElem(desc.clone()),\n                ),\n                GreaterElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_int_cmp_ops!(GreaterEqualOps, B::int_greater_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterEqual(desc.clone()),\n                ),\n                GreaterEqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_int_cmp_ops!(GreaterEqualElemOps, B::int_greater_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterEqualElem(desc.clone()),\n                ),\n                GreaterEqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_int_cmp_ops!(LowerOps, B::int_lower);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),\n                LowerOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_int_cmp_ops!(LowerElemOps, B::int_lower_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerElem(desc.clone()),\n                ),\n                LowerElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        binary_int_cmp_ops!(LowerEqualOps, B::int_lower_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerEqual(desc.clone()),\n                ),\n                LowerEqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_int_cmp_ops!(LowerEqualElemOps, B::int_lower_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerEqualElem(desc.clone()),\n                ),\n                LowerEqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(AddOps, B::int_add);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Add(desc.clone())),\n                AddOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(AddOps, B::int_add_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::AddScalar(desc.clone()),\n                ),\n                AddOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(SubOps, B::int_sub);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),\n                SubOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(SubOps, B::int_sub_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::SubScalar(desc.clone()),\n                ),\n                SubOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(MulOps, B::int_mul);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),\n                MulOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(MulOps, B::int_mul_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::MulScalar(desc.clone()),\n                ),\n                MulOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(DivOps, B::int_div);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Div(desc.clone())),\n                DivOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(DivOps, B::int_div_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::DivScalar(desc.clone()),\n                ),\n                DivOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(ModOps, B::int_remainder);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),\n                ModOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(ModOps, B::int_remainder_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::RemScalar(desc.clone()),\n                ),\n                ModOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct ZerosOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.desc.shape.clone();\n                let output = B::int_zeros(shape, &self.device, self.desc.dtype.into());\n                handles.register_int_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseInt(BaseOperationIr::Zeros(desc.clone())),\n                ZerosOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct OnesOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.desc.shape.clone();\n                let output = B::int_ones(shape, &self.device, self.desc.dtype.into());\n                handles.register_int_tensor::<B>(&self.desc.id, output);\n            }\n        }\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseInt(BaseOperationIr::Ones(desc.clone())),\n                OnesOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<Self>,\n        dtype: IntDType,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct FullOps<B: FusionBackend> {\n            out: TensorIr,\n            elem: ScalarIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.out.shape.clone();\n                let output =\n                    B::int_full(shape, self.elem.into(), &self.device, self.out.dtype.into());\n                handles.register_int_tensor::<B>(&self.out.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let dtype = dtype.into();\n        let value = fill_value.into();\n        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Full(desc.clone())),\n                FullOps::<B>::new(desc.out, desc.value, device.clone()),\n            )\n            .output()\n    }\n\n    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(SumOps, B::int_sum, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),\n                SumOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {\n        reduce_int_ops!(SumDimOps, B::int_sum_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),\n                SumDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(ProdOps, B::int_prod, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),\n                ProdOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(ProdDimOps, B::int_prod_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ProdDim(desc.clone())),\n                ProdDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(MeanOps, B::int_mean, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),\n                MeanOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(MeanDimOps, B::int_mean_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MeanDim(desc.clone())),\n                MeanDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumsumOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_cumsum(input, self.desc.axis);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),\n                CumsumOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumprodOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_cumprod(input, self.desc.axis);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumProd(desc.clone())),\n                CumprodOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumminOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_cummin(input, self.desc.axis);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),\n                CumminOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CummaxOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_cummax(input, self.desc.axis);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),\n                CummaxOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(ArgMaxOps, B::int_argmax);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMax(desc.clone())),\n                ArgMaxOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(ArgMinOps, B::int_argmin);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::ArgMin(desc.clone())),\n                ArgMinOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct ClampOps<B: FusionBackend> {\n            desc: ClampOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let output = B::int_clamp(input, self.desc.min.into(), self.desc.max.into());\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let min = min.into();\n        let max = max.into();\n        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Clamp(desc.clone())),\n                ClampOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(AbsOps, B::int_abs);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),\n                AbsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct IntoFloatOps<B: FusionBackend> {\n            desc: CastOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoFloatOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_into_float(input);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), B::FloatElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::IntoFloat(desc.clone())),\n                IntoFloatOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct SwapDimsOps<B: FusionBackend> {\n            desc: SwapDimsOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::SwapDims(desc.clone())),\n                SwapDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(MaxOps, B::int_max, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Max(desc.clone())),\n                MaxOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(MaxDimOps, B::int_max_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),\n                MaxDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_max_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        #[derive(new, Debug)]\n        struct MaxDimWithIndicesOps<B: FusionBackend> {\n            desc: ReduceDimWithIndicesOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let dtype = tensor.dtype;\n        let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(dtype, NumericOperationIr::MaxDimWithIndices(desc.clone())),\n                MaxDimWithIndicesOps::<B>::new(desc),\n            )\n            .outputs()\n            .into()\n    }\n\n    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(MinOps, B::int_min, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::Min(desc.clone())),\n                MinOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(MaxAbsOps, B::int_max_abs, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),\n                MaxAbsOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(MaxAbsDimOps, B::int_max_abs_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(\n                    desc.out.dtype,\n                    NumericOperationIr::MaxAbsDim(desc.clone()),\n                ),\n                MaxAbsDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_int_ops!(MinDimOps, B::int_min_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),\n                MinDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_min_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        #[derive(new, Debug)]\n        struct MinDimWithIndicesOps<B: FusionBackend> {\n            desc: ReduceDimWithIndicesOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n                let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let dtype = tensor.dtype;\n        let desc = ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, dtype, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericInt(dtype, NumericOperationIr::MinDimWithIndices(desc.clone())),\n                MinDimWithIndicesOps::<B>::new(desc),\n            )\n            .outputs()\n            .into()\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct IntRandomOps<B: FusionBackend> {\n            desc: RandomOpIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntRandomOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.desc.out.shape.clone();\n                let output = B::int_random(shape, self.desc.distribution, &self.device);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let dtype = IntElem::<Self>::dtype();\n        let client = get_client::<B>(device);\n        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::NumericInt(dtype, NumericOperationIr::IntRandom(desc.clone())),\n                IntRandomOps::<B>::new(desc, device.clone()),\n            )\n            .output()\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct PermuteDimsOps<B: FusionBackend> {\n            desc: PermuteOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_permute(input, self.desc.axes.as_slice());\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),\n                PermuteDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct ExpandOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_expand(input, self.desc.out.shape.clone());\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Expand(desc.clone())),\n                ExpandOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct FlipDimsOps<B: FusionBackend> {\n            desc: FlipOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let axes = &self.desc.axes;\n                let output = B::int_flip(input, axes);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),\n                FlipDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct RepeatDimOps<B: FusionBackend> {\n            desc: RepeatDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_int_tensor::<B>(&self.desc.tensor);\n\n                let output = B::int_repeat_dim(tensor, self.desc.dim, self.desc.times);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc.clone())),\n                RepeatDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(BitwiseAndOps, B::bitwise_and);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseAnd(desc.clone())),\n                BitwiseAndOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc.clone())),\n                BitwiseAndOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(BitwiseOrOps, B::bitwise_or);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseOr(desc.clone())),\n                BitwiseOrOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc.clone())),\n                BitwiseOrOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(BitwiseXorOps, B::bitwise_xor);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseXor(desc.clone())),\n                BitwiseXorOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc.clone())),\n                BitwiseXorOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        unary_int_ops!(BitwiseNotOps, B::bitwise_not);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseNot(desc.clone())),\n                BitwiseNotOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc.clone())),\n                BitwiseLeftShiftOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(desc.clone())),\n                BitwiseLeftShiftOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseRightShift(desc.clone())),\n                BitwiseRightShiftOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(desc.clone())),\n                BitwiseRightShiftOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct CastOps<B: FusionBackend> {\n            desc: CastOpIr,\n            dtype: burn_backend::IntDType,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output: B::IntTensorPrimitive = B::int_cast(input, self.dtype);\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Cast(desc.clone())),\n                CastOps::<B>::new(desc, dtype),\n            )\n            .output()\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct UnfoldOps<B: FusionBackend> {\n            desc: UnfoldOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = B::int_unfold(input, self.desc.dim, self.desc.size, self.desc.step);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),\n                UnfoldOps::<B>::new(desc),\n            )\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/mod.rs",
    "content": "mod activation;\nmod binary;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\nmod unary;\n\nmod base;\npub use base::NoOp;\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/module.rs",
    "content": "use crate::{\n    Fusion, FusionBackend,\n    stream::{OperationStreams, execution::Operation},\n};\nuse burn_backend::{\n    Element,\n    ops::{\n        ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions,\n        InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,\n        MaxPool2dWithIndices, ModuleOps,\n    },\n    tensor::{FloatTensor, IntTensor},\n};\nuse burn_ir::*;\nuse std::marker::PhantomData;\n\nmacro_rules! make_ops {\n    ($name:ident, $desc:ty, $fn:expr) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: $desc,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                #[allow(clippy::redundant_closure_call)]\n                $fn(&self.desc, handles)\n            }\n        }\n    };\n}\n\nimpl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {\n    fn conv1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        make_ops!(Conv1dOps, Conv1dOpIr, |desc: &Conv1dOpIr,\n                                          handles: &mut HandleContainer<\n            B::Handle,\n        >| {\n            let x = handles.get_float_tensor::<B>(&desc.x);\n            let weight = handles.get_float_tensor::<B>(&desc.weight);\n            let bias = desc\n                .bias\n                .as_ref()\n                .map(|bias| handles.get_float_tensor::<B>(bias));\n            let output = B::conv1d(x, weight, bias, desc.options.clone().into());\n            handles.register_float_tensor::<B>(&desc.out.id, output);\n        });\n\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = Conv1dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv1d(desc.clone())),\n                Conv1dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv1d_x_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv1dXBackwardOps,\n            Conv1dXBackwardOpIr,\n            |desc: &Conv1dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv1d_x_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv1dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv1dXBackward(desc.clone())),\n                Conv1dXBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv1d_weight_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv1dWeightBackwardOps,\n            Conv1dWeightBackwardOpIr,\n            |desc: &Conv1dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv1d_weight_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv1dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv1dWeightBackward(desc.clone())),\n                Conv1dWeightBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv1d_bias_backward(\n        x: FloatTensor<Fusion<B>>,\n        bias: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv1dBiasBackwardOps,\n            Conv1dBiasBackwardOpIr,\n            |desc: &Conv1dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let bias = handles.get_float_tensor::<B>(&desc.bias);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output = B::conv1d_bias_backward(x, bias, output_grad);\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv1dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(desc.clone())),\n                Conv1dBiasBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        make_ops!(Conv2dOps, Conv2dOpIr, |args: &Conv2dOpIr,\n                                          handles: &mut HandleContainer<\n            B::Handle,\n        >| {\n            let x = handles.get_float_tensor::<B>(&args.x);\n            let weight = handles.get_float_tensor::<B>(&args.weight);\n            let bias = args\n                .bias\n                .as_ref()\n                .map(|bias| handles.get_float_tensor::<B>(bias));\n\n            let output = B::conv2d(x, weight, bias, args.options.clone().into());\n\n            handles.register_float_tensor::<B>(&args.out.id, output);\n        });\n\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = Conv2dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv2d(desc.clone())),\n                Conv2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv2d_x_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv2dXBackwardOps,\n            Conv2dXBackwardOpIr,\n            |desc: &Conv2dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv2d_x_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv2dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv2dXBackward(desc.clone())),\n                Conv2dXBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv2d_weight_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv2dWeightBackwardOps,\n            Conv2dWeightBackwardOpIr,\n            |desc: &Conv2dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv2d_weight_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv2dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv2dWeightBackward(desc.clone())),\n                Conv2dWeightBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv2d_bias_backward(\n        x: FloatTensor<Fusion<B>>,\n        bias: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv2dBiasBackwardOps,\n            Conv2dBiasBackwardOpIr,\n            |desc: &Conv2dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let bias = handles.get_float_tensor::<B>(&desc.bias);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output = B::conv2d_bias_backward(x, bias, output_grad);\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv2dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(desc.clone())),\n                Conv2dBiasBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            DeformConv2dOps,\n            DeformConv2dOpIr,\n            |args: &DeformConv2dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let offset = handles.get_float_tensor::<B>(&args.offset);\n                let weight = handles.get_float_tensor::<B>(&args.weight);\n                let mask = args\n                    .mask\n                    .as_ref()\n                    .map(|mask| handles.get_float_tensor::<B>(mask));\n                let bias = args\n                    .bias\n                    .as_ref()\n                    .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                let output =\n                    B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n        if let Some(mask) = mask.as_ref() {\n            streams.tensor(mask)\n        }\n\n        let client = x.client.clone();\n        let desc = DeformConv2dOpIr::create(\n            x.into_ir(),\n            offset.into_ir(),\n            weight.into_ir(),\n            mask.map(|mask| mask.into_ir()),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::DeformableConv2d(Box::new(desc.clone()))),\n                DeformConv2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        make_ops!(\n            DeformConv2dBackwardOps,\n            DeformConv2dBackwardOpIr,\n            |args: &DeformConv2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let offset = handles.get_float_tensor::<B>(&args.offset);\n                let weight = handles.get_float_tensor::<B>(&args.weight);\n                let mask = args\n                    .mask\n                    .as_ref()\n                    .map(|mask| handles.get_float_tensor::<B>(mask));\n                let bias = args\n                    .bias\n                    .as_ref()\n                    .map(|bias| handles.get_float_tensor::<B>(bias));\n                let output_grad = handles.get_float_tensor::<B>(&args.out_grad);\n\n                let output = B::deform_conv2d_backward(\n                    x,\n                    offset,\n                    weight,\n                    mask,\n                    bias,\n                    output_grad,\n                    args.options.clone().into(),\n                );\n\n                handles.register_float_tensor::<B>(&args.input_grad.id, output.x_grad);\n                handles.register_float_tensor::<B>(&args.offset_grad.id, output.offset_grad);\n                handles.register_float_tensor::<B>(&args.weight_grad.id, output.weight_grad);\n                if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {\n                    handles.register_float_tensor::<B>(&field.id, mask_grad);\n                }\n                if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {\n                    handles.register_float_tensor::<B>(&field.id, bias_grad);\n                }\n            }\n        );\n\n        let has_bias = bias.is_some();\n        let has_mask = mask.is_some();\n\n        let mut streams = OperationStreams::with_inputs([&x, &offset, &weight, &output_grad]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias);\n        }\n        if let Some(mask) = mask.as_ref() {\n            streams.tensor(mask);\n        }\n\n        let client = x.client.clone();\n        let desc = DeformConv2dBackwardOpIr::create(\n            x.into_ir(),\n            offset.into_ir(),\n            weight.into_ir(),\n            mask.map(|mask| mask.into_ir()),\n            bias.map(|bias| bias.into_ir()),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        let mut outputs = client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::DeformableConv2dBackward(Box::new(\n                    desc.clone(),\n                ))),\n                DeformConv2dBackwardOps::<B>::new(desc),\n            )\n            .into_iter();\n\n        // When the number of outputs is variable, the order is important\n        let input_grad = outputs.next().unwrap();\n        let offset_grad = outputs.next().unwrap();\n        let weight_grad = outputs.next().unwrap();\n        let mask_grad = has_mask.then(|| outputs.next().unwrap());\n        let bias_grad = has_bias.then(|| outputs.next().unwrap());\n\n        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        make_ops!(Conv3dOps, Conv3dOpIr, |args: &Conv3dOpIr,\n                                          handles: &mut HandleContainer<\n            B::Handle,\n        >| {\n            let x = handles.get_float_tensor::<B>(&args.x);\n            let weight = handles.get_float_tensor::<B>(&args.weight);\n            let bias = args\n                .bias\n                .as_ref()\n                .map(|bias| handles.get_float_tensor::<B>(bias));\n\n            let output = B::conv3d(x, weight, bias, args.options.clone().into());\n\n            handles.register_float_tensor::<B>(&args.out.id, output);\n        });\n\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = Conv3dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv3d(desc.clone())),\n                Conv3dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv3d_x_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv3dXBackwardOps,\n            Conv3dXBackwardOpIr,\n            |desc: &Conv3dXBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv3d_x_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv3dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv3dXBackward(desc.clone())),\n                Conv3dXBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv3d_weight_backward(\n        x: FloatTensor<Fusion<B>>,\n        weight: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv3dWeightBackwardOps,\n            Conv3dWeightBackwardOpIr,\n            |desc: &Conv3dWeightBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let weight = handles.get_float_tensor::<B>(&desc.weight);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output =\n                    B::conv3d_weight_backward(x, weight, output_grad, desc.options.clone().into());\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &weight, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv3dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv3dWeightBackward(desc.clone())),\n                Conv3dWeightBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv3d_bias_backward(\n        x: FloatTensor<Fusion<B>>,\n        bias: FloatTensor<Fusion<B>>,\n        output_grad: FloatTensor<Fusion<B>>,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            Conv3dBiasBackwardOps,\n            Conv3dBiasBackwardOpIr,\n            |desc: &Conv3dBiasBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&desc.x);\n                let bias = handles.get_float_tensor::<B>(&desc.bias);\n                let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n                let output = B::conv3d_bias_backward(x, bias, output_grad);\n                handles.register_float_tensor::<B>(&desc.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &bias, &output_grad]);\n\n        let client = x.client.clone();\n        let desc = Conv3dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(desc.clone())),\n                Conv3dBiasBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv_transpose1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            ConvTranspose1dOps,\n            ConvTranspose1dOpIr,\n            |args: &ConvTranspose1dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let weight = handles.get_float_tensor::<B>(&args.weight);\n                let bias = args\n                    .bias\n                    .as_ref()\n                    .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                let output = B::conv_transpose1d(x, weight, bias, args.options.clone().into());\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = ConvTranspose1dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::ConvTranspose1d(desc.clone())),\n                ConvTranspose1dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            ConvTranspose2dOps,\n            ConvTranspose2dOpIr,\n            |args: &ConvTranspose2dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let weight = handles.get_float_tensor::<B>(&args.weight);\n                let bias = args\n                    .bias\n                    .as_ref()\n                    .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                let output = B::conv_transpose2d(x, weight, bias, args.options.clone().into());\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = ConvTranspose2dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::ConvTranspose2d(desc.clone())),\n                ConvTranspose2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            ConvTranspose3dOps,\n            ConvTranspose3dOpIr,\n            |args: &ConvTranspose3dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let weight = handles.get_float_tensor::<B>(&args.weight);\n                let bias = args\n                    .bias\n                    .as_ref()\n                    .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                let output = B::conv_transpose3d(x, weight, bias, args.options.clone().into());\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let mut streams = OperationStreams::with_inputs([&x, &weight]);\n        if let Some(bias) = bias.as_ref() {\n            streams.tensor(bias)\n        }\n\n        let client = x.client.clone();\n        let desc = ConvTranspose3dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::ConvTranspose3d(desc.clone())),\n                ConvTranspose3dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn avg_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AvgPool1dOps,\n            AvgPool1dOpIr,\n            |args: &AvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::avg_pool1d(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.count_include_pad,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = AvgPool1dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AvgPool1d(desc.clone())),\n                AvgPool1dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AvgPool2dOps,\n            AvgPool2dOpIr,\n            |args: &AvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::avg_pool2d(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.count_include_pad,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = AvgPool2dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AvgPool2d(desc.clone())),\n                AvgPool2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AvgPool1dBackwardOps,\n            AvgPool1dBackwardOpIr,\n            |args: &AvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let output = B::avg_pool1d_backward(\n                    x,\n                    grad,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.count_include_pad,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &grad]);\n\n        let client = x.client.clone();\n        let desc = AvgPool1dBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(desc.clone())),\n                AvgPool1dBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AvgPool2dBackwardOps,\n            AvgPool2dBackwardOpIr,\n            |args: &AvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let output = B::avg_pool2d_backward(\n                    x,\n                    grad,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.count_include_pad,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &grad]);\n\n        let client = x.client.clone();\n        let desc = AvgPool2dBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(desc.clone())),\n                AvgPool2dBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn max_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            MaxPool1dOps,\n            MaxPool1dOpIr,\n            |args: &MaxPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::max_pool1d(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = MaxPool1dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool1d(desc.clone())),\n                MaxPool1dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            MaxPool2dOps,\n            MaxPool2dOpIr,\n            |args: &MaxPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::max_pool2d(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = MaxPool2dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool2d(desc.clone())),\n                MaxPool2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn max_pool1d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<Self> {\n        make_ops!(\n            MaxPool1dWithIndicesOps,\n            MaxPool1dWithIndicesOpIr,\n            |args: &MaxPool1dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::max_pool1d_with_indices(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output.output);\n                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = MaxPool1dWithIndicesOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            B::IntElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        let [out, out_indices] = client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndices(desc.clone())),\n                MaxPool1dWithIndicesOps::<B>::new(desc),\n            )\n            .outputs();\n\n        MaxPool1dWithIndices::new(out, out_indices)\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        make_ops!(\n            MaxPool2dWithIndicesOps,\n            MaxPool2dWithIndicesOpIr,\n            |args: &MaxPool2dWithIndicesOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::max_pool2d_with_indices(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output.output);\n                handles.register_int_tensor::<B>(&args.out_indices.id, output.indices);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = MaxPool2dWithIndicesOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            B::IntElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        let [out, out_indices] = client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndices(desc.clone())),\n                MaxPool2dWithIndicesOps::<B>::new(desc),\n            )\n            .outputs();\n\n        MaxPool2dWithIndices::new(out, out_indices)\n    }\n\n    fn max_pool1d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool1dBackward<Self> {\n        make_ops!(\n            MaxPool1dWithIndicesBackwardOps,\n            MaxPool1dWithIndicesBackwardOpIr,\n            |args: &MaxPool1dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let indices = handles.get_int_tensor::<B>(&args.indices);\n                let output = B::max_pool1d_with_indices_backward(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                    grad,\n                    indices,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);\n\n        let client = x.client.clone();\n        let desc = MaxPool1dWithIndicesBackwardOpIr::create(\n            x.into_ir(),\n            output_grad.into_ir(),\n            indices.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        let out = client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool1dWithIndicesBackward(\n                    desc.clone(),\n                )),\n                MaxPool1dWithIndicesBackwardOps::<B>::new(desc),\n            )\n            .output();\n\n        MaxPool1dBackward::new(out)\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool2dBackward<Self> {\n        make_ops!(\n            MaxPool2dWithIndicesBackwardOps,\n            MaxPool2dWithIndicesBackwardOpIr,\n            |args: &MaxPool2dWithIndicesBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let indices = handles.get_int_tensor::<B>(&args.indices);\n                let output = B::max_pool2d_with_indices_backward(\n                    x,\n                    args.kernel_size,\n                    args.stride,\n                    args.padding,\n                    args.dilation,\n                    args.ceil_mode,\n                    grad,\n                    indices,\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output.x_grad);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &output_grad, &indices]);\n\n        let client = x.client.clone();\n        let desc = MaxPool2dWithIndicesBackwardOpIr::create(\n            x.into_ir(),\n            output_grad.into_ir(),\n            indices.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        let out = client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::MaxPool2dWithIndicesBackward(\n                    desc.clone(),\n                )),\n                MaxPool2dWithIndicesBackwardOps::<B>::new(desc),\n            )\n            .output();\n\n        MaxPool2dBackward::new(out)\n    }\n\n    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {\n        make_ops!(\n            AdaptiveAvgPool1dOps,\n            AdaptiveAvgPool1dOpIr,\n            |args: &AdaptiveAvgPool1dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::adaptive_avg_pool1d(x, args.output_size);\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(desc.clone())),\n                AdaptiveAvgPool1dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        make_ops!(\n            AdaptiveAvgPool2dOps,\n            AdaptiveAvgPool2dOpIr,\n            |args: &AdaptiveAvgPool2dOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::adaptive_avg_pool2d(x, args.output_size);\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(desc.clone())),\n                AdaptiveAvgPool2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn adaptive_avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AdaptiveAvgPool1dBackwardOps,\n            AdaptiveAvgPool1dBackwardOpIr,\n            |args: &AdaptiveAvgPool1dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let output = B::adaptive_avg_pool1d_backward(x, grad);\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &grad]);\n\n        let client = x.client.clone();\n        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1dBackward(desc.clone())),\n                AdaptiveAvgPool1dBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            AdaptiveAvgPool2dBackwardOps,\n            AdaptiveAvgPool2dBackwardOpIr,\n            |args: &AdaptiveAvgPool2dBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let output = B::adaptive_avg_pool2d_backward(x, grad);\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n        let streams = OperationStreams::with_inputs([&x, &grad]);\n\n        let client = x.client.clone();\n        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2dBackward(desc.clone())),\n                AdaptiveAvgPool2dBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            InterpolateOps,\n            InterpolateOpIr,\n            |args: &InterpolateOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let output = B::interpolate(x, args.output_size, args.options.clone().into());\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x]);\n\n        let client = x.client.clone();\n        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Interpolate(desc.clone())),\n                InterpolateOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        make_ops!(\n            InterpolateBackwardOps,\n            InterpolateBackwardOpIr,\n            |args: &InterpolateBackwardOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let x = handles.get_float_tensor::<B>(&args.x);\n                let grad = handles.get_float_tensor::<B>(&args.grad);\n                let output =\n                    B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let streams = OperationStreams::with_inputs([&x, &grad]);\n\n        let client = x.client.clone();\n        let desc = InterpolateBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            output_size,\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::InterpolateBackward(desc.clone())),\n                InterpolateBackwardOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn attention(\n        query: FloatTensor<Fusion<B>>,\n        key: FloatTensor<Fusion<B>>,\n        value: FloatTensor<Fusion<B>>,\n        mask: Option<burn_backend::tensor::BoolTensor<Fusion<B>>>,\n        attn_bias: Option<FloatTensor<Fusion<B>>>,\n        options: burn_backend::ops::AttentionModuleOptions,\n    ) -> FloatTensor<Fusion<B>> {\n        make_ops!(\n            AttentionOps,\n            AttentionOpIr,\n            |args: &AttentionOpIr, handles: &mut HandleContainer<B::Handle>| {\n                let query = handles.get_float_tensor::<B>(&args.query);\n                let key = handles.get_float_tensor::<B>(&args.key);\n                let value = handles.get_float_tensor::<B>(&args.value);\n                let mask = args.mask.as_ref().map(|m| handles.get_bool_tensor::<B>(m));\n                let attn_bias = args\n                    .attn_bias\n                    .as_ref()\n                    .map(|ab| handles.get_float_tensor::<B>(ab));\n\n                let output = B::attention(\n                    query,\n                    key,\n                    value,\n                    mask,\n                    attn_bias,\n                    args.options.clone().into(),\n                );\n\n                handles.register_float_tensor::<B>(&args.out.id, output);\n            }\n        );\n\n        let mut streams = OperationStreams::with_inputs([&query, &key, &value]);\n        if let Some(mask) = &mask {\n            streams.tensor(mask);\n        }\n        if let Some(attn_bias) = &attn_bias {\n            streams.tensor(attn_bias);\n        }\n\n        let client = query.client.clone();\n        let desc = AttentionOpIr::create(\n            query.into_ir(),\n            key.into_ir(),\n            value.into_ir(),\n            mask.map(|m| m.into_ir()),\n            attn_bias.map(|ab| ab.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::Module(ModuleOperationIr::Attention(desc.clone())),\n                AttentionOps::<B>::new(desc),\n            )\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/qtensor.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn_backend::{\n    DType, Element, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,\n    ops::QTensorOps,\n    quantization::{QuantPropagation, QuantScheme, QuantizationParametersPrimitive},\n    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},\n};\nuse burn_ir::{\n    BaseOperationIr, DequantizeOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, HandleContainer,\n    InitOperationIr, MatmulOpIr, OperationIr, OperationOutput, PermuteOpIr,\n    QuantizationParametersIr, QuantizeOpIr, SelectOpIr, ShapeOpIr, SliceOpIr, SwapDimsOpIr,\n};\n\nuse crate::{\n    Fusion, FusionBackend, get_client,\n    stream::{OperationStreams, execution::Operation},\n};\n\nuse super::NoOp;\n\nimpl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {\n    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {\n        let client = get_client::<B>(device);\n        let dtype = data.dtype;\n        let tensor = B::q_from_data(data, device);\n        let shape = burn_backend::TensorMetadata::shape(&tensor);\n\n        let handle = B::quantized_tensor_handle(tensor);\n        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::Init(desc),\n                NoOp::<B>::new(),\n            )\n            .output()\n    }\n\n    fn quantize(\n        tensor: FloatTensor<Self>,\n        scheme: &QuantScheme,\n        qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct QuantizeOp<B: FusionBackend> {\n            desc: QuantizeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for QuantizeOp<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let scales = handles.get_float_tensor::<B>(&self.desc.qparams.scales);\n\n                let qparams = QuantizationParametersPrimitive { scales };\n                let output = B::quantize(tensor, &self.desc.scheme, qparams);\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &qparams.scales]);\n\n        let client = tensor.client.clone();\n        let qparams = QuantizationParametersIr {\n            scales: qparams.scales.into_ir(),\n        };\n        let desc = QuantizeOpIr::create(tensor.into_ir(), qparams, *scheme, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.tensor.dtype, FloatOperationIr::Quantize(desc.clone())),\n                QuantizeOp::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct DequantizeOp<B: FusionBackend> {\n            desc: DequantizeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for DequantizeOp<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_quantized_tensor::<B>(&self.desc.input);\n\n                let output = B::dequantize(tensor);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let dtype = B::FloatElem::dtype();\n        let desc = DequantizeOpIr::create(tensor.into_ir(), dtype, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())),\n                DequantizeOp::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {\n        tensor.client.device().clone()\n    }\n\n    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {\n        let device_original: &B::Device = tensor.client.device();\n        let device_target: B::Device = device.clone();\n\n        if device_original == &device_target {\n            return tensor;\n        }\n\n        let id = tensor.stream;\n        let client_target = get_client::<B>(&device_target);\n        let client_original = tensor.client.clone();\n\n        client_original.change_client_quantized::<B>(tensor.into_ir(), client_target, id)\n    }\n\n    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        if tensor.shape == shape {\n            return tensor;\n        }\n\n        #[derive(new, Debug)]\n        struct ReshapeDimsOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_quantized_tensor::<B>(&self.desc.input);\n                let output = B::q_reshape(input, self.desc.out.shape.clone());\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),\n                ReshapeDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        tensor.q_into_data::<B>().await\n    }\n\n    fn q_swap_dims(\n        tensor: QuantizedTensor<Self>,\n        dim1: usize,\n        dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct SwapDimsOps<B: FusionBackend> {\n            desc: SwapDimsOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_quantized_tensor::<B>(&self.desc.input);\n                let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2);\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),\n                SwapDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct PermuteDimsOps<B: FusionBackend> {\n            desc: PermuteOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_quantized_tensor::<B>(&self.desc.input);\n                let output = B::q_permute(input, self.desc.axes.as_slice());\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Permute(desc.clone())),\n                PermuteDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct FlipOps<B: FusionBackend> {\n            desc: FlipOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_quantized_tensor::<B>(&self.desc.input);\n                let output = B::q_flip(input, &self.desc.axes);\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Flip(desc.clone())),\n                FlipOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_gather(\n        dim: usize,\n        tensor: QuantizedTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct GatherOps<B: FusionBackend> {\n            desc: GatherOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::q_gather(self.desc.dim, tensor, indices);\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),\n                GatherOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_select(\n        tensor: QuantizedTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct SelectOps<B: FusionBackend> {\n            desc: SelectOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::q_select(tensor, self.desc.dim, indices);\n\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),\n                SelectOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_slice(tensor: QuantizedTensor<Self>, slices: &[Slice]) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceOps<B: FusionBackend> {\n            desc: SliceOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_quantized_tensor::<B>(&self.desc.tensor);\n\n                let output = B::q_slice(tensor, self.desc.ranges.as_slice());\n\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),\n                SliceOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        #[derive(new, Debug)]\n        struct ExpandOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_quantized_tensor::<B>(&self.desc.input);\n                let output = B::q_expand(input, self.desc.out.shape.clone());\n\n                handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),\n                ExpandOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {\n        #[derive(new, Debug)]\n        struct MatmulOps<B: FusionBackend> {\n            desc: MatmulOpIr,\n            lhs_quantized: bool,\n            rhs_quantized: bool,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MatmulOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = match self.lhs_quantized {\n                    true => {\n                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.lhs))\n                    }\n                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.lhs)),\n                };\n                let rhs = match self.rhs_quantized {\n                    true => {\n                        TensorPrimitive::QFloat(handles.get_quantized_tensor::<B>(&self.desc.rhs))\n                    }\n                    false => TensorPrimitive::Float(handles.get_float_tensor::<B>(&self.desc.rhs)),\n                };\n                let output = B::q_matmul(lhs, rhs);\n                match output {\n                    TensorPrimitive::Float(output) => {\n                        handles.register_float_tensor::<B>(&self.desc.out.id, output);\n                    }\n                    TensorPrimitive::QFloat(output) => {\n                        handles.register_quantized_tensor::<B>(&self.desc.out.id, output);\n                    }\n                }\n            }\n        }\n\n        let mut propagation = QuantPropagation::Inhibit;\n        let mut scheme = QuantScheme::default();\n        let mut streams = OperationStreams::default();\n        let mut lhs_quantized = false;\n        let mut rhs_quantized = false;\n        match &lhs {\n            TensorPrimitive::QFloat(lhs) => {\n                propagation = lhs.propagation();\n                scheme = *lhs.scheme();\n                lhs_quantized = true;\n                streams.tensor(lhs);\n            }\n            TensorPrimitive::Float(lhs) => {\n                streams.tensor(lhs);\n            }\n        }\n        match &rhs {\n            TensorPrimitive::QFloat(rhs) => {\n                propagation = rhs.propagation();\n                scheme = *rhs.scheme();\n                rhs_quantized = true;\n                streams.tensor(rhs);\n            }\n            TensorPrimitive::Float(rhs) => {\n                streams.tensor(rhs);\n            }\n        }\n\n        let dtype = match propagation {\n            QuantPropagation::Propagate => DType::QFloat(scheme),\n            QuantPropagation::Inhibit => B::FloatElem::dtype(),\n        };\n\n        let client = match &lhs {\n            TensorPrimitive::Float(lhs) => lhs.client.clone(),\n            TensorPrimitive::QFloat(lhs) => lhs.client.clone(),\n        };\n\n        let lhs = match lhs {\n            TensorPrimitive::Float(lhs) => lhs.into_ir(),\n            TensorPrimitive::QFloat(lhs) => lhs.into_ir(),\n        };\n        let rhs = match rhs {\n            TensorPrimitive::Float(rhs) => rhs.into_ir(),\n            TensorPrimitive::QFloat(rhs) => rhs.into_ir(),\n        };\n\n        let desc = MatmulOpIr::create_mixed(lhs, rhs, dtype, || client.create_empty_handle());\n\n        let out = client\n            .register(\n                streams,\n                OperationIr::Float(dtype, FloatOperationIr::Matmul(desc.clone())),\n                MatmulOps::<B>::new(desc, lhs_quantized, rhs_quantized),\n            )\n            .output();\n\n        match propagation {\n            QuantPropagation::Propagate => TensorPrimitive::QFloat(out),\n            QuantPropagation::Inhibit => TensorPrimitive::Float(out),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/tensor.rs",
    "content": "use super::NoOp;\nuse crate::{\n    Fusion, FusionBackend, binary_float_cmp_ops, binary_float_ops, bool_dtype, get_client,\n    reduce_float_ops, reduce_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,\n    stream::{OperationStreams, execution::Operation},\n    unary_float_ops,\n};\nuse burn_backend::{\n    Distribution, Element, ExecutionError, FloatDType, Scalar, Shape, Slice, TensorData,\n    ops::{FloatTensorOps, GridSampleOptions},\n    tensor::{BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntTensor},\n};\nuse burn_ir::*;\nuse std::marker::PhantomData;\n\nimpl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(data),\n        fields(?data.shape, ?data.dtype)\n    ))]\n    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {\n        let client = get_client::<B>(device);\n        let dtype = data.dtype;\n        let tensor = B::float_from_data(data, device);\n        let shape = burn_backend::TensorMetadata::shape(&tensor);\n\n        let handle = B::float_tensor_handle(tensor);\n        let desc = InitOperationIr::create(shape, dtype, || client.register_tensor_handle(handle));\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::Init(desc),\n                NoOp::<B>::new(),\n            )\n            .output()\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct RandomOps<B: FusionBackend> {\n            desc: RandomOpIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for RandomOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output: B::FloatTensorPrimitive = B::float_random(\n                    self.desc.out.shape.clone(),\n                    self.desc.distribution,\n                    &self.device,\n                );\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let dtype = FloatElem::<Self>::dtype();\n        let client = get_client::<B>(device);\n        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::Float(dtype, FloatOperationIr::Random(desc.clone())),\n                RandomOps::<B>::new(desc, device.clone()),\n            )\n            .output()\n    }\n\n    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct ZerosOps<B: FusionBackend> {\n            out: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ZerosOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.out.shape.clone();\n                let output = B::float_zeros(shape, &self.device, self.out.dtype.into());\n                handles.register_float_tensor::<B>(&self.out.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseFloat(BaseOperationIr::Zeros(desc.clone())),\n                ZerosOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct OnesOps<B: FusionBackend> {\n            out: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for OnesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.out.shape.clone();\n                let output = B::float_ones(shape, &self.device, self.out.dtype.into());\n                handles.register_float_tensor::<B>(&self.out.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseFloat(BaseOperationIr::Ones(desc.clone())),\n                OnesOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn float_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<Self>,\n        dtype: FloatDType,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct FullOps<B: FusionBackend> {\n            out: TensorIr,\n            elem: ScalarIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FullOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let shape = self.out.shape.clone();\n                let dtype = self.out.dtype.into();\n                let output: B::FloatTensorPrimitive =\n                    B::float_full(shape, self.elem.into(), &self.device, dtype);\n                handles.register_float_tensor::<B>(&self.out.id, output);\n            }\n        }\n\n        let dtype = dtype.into();\n        let client = get_client::<B>(device);\n        let value = fill_value.into();\n        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::NumericFloat(dtype, NumericOperationIr::Full(desc.clone())),\n                FullOps::<B>::new(desc.out, desc.value, device.clone()),\n            )\n            .output()\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(\n            from = ?tensor.client.device(),\n            shape = ?tensor.shape,\n            dtype = ?tensor.dtype\n        )\n    ))]\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        tensor.into_data::<B>().await\n    }\n\n    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {\n        tensor.client.device().clone()\n    }\n\n    #[cfg_attr(feature = \"tracing\", tracing::instrument(\n        level=\"trace\",\n        skip(tensor),\n        fields(\n            from = ?tensor.client.device(),\n            shape = ?tensor.shape,\n            dtype = ?tensor.dtype,\n        )\n    ))]\n    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {\n        let device_original: &B::Device = tensor.client.device();\n\n        if device_original == device {\n            return tensor;\n        }\n\n        let id = tensor.stream;\n        let client_target = get_client::<B>(device);\n        let client_original = tensor.client.clone();\n\n        client_original\n            .clone()\n            .change_client_float::<B>(tensor.into_ir(), client_target, id)\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {\n        #[derive(new, Debug)]\n        struct IntoIntOps<B: FusionBackend> {\n            desc: CastOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IntoIntOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_into_int(input);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), B::IntElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.input.dtype, FloatOperationIr::IntoInt(desc.clone())),\n                IntoIntOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct EmptyOps<B: FusionBackend> {\n            desc: TensorIr,\n            device: Device<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let output = B::float_empty(\n                    self.desc.shape.clone(),\n                    &self.device,\n                    self.desc.dtype.into(),\n                );\n                handles.register_float_tensor::<B>(&self.desc.id, output);\n            }\n        }\n\n        let client = get_client::<B>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(\n                OperationStreams::default(),\n                OperationIr::BaseFloat(BaseOperationIr::Empty(desc.clone())),\n                EmptyOps::<B>::new(desc.out, device.clone()),\n            )\n            .output()\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(AddOps, B::float_add);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Add(desc.clone())),\n                AddOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(AddOps, B::float_add_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::AddScalar(desc.clone()),\n                ),\n                AddOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct ClampOps<B: FusionBackend> {\n            desc: ClampOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ClampOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let output = B::float_clamp(input, self.desc.min.into(), self.desc.max.into());\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let min = min.into();\n        let max = max.into();\n        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.tensor.dtype,\n                    NumericOperationIr::Clamp(desc.clone()),\n                ),\n                ClampOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(SubOps, B::float_sub);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sub(desc.clone())),\n                SubOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(SubOps, B::float_sub_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::SubScalar(desc.clone()),\n                ),\n                SubOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(MulOps, B::float_mul);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mul(desc.clone())),\n                MulOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(MulOps, B::float_mul_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::MulScalar(desc.clone()),\n                ),\n                MulOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(DivOps, B::float_div);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Div(desc.clone())),\n                DivOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(DivOps, B::float_div_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::DivScalar(desc.clone()),\n                ),\n                DivOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(ModOps, B::float_remainder);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Rem(desc.clone())),\n                ModOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(ModOps, B::float_remainder_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::RemScalar(desc.clone()),\n                ),\n                ModOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(MatmulOps, B::float_matmul);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Matmul(desc.clone())),\n                MatmulOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CrossOps<B: FusionBackend> {\n            desc: CrossOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CrossOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let rhs = handles.get_float_tensor::<B>(&self.desc.rhs);\n                let output = B::float_cross(lhs, rhs, self.desc.dim);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cross(desc.clone())),\n                CrossOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct SwapDimsOps<B: FusionBackend> {\n            desc: SwapDimsOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SwapDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())),\n                SwapDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        if tensor.shape == shape {\n            return tensor;\n        }\n\n        #[derive(new, Debug)]\n        struct ReshapeDimsOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ReshapeDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_reshape(input, self.desc.out.shape.clone());\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())),\n                ReshapeDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct GatherOps<B: FusionBackend> {\n            desc: GatherOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for GatherOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::float_gather(self.desc.dim, tensor, indices);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices]);\n\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Gather(desc.clone())),\n                GatherOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct ScatterOps<B: FusionBackend> {\n            desc: ScatterOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ScatterOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n                let value = handles.get_float_tensor::<B>(&self.desc.value);\n\n                let output = B::float_scatter_add(self.desc.dim, tensor, indices, value);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);\n\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Scatter(desc.clone())),\n                ScatterOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct SelectOps<B: FusionBackend> {\n            desc: SelectOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n\n                let output = B::float_select(tensor, self.desc.dim, indices);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices]);\n\n        let client = tensor.client.clone();\n        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Select(desc.clone())),\n                SelectOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct SelectAssignOps<B: FusionBackend> {\n            desc: SelectAssignOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SelectAssignOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let indices = handles.get_int_tensor::<B>(&self.desc.indices);\n                let value = handles.get_float_tensor::<B>(&self.desc.value);\n\n                let output = B::float_select_add(tensor, self.desc.dim, indices, value);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &indices, &value]);\n\n        let client = tensor.client.clone();\n        let desc = SelectAssignOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc.clone())),\n                SelectAssignOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceOps<B: FusionBackend> {\n            desc: SliceOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n\n                let output = B::float_slice(tensor, self.desc.ranges.as_slice());\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())),\n                SliceOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct SliceAssignOps<B: FusionBackend> {\n            desc: SliceAssignOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for SliceAssignOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_float_tensor::<B>(&self.desc.value);\n\n                let output = B::float_slice_assign(tensor, self.desc.ranges.as_slice(), value);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &value]);\n\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc.clone())),\n                SliceAssignOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskWhereOps<B: FusionBackend> {\n            desc: MaskWhereOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskWhereOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let value = handles.get_float_tensor::<B>(&self.desc.value);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::float_mask_where(tensor, mask, value);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask, &value]);\n\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc.clone())),\n                MaskWhereOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct MaskFillOps<B: FusionBackend> {\n            desc: MaskFillOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaskFillOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let mask = handles.get_bool_tensor::<B>(&self.desc.mask);\n\n                let output = B::float_mask_fill(tensor, mask, self.desc.value.into());\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &mask]);\n\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc.clone())),\n                MaskFillOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float_cmp_ops!(EqualOps, B::float_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Equal(desc.clone())),\n                EqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_float_cmp_ops!(EqualElemOps, B::float_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc.clone())),\n                EqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float_cmp_ops!(GreaterOps, B::float_greater);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::Greater(desc.clone()),\n                ),\n                GreaterOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_float_cmp_ops!(GreaterElemOps, B::float_greater_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterElem(desc.clone()),\n                ),\n                GreaterElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float_cmp_ops!(GreaterEqualOps, B::float_greater_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterEqual(desc.clone()),\n                ),\n                GreaterEqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_float_cmp_ops!(GreaterEqualElemOps, B::float_greater_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::GreaterEqualElem(desc.clone()),\n                ),\n                GreaterEqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float_cmp_ops!(LowerOps, B::float_lower);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.lhs.dtype, NumericOperationIr::Lower(desc.clone())),\n                LowerOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_float_cmp_ops!(LowerElemOps, B::float_lower_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerElem(desc.clone()),\n                ),\n                LowerElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        binary_float_cmp_ops!(LowerEqualOps, B::float_lower_equal);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            bool_dtype::<B::BoolElem>(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerEqual(desc.clone()),\n                ),\n                LowerEqualOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        scalar_float_cmp_ops!(LowerEqualElemOps, B::float_lower_equal_elem);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc =\n            ScalarOpIr::create_comparison(lhs.into_ir(), rhs, bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.lhs.dtype,\n                    NumericOperationIr::LowerEqualElem(desc.clone()),\n                ),\n                LowerEqualElemOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(SumOps, B::float_sum, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Sum(desc.clone())),\n                SumOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(SumDimOps, B::float_sum_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::SumDim(desc.clone())),\n                SumDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ProdOps, B::float_prod, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Prod(desc.clone())),\n                ProdOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(ProdDimOps, B::float_prod_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::ProdDim(desc.clone()),\n                ),\n                ProdDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(MeanOps, B::float_mean, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Mean(desc.clone())),\n                MeanOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(MeanDimOps, B::float_mean_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::MeanDim(desc.clone()),\n                ),\n                MeanDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumsumOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumsumOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_cumsum(input, self.desc.axis);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumSum(desc.clone())),\n                CumsumOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumprodOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumprodOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_cumprod(input, self.desc.axis);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::CumProd(desc.clone()),\n                ),\n                CumprodOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CumminOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_cummin(input, self.desc.axis);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMin(desc.clone())),\n                CumminOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CummaxOps<B: FusionBackend> {\n            desc: DimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CummaxOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_cummax(input, self.desc.axis);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::CumMax(desc.clone())),\n                CummaxOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ExpOps, B::float_exp);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Exp(desc.clone())),\n                ExpOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(LogOps, B::float_log);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Log(desc.clone())),\n                LogOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(Log1pOps, B::float_log1p);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Log1p(desc.clone())),\n                Log1pOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        scalar_float_ops!(PowfOps, B::float_powf_scalar);\n\n        let streams = OperationStreams::with_inputs([&lhs]);\n\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::PowfScalar(desc.clone())),\n                PowfOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(SqrtOps, B::float_sqrt);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sqrt(desc.clone())),\n                SqrtOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(AbsOps, B::float_abs);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Abs(desc.clone())),\n                AbsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(CosOps, B::float_cos);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cos(desc.clone())),\n                CosOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(SinOps, B::float_sin);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sin(desc.clone())),\n                SinOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(TanOps, B::float_tan);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Tan(desc.clone())),\n                TanOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(CoshOps, B::float_cosh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Cosh(desc.clone())),\n                CoshOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(SinhOps, B::float_sinh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Sinh(desc.clone())),\n                SinhOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(TanhOps, B::float_tanh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Tanh(desc.clone())),\n                TanhOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcCosOps, B::float_acos);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCos(desc.clone())),\n                ArcCosOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcCoshOps, B::float_acosh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcCosh(desc.clone())),\n                ArcCoshOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcSinOps, B::float_asin);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSin(desc.clone())),\n                ArcSinOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcSinhOps, B::float_asinh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcSinh(desc.clone())),\n                ArcSinhOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcTanOps, B::float_atan);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan(desc.clone())),\n                ArcTanOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(ArcTanhOps, B::float_atanh);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTanh(desc.clone())),\n                ArcTanhOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(ArcTan2Ops, B::float_atan2);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::ArcTan2(desc.clone())),\n                ArcTan2Ops::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(Recip, B::float_recip);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Recip(desc.clone())),\n                Recip::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(TanhOps, B::float_erf);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Erf(desc.clone())),\n                TanhOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CatOps<B: FusionBackend> {\n            desc: CatOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CatOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensors = self\n                    .desc\n                    .tensors\n                    .iter()\n                    .map(|tensor| handles.get_float_tensor::<B>(tensor))\n                    .collect();\n\n                let output = B::float_cat(tensors, self.desc.dim);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs(&tensors);\n\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Cat(desc.clone())),\n                CatOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_float2int_ops!(ArgMaxOps, B::float_argmax);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        // TODO: rename `create_with_dtype` specifically for ARG / indices\n        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, B::IntElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.input.dtype,\n                    NumericOperationIr::ArgMax(desc.clone()),\n                ),\n                ArgMaxOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct RepeatDimOps<B: FusionBackend> {\n            desc: RepeatDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n\n                let output = B::float_repeat_dim(tensor, self.desc.dim, self.desc.times);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc.clone())),\n                RepeatDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        reduce_float2int_ops!(ArgMinOps, B::float_argmin);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create_arg(tensor.into_ir(), dim, B::IntElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.input.dtype,\n                    NumericOperationIr::ArgMin(desc.clone()),\n                ),\n                ArgMinOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(MaxOps, B::float_max, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Max(desc.clone())),\n                MaxOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(MaxDimOps, B::float_max_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxDim(desc.clone())),\n                MaxDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_max_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        #[derive(new, Debug)]\n        struct MaxDimWithIndicesOps<B: FusionBackend> {\n            desc: ReduceDimWithIndicesOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MaxDimWithIndicesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc =\n            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, B::IntElem::dtype(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.tensor.dtype,\n                    NumericOperationIr::MaxDimWithIndices(desc.clone()),\n                ),\n                MaxDimWithIndicesOps::<B>::new(desc),\n            )\n            .outputs()\n            .into()\n    }\n\n    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(MinOps, B::float_min, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::Min(desc.clone())),\n                MinOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(MinDimOps, B::float_min_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MinDim(desc.clone())),\n                MinDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_min_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        #[derive(new, Debug)]\n        struct MinDimWithIndicesOps<B: FusionBackend> {\n            desc: ReduceDimWithIndicesOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for MinDimWithIndicesOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n                handles.register_int_tensor::<B>(&self.desc.out_indices.id, indices);\n            }\n        }\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc =\n            ReduceDimWithIndicesOpIr::create(tensor.into_ir(), dim, B::IntElem::dtype(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.tensor.dtype,\n                    NumericOperationIr::MinDimWithIndices(desc.clone()),\n                ),\n                MinDimWithIndicesOps::<B>::new(desc),\n            )\n            .outputs()\n            .into()\n    }\n\n    fn float_max_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(MaxAbsOps, B::float_max_abs, reduce);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(desc.out.dtype, NumericOperationIr::MaxAbs(desc.clone())),\n                MaxAbsOps::<B>::new(desc.into()),\n            )\n            .output()\n    }\n\n    fn float_max_abs_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        reduce_float_ops!(MaxAbsDimOps, B::float_max_abs_dim);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::NumericFloat(\n                    desc.out.dtype,\n                    NumericOperationIr::MaxAbsDim(desc.clone()),\n                ),\n                MaxAbsDimOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    // TODO: float_powi w/ burn-cubecl-fusion impl\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        binary_float_ops!(PowOps, B::float_powf);\n\n        let streams = OperationStreams::with_inputs([&lhs, &rhs]);\n\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Powf(desc.clone())),\n                PowOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct PermuteDimsOps<B: FusionBackend> {\n            desc: PermuteOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for PermuteDimsOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_permute(input, self.desc.axes.as_slice());\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())),\n                PermuteDimsOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct ExpandOps<B: FusionBackend> {\n            desc: ShapeOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for ExpandOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_expand(input, self.desc.out.shape.clone());\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())),\n                ExpandOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct FlipOps<B: FusionBackend> {\n            desc: FlipOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for FlipOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_flip(input, &self.desc.axes);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())),\n                FlipOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(RoundOps, B::float_round);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Round(desc.clone())),\n                RoundOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(FloorOps, B::float_floor);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Floor(desc.clone())),\n                FloorOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(CeilOps, B::float_ceil);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Ceil(desc.clone())),\n                CeilOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        unary_float_ops!(TruncOps, B::float_trunc);\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::Trunc(desc.clone())),\n                TruncOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct CastOps<B: FusionBackend> {\n            desc: CastOpIr,\n            dtype: burn_backend::FloatDType,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for CastOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output: B::FloatTensorPrimitive = B::float_cast(input, self.dtype);\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Cast(desc.clone())),\n                CastOps::<B>::new(desc, dtype),\n            )\n            .output()\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct UnfoldOps<B: FusionBackend> {\n            desc: UnfoldOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_unfold(input, self.desc.dim, self.desc.size, self.desc.step);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(\n                streams,\n                OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),\n                UnfoldOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct IsNanOps<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsNanOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_is_nan(input);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc =\n            UnaryOpIr::create_comparison(tensor.into_ir(), bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.input.dtype, FloatOperationIr::IsNan(desc.clone())),\n                IsNanOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        #[derive(new, Debug)]\n        struct IsInfOps<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for IsInfOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = B::float_is_inf(input);\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor]);\n\n        let client = tensor.client.clone();\n        let desc =\n            UnaryOpIr::create_comparison(tensor.into_ir(), bool_dtype::<B::BoolElem>(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.input.dtype, FloatOperationIr::IsInf(desc.clone())),\n                IsInfOps::<B>::new(desc),\n            )\n            .output()\n    }\n\n    fn float_grid_sample_2d(\n        tensor: FloatTensor<Self>,\n        grid: FloatTensor<Self>,\n        options: GridSampleOptions,\n    ) -> FloatTensor<Self> {\n        #[derive(new, Debug)]\n        struct GridSample2dOps<B: FusionBackend> {\n            desc: GridSample2dOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for GridSample2dOps<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let tensor = handles.get_float_tensor::<B>(&self.desc.tensor);\n                let grid = handles.get_float_tensor::<B>(&self.desc.grid);\n                let output =\n                    B::float_grid_sample_2d(tensor, grid, self.desc.options.clone().into());\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n\n        let streams = OperationStreams::with_inputs([&tensor, &grid]);\n\n        let client = tensor.client.clone();\n        let desc =\n            GridSample2dOpIr::create(tensor.into_ir(), grid.into_ir(), options.into(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(\n                streams,\n                OperationIr::Float(desc.out.dtype, FloatOperationIr::GridSample2d(desc.clone())),\n                GridSample2dOps::<B>::new(desc),\n            )\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/transaction.rs",
    "content": "use burn_backend::{\n    backend::ExecutionError,\n    ops::{TransactionOps, TransactionPrimitive},\n};\n\nuse crate::{Fusion, FusionBackend};\n\nimpl<B: FusionBackend> TransactionOps<Fusion<B>> for Fusion<B> {\n    async fn tr_execute(\n        transaction: TransactionPrimitive<Self>,\n    ) -> Result<burn_backend::ops::TransactionPrimitiveData, ExecutionError> {\n        B::tr_execute(TransactionPrimitive::new(\n            transaction\n                .read_floats\n                .into_iter()\n                .map(|t| t.client.clone().resolve_tensor_float::<B>(t))\n                .collect(),\n            transaction\n                .read_qfloats\n                .into_iter()\n                .map(|_t| todo!(\"Quantization not supported yet\"))\n                .collect(),\n            transaction\n                .read_ints\n                .into_iter()\n                .map(|t| t.client.clone().resolve_tensor_int::<B>(t))\n                .collect(),\n            transaction\n                .read_bools\n                .into_iter()\n                .map(|t| t.client.clone().resolve_tensor_bool::<B>(t))\n                .collect(),\n        ))\n        .await\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/ops/unary.rs",
    "content": "#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs.into());\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n    (\n        $name:ident,\n        $ops:expr,\n        noconvert\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_float_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ReduceDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = $ops(input, self.desc.axis);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_float2int_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ReduceDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = $ops(input, self.desc.axis);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_int_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ReduceDimOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = $ops(input, self.desc.axis);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float2int_ops {\n    (\n        $name:ident,\n        $ops:expr,\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs.clone());\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! unary_float_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = $ops(input);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n    (\n        $name:ident,\n        $ops:expr,\n        reduce\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_float_tensor::<B>(&self.desc.input);\n                let output = $ops(input);\n\n                handles.register_float_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! unary_int_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = $ops(input);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n    (\n        $name:ident,\n        $ops:expr,\n        reduce\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: UnaryOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let input = handles.get_int_tensor::<B>(&self.desc.input);\n                let output = $ops(input);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float_cmp_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_float_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs.into());\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_int_cmp_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs.into());\n\n                handles.register_bool_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_int_ops {\n    (\n        $name:ident,\n        $ops:expr\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs.into());\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n    (\n        $name:ident,\n        $ops:expr,\n        noconvert\n    ) => {\n        #[derive(new, Debug)]\n        struct $name<B: FusionBackend> {\n            desc: ScalarOpIr,\n            _b: PhantomData<B>,\n        }\n\n        impl<B: FusionBackend> Operation<B::FusionRuntime> for $name<B> {\n            fn execute(&self, handles: &mut HandleContainer<B::Handle>) {\n                let lhs = handles.get_int_tensor::<B>(&self.desc.lhs);\n                let output = $ops(lhs, self.desc.rhs);\n\n                handles.register_int_tensor::<B>(&self.desc.out.id, output);\n            }\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/search/block.rs",
    "content": "use crate::{FuserStatus, NumOperations, OperationFuser, stream::store::ExecutionStrategy};\nuse burn_ir::{OperationIr, TensorId, TensorIr};\nuse std::{collections::HashSet, sync::Arc};\n\n/// A block represents a list of operations, not necessarily in the same order as the execution\n/// stream.\n///\n/// The start and end position of the relative execution stream are tracked in the block alongside\n/// the ordering.\npub struct Block<O> {\n    builders: Vec<Box<dyn OperationFuser<O>>>,\n    operations: Vec<OperationIr>,\n    ids: HashSet<TensorId>,\n    ordering: Vec<usize>,\n    /// The start position in the relative execution stream.\n    pub start_pos: usize,\n    /// The end position in the relative execution stream.\n    pub end_pos: usize,\n}\n\n/// The result of [registering](Block::register) an [operation](OperationIr).\npub enum RegistrationResult {\n    /// If the [operation](OperationIr) is correctly registered.\n    Accepted,\n    /// If the [operation](OperationIr) isn't part of the graph.\n    ///\n    /// In this case the operation isn't registered.\n    NotPartOfTheGraph,\n}\n\n/// The optimization found for a [block](Block).\n#[derive(Debug, new)]\npub struct BlockOptimization<O> {\n    /// The [execution strategy](ExecutionStrategy) to be used to execute the [block](Block).\n    pub strategy: ExecutionStrategy<O>,\n    /// The ordering of each operation in the relative execution stream.\n    pub ordering: Vec<usize>,\n}\n\nimpl<O: NumOperations> Block<O> {\n    /// Create a new block that will be optimized with the provided [optimization builders](OptimizationBuilder).\n    pub fn new(builders: &[Box<dyn OperationFuser<O>>]) -> Self {\n        Self {\n            builders: builders.iter().map(|o| o.clone_dyn()).collect(),\n            operations: Vec::new(),\n            ids: HashSet::new(),\n            ordering: Vec::new(),\n            start_pos: usize::MAX,\n            end_pos: usize::MIN,\n        }\n    }\n\n    /// Sort the [blocks](Block) based on the start position.\n    pub fn sort(blocks: &mut [Self]) {\n        blocks.sort_by(|a, b| a.start_pos.cmp(&b.start_pos));\n    }\n\n    /// Optimize the block.\n    pub fn optimize(mut self) -> BlockOptimization<O> {\n        match find_best_optimization_index(&mut self.builders) {\n            Some(index) => {\n                let opt = self.builders[index].finish();\n                let opt_len = opt.len();\n                if opt_len < self.operations.len() {\n                    self.ordering.drain(opt_len..);\n                }\n\n                let strategy = ExecutionStrategy::Optimization {\n                    ordering: Arc::new(self.ordering.clone()),\n                    opt,\n                };\n                BlockOptimization::new(strategy, self.ordering)\n            }\n            None => {\n                let strategy = ExecutionStrategy::Operations {\n                    ordering: Arc::new(self.ordering.clone()),\n                };\n                BlockOptimization::new(strategy, self.ordering)\n            }\n        }\n    }\n\n    /// Returns if the block contains any of the provided [tensors](TensorIr).\n    pub fn contains_tensors(&self, tensors: &[&TensorIr]) -> bool {\n        for node in tensors {\n            if self.ids.contains(&node.id) {\n                return true;\n            }\n        }\n\n        false\n    }\n\n    /// Merge the current block with the other one and returns if the operation is successful.\n    ///\n    /// # Warning\n    ///\n    /// This will modify the current block even if the other block isn't correctly merged.\n    pub fn merge(&mut self, other: &Block<O>) -> bool {\n        for (op, pos) in other.operations.iter().zip(&other.ordering) {\n            self.register(op, *pos, true);\n        }\n\n        // The operation is successful if the current block can still be optimized.\n        self.still_optimizing()\n    }\n\n    /// Register an [operation](OperationIr) in the current block.\n    ///\n    /// You need to provide the order of the operation as well as a force flag.\n    ///\n    /// When the force flag is true, the builder will always accept the operation, otherwise it\n    /// might refuse it if the operation [isn't part of the graph](RegistrationResult::NotPartOfTheGraph).\n    ///\n    /// Forcing is useful to fuse operations that are part of different graphs, but included\n    /// in the same optimization.\n    pub fn register(\n        &mut self,\n        operation: &OperationIr,\n        order: usize,\n        force: bool,\n    ) -> RegistrationResult {\n        if self.ids.is_empty() {\n            self.register_op(operation, order);\n            return RegistrationResult::Accepted;\n        }\n        let mut contains = false;\n        for node in operation.nodes() {\n            contains = self.ids.contains(&node.id);\n\n            if contains {\n                break;\n            }\n        }\n\n        if !contains && !force {\n            return RegistrationResult::NotPartOfTheGraph;\n        }\n\n        self.register_op(operation, order);\n        RegistrationResult::Accepted\n    }\n\n    /// If the block can still be optimized further.\n    pub fn still_optimizing(&self) -> bool {\n        let mut num_stopped = 0;\n\n        for optimization in self.builders.iter() {\n            if let FuserStatus::Closed = optimization.status() {\n                num_stopped += 1\n            }\n        }\n\n        num_stopped < self.builders.len()\n    }\n\n    fn register_op(&mut self, operation: &OperationIr, pos: usize) {\n        self.operations.push(operation.clone());\n        self.ordering.push(pos);\n\n        if pos < self.start_pos {\n            self.start_pos = pos;\n        }\n        if pos + 1 > self.end_pos {\n            self.end_pos = pos + 1;\n        }\n\n        for builder in self.builders.iter_mut() {\n            builder.fuse(operation);\n        }\n\n        for node in operation.nodes() {\n            self.ids.insert(node.id);\n        }\n    }\n}\n\nimpl<O> BlockOptimization<O> {\n    /// Maps the ordering of the current block optimization using the given mapping.\n    pub fn map_ordering(&mut self, mapping: &[usize]) {\n        for i in self.ordering.iter_mut() {\n            *i = mapping[*i];\n        }\n        self.strategy.map_ordering(mapping);\n    }\n}\n\nimpl<O> ExecutionStrategy<O> {\n    /// Maps the ordering of the current execution strategy using the given mapping.\n    pub fn map_ordering(&mut self, mapping: &[usize]) {\n        match self {\n            ExecutionStrategy::Optimization { ordering, .. } => {\n                let mut ordering_mapped = ordering.to_vec();\n\n                for o in ordering_mapped.iter_mut() {\n                    *o = mapping[*o];\n                }\n                *ordering = Arc::new(ordering_mapped);\n            }\n            ExecutionStrategy::Operations { ordering } => {\n                let mut ordering_mapped = ordering.to_vec();\n\n                for o in ordering_mapped.iter_mut() {\n                    *o = mapping[*o];\n                }\n\n                *ordering = Arc::new(ordering_mapped);\n            }\n            ExecutionStrategy::Composed(items) => {\n                for item in items.iter_mut() {\n                    item.map_ordering(mapping);\n                }\n            }\n        }\n    }\n}\n\nfn find_best_optimization_index<O>(\n    optimizations: &mut [Box<dyn OperationFuser<O>>],\n) -> Option<usize> {\n    let mut best_index = None;\n    let mut best_score = 0;\n\n    for (i, optimization) in optimizations.iter().enumerate() {\n        let properties = optimization.properties();\n\n        // A score of zero is worse than fusing.\n        if properties.ready && properties.score > best_score {\n            best_index = Some(i);\n            best_score = properties.score;\n        }\n    }\n\n    best_index\n}\n\nimpl<O> PartialEq for Block<O> {\n    fn eq(&self, other: &Self) -> bool {\n        // Since the ordering can be seen as operation ids, we can use it to compare\n        // blocks.\n        let mut sorted_a = self.ordering.clone();\n        let mut sorted_b = other.ordering.clone();\n        sorted_a.sort();\n        sorted_b.sort();\n\n        sorted_a == sorted_b\n    }\n}\n\nimpl<O> core::fmt::Debug for Block<O> {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_fmt(format_args!(\n            \"Block {{ pos: [{:?}, {:?}; {:?}] }}\",\n            self.start_pos,\n            self.end_pos,\n            self.ordering.len(),\n        ))\n    }\n}\n\nimpl<O> Clone for Block<O> {\n    fn clone(&self) -> Self {\n        Self {\n            builders: self.builders.iter().map(|b| b.clone_dyn()).collect(),\n            operations: self.operations.clone(),\n            ids: self.ids.clone(),\n            ordering: self.ordering.clone(),\n            start_pos: self.start_pos,\n            end_pos: self.end_pos,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/search/merging.rs",
    "content": "use super::Block;\nuse crate::NumOperations;\n\n#[derive(Debug, PartialEq)]\n/// The result of [merging](merge_blocks) [blocks](Block).\npub enum MergeBlocksResult<O> {\n    /// All [blocks](Block) merged into one.\n    Full(Block<O>),\n    /// Some [blocks](Block) merged and some failed.\n    Partial {\n        merged: Vec<Block<O>>,\n        failed: Vec<Block<O>>,\n    },\n    /// All [blocks](Block) failed to merge.\n    Fail,\n}\n\n/// Merge multiple [block](Block) together.\n///\n/// The resulting [blocks](Block) might be sorted if the flag is true, otherwise the order isn't\n/// guarantee. This is mostly useful for testing.\n///\n/// # Strategy\n///\n/// The merging strategy is in two steps:\n///\n/// 1. The first step is to recursively try to merge adjacent blocks. This has the advantage of\n///    trying multiple blocks ordering, therefore trying multiple permutation of the blocks.\n///    However, it has the downside of not trying to merge blocks that are further away in the list\n///    of blocks. Since trying all combinations possible is exponential, therefore not possible, we\n///    fallback on the second strategy.\n/// 2. The second step is to reduce blocks by setting an accumulator block, then sequentially\n///    trying to merge the remaining blocks. We try some permutations based on the result from\n///    step1.\npub fn merge_blocks<O: NumOperations>(blocks: &[&Block<O>], sorted: bool) -> MergeBlocksResult<O> {\n    if blocks.is_empty() {\n        return MergeBlocksResult::Fail;\n    }\n\n    if blocks.len() == 1 {\n        return MergeBlocksResult::Full(blocks[0].clone());\n    }\n\n    if blocks.len() == 2 {\n        let block0 = blocks[0];\n        let block1 = blocks[1];\n\n        return match merge_two(block0, block1) {\n            Some(result) => MergeBlocksResult::Full(result),\n            None => MergeBlocksResult::Fail,\n        };\n    }\n\n    let mut step1 = merge_blocks_step1(blocks);\n\n    if step1.full.len() == 1 && step1.failed.is_empty() && step1.partial.is_empty() {\n        MergeBlocksResult::Full(step1.full.remove(0))\n    } else if step1.partial.len() == 1 && step1.failed.is_empty() && step1.full.is_empty() {\n        MergeBlocksResult::Full(step1.partial.remove(0))\n    } else {\n        let result = merge_blocks_step2(step1);\n\n        if !sorted {\n            return result;\n        }\n\n        match result {\n            MergeBlocksResult::Full(block) => MergeBlocksResult::Full(block),\n            MergeBlocksResult::Partial {\n                mut merged,\n                mut failed,\n            } => {\n                Block::sort(&mut merged);\n                Block::sort(&mut failed);\n\n                MergeBlocksResult::Partial { merged, failed }\n            }\n            MergeBlocksResult::Fail => MergeBlocksResult::Fail,\n        }\n    }\n}\n\nstruct MergeBlockStep1<O> {\n    full: Vec<Block<O>>,\n    partial: Vec<Block<O>>,\n    failed: Vec<Block<O>>,\n}\n\nimpl<O> Default for MergeBlockStep1<O> {\n    fn default() -> Self {\n        Self {\n            full: Default::default(),\n            partial: Default::default(),\n            failed: Default::default(),\n        }\n    }\n}\n\nfn merge_blocks_step1<O: NumOperations>(blocks: &[&Block<O>]) -> MergeBlockStep1<O> {\n    let step_size = blocks.len() / 2;\n    let num_steps = f32::ceil(blocks.len() as f32 / step_size as f32) as usize;\n\n    let mut result = MergeBlockStep1::default();\n\n    for i in 0..num_steps {\n        let start = i * step_size;\n        let end = usize::min(start + step_size, blocks.len());\n\n        match merge_blocks(&blocks[start..end], false) {\n            MergeBlocksResult::Full(block) => {\n                result.full.push(block);\n            }\n            MergeBlocksResult::Partial {\n                mut merged,\n                mut failed,\n            } => {\n                result.partial.append(&mut merged);\n                result.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {\n                for b in &blocks[start..end] {\n                    result.failed.push((*b).clone());\n                }\n            }\n        }\n    }\n\n    result\n}\n\nfn merge_blocks_step2<O: NumOperations>(mut step1: MergeBlockStep1<O>) -> MergeBlocksResult<O> {\n    // First let's try to merge partial graphs.\n    if step1.partial.len() > 1 {\n        match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {\n            MergeBlocksResult::Full(block) => {\n                step1.partial = vec![block];\n            }\n            MergeBlocksResult::Partial { merged, mut failed } => {\n                step1.partial = merged;\n                step1.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {}\n        }\n    }\n\n    // Then let's try to merge partial graphs with failed merges.\n    if !step1.failed.is_empty() {\n        step1.partial.append(&mut step1.failed);\n        match merge_accumulator(&step1.partial[0], &step1.partial[1..]) {\n            MergeBlocksResult::Full(block) => {\n                step1.partial = vec![block];\n            }\n            MergeBlocksResult::Partial { merged, mut failed } => {\n                step1.partial = merged;\n                step1.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {}\n        }\n    }\n\n    // Then let's try to merge full graphs.\n    if step1.full.len() > 1 {\n        match merge_accumulator(&step1.full[0], &step1.full[1..]) {\n            MergeBlocksResult::Full(block) => {\n                step1.full = vec![block];\n            }\n            MergeBlocksResult::Partial { merged, mut failed } => {\n                step1.full = merged;\n                step1.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {}\n        }\n    }\n\n    // Then let's try to merge full graphs with failed graphs.\n    if !step1.full.is_empty() {\n        step1.full.append(&mut step1.failed);\n        match merge_accumulator(&step1.full[0], &step1.full[1..]) {\n            MergeBlocksResult::Full(block) => {\n                step1.full = vec![block];\n            }\n            MergeBlocksResult::Partial { merged, mut failed } => {\n                step1.full = merged;\n                step1.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {}\n        }\n    }\n\n    // Then let's try to merge full graphs with partial graphs.\n    if !step1.full.is_empty() || !step1.partial.is_empty() {\n        step1.full.append(&mut step1.partial);\n        match merge_accumulator(&step1.full[0], &step1.full[1..]) {\n            MergeBlocksResult::Full(block) => {\n                step1.full = vec![block];\n            }\n            MergeBlocksResult::Partial { merged, mut failed } => {\n                step1.full = merged;\n                step1.failed.append(&mut failed);\n            }\n            MergeBlocksResult::Fail => {\n                // We do nothing.\n            }\n        }\n    }\n\n    if step1.full.is_empty() {\n        MergeBlocksResult::Fail\n    } else if step1.failed.is_empty() {\n        if step1.full.len() == 1 {\n            MergeBlocksResult::Full(step1.full.remove(0))\n        } else {\n            MergeBlocksResult::Partial {\n                merged: step1.full,\n                failed: vec![],\n            }\n        }\n    } else {\n        MergeBlocksResult::Partial {\n            merged: step1.full,\n            failed: step1.failed,\n        }\n    }\n}\n\nfn merge_accumulator<O: NumOperations>(\n    base: &Block<O>,\n    blocks: &[Block<O>],\n) -> MergeBlocksResult<O> {\n    let mut base = base.clone();\n    let mut merged_failed = Vec::<Block<O>>::new();\n    let mut merged_success = false;\n\n    for block in blocks {\n        let mut base_current = base.clone();\n        match base_current.merge(block) {\n            false => {\n                merged_failed.push((*block).clone());\n            }\n            true => {\n                merged_success = true;\n                base = base_current;\n            }\n        }\n    }\n\n    if merged_success {\n        if merged_failed.is_empty() {\n            MergeBlocksResult::Full(base)\n        } else {\n            MergeBlocksResult::Partial {\n                merged: vec![base],\n                failed: merged_failed,\n            }\n        }\n    } else {\n        MergeBlocksResult::Fail\n    }\n}\n\nfn merge_two<O: NumOperations>(a: &Block<O>, b: &Block<O>) -> Option<Block<O>> {\n    let mut base = a.clone();\n\n    if base.merge(b) {\n        return Some(base);\n    }\n\n    let mut base = b.clone();\n\n    match base.merge(a) {\n        true => Some(base),\n        false => None,\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    pub use crate::stream::execution::tests::{TestOptimization, TestOptimizationBuilder};\n    use crate::{\n        OperationFuser,\n        stream::tests::{operation_1, operation_2, operation_3},\n    };\n\n    #[test]\n    fn test_merge_blocks_no_block() {\n        let actual = merge_blocks::<TestOptimization>(&[], true);\n\n        assert_eq!(actual, MergeBlocksResult::Fail);\n    }\n\n    #[test]\n    fn test_merge_blocks_single() {\n        let builders = builders();\n        let block = Block::new(&builders);\n        let actual = merge_blocks::<TestOptimization>(&[&block], true);\n\n        assert_eq!(actual, MergeBlocksResult::Full(block));\n    }\n\n    #[test]\n    fn test_merge_blocks_two_blocks() {\n        let builders = builders();\n        let mut block1 = Block::new(&builders);\n        let mut block2 = Block::new(&builders);\n        block1.register(&operation_1(), 0, false);\n        block1.register(&operation_1(), 1, false);\n        block2.register(&operation_1(), 2, false);\n        block2.register(&operation_1(), 3, false);\n\n        let actual = merge_blocks::<TestOptimization>(&[&block1, &block2], true);\n\n        let mut expected = Block::new(&builders);\n        expected.register(&operation_1(), 0, false);\n        expected.register(&operation_1(), 1, false);\n        expected.register(&operation_1(), 2, false);\n        expected.register(&operation_1(), 3, false);\n\n        assert_eq!(actual, MergeBlocksResult::Full(expected));\n    }\n\n    #[test]\n    fn test_merge_blocks_three_blocks() {\n        let builders = builders();\n        let mut block1 = Block::new(&builders);\n        let mut block2 = Block::new(&builders);\n        let mut block3 = Block::new(&builders);\n        block1.register(&operation_1(), 0, false);\n        block2.register(&operation_1(), 1, false);\n        block3.register(&operation_1(), 2, false);\n\n        let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);\n\n        let mut expected = Block::new(&builders);\n        expected.register(&operation_1(), 0, false);\n        expected.register(&operation_1(), 1, false);\n        expected.register(&operation_1(), 2, false);\n\n        assert_eq!(actual, MergeBlocksResult::Full(expected));\n    }\n\n    #[test]\n    fn test_merge_blocks_three_blocks_partial() {\n        let builders = builders();\n        let mut block1 = Block::new(&builders);\n        let mut block2 = Block::new(&builders);\n        let mut block3 = Block::new(&builders);\n        block1.register(&operation_1(), 0, false);\n        block2.register(&operation_2(), 1, false);\n        block3.register(&operation_1(), 2, false);\n\n        let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3], true);\n\n        let mut expected1 = Block::new(&builders);\n        let mut expected2 = Block::new(&builders);\n        expected1.register(&operation_1(), 0, false);\n        expected1.register(&operation_1(), 2, false);\n        expected2.register(&operation_2(), 1, false);\n\n        assert_eq!(\n            actual,\n            MergeBlocksResult::Partial {\n                merged: vec![expected1, expected2],\n                failed: vec![]\n            }\n        );\n    }\n\n    #[test]\n    fn test_merge_blocks_four_blocks_partial_with_failure() {\n        let builders = builders();\n        let mut block1 = Block::new(&builders);\n        let mut block2 = Block::new(&builders);\n        let mut block3 = Block::new(&builders);\n        let mut block4 = Block::new(&builders);\n        block1.register(&operation_1(), 0, false);\n        block2.register(&operation_2(), 1, false);\n        block3.register(&operation_1(), 2, false);\n        block4.register(&operation_3(), 3, false);\n\n        let actual = merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4], true);\n\n        let mut expected1 = Block::new(&builders);\n        let mut expected2 = Block::new(&builders);\n        let mut failed = Block::new(&builders);\n        expected1.register(&operation_1(), 0, false);\n        expected1.register(&operation_1(), 2, false);\n        expected2.register(&operation_2(), 1, false);\n        failed.register(&operation_3(), 3, false);\n\n        assert_eq!(\n            actual,\n            MergeBlocksResult::Partial {\n                merged: vec![expected1],\n                failed: vec![expected2, failed]\n            }\n        );\n    }\n\n    #[test]\n    fn test_merge_blocks_five_blocks_partial_with_failure() {\n        let builders = builders();\n        let mut block1 = Block::new(&builders);\n        let mut block2 = Block::new(&builders);\n        let mut block3 = Block::new(&builders);\n        let mut block4 = Block::new(&builders);\n        let mut block5 = Block::new(&builders);\n        block1.register(&operation_1(), 0, false);\n        block2.register(&operation_2(), 1, false);\n        block3.register(&operation_1(), 2, false);\n        block4.register(&operation_3(), 3, false);\n        block5.register(&operation_2(), 4, false);\n\n        let actual =\n            merge_blocks::<TestOptimization>(&[&block1, &block2, &block3, &block4, &block5], true);\n\n        let mut expected1 = Block::new(&builders);\n        let mut expected2 = Block::new(&builders);\n        let mut failed = Block::new(&builders);\n        expected1.register(&operation_1(), 0, false);\n        expected1.register(&operation_1(), 2, false);\n        expected2.register(&operation_2(), 1, false);\n        expected2.register(&operation_2(), 4, false);\n        failed.register(&operation_3(), 3, false);\n\n        assert_eq!(\n            actual,\n            MergeBlocksResult::Partial {\n                merged: vec![expected1, expected2],\n                failed: vec![failed]\n            }\n        );\n    }\n\n    fn builders() -> Vec<Box<dyn OperationFuser<TestOptimization>>> {\n        let builder_1 = TestOptimizationBuilder::new(0, vec![operation_1(); 10]);\n        let builder_2 = TestOptimizationBuilder::new(1, vec![operation_2(); 10]);\n\n        vec![Box::new(builder_1), Box::new(builder_2)]\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/search/mod.rs",
    "content": "mod block;\nmod optimization;\n\npub(super) mod merging;\npub(super) use block::*;\n\npub use optimization::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/search/optimization/blocks.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    NumOperations,\n    search::{\n        Block, BlockOptimization,\n        merging::{MergeBlocksResult, merge_blocks},\n    },\n    stream::store::ExecutionStrategy,\n};\n\n/// Try to optimize a list of [blocks](Block) into a [block optimization](BlockOptimization).\n///\n/// # Notes\n///\n/// What we know here is that every block is independent at that time and can be executed\n/// in any order.\n///\n/// The contract is that the length of operations executed must include all operations. If we don't\n/// find an optimization that can be executed with that constraint, we return a\n/// [BlocksOptimizerResult::WithHoles].\npub struct BlocksOptimizer<O> {\n    blocks: Vec<Block<O>>,\n    resolved: Vec<bool>,\n    last_checked: usize,\n}\n\n/// When we can't find a proper optimization for the provided list of [blocks](Block).\npub enum BlocksOptimizerResult<O> {\n    /// When an optimization fill the hole stream.\n    Full(BlockOptimization<O>),\n    /// The optimization found with the holes indices.\n    WithHoles {\n        strategies: Vec<Box<ExecutionStrategy<O>>>,\n        ordering: Vec<usize>,\n        holes: Vec<usize>,\n    },\n}\n\nenum BlockOptimizationStep<O> {\n    Contiguous {\n        strategy: ExecutionStrategy<O>,\n    },\n    /// Only happen when we fallback on executing a single operation.\n    Operation {\n        strategy: ExecutionStrategy<O>,\n    },\n    WithHoles {\n        strategy: ExecutionStrategy<O>,\n        holes: Vec<usize>,\n    },\n    Stop,\n}\n\nimpl<O: NumOperations> BlocksOptimizer<O> {\n    /// Create a new optimizer with the given blocks.\n    pub fn new(blocks: Vec<Block<O>>) -> Self {\n        let num_ops: usize = blocks.iter().map(|g| g.end_pos).max().unwrap();\n\n        Self {\n            blocks,\n            resolved: vec![false; num_ops],\n            last_checked: 0,\n        }\n    }\n\n    /// Optimizes the blocks.\n    ///\n    /// The strategy is quite simple. We try to merge as much [blocks](Block) together as we can,\n    /// then we iterate over them in order composing optimizations with the remaining blocks, all\n    /// while minimizing fallbacks operations to avoid having holes in the optimization stream.\n    pub fn optimize(mut self) -> BlocksOptimizerResult<O> {\n        self = self.merging_pass();\n\n        let mut strategies = Vec::with_capacity(self.blocks.len());\n        let mut ordering = Vec::new();\n        let mut blocks = Vec::new();\n        core::mem::swap(&mut blocks, &mut self.blocks);\n\n        for block in blocks {\n            match self.optimize_block(block, &mut ordering) {\n                BlockOptimizationStep::Contiguous { strategy } => {\n                    strategies.push(Box::new(strategy));\n                }\n                BlockOptimizationStep::Operation { strategy } => {\n                    strategies.push(Box::new(strategy));\n                    break;\n                }\n                BlockOptimizationStep::WithHoles { strategy, holes } => {\n                    strategies.push(Box::new(strategy));\n\n                    return BlocksOptimizerResult::WithHoles {\n                        strategies,\n                        ordering,\n                        holes,\n                    };\n                }\n                BlockOptimizationStep::Stop => {\n                    break;\n                }\n            }\n        }\n\n        let optimization = match strategies.len() > 1 {\n            true => BlockOptimization {\n                strategy: ExecutionStrategy::Composed(strategies),\n                ordering,\n            },\n            false => BlockOptimization {\n                strategy: *strategies.remove(0),\n                ordering,\n            },\n        };\n\n        BlocksOptimizerResult::Full(optimization)\n    }\n\n    /// Optimize a single block.\n    fn optimize_block(\n        &mut self,\n        block: Block<O>,\n        ordering: &mut Vec<usize>,\n    ) -> BlockOptimizationStep<O> {\n        let last_index = block.end_pos;\n        let mut block_optimization = block.optimize();\n        let opt_size = block_optimization.ordering.len();\n\n        for pos in block_optimization.ordering.iter() {\n            self.update_check(*pos);\n        }\n\n        if self.last_checked != ordering.len() + opt_size {\n            if !ordering.is_empty() {\n                // Don't include that block and need further exploring.\n                return BlockOptimizationStep::Stop;\n            }\n\n            return self.optimize_holes(block_optimization, last_index, ordering);\n        }\n\n        ordering.append(&mut block_optimization.ordering);\n        BlockOptimizationStep::Contiguous {\n            strategy: block_optimization.strategy,\n        }\n    }\n\n    /// The provided optimization has holes.\n    fn optimize_holes(\n        &mut self,\n        mut optimization: BlockOptimization<O>,\n        last_index: usize,\n        ordering_global: &mut Vec<usize>,\n    ) -> BlockOptimizationStep<O> {\n        match optimization.strategy {\n            ExecutionStrategy::Optimization { opt, ordering } => {\n                ordering_global.append(&mut optimization.ordering);\n                let holes = self.find_holes(last_index);\n\n                if holes.is_empty() {\n                    let strategy = ExecutionStrategy::Optimization { opt, ordering };\n                    BlockOptimizationStep::Contiguous { strategy }\n                } else {\n                    let strategy = ExecutionStrategy::Optimization { opt, ordering };\n                    BlockOptimizationStep::WithHoles { strategy, holes }\n                }\n            }\n            ExecutionStrategy::Operations { ordering } => {\n                let min = ordering.iter().min().unwrap();\n                ordering_global.push(*min);\n\n                let strategy = ExecutionStrategy::Operations {\n                    ordering: Arc::new(vec![*min]),\n                };\n                BlockOptimizationStep::Operation { strategy }\n            }\n            _ => unreachable!(),\n        }\n    }\n\n    fn update_check(&mut self, pos: usize) {\n        self.resolved[pos] = true;\n\n        for i in self.last_checked..self.resolved.len() {\n            if self.resolved[i] {\n                self.last_checked += 1;\n            } else {\n                break;\n            }\n        }\n    }\n\n    fn find_holes(&mut self, last: usize) -> Vec<usize> {\n        let mut fallbacks = Vec::new();\n\n        for i in self.last_checked..last {\n            if !self.resolved[i] {\n                fallbacks.push(i);\n                self.resolved[i] = true;\n            }\n            self.last_checked += 1;\n        }\n\n        fallbacks\n    }\n\n    /// Try to merge blocks together.\n    fn merging_pass(mut self) -> Self {\n        if self.blocks.len() == 1 {\n            return self;\n        }\n\n        Block::sort(&mut self.blocks);\n        let blocks = self.blocks.iter().collect::<Vec<_>>();\n\n        match merge_blocks(&blocks, false) {\n            MergeBlocksResult::Full(block) => {\n                self.blocks = vec![block];\n            }\n            MergeBlocksResult::Partial {\n                mut merged,\n                mut failed,\n            } => {\n                merged.append(&mut failed);\n                self.blocks = merged;\n                Block::sort(&mut self.blocks);\n            }\n            MergeBlocksResult::Fail => {}\n        }\n\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/search/optimization/mod.rs",
    "content": "mod blocks;\nmod stream;\n\npub use stream::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/search/optimization/stream.rs",
    "content": "use super::blocks::BlocksOptimizer;\nuse crate::{\n    NumOperations, OperationFuser,\n    search::{\n        Block, BlockOptimization, RegistrationResult,\n        merging::{MergeBlocksResult, merge_blocks},\n        optimization::blocks::BlocksOptimizerResult,\n    },\n    stream::store::ExecutionStrategy,\n};\nuse burn_ir::OperationIr;\n\n/// Optimize a stream of [operations](OperationIr) using a list of [builders](OptimizationBuilder).\npub struct StreamOptimizer<O> {\n    builders: Vec<Box<dyn OperationFuser<O>>>,\n    blocks: Vec<Block<O>>,\n    length: usize,\n    stopped: bool,\n    max_blocks: Option<usize>,\n}\n\nimpl<O: NumOperations> StreamOptimizer<O> {\n    /// Create a new stream optimizer.\n    pub fn new(builders: Vec<Box<dyn OperationFuser<O>>>) -> Self {\n        Self {\n            builders,\n            blocks: Vec::new(),\n            length: 0,\n            stopped: false,\n            // Too high and it may breaks the fusion cache always retriggering explorations.\n            max_blocks: Some(5),\n        }\n    }\n\n    /// Register a new [operation](OperationIr) in the optimizer.\n    ///\n    /// You can use the function [Self::still_optimizing] to know if the operations are actually\n    /// being registered.\n    pub fn register(&mut self, operation: &OperationIr) {\n        if self.stopped {\n            return;\n        }\n\n        if self.blocks.is_empty() {\n            self.on_new_block(operation);\n            self.length += 1;\n            return;\n        }\n\n        match self.merge_blocks(operation, false) {\n            MergeBlockStep::Full | MergeBlockStep::NoNeed => {}\n            MergeBlockStep::Fail | MergeBlockStep::Partial => {\n                // With the given operation, blocks are no longer independent.\n                self.stopped = true;\n                return;\n            }\n        }\n\n        if let Some(max_blocks) = self.max_blocks {\n            if self.register_max_block(operation, max_blocks) {\n                self.length += 1;\n            } else {\n                self.stopped = true;\n            }\n            return;\n        }\n\n        let added_count = self.register_inner(operation, false);\n        if added_count == 0 {\n            self.on_new_block(operation);\n        }\n\n        self.length += 1;\n    }\n\n    /// Optimize the current stream on the given [operations](OperationIr).\n    ///\n    /// # Notes\n    ///\n    /// The operations provided are the same as the ones used in the [register](Self::register)\n    /// method, this simply remove the need for the current type to also keep track of the list of\n    /// operations.\n    pub fn optimize(&self, operations: &[OperationIr]) -> BlockOptimization<O> {\n        let result = BlocksOptimizer::new(self.blocks.clone()).optimize();\n\n        match result {\n            BlocksOptimizerResult::Full(block_optimization) => block_optimization,\n            BlocksOptimizerResult::WithHoles {\n                mut strategies,\n                mut ordering,\n                mut holes,\n            } => {\n                loop {\n                    let mut search = self.new_empty_search();\n\n                    let mut operations_holes = Vec::with_capacity(holes.len());\n\n                    for index in holes.iter() {\n                        let op = &operations[*index];\n                        operations_holes.push(op.clone());\n                        search.register(op);\n                    }\n\n                    let mut optimization_of_holes = search.optimize(&operations_holes);\n\n                    optimization_of_holes.map_ordering(&holes);\n\n                    strategies.push(Box::new(optimization_of_holes.strategy));\n                    holes.drain(0..optimization_of_holes.ordering.len());\n                    ordering.append(&mut optimization_of_holes.ordering);\n\n                    if holes.is_empty() {\n                        break;\n                    }\n                }\n\n                BlockOptimization::new(ExecutionStrategy::Composed(strategies), ordering)\n            }\n        }\n    }\n\n    /// Reset the state of the optimizer.\n    pub fn reset(&mut self) {\n        self.builders.iter_mut().for_each(|b| b.reset());\n        self.length = 0;\n        self.blocks.clear();\n        self.stopped = false;\n    }\n\n    /// Returns if some optimizations are still possible within the stream.\n    pub fn still_optimizing(&self) -> bool {\n        if self.stopped {\n            return false;\n        }\n        if self.blocks.is_empty() {\n            return true;\n        }\n\n        let mut num_stopped = 0;\n\n        for block in self.blocks.iter() {\n            if !block.still_optimizing() {\n                num_stopped += 1\n            }\n        }\n\n        num_stopped < self.blocks.len()\n    }\n\n    fn register_max_block(&mut self, operation: &OperationIr, max_blocks: usize) -> bool {\n        if max_blocks == 1 {\n            // Register in the single block with a force.\n            self.register_inner(operation, true);\n            return true;\n        }\n        let added_count = self.register_inner(operation, false);\n\n        if added_count > 0 {\n            return true;\n        }\n\n        if added_count == 0 && self.blocks.len() < max_blocks {\n            self.on_new_block(operation);\n            return true;\n        }\n\n        self.merge_blocks(operation, true);\n\n        if self.blocks.len() >= max_blocks {\n            self.stopped = true;\n            return false;\n        }\n\n        let added_count = self.register_inner(operation, false);\n\n        if added_count == 0 {\n            self.on_new_block(operation);\n        }\n\n        true\n    }\n\n    fn register_inner(&mut self, operation: &OperationIr, force: bool) -> usize {\n        let mut added_count = 0;\n        for block in self.blocks.iter_mut() {\n            match block.register(operation, self.length, force) {\n                RegistrationResult::Accepted => {\n                    added_count += 1;\n                }\n                RegistrationResult::NotPartOfTheGraph => {}\n            }\n        }\n        added_count\n    }\n\n    fn new_empty_search(&self) -> Self {\n        Self::new(\n            self.builders\n                .iter()\n                .map(|b| {\n                    let mut b = b.clone_dyn();\n                    b.reset();\n                    b\n                })\n                .collect(),\n        )\n    }\n\n    fn merge_blocks(&mut self, operation: &OperationIr, all: bool) -> MergeBlockStep {\n        let nodes = operation.nodes();\n        let mut block_merges = Vec::new();\n\n        for (i, block) in self.blocks.iter().enumerate() {\n            if all || block.contains_tensors(&nodes) {\n                block_merges.push(i);\n            }\n        }\n\n        if block_merges.len() <= 1 {\n            return MergeBlockStep::NoNeed;\n        }\n\n        let blocks_to_merge = self\n            .blocks\n            .iter()\n            .enumerate()\n            .filter_map(|(i, g)| match block_merges.contains(&i) {\n                true => Some(g),\n                false => None,\n            })\n            .collect::<Vec<_>>();\n\n        let merged = merge_blocks(&blocks_to_merge, false);\n\n        let mut clear_blocks = || {\n            let mut indices = block_merges.to_vec();\n            indices.sort();\n\n            for g in indices.into_iter().rev() {\n                self.blocks.remove(g);\n            }\n        };\n\n        match merged {\n            MergeBlocksResult::Full(block) => {\n                clear_blocks();\n                self.blocks.push(block);\n                Block::sort(&mut self.blocks);\n                MergeBlockStep::Full\n            }\n            MergeBlocksResult::Partial {\n                mut merged,\n                mut failed,\n            } => {\n                clear_blocks();\n                self.blocks.append(&mut merged);\n                self.blocks.append(&mut failed);\n                Block::sort(&mut self.blocks);\n                MergeBlockStep::Partial\n            }\n            MergeBlocksResult::Fail => MergeBlockStep::Fail,\n        }\n    }\n\n    fn on_new_block(&mut self, operation: &OperationIr) {\n        let mut block = Block::new(&self.builders);\n        block.register(operation, self.length, true);\n        self.blocks.push(block);\n    }\n}\n\nenum MergeBlockStep {\n    Full,\n    Partial,\n    Fail,\n    NoNeed,\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/server.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    FusionBackend, FusionRuntime,\n    stream::{MultiStream, OperationStreams, StreamId, execution::Operation},\n};\nuse burn_backend::{TensorData, backend::ExecutionError};\nuse burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr};\n\npub struct FusionServer<R: FusionRuntime> {\n    streams: MultiStream<R>,\n    pub(crate) handles: HandleContainer<R::FusionHandle>,\n}\n\nimpl<R> FusionServer<R>\nwhere\n    R: FusionRuntime,\n{\n    pub fn new(device: R::FusionDevice) -> Self {\n        Self {\n            streams: MultiStream::new(device.clone()),\n            handles: HandleContainer::new(),\n        }\n    }\n\n    pub fn register(\n        &mut self,\n        streams: OperationStreams,\n        repr: OperationIr,\n        operation: Arc<dyn Operation<R>>,\n    ) {\n        self.streams\n            .register(streams, repr, operation, &mut self.handles)\n    }\n\n    pub fn drain_stream(&mut self, id: StreamId) {\n        self.streams.drain(&mut self.handles, id)\n    }\n\n    pub fn read_float<B>(\n        &mut self,\n        tensor: TensorIr,\n        id: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        // Make sure all registered operations are executed.\n        // The underlying backend can still be async.\n        self.drain_stream(id);\n        let tensor_float = self.handles.get_float_tensor::<B>(&tensor);\n        self.streams.mark_read(id, &tensor, &self.handles);\n        B::float_into_data(tensor_float)\n    }\n\n    pub fn read_int<B>(\n        &mut self,\n        tensor: TensorIr,\n        id: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        // Make sure all registered operations are executed.\n        // The underlying backend can still be async.\n        self.drain_stream(id);\n        let tensor_int = self.handles.get_int_tensor::<B>(&tensor);\n        self.streams.mark_read(id, &tensor, &self.handles);\n        B::int_into_data(tensor_int)\n    }\n\n    pub fn read_bool<B>(\n        &mut self,\n        tensor: TensorIr,\n        id: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        // Make sure all registered operations are executed.\n        // The underlying backend can still be async.\n        self.drain_stream(id);\n        let tensor_bool = self.handles.get_bool_tensor::<B>(&tensor);\n        self.streams.mark_read(id, &tensor, &self.handles);\n        B::bool_into_data(tensor_bool)\n    }\n\n    pub fn read_quantized<B>(\n        &mut self,\n        tensor: TensorIr,\n        id: StreamId,\n    ) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send + use<R, B>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        // Make sure all registered operations are executed.\n        // The underlying backend can still be async.\n        self.drain_stream(id);\n        let tensor_q = self.handles.get_quantized_tensor::<B>(&tensor);\n        self.streams.mark_read(id, &tensor, &self.handles);\n        B::q_into_data(tensor_q)\n    }\n\n    pub fn change_server_float<B>(\n        &mut self,\n        tensor: &TensorIr,\n        output_id: TensorId,\n        stream_tensor: StreamId,\n        device: &R::FusionDevice,\n        server_device: &mut Self,\n    ) where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let tensor_float = self.handles.get_float_tensor::<B>(tensor);\n        self.streams.mark_read(stream_tensor, tensor, &self.handles);\n\n        let tensor = B::float_to_device(tensor_float, device);\n\n        server_device\n            .handles\n            .register_float_tensor::<B>(&output_id, tensor.clone());\n    }\n\n    pub fn resolve_server_float<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.handles.get_float_tensor::<B>(tensor)\n    }\n\n    pub fn resolve_server_int<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.handles.get_int_tensor::<B>(tensor)\n    }\n\n    pub fn resolve_server_bool<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        self.handles.get_bool_tensor::<B>(tensor)\n    }\n\n    pub fn change_server_int<B>(\n        &mut self,\n        tensor: &TensorIr,\n        output_id: TensorId,\n        stream_tensor: StreamId,\n        device: &R::FusionDevice,\n        server_device: &mut Self,\n    ) where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let tensor_int = self.handles.get_int_tensor::<B>(tensor);\n        self.streams.mark_read(stream_tensor, tensor, &self.handles);\n        let tensor = B::int_to_device(tensor_int, device);\n\n        server_device\n            .handles\n            .register_int_tensor::<B>(&output_id, tensor.clone());\n    }\n\n    pub fn change_server_bool<B>(\n        &mut self,\n        tensor: &TensorIr,\n        output_id: TensorId,\n        stream_tensor: StreamId,\n        device: &R::FusionDevice,\n        server_device: &mut Self,\n    ) where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let tensor_bool = self.handles.get_bool_tensor::<B>(tensor);\n        self.streams.mark_read(stream_tensor, tensor, &self.handles);\n        let tensor = B::bool_to_device(tensor_bool, device);\n\n        server_device\n            .handles\n            .register_bool_tensor::<B>(&output_id, tensor.clone());\n    }\n\n    pub fn change_server_quantized<B>(\n        &mut self,\n        tensor: &TensorIr,\n        output_id: TensorId,\n        device: &R::FusionDevice,\n        server_device: &mut Self,\n    ) where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let tensor = self.handles.get_quantized_tensor::<B>(tensor);\n        let tensor = B::q_to_device(tensor, device);\n\n        server_device\n            .handles\n            .register_quantized_tensor::<B>(&output_id, tensor);\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/base.rs",
    "content": "pub use burn_backend::StreamId;\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/context.rs",
    "content": "use burn_backend::{Shape, Slice};\nuse burn_ir::*;\nuse hashbrown::HashMap;\n\n/// The context contains the relative graph tensor mapping so that a relative tensor id can be\n/// mapped to an existing tensor that can be fetched and updated with the\n/// [handle container](HandleContainer).\n///\n/// It also contains all scalar values, which can change even for the same graph. They are sorted\n/// in the order in which they appear in the graph.\n#[allow(clippy::too_many_arguments)]\n#[derive(new)]\npub struct Context<'a, H> {\n    /// The tensor mapping where local tensor id points to the updated tensor representation.\n    pub tensors: &'a mut HashMap<TensorId, TensorIr>,\n    /// Handle container to retrieve tensors based on their representation.\n    pub handles: &'a mut HandleContainer<H>,\n    /// Scalars found in the graph in the order they appeared.\n    pub scalars: &'a mut HashMap<ScalarId, ScalarIr>,\n    /// Shape mapping from relative shape ids to global (real) shape ids.\n    pub shapes_relative2global: &'a HashMap<usize, usize>,\n}\n\n#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]\n/// Scalar unique identifier.\npub struct ScalarId {\n    /// The value.\n    pub value: u64,\n}\n\npub(crate) struct OperationConverter {\n    tensors_relative2global: HashMap<TensorId, TensorIr>,\n    tensors_global2relative: HashMap<TensorId, TensorIr>,\n    shapes_global2relative: HashMap<usize, usize>,\n    shapes_relative2global: HashMap<usize, usize>,\n    scalars: HashMap<ScalarId, ScalarIr>,\n}\n\nimpl Default for OperationConverter {\n    fn default() -> Self {\n        let mut val = Self {\n            tensors_relative2global: Default::default(),\n            tensors_global2relative: Default::default(),\n            shapes_global2relative: Default::default(),\n            shapes_relative2global: Default::default(),\n            scalars: Default::default(),\n        };\n\n        // global 1 is always shape id 0.\n        val.shapes_global2relative.insert(1, 0);\n        val.shapes_relative2global.insert(0, 1);\n\n        val\n    }\n}\n\n/// Fork of a [context](Context) which owns its data.\npub struct ContextOwned<H> {\n    tensors: HashMap<TensorId, TensorIr>,\n    handles: HandleContainer<H>,\n    scalars: HashMap<ScalarId, ScalarIr>,\n    shapes_relative2global: HashMap<usize, usize>,\n}\n\nimpl<H: Clone> ContextOwned<H> {\n    /// Convert into [context](Context).\n    pub fn as_context(&mut self) -> Context<'_, H> {\n        Context {\n            tensors: &mut self.tensors,\n            handles: &mut self.handles,\n            scalars: &mut self.scalars,\n            shapes_relative2global: &self.shapes_relative2global,\n        }\n    }\n\n    /// Fork the context again.\n    pub fn fork(&self) -> ContextOwned<H> {\n        ContextOwned {\n            tensors: self.tensors.clone(),\n            handles: self.handles.fork(),\n            scalars: self.scalars.clone(),\n            shapes_relative2global: self.shapes_relative2global.clone(),\n        }\n    }\n}\n\nimpl<H: Clone> Context<'_, H> {\n    /// Fork the context into an [owned context](ContextOwned).\n    pub fn fork(&self) -> ContextOwned<H> {\n        ContextOwned {\n            tensors: self.tensors.clone(),\n            handles: self.handles.fork(),\n            scalars: self.scalars.clone(),\n            shapes_relative2global: self.shapes_relative2global.clone(),\n        }\n    }\n}\n\npub(crate) trait RelativeOps {\n    /// Convert (usually an [`OperationIr`]) to a relative form.\n    ///\n    /// The id and the shape of tensors will be computed relative to existing\n    /// operations in the queue. We do this because we want to fuse operations\n    /// that have similar shapes, but we do not care about the exact values.\n    ///\n    /// Similar we do not care about the exact ids of the tensor, but about their\n    /// relative ids (how close they are in the operation queue)\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self;\n}\n\nimpl OperationConverter {\n    pub(crate) fn context<'a, H>(\n        &'a mut self,\n        handles: &'a mut HandleContainer<H>,\n    ) -> Context<'a, H> {\n        Context {\n            handles,\n            tensors: &mut self.tensors_relative2global,\n            scalars: &mut self.scalars,\n            shapes_relative2global: &self.shapes_relative2global,\n        }\n    }\n\n    pub(crate) fn clear(&mut self) {\n        self.tensors_relative2global.clear();\n        self.tensors_global2relative.clear();\n\n        self.shapes_global2relative.clear();\n        self.shapes_relative2global.clear();\n\n        // global 1 is always shape id 0.\n        self.shapes_global2relative.insert(1, 0);\n        self.shapes_relative2global.insert(0, 1);\n\n        self.scalars.clear();\n    }\n}\n\nimpl RelativeOps for OperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            OperationIr::BaseFloat(ops) => OperationIr::BaseFloat(ops.to_relative(converter)),\n            OperationIr::BaseInt(ops) => OperationIr::BaseInt(ops.to_relative(converter)),\n            OperationIr::BaseBool(ops) => OperationIr::BaseBool(ops.to_relative(converter)),\n            OperationIr::NumericFloat(dtype, ops) => {\n                OperationIr::NumericFloat(*dtype, ops.to_relative(converter))\n            }\n            OperationIr::NumericInt(dtype, ops) => {\n                OperationIr::NumericInt(*dtype, ops.to_relative(converter))\n            }\n            OperationIr::Bool(ops) => OperationIr::Bool(ops.to_relative(converter)),\n            OperationIr::Int(ops) => OperationIr::Int(ops.to_relative(converter)),\n            OperationIr::Float(dtype, ops) => {\n                OperationIr::Float(*dtype, ops.to_relative(converter))\n            }\n            OperationIr::Module(ops) => OperationIr::Module(ops.to_relative(converter)),\n            OperationIr::Custom(ops) => OperationIr::Custom(ops.to_relative(converter)),\n            OperationIr::Init(ops) => OperationIr::Init(ops.to_relative(converter)),\n            OperationIr::Drop(tensor) => OperationIr::Drop(tensor.to_relative(converter)),\n        }\n    }\n}\n\nimpl RelativeOps for ModuleOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            ModuleOperationIr::Embedding(desc) => ModuleOperationIr::Embedding(EmbeddingOpIr {\n                weights: desc.weights.to_relative(converter),\n                indices: desc.indices.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::EmbeddingBackward(desc) => {\n                ModuleOperationIr::EmbeddingBackward(EmbeddingBackwardOpIr {\n                    weights: desc.weights.to_relative(converter),\n                    out_grad: desc.out_grad.to_relative(converter),\n                    indices: desc.indices.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv1d(desc) => ModuleOperationIr::Conv1d(Conv1dOpIr {\n                x: desc.x.to_relative(converter),\n                weight: desc.weight.to_relative(converter),\n                bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                options: desc.options.clone(),\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::Conv1dXBackward(desc) => {\n                ModuleOperationIr::Conv1dXBackward(Conv1dXBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv1dWeightBackward(desc) => {\n                ModuleOperationIr::Conv1dWeightBackward(Conv1dWeightBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv1dBiasBackward(desc) => {\n                ModuleOperationIr::Conv1dBiasBackward(Conv1dBiasBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    bias: desc.bias.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv2d(desc) => ModuleOperationIr::Conv2d(Conv2dOpIr {\n                x: desc.x.to_relative(converter),\n                weight: desc.weight.to_relative(converter),\n                bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                options: desc.options.clone(),\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::Conv2dXBackward(desc) => {\n                ModuleOperationIr::Conv2dXBackward(Conv2dXBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv2dWeightBackward(desc) => {\n                ModuleOperationIr::Conv2dWeightBackward(Conv2dWeightBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv2dBiasBackward(desc) => {\n                ModuleOperationIr::Conv2dBiasBackward(Conv2dBiasBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    bias: desc.bias.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv3d(desc) => ModuleOperationIr::Conv3d(Conv3dOpIr {\n                x: desc.x.to_relative(converter),\n                weight: desc.weight.to_relative(converter),\n                bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                options: desc.options.clone(),\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::Conv3dXBackward(desc) => {\n                ModuleOperationIr::Conv3dXBackward(Conv3dXBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv3dWeightBackward(desc) => {\n                ModuleOperationIr::Conv3dWeightBackward(Conv3dWeightBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Conv3dBiasBackward(desc) => {\n                ModuleOperationIr::Conv3dBiasBackward(Conv3dBiasBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    bias: desc.bias.to_relative(converter),\n                    output_grad: desc.output_grad.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::DeformableConv2d(desc) => {\n                ModuleOperationIr::DeformableConv2d(Box::new(DeformConv2dOpIr {\n                    x: desc.x.to_relative(converter),\n                    offset: desc.offset.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),\n                    bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                }))\n            }\n            ModuleOperationIr::DeformableConv2dBackward(desc) => {\n                ModuleOperationIr::DeformableConv2dBackward(Box::new(DeformConv2dBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    offset: desc.offset.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),\n                    bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                    out_grad: desc.out_grad.to_relative(converter),\n                    options: desc.options.clone(),\n                    input_grad: desc.input_grad.to_relative(converter),\n                    offset_grad: desc.offset_grad.to_relative(converter),\n                    weight_grad: desc.weight_grad.to_relative(converter),\n                    mask_grad: desc.mask_grad.as_ref().map(|t| t.to_relative(converter)),\n                    bias_grad: desc.bias_grad.as_ref().map(|t| t.to_relative(converter)),\n                }))\n            }\n            ModuleOperationIr::ConvTranspose1d(desc) => {\n                ModuleOperationIr::ConvTranspose1d(ConvTranspose1dOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::ConvTranspose2d(desc) => {\n                ModuleOperationIr::ConvTranspose2d(ConvTranspose2dOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::ConvTranspose3d(desc) => {\n                ModuleOperationIr::ConvTranspose3d(ConvTranspose3dOpIr {\n                    x: desc.x.to_relative(converter),\n                    weight: desc.weight.to_relative(converter),\n                    bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AvgPool1d(desc) => ModuleOperationIr::AvgPool1d(AvgPool1dOpIr {\n                x: desc.x.to_relative(converter),\n                kernel_size: desc.kernel_size,\n                stride: desc.stride,\n                padding: desc.padding,\n                count_include_pad: desc.count_include_pad,\n                ceil_mode: desc.ceil_mode,\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::AvgPool2d(desc) => ModuleOperationIr::AvgPool2d(AvgPool2dOpIr {\n                x: desc.x.to_relative(converter),\n                kernel_size: desc.kernel_size,\n                stride: desc.stride,\n                padding: desc.padding,\n                count_include_pad: desc.count_include_pad,\n                ceil_mode: desc.ceil_mode,\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::AvgPool1dBackward(desc) => {\n                ModuleOperationIr::AvgPool1dBackward(AvgPool1dBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    count_include_pad: desc.count_include_pad,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AvgPool2dBackward(desc) => {\n                ModuleOperationIr::AvgPool2dBackward(AvgPool2dBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    count_include_pad: desc.count_include_pad,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AdaptiveAvgPool1d(desc) => {\n                ModuleOperationIr::AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr {\n                    x: desc.x.to_relative(converter),\n                    output_size: desc.output_size,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AdaptiveAvgPool2d(desc) => {\n                ModuleOperationIr::AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr {\n                    x: desc.x.to_relative(converter),\n                    output_size: desc.output_size,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AdaptiveAvgPool1dBackward(desc) => {\n                ModuleOperationIr::AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::AdaptiveAvgPool2dBackward(desc) => {\n                ModuleOperationIr::AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::MaxPool1d(desc) => ModuleOperationIr::MaxPool1d(MaxPool1dOpIr {\n                x: desc.x.to_relative(converter),\n                kernel_size: desc.kernel_size,\n                stride: desc.stride,\n                padding: desc.padding,\n                dilation: desc.dilation,\n                ceil_mode: desc.ceil_mode,\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::MaxPool1dWithIndices(desc) => {\n                ModuleOperationIr::MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr {\n                    x: desc.x.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    dilation: desc.dilation,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                    out_indices: desc.out_indices.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::MaxPool1dWithIndicesBackward(desc) => {\n                ModuleOperationIr::MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    indices: desc.indices.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    dilation: desc.dilation,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::MaxPool2d(desc) => ModuleOperationIr::MaxPool2d(MaxPool2dOpIr {\n                x: desc.x.to_relative(converter),\n                kernel_size: desc.kernel_size,\n                stride: desc.stride,\n                padding: desc.padding,\n                dilation: desc.dilation,\n                ceil_mode: desc.ceil_mode,\n                out: desc.out.to_relative(converter),\n            }),\n            ModuleOperationIr::MaxPool2dWithIndices(desc) => {\n                ModuleOperationIr::MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr {\n                    x: desc.x.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    dilation: desc.dilation,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                    out_indices: desc.out_indices.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::MaxPool2dWithIndicesBackward(desc) => {\n                ModuleOperationIr::MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    indices: desc.indices.to_relative(converter),\n                    kernel_size: desc.kernel_size,\n                    stride: desc.stride,\n                    padding: desc.padding,\n                    dilation: desc.dilation,\n                    ceil_mode: desc.ceil_mode,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Interpolate(desc) => {\n                ModuleOperationIr::Interpolate(InterpolateOpIr {\n                    x: desc.x.to_relative(converter),\n                    output_size: desc.output_size,\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::InterpolateBackward(desc) => {\n                ModuleOperationIr::InterpolateBackward(InterpolateBackwardOpIr {\n                    x: desc.x.to_relative(converter),\n                    grad: desc.grad.to_relative(converter),\n                    output_size: desc.output_size,\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            ModuleOperationIr::Attention(desc) => ModuleOperationIr::Attention(AttentionOpIr {\n                query: desc.query.to_relative(converter),\n                key: desc.key.to_relative(converter),\n                value: desc.value.to_relative(converter),\n                mask: desc.mask.as_ref().map(|m| m.to_relative(converter)),\n                attn_bias: desc.attn_bias.as_ref().map(|ab| ab.to_relative(converter)),\n                options: desc.options.clone(),\n                out: desc.out.to_relative(converter),\n            }),\n        }\n    }\n}\n\nimpl RelativeOps for FloatOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            FloatOperationIr::Exp(desc) => FloatOperationIr::Exp(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Log(desc) => FloatOperationIr::Log(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Log1p(desc) => FloatOperationIr::Log1p(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Erf(desc) => FloatOperationIr::Erf(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Powf(desc) => FloatOperationIr::Powf(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::PowfScalar(desc) => FloatOperationIr::PowfScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Sqrt(desc) => FloatOperationIr::Sqrt(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Cos(desc) => FloatOperationIr::Cos(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Sin(desc) => FloatOperationIr::Sin(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Tanh(desc) => FloatOperationIr::Tanh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Tan(desc) => FloatOperationIr::Tan(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Cosh(desc) => FloatOperationIr::Cosh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Sinh(desc) => FloatOperationIr::Sinh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcCos(desc) => FloatOperationIr::ArcCos(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcCosh(desc) => FloatOperationIr::ArcCosh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcSin(desc) => FloatOperationIr::ArcSin(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcSinh(desc) => FloatOperationIr::ArcSinh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcTan(desc) => FloatOperationIr::ArcTan(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcTanh(desc) => FloatOperationIr::ArcTanh(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::ArcTan2(desc) => FloatOperationIr::ArcTan2(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::IntoInt(desc) => FloatOperationIr::IntoInt(CastOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Matmul(desc) => FloatOperationIr::Matmul(MatmulOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Cross(desc) => FloatOperationIr::Cross(CrossOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                dim: desc.dim,\n            }),\n            FloatOperationIr::Random(desc) => FloatOperationIr::Random(RandomOpIr {\n                out: desc.out.to_relative(converter),\n                distribution: desc.distribution,\n            }),\n            FloatOperationIr::Recip(desc) => FloatOperationIr::Recip(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Quantize(desc) => FloatOperationIr::Quantize(QuantizeOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                qparams: QuantizationParametersIr {\n                    scales: desc.qparams.scales.to_relative(converter),\n                },\n                scheme: desc.scheme,\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Dequantize(desc) => FloatOperationIr::Dequantize(DequantizeOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Round(desc) => FloatOperationIr::Round(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Floor(desc) => FloatOperationIr::Floor(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Ceil(desc) => FloatOperationIr::Ceil(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::Trunc(desc) => FloatOperationIr::Ceil(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::IsNan(desc) => FloatOperationIr::IsNan(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::IsInf(desc) => FloatOperationIr::IsInf(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            FloatOperationIr::GridSample2d(desc) => {\n                FloatOperationIr::GridSample2d(GridSample2dOpIr {\n                    tensor: desc.tensor.to_relative(converter),\n                    grid: desc.grid.to_relative(converter),\n                    options: desc.options.clone(),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n        }\n    }\n}\n\nimpl RelativeOps for BoolOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            BoolOperationIr::IntoFloat(desc) => BoolOperationIr::IntoFloat(CastOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BoolOperationIr::IntoInt(desc) => BoolOperationIr::IntoInt(CastOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BoolOperationIr::Not(desc) => BoolOperationIr::Not(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BoolOperationIr::And(desc) => BoolOperationIr::And(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BoolOperationIr::Or(desc) => BoolOperationIr::Or(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n        }\n    }\n}\n\nimpl RelativeOps for IntOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            IntOperationIr::IntoFloat(desc) => IntOperationIr::IntoFloat(CastOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::Matmul(desc) => IntOperationIr::Matmul(MatmulOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseAnd(desc) => IntOperationIr::BitwiseAnd(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseAndScalar(desc) => {\n                IntOperationIr::BitwiseAndScalar(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            IntOperationIr::BitwiseOr(desc) => IntOperationIr::BitwiseOr(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseOrScalar(desc) => IntOperationIr::BitwiseOrScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs,\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseXor(desc) => IntOperationIr::BitwiseXor(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseXorScalar(desc) => {\n                IntOperationIr::BitwiseXorScalar(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            IntOperationIr::BitwiseNot(desc) => IntOperationIr::BitwiseNot(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            IntOperationIr::BitwiseLeftShift(desc) => {\n                IntOperationIr::BitwiseLeftShift(BinaryOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            IntOperationIr::BitwiseLeftShiftScalar(desc) => {\n                IntOperationIr::BitwiseLeftShiftScalar(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            IntOperationIr::BitwiseRightShift(desc) => {\n                IntOperationIr::BitwiseRightShift(BinaryOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            IntOperationIr::BitwiseRightShiftScalar(desc) => {\n                IntOperationIr::BitwiseRightShiftScalar(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n        }\n    }\n}\n\nimpl RelativeOps for CustomOpIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> CustomOpIr {\n        let id = self.id.clone();\n\n        CustomOpIr {\n            id,\n            inputs: self\n                .inputs\n                .iter()\n                .map(|x| x.to_relative(converter))\n                .collect(),\n            outputs: self\n                .outputs\n                .iter()\n                .map(|x| x.to_relative(converter))\n                .collect(),\n        }\n    }\n}\n\nimpl RelativeOps for NumericOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            NumericOperationIr::Add(desc) => NumericOperationIr::Add(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::AddScalar(desc) => NumericOperationIr::AddScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Sub(desc) => NumericOperationIr::Sub(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::SubScalar(desc) => NumericOperationIr::SubScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Div(desc) => NumericOperationIr::Div(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::DivScalar(desc) => NumericOperationIr::DivScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Rem(desc) => NumericOperationIr::Rem(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::RemScalar(desc) => NumericOperationIr::RemScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Mul(desc) => NumericOperationIr::Mul(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MulScalar(desc) => NumericOperationIr::MulScalar(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Abs(desc) => NumericOperationIr::Abs(UnaryOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Full(desc) => NumericOperationIr::Full(FullOpIr {\n                out: desc.out.to_relative(converter),\n                value: desc.value.to_relative(converter),\n            }),\n            NumericOperationIr::MeanDim(desc) => NumericOperationIr::MeanDim(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                axis: desc.axis,\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Mean(desc) => NumericOperationIr::Mean(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Sum(desc) => NumericOperationIr::Sum(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::SumDim(desc) => {\n                NumericOperationIr::SumDim(ReduceDimOpIr {\n                    input: desc.input.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                    axis: desc.axis, // Axis should stay the same.\n                })\n            }\n            NumericOperationIr::Prod(desc) => NumericOperationIr::Prod(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::ProdDim(desc) => NumericOperationIr::ProdDim(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                axis: desc.axis,\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Greater(desc) => NumericOperationIr::Greater(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::GreaterElem(desc) => NumericOperationIr::GreaterElem(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::GreaterEqual(desc) => {\n                NumericOperationIr::GreaterEqual(BinaryOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            NumericOperationIr::GreaterEqualElem(desc) => {\n                NumericOperationIr::GreaterEqualElem(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            NumericOperationIr::Lower(desc) => NumericOperationIr::Lower(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::LowerElem(desc) => NumericOperationIr::LowerElem(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::LowerEqual(desc) => NumericOperationIr::LowerEqual(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::LowerEqualElem(desc) => {\n                NumericOperationIr::LowerEqualElem(ScalarOpIr {\n                    lhs: desc.lhs.to_relative(converter),\n                    rhs: desc.rhs.to_relative(converter),\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            NumericOperationIr::ArgMax(desc) => NumericOperationIr::ArgMax(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis, // Axis should stay the same.\n            }),\n            NumericOperationIr::ArgMin(desc) => NumericOperationIr::ArgMin(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis, // Axis should stay the same.\n            }),\n            NumericOperationIr::Max(desc) => NumericOperationIr::Max(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MaxDimWithIndices(desc) => {\n                NumericOperationIr::MaxDimWithIndices(ReduceDimWithIndicesOpIr {\n                    tensor: desc.tensor.to_relative(converter),\n                    dim: desc.dim,\n                    out: desc.out.to_relative(converter),\n                    out_indices: desc.out_indices.to_relative(converter),\n                })\n            }\n            NumericOperationIr::MinDimWithIndices(desc) => {\n                NumericOperationIr::MinDimWithIndices(ReduceDimWithIndicesOpIr {\n                    tensor: desc.tensor.to_relative(converter),\n                    dim: desc.dim,\n                    out: desc.out.to_relative(converter),\n                    out_indices: desc.out_indices.to_relative(converter),\n                })\n            }\n            NumericOperationIr::Min(desc) => NumericOperationIr::Min(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MaxDim(desc) => NumericOperationIr::MaxDim(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                axis: desc.axis,\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MinDim(desc) => NumericOperationIr::MinDim(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                axis: desc.axis,\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MaxAbs(desc) => NumericOperationIr::MaxAbs(ReduceOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::MaxAbsDim(desc) => NumericOperationIr::MaxAbsDim(ReduceDimOpIr {\n                input: desc.input.to_relative(converter),\n                axis: desc.axis,\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::Clamp(desc) => NumericOperationIr::Clamp(ClampOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                min: desc.min.to_relative(converter),\n                max: desc.max.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::IntRandom(desc) => NumericOperationIr::IntRandom(RandomOpIr {\n                out: desc.out.to_relative(converter),\n                distribution: desc.distribution,\n            }),\n            NumericOperationIr::Powi(desc) => NumericOperationIr::Powi(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            NumericOperationIr::CumSum(desc) => NumericOperationIr::CumSum(DimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis,\n            }),\n            NumericOperationIr::CumProd(desc) => NumericOperationIr::CumProd(DimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis,\n            }),\n            NumericOperationIr::CumMin(desc) => NumericOperationIr::CumMin(DimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis,\n            }),\n            NumericOperationIr::CumMax(desc) => NumericOperationIr::CumMax(DimOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axis: desc.axis,\n            }),\n        }\n    }\n}\n\nimpl RelativeOps for BaseOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        match self {\n            BaseOperationIr::Reshape(desc) => BaseOperationIr::Reshape(ShapeOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::SwapDims(desc) => BaseOperationIr::SwapDims(SwapDimsOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                dim1: desc.dim1,\n                dim2: desc.dim2,\n            }),\n            BaseOperationIr::Permute(desc) => BaseOperationIr::Permute(PermuteOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axes: desc.axes.clone(),\n            }),\n            BaseOperationIr::Expand(desc) => BaseOperationIr::Expand(ShapeOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Unfold(desc) => BaseOperationIr::Unfold(UnfoldOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                dim: desc.dim,\n                size: desc.size,\n                step: desc.step,\n            }),\n            BaseOperationIr::Flip(desc) => BaseOperationIr::Flip(FlipOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n                axes: desc.axes.clone(),\n            }),\n            BaseOperationIr::Slice(desc) => BaseOperationIr::Slice(SliceOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                ranges: desc.ranges.iter().map(|_info| Slice::from(0..1)).collect(),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::SliceAssign(desc) => BaseOperationIr::SliceAssign(SliceAssignOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                ranges: desc.ranges.iter().map(|_range| Slice::from(0..1)).collect(),\n                value: desc.value.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Gather(desc) => BaseOperationIr::Gather(GatherOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                dim: desc.dim,\n                indices: desc.indices.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Scatter(desc) => BaseOperationIr::Scatter(ScatterOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                dim: desc.dim,\n                indices: desc.indices.to_relative(converter),\n                value: desc.value.to_relative(converter),\n                update: desc.update,\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Select(desc) => BaseOperationIr::Select(SelectOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                dim: desc.dim,\n                indices: desc.indices.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::SelectAssign(desc) => {\n                BaseOperationIr::SelectAssign(SelectAssignOpIr {\n                    tensor: desc.tensor.to_relative(converter),\n                    dim: desc.dim,\n                    indices: desc.indices.to_relative(converter),\n                    value: desc.value.to_relative(converter),\n                    update: desc.update,\n                    out: desc.out.to_relative(converter),\n                })\n            }\n            BaseOperationIr::MaskWhere(desc) => BaseOperationIr::MaskWhere(MaskWhereOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                mask: desc.mask.to_relative(converter),\n                value: desc.value.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::MaskFill(desc) => BaseOperationIr::MaskFill(MaskFillOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                mask: desc.mask.to_relative(converter),\n                value: desc.value.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Equal(desc) => BaseOperationIr::Equal(BinaryOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::EqualElem(desc) => BaseOperationIr::EqualElem(ScalarOpIr {\n                lhs: desc.lhs.to_relative(converter),\n                rhs: desc.rhs.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::RepeatDim(desc) => BaseOperationIr::RepeatDim(RepeatDimOpIr {\n                tensor: desc.tensor.to_relative(converter),\n                dim: desc.dim,\n                times: desc.times,\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Cat(desc) => BaseOperationIr::Cat(CatOpIr {\n                tensors: desc\n                    .tensors\n                    .iter()\n                    .map(|tensor| tensor.to_relative(converter))\n                    .collect(),\n                dim: desc.dim,\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Cast(desc) => BaseOperationIr::Cast(CastOpIr {\n                input: desc.input.to_relative(converter),\n                out: desc.out.to_relative(converter),\n            }),\n            BaseOperationIr::Empty(desc) => BaseOperationIr::Empty(desc.to_relative(converter)),\n            BaseOperationIr::Ones(desc) => BaseOperationIr::Ones(desc.to_relative(converter)),\n            BaseOperationIr::Zeros(desc) => BaseOperationIr::Zeros(desc.to_relative(converter)),\n        }\n    }\n}\n\nimpl RelativeOps for InitOperationIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        Self {\n            out: self.out.to_relative(converter),\n        }\n    }\n}\n\nimpl RelativeOps for CreationOpIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        Self {\n            out: self.out.to_relative(converter),\n        }\n    }\n}\n\nimpl RelativeOps for TensorIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        let relative_id = self.id.to_relative(converter);\n\n        // We can create relative shapes by mapping each shape found to an ID, which is a `usize`.\n        let mut relative_shape = Vec::with_capacity(self.shape.rank());\n        for dim in self.shape.iter() {\n            if let Some(dim_id) = converter.shapes_global2relative.get(dim) {\n                // We already saw that dim value before, so we retrieve its ID.\n                relative_shape.push(*dim_id);\n            } else {\n                // We never saw this dim value before, therefore we create a new ID.\n                let dim_id = converter.shapes_global2relative.len();\n                relative_shape.push(dim_id);\n\n                converter.shapes_global2relative.insert(*dim, dim_id);\n                converter.shapes_relative2global.insert(dim_id, *dim);\n            }\n        }\n\n        // We create the relative tensor.\n        let relative_tensor = TensorIr {\n            id: relative_id,\n            shape: Shape::from(relative_shape),\n            status: self.status,\n            dtype: self.dtype,\n        };\n\n        // We update both mappings.\n        converter\n            .tensors_relative2global\n            .insert(relative_id, self.clone());\n        converter\n            .tensors_global2relative\n            .insert(self.id, relative_tensor.clone());\n\n        relative_tensor\n    }\n}\n\nimpl RelativeOps for TensorId {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        if let Some(value) = converter.tensors_global2relative.get(self) {\n            // If we already have the same tensor registered, we have to update its value, but not\n            // its id.\n            value.id\n        } else {\n            // We create a new relative id since we never seen this tensor in the graph before.\n            TensorId::new(converter.tensors_relative2global.len() as u64)\n        }\n    }\n}\n\nimpl RelativeOps for ScalarIr {\n    fn to_relative(&self, converter: &mut OperationConverter) -> Self {\n        if matches!(self, ScalarIr::Bool(_)) {\n            todo!(\"Unsupported dtype ({self:?}) for scalar\")\n        }\n\n        let id = ScalarId {\n            value: converter.scalars.len() as u64,\n        };\n\n        converter.scalars.insert(id, *self);\n        ScalarIr::UInt(id.value)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_backend::DType;\n    use burn_ir::{TensorId, TensorIr, TensorStatus};\n\n    #[test]\n    fn tensor_description_to_relative() {\n        let tensor1 = TensorIr {\n            id: TensorId::new(500),\n            shape: Shape::new([512, 32, 2048]),\n            status: TensorStatus::ReadOnly,\n            dtype: DType::F32,\n        };\n        let tensor2 = TensorIr {\n            id: TensorId::new(501),\n            shape: Shape::new([512, 128, 2048]),\n            status: TensorStatus::ReadOnly,\n            dtype: DType::F32,\n        };\n        let mut converter = OperationConverter::default();\n        let tensor1_local = tensor1.to_relative(&mut converter);\n        let tensor2_local = tensor2.to_relative(&mut converter);\n\n        assert_eq!(\n            tensor1_local,\n            TensorIr {\n                id: TensorId::new(0),\n                shape: Shape::new([1, 2, 3]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32\n            }\n        );\n        assert_eq!(\n            tensor2_local,\n            TensorIr {\n                id: TensorId::new(1),\n                shape: Shape::new([1, 4, 3]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32\n            }\n        );\n    }\n\n    #[test]\n    fn scalar_ir_to_relative() {\n        let scalar1 = ScalarIr::Float(1.0);\n        let scalar2 = ScalarIr::UInt(1);\n        let mut converter = OperationConverter::default();\n        let scalar1_local = scalar1.to_relative(&mut converter);\n        let scalar2_local = scalar2.to_relative(&mut converter);\n\n        assert_eq!(scalar1_local, ScalarIr::UInt(0));\n        assert_eq!(scalar2_local, ScalarIr::UInt(1));\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/base.rs",
    "content": "use burn_ir::HandleContainer;\n\nuse crate::FusionRuntime;\n\n/// The mode in which the execution is done.\n#[derive(Clone, Copy, Debug)]\npub(crate) enum ExecutionMode {\n    Lazy,\n    Sync,\n}\n\n/// General trait to abstract how a single operation is executed.\npub trait Operation<R: FusionRuntime>: Send + Sync + core::fmt::Debug {\n    /// Execute the operation.\n    fn execute(&self, handles: &mut HandleContainer<R::FusionHandle>);\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/explorer.rs",
    "content": "use burn_ir::OperationIr;\n\nuse super::ExecutionMode;\nuse crate::{\n    NumOperations, OperationFuser,\n    search::{BlockOptimization, StreamOptimizer},\n};\n\n/// Explore and create new optimization.\npub struct Explorer<O> {\n    optimizer: StreamOptimizer<O>,\n    num_deferred: usize,\n    num_explored: usize,\n    is_still_optimizing: bool,\n}\n\n/// The result of an exploration done by the [explorer](Explorer).\npub enum ExplorationAction<O> {\n    /// Found a new optimization.\n    Completed(BlockOptimization<O>),\n    /// We should continue exploring before arriving at a conclusion.\n    Continue,\n}\n\nimpl<O: NumOperations> Explorer<O> {\n    /// Create a new explorer.\n    pub(crate) fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {\n        Self {\n            optimizer: StreamOptimizer::new(optimizations),\n            num_deferred: 0,\n            num_explored: 0,\n            is_still_optimizing: true,\n        }\n    }\n\n    /// Indicate that a new operation is added.\n    pub(crate) fn on_new_operation(&mut self) {\n        self.num_deferred += 1;\n    }\n\n    /// If the explorer is up to date.\n    pub(crate) fn is_up_to_date(&self) -> bool {\n        self.num_deferred == 0\n    }\n\n    /// Explore the provided operations.\n    pub(crate) fn explore(\n        &mut self,\n        operations: &[OperationIr],\n        mode: ExecutionMode,\n    ) -> ExplorationAction<O> {\n        self.update(operations);\n\n        // Can only continue exploration when not sync.\n        if let ExecutionMode::Lazy = mode\n            && self.is_still_optimizing\n        {\n            return ExplorationAction::Continue;\n        }\n\n        let optimization = self.optimizer.optimize(operations);\n\n        ExplorationAction::Completed(optimization)\n    }\n\n    /// Reset the state of the explorer to the provided list of operations.\n    pub(crate) fn reset(&mut self, operations: &[OperationIr]) {\n        self.optimizer.reset();\n        self.num_explored = 0;\n        self.num_deferred = operations.len();\n        self.is_still_optimizing = true;\n    }\n\n    /// Register any operations that we had deferred\n    fn update(&mut self, operations: &[OperationIr]) {\n        for i in (0..self.num_deferred).rev() {\n            if !self.is_still_optimizing {\n                break;\n            }\n            let index = operations.len() - 1 - i;\n            let relative = &operations[index];\n\n            self.optimizer.register(relative);\n            self.num_explored += 1;\n\n            self.is_still_optimizing = self.optimizer.still_optimizing();\n        }\n\n        self.num_deferred = 0;\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/mod.rs",
    "content": "pub(crate) mod validator;\n\nmod base;\nmod explorer;\nmod ordering;\nmod policy;\nmod processor;\n\npub use base::*;\npub use ordering::*;\n\npub(crate) use explorer::*;\npub(crate) use policy::*;\npub(crate) use processor::*;\n\n#[cfg(test)]\npub(crate) mod tests;\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/ordering.rs",
    "content": "use std::sync::Arc;\n\nuse burn_ir::HandleContainer;\n\nuse crate::{FusionRuntime, NumOperations, Optimization, stream::Context};\n\nuse super::Operation;\n\n/// Manage the execution of potentially multiple optimizations and operations out of order.\npub struct OrderedExecution<R: FusionRuntime> {\n    operations: Vec<Arc<dyn Operation<R>>>,\n    num_executed: usize,\n    ordering: Option<Arc<Vec<usize>>>,\n}\n\nimpl<R: FusionRuntime> OrderedExecution<R> {\n    /// Returns the operation that can be executed without impacting the state of the execution.\n    ///\n    /// This is useful to implement fallback for optimizations.\n    #[allow(clippy::borrowed_box)]\n    pub fn operation_within_optimization(&self, index: usize) -> Arc<dyn Operation<R>> {\n        match &self.ordering {\n            Some(val) => {\n                let index = val[index];\n                self.operations[index].clone()\n            }\n            None => panic!(\"No ordering provided\"),\n        }\n    }\n\n    pub(crate) fn new(operations: Vec<Arc<dyn Operation<R>>>) -> Self {\n        Self {\n            operations,\n            num_executed: 0,\n            ordering: None,\n        }\n    }\n\n    pub(crate) fn finish(mut self) -> (Vec<Arc<dyn Operation<R>>>, usize) {\n        self.operations.drain(0..self.num_executed);\n        (self.operations, self.num_executed)\n    }\n\n    pub(crate) fn execute_optimization(\n        &mut self,\n        optimization: &mut R::Optimization,\n        context: &mut Context<'_, R::FusionHandle>,\n        ordering: Arc<Vec<usize>>,\n    ) {\n        if ordering.len() > self.operations.len() {\n            panic!(\"Ordering is bigger than operations\");\n        }\n        self.ordering = Some(ordering);\n        let num_drained = optimization.len();\n        optimization.execute(context, self);\n        self.num_executed += num_drained;\n    }\n\n    pub(crate) fn execute_operations(\n        &mut self,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        ordering: &[usize],\n    ) {\n        self.num_executed += ordering.len();\n\n        for id in ordering {\n            let op = &self.operations[*id];\n            op.execute(handles);\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/policy.rs",
    "content": "use burn_ir::OperationIr;\n\nuse super::ExecutionMode;\nuse super::validator::{\n    ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator,\n    ValidatorState,\n};\nuse crate::stream::execution::validator::OperationsValidator;\nuse crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery};\nuse std::marker::PhantomData;\n\n/// The policy keeps track of all possible execution plans for the current operations.\n///\n/// # Details\n///\n/// We keep track of each new operation added and invalidate potential execution plans\n/// when we see a different operation is added.\n///\n/// Therefore, the overhead is very minimal, since the time-complexity of checking for existing\n/// execution plans scales with the number of concurrent potential plans for the current operations,\n/// which isn't supposed to be big at any time.\npub(crate) struct Policy<O> {\n    /// List of potential execution plans that are compatible with current stream segment\n    candidates: Vec<OperationsValidator<ExecutionPlanId>>,\n    /// List of candidate execution plans that have been found; we can still keep searching\n    /// to potentially find a better one.\n    availables: Vec<AvailableItem>,\n    /// The found execution plan that should be executed, along with the number of operations\n    /// in the plan.\n    found: Option<(ExecutionPlanId, usize)>,\n    /// The number of operations that have been analyzed\n    num_operations: usize,\n    _item_type: PhantomData<O>,\n}\n\n#[derive(new)]\nstruct AvailableItem {\n    id: ExecutionPlanId,\n    size: usize,\n    triggers: Vec<TriggerValidator>,\n}\n\n/// Action to be made depending on the stream.\n#[derive(PartialEq, Eq, Debug)]\npub enum Action {\n    /// Continue exploring using the [builder](crate::OptimizationBuilder).\n    Explore,\n    /// The current policy indicates that an exploration may be possible in the future, so the\n    /// best action is to defer any execution.\n    ///\n    /// Sometimes, it can be a false positive and a new exploration should be built from scratch.\n    /// Therefore it's important to keep the previous operations to rebuild the state if it\n    /// happens.\n    Defer,\n    /// An exploration has been found, and the best action is to execute it!\n    Execute(ExecutionPlanId),\n}\n\nimpl<O: core::fmt::Debug> Policy<O> {\n    /// Create a new policy.\n    pub(crate) fn new() -> Self {\n        Self {\n            candidates: Vec::new(),\n            availables: Vec::new(),\n            found: None,\n            num_operations: 0,\n            _item_type: PhantomData,\n        }\n    }\n\n    /// Returns the [action](Action) that should be taken given the state of the policy.\n    pub fn action(\n        &self,\n        store: &ExecutionPlanStore<O>,\n        operations: &[OperationIr],\n        mode: ExecutionMode,\n    ) -> Action {\n        if self.num_operations < operations.len() {\n            panic!(\n                \"Internal Error: Can't retrieve the policy action on a list of operations bigger than what is analyzed.\"\n            );\n        }\n\n        if let Some((id, _length)) = self.found {\n            return Action::Execute(id);\n        }\n\n        match mode {\n            ExecutionMode::Lazy => self.action_lazy(operations),\n            ExecutionMode::Sync => self.action_sync(operations, store),\n        }\n    }\n\n    /// Update the policy state.\n    pub fn update(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {\n        // reset the candidates to contain all execution plans starting with the operation.\n        if self.num_operations == 0 {\n            self.candidates = store\n                .find(SearchQuery::PlansStartingWith(operation))\n                .into_iter()\n                .map(OperationsValidator::new)\n                .collect();\n        }\n\n        self.update_candidates(store, operation);\n        self.check_candidates(store);\n\n        self.update_availables(store, operation);\n        self.check_availables();\n        self.num_operations += 1;\n    }\n\n    // Reset the state of the policy.\n    pub fn reset(&mut self) {\n        self.candidates.clear();\n        self.availables.clear();\n\n        self.num_operations = 0;\n        self.found = None;\n    }\n\n    /// Check which candidates can be removed, and which one can go from\n    /// 'candidate' to 'available'\n    fn check_candidates(&mut self, store: &ExecutionPlanStore<O>) {\n        let mut candidates_to_remove = Vec::new();\n\n        for candidate in self.candidates.iter() {\n            match candidate.state {\n                ValidatorState::Found { size } => {\n                    let item = store.get_unchecked(candidate.id);\n                    let mut triggers = Vec::with_capacity(item.triggers.len());\n\n                    for (index, trigger) in item.triggers.iter().enumerate() {\n                        triggers.push(match trigger {\n                            ExecutionTrigger::OnOperations(_) => TriggerValidator::OnOperations {\n                                matching: OperationsValidator::new(index),\n                                progress: TriggerProgress::NotInit,\n                            },\n                            ExecutionTrigger::OnSync => TriggerValidator::OnSync,\n                            ExecutionTrigger::Always => TriggerValidator::Always,\n                        });\n                    }\n\n                    self.availables\n                        .push(AvailableItem::new(candidate.id, size, triggers));\n                    candidates_to_remove.push(candidate.id);\n                }\n                ValidatorState::Invalidated => {\n                    candidates_to_remove.push(candidate.id);\n                }\n                ValidatorState::Validating => {}\n            };\n        }\n\n        let mut updated_candidates = Vec::new();\n        core::mem::swap(&mut updated_candidates, &mut self.candidates);\n\n        self.candidates = updated_candidates\n            .into_iter()\n            .filter(|candidate| !candidates_to_remove.iter().any(|id| id == &candidate.id))\n            .collect();\n    }\n\n    fn check_availables(&mut self) {\n        for available in self.availables.iter() {\n            for trigger in available.triggers.iter() {\n                match trigger {\n                    TriggerValidator::OnOperations {\n                        matching,\n                        progress: _,\n                    } => {\n                        if let ValidatorState::Found {\n                            size: _size_of_trigger,\n                        } = matching.state\n                        {\n                            self.found = Some((available.id, available.size));\n                            return;\n                        }\n                    }\n                    TriggerValidator::Always => {\n                        self.found = Some((available.id, available.size));\n                        return;\n                    }\n                    TriggerValidator::OnSync => {\n                        // Does nothing during an update.\n                    }\n                }\n            }\n        }\n    }\n\n    fn update_candidates(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {\n        let main_store = ExecutionPlanOperationsStore::new(store);\n\n        self.candidates\n            .iter_mut()\n            .for_each(|candidate| candidate.update(operation, self.num_operations, &main_store));\n    }\n\n    fn update_availables(&mut self, store: &ExecutionPlanStore<O>, operation: &OperationIr) {\n        self.availables.iter_mut().for_each(|available| {\n            let store_trigger = TriggerOperationsStore::new(available.id, store);\n\n            available.triggers.iter_mut().for_each(|trigger| {\n                if let TriggerValidator::OnOperations { matching, progress } = trigger {\n                    match progress {\n                        TriggerProgress::NotInit => {\n                            *progress = TriggerProgress::NumChecked(0);\n                        }\n                        TriggerProgress::NumChecked(num_check) => {\n                            matching.update(operation, *num_check, &store_trigger);\n                            *num_check += 1;\n                        }\n                    }\n                }\n            });\n        });\n    }\n\n    fn action_lazy(&self, operations: &[OperationIr]) -> Action {\n        if !self.candidates.is_empty() {\n            return Action::Defer;\n        }\n\n        for available in self.availables.iter() {\n            if available.size == operations.len() {\n                return Action::Defer;\n            }\n\n            for trigger in available.triggers.iter() {\n                if let TriggerValidator::OnOperations {\n                    matching,\n                    progress: _,\n                } = trigger\n                    && let ValidatorState::Validating = matching.state\n                {\n                    return Action::Defer;\n                }\n            }\n        }\n\n        Action::Explore\n    }\n\n    fn action_sync(&self, operations: &[OperationIr], store: &ExecutionPlanStore<O>) -> Action {\n        for available in self.availables.iter() {\n            if available.size == operations.len() {\n                return Action::Execute(available.id);\n            }\n        }\n\n        for candidate in self.candidates.iter() {\n            let item = store.get_unchecked(candidate.id);\n\n            if item.operations.len() == operations.len() {\n                return Action::Execute(candidate.id);\n            }\n        }\n\n        Action::Explore\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_backend::{DType, Shape};\n    use burn_ir::{FloatOperationIr, TensorId, TensorIr, TensorStatus, UnaryOpIr};\n\n    use super::*;\n    use crate::{\n        search::BlockOptimization,\n        stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger},\n    };\n    use std::ops::Range;\n\n    #[test]\n    fn given_no_optimization_should_explore() {\n        let store = ExecutionPlanStore::default();\n        let mut policy = Policy::new();\n        let stream = TestStream::new(3);\n\n        stream.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(0..3),\n            Action::Explore,\n        );\n    }\n\n    #[test]\n    fn given_existing_optimizations_when_sync_should_execute_one_when_available() {\n        let mut store = ExecutionPlanStore::default();\n        let mut policy = Policy::new();\n        let stream = TestStream::new(3);\n\n        let id_1 = store.add(ExecutionPlan {\n            operations: stream.operations[0..2].to_vec(),\n            triggers: Vec::new(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),\n        });\n        let _id_2 = store.add(ExecutionPlan {\n            operations: stream.operations[0..3].to_vec(),\n            triggers: Vec::new(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),\n        });\n\n        stream.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(0..2),\n            Action::Defer,\n        );\n\n        let action = policy.action(&store, &stream.operations[0..2], ExecutionMode::Sync);\n        assert_eq!(action, Action::Execute(id_1));\n    }\n\n    #[test]\n    fn given_existing_plan_when_found_trigger_should_execute_plan() {\n        let mut store = ExecutionPlanStore::default();\n        let mut policy = Policy::new();\n\n        let stream = TestStream::new(3);\n        let id = store.add(ExecutionPlan {\n            operations: stream.operations[0..2].to_vec(),\n            triggers: stream.operations[2..3]\n                .iter()\n                .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))\n                .collect(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),\n        });\n\n        stream.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(0..2),\n            Action::Defer,\n        );\n        stream.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(2..3),\n            Action::Execute(id),\n        );\n    }\n\n    #[test]\n    fn should_support_multiple_triggers() {\n        let mut store = ExecutionPlanStore::default();\n        let mut policy_1 = Policy::new();\n        let mut policy_2 = Policy::new();\n\n        let mut stream_1 = TestStream::new(2);\n        let mut stream_2 = TestStream::new(2);\n\n        // Create different end operation for each stream.\n        let trigger_id_1 = 5;\n        let trigger_id_2 = 6;\n        stream_1.new_ops(trigger_id_1);\n        stream_2.new_ops(trigger_id_2);\n\n        let id = store.add(ExecutionPlan {\n            operations: stream_1.operations[0..2].to_vec(),\n            triggers: vec![\n                ExecutionTrigger::OnOperations(vec![stream_1.operations[2].clone()]),\n                ExecutionTrigger::OnOperations(vec![stream_2.operations[2].clone()]),\n            ],\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),\n        });\n\n        stream_1.assert_updates(\n            &store,\n            &mut policy_1,\n            AssertUpdatesOptions::OperationsIndex(0..2),\n            Action::Defer,\n        );\n        stream_2.assert_updates(\n            &store,\n            &mut policy_2,\n            AssertUpdatesOptions::OperationsIndex(0..2),\n            Action::Defer,\n        );\n\n        stream_1.assert_updates(\n            &store,\n            &mut policy_1,\n            AssertUpdatesOptions::OperationsIndex(2..3), // First trigger.\n            Action::Execute(id),\n        );\n        stream_2.assert_updates(\n            &store,\n            &mut policy_2,\n            AssertUpdatesOptions::OperationsIndex(2..3), // Second trigger.\n            Action::Execute(id),\n        );\n    }\n\n    #[test]\n    fn should_select_right_optimization() {\n        let mut store = ExecutionPlanStore::default();\n        let mut policy_1 = Policy::new();\n        let mut policy_2 = Policy::new();\n\n        let mut stream_1 = TestStream::new(2);\n        let mut stream_2 = TestStream::new(2);\n\n        // Create different streams after op 2.\n        stream_1.new_ops(4);\n        stream_1.new_ops(5);\n\n        stream_2.new_ops(5);\n        stream_2.new_ops(6);\n\n        let optimization_stream_1 = store.add(ExecutionPlan {\n            operations: stream_1.operations[0..3].to_vec(),\n            triggers: stream_1.operations[3..4]\n                .iter()\n                .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))\n                .collect(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),\n        });\n        let optimization_stream_2 = store.add(ExecutionPlan {\n            operations: stream_2.operations[0..3].to_vec(),\n            triggers: stream_2.operations[3..4]\n                .iter()\n                .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))\n                .collect(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),\n        });\n        assert_ne!(optimization_stream_1, optimization_stream_2);\n\n        stream_1.assert_updates(\n            &store,\n            &mut policy_1,\n            AssertUpdatesOptions::OperationsIndex(0..3),\n            Action::Defer,\n        );\n        stream_2.assert_updates(\n            &store,\n            &mut policy_2,\n            AssertUpdatesOptions::OperationsIndex(0..3),\n            Action::Defer,\n        );\n\n        stream_1.assert_updates(\n            &store,\n            &mut policy_1,\n            AssertUpdatesOptions::OperationsIndex(3..4),\n            Action::Execute(optimization_stream_1),\n        );\n        stream_2.assert_updates(\n            &store,\n            &mut policy_2,\n            AssertUpdatesOptions::OperationsIndex(3..4),\n            Action::Execute(optimization_stream_2),\n        );\n    }\n\n    #[test]\n    fn should_invalidate_wrong_optimizations() {\n        let mut store = ExecutionPlanStore::default();\n        let stream_1 = TestStream::new(4);\n        let mut stream_2 = TestStream::new(2);\n        stream_2.new_ops(6);\n        stream_2.new_ops(7);\n\n        store.add(ExecutionPlan {\n            operations: stream_1.operations[0..3].to_vec(),\n            triggers: stream_1.operations[3..4]\n                .iter()\n                .map(|desc| ExecutionTrigger::OnOperations(vec![desc.clone()]))\n                .collect(),\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(3), Vec::new()),\n        });\n\n        let mut policy = Policy::new();\n        // Same path as stream 1\n        stream_2.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(0..2),\n            Action::Defer,\n        );\n\n        // But is different.\n        stream_2.assert_updates(\n            &store,\n            &mut policy,\n            AssertUpdatesOptions::OperationsIndex(2..4),\n            Action::Explore,\n        );\n    }\n\n    #[derive(Default, Debug)]\n    struct TestStream {\n        tensors: Vec<TensorIr>,\n        operations: Vec<OperationIr>,\n    }\n\n    #[derive(Debug)]\n    enum AssertUpdatesOptions {\n        OperationsIndex(Range<usize>),\n    }\n\n    impl TestStream {\n        /// Create a new test stream with `num_ops` operations registered.\n        pub fn new(num_ops: usize) -> Self {\n            let mut stream = Self::default();\n            for id in 0..num_ops {\n                stream.new_ops(id as u64 + 1);\n            }\n\n            stream\n        }\n\n        /// The first follow should only be cache miss.\n        pub fn assert_updates(\n            &self,\n            optimizations: &ExecutionPlanStore<()>,\n            policy: &mut Policy<()>,\n            options: AssertUpdatesOptions,\n            action: Action,\n        ) {\n            match options {\n                AssertUpdatesOptions::OperationsIndex(range) => {\n                    for i in range {\n                        let stream = &self.operations[0..i];\n                        let next_ops = &self.operations[i];\n                        policy.update(optimizations, next_ops);\n                        let result = policy.action(optimizations, stream, ExecutionMode::Lazy);\n\n                        assert_eq!(result, action);\n                    }\n                }\n            }\n        }\n\n        /// Add a simple operation to the stream.\n        pub fn new_ops(&mut self, out_id: u64) {\n            if self.tensors.is_empty() {\n                // Root node.\n                self.new_empty_node(0);\n            }\n\n            // Out node.\n            self.new_empty_node(out_id);\n\n            self.operations.push(OperationIr::Float(\n                DType::F32,\n                FloatOperationIr::Log(self.unary_description()),\n            ));\n        }\n\n        fn new_empty_node(&mut self, id: u64) {\n            self.tensors.push(TensorIr {\n                id: TensorId::new(id),\n                shape: Shape::new([32, 32, 1]),\n                status: TensorStatus::NotInit,\n                dtype: DType::F32,\n            });\n        }\n\n        fn unary_description(&self) -> UnaryOpIr {\n            let size = self.tensors.len();\n\n            UnaryOpIr {\n                input: self.tensors[size - 2].clone(),\n                out: self.tensors[size - 1].clone(),\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/processor.rs",
    "content": "use burn_ir::OperationIr;\n\nuse super::{ExecutionMode, ExplorationAction, Explorer};\nuse crate::search::BlockOptimization;\nuse crate::stream::execution::{Action, Policy};\nuse crate::stream::store::{ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};\nuse crate::{NumOperations, OperationFuser};\n\n/// Process a [stream segment](StreamSegment) following a [policy](Policy).\npub(crate) struct Processor<O> {\n    policy: Policy<O>,\n    explorer: Explorer<O>,\n}\n\n/// A part of a stream that can be executed partially using [execution plan](ExecutionPlan).\npub(crate) trait StreamSegment<O> {\n    /// The operations in the segment.\n    fn operations(&self) -> &[OperationIr];\n    /// Execute part of the segment using the given plan id.\n    fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<O>);\n}\n\nimpl<O: NumOperations> Processor<O> {\n    /// Create a new stream processor.\n    pub fn new(optimizations: Vec<Box<dyn OperationFuser<O>>>) -> Self {\n        Self {\n            policy: Policy::new(),\n            explorer: Explorer::new(optimizations),\n        }\n    }\n\n    /// Process the [stream segment](StreamSegment) with the provided [mode](ExecutionMode).\n    pub fn process<Segment>(\n        &mut self,\n        mut segment: Segment,\n        store: &mut ExecutionPlanStore<O>,\n        mode: ExecutionMode,\n    ) where\n        Segment: StreamSegment<O>,\n    {\n        // We assume that we always register a new operation in lazy mode.\n        if let ExecutionMode::Lazy = mode {\n            self.on_new_operation(&segment, store);\n        }\n\n        loop {\n            if segment.operations().is_empty() {\n                break;\n            }\n\n            let action = self.policy.action(store, segment.operations(), mode);\n\n            match action {\n                Action::Explore => {\n                    self.explore(&mut segment, store, mode);\n\n                    if self.explorer.is_up_to_date() {\n                        break;\n                    }\n                }\n                Action::Defer => {\n                    match mode {\n                        ExecutionMode::Lazy => break,\n                        ExecutionMode::Sync => panic!(\"Can't defer while sync\"),\n                    };\n                }\n                Action::Execute(id) => {\n                    if let ExecutionMode::Sync = mode {\n                        store.add_trigger(id, ExecutionTrigger::OnSync);\n                    }\n\n                    segment.execute(id, store);\n                    self.reset(store, segment.operations());\n                }\n            };\n        }\n    }\n\n    fn on_new_operation<Segment>(&mut self, segment: &Segment, store: &mut ExecutionPlanStore<O>)\n    where\n        Segment: StreamSegment<O>,\n    {\n        self.policy.update(\n            store,\n            segment\n                .operations()\n                .last()\n                .expect(\"At least one operation in the operation list.\"),\n        );\n        self.explorer.on_new_operation();\n    }\n\n    fn explore<Item: StreamSegment<O>>(\n        &mut self,\n        item: &mut Item,\n        store: &mut ExecutionPlanStore<O>,\n        mode: ExecutionMode,\n    ) {\n        match self.explorer.explore(item.operations(), mode) {\n            ExplorationAction::Completed(optim) => {\n                let id = Self::on_exploration_completed(\n                    &self.policy,\n                    item.operations(),\n                    store,\n                    optim,\n                    mode,\n                );\n                item.execute(id, store);\n                self.reset(store, item.operations());\n            }\n            ExplorationAction::Continue => {\n                if let ExecutionMode::Sync = mode {\n                    panic!(\"Can't continue exploring when sync.\")\n                }\n            }\n        }\n    }\n\n    fn reset(&mut self, store: &mut ExecutionPlanStore<O>, operations: &[OperationIr]) {\n        self.explorer.reset(operations);\n        self.policy.reset();\n\n        // Reset the policy state with the remaining operations\n        for operation in operations.iter() {\n            self.policy.update(store, operation);\n        }\n    }\n\n    /// We found an optimization (i.e. a new execution plan).\n    /// Cache it in the store.\n    fn on_exploration_completed(\n        policy: &Policy<O>,\n        operations: &[OperationIr],\n        store: &mut ExecutionPlanStore<O>,\n        optimization: BlockOptimization<O>,\n        mode: ExecutionMode,\n    ) -> ExecutionPlanId {\n        let num_optimized = optimization.ordering.len();\n        let relative = &operations[0..num_optimized];\n\n        match mode {\n            ExecutionMode::Lazy => {\n                let next_ops = &operations[num_optimized..operations.len()];\n\n                let trigger = if next_ops.is_empty() {\n                    // Happens if the next ops is included in the fused operation, and there is no\n                    // way the builder can still continue fusing.\n                    ExecutionTrigger::Always\n                } else {\n                    ExecutionTrigger::OnOperations(next_ops.to_vec())\n                };\n\n                match policy.action(store, relative, ExecutionMode::Sync) {\n                    Action::Execute(id) => {\n                        store.add_trigger(id, trigger);\n                        id\n                    }\n                    _ => {\n                        let plan = ExecutionPlan {\n                            operations: relative.to_vec(),\n                            triggers: vec![trigger],\n                            optimization,\n                        };\n                        store.add(plan)\n                    }\n                }\n            }\n            ExecutionMode::Sync => match policy.action(store, relative, ExecutionMode::Sync) {\n                Action::Execute(id) => {\n                    store.add_trigger(id, ExecutionTrigger::OnSync);\n                    id\n                }\n                _ => {\n                    let plan = ExecutionPlan {\n                        operations: relative.to_vec(),\n                        triggers: vec![ExecutionTrigger::OnSync],\n                        optimization,\n                    };\n                    store.add(plan)\n                }\n            },\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/tests.rs",
    "content": "//! A testing module that ensures the correctness of the explorer, policy, and processor.\n//!\n//! The primary focus is on validating the seamless interaction between these three components to\n//! execute and optimize a stream of operations accurately.\n//!\n//! To test these components effectively, we create mock types for the stream, optimization,\n//! optimization builder, and stream segment. These mock types aid in comprehensively\n//! understanding the process of optimizing streams.\nuse std::sync::Arc;\n\nuse burn_backend::{DType, Shape};\nuse burn_ir::{\n    BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarIr, ScalarOpIr, TensorId,\n    TensorIr, TensorStatus, UnaryOpIr,\n};\n\nuse crate::{\n    FuserProperties, FuserStatus, NumOperations, OperationFuser,\n    search::BlockOptimization,\n    stream::store::{\n        ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger,\n    },\n};\n\nuse super::*;\n\n/// A fake stream of operations for testing purpose.\npub struct TestStream {\n    processor: Processor<TestOptimization>,\n    store: ExecutionPlanStore<TestOptimization>,\n    executed: Vec<ExecutionPlanId>,\n    operations: Vec<OperationIr>,\n}\n\n/// A fake [optimization builder](OptimizationBuilder) for testing purpose.\n///\n/// The optimizer tries to fuse only the `expected_operations` if they appear\n/// in the operations queue\n#[derive(Clone)]\npub struct TestOptimizationBuilder {\n    builder_id: usize,\n    expected_operations: Vec<OperationIr>,\n    actual: Vec<OperationIr>,\n}\n\n/// A fake optimization for testing purpose.\n#[derive(new, Debug, PartialEq)]\npub struct TestOptimization {\n    builder_id: usize,\n    size: usize,\n}\n\nimpl NumOperations for TestOptimization {\n    fn len(&self) -> usize {\n        self.size\n    }\n}\n\n/// A fake [stream segment](StreamSegment) for testing purpose.\n#[derive(new)]\npub struct TestSegment<'i> {\n    operations: &'i mut Vec<OperationIr>,\n    executed: &'i mut Vec<ExecutionPlanId>,\n}\n\nimpl<O> ExecutionStrategy<O> {\n    /// Create an ordered execution strategy with the given size.\n    pub fn operations(size: usize) -> Self {\n        Self::Operations {\n            ordering: Arc::new((0..size).collect()),\n        }\n    }\n}\n\nimpl ExecutionStrategy<TestOptimization> {\n    /// Only use it for testing, to easily create ordered strategies.\n    pub fn optimization(opt: TestOptimization) -> Self {\n        let ordering = Arc::new((0..opt.size).collect());\n        Self::Optimization { opt, ordering }\n    }\n}\n\n/// This is a substantial test case that examines a lengthy scenario with a diverse set of conditions.\n///\n/// While it's usually preferable to split tests into multiple independent scenarios, in this case, it is\n/// crucial to verify that the stream's state is correctly updated when various cases occur consecutively.\n#[test]\nfn should_support_complex_stream() {\n    // We have 2 different optimization builders in this test case.\n    let builder_id_1 = 0;\n    let builder_id_2 = 1;\n\n    // We will have a total of 3 execution plans to execute.\n    let plan_id_1 = 0;\n    let plan_id_2 = 1;\n    let plan_id_3 = 2;\n\n    let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);\n    let builder_2 = TestOptimizationBuilder::new(builder_id_2, vec![operation_2(), operation_2()]);\n    let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);\n\n    // builder_1 is still waiting to see next op is operation_2\n    // builder_2 is closed because it's not the right operation\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(0);\n\n    // No optimization found for the first two operations.\n    stream.add(operation_1());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(1);\n    stream.assert_last_executed(plan_id_1);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_1()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization::new(ExecutionStrategy::operations(2), Vec::new()),\n        },\n    );\n\n    // Nothing to execute.\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(1);\n\n    // Now we should trigger the first optimization builder.\n    stream.add(operation_2());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(2);\n    stream.assert_last_executed(plan_id_2);\n    stream.assert_plan(\n        plan_id_2,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization::new(\n                ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),\n                vec![0, 1],\n            ),\n        },\n    );\n\n    // Nothing to execute.\n    stream.add(operation_2());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(2);\n\n    // Now we should trigger the second optimization builder.\n    stream.add(operation_2());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(3);\n    stream.assert_last_executed(plan_id_3);\n    stream.assert_plan(\n        plan_id_3,\n        ExecutionPlan {\n            operations: vec![operation_2(), operation_2()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_2, 2)),\n                ordering: vec![0, 1],\n            },\n        },\n    );\n\n    // Nothing to execute.\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(3);\n\n    // Now we should trigger the first optimization builder (second plan).\n    stream.add(operation_2());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(4);\n    stream.assert_last_executed(plan_id_2);\n    stream.assert_plan(\n        plan_id_2,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),\n                ordering: vec![0, 1],\n            },\n        },\n    );\n\n    // Nothing to execute.\n    stream.add(operation_2());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(4);\n\n    // Now we should trigger the first optimization builder (third plan).\n    stream.add(operation_2());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(5);\n    stream.assert_last_executed(plan_id_3);\n}\n\n/// In this scenario we will never use an optimization, but we check that we reuse the execution plan stored.\n#[test]\nfn should_reuse_basic_operations() {\n    let builder_id_1 = 0;\n    let plan_id_1 = 0;\n    let plan_id_2 = 1;\n\n    let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);\n    let mut stream = TestStream::new(vec![Box::new(builder_1)]);\n\n    stream.add(operation_3());\n    stream.assert_last_executed(plan_id_1);\n    stream.assert_number_of_operations(0);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_3()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(1),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    stream.add(operation_3());\n    stream.assert_last_executed(plan_id_1);\n    stream.assert_number_of_operations(0);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_3()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(1),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    // Lazy try to build optimization 1.\n    stream.add(operation_1());\n    // But not possible.\n    stream.add(operation_3());\n\n    // Creates a new plan with both operations.\n    stream.assert_plan(\n        plan_id_2,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_3()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(2),\n                ordering: vec![0],\n            },\n        },\n    );\n    stream.assert_number_of_operations(0);\n    stream.assert_last_executed(plan_id_2);\n}\n\n// In this scenario we validate that we support multiple optimization builders with overlapping\n// operations.\n//\n// This is a very long scenario that validates a lot of things.\n#[test]\nfn should_support_overlapping_optimizations() {\n    // We have 2 different optimization builders in this test case.\n    let builder_id_1 = 0;\n    let builder_id_2 = 0;\n\n    // We will have a total of 5 execution plans to execute.\n    let plan_id_1 = 0;\n    let plan_id_2 = 1;\n    let plan_id_3 = 2;\n    let plan_id_4 = 3;\n    let plan_id_5 = 4;\n\n    let builder_1 = TestOptimizationBuilder::new(builder_id_1, vec![operation_1(), operation_2()]);\n    let builder_2 = TestOptimizationBuilder::new(\n        builder_id_2,\n        vec![operation_1(), operation_2(), operation_1(), operation_1()],\n    );\n    let mut stream = TestStream::new(vec![Box::new(builder_1), Box::new(builder_2)]);\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(0);\n\n    stream.add(operation_2());\n    stream.assert_number_of_operations(2);\n    stream.assert_number_of_executions(0);\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(3);\n    stream.assert_number_of_executions(0);\n\n    stream.add(operation_2());\n    stream.assert_number_of_operations(2);\n    stream.assert_number_of_executions(1);\n    stream.assert_last_executed(plan_id_1);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2()],\n            triggers: vec![ExecutionTrigger::OnOperations(vec![\n                operation_1(),\n                operation_2(),\n            ])],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),\n                ordering: vec![0, 1],\n            },\n        },\n    );\n\n    stream.add(operation_2());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(3);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2()],\n            triggers: vec![\n                ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),\n                ExecutionTrigger::OnOperations(vec![operation_2()]),\n            ],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),\n                ordering: vec![0, 1],\n            },\n        },\n    );\n    stream.assert_plan(\n        plan_id_2,\n        ExecutionPlan {\n            operations: vec![operation_2()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(1),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(3);\n\n    stream.add(operation_2());\n    stream.assert_number_of_operations(2);\n    stream.assert_number_of_executions(3);\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(3);\n    stream.assert_number_of_executions(3);\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(4);\n\n    stream.assert_plan(\n        plan_id_3,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2(), operation_1(), operation_1()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 4)),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(1);\n    stream.assert_number_of_executions(4);\n\n    stream.add(operation_2());\n    stream.assert_number_of_operations(2);\n    stream.assert_number_of_executions(4);\n\n    stream.add(operation_1());\n    stream.assert_number_of_operations(3);\n    stream.assert_number_of_executions(4);\n\n    stream.sync();\n    stream.assert_number_of_operations(0);\n    stream.assert_number_of_executions(6);\n    stream.assert_plan(\n        plan_id_1,\n        ExecutionPlan {\n            operations: vec![operation_1(), operation_2()],\n            triggers: vec![\n                ExecutionTrigger::OnOperations(vec![operation_1(), operation_2()]),\n                ExecutionTrigger::OnOperations(vec![operation_2()]),\n                ExecutionTrigger::OnSync,\n            ],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::optimization(TestOptimization::new(builder_id_1, 2)),\n                ordering: vec![0, 1],\n            },\n        },\n    );\n    stream.assert_plan(\n        plan_id_4,\n        ExecutionPlan {\n            operations: vec![operation_1()],\n            triggers: vec![ExecutionTrigger::OnSync],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(1),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    stream.add(operation_3());\n    stream.assert_last_executed(plan_id_5);\n    stream.assert_plan(\n        plan_id_5,\n        ExecutionPlan {\n            operations: vec![operation_3()],\n            triggers: vec![ExecutionTrigger::Always],\n            optimization: BlockOptimization {\n                strategy: ExecutionStrategy::operations(1),\n                ordering: vec![0],\n            },\n        },\n    );\n\n    stream.add(operation_3());\n    stream.assert_last_executed(plan_id_5);\n}\n\nimpl TestStream {\n    /// Create a new stream with the given optimization builders.\n    fn new(optimizations: Vec<Box<dyn OperationFuser<TestOptimization>>>) -> Self {\n        Self {\n            processor: Processor::<TestOptimization>::new(optimizations),\n            store: ExecutionPlanStore::<TestOptimization>::new(),\n            executed: Vec::new(),\n            operations: Vec::new(),\n        }\n    }\n\n    /// Add an operation to the stream.\n    fn add(&mut self, operation: OperationIr) {\n        self.operations.push(operation);\n        self.processor.process(\n            TestSegment::new(&mut self.operations, &mut self.executed),\n            &mut self.store,\n            ExecutionMode::Lazy,\n        );\n    }\n\n    /// Sync the stream.\n    fn sync(&mut self) {\n        self.processor.process(\n            TestSegment::new(&mut self.operations, &mut self.executed),\n            &mut self.store,\n            ExecutionMode::Sync,\n        );\n    }\n\n    /// Assert that the plan has been executed as provided.\n    fn assert_plan(&self, id: ExecutionPlanId, expected: ExecutionPlan<TestOptimization>) {\n        let actual = self.store.get_unchecked(id);\n        assert_eq!(actual.operations, expected.operations, \"Same operations\");\n        assert_eq!(actual.triggers, expected.triggers, \"Same triggers\");\n    }\n\n    /// Assert that the given plan id has been the last executed.\n    fn assert_last_executed(&self, id: ExecutionPlanId) {\n        match self.executed.last() {\n            Some(last_id) => assert_eq!(*last_id, id),\n            None => panic!(\"No plan has been executed\"),\n        }\n    }\n\n    /// Assert the number of executions since the start of the stream.\n    fn assert_number_of_executions(&self, number: usize) {\n        assert_eq!(self.executed.len(), number, \"Number of execution match\");\n    }\n\n    /// Assert the number of operations queued.\n    fn assert_number_of_operations(&self, number: usize) {\n        assert_eq!(self.operations.len(), number);\n    }\n}\n\nimpl TestOptimizationBuilder {\n    /// Create a new optimization builder that follows a pattern with a trigger.\n    pub fn new(builder_id: usize, operations: Vec<OperationIr>) -> Self {\n        Self {\n            builder_id,\n            expected_operations: operations,\n            actual: Vec::new(),\n        }\n    }\n}\n\nimpl OperationFuser<TestOptimization> for TestOptimizationBuilder {\n    /// Register a new operation.\n    fn fuse(&mut self, operation: &OperationIr) {\n        self.actual.push(operation.clone());\n    }\n\n    /// Build the optimization.\n    fn finish(&mut self) -> TestOptimization {\n        TestOptimization::new(self.builder_id, self.len())\n    }\n\n    /// Reset the state.\n    fn reset(&mut self) {\n        self.actual.clear();\n    }\n\n    /// Return the optimization status.\n    fn status(&self) -> FuserStatus {\n        if self.actual.len() < self.expected_operations.len() {\n            let operations = &self.expected_operations[0..self.actual.len()];\n\n            return match self.actual == operations {\n                // Still optimizing.\n                true => FuserStatus::Open,\n                // Never gonna be possible on that stream.\n                false => FuserStatus::Closed,\n            };\n        }\n\n        FuserStatus::Closed\n    }\n\n    /// Return the properties of this optimization.\n    fn properties(&self) -> FuserProperties {\n        if self.actual.len() < self.expected_operations.len() {\n            // Optimization not possible.\n            return FuserProperties {\n                score: 0,\n                ready: false,\n            };\n        }\n\n        let stream_is_ok =\n            self.actual[0..self.expected_operations.len()] == self.expected_operations;\n\n        if !stream_is_ok {\n            // Optimization not possible.\n            return FuserProperties {\n                score: 0,\n                ready: false,\n            };\n        }\n\n        // Optimization possible.\n        FuserProperties {\n            score: self.expected_operations.len() as u64,\n            ready: true,\n        }\n    }\n\n    // The number of operations that should be handle by the optimization.\n    fn len(&self) -> usize {\n        self.expected_operations.len()\n    }\n    fn clone_dyn(&self) -> Box<dyn OperationFuser<TestOptimization>> {\n        Box::new(self.clone())\n    }\n}\n\nimpl StreamSegment<TestOptimization> for TestSegment<'_> {\n    // The operations in the process.\n    fn operations(&self) -> &[OperationIr] {\n        self.operations\n    }\n\n    // Execute the process.\n    fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<TestOptimization>) {\n        let execution_plan = store.get_unchecked(id);\n\n        self.execute_strategy(&execution_plan.optimization.strategy);\n\n        self.executed.push(id);\n    }\n}\n\nimpl TestSegment<'_> {\n    fn execute_strategy(&mut self, strategy: &ExecutionStrategy<TestOptimization>) {\n        match strategy {\n            ExecutionStrategy::Optimization { opt, .. } => {\n                self.operations.drain(0..opt.size);\n            }\n            ExecutionStrategy::Operations { ordering } => {\n                self.operations.drain(0..ordering.len());\n            }\n            ExecutionStrategy::Composed(strategies) => {\n                for strategy in strategies {\n                    self.execute_strategy(strategy);\n                }\n            }\n        }\n    }\n}\n\n/// Just a simple operation.\npub fn operation_1() -> OperationIr {\n    OperationIr::NumericFloat(\n        DType::F32,\n        NumericOperationIr::Add(BinaryOpIr {\n            lhs: TensorIr {\n                id: TensorId::new(0),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32,\n            },\n            rhs: TensorIr {\n                id: TensorId::new(1),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32,\n            },\n            out: TensorIr {\n                id: TensorId::new(2),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::NotInit,\n                dtype: DType::F32,\n            },\n        }),\n    )\n}\n\n/// Just a simple operation.\npub fn operation_2() -> OperationIr {\n    OperationIr::NumericFloat(\n        DType::F32,\n        NumericOperationIr::AddScalar(ScalarOpIr {\n            lhs: TensorIr {\n                id: TensorId::new(0),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32,\n            },\n            rhs: ScalarIr::Float(5.0),\n            out: TensorIr {\n                id: TensorId::new(2),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::NotInit,\n                dtype: DType::F32,\n            },\n        }),\n    )\n}\n\n/// Just a simple operation.\npub fn operation_3() -> OperationIr {\n    OperationIr::Float(\n        DType::F32,\n        FloatOperationIr::Log(UnaryOpIr {\n            input: TensorIr {\n                id: TensorId::new(0),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::ReadOnly,\n                dtype: DType::F32,\n            },\n            out: TensorIr {\n                id: TensorId::new(0),\n                shape: Shape::new([32, 32]),\n                status: TensorStatus::NotInit,\n                dtype: DType::F32,\n            },\n        }),\n    )\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/execution/validator.rs",
    "content": "use burn_ir::OperationIr;\n\nuse crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};\n\n/// Compare each operation in the list of operations provided by the [store](OperationsStore)\n/// to verify if the newly added operations match the original list.\n///\n/// It is used by the [policy](crate::stream::execution::Policy) to check each candidate as well\n/// as to verify if a list of operations is optimal to execute based on their triggers.\n#[derive(Debug)]\npub(crate) struct OperationsValidator<ID> {\n    /// The ID used to retrieve the operation list.\n    pub(crate) id: ID,\n    /// The current [state](MatchingState).\n    pub(crate) state: ValidatorState,\n}\n\n/// The state of the validator.\n#[derive(Debug)]\npub(crate) enum ValidatorState {\n    /// A matching operation list has been found.\n    Found { size: usize },\n    /// No matching operation list has been found.\n    Invalidated,\n    /// Potentially going to find a matching operation list when more operations are added.\n    Validating,\n}\n\n/// Provides a list of operations based on an Id.\npub(crate) trait OperationsStore {\n    /// The type used for the identifier.\n    type Id: Copy;\n\n    /// retrieve the list of operations corresponding on the provided id.\n    fn get(&self, id: Self::Id) -> &[OperationIr];\n}\n\nimpl<ID> OperationsValidator<ID> {\n    /// Create a new validator.\n    pub(crate) fn new(id: ID) -> Self {\n        Self {\n            id,\n            state: ValidatorState::Validating,\n        }\n    }\n\n    /// Update the state of the validator based on the newly added operation.\n    pub(crate) fn update<S>(&mut self, added: &OperationIr, added_position: usize, store: &S)\n    where\n        S: OperationsStore<Id = ID>,\n        ID: PartialEq + Copy,\n    {\n        match &self.state {\n            ValidatorState::Found { size: _ } => return,\n            ValidatorState::Invalidated => return,\n            ValidatorState::Validating => {}\n        };\n\n        let item = store.get(self.id);\n        let operation_candidate = match item.get(added_position) {\n            Some(val) => val,\n            None => {\n                self.state = ValidatorState::Invalidated;\n                return;\n            }\n        };\n\n        if operation_candidate != added {\n            self.state = ValidatorState::Invalidated;\n            return;\n        }\n\n        // Finished\n        if item.len() == added_position + 1 {\n            self.state = ValidatorState::Found { size: item.len() };\n        }\n    }\n}\n\n/// [Operations store](OperationsStore) used to retrieve the list of operations for a trigger.\n#[derive(new)]\npub(crate) struct TriggerOperationsStore<'a, O> {\n    id: ExecutionPlanId,\n    store: &'a ExecutionPlanStore<O>,\n}\n\n/// Validates when operations match a trigger.\n#[derive(Debug)]\npub(crate) enum TriggerValidator {\n    OnOperations {\n        matching: OperationsValidator<TriggerId>,\n        progress: TriggerProgress,\n    },\n    Always,\n    OnSync,\n}\n\n/// The progress made into the trigger validation process.\n#[derive(Debug)]\npub(crate) enum TriggerProgress {\n    /// When the validation hasn't started.\n    NotInit,\n    /// The number of operations that have been checked.\n    NumChecked(usize),\n}\n\n/// An execution plan can have many triggers, so we use the position in the list to identify a\n/// trigger.\npub(crate) type TriggerId = usize;\n\nimpl<O: core::fmt::Debug> OperationsStore for TriggerOperationsStore<'_, O> {\n    type Id = TriggerId;\n\n    fn get(&self, id: Self::Id) -> &[OperationIr] {\n        match &self.store.get_unchecked(self.id).triggers[id] {\n            ExecutionTrigger::OnOperations(operations) => operations,\n            ExecutionTrigger::OnSync => &[],\n            ExecutionTrigger::Always => &[],\n        }\n    }\n}\n\n/// [Operations store](OperationsStore) used to retrieve the list of operations for an\n/// [execution plan](crate::stream::store::ExecutionPlan).\n#[derive(new)]\npub(crate) struct ExecutionPlanOperationsStore<'a, O> {\n    store: &'a ExecutionPlanStore<O>,\n}\n\nimpl<O: core::fmt::Debug> OperationsStore for ExecutionPlanOperationsStore<'_, O> {\n    type Id = ExecutionPlanId;\n\n    fn get(&self, id: Self::Id) -> &[OperationIr] {\n        &self.store.get_unchecked(id).operations\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/memory_checks.rs",
    "content": "use hashbrown::HashMap;\nuse std::{\n    fmt::Display,\n    sync::{\n        Arc,\n        atomic::{AtomicU64, Ordering},\n        mpsc::SyncSender,\n    },\n    thread::JoinHandle,\n    time::Duration,\n};\n\nuse burn_ir::{HandleContainer, TensorId, TensorStatus};\nuse burn_std::id::StreamId;\n\nuse crate::FusionRuntime;\n\nuse super::Stream;\n\n/// Memory checks struct to validate there is no memory leak with the fusion runtime.\n#[derive(Clone)]\npub(crate) struct MemoryChecks {\n    sender: SyncSender<Message>,\n    num_queued: Arc<AtomicU64>,\n    // Keeps track of its thread.\n    _handle: Arc<JoinHandle<()>>,\n}\n\nenum Message {\n    Register(StreamAnalyses),\n    Check(SyncSender<MemoryReport>),\n}\n\nenum MemoryReport {\n    Success,\n    NotReady,\n    NotStarted,\n    Fail(String),\n}\n\n#[derive(Default)]\nstruct StreamAnalyses {\n    streams: HashMap<StreamId, Analysis>,\n    num_handles: usize,\n}\n\nimpl Display for StreamAnalyses {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_str(\"\\n==== Fusion Memory Report ====\\n\")?;\n        f.write_fmt(format_args!(\" - Handles: {}\\n\", self.num_handles))?;\n        f.write_fmt(format_args!(\" - Streams: {}\\n\", self.streams.len()))?;\n\n        for (id, analysis) in self.streams.iter() {\n            f.write_fmt(format_args!(\n                \"  - {} => operations: {} cursor: {}\\n\",\n                id, analysis.num_operations, analysis.cursor\n            ))?;\n            for (tid, (origin, status)) in analysis.variables.iter() {\n                f.write_fmt(format_args!(\n                    \"   - {tid} => origin: {origin} status: {status:?}\\n\",\n                ))?;\n            }\n        }\n\n        f.write_str(\"==============================\\n\")\n    }\n}\n\n#[derive(Default, Debug)]\nstruct Analysis {\n    variables: HashMap<TensorId, (StreamId, TensorStatus)>,\n    num_operations: usize,\n    cursor: u64,\n}\n\n#[macro_export]\n/// Export memory checks tests.\nmacro_rules! memory_checks {\n    () => {\n        #[cfg(test)]\n        mod memory_checks {\n            #[test]\n            fn test_memory_leaks() {\n                burn_fusion::stream::memory_checks::check_memory_leaks();\n            }\n        }\n    };\n}\n\nstatic INSTANCE: spin::Mutex<Option<MemoryChecks>> = spin::Mutex::new(None);\n\n/// Performs memory checks and panics if a leak is discovered.\npub fn check_memory_leaks() {\n    let mut num_try_uninit = 0;\n    let max_try = 25;\n\n    loop {\n        let report = fetch_memory_report();\n        match report {\n            MemoryReport::Success => return,\n            MemoryReport::NotReady => {\n                num_try_uninit = 0;\n                std::thread::sleep(Duration::from_millis(100))\n            }\n            MemoryReport::NotStarted => {\n                if num_try_uninit >= max_try {\n                    // Nothing is running on the fusion runtime.\n                    return;\n                }\n                num_try_uninit += 1;\n                std::thread::sleep(Duration::from_millis(100))\n            }\n            MemoryReport::Fail(msg) => panic!(\"{msg}\"),\n        }\n    }\n}\n\nfn fetch_memory_report() -> MemoryReport {\n    let report = INSTANCE.lock();\n\n    let report = match report.as_ref() {\n        Some(client) => client,\n        None => return MemoryReport::NotStarted,\n    };\n\n    let (sender, rec) = std::sync::mpsc::sync_channel(1);\n    match report.sender.send(Message::Check(sender)) {\n        Ok(_) => {}\n        Err(err) => {\n            panic!(\"Channel closed can't send the check call: {err:?}\")\n        }\n    };\n\n    match rec.recv() {\n        Ok(report) => report,\n        Err(err) => panic!(\"Received an error from fetching check results: {err}\"),\n    }\n}\n\nimpl Default for MemoryChecks {\n    fn default() -> Self {\n        let mut instance = INSTANCE.lock();\n        let result = match instance.as_mut() {\n            Some(client) => client.clone(),\n            None => {\n                let this = Self::spawn_new();\n                *instance = Some(this.clone());\n                this\n            }\n        };\n        core::mem::drop(instance);\n        result\n    }\n}\n\nimpl MemoryChecks {\n    pub(crate) fn check<R: FusionRuntime>(\n        &mut self,\n        streams: &HashMap<StreamId, Stream<R>>,\n        handles: &HandleContainer<R::FusionHandle>,\n    ) {\n        let mut analyses = StreamAnalyses {\n            num_handles: handles.num_handles(),\n            streams: Default::default(),\n        };\n\n        for (id, s) in streams.iter() {\n            let analysis = Analysis {\n                variables: s.queue.variables.clone(),\n                num_operations: s.queue.global.len(),\n                cursor: s.cursor,\n            };\n            analyses.streams.insert(*id, analysis);\n        }\n\n        self.num_queued.fetch_add(1, Ordering::Relaxed);\n        match self.sender.send(Message::Register(analyses)) {\n            Ok(..) => {}\n            Err(err) => {\n                panic!(\"Can't register memory checks analysis: {err:?}\")\n            }\n        }\n    }\n\n    fn spawn_new() -> Self {\n        let (sender, rec) = std::sync::mpsc::sync_channel(100);\n        let num_queued = Arc::new(AtomicU64::new(0));\n        let num_queued_moved = num_queued.clone();\n\n        let handle = std::thread::spawn(move || {\n            let mut last_analyses = None;\n\n            loop {\n                let payload = match rec.recv() {\n                    Err(_err) => {\n                        // A client has panic, safe to skip as it may be normal.\n                        continue;\n                    }\n                    Ok(payload) => payload,\n                };\n                match payload {\n                    Message::Register(payload) => {\n                        last_analyses = Some(payload);\n                        num_queued_moved.fetch_sub(1, Ordering::Relaxed);\n                    }\n                    Message::Check(callback) => {\n                        if num_queued_moved.load(Ordering::Relaxed) > 1 {\n                            callback.send(MemoryReport::NotReady).unwrap();\n                            continue;\n                        }\n\n                        // We assume that if nothing has been registered in the last second\n                        // while being at a count of 1, it's the end.\n                        std::thread::sleep(Duration::from_secs(5));\n\n                        if num_queued_moved.load(Ordering::Relaxed) <= 1 {\n                            match last_analyses.take() {\n                                Some(val) => {\n                                    callback.send(Self::final_check(val)).unwrap();\n                                }\n                                None => {\n                                    callback\n                                        .send(MemoryReport::Fail(\"No analyses\".into()))\n                                        .unwrap();\n                                }\n                            }\n                        } else {\n                            callback.send(MemoryReport::NotReady).unwrap();\n                        }\n                    }\n                }\n            }\n        });\n\n        Self {\n            sender,\n            num_queued,\n            _handle: Arc::new(handle),\n        }\n    }\n\n    fn final_check(analyses: StreamAnalyses) -> MemoryReport {\n        if !analyses.streams.is_empty() || analyses.num_handles > 0 {\n            return MemoryReport::Fail(format!(\"{analyses}\"));\n        }\n\n        MemoryReport::Success\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/mod.rs",
    "content": "pub(crate) mod execution;\npub(crate) mod queue;\npub(crate) mod shared_tensors;\npub(crate) mod store;\n\n#[cfg(feature = \"memory-checks\")]\n/// Memory checks module.\npub mod memory_checks;\n\n#[cfg(not(feature = \"memory-checks\"))]\n#[macro_export]\n/// Export memory checks tests.\nmacro_rules! memory_checks {\n    () => {\n        #[cfg(test)]\n        mod memory_checks {\n            #[ignore = \"'memory-checks' disabled\"]\n            #[test]\n            fn test_memory_leaks() {\n                //\n            }\n        }\n    };\n}\n\nmod base;\nmod context;\nmod multi;\n\npub use base::*;\npub use context::*;\npub use execution::*;\npub use multi::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/multi.rs",
    "content": "use std::sync::Arc;\n\nuse burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};\nuse hashbrown::{HashMap, HashSet};\n\nuse super::{\n    StreamId,\n    execution::{ExecutionMode, Operation, Processor, StreamSegment},\n    queue::OperationQueue,\n    shared_tensors::SharedTensors,\n    store::{ExecutionPlanId, ExecutionPlanStore},\n};\nuse crate::{\n    DropOp, FusionRuntime,\n    stream::shared_tensors::{SharedTensorAnalysis, SharedTensorDropAction},\n};\n\n/// Keep track of multiple concurrent lazy streams of operations.\npub struct MultiStream<R: FusionRuntime> {\n    streams: HashMap<StreamId, Stream<R>>,\n    optimizations: ExecutionPlanStore<R::Optimization>,\n    shared_tensors: SharedTensors,\n    device: R::FusionDevice,\n    #[cfg(feature = \"memory-checks\")]\n    memory_checks: super::memory_checks::MemoryChecks,\n}\n\n#[derive(Debug)]\nenum DropAction {\n    SkipSharedTensor,\n    ForceSharedTensor(Vec<StreamId>, TensorId),\n    ContinueDrop,\n}\n\nimpl<R: FusionRuntime> MultiStream<R> {\n    pub(crate) fn new(device: R::FusionDevice) -> Self {\n        Self {\n            streams: HashMap::new(),\n            optimizations: ExecutionPlanStore::new(),\n            shared_tensors: SharedTensors::default(),\n            device,\n            #[cfg(feature = \"memory-checks\")]\n            memory_checks: super::memory_checks::MemoryChecks::default(),\n        }\n    }\n\n    /// Register a new tensor operation.\n    pub(crate) fn register(\n        &mut self,\n        streams: OperationStreams,\n        mut repr: OperationIr,\n        operation: Arc<dyn Operation<R>>,\n        handles: &mut HandleContainer<R::FusionHandle>,\n    ) {\n        let id = self.resolve_streams(&streams, handles, &mut repr);\n\n        let drop_action = match &mut repr {\n            OperationIr::Drop(tensor_ir) => Some(self.handle_drop_op(id, tensor_ir)),\n            _ => None,\n        };\n\n        let sync = match drop_action {\n            Some(DropAction::SkipSharedTensor) => return,\n            Some(DropAction::ContinueDrop) => true,\n            Some(DropAction::ForceSharedTensor(stream_ids, tid)) => {\n                for stream_id in stream_ids {\n                    if let Some(stream) = self.streams.get_mut(&stream_id) {\n                        stream.queue.variables.remove(&tid);\n                        if stream.queue.variables.is_empty() {\n                            self.streams.remove(&stream_id);\n                        }\n                    }\n                }\n                true\n            }\n            None => false,\n        };\n\n        let num_executed = self.enqueue_operation(id, repr, &streams, operation, handles);\n\n        if num_executed > 0\n            && let Some(stream) = self.streams.get_mut(&id)\n        {\n            let cleared = self.shared_tensors.on_executed_ops(id, stream);\n            self.clear_shared_tensors(&cleared, id);\n            let to_drop = self.shared_tensors.clear_tensors(cleared);\n            self.drop_shared_tensors(to_drop, handles, id);\n        }\n\n        let stream = match self.streams.get(&id) {\n            Some(val) => val,\n            None => {\n                #[cfg(feature = \"memory-checks\")]\n                self.memory_checks.check(&self.streams, handles);\n                return;\n            }\n        };\n\n        if !stream.queue.variables.is_empty() && sync {\n            // Not draining the queue can cause a memory leak when a stream is closing.\n            self.drain(handles, id);\n        }\n\n        #[cfg(feature = \"memory-checks\")]\n        self.memory_checks.check(&self.streams, handles);\n    }\n\n    /// Checks if the current operation is a drop.\n    ///\n    /// When a tensor is shared across multiple concurrent streams, dropping a tensor might cause a\n    /// problem when the same tensor is registered lazily on another stream, but not yet executed.\n    fn handle_drop_op(&mut self, id: StreamId, tensor_ir: &mut TensorIr) -> DropAction {\n        match !matches!(tensor_ir.status, TensorStatus::ReadWrite) {\n            true => {\n                let stream = self.streams.get(&id);\n                let on_drop = self\n                    .shared_tensors\n                    .on_drop(id, tensor_ir.id, stream.is_none());\n\n                match on_drop {\n                    SharedTensorDropAction::ForceDrop(streams) => {\n                        tensor_ir.status = TensorStatus::ReadWrite;\n                        DropAction::ForceSharedTensor(streams, tensor_ir.id)\n                    }\n                    SharedTensorDropAction::Skip => DropAction::SkipSharedTensor,\n                }\n            }\n            false => DropAction::ContinueDrop,\n        }\n    }\n\n    /// Enqueue an operation on the queue.\n    fn enqueue_operation(\n        &mut self,\n        id: StreamId,\n        repr: OperationIr,\n        streams: &OperationStreams,\n        operation: Arc<dyn Operation<R>>,\n        handles: &mut HandleContainer<R::FusionHandle>,\n    ) -> usize {\n        let stream = match self.streams.get_mut(&id) {\n            Some(stream) => stream,\n            None => {\n                let stream = Stream::new(self.device.clone());\n                self.streams.insert(id, stream);\n                self.streams\n                    .get_mut(&id)\n                    .expect(\"Just added, so should be included in the hashmap.\")\n            }\n        };\n\n        stream.queue.add(repr, operation, streams, id);\n\n        let len_before = stream.queue.global.len();\n        stream.processor.process(\n            Segment::new(&mut stream.queue, handles),\n            &mut self.optimizations,\n            ExecutionMode::Lazy,\n        );\n        let len_after = stream.queue.global.len();\n        let num_executed = len_before - len_after;\n\n        stream.cursor += num_executed as u64;\n\n        num_executed\n    }\n\n    /// Mark a tensor as read.\n    #[allow(unused_variables)]\n    pub fn mark_read(\n        &mut self,\n        id: StreamId,\n        ir: &TensorIr,\n        handles: &HandleContainer<R::FusionHandle>,\n    ) {\n        if !matches!(ir.status, TensorStatus::ReadWrite) {\n            return;\n        };\n\n        let stream = match self.streams.get_mut(&id) {\n            Some(val) => val,\n            None => return,\n        };\n\n        stream.queue.variables.remove(&ir.id);\n\n        if stream.queue.variables.is_empty() {\n            self.streams.remove(&id);\n        }\n\n        #[cfg(feature = \"memory-checks\")]\n        self.memory_checks.check(&self.streams, handles);\n    }\n\n    /// Drain a stream\n    pub fn drain(&mut self, handles: &mut HandleContainer<R::FusionHandle>, id: StreamId) {\n        if let Some(stream) = self.streams.get_mut(&id) {\n            let old = unsafe { StreamId::swap(id) };\n            let num_executed = stream.queue.global.len();\n            stream.processor.process(\n                Segment::new(&mut stream.queue, handles),\n                &mut self.optimizations,\n                ExecutionMode::Sync,\n            );\n            stream.cursor += num_executed as u64;\n\n            let cleared = self.shared_tensors.on_executed_ops(id, stream);\n            self.clear_shared_tensors(&cleared, id);\n            let to_drop = self.shared_tensors.clear_tensors(cleared);\n\n            self.drop_shared_tensors(to_drop, handles, id);\n            unsafe {\n                StreamId::swap(old);\n            };\n        }\n    }\n\n    /// When one of the provided streams is different from the current stream, we drain them.\n    ///\n    /// Returns the selected stream id.\n    fn resolve_streams(\n        &mut self,\n        streams: &OperationStreams,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        op: &mut OperationIr,\n    ) -> StreamId {\n        let current = streams.current;\n        let nodes = op.nodes();\n\n        let analysis = self.analyse_shared_tensors(&nodes, streams, current);\n\n        self.merge_streams_timelines(handles, &analysis, current, &nodes);\n        self.register_shared_tensors_drop(&analysis, op);\n\n        current\n    }\n\n    /// Drain the stream only if one of the tensor in the given nodes is also included in the\n    /// stream queue.\n    fn resolve_stream(\n        &mut self,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        id: StreamId,\n        nodes: &[&TensorIr],\n    ) {\n        if let Some(stream) = self.streams.get(&id) {\n            for node in nodes {\n                if stream.queue.variables.contains_key(&node.id) {\n                    self.drain(handles, id);\n                    return;\n                }\n            }\n        }\n    }\n\n    fn analyse_shared_tensors(\n        &mut self,\n        nodes: &[&TensorIr],\n        streams: &OperationStreams,\n        current: StreamId,\n    ) -> MultiSharedTensorAnalysis {\n        let mut shared_analysis = MultiSharedTensorAnalysis::default();\n\n        for node in nodes.iter() {\n            let analysis = self\n                .shared_tensors\n                .analyse(current, node, streams, &self.streams);\n            match analysis {\n                SharedTensorAnalysis::SharedFromCurrentStream => {\n                    shared_analysis.current.push((node.id, node.status));\n                }\n                SharedTensorAnalysis::NotShared => {}\n                SharedTensorAnalysis::SharedFromExistingStream {\n                    stream_id,\n                    original_cursor,\n                } => {\n                    shared_analysis\n                        .existing\n                        .push((node.id, stream_id, original_cursor));\n                }\n                SharedTensorAnalysis::SharedFromNewStream { stream_id } => {\n                    shared_analysis.new.push((node.id, stream_id));\n                }\n            }\n        }\n\n        shared_analysis\n    }\n\n    fn merge_streams_timelines(\n        &mut self,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        analysis: &MultiSharedTensorAnalysis,\n        current: StreamId,\n        nodes: &[&TensorIr],\n    ) {\n        // If we only have current tensors that are shared, we're safe to not sync the timelines.\n        if analysis.new.is_empty() && analysis.existing.is_empty() && analysis.current.is_empty() {\n            return;\n        }\n\n        let mut streams_to_sync = HashSet::new();\n        for (_tensor_id, stream_id) in analysis.new.iter() {\n            streams_to_sync.insert(*stream_id);\n        }\n\n        for (_tensor_id, stream_id, original_cursor) in analysis.existing.iter() {\n            if let Some(stream) = self.streams.get(stream_id) {\n                // We only have to sync a stream when the stream isn't up to date with\n                // the original cursor of the current operation.\n                if stream.cursor <= *original_cursor && *stream_id != current {\n                    streams_to_sync.insert(*stream_id);\n                }\n            }\n        }\n\n        for (tensor_id, status) in analysis.current.iter() {\n            if let TensorStatus::ReadWrite = status {\n                for stream in self.shared_tensors.streams_of(tensor_id) {\n                    streams_to_sync.insert(stream);\n                }\n            }\n        }\n\n        for id in streams_to_sync.drain() {\n            log::trace!(\"Drain stream {id} for use in current {current}\");\n            self.resolve_stream(handles, id, nodes);\n        }\n    }\n\n    fn register_shared_tensors_drop(\n        &mut self,\n        analysis: &MultiSharedTensorAnalysis,\n        op: &mut OperationIr,\n    ) {\n        let mut readonly_tensors = Vec::new();\n\n        for (tensor_id, _stream_id) in analysis.new.iter() {\n            readonly_tensors.push(*tensor_id);\n        }\n        for (tensor_id, _stream_id, _cursor) in analysis.existing.iter() {\n            readonly_tensors.push(*tensor_id);\n        }\n        for (tensor_id, status) in analysis.current.iter() {\n            if let TensorStatus::ReadOnly = status {\n                readonly_tensors.push(*tensor_id);\n            }\n        }\n\n        self.shared_tensors\n            .tag_manual_drop(op.mark_read_only(&readonly_tensors));\n    }\n\n    fn drop_shared_tensors(\n        &mut self,\n        tensors: Vec<TensorIr>,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        current: StreamId,\n    ) {\n        for (stream_id, s) in self.streams.iter_mut() {\n            for tensor in tensors.iter() {\n                if let Some((original, _status)) = s.queue.variables.get(&tensor.id)\n                    && original != stream_id\n                {\n                    s.queue.variables.remove(&tensor.id);\n                }\n            }\n        }\n        for tensor in tensors {\n            let streams = OperationStreams {\n                streams: HashMap::new(),\n                current,\n            };\n\n            let op = Arc::new(DropOp { id: tensor.id });\n            self.register(streams, OperationIr::Drop(tensor), op, handles);\n        }\n    }\n    fn clear_shared_tensors(&mut self, tensors: &[TensorId], current: StreamId) {\n        let mut to_remove = Vec::new();\n        for (stream_id, s) in self.streams.iter_mut() {\n            for tensor in tensors.iter() {\n                s.queue.variables.remove(tensor);\n            }\n\n            if s.queue.variables.is_empty() && current != *stream_id {\n                to_remove.push(*stream_id);\n            }\n        }\n\n        for s in to_remove {\n            self.streams.remove(&s);\n        }\n    }\n}\n\npub(crate) struct Stream<R: FusionRuntime> {\n    pub(crate) queue: OperationQueue<R>,\n    processor: Processor<R::Optimization>,\n    pub(crate) cursor: u64,\n}\n\n#[derive(new)]\nstruct Segment<'a, R: FusionRuntime> {\n    queue: &'a mut OperationQueue<R>,\n    handles: &'a mut HandleContainer<R::FusionHandle>,\n}\n\nimpl<R: FusionRuntime> StreamSegment<R::Optimization> for Segment<'_, R> {\n    fn operations(&self) -> &[OperationIr] {\n        &self.queue.relative\n    }\n\n    fn execute(&mut self, id: ExecutionPlanId, store: &mut ExecutionPlanStore<R::Optimization>) {\n        self.queue.execute(id, self.handles, store)\n    }\n}\n\nimpl<R: FusionRuntime> Stream<R> {\n    fn new(device: R::FusionDevice) -> Self {\n        Self {\n            processor: Processor::new(R::fusers(device)),\n            queue: OperationQueue::new(),\n            cursor: 0,\n        }\n    }\n}\n\n#[derive(Debug)]\n/// Manage the streams used for the current [operation](OperationIr).\npub struct OperationStreams {\n    pub(crate) streams: HashMap<TensorId, StreamId>,\n    pub(crate) current: StreamId,\n}\n\nimpl Default for OperationStreams {\n    fn default() -> Self {\n        Self {\n            streams: HashMap::new(),\n            current: StreamId::current(),\n        }\n    }\n}\n\nimpl OperationStreams {\n    /// Register a tensor in the list of streams used for the current [operation](OperationIr).\n    ///\n    /// You only need to register input tensors, not the outputs.\n    /// So init tensor operations should have no streams registered.\n    pub fn tensor<R: FusionRuntime>(&mut self, tensor: &crate::FusionTensor<R>) {\n        self.streams.insert(tensor.id, tensor.stream);\n    }\n\n    pub(crate) fn get(&self, id: TensorId) -> Option<StreamId> {\n        self.streams.get(&id).cloned()\n    }\n\n    /// Create new operation streams with the given inputs.\n    ///\n    /// The inputs are automatically registered.\n    pub fn with_inputs<'a, R: FusionRuntime + 'a, I>(tensors: I) -> Self\n    where\n        I: IntoIterator<Item = &'a crate::FusionTensor<R>>,\n    {\n        let mut streams = OperationStreams::default();\n        for tensor in tensors.into_iter() {\n            streams.tensor(tensor)\n        }\n        streams\n    }\n}\n\n#[derive(Default, Debug)]\nstruct MultiSharedTensorAnalysis {\n    /// Tensors that are shared with other streams, but we're currently executing on the same stream\n    /// the tensor was originally created.\n    current: Vec<(TensorId, TensorStatus)>,\n    /// Tensors that are shared with new streams.\n    new: Vec<(TensorId, StreamId)>,\n    /// Tensors that are shared with existing streams.\n    existing: Vec<(TensorId, StreamId, u64)>,\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/queue/base.rs",
    "content": "use std::sync::Arc;\n\nuse crate::FusionRuntime;\nuse crate::stream::{OperationConverter, OperationStreams, RelativeOps, execution::Operation};\nuse burn_backend::StreamId;\nuse burn_ir::{OperationIr, TensorId, TensorStatus};\n\nuse hashbrown::HashMap;\n\n/// A growing list of [tensor operation descriptions](OperationIr).\npub struct OperationQueue<R: FusionRuntime> {\n    /// List of operation descriptions. These contain the exact tensor IDs\n    /// and shapes so that kernels can be run correctly.\n    ///\n    /// The length of this list is the same as the length of the `operations` list.\n    pub(crate) global: Vec<OperationIr>,\n    /// List of operation descriptions. The tensor IDs and shapes are relative\n    /// because we don't need to know the exact values, but they are sufficient to\n    /// determine which operations can be fused.\n    pub(crate) relative: Vec<OperationIr>,\n    pub(crate) converter: OperationConverter,\n    pub(crate) operations: Vec<Arc<dyn Operation<R>>>,\n    pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,\n}\n\nimpl<R: FusionRuntime> Default for OperationQueue<R> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<R: FusionRuntime> OperationQueue<R> {\n    /// Create a new empty queue.\n    pub fn new() -> Self {\n        Self {\n            global: Vec::new(),\n            relative: Vec::new(),\n            converter: OperationConverter::default(),\n            operations: Vec::new(),\n            variables: HashMap::new(),\n        }\n    }\n\n    /// Add a new tensor operation to the queue.\n    ///\n    /// The new [operation intermediate representation](OperationIr) will be converted to a local\n    /// representation that can be reused when the same pattern emerge in different but similar\n    /// scenario, so that the same optimization can be used.\n    pub fn add(\n        &mut self,\n        global: OperationIr,\n        operation: Arc<dyn Operation<R>>,\n        streams: &OperationStreams,\n        current: StreamId,\n    ) {\n        for node in global.nodes() {\n            if let Some(stream_id) = streams.get(node.id) {\n                self.variables.insert(node.id, (stream_id, node.status));\n            } else {\n                self.variables.insert(node.id, (current, node.status));\n            }\n        }\n        let relative = global.to_relative(&mut self.converter);\n        self.relative.push(relative);\n        self.global.push(global);\n        self.operations.push(operation);\n    }\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn stream_id_from_different_threads() {\n        let current = StreamId::current();\n\n        let thread1 = std::thread::spawn(|| (StreamId::current(), StreamId::current()));\n        let thread2 = std::thread::spawn(StreamId::current);\n\n        let (stream_1, stream_11) = thread1.join().unwrap();\n        let stream_2 = thread2.join().unwrap();\n\n        assert_ne!(current, stream_1, \"Should be different from thread 1\");\n        assert_ne!(current, stream_2, \"Should be different from thread 2\");\n        assert_ne!(\n            stream_1, stream_2,\n            \"Should be different from different threads\"\n        );\n        assert_eq!(\n            stream_1, stream_11,\n            \"Should be the same, since same thread.\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/queue/execution.rs",
    "content": "use std::sync::Arc;\n\nuse burn_ir::{HandleContainer, TensorStatus};\n\nuse crate::{\n    FusionRuntime,\n    search::BlockOptimization,\n    stream::{\n        Context, Operation, OperationConverter, OrderedExecution, RelativeOps,\n        store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy},\n    },\n};\n\nuse super::OperationQueue;\n\nimpl<R: FusionRuntime> OperationQueue<R> {\n    /// Execute the queue partially following the execution strategy from the plan.\n    pub(crate) fn execute(\n        &mut self,\n        id: ExecutionPlanId,\n        handles: &mut HandleContainer<R::FusionHandle>,\n        store: &mut ExecutionPlanStore<R::Optimization>,\n    ) {\n        let plan = store.get_mut_unchecked(id);\n        self.execute_block_optimization(&mut plan.optimization, handles);\n    }\n\n    fn execute_block_optimization(\n        &mut self,\n        step: &mut BlockOptimization<R::Optimization>,\n        handles: &mut HandleContainer<R::FusionHandle>,\n    ) {\n        let mut operations = Vec::new();\n        core::mem::swap(&mut operations, &mut self.operations);\n\n        let (operations, num_drained) =\n            QueueExecution::run(step, &mut self.converter, handles, operations);\n\n        self.operations = operations;\n        self.drain_queue(num_drained, handles);\n    }\n\n    /// Bookkeeping after executing `num_drained` operations from the queue.\n    fn drain_queue(&mut self, num_drained: usize, handles: &mut HandleContainer<R::FusionHandle>) {\n        self.global[0..num_drained]\n            .iter()\n            .flat_map(|desc| desc.nodes())\n            .for_each(|tensor| {\n                if tensor.status == TensorStatus::ReadWrite {\n                    self.variables.remove(&tensor.id);\n                };\n                handles.free(tensor)\n            });\n\n        self.global.drain(0..num_drained);\n\n        self.reset_relative();\n    }\n\n    fn reset_relative(&mut self) {\n        self.relative.clear();\n        self.converter.clear();\n\n        for node in self.global.iter() {\n            let relative = node.to_relative(&mut self.converter);\n            self.relative.push(relative);\n        }\n    }\n}\n\n/// A queue execution has the responsibility to run the provided\n/// [optimization](FusionRuntime::Optimization) without holes.\nenum QueueExecution<'a, R: FusionRuntime> {\n    Single {\n        handles: &'a mut HandleContainer<R::FusionHandle>,\n        converter: &'a mut OperationConverter,\n        execution: OrderedExecution<R>,\n    },\n    Multiple {\n        context: &'a mut Context<'a, R::FusionHandle>,\n        execution: OrderedExecution<R>,\n    },\n}\n\nimpl<'a, R: FusionRuntime> QueueExecution<'a, R> {\n    fn run(\n        optimization: &mut BlockOptimization<R::Optimization>,\n        converter: &'a mut OperationConverter,\n        handles: &'a mut HandleContainer<R::FusionHandle>,\n        operations: Vec<Arc<dyn Operation<R>>>,\n    ) -> (Vec<Arc<dyn Operation<R>>>, usize) {\n        let execution = OrderedExecution::new(operations);\n\n        if matches!(&optimization.strategy, ExecutionStrategy::Composed(..)) {\n            let mut context = converter.context(handles);\n            let mut this = QueueExecution::Multiple {\n                context: &mut context,\n                execution,\n            };\n\n            this = this.execute_strategy(&mut optimization.strategy);\n\n            match this {\n                QueueExecution::Multiple { execution, .. } => execution.finish(),\n                _ => unreachable!(),\n            }\n        } else {\n            let mut this = QueueExecution::Single {\n                handles,\n                converter,\n                execution,\n            };\n            this = this.execute_strategy(&mut optimization.strategy);\n\n            match this {\n                QueueExecution::Single { execution, .. } => execution.finish(),\n                _ => unreachable!(),\n            }\n        }\n    }\n\n    fn execute_strategy(mut self, strategy: &mut ExecutionStrategy<R::Optimization>) -> Self {\n        match &mut self {\n            QueueExecution::Single {\n                handles,\n                converter,\n                execution,\n            } => match strategy {\n                ExecutionStrategy::Optimization { ordering, opt } => {\n                    let mut context = converter.context(handles);\n                    execution.execute_optimization(opt, &mut context, ordering.clone())\n                }\n                ExecutionStrategy::Operations { ordering } => {\n                    execution.execute_operations(handles, ordering)\n                }\n                ExecutionStrategy::Composed(_) => unreachable!(),\n            },\n            QueueExecution::Multiple { context, execution } => match strategy {\n                ExecutionStrategy::Optimization { opt, ordering } => {\n                    execution.execute_optimization(opt, context, ordering.clone());\n                }\n                ExecutionStrategy::Operations { ordering } => {\n                    execution.execute_operations(context.handles, ordering);\n                }\n                ExecutionStrategy::Composed(items) => {\n                    for item in items.iter_mut() {\n                        self = self.execute_strategy(item);\n                    }\n                }\n            },\n        };\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/queue/mod.rs",
    "content": "mod base;\nmod execution;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/shared_tensors.rs",
    "content": "use burn_backend::StreamId;\nuse burn_ir::{TensorId, TensorIr};\nuse hashbrown::HashMap;\n\nuse super::{OperationStreams, Stream};\nuse crate::FusionRuntime;\n\n#[derive(Default)]\n/// Manages tensors that are shared between multiple streams.\npub struct SharedTensors {\n    shared_tensors: HashMap<TensorId, SharedTensor>,\n    shared_tensors_manual_drop: HashMap<TensorId, TensorIr>,\n}\n\n#[derive(Default, Debug)]\n/// A tensor that is shared between multiple streams.\nstruct SharedTensor {\n    streams: HashMap<StreamId, SharedTensorState>,\n}\n\n#[derive(Debug)]\nstruct SharedTensorState {\n    cursor_current: u64,\n    cursor_origin: u64,\n}\n\n#[derive(Debug)]\n/// What do to when a tensor is dropped.\npub enum SharedTensorDropAction {\n    /// Performs the drop and removes the shared tensor from the provided list of\n    /// stream ids.\n    ForceDrop(Vec<StreamId>),\n    /// Skip the drop.\n    Skip,\n}\n\n#[derive(Debug)]\n/// Information about a shared tensor.\npub enum SharedTensorAnalysis {\n    /// The tensor is not shared.\n    NotShared,\n    /// The tensor is shared, but its original stream is the current one.\n    SharedFromCurrentStream,\n    /// The tensor is shared, and its original stream is an existing stream.\n    SharedFromExistingStream {\n        /// The stream id of the existing stream.\n        stream_id: StreamId,\n        /// The position of execution in the existing stream where the tensor was created.\n        original_cursor: u64,\n    },\n    /// The tensor is shared, and its original stream is a new one without any operation\n    /// executed.\n    SharedFromNewStream {\n        /// The stream id of the new stream.\n        stream_id: StreamId,\n    },\n}\n\nimpl SharedTensors {\n    /// Function to call when a drop operation is registered on the given stream and tensor.\n    pub fn on_drop(\n        &mut self,\n        stream_id: StreamId,\n        tensor_id: TensorId,\n        stream_completed: bool,\n    ) -> SharedTensorDropAction {\n        let mut execute_still = false;\n\n        if let Some(shared) = self.shared_tensors.get_mut(&tensor_id) {\n            if stream_completed {\n                shared.drop(stream_id);\n                execute_still = shared.streams.is_empty();\n            }\n        } else {\n            execute_still = true;\n        }\n\n        if execute_still {\n            let state = self.shared_tensors.remove(&tensor_id);\n            self.shared_tensors_manual_drop.remove(&tensor_id);\n\n            return match state {\n                Some(val) => {\n                    let streams = val.streams.keys().copied().collect();\n                    SharedTensorDropAction::ForceDrop(streams)\n                }\n                None => SharedTensorDropAction::ForceDrop(Vec::new()),\n            };\n        }\n\n        SharedTensorDropAction::Skip\n    }\n\n    /// Function to call when one or many operations were executed on the stream.\n    ///\n    /// Returns the tensor id that can be cleared with [Self::clear_tensors]\n    pub fn on_executed_ops<R: FusionRuntime>(\n        &mut self,\n        id: StreamId,\n        stream: &mut Stream<R>,\n    ) -> Vec<TensorId> {\n        let mut cleared = Vec::new();\n        for (tensor_id, state) in self.shared_tensors.iter_mut() {\n            match state.update(id, stream) {\n                SharedTensorUpdate::RemovedFromStream(no_more_stream) => {\n                    stream.queue.variables.remove(tensor_id);\n\n                    if no_more_stream {\n                        cleared.push(*tensor_id);\n                    }\n                }\n                SharedTensorUpdate::ReadyForCleanup => {\n                    cleared.push(*tensor_id);\n                }\n                SharedTensorUpdate::NoChange => {}\n            }\n        }\n        cleared\n    }\n\n    /// Clear the provided tensors and returns the list of tensors that can be manually dropped.\n    pub fn clear_tensors(&mut self, tensors: Vec<TensorId>) -> Vec<TensorIr> {\n        let mut to_drop = Vec::new();\n        for id in tensors {\n            self.shared_tensors.remove(&id);\n\n            if let Some(tensor) = self.shared_tensors_manual_drop.remove(&id) {\n                to_drop.push(tensor);\n            }\n        }\n\n        self.register_manual_drop(to_drop)\n    }\n\n    pub fn streams_of(&mut self, tensor: &TensorId) -> Vec<StreamId> {\n        let mut streams = Vec::new();\n\n        if let Some(value) = self.shared_tensors.get(tensor) {\n            for s in value.streams.keys() {\n                streams.push(*s);\n            }\n        }\n\n        streams\n    }\n\n    /// Analyses the current tensor and updates its state.\n    pub fn analyse<R: FusionRuntime>(\n        &mut self,\n        id: StreamId,\n        node: &TensorIr,\n        streams_op: &OperationStreams,\n        streams: &HashMap<StreamId, Stream<R>>,\n    ) -> SharedTensorAnalysis {\n        let stream_id = match streams_op.streams.get(&node.id) {\n            Some(val) => val,\n            None => {\n                return match self.shared_tensors.contains_key(&node.id) {\n                    true => SharedTensorAnalysis::SharedFromCurrentStream,\n                    false => SharedTensorAnalysis::NotShared,\n                };\n            }\n        };\n\n        if stream_id == &id {\n            return match self.shared_tensors.contains_key(&node.id) {\n                true => SharedTensorAnalysis::SharedFromCurrentStream,\n                false => SharedTensorAnalysis::NotShared,\n            };\n        }\n\n        // Here the node is tagged as newly shared.\n        let stream_current = streams.get(&id);\n        let stream = streams.get(stream_id);\n\n        let state = match self.shared_tensors.get_mut(&node.id) {\n            Some(state) => state,\n            None => {\n                self.shared_tensors.insert(node.id, SharedTensor::default());\n                self.shared_tensors.get_mut(&node.id).unwrap()\n            }\n        };\n\n        state.register_new_stream(id, stream_current);\n        match state.register_new_stream(*stream_id, stream) {\n            Some(origin) => SharedTensorAnalysis::SharedFromExistingStream {\n                stream_id: *stream_id,\n                original_cursor: origin,\n            },\n            None => SharedTensorAnalysis::SharedFromNewStream {\n                stream_id: *stream_id,\n            },\n        }\n    }\n\n    /// Tag the provided tensors as manually dropped.\n    pub fn tag_manual_drop(&mut self, dropped: Vec<TensorIr>) {\n        for tensor in dropped {\n            self.shared_tensors_manual_drop.insert(tensor.id, tensor);\n        }\n    }\n\n    fn register_manual_drop(&mut self, mut tensors: Vec<TensorIr>) -> Vec<TensorIr> {\n        if self.shared_tensors_manual_drop.is_empty() {\n            return tensors;\n        }\n\n        let mut to_drop = Vec::new();\n        for id in self.shared_tensors_manual_drop.keys() {\n            if !self.shared_tensors.contains_key(id) {\n                to_drop.push(*id);\n            }\n        }\n\n        for id in to_drop {\n            let entry = self.shared_tensors_manual_drop.remove(&id).unwrap();\n            tensors.push(entry);\n        }\n\n        tensors\n    }\n}\n\n/// The result from a [SharedTensor::update].\npub enum SharedTensorUpdate {\n    /// The tensor is removed from the current stream.\n    ///\n    /// Also contains if the current stream is empty.\n    RemovedFromStream(bool),\n    /// If the tensor is shared across zero streams.\n    ReadyForCleanup,\n    /// If nothing has been done from the update.\n    NoChange,\n}\n\nimpl SharedTensor {\n    /// Register the tensor as also part of the given stream.\n    ///\n    /// The stream might not exist yet when the current tensor is part of the first operation in\n    /// the newly created stream.\n    fn register_new_stream<R: FusionRuntime>(\n        &mut self,\n        id: StreamId,\n        stream: Option<&Stream<R>>,\n    ) -> Option<u64> {\n        let cursor_current = match stream {\n            Some(stream) => stream.cursor + stream.queue.global.len() as u64,\n            None => 1,\n        };\n\n        match self.streams.get_mut(&id) {\n            Some(s) => {\n                s.cursor_current = cursor_current;\n                Some(s.cursor_origin)\n            }\n            None => {\n                let state = SharedTensorState {\n                    cursor_current,\n                    cursor_origin: cursor_current,\n                };\n                self.streams.insert(id, state);\n                None\n            }\n        }\n    }\n\n    /// Update the current shared tensor state on the given stream.\n    ///\n    /// If the shared tensor is no longer needed on the stream, we will remove it from the list of\n    /// shared streams.\n    fn update<R: FusionRuntime>(&mut self, id: StreamId, stream: &Stream<R>) -> SharedTensorUpdate {\n        let entry = match self.streams.remove(&id) {\n            Some(val) => val,\n            None => {\n                return if self.streams.is_empty() {\n                    SharedTensorUpdate::ReadyForCleanup\n                } else {\n                    SharedTensorUpdate::NoChange\n                };\n            }\n        };\n\n        // We can only free the shared tensor if the latest cursor is executed.\n        if entry.cursor_current <= stream.cursor {\n            SharedTensorUpdate::RemovedFromStream(self.streams.is_empty())\n        } else {\n            self.streams.insert(id, entry);\n            SharedTensorUpdate::NoChange\n        }\n    }\n\n    fn drop(&mut self, id: StreamId) {\n        self.streams.remove(&id);\n    }\n}\n\nimpl core::fmt::Debug for SharedTensors {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_str(\"\\n==== Shared Tensors ====\\n\")?;\n\n        for sh in self.shared_tensors.iter() {\n            f.write_fmt(format_args!(\"  - Shared {}\", sh.0))?;\n            for (id, state) in sh.1.streams.iter() {\n                f.write_fmt(format_args!(\n                    \" [{}, cursor={}..{}] \",\n                    id, state.cursor_origin, state.cursor_current\n                ))?;\n            }\n            f.write_str(\"\\n\")?;\n        }\n        for sh in self.shared_tensors_manual_drop.iter() {\n            f.write_fmt(format_args!(\"  - Manual Drop {}\", sh.0))?;\n            f.write_str(\"\\n\")?;\n        }\n\n        f.write_str(\"========================\\n\")\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/store/base.rs",
    "content": "use std::sync::Arc;\n\nuse crate::search::BlockOptimization;\n\nuse super::{ExecutionPlanIndex, InsertQuery, SearchQuery};\nuse burn_ir::OperationIr;\nuse serde::{Deserialize, Serialize};\n\n/// The store that contains all explorations done on a device.\n#[derive(Default)]\npub(crate) struct ExecutionPlanStore<O> {\n    plans: Vec<ExecutionPlan<O>>,\n    index: ExecutionPlanIndex,\n}\n\n/// How a list of operations should be executed.\n#[derive(PartialEq, Debug, Clone)]\npub(crate) enum ExecutionStrategy<O> {\n    /// An optimization was found, and therefore should be executed.\n    Optimization { opt: O, ordering: Arc<Vec<usize>> },\n    /// No optimization was found, each operation should be executed individually.\n    Operations { ordering: Arc<Vec<usize>> },\n    /// A composition of multiple execution strategies.\n    Composed(Vec<Box<Self>>),\n}\n\n/// The trigger that indicates when to stop exploring.\n#[derive(Debug, PartialEq, Serialize, Deserialize)]\npub(crate) enum ExecutionTrigger {\n    OnOperations(Vec<OperationIr>),\n    OnSync,\n    Always,\n}\n\n/// The unique identifier for an exploration that was executed.\npub(crate) type ExecutionPlanId = usize;\n\n/// The outcome of an exploration that can be stored.\n#[derive(Debug)]\npub(crate) struct ExecutionPlan<O> {\n    /// The operations on which the exploration is related to.\n    pub(crate) operations: Vec<OperationIr>,\n    /// The criteria that signal when this plan should be executed. Only one trigger is necessary.\n    pub(crate) triggers: Vec<ExecutionTrigger>,\n    /// The optimization that should be used when executing this plan.\n    pub(crate) optimization: BlockOptimization<O>,\n}\n\nimpl<O: core::fmt::Debug> ExecutionPlanStore<O> {\n    pub fn new() -> Self {\n        Self {\n            plans: Vec::new(),\n            index: ExecutionPlanIndex::default(),\n        }\n    }\n\n    pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {\n        self.index.find(query)\n    }\n\n    pub fn add(&mut self, exploration: ExecutionPlan<O>) -> ExecutionPlanId {\n        if exploration.operations.is_empty() {\n            panic!(\"Can't add an empty optimization.\");\n        }\n\n        let id = self.plans.len();\n\n        self.index.insert(InsertQuery::NewPlan {\n            operations: &exploration.operations,\n            id,\n        });\n\n        self.plans.push(exploration);\n\n        id\n    }\n\n    pub fn get_mut_unchecked(&mut self, id: ExecutionPlanId) -> &mut ExecutionPlan<O> {\n        &mut self.plans[id]\n    }\n\n    pub fn get_unchecked(&self, id: ExecutionPlanId) -> &ExecutionPlan<O> {\n        &self.plans[id]\n    }\n\n    /// Add a new end condition for an optimization.\n    pub fn add_trigger(&mut self, id: ExecutionPlanId, trigger: ExecutionTrigger) {\n        let criteria = &mut self.plans[id].triggers;\n\n        if !criteria.contains(&trigger) {\n            criteria.push(trigger);\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/store/index.rs",
    "content": "use crate::stream::store::ExecutionPlanId;\nuse burn_ir::OperationIr;\nuse serde::{Deserialize, Serialize};\nuse std::{\n    collections::{HashMap, hash_map::DefaultHasher},\n    hash::{Hash, Hasher},\n};\n\n/// Index used to search optimizations.\n#[derive(Default, Serialize, Deserialize, Clone)]\npub struct ExecutionPlanIndex {\n    /// We can't use `HashMap<OperationIr, Vec<ExecutionPlanId>>` since `OperationIr`\n    /// doesn't implement [`Eq`](core::cmp::Eq).\n    ///\n    /// `OperationIr` can't implement `Eq` since float types don't implement it.\n    ///\n    /// We rely instead on [`PartialEq`](core::cmp::PartialEq) to manually handle hash collisions.\n    /// This is OK because we use `relative` operations where any scalar values are set to zeros,\n    /// see [`RelativeStreamConverter`](crate::stream::RelativeStreamConverter).\n    ///\n    /// Map from the hash of the `OperationIr` to a list of `(OperationIr, index)` pairs,\n    /// where `index` is the index of all the execution plans that start with the `OperationIr`\n    /// in the `starters` list.\n    mapping: HashMap<u64, Vec<(OperationIr, usize)>>,\n    starters: Vec<Vec<ExecutionPlanId>>,\n}\n\npub enum SearchQuery<'a> {\n    PlansStartingWith(&'a OperationIr),\n}\n\npub enum InsertQuery<'a> {\n    NewPlan {\n        operations: &'a [OperationIr],\n        id: ExecutionPlanId,\n    },\n}\n\nimpl ExecutionPlanIndex {\n    /// Search optimizations with the given [query](SearchQuery).\n    pub fn find(&self, query: SearchQuery<'_>) -> Vec<ExecutionPlanId> {\n        match query {\n            SearchQuery::PlansStartingWith(ops) => self.find_starting_with(ops),\n        }\n    }\n\n    /// Register a new optimization with the given [query](InsertQuery).\n    pub fn insert(&mut self, query: InsertQuery<'_>) {\n        match query {\n            InsertQuery::NewPlan { operations, id } => {\n                if let Some(operation) = operations.first() {\n                    self.insert_new_operation(operation, id)\n                }\n            }\n        }\n    }\n\n    /// Find execution plans starting with the `OperationIr`\n    fn find_starting_with(&self, operation: &OperationIr) -> Vec<ExecutionPlanId> {\n        let key = self.operation_key(operation);\n        let values = match self.mapping.get(&key) {\n            Some(val) => val,\n            None => return Vec::new(),\n        };\n\n        if values.is_empty() {\n            return Vec::new();\n        }\n\n        let (_, index) = match values.iter().find(|value| &value.0 == operation) {\n            Some(val) => val,\n            None => return Vec::new(),\n        };\n\n        match self.starters.get(*index) {\n            Some(value) => value.clone(),\n            None => Vec::new(),\n        }\n    }\n\n    /// Update the index for an execution plan starting with operation `ops`\n    fn insert_new_operation(&mut self, ops: &OperationIr, new_id: ExecutionPlanId) {\n        let key = self.operation_key(ops);\n        let values = match self.mapping.get_mut(&key) {\n            Some(val) => val,\n            None => {\n                // New starter ops.\n                let index = self.starters.len();\n                self.starters.push(vec![new_id]);\n                self.mapping.insert(key, vec![(ops.clone(), index)]);\n\n                return;\n            }\n        };\n        let (_, index) = match values.iter_mut().find(|value| &value.0 == ops) {\n            Some(val) => val,\n            None => {\n                // New with hash collision.\n                let index = self.starters.len();\n                self.starters.push(vec![new_id]);\n                values.push((ops.clone(), index));\n                return;\n            }\n        };\n\n        // New optimization for an existing starter.\n        self.starters\n            .get_mut(*index)\n            .expect(\"Should exist\")\n            .push(new_id);\n    }\n\n    // Hash the value of the first operation in a list.\n    fn operation_key(&self, ops: &OperationIr) -> u64 {\n        let mut hasher = DefaultHasher::new();\n        ops.hash(&mut hasher);\n        hasher.finish()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_backend::{DType, Shape};\n    use burn_ir::{\n        BinaryOpIr, NumericOperationIr, ScalarIr, ScalarOpIr, TensorId, TensorIr, TensorStatus,\n    };\n\n    use super::*;\n\n    #[test]\n    fn should_find_optimization_id_based_on_tensor_ops() {\n        let mut index = ExecutionPlanIndex::default();\n        let stream_1 = [ops_1()];\n        let optimization_id_1 = 0;\n\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_1,\n            id: optimization_id_1,\n        });\n\n        let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));\n\n        assert_eq!(found, vec![optimization_id_1]);\n    }\n\n    #[test]\n    fn should_support_multiple_optimization_ids_with_same_starting_ops() {\n        let mut index = ExecutionPlanIndex::default();\n        let stream_1 = [ops_1(), ops_2(), ops_1()];\n        let stream_2 = [ops_1(), ops_1(), ops_2()];\n        let optimization_id_1 = 0;\n        let optimization_id_2 = 1;\n\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_1,\n            id: optimization_id_1,\n        });\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_2,\n            id: optimization_id_2,\n        });\n\n        let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));\n\n        assert_eq!(found, vec![optimization_id_1, optimization_id_2]);\n    }\n\n    #[test]\n    fn should_only_find_optimization_with_correct_starting_ops() {\n        let mut index = ExecutionPlanIndex::default();\n        let stream_1 = [ops_1(), ops_1()];\n        let stream_2 = [ops_2(), ops_1()];\n        let optimization_id_1 = 0;\n        let optimization_id_2 = 1;\n\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_1,\n            id: optimization_id_1,\n        });\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_2,\n            id: optimization_id_2,\n        });\n\n        let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));\n\n        assert_eq!(found, vec![optimization_id_1]);\n    }\n\n    #[test]\n    fn should_handle_hash_collisions() {\n        let mut index = ExecutionPlanIndex::default();\n        let stream_1 = [ops_1(), ops_1()];\n        let stream_2 = [ops_3(), ops_1()];\n        let optimization_id_1 = 0;\n        let optimization_id_2 = 1;\n\n        let stream_1_key = index.operation_key(&stream_1[0]);\n        let stream_2_key = index.operation_key(&stream_2[0]);\n\n        assert_ne!(\n            stream_1_key, stream_2_key,\n            \"Ops 1 and Ops 3 should not have the same hash\"\n        ); // ops 1 and 3 have different variants, so the hash differs\n        assert_ne!(stream_1[0], stream_2[0], \"Ops 1 and Ops 3 are different.\");\n\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_1,\n            id: optimization_id_1,\n        });\n        index.insert(InsertQuery::NewPlan {\n            operations: &stream_2,\n            id: optimization_id_2,\n        });\n\n        let found = index.find(SearchQuery::PlansStartingWith(&stream_1[0]));\n\n        assert_eq!(found, vec![optimization_id_1]);\n    }\n\n    fn ops_1() -> OperationIr {\n        OperationIr::NumericFloat(\n            DType::F32,\n            NumericOperationIr::Add(BinaryOpIr {\n                lhs: TensorIr {\n                    id: TensorId::new(0),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::ReadOnly,\n                    dtype: DType::F32,\n                },\n                rhs: TensorIr {\n                    id: TensorId::new(1),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::ReadOnly,\n                    dtype: DType::F32,\n                },\n                out: TensorIr {\n                    id: TensorId::new(2),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::NotInit,\n                    dtype: DType::F32,\n                },\n            }),\n        )\n    }\n\n    fn ops_2() -> OperationIr {\n        OperationIr::NumericFloat(\n            DType::F32,\n            NumericOperationIr::AddScalar(ScalarOpIr {\n                lhs: TensorIr {\n                    id: TensorId::new(0),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::ReadOnly,\n                    dtype: DType::F32,\n                },\n                rhs: ScalarIr::Float(5.0),\n                out: TensorIr {\n                    id: TensorId::new(2),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::NotInit,\n                    dtype: DType::F32,\n                },\n            }),\n        )\n    }\n\n    fn ops_3() -> OperationIr {\n        OperationIr::NumericFloat(\n            DType::F32,\n            NumericOperationIr::Sub(BinaryOpIr {\n                lhs: TensorIr {\n                    id: TensorId::new(0),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::ReadOnly,\n                    dtype: DType::F32,\n                },\n                rhs: TensorIr {\n                    id: TensorId::new(1),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::ReadOnly,\n                    dtype: DType::F32,\n                },\n                out: TensorIr {\n                    id: TensorId::new(2),\n                    shape: Shape::new([32, 32]),\n                    status: TensorStatus::NotInit,\n                    dtype: DType::F32,\n                },\n            }),\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-fusion/src/stream/store/mod.rs",
    "content": "mod base;\nmod index;\n\npub(crate) use base::*;\npub(super) use index::*;\n"
  },
  {
    "path": "crates/burn-fusion/src/tensor.rs",
    "content": "use crate::{\n    Client, FusionBackend, FusionRuntime,\n    stream::{Operation, OperationStreams, StreamId},\n};\nuse burn_backend::{\n    DType, ExecutionError, QTensorPrimitive, Shape, TensorData, TensorMetadata,\n    quantization::QuantScheme,\n};\nuse burn_ir::{OperationIr, TensorId, TensorIr, TensorStatus};\nuse std::sync::{\n    Arc,\n    atomic::{AtomicU32, Ordering},\n};\n\n/// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind.\npub struct FusionTensor<R: FusionRuntime> {\n    /// Tensor id.\n    pub id: TensorId,\n    /// The shape of the tensor.\n    pub shape: Shape,\n    /// The fusion client.\n    pub client: Client<R>,\n    /// The datatype of the tensor.\n    pub dtype: DType,\n    /// The current stream id this tensor is on.\n    pub stream: StreamId,\n    pub(crate) count: Arc<AtomicU32>,\n}\n\nimpl<R: FusionRuntime> Clone for FusionTensor<R> {\n    fn clone(&self) -> Self {\n        self.count.fetch_add(1, Ordering::Acquire);\n\n        Self {\n            id: self.id,\n            shape: self.shape.clone(),\n            client: self.client.clone(),\n            dtype: self.dtype,\n            stream: self.stream,\n            count: self.count.clone(),\n        }\n    }\n}\n\nimpl<R: FusionRuntime> core::fmt::Debug for FusionTensor<R> {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        f.write_str(\n            format!(\n                \"{{ id: {:?}, shape: {:?}, device: {:?} }}\",\n                self.id,\n                self.shape,\n                self.client.device().clone(),\n            )\n            .as_str(),\n        )\n    }\n}\n\nimpl<R: FusionRuntime> TensorMetadata for FusionTensor<R> {\n    fn dtype(&self) -> DType {\n        self.dtype\n    }\n\n    fn shape(&self) -> Shape {\n        self.shape.clone()\n    }\n\n    fn rank(&self) -> usize {\n        self.shape.num_dims()\n    }\n}\n\nimpl<R: FusionRuntime> FusionTensor<R> {\n    pub(crate) fn new(\n        id: TensorId,\n        shape: Shape,\n        dtype: DType,\n        client: Client<R>,\n        stream: StreamId,\n    ) -> Self {\n        Self {\n            id,\n            shape,\n            client,\n            dtype,\n            stream,\n            count: Arc::new(AtomicU32::new(1)),\n        }\n    }\n\n    fn status(&self, count: u32) -> TensorStatus {\n        if count <= 1 {\n            TensorStatus::ReadWrite\n        } else {\n            TensorStatus::ReadOnly\n        }\n    }\n\n    /// Intermediate representation to be used when using an uninitialized tensor as output.\n    pub fn to_ir_out(&self) -> TensorIr {\n        TensorIr {\n            status: TensorStatus::NotInit,\n            shape: self.shape.clone(),\n            id: self.id,\n            dtype: self.dtype,\n        }\n    }\n\n    /// Intermediate representation to be used when using an initialized tensor used as input.\n    pub fn into_ir(mut self) -> TensorIr {\n        let count = self.count.load(Ordering::Acquire);\n        let status = self.status(count);\n\n        let mut shape_out = Shape::from(Vec::<usize>::new());\n        core::mem::swap(&mut self.shape, &mut shape_out);\n\n        if let TensorStatus::ReadWrite = status {\n            // Avoids an unwanted drop on the same thread.\n            //\n            // Since `drop` is called after `into_ir`, we must not register a drop if the tensor\n            // was consumed with a `ReadWrite` status.\n            self.count.fetch_add(1, Ordering::Acquire);\n        }\n\n        TensorIr {\n            status,\n            shape: shape_out,\n            id: self.id,\n            dtype: self.dtype,\n        }\n    }\n\n    pub(crate) async fn into_data<B>(self) -> Result<TensorData, ExecutionError>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let id = self.stream;\n        let client = self.client.clone();\n        let desc = self.into_ir();\n        client.read_tensor_float::<B>(desc, id).await\n    }\n\n    pub(crate) async fn q_into_data<B>(self) -> Result<TensorData, ExecutionError>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        if let DType::QFloat(_scheme) = self.dtype {\n            let id = self.stream;\n            let client = self.client.clone();\n            let desc = self.into_ir();\n            client.read_tensor_quantized::<B>(desc, id).await\n        } else {\n            panic!(\"Expected quantized float dtype, got {:?}\", self.dtype)\n        }\n    }\n\n    pub(crate) async fn int_into_data<B>(self) -> Result<TensorData, ExecutionError>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let id = self.stream;\n        let client = self.client.clone();\n        let desc = self.into_ir();\n        client.read_tensor_int::<B>(desc, id).await\n    }\n\n    pub(crate) async fn bool_into_data<B>(self) -> Result<TensorData, ExecutionError>\n    where\n        B: FusionBackend<FusionRuntime = R>,\n    {\n        let id = self.stream;\n        let client = self.client.clone();\n        let desc = self.into_ir();\n        client.read_tensor_bool::<B>(desc, id).await\n    }\n}\n\n#[derive(new, Debug)]\npub(crate) struct DropOp {\n    pub(crate) id: TensorId,\n}\n\nimpl<RO: FusionRuntime> Operation<RO> for DropOp {\n    fn execute(&self, handles: &mut burn_ir::HandleContainer<RO::FusionHandle>) {\n        handles.remove_handle(self.id);\n    }\n}\n\nimpl<R: FusionRuntime> Drop for FusionTensor<R> {\n    fn drop(&mut self) {\n        let count = self.count.fetch_sub(1, Ordering::Acquire);\n\n        // Workaround to prevent segfaults when an operation panics\n        if std::thread::panicking() {\n            return;\n        }\n\n        match self.status(count) {\n            TensorStatus::ReadWrite => {\n                let mut shape = Shape::from(Vec::<usize>::new());\n                core::mem::swap(&mut shape, &mut self.shape);\n\n                let ir = TensorIr {\n                    id: self.id,\n                    shape,\n                    status: TensorStatus::ReadWrite,\n                    dtype: self.dtype,\n                };\n                let mut streams = OperationStreams::default();\n                streams.tensor(self);\n\n                self.client\n                    .register(streams, OperationIr::Drop(ir), DropOp { id: self.id });\n            }\n            TensorStatus::ReadOnly => {}\n            TensorStatus::NotInit => {}\n        }\n    }\n}\n\nimpl<R: FusionRuntime> QTensorPrimitive for FusionTensor<R> {\n    fn scheme(&self) -> &QuantScheme {\n        if let DType::QFloat(scheme) = &self.dtype {\n            scheme\n        } else {\n            panic!(\n                \"Quantization scheme is not valid for dtype {:?}\",\n                self.dtype,\n            )\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/Cargo.toml",
    "content": "[package]\nauthors = [\"laggui <lagrange.guillaume.1@gmail.com>\", \"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Intermediate representation for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\"]\nlicense.workspace = true\nname = \"burn-ir\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-ir\"\ndocumentation = \"https://docs.rs/burn-ir\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\nstd = [\"burn-backend/std\"]\ndoc = [\"default\"]\ntracing = [\n    \"burn-backend/tracing\",\n]\n\n[dependencies]\nserde = { workspace = true }\nhashbrown = { workspace = true } # no_std compatible\n\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-ir/README.md",
    "content": "# Burn Intermediate Representation\n\nDefines an Intermediate Representation (IR) used to represent tensors and operations.\n\nThe abstraction over computation allows execution across different targets (e.g., remote backend).\nIt also enables optimization and transformation of tensor computations before execution (e.g.,\noperator fusion).\n"
  },
  {
    "path": "crates/burn-ir/src/backend.rs",
    "content": "use burn_backend::{\n    Backend, Shape,\n    tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},\n};\n\n/// A tensor representation containing a reference to a tensor resource with a given shape.\n#[derive(Clone)]\npub struct TensorHandle<H: Clone> {\n    /// The type that can be used to point to a tensor of any kind.\n    pub handle: H,\n    /// The shape associated to the tensor.\n    pub shape: Shape,\n}\n\n/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor\n/// intermediate representation for compilation purpose or other...\npub trait BackendIr: Backend {\n    /// The type that can be used to point to a tensor of any kind.\n    type Handle: Sync + Send + Clone;\n\n    /// Convert a [handle](BackendIr::Handle) to a [float tensor](Backend::FloatTensorPrimitive).\n    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self>;\n    /// Convert a [handle](BackendIr::Handle) to an [int tensor](Backend::IntTensorPrimitive).\n    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self>;\n    /// Convert a [handle](BackendIr::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).\n    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;\n    /// Convert a [handle](BackendIr::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).\n    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self>;\n\n    /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](BackendIr::Handle).\n    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle;\n    /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](BackendIr::Handle).\n    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle;\n    /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](BackendIr::Handle).\n    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;\n    /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](BackendIr::Handle).\n    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle;\n}\n\n/// Handle which points to a backend tensor primitive kind.\n#[derive(Clone, Debug)]\npub enum HandleKind<B: Backend> {\n    /// Float tensor handle.\n    Float(B::FloatTensorPrimitive),\n    /// Int tensor handle.\n    Int(B::IntTensorPrimitive),\n    /// Bool tensor handle.\n    Bool(B::BoolTensorPrimitive),\n    /// Quantized tensor handle.\n    Quantized(B::QuantizedTensorPrimitive),\n}\n\nimpl<B: Backend> HandleKind<B> {\n    /// Returns the handle kind name.\n    pub fn name(&self) -> &str {\n        match self {\n            HandleKind::Float(_) => \"float\",\n            HandleKind::Int(_) => \"int\",\n            HandleKind::Bool(_) => \"bool\",\n            HandleKind::Quantized(_) => \"quantized\",\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/src/builder.rs",
    "content": "#![allow(missing_docs)]\n\nuse alloc::vec::Vec;\nuse burn_backend::{\n    DType, Distribution, Shape, Slice, SliceOps, calculate_matmul_output,\n    ops::{\n        conv::{\n            calculate_conv_output_shape, calculate_conv_transpose_output_shape,\n            calculate_pool_output_shape,\n        },\n        unfold::calculate_unfold_shape,\n    },\n    quantization::QuantScheme,\n    tensor::IndexingUpdateOp,\n};\n\nuse crate::{ScalarIr, TensorId, TensorIr};\n\nuse super::operation::*;\n\nimpl CreationOpIr {\n    pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {\n        let out = TensorIr::uninit(new_id(), shape, dtype);\n\n        CreationOpIr { out }\n    }\n}\n\nimpl InitOperationIr {\n    pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {\n        let out = TensorIr::uninit(new_id(), shape, dtype);\n\n        InitOperationIr { out }\n    }\n}\n\nimpl RandomOpIr {\n    pub fn create(\n        shape: Shape,\n        dtype: DType,\n        distribution: Distribution,\n        new_id: impl FnOnce() -> TensorId,\n    ) -> Self {\n        let out = TensorIr::uninit(new_id(), shape, dtype);\n\n        RandomOpIr { out, distribution }\n    }\n}\n\nimpl FullOpIr {\n    pub fn create(\n        shape: Shape,\n        dtype: DType,\n        value: ScalarIr,\n        new_id: impl FnOnce() -> TensorId,\n    ) -> Self {\n        // TODO: check that ScalarIr dtype matches dtype?\n        let out = TensorIr::uninit(new_id(), shape, dtype);\n\n        FullOpIr { out, value }\n    }\n}\n\nimpl CastOpIr {\n    pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {\n        let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);\n        CastOpIr { input, out }\n    }\n}\n\nimpl ShapeOpIr {\n    pub fn expand(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {\n        let shape = input.shape.expand(shape).unwrap();\n        Self::create(input, shape, new_id)\n    }\n\n    pub fn reshape(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {\n        let shape = input.shape.reshape(shape).unwrap();\n        Self::create(input, shape, new_id)\n    }\n\n    fn create(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {\n        let out = TensorIr::uninit(new_id(), shape, input.dtype);\n        ShapeOpIr { input, out }\n    }\n}\n\n// \"Lower\" specific operations into a binary or unary op representation.\n// Useful when collecting inputs and outputs and don't care about the other semantics.\nimpl From<MatmulOpIr> for BinaryOpIr {\n    fn from(value: MatmulOpIr) -> Self {\n        Self {\n            lhs: value.lhs,\n            rhs: value.rhs,\n            out: value.out,\n        }\n    }\n}\n\nimpl From<ReduceOpIr> for UnaryOpIr {\n    fn from(value: ReduceOpIr) -> Self {\n        Self {\n            input: value.input,\n            out: value.out,\n        }\n    }\n}\n\n#[derive(Debug)]\n#[allow(missing_docs)]\npub enum IrError {\n    DTypeMismatch,\n}\n\nfn dtype_compat(lhs: &DType, rhs: &DType) -> bool {\n    let lhs_qfloat = matches!(lhs, DType::QFloat(_));\n    let rhs_qfloat = matches!(rhs, DType::QFloat(_));\n    if lhs_qfloat && (rhs_qfloat || rhs.is_float())\n        || lhs.is_float() && (rhs_qfloat || rhs.is_float())\n    {\n        true\n    } else {\n        lhs == rhs\n    }\n}\n\nfn output_check<'a, I>(inputs: I, compat: impl Fn(&DType, &DType) -> bool) -> Result<DType, IrError>\nwhere\n    I: IntoIterator<Item = &'a DType>,\n{\n    let mut iter = inputs.into_iter();\n    let first = iter.next().unwrap();\n    for d in iter {\n        if !compat(first, d) {\n            return Err(IrError::DTypeMismatch);\n        }\n    }\n    Ok(*first)\n}\n\nfn output_dtype<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {\n    output_check(inputs, |a, b| a == b)\n}\n\nfn output_dtype_mixed<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {\n    output_check(inputs, dtype_compat)\n}\n\n/// Macro to implement `create` constructors for operations with a single output.\n///\n/// Supports shape and dtype validation.\nmacro_rules! impl_ir_create {\n    (@create_fn $op:ident { $( $field:ident : $ty:ty ),* $(,)? } , $shape:expr, $dtype:expr) => {\n        #[doc = \"Create a new operation IR from the given inputs.\"]\n        #[doc = \"`new_id` should generate a unique `TensorId` for the uninitialized output tensor.\"]\n        #[allow(clippy::too_many_arguments)]\n        pub fn create($( $field : $ty ),*, new_id: impl FnOnce() -> crate::TensorId) -> $op {\n            let shape = $shape;\n            let dtype = $dtype;\n            let out = TensorIr::uninit(new_id(), shape, dtype);\n            $op { $( $field ),*, out }\n        }\n    };\n\n    // Case: simple op, single `create`\n    (\n        $op:ident { $( $field:ident : $ty:ty ),* $(,)? },\n        shape = $shape:expr,\n        dtype = $dtype:expr\n    ) => {\n        impl $op {\n            impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);\n        }\n    };\n\n    // Case: op with one additional constructor that accepts an explicit output dtype\n    (\n        $op:ident { $( $field:ident : $ty:ty ),* $(,)? },\n        shape = $shape:expr,\n        dtype = $dtype:expr,\n        $fn_name:ident ( $extra:ident : $extra_ty:ty )\n    ) => {\n        impl $op {\n            impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);\n\n            #[doc = \"Create a new operation IR from the given inputs and the given output dtype.\"]\n            #[allow(clippy::too_many_arguments)]\n            pub fn $fn_name($( $field : $ty ),*, $extra: $extra_ty, new_id: impl FnOnce() -> crate::TensorId) -> Self {\n                let shape = $shape;\n                let _ = $dtype; // still validates dtype if needed\n                let out = TensorIr::uninit(new_id(), shape, $extra);\n                $op { $( $field ),*, out }\n            }\n        }\n    };\n}\n\nimpl_ir_create!(\n    UnaryOpIr { input: TensorIr },\n    shape = input.shape.clone(),\n    dtype = input.dtype,\n    // Additional constructor for unary comparisons\n    create_comparison(bool_dtype: DType)\n);\n\nimpl_ir_create!(\n    BinaryOpIr {\n        lhs: TensorIr,\n        rhs: TensorIr\n    },\n    shape = lhs.shape.broadcast(&rhs.shape).unwrap(),\n    dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap(),\n    // Additional constructor for binary comparisons\n    create_comparison(bool_dtype: DType)\n);\n\nimpl_ir_create!(\n    ScalarOpIr {\n        lhs: TensorIr,\n        rhs: ScalarIr\n    },\n    shape = lhs.shape.clone(),\n    dtype = lhs.dtype,\n    // Additional constructor for scalar comparisons\n    create_comparison(bool_dtype: DType)\n);\n\nimpl_ir_create!(\n    MatmulOpIr {\n        lhs: TensorIr,\n        rhs: TensorIr\n    },\n    shape = calculate_matmul_output(&lhs.shape, &rhs.shape).unwrap(),\n    dtype = output_dtype_mixed([&lhs.dtype, &rhs.dtype]).unwrap(),\n    // Additional constructor for mixed dtypes\n    create_mixed(out_dtype: DType)\n);\n\nimpl_ir_create!(\n    SwapDimsOpIr {\n        input: TensorIr,\n        dim1: usize,\n        dim2: usize\n    },\n    shape = input.shape.clone().swapped(dim1, dim2).unwrap(),\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    PermuteOpIr { input: TensorIr, axes: Vec<usize> },\n    shape = input.shape.clone().permuted(&axes).unwrap(),\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    RepeatDimOpIr {\n        tensor: TensorIr,\n        dim: usize,\n        times: usize\n    },\n    shape = tensor.shape.clone().repeat(dim, times).unwrap(),\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    FlipOpIr { input: TensorIr, axes: Vec<usize> },\n    shape = input.shape.clone(), // TODO: check if axes are within the tensor dimensions\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    CatOpIr { tensors: Vec<TensorIr>, dim: usize },\n    shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap(),\n    dtype = output_dtype(tensors.iter().map(|t| &t.dtype)).unwrap()\n);\n\nimpl_ir_create!(\n    GatherOpIr {\n        tensor: TensorIr,\n        dim: usize,\n        indices: TensorIr\n    },\n    shape = indices.shape.clone(), // TODO: check dims compat between tensor and indices\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    ScatterOpIr {\n        tensor: TensorIr,\n        dim: usize,\n        indices: TensorIr,\n        value: TensorIr,\n        update: IndexingUpdateOp\n    },\n    shape = tensor.shape.clone(), // TODO: check dims compat between tensor and indices\n    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()\n);\n\nimpl_ir_create!(\n    ReduceOpIr { input: TensorIr },\n    shape = [1].into(),\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    ReduceDimOpIr {\n        input: TensorIr,\n        axis: usize\n    },\n    shape = input.shape.clone().reduce(axis).unwrap(),\n    dtype = input.dtype,\n    // Additional constructor for argument reduction\n    create_arg(ind_dtype: DType)\n);\n\nimpl_ir_create!(\n    DimOpIr {\n        input: TensorIr,\n        axis: usize\n    },\n    shape = input.shape.clone(), // TODO: check dims within rank\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    SelectOpIr {\n        tensor: TensorIr,\n        dim: usize,\n        indices: TensorIr\n    },\n    // TODO: shape.select?\n    shape = {\n        let mut s = tensor.shape.clone();\n        s[dim] = indices.shape[0];\n        s\n    },\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    SelectAssignOpIr {\n        tensor: TensorIr,\n        dim: usize,\n        indices: TensorIr,\n        value: TensorIr,\n        update: IndexingUpdateOp\n    },\n    // TODO: check value and indices shape match for dim\n    shape = tensor.shape.clone(),\n    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()\n);\n\nimpl_ir_create!(\n    SliceOpIr {\n        tensor: TensorIr,\n        ranges: Vec<Slice>,\n    },\n    shape = tensor.shape.clone().slice(&ranges).unwrap(),\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    SliceAssignOpIr {\n        tensor: TensorIr,\n        ranges: Vec<Slice>,\n        value: TensorIr\n    },\n    // TODO: check slice and value number of elements match\n    shape = tensor.shape.clone(),\n    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()\n);\n\nimpl_ir_create!(\n    MaskWhereOpIr {\n        tensor: TensorIr,\n        mask: TensorIr,\n        value: TensorIr\n    },\n    shape = Shape::broadcast_many([&tensor.shape, &mask.shape, &value.shape]).unwrap(),\n    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()\n);\n\nimpl_ir_create!(\n    MaskFillOpIr {\n        tensor: TensorIr,\n        mask: TensorIr,\n        value: ScalarIr\n    },\n    shape = tensor.shape.broadcast(&mask.shape).unwrap(),\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    ClampOpIr {\n        tensor: TensorIr,\n        min: ScalarIr,\n        max: ScalarIr\n    },\n    shape = tensor.shape.clone(),\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    AvgPool1dOpIr {\n        x: TensorIr,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool\n    },\n    shape = calculate_pool_output_shape(\n        &x.shape,\n        &[kernel_size],\n        &[stride],\n        &[padding],\n        &[1],\n        ceil_mode\n    )\n    .unwrap(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AvgPool1dBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AvgPool2dOpIr {\n        x: TensorIr,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool\n    },\n    shape = calculate_pool_output_shape(\n        &x.shape,\n        &kernel_size,\n        &stride,\n        &padding,\n        &[1, 1],\n        ceil_mode\n    )\n    .unwrap(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AvgPool2dBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    MaxPool1dOpIr {\n        x: TensorIr,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool\n    },\n    shape = calculate_pool_output_shape(\n        &x.shape,\n        &[kernel_size],\n        &[stride],\n        &[padding],\n        &[dilation],\n        ceil_mode\n    )\n    .unwrap(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    MaxPool2dOpIr {\n        x: TensorIr,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool\n    },\n    shape = calculate_pool_output_shape(\n        &x.shape,\n        &kernel_size,\n        &stride,\n        &padding,\n        &dilation,\n        ceil_mode\n    )\n    .unwrap(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    MaxPool1dWithIndicesBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n        indices: TensorIr,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    MaxPool2dWithIndicesBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n        indices: TensorIr,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AdaptiveAvgPool1dOpIr {\n        x: TensorIr,\n        output_size: usize\n    },\n    shape = Shape::new([x.shape[0], x.shape[1], output_size]),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AdaptiveAvgPool2dOpIr {\n        x: TensorIr,\n        output_size: [usize; 2]\n    },\n    shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AdaptiveAvgPool1dBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    AdaptiveAvgPool2dBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    InterpolateOpIr {\n        x: TensorIr,\n        output_size: [usize; 2],\n        options: InterpolateOptionsIr\n    },\n    shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    InterpolateBackwardOpIr {\n        x: TensorIr,\n        grad: TensorIr,\n        output_size: [usize; 2],\n        options: InterpolateOptionsIr\n    },\n    shape = x.shape.clone(),\n    dtype = x.dtype\n);\n\nimpl_ir_create!(\n    GridSample2dOpIr {\n        tensor: TensorIr,\n        grid: TensorIr,\n        options: GridSampleOptionsIr\n    },\n    // Input tensor: [N, C, H_in, W_in]\n    // Grid: [N, H_out, W_out, 2]\n    // Output: [N, C, H_out, W_out]\n    shape = Shape::new([\n        tensor.shape[0],\n        tensor.shape[1],\n        grid.shape[1],\n        grid.shape[2]\n    ]),\n    dtype = tensor.dtype\n);\n\nimpl_ir_create!(\n    Conv1dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: Conv1dOptionsIr\n    },\n    shape = calculate_conv_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.dilation,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    Conv1dXBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv1dOptionsIr\n    },\n    shape = x.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv1dWeightBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv1dOptionsIr\n    },\n    shape = weight.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv1dBiasBackwardOpIr {\n        x: TensorIr,\n        bias: TensorIr,\n        output_grad: TensorIr,\n    },\n    shape = bias.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv2dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: Conv2dOptionsIr\n    },\n    shape = calculate_conv_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.dilation,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    Conv2dXBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv2dOptionsIr\n    },\n    shape = x.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv2dWeightBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv2dOptionsIr\n    },\n    shape = weight.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv2dBiasBackwardOpIr {\n        x: TensorIr,\n        bias: TensorIr,\n        output_grad: TensorIr,\n    },\n    shape = bias.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv3dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: Conv3dOptionsIr\n    },\n    shape = calculate_conv_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.dilation,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    Conv3dXBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv3dOptionsIr\n    },\n    shape = x.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv3dWeightBackwardOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        output_grad: TensorIr,\n        options: Conv3dOptionsIr\n    },\n    shape = weight.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    Conv3dBiasBackwardOpIr {\n        x: TensorIr,\n        bias: TensorIr,\n        output_grad: TensorIr,\n    },\n    shape = bias.shape.clone(),\n    dtype = output_grad.dtype\n);\n\nimpl_ir_create!(\n    DeformConv2dOpIr {\n        x: TensorIr,\n        offset: TensorIr,\n        weight: TensorIr,\n        mask: Option<TensorIr>,\n        bias: Option<TensorIr>,\n        options: DeformableConv2dOptionsIr\n    },\n    shape = calculate_conv_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.dilation,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&offset.dtype),\n                Some(&weight.dtype),\n                mask.as_ref().map(|m| &m.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    ConvTranspose1dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: ConvTranspose1dOptionsIr\n    },\n    shape = calculate_conv_transpose_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.padding_out,\n            &options.dilation,\n            options.groups,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    ConvTranspose2dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: ConvTranspose2dOptionsIr\n    },\n    shape = calculate_conv_transpose_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.padding_out,\n            &options.dilation,\n            options.groups,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    ConvTranspose3dOpIr {\n        x: TensorIr,\n        weight: TensorIr,\n        bias: Option<TensorIr>,\n        options: ConvTranspose3dOptionsIr\n    },\n    shape = calculate_conv_transpose_output_shape(\n            &x.shape,\n            &weight.shape,\n            &options.stride,\n            &options.padding,\n            &options.padding_out,\n            &options.dilation,\n            options.groups,\n        )\n        .unwrap(),\n    dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap()\n);\n\nimpl_ir_create!(\n    UnfoldOpIr {\n        input: TensorIr,\n        dim: usize,\n        size: usize,\n        step: usize\n    },\n    shape = calculate_unfold_shape(input.shape.clone(), dim, size, step),\n    dtype = input.dtype\n);\n\nimpl_ir_create!(\n    CrossOpIr {\n        lhs: TensorIr,\n        rhs: TensorIr,\n        dim: usize\n    },\n    shape = lhs.shape.broadcast(&rhs.shape).unwrap(),\n    dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap()\n);\n\nimpl_ir_create!(\n    QuantizeOpIr {\n        tensor: TensorIr,\n        qparams: QuantizationParametersIr,\n        scheme: QuantScheme\n    },\n    shape = tensor.shape.clone(),\n    dtype = DType::QFloat(scheme)\n);\n\nimpl_ir_create!(\n    AttentionOpIr {\n        query: TensorIr,\n        key: TensorIr,\n        value: TensorIr,\n        mask: Option<TensorIr>,\n        attn_bias: Option<TensorIr>,\n        options: AttentionOptionsIr,\n    },\n    shape = Shape::new([query.shape[0], query.shape[1], query.shape[2], value.shape[3]]),\n    dtype = query.dtype\n);\n\nimpl DequantizeOpIr {\n    pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {\n        let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);\n\n        DequantizeOpIr { input, out }\n    }\n}\n\n// Operations with multiple outputs\n\nimpl ReduceDimWithIndicesOpIr {\n    pub fn create(\n        tensor: TensorIr,\n        dim: usize,\n        dtype_indices: DType,\n        mut new_id: impl FnMut() -> TensorId,\n    ) -> Self {\n        let mut shape = tensor.shape.clone();\n        shape[dim] = 1;\n        let out = TensorIr::uninit(new_id(), shape.clone(), tensor.dtype);\n        let out_indices = TensorIr::uninit(new_id(), shape.clone(), dtype_indices);\n\n        ReduceDimWithIndicesOpIr {\n            tensor,\n            dim,\n            out,\n            out_indices,\n        }\n    }\n}\n\nimpl DeformConv2dBackwardOpIr {\n    #[allow(clippy::too_many_arguments)]\n    pub fn create(\n        x: TensorIr,\n        offset: TensorIr,\n        weight: TensorIr,\n        mask: Option<TensorIr>,\n        bias: Option<TensorIr>,\n        out_grad: TensorIr,\n        options: DeformableConv2dOptionsIr,\n        mut new_id: impl FnMut() -> TensorId,\n    ) -> Self {\n        let dtype = output_dtype(\n            [\n                Some(&x.dtype),\n                Some(&weight.dtype),\n                mask.as_ref().map(|m| &m.dtype),\n                bias.as_ref().map(|b| &b.dtype),\n            ]\n            .iter()\n            .filter_map(|&d| d),\n        )\n        .unwrap();\n\n        let input_grad = TensorIr::uninit(new_id(), x.shape.clone(), dtype);\n        let offset_grad = TensorIr::uninit(new_id(), offset.shape.clone(), dtype);\n        let weight_grad = TensorIr::uninit(new_id(), weight.shape.clone(), dtype);\n        let mask_grad = mask\n            .as_ref()\n            .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));\n        let bias_grad = bias\n            .as_ref()\n            .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));\n\n        DeformConv2dBackwardOpIr {\n            x,\n            offset,\n            weight,\n            mask,\n            bias,\n            out_grad,\n            options,\n            input_grad,\n            offset_grad,\n            weight_grad,\n            mask_grad,\n            bias_grad,\n        }\n    }\n}\n\nimpl MaxPool1dWithIndicesOpIr {\n    #[allow(clippy::too_many_arguments)]\n    pub fn create(\n        x: TensorIr,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        dtype_indices: DType,\n        mut new_id: impl FnMut() -> TensorId,\n    ) -> Self {\n        let shape = calculate_pool_output_shape(\n            &x.shape,\n            &[kernel_size],\n            &[stride],\n            &[padding],\n            &[dilation],\n            ceil_mode,\n        )\n        .unwrap();\n        let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);\n        let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);\n\n        MaxPool1dWithIndicesOpIr {\n            x,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            out,\n            out_indices,\n        }\n    }\n}\n\nimpl MaxPool2dWithIndicesOpIr {\n    #[allow(clippy::too_many_arguments)]\n    pub fn create(\n        x: TensorIr,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        dtype_indices: DType,\n        mut new_id: impl FnMut() -> TensorId,\n    ) -> Self {\n        let shape = calculate_pool_output_shape(\n            &x.shape,\n            &kernel_size,\n            &stride,\n            &padding,\n            &dilation,\n            ceil_mode,\n        )\n        .unwrap();\n        let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);\n        let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);\n\n        MaxPool2dWithIndicesOpIr {\n            x,\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            out,\n            out_indices,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/src/handle.rs",
    "content": "use hashbrown::HashMap;\n\nuse crate::{BackendIr, TensorHandle, TensorId, TensorIr, TensorStatus};\n\n/// Keep all [tensor handles](BackendIr::Handle) in one place and ensure that all resources\n/// are used optimally.\n#[derive(Default)]\npub struct HandleContainer<H> {\n    handles: HashMap<TensorId, Handle<H>>,\n    counter: u64,\n}\n\nimpl<H: Clone> HandleContainer<H> {\n    /// Fork the container, useful for autotune.\n    pub fn fork(&self) -> Self {\n        let mut handles = HashMap::with_capacity(self.handles.len());\n\n        for (id, handle) in self.handles.iter() {\n            handles.insert(*id, handle.clone());\n        }\n\n        Self {\n            handles,\n            counter: self.counter,\n        }\n    }\n}\n\nimpl<H> core::fmt::Debug for HandleContainer<H> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.debug_struct(\"HandleContainer\")\n            .field(\"handles\", &self.handles.keys()) // only care about the IDs when debugging\n            .field(\"counter\", &self.counter)\n            .finish()\n    }\n}\n\n/// Backend [tensor handle](BackendIr::Handle) wrapper tracking their creation state\n#[derive(Clone)]\npub enum Handle<H> {\n    /// No [tensor handle](BackendIr::Handle) has been created yet\n    NotInit,\n    /// A [tensor handle](BackendIr::Handle) has been created\n    Existing(H),\n}\n\nimpl<H: Clone> HandleContainer<H> {\n    /// Create a new HandleContainer\n    pub fn new() -> Self {\n        Self {\n            handles: HashMap::new(),\n            counter: 0,\n        }\n    }\n\n    /// Register a handle for the given [tensor id](TensorId).\n    pub fn register_handle(&mut self, id: TensorId, handle: H) {\n        self.handles.insert(id, Handle::Existing(handle));\n    }\n\n    /// Whether an handle exists.\n    pub fn has_handle(&mut self, id: &TensorId) -> bool {\n        self.handles.contains_key(id)\n    }\n\n    /// Get the reference to a handle.\n    pub fn get_handle_ref(&self, id: &TensorId) -> Option<&H> {\n        self.handles\n            .get(id)\n            .filter(|h| !matches!(h, Handle::NotInit))\n            .map(|h| match h {\n                Handle::Existing(handle) => handle,\n                Handle::NotInit => unreachable!(),\n            })\n    }\n\n    /// Get the handle for the given [tensor id](TensorId). The status is used to determine if the\n    /// tensor should be popped out of the current tensor map, necessary for inplace operations.\n    ///\n    /// # Warnings\n    ///\n    /// Make sure the status corresponds to the operation you want to execute the handle on,\n    /// otherwise you might remove a tensor handle that will be required in the future.\n    pub fn get_handle(&mut self, id: &TensorId, status: &TensorStatus) -> H {\n        let (id, handle) = self\n            .handles\n            .remove_entry(id)\n            .unwrap_or_else(|| panic!(\"Should have handle for tensor {id:?}\"));\n\n        match handle {\n            Handle::Existing(handle) => match status {\n                TensorStatus::ReadOnly => {\n                    self.handles.insert(id, Handle::Existing(handle.clone()));\n                    handle\n                }\n                TensorStatus::ReadWrite => handle,\n                TensorStatus::NotInit => panic!(\n                    \"Cannot get uninitialized tensor {id:?}. Tensor exist but with wrong status\"\n                ),\n            },\n            Handle::NotInit => panic!(\"Cannot get uninitialized handle {id:?}.\"),\n        }\n    }\n\n    /// Get the tensor handle for the given [tensor intermediate representation](TensorIr).\n    pub fn get_tensor_handle(&mut self, tensor: &TensorIr) -> TensorHandle<H> {\n        TensorHandle {\n            handle: self.get_handle(&tensor.id, &tensor.status),\n            shape: tensor.shape.clone(),\n        }\n    }\n\n    /// Get the [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) corresponding to the\n    /// given [tensor intermediate representation](TensorIr).\n    pub fn get_float_tensor<B>(&mut self, tensor: &TensorIr) -> B::FloatTensorPrimitive\n    where\n        B: BackendIr<Handle = H>,\n    {\n        B::float_tensor(self.get_tensor_handle(tensor))\n    }\n\n    /// Get the [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) corresponding to the\n    /// given [tensor intermediate representation](TensorIr).\n    pub fn get_int_tensor<B>(&mut self, tensor: &TensorIr) -> B::IntTensorPrimitive\n    where\n        B: BackendIr<Handle = H>,\n    {\n        B::int_tensor(self.get_tensor_handle(tensor))\n    }\n\n    /// Get the [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) corresponding to the\n    /// given [tensor intermediate representation](TensorIr).\n    pub fn get_bool_tensor<B>(&mut self, tensor: &TensorIr) -> B::BoolTensorPrimitive\n    where\n        B: BackendIr<Handle = H>,\n    {\n        B::bool_tensor(self.get_tensor_handle(tensor))\n    }\n\n    /// Get the [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) corresponding to the\n    /// given [tensor intermediate representation](TensorIr).\n    pub fn get_quantized_tensor<B>(&mut self, tensor: &TensorIr) -> B::QuantizedTensorPrimitive\n    where\n        B: BackendIr<Handle = H>,\n    {\n        B::quantized_tensor(self.get_tensor_handle(tensor))\n    }\n\n    /// Register a new [float tensor](burn_backend::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId).\n    pub fn register_float_tensor<B>(&mut self, id: &TensorId, tensor: B::FloatTensorPrimitive)\n    where\n        B: BackendIr<Handle = H>,\n    {\n        let handle = B::float_tensor_handle(tensor);\n        self.handles.insert(*id, Handle::Existing(handle));\n    }\n\n    /// Register a new [quantized tensor](burn_backend::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).\n    pub fn register_quantized_tensor<B>(\n        &mut self,\n        id: &TensorId,\n        tensor: B::QuantizedTensorPrimitive,\n    ) where\n        B: BackendIr<Handle = H>,\n    {\n        let handle = B::quantized_tensor_handle(tensor);\n        self.handles.insert(*id, Handle::Existing(handle));\n    }\n\n    /// Register a new [int tensor](burn_backend::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId).\n    pub fn register_int_tensor<B>(&mut self, id: &TensorId, tensor: B::IntTensorPrimitive)\n    where\n        B: BackendIr<Handle = H>,\n    {\n        let handle = B::int_tensor_handle(tensor);\n        self.handles.insert(*id, Handle::Existing(handle));\n    }\n\n    /// Register a new [bool tensor](burn_backend::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId).\n    pub fn register_bool_tensor<B>(&mut self, id: &TensorId, tensor: B::BoolTensorPrimitive)\n    where\n        B: BackendIr<Handle = H>,\n    {\n        let handle = B::bool_tensor_handle(tensor);\n        self.handles.insert(*id, Handle::Existing(handle));\n    }\n\n    /// Remove tensor handle from container.\n    pub fn remove_handle(&mut self, id: TensorId) -> Option<Handle<H>> {\n        self.handles.remove(&id)\n    }\n\n    /// Remove tensor handle from container if writable\n    pub fn free(&mut self, tensor: &TensorIr) {\n        match tensor.status {\n            TensorStatus::ReadOnly => (),\n            TensorStatus::NotInit => (),\n            TensorStatus::ReadWrite => {\n                self.handles.remove(&tensor.id);\n            }\n        };\n    }\n\n    /// Returns the number of handles.\n    pub fn num_handles(&self) -> usize {\n        self.handles.len()\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! Burn intermediate representation.\n\nextern crate alloc;\n\nmod backend;\nmod builder;\nmod handle;\nmod operation;\nmod scalar;\nmod tensor;\n\npub use backend::*;\npub use builder::*;\npub use handle::*;\npub use operation::*;\npub use scalar::*;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-ir/src/operation.rs",
    "content": "use burn_backend::ops::AttentionModuleOptions;\nuse burn_backend::tensor::IndexingUpdateOp;\nuse core::hash::Hash;\nuse serde::{Deserialize, Serialize};\n\nuse alloc::borrow::ToOwned;\nuse alloc::boxed::Box;\nuse alloc::{string::String, vec::Vec};\n\nuse burn_backend::{\n    DType, Distribution, Slice,\n    ops::{\n        ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,\n        GridSamplePaddingMode, InterpolateMode, InterpolateOptions,\n    },\n    quantization::QuantScheme,\n};\n\nuse crate::{ScalarIr, TensorId, TensorIr, TensorStatus};\n\n/// Custom operation in fusion stream, declaring its inputs and outputs.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct CustomOpIr {\n    /// Unique identifier of the operation.\n    pub id: String,\n    /// Input tensors used in the custom operation.\n    pub inputs: Vec<TensorIr>,\n    /// Output tensors used in the custom operation.\n    pub outputs: Vec<TensorIr>,\n}\n\nimpl CustomOpIr {\n    /// Create a new custom operation intermediate representation.\n    pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {\n        Self {\n            id: id.to_owned(),\n            inputs: inputs.to_vec(),\n            outputs: outputs.to_vec(),\n        }\n    }\n\n    /// Cast the intermediate representation, and get the in and output tensors.\n    pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(\n        &self,\n    ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {\n        (\n            self.inputs.as_slice().try_into().expect(\n                \"Wrong number of inputs expected (expected {D}, is {}), check your implementation\",\n            ),\n            self.outputs.as_slice().try_into().expect(\n                \"Wrong number of outputs expected (expected {D}, is {}), check your implementation\",\n            ),\n        )\n    }\n\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        Box::new(self.inputs.iter())\n    }\n\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        Box::new(self.outputs.iter())\n    }\n}\n\n/// Describe all tensor operations possible.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(clippy::large_enum_variant)]\npub enum OperationIr {\n    /// Basic operation on a float tensor.\n    BaseFloat(BaseOperationIr),\n    /// Basic operation on an int tensor.\n    BaseInt(BaseOperationIr),\n    /// Basic operation on a bool tensor.\n    BaseBool(BaseOperationIr),\n    /// Numeric operation on a float tensor.\n    NumericFloat(DType, NumericOperationIr),\n    /// Numeric operation on an int tensor.\n    NumericInt(DType, NumericOperationIr),\n    /// Operation specific to a bool tensor.\n    Bool(BoolOperationIr),\n    /// Operation specific to an int tensor.\n    Int(IntOperationIr),\n    /// Operation specific to a float tensor.\n    Float(DType, FloatOperationIr),\n    /// Module operation.\n    Module(ModuleOperationIr),\n    /// Initialize operation.\n    Init(InitOperationIr),\n    /// A custom operation.\n    Custom(CustomOpIr),\n    /// A tensor is dropped.\n    Drop(TensorIr),\n}\n\n/// Operation intermediate representation specific to a float tensor.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum FloatOperationIr {\n    /// Operation corresponding to [exp](burn_backend::ops::FloatTensorOps::float_exp).\n    Exp(UnaryOpIr),\n    /// Operation corresponding to [log](burn_backend::ops::FloatTensorOps::float_log).\n    Log(UnaryOpIr),\n    /// Operation corresponding to [log1p](burn_backend::ops::FloatTensorOps::float_log1p).\n    Log1p(UnaryOpIr),\n    /// Operation corresponding to [erf](burn_backend::ops::FloatTensorOps::float_erf).\n    Erf(UnaryOpIr),\n    /// Operation corresponding to [powf_scalar](burn_backend::ops::FloatTensorOps::float_powf_scalar).\n    PowfScalar(ScalarOpIr),\n    /// Operation corresponding to [sqrt](burn_backend::ops::FloatTensorOps::float_sqrt).\n    Sqrt(UnaryOpIr),\n    /// Operation corresponding to [cos](burn_backend::ops::FloatTensorOps::float_cos).\n    Cos(UnaryOpIr),\n    /// Operation corresponding to [cosh](burn_backend::ops::FloatTensorOps::float_cosh).\n    Cosh(UnaryOpIr),\n    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sin).\n    Sin(UnaryOpIr),\n    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sinh).\n    Sinh(UnaryOpIr),\n    /// Operation corresponding to [tan](burn_backend::ops::FloatTensorOps::float_tan).\n    Tan(UnaryOpIr),\n    /// Operation corresponding to [tanh](burn_backend::ops::FloatTensorOps::float_tanh).\n    Tanh(UnaryOpIr),\n    /// Operation corresponding to [acos](burn_backend::ops::FloatTensorOps::float_acos).\n    ArcCos(UnaryOpIr),\n    /// Operation corresponding to [acosh](burn_backend::ops::FloatTensorOps::float_acosh).\n    ArcCosh(UnaryOpIr),\n    /// Operation corresponding to [asin](burn_backend::ops::FloatTensorOps::float_asin).\n    ArcSin(UnaryOpIr),\n    /// Operation corresponding to [asinh](burn_backend::ops::FloatTensorOps::float_asinh).\n    ArcSinh(UnaryOpIr),\n    /// Operation corresponding to [atan](burn_backend::ops::FloatTensorOps::float_atan).\n    ArcTan(UnaryOpIr),\n    /// Operation corresponding to [atanh](burn_backend::ops::FloatTensorOps::float_atanh).\n    ArcTanh(UnaryOpIr),\n    /// Operation corresponding to [atan2](burn_backend::ops::FloatTensorOps::float_atan2).\n    ArcTan2(BinaryOpIr),\n    /// Operation corresponding to [round](burn_backend::ops::FloatTensorOps::float_round).\n    Round(UnaryOpIr),\n    /// Operation corresponding to [floor](burn_backend::ops::FloatTensorOps::float_floor).\n    Floor(UnaryOpIr),\n    /// Operation corresponding to [ceil](burn_backend::ops::FloatTensorOps::float_ceil).\n    Ceil(UnaryOpIr),\n    /// Operation corresponding to [trunc](burn_backend::ops::FloatTensorOps::float_trunc).\n    Trunc(UnaryOpIr),\n    /// Operation corresponding to [into_int](burn_backend::ops::FloatTensorOps::float_into_int).\n    IntoInt(CastOpIr),\n    /// Operation corresponding to [matmul](burn_backend::ops::FloatTensorOps::float_matmul).\n    Matmul(MatmulOpIr),\n    /// Operation corresponding to [cross](burn_backend::ops::FloatTensorOps::float_cross).\n    Cross(CrossOpIr),\n    /// Operation corresponding to [random](burn_backend::ops::FloatTensorOps::float_random).\n    Random(RandomOpIr),\n    /// Operation corresponding to [recip](burn_backend::ops::FloatTensorOps::float_recip).\n    Recip(UnaryOpIr),\n    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_nan).\n    IsNan(UnaryOpIr),\n    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_inf).\n    IsInf(UnaryOpIr),\n    /// Operation corresponding to [quantize](burn_backend::ops::QTensorOps::quantize).\n    Quantize(QuantizeOpIr),\n    /// Operation corresponding to [dequantize](burn_backend::ops::QTensorOps::dequantize).\n    Dequantize(DequantizeOpIr),\n    /// Operation corresponding to [grid_sample_2d](burn_backend::ops::FloatTensorOps::float_grid_sample_2d).\n    GridSample2d(GridSample2dOpIr),\n    /// Operation corresponding to [powf](burn_backend::ops::FloatTensorOps::float_powi).\n    Powf(BinaryOpIr),\n}\n\n/// Operation intermediate representation specific to module.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum ModuleOperationIr {\n    /// Operation corresponding to [embedding](burn_backend::ops::ModuleOps::embedding).\n    Embedding(EmbeddingOpIr),\n    /// Operation corresponding to [embedding_backward](burn_backend::ops::ModuleOps::embedding_backward).\n    EmbeddingBackward(EmbeddingBackwardOpIr),\n    /// Operation corresponding to [conv1d](burn_backend::ops::ModuleOps::conv1d).\n    Conv1d(Conv1dOpIr),\n    /// Operation corresponding to [conv1d_x_backward](burn_backend::ops::ModuleOps::conv1d_x_backward).\n    Conv1dXBackward(Conv1dXBackwardOpIr),\n    /// Operation corresponding to [conv1d_weight_backward](burn_backend::ops::ModuleOps::conv1d_weight_backward).\n    Conv1dWeightBackward(Conv1dWeightBackwardOpIr),\n    /// Operation corresponding to [conv1d_bias_backward](burn_backend::ops::ModuleOps::conv1d_bias_backward).\n    Conv1dBiasBackward(Conv1dBiasBackwardOpIr),\n    /// Operation corresponding to [conv2d](burn_backend::ops::ModuleOps::conv2d).\n    Conv2d(Conv2dOpIr),\n    /// Operation corresponding to [conv2d_x_backward](burn_backend::ops::ModuleOps::conv2d_x_backward).\n    Conv2dXBackward(Conv2dXBackwardOpIr),\n    /// Operation corresponding to [conv2d_weight_backward](burn_backend::ops::ModuleOps::conv2d_weight_backward).\n    Conv2dWeightBackward(Conv2dWeightBackwardOpIr),\n    /// Operation corresponding to [conv2d_bias_backward](burn_backend::ops::ModuleOps::conv2d_bias_backward).\n    Conv2dBiasBackward(Conv2dBiasBackwardOpIr),\n    /// Operation corresponding to [conv3d](burn_backend::ops::ModuleOps::conv3d).\n    Conv3d(Conv3dOpIr),\n    /// Operation corresponding to [conv3d_x_backward](burn_backend::ops::ModuleOps::conv3d_x_backward).\n    Conv3dXBackward(Conv3dXBackwardOpIr),\n    /// Operation corresponding to [conv3d_weight_backward](burn_backend::ops::ModuleOps::conv3d_weight_backward).\n    Conv3dWeightBackward(Conv3dWeightBackwardOpIr),\n    /// Operation corresponding to [conv3d_bias_backward](burn_backend::ops::ModuleOps::conv3d_bias_backward).\n    Conv3dBiasBackward(Conv3dBiasBackwardOpIr),\n    /// Operation corresponding to [deform_conv2d](burn_backend::ops::ModuleOps::deform_conv2d)\n    DeformableConv2d(Box<DeformConv2dOpIr>),\n    /// Operation corresponding to [deform_conv2d_backward](burn_backend::ops::ModuleOps::deform_conv2d_backward)\n    DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),\n    /// Operation corresponding to [conv transpose 1d](burn_backend::ops::ModuleOps::conv_transpose1d).\n    ConvTranspose1d(ConvTranspose1dOpIr),\n    /// Operation corresponding to [conv transpose 2d](burn_backend::ops::ModuleOps::conv_transpose2d).\n    ConvTranspose2d(ConvTranspose2dOpIr),\n    /// Operation corresponding to [conv transpose 3d](burn_backend::ops::ModuleOps::conv_transpose3d).\n    ConvTranspose3d(ConvTranspose3dOpIr),\n    /// Operation corresponding to [avg pool 1d](burn_backend::ops::ModuleOps::avg_pool1d).\n    AvgPool1d(AvgPool1dOpIr),\n    /// Operation corresponding to [avg pool 2d](burn_backend::ops::ModuleOps::avg_pool2d).\n    AvgPool2d(AvgPool2dOpIr),\n    /// Operation corresponding to\n    /// [avg pool 1d backward](burn_backend::ops::ModuleOps::avg_pool1d_backward).\n    AvgPool1dBackward(AvgPool1dBackwardOpIr),\n    /// Operation corresponding to\n    /// [avg pool 2d backward](burn_backend::ops::ModuleOps::avg_pool2d_backward).\n    AvgPool2dBackward(AvgPool2dBackwardOpIr),\n    /// Operation corresponding to\n    /// [adaptive avg pool 1d](burn_backend::ops::ModuleOps::adaptive_avg_pool1d).\n    AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),\n    /// Operation corresponding to\n    /// [adaptive avg pool 2d](burn_backend::ops::ModuleOps::adaptive_avg_pool2d).\n    AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),\n    /// Operation corresponding to\n    /// [adaptive avg pool 1d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool1d_backward).\n    AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),\n    /// Operation corresponding to\n    /// [adaptive avg pool 2d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool2d_backward).\n    AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),\n    /// Operation corresponding to\n    /// [max pool 1d](burn_backend::ops::ModuleOps::max_pool1d).\n    MaxPool1d(MaxPool1dOpIr),\n    /// Operation corresponding to\n    /// [max pool 1d with indices](burn_backend::ops::ModuleOps::max_pool1d_with_indices).\n    MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),\n    /// Operation corresponding to\n    /// [max pool 1d with indices backward](burn_backend::ops::ModuleOps::max_pool1d_with_indices_backward).\n    MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),\n    /// Operation corresponding to\n    /// [max pool 2d](burn_backend::ops::ModuleOps::max_pool1d).\n    MaxPool2d(MaxPool2dOpIr),\n    /// Operation corresponding to\n    /// [max pool 2d with indices](burn_backend::ops::ModuleOps::max_pool2d_with_indices).\n    MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),\n    /// Operation corresponding to\n    /// [max pool 2d with indices backward](burn_backend::ops::ModuleOps::max_pool2d_with_indices_backward).\n    MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),\n    /// Operation corresponding to [interpolate](burn_backend::ops::ModuleOps::interpolate).\n    Interpolate(InterpolateOpIr),\n    /// Operation corresponding to [interpolate backward](burn_backend::ops::ModuleOps::interpolate_backward).\n    InterpolateBackward(InterpolateBackwardOpIr),\n    /// Operation corresponding to [attention](burn_backend::ops::ModuleOps::attention).\n    Attention(AttentionOpIr),\n}\n\n/// Basic operations that can be done on any tensor type.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum BaseOperationIr {\n    /// Operation corresponding to:\n    ///\n    /// Float => [reshape](burn_backend::ops::FloatTensorOps::float_reshape).\n    /// Int => [reshape](burn_backend::ops::IntTensorOps::int_reshape).\n    /// Bool => [reshape](burn_backend::ops::BoolTensorOps::bool_reshape).\n    Reshape(ShapeOpIr),\n\n    /// Operation corresponding to:\n    ///\n    /// Float => [swap_dims](burn_backend::ops::FloatTensorOps::float_swap_dims).\n    /// Int => [swap_dims](burn_backend::ops::IntTensorOps::int_swap_dims).\n    /// Bool => [swap_dims](burn_backend::ops::BoolTensorOps::bool_swap_dims).\n    SwapDims(SwapDimsOpIr),\n\n    /// Operation corresponding to:\n    ///\n    /// Float => [permute](burn_backend::ops::FloatTensorOps::float_permute).\n    /// Int => [permute](burn_backend::ops::IntTensorOps::int_permute).\n    /// Bool => [permute](burn_backend::ops::BoolTensorOps::bool_permute).\n    Permute(PermuteOpIr),\n\n    /// Operation corresponding to:\n    /// Float => [flip](burn_backend::ops::FloatTensorOps::float_flip).\n    /// Int => [flip](burn_backend::ops::IntTensorOps::int_flip).\n    /// Bool => [flip](burn_backend::ops::BoolTensorOps::bool_flip).\n    Flip(FlipOpIr),\n\n    /// Operation corresponding to:\n    ///\n    /// Float => [expand](burn_backend::ops::FloatTensorOps::float_expand).\n    /// Int => [expand](burn_backend::ops::IntTensorOps::int_expand).\n    /// Bool => [expand](burn_backend::ops::BoolTensorOps::bool_expand).\n    Expand(ShapeOpIr),\n\n    /// Unfold windows along an axis.\n    ///\n    Unfold(UnfoldOpIr),\n\n    /// Operation corresponding to:\n    ///\n    /// Float => [slice](burn_backend::ops::FloatTensorOps::float_slice).\n    /// Int => [slice](burn_backend::ops::IntTensorOps::int_slice).\n    /// Bool => [slice](burn_backend::ops::BoolTensorOps::bool_slice).\n    Slice(SliceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [slice assign](burn_backend::ops::FloatTensorOps::float_slice_assign).\n    /// Int => [slice assign](burn_backend::ops::IntTensorOps::int_slice_assign).\n    /// Bool => [slice assign](burn_backend::ops::BoolTensorOps::bool_slice_assign).\n    SliceAssign(SliceAssignOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [select](burn_backend::ops::FloatTensorOps::float_select).\n    /// Int => [select](burn_backend::ops::IntTensorOps::int_select).\n    /// Bool => [select](burn_backend::ops::BoolTensorOps::bool_select).\n    Select(SelectOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [select assign](burn_backend::ops::FloatTensorOps::float_select_add).\n    /// Int => [select assign](burn_backend::ops::IntTensorOps::int_select_add).\n    /// Bool => [select assign](burn_backend::ops::BoolTensorOps::bool_select_or).\n    SelectAssign(SelectAssignOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mask where](burn_backend::ops::FloatTensorOps::float_mask_where).\n    /// Int => [mask where](burn_backend::ops::IntTensorOps::int_mask_where).\n    /// Bool => [mask where](burn_backend::ops::BoolTensorOps::bool_mask_where).\n    MaskWhere(MaskWhereOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mask fill](burn_backend::ops::FloatTensorOps::float_mask_fill).\n    /// Int => [mask fill](burn_backend::ops::IntTensorOps::int_mask_fill).\n    /// Bool => [mask fill](burn_backend::ops::BoolTensorOps::bool_mask_fill).\n    MaskFill(MaskFillOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [gather](burn_backend::ops::FloatTensorOps::float_gather).\n    /// Int => [gather](burn_backend::ops::IntTensorOps::int_gather).\n    /// Bool => [gather](burn_backend::ops::BoolTensorOps::bool_gather).\n    Gather(GatherOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [scatter](burn_backend::ops::FloatTensorOps::float_scatter_add).\n    /// Int => [scatter](burn_backend::ops::IntTensorOps::int_scatter_add).\n    /// Bool => [scatter](burn_backend::ops::BoolTensorOps::bool_scatter_or).\n    Scatter(ScatterOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [equal](burn_backend::ops::FloatTensorOps::float_equal).\n    /// Int => [equal](burn_backend::ops::IntTensorOps::int_equal).\n    /// Bool => [equal](burn_backend::ops::BoolTensorOps::bool_equal).\n    Equal(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [equal elem](burn_backend::ops::FloatTensorOps::float_equal_elem).\n    /// Int => [equal elem](burn_backend::ops::IntTensorOps::int_equal_elem).\n    /// Bool => [equal elem](burn_backend::ops::BoolTensorOps::bool_equal_elem).\n    EqualElem(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [repeat dim](burn_backend::ops::FloatTensorOps::float_repeat_dim).\n    /// Int => [repeat dim](burn_backend::ops::IntTensorOps::int_repeat_dim).\n    /// Bool => [repeat dim](burn_backend::ops::BoolTensorOps::bool_repeat_dim).\n    RepeatDim(RepeatDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [cat](burn_backend::ops::FloatTensorOps::float_cat).\n    /// Int => [cat](burn_backend::ops::IntTensorOps::int_cat).\n    /// Bool => [cat](burn_backend::ops::BoolTensorOps::bool_cat).\n    Cat(CatOpIr),\n    /// Cast operation, no direct operation and should be supported by fusion backend.\n    Cast(CastOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [empty](burn_backend::ops::FloatTensorOps::float_empty).\n    /// Int => [empty](burn_backend::ops::IntTensorOps::int_empty).\n    /// Bool => [empty](burn_backend::ops::BoolTensorOps::bool_empty).\n    Empty(CreationOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [ones](burn_backend::ops::FloatTensorOps::float_ones).\n    /// Int => [ones](burn_backend::ops::IntTensorOps::int_ones).\n    /// Bool => [ones](burn_backend::ops::BoolTensorOps::bool_ones).\n    Ones(CreationOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [zeros](burn_backend::ops::FloatTensorOps::float_zeros).\n    /// Int => [zeros](burn_backend::ops::IntTensorOps::int_zeros).\n    /// Bool => [zeros](burn_backend::ops::BoolTensorOps::bool_zeros).\n    Zeros(CreationOpIr),\n}\n\n/// Numeric operations on int and float tensors.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum NumericOperationIr {\n    /// Operation corresponding to:\n    ///\n    /// Float => [add](burn_backend::ops::FloatTensorOps::float_add).\n    /// Int => [add](burn_backend::ops::IntTensorOps::int_add).\n    Add(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [add scalar](burn_backend::ops::FloatTensorOps::float_add_scalar).\n    /// Int => [add scalar](burn_backend::ops::IntTensorOps::int_add_scalar).\n    AddScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [sub](burn_backend::ops::FloatTensorOps::float_sub).\n    /// Int => [sub](burn_backend::ops::IntTensorOps::int_sub).\n    Sub(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [sub scalar](burn_backend::ops::FloatTensorOps::float_sub_scalar).\n    /// Int => [sub scalar](burn_backend::ops::IntTensorOps::int_sub_scalar).\n    SubScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [div](burn_backend::ops::FloatTensorOps::float_div).\n    /// Int => [div](burn_backend::ops::IntTensorOps::int_div).\n    Div(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [div scalar](burn_backend::ops::FloatTensorOps::float_div_scalar).\n    /// Int => [div scalar](burn_backend::ops::IntTensorOps::int_div_scalar).\n    DivScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [rem](burn_backend::ops::FloatTensorOps::float_remainder).\n    /// Int => [rem](burn_backend::ops::IntTensorOps::int_remainder).\n    Rem(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [rem scalar](burn_backend::ops::FloatTensorOps::float_remainder_scalar).\n    /// Int => [rem scalar](burn_backend::ops::IntTensorOps::int_remainder_scalar).\n    RemScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mul](burn_backend::ops::FloatTensorOps::float_mul).\n    /// Int => [mul](burn_backend::ops::IntTensorOps::int_mul).\n    Mul(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mul scalar](burn_backend::ops::FloatTensorOps::float_mul_scalar).\n    /// Int => [mul scalar](burn_backend::ops::IntTensorOps::int_mul_scalar).\n    MulScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [abs](burn_backend::ops::FloatTensorOps::float_abs).\n    /// Int => [abs](burn_backend::ops::IntTensorOps::int_abs).\n    Abs(UnaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [full](burn_backend::ops::FloatTensorOps::float_full).\n    /// Int => [full](burn_backend::ops::IntTensorOps::int_full).\n    Full(FullOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mean dim](burn_backend::ops::FloatTensorOps::float_mean_dim).\n    /// Int => [mean dim](burn_backend::ops::IntTensorOps::int_mean_dim).\n    MeanDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [mean](burn_backend::ops::FloatTensorOps::float_mean).\n    /// Int => [mean](burn_backend::ops::IntTensorOps::int_mean).\n    Mean(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [sum](burn_backend::ops::FloatTensorOps::float_sum).\n    /// Int => [sum](burn_backend::ops::IntTensorOps::int_sum).\n    Sum(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [sum dim](burn_backend::ops::FloatTensorOps::float_sum_dim).\n    /// Int => [sum dim](burn_backend::ops::IntTensorOps::int_sum_dim).\n    SumDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [prod](burn_backend::ops::FloatTensorOps::float_prod).\n    /// Int => [prod](burn_backend::ops::IntTensorOps::int_prod).\n    Prod(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [prod dim](burn_backend::ops::FloatTensorOps::float_prod_dim).\n    /// Int => [prod dim](burn_backend::ops::IntTensorOps::int_prod_dim).\n    ProdDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [greater](burn_backend::ops::FloatTensorOps::float_greater).\n    /// Int => [greater](burn_backend::ops::IntTensorOps::int_greater).\n    Greater(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [greater elem](burn_backend::ops::FloatTensorOps::float_greater_elem).\n    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).\n    GreaterElem(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [greater equal](burn_backend::ops::FloatTensorOps::float_greater_elem).\n    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).\n    GreaterEqual(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [greater equal elem](burn_backend::ops::FloatTensorOps::float_greater_equal_elem).\n    /// Int => [greater equal elem](burn_backend::ops::IntTensorOps::int_greater_equal_elem).\n    GreaterEqualElem(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [lower](burn_backend::ops::FloatTensorOps::float_lower).\n    /// Int => [lower](burn_backend::ops::IntTensorOps::int_lower).\n    Lower(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [lower elem](burn_backend::ops::FloatTensorOps::float_lower_elem).\n    /// Int => [lower elem](burn_backend::ops::IntTensorOps::int_lower_elem).\n    LowerElem(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [lower equal](burn_backend::ops::FloatTensorOps::float_lower_equal).\n    /// Int => [lower equal](burn_backend::ops::IntTensorOps::int_lower_equal).\n    LowerEqual(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [lower equal elem](burn_backend::ops::FloatTensorOps::float_lower_equal_elem).\n    /// Int => [lower equal elem](burn_backend::ops::IntTensorOps::int_lower_equal_elem).\n    LowerEqualElem(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [argmax](burn_backend::ops::FloatTensorOps::float_argmax).\n    /// Int => [argmax](burn_backend::ops::IntTensorOps::int_argmax).\n    ArgMax(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [argmin](burn_backend::ops::FloatTensorOps::float_argmin).\n    /// Int => [argmin](burn_backend::ops::IntTensorOps::int_argmin).\n    ArgMin(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [max](burn_backend::ops::FloatTensorOps::float_max).\n    /// Int => [max](burn_backend::ops::IntTensorOps::int_max).\n    Max(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [max dim with indices](burn_backend::ops::FloatTensorOps::float_max_dim_with_indices).\n    /// Int => [max dim with indices](burn_backend::ops::IntTensorOps::int_max_dim_with_indices).\n    MaxDimWithIndices(ReduceDimWithIndicesOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [min dim with indices](burn_backend::ops::FloatTensorOps::float_min_dim_with_indices).\n    /// Int => [min dim with indices](burn_backend::ops::IntTensorOps::int_min_dim_with_indices).\n    MinDimWithIndices(ReduceDimWithIndicesOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [min](burn_backend::ops::FloatTensorOps::float_min).\n    /// Int => [min](burn_backend::ops::IntTensorOps::int_min).\n    Min(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [max dim](burn_backend::ops::FloatTensorOps::float_max_dim).\n    /// Int => [max dim](burn_backend::ops::IntTensorOps::int_max_dim).\n    MaxDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [min dim](burn_backend::ops::FloatTensorOps::float_min_dim).\n    /// Int => [min dim](burn_backend::ops::IntTensorOps::int_min_dim).\n    MinDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [max_abs](burn_backend::ops::FloatTensorOps::float_max_abs).\n    /// Int => [max_abs](burn_backend::ops::IntTensorOps::int_max_abs).\n    MaxAbs(ReduceOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [max_abs dim](burn_backend::ops::FloatTensorOps::float_max_abs_dim).\n    /// Int => [max_abs dim](burn_backend::ops::IntTensorOps::int_max_abs_dim).\n    MaxAbsDim(ReduceDimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [clamp](burn_backend::ops::FloatTensorOps::float_clamp).\n    /// Int => [clamp](burn_backend::ops::IntTensorOps::int_clamp).\n    Clamp(ClampOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [random](burn_backend::ops::IntTensorOps::int_random).\n    IntRandom(RandomOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [powf](burn_backend::ops::FloatTensorOps::float_powi).\n    /// Int => [powf](burn_backend::ops::IntTensorOps::int_powi).\n    Powi(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [cumsum](burn_backend::ops::FloatTensorOps::float_cumsum).\n    /// Int => [cumsum](burn_backend::ops::IntTensorOps::int_cumsum).\n    CumSum(DimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [cumprod](burn_backend::ops::FloatTensorOps::float_cumprod).\n    /// Int => [cumprod](burn_backend::ops::IntTensorOps::int_cumprod).\n    CumProd(DimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [cummin](burn_backend::ops::FloatTensorOps::float_cummin).\n    /// Int => [cummin](burn_backend::ops::IntTensorOps::int_cummin).\n    CumMin(DimOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Float => [cummax](burn_backend::ops::FloatTensorOps::float_cummax).\n    /// Int => [cummax](burn_backend::ops::IntTensorOps::int_cummax).\n    CumMax(DimOpIr),\n}\n\n/// Operation intermediate representation specific to an int tensor.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum IntOperationIr {\n    /// Operation corresponding to [into float](burn_backend::ops::IntTensorOps::int_into_float).\n    IntoFloat(CastOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise and](burn_backend::ops::IntTensorOps::bitwise_and).\n    BitwiseAnd(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise and scalar](burn_backend::ops::IntTensorOps::bitwise_and_scalar).\n    BitwiseAndScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise or](burn_backend::ops::IntTensorOps::bitwise_or).\n    BitwiseOr(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise or scalar](burn_backend::ops::IntTensorOps::bitwise_or_scalar).\n    BitwiseOrScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise xor](burn_backend::ops::IntTensorOps::bitwise_xor).\n    BitwiseXor(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise xor scalar](burn_backend::ops::IntTensorOps::bitwise_xor_scalar).\n    BitwiseXorScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise not](burn_backend::ops::IntTensorOps::bitwise_not).\n    BitwiseNot(UnaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise left shift](burn_backend::ops::IntTensorOps::bitwise_left_shift).\n    BitwiseLeftShift(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise left shift scalar](burn_backend::ops::IntTensorOps::bitwise_left_shift_scalar).\n    BitwiseLeftShiftScalar(ScalarOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise right shift](burn_backend::ops::IntTensorOps::bitwise_right_shift).\n    BitwiseRightShift(BinaryOpIr),\n    /// Operation corresponding to:\n    ///\n    /// Int => [bitwise right shift scalar](burn_backend::ops::IntTensorOps::bitwise_right_shift_scalar).\n    BitwiseRightShiftScalar(ScalarOpIr),\n    /// Operation corresponding to [matmul](burn_backend::ops::IntTensorOps::int_matmul).\n    Matmul(MatmulOpIr),\n}\n\n/// Operation intermediate representation specific to a bool tensor.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub enum BoolOperationIr {\n    /// Operation corresponding to [into float](burn_backend::ops::BoolTensorOps::bool_into_float).\n    IntoFloat(CastOpIr),\n    /// Operation corresponding to [into int](burn_backend::ops::BoolTensorOps::bool_into_int).\n    IntoInt(CastOpIr),\n    /// Operation corresponding to [not](burn_backend::ops::BoolTensorOps::bool_not).\n    Not(UnaryOpIr),\n    /// Operation corresponding to [and](burn_backend::ops::BoolTensorOps::bool_and).\n    And(BinaryOpIr),\n    /// Operation corresponding to [or](burn_backend::ops::BoolTensorOps::bool_or).\n    Or(BinaryOpIr),\n}\n\n/// Swap dim operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct SwapDimsOpIr {\n    /// Input tensor intermediate representation.\n    pub input: TensorIr,\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n    /// The first dim to swap.\n    pub dim1: usize,\n    /// The second dim to swap.\n    pub dim2: usize,\n}\n\n/// Permute operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct PermuteOpIr {\n    /// Input tensor intermediate representation.\n    pub input: TensorIr,\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n    /// The new order of the dimensions.\n    pub axes: Vec<usize>,\n}\n\n/// Shape operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct ShapeOpIr {\n    /// Input tensor intermediate representation.\n    pub input: TensorIr,\n    /// Output tensor intermediate representation with the new shape.\n    pub out: TensorIr,\n}\n\n/// Unfold operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct UnfoldOpIr {\n    /// Input tensor intermediate representation.\n    pub input: TensorIr,\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n\n    /// The selected dim.\n    pub dim: usize,\n    /// The window size.\n    pub size: usize,\n    /// The window step along dim.\n    pub step: usize,\n}\n\n/// Flip operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct FlipOpIr {\n    /// Input tensor intermediate representation.\n    pub input: TensorIr,\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n    /// The dimensions to flip.\n    pub axes: Vec<usize>,\n}\n\n#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct RandomOpIr {\n    pub out: TensorIr,\n    pub distribution: Distribution,\n}\n\n/// Creation operation intermediate representation.\n/// As opposed to [InitOperationIr], creation operations are lazy initialized.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct CreationOpIr {\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n}\n\n/// Full operation intermediate representation.\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct FullOpIr {\n    /// Output tensor intermediate representation.\n    pub out: TensorIr,\n    /// Fill value.\n    pub value: ScalarIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n/// Declares a tensor has been initialized.\n///\n/// It is necessary to register for proper orphan detection and avoid memory leak.\npub struct InitOperationIr {\n    /// The initialized tensor.\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct BinaryOpIr {\n    pub lhs: TensorIr,\n    pub rhs: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MatmulOpIr {\n    pub lhs: TensorIr,\n    pub rhs: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct CrossOpIr {\n    pub lhs: TensorIr,\n    pub rhs: TensorIr,\n    pub out: TensorIr,\n    pub dim: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct UnaryOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ScalarOpIr {\n    pub lhs: TensorIr,\n    // TODO: Make that an enum with `Value` and `Id` variants for relative/global\n    // conversion.\n    pub rhs: ScalarIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]\n#[allow(missing_docs)]\npub struct ReduceOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]\n#[allow(missing_docs)]\npub struct ReduceDimOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n    pub axis: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct CastOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n}\n\n/// IR for operations that operate along a dimension without reducing it.\n/// Unlike `ReduceDimOpIr`, the output shape is the same as the input shape.\n#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]\n#[allow(missing_docs)]\npub struct DimOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n    pub axis: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct GatherOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub indices: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ScatterOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub indices: TensorIr,\n    pub value: TensorIr,\n    pub update: IndexingUpdateOp,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct SelectOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub indices: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct SelectAssignOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub indices: TensorIr,\n    pub value: TensorIr,\n    pub update: IndexingUpdateOp,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct SliceOpIr {\n    pub tensor: TensorIr,\n    pub ranges: Vec<Slice>,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct SliceAssignOpIr {\n    pub tensor: TensorIr,\n    pub ranges: Vec<burn_backend::Slice>,\n    pub value: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaskWhereOpIr {\n    pub tensor: TensorIr,\n    pub mask: TensorIr,\n    pub value: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaskFillOpIr {\n    pub tensor: TensorIr,\n    pub mask: TensorIr,\n    pub value: ScalarIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ClampOpIr {\n    pub tensor: TensorIr,\n    pub min: ScalarIr,\n    pub max: ScalarIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct RepeatDimOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub times: usize,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct CatOpIr {\n    pub tensors: Vec<TensorIr>,\n    pub dim: usize,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ReduceDimWithIndicesOpIr {\n    pub tensor: TensorIr,\n    pub dim: usize,\n    pub out: TensorIr,\n    pub out_indices: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct EmbeddingOpIr {\n    pub weights: TensorIr,\n    pub indices: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct EmbeddingBackwardOpIr {\n    pub weights: TensorIr,\n    pub out_grad: TensorIr,\n    pub indices: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv1dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: Conv1dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv1dXBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv1dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv1dWeightBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv1dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv1dBiasBackwardOpIr {\n    pub x: TensorIr,\n    pub bias: TensorIr,\n    pub output_grad: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv2dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: Conv2dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv2dXBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv2dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv2dWeightBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv2dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv2dBiasBackwardOpIr {\n    pub x: TensorIr,\n    pub bias: TensorIr,\n    pub output_grad: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct DeformConv2dOpIr {\n    pub x: TensorIr,\n    pub offset: TensorIr,\n    pub weight: TensorIr,\n    pub mask: Option<TensorIr>,\n    pub bias: Option<TensorIr>,\n    pub options: DeformableConv2dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct DeformConv2dBackwardOpIr {\n    pub x: TensorIr,\n    pub offset: TensorIr,\n    pub weight: TensorIr,\n    pub mask: Option<TensorIr>,\n    pub bias: Option<TensorIr>,\n    pub out_grad: TensorIr,\n    pub options: DeformableConv2dOptionsIr,\n    pub input_grad: TensorIr,\n    pub offset_grad: TensorIr,\n    pub weight_grad: TensorIr,\n    pub mask_grad: Option<TensorIr>,\n    pub bias_grad: Option<TensorIr>,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv3dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: Conv3dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv3dXBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv3dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv3dWeightBackwardOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub output_grad: TensorIr,\n    pub options: Conv3dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv3dBiasBackwardOpIr {\n    pub x: TensorIr,\n    pub bias: TensorIr,\n    pub output_grad: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose1dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: ConvTranspose1dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose2dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: ConvTranspose2dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose3dOpIr {\n    pub x: TensorIr,\n    pub weight: TensorIr,\n    pub bias: Option<TensorIr>,\n    pub options: ConvTranspose3dOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv1dOptionsIr {\n    pub stride: [usize; 1],\n    pub padding: [usize; 1],\n    pub dilation: [usize; 1],\n    pub groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv2dOptionsIr {\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub dilation: [usize; 2],\n    pub groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct DeformableConv2dOptionsIr {\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub dilation: [usize; 2],\n    pub weight_groups: usize,\n    pub offset_groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct Conv3dOptionsIr {\n    pub stride: [usize; 3],\n    pub padding: [usize; 3],\n    pub dilation: [usize; 3],\n    pub groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose1dOptionsIr {\n    pub stride: [usize; 1],\n    pub padding: [usize; 1],\n    pub padding_out: [usize; 1],\n    pub dilation: [usize; 1],\n    pub groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose2dOptionsIr {\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub padding_out: [usize; 2],\n    pub dilation: [usize; 2],\n    pub groups: usize,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct ConvTranspose3dOptionsIr {\n    pub stride: [usize; 3],\n    pub padding: [usize; 3],\n    pub padding_out: [usize; 3],\n    pub dilation: [usize; 3],\n    pub groups: usize,\n}\n\n/// Quantization parameters intermediate representation.\n#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]\npub struct QuantizationParametersIr {\n    /// The scaling factor.\n    pub scales: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct QuantizeOpIr {\n    pub tensor: TensorIr,\n    pub qparams: QuantizationParametersIr,\n    pub scheme: QuantScheme,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct DequantizeOpIr {\n    pub input: TensorIr,\n    pub out: TensorIr,\n}\n\nimpl From<ConvOptions<1>> for Conv1dOptionsIr {\n    fn from(value: ConvOptions<1>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<ConvOptions<2>> for Conv2dOptionsIr {\n    fn from(value: ConvOptions<2>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<ConvOptions<3>> for Conv3dOptionsIr {\n    fn from(value: ConvOptions<3>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {\n    fn from(value: DeformConvOptions<2>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            dilation: value.dilation,\n            weight_groups: value.weight_groups,\n            offset_groups: value.offset_groups,\n        }\n    }\n}\n\nimpl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {\n    fn from(value: ConvTransposeOptions<1>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            padding_out: value.padding_out,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {\n    fn from(value: ConvTransposeOptions<2>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            padding_out: value.padding_out,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {\n    fn from(value: ConvTransposeOptions<3>) -> Self {\n        Self {\n            stride: value.stride,\n            padding: value.padding,\n            padding_out: value.padding_out,\n            dilation: value.dilation,\n            groups: value.groups,\n        }\n    }\n}\n\nimpl From<Conv1dOptionsIr> for ConvOptions<1> {\n    fn from(val: Conv1dOptionsIr) -> Self {\n        ConvOptions {\n            stride: val.stride,\n            padding: val.padding,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\nimpl From<Conv2dOptionsIr> for ConvOptions<2> {\n    fn from(val: Conv2dOptionsIr) -> Self {\n        ConvOptions {\n            stride: val.stride,\n            padding: val.padding,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\nimpl From<Conv3dOptionsIr> for ConvOptions<3> {\n    fn from(val: Conv3dOptionsIr) -> Self {\n        ConvOptions {\n            stride: val.stride,\n            padding: val.padding,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\nimpl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {\n    fn from(value: DeformableConv2dOptionsIr) -> Self {\n        DeformConvOptions {\n            stride: value.stride,\n            padding: value.padding,\n            dilation: value.dilation,\n            weight_groups: value.weight_groups,\n            offset_groups: value.offset_groups,\n        }\n    }\n}\n\nimpl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {\n    fn from(val: ConvTranspose1dOptionsIr) -> Self {\n        ConvTransposeOptions {\n            stride: val.stride,\n            padding: val.padding,\n            padding_out: val.padding_out,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\nimpl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {\n    fn from(val: ConvTranspose2dOptionsIr) -> Self {\n        ConvTransposeOptions {\n            stride: val.stride,\n            padding: val.padding,\n            padding_out: val.padding_out,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\nimpl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {\n    fn from(val: ConvTranspose3dOptionsIr) -> Self {\n        ConvTransposeOptions {\n            stride: val.stride,\n            padding: val.padding,\n            padding_out: val.padding_out,\n            dilation: val.dilation,\n            groups: val.groups,\n        }\n    }\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AvgPool1dOpIr {\n    pub x: TensorIr,\n    pub kernel_size: usize,\n    pub stride: usize,\n    pub padding: usize,\n    pub count_include_pad: bool,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AvgPool2dOpIr {\n    pub x: TensorIr,\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub count_include_pad: bool,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AvgPool1dBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub kernel_size: usize,\n    pub stride: usize,\n    pub padding: usize,\n    pub count_include_pad: bool,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AvgPool2dBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub count_include_pad: bool,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AdaptiveAvgPool1dOpIr {\n    pub x: TensorIr,\n    pub output_size: usize,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AdaptiveAvgPool2dOpIr {\n    pub x: TensorIr,\n    pub output_size: [usize; 2],\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AdaptiveAvgPool1dBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AdaptiveAvgPool2dBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaxPool1dOpIr {\n    pub x: TensorIr,\n    pub kernel_size: usize,\n    pub stride: usize,\n    pub padding: usize,\n    pub dilation: usize,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaxPool1dWithIndicesOpIr {\n    pub x: TensorIr,\n    pub kernel_size: usize,\n    pub stride: usize,\n    pub padding: usize,\n    pub dilation: usize,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n    pub out_indices: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaxPool1dWithIndicesBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub indices: TensorIr,\n    pub kernel_size: usize,\n    pub stride: usize,\n    pub padding: usize,\n    pub dilation: usize,\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaxPool2dOpIr {\n    pub x: TensorIr,\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub dilation: [usize; 2],\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[allow(missing_docs)]\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\npub struct MaxPool2dWithIndicesOpIr {\n    pub x: TensorIr,\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub dilation: [usize; 2],\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n    pub out_indices: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct MaxPool2dWithIndicesBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub indices: TensorIr,\n    pub kernel_size: [usize; 2],\n    pub stride: [usize; 2],\n    pub padding: [usize; 2],\n    pub dilation: [usize; 2],\n    pub ceil_mode: bool,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub enum InterpolateModeIr {\n    Nearest,\n    Bilinear,\n    Bicubic,\n    Lanczos3,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct InterpolateOptionsIr {\n    pub mode: InterpolateModeIr,\n    pub align_corners: bool,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct InterpolateOpIr {\n    pub x: TensorIr,\n    pub output_size: [usize; 2],\n    pub options: InterpolateOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AttentionOptionsIr {\n    pub scale: Option<ScalarIr>,\n    pub softcap: Option<ScalarIr>,\n    pub is_causal: bool,\n}\n\nimpl From<AttentionOptionsIr> for AttentionModuleOptions {\n    fn from(ir: AttentionOptionsIr) -> Self {\n        AttentionModuleOptions {\n            scale: ir.scale.map(|s| s.elem()),\n            softcap: ir.softcap.map(|s| s.elem()),\n            is_causal: ir.is_causal,\n        }\n    }\n}\n\nimpl From<AttentionModuleOptions> for AttentionOptionsIr {\n    fn from(ir: AttentionModuleOptions) -> Self {\n        AttentionOptionsIr {\n            scale: ir.scale.map(ScalarIr::Float),\n            softcap: ir.softcap.map(ScalarIr::Float),\n            is_causal: ir.is_causal,\n        }\n    }\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct AttentionOpIr {\n    pub query: TensorIr,\n    pub key: TensorIr,\n    pub value: TensorIr,\n    pub mask: Option<TensorIr>,\n    pub attn_bias: Option<TensorIr>,\n    pub options: AttentionOptionsIr,\n    pub out: TensorIr,\n}\n\nimpl From<InterpolateModeIr> for InterpolateMode {\n    fn from(val: InterpolateModeIr) -> Self {\n        match val {\n            InterpolateModeIr::Nearest => Self::Nearest,\n            InterpolateModeIr::Bilinear => Self::Bilinear,\n            InterpolateModeIr::Bicubic => Self::Bicubic,\n            InterpolateModeIr::Lanczos3 => Self::Lanczos3,\n        }\n    }\n}\n\nimpl From<InterpolateOptionsIr> for InterpolateOptions {\n    fn from(val: InterpolateOptionsIr) -> Self {\n        Self::new(val.mode.into()).with_align_corners(val.align_corners)\n    }\n}\n\nimpl From<InterpolateMode> for InterpolateModeIr {\n    fn from(val: InterpolateMode) -> Self {\n        match val {\n            InterpolateMode::Nearest => Self::Nearest,\n            InterpolateMode::Bilinear => Self::Bilinear,\n            InterpolateMode::Bicubic => Self::Bicubic,\n            InterpolateMode::Lanczos3 => Self::Lanczos3,\n        }\n    }\n}\n\nimpl From<InterpolateOptions> for InterpolateOptionsIr {\n    fn from(val: InterpolateOptions) -> Self {\n        Self {\n            mode: val.mode.into(),\n            align_corners: val.align_corners,\n        }\n    }\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct InterpolateBackwardOpIr {\n    pub x: TensorIr,\n    pub grad: TensorIr,\n    pub output_size: [usize; 2],\n    pub options: InterpolateOptionsIr,\n    pub out: TensorIr,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub enum GridSamplePaddingModeIr {\n    Zeros,\n    Border,\n    Reflection,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct GridSampleOptionsIr {\n    pub mode: InterpolateModeIr,\n    pub padding_mode: GridSamplePaddingModeIr,\n    pub align_corners: bool,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub struct GridSample2dOpIr {\n    pub tensor: TensorIr,\n    pub grid: TensorIr,\n    pub options: GridSampleOptionsIr,\n    pub out: TensorIr,\n}\n\nimpl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {\n    fn from(val: GridSamplePaddingModeIr) -> Self {\n        match val {\n            GridSamplePaddingModeIr::Zeros => Self::Zeros,\n            GridSamplePaddingModeIr::Border => Self::Border,\n            GridSamplePaddingModeIr::Reflection => Self::Reflection,\n        }\n    }\n}\n\nimpl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {\n    fn from(val: GridSamplePaddingMode) -> Self {\n        match val {\n            GridSamplePaddingMode::Zeros => Self::Zeros,\n            GridSamplePaddingMode::Border => Self::Border,\n            GridSamplePaddingMode::Reflection => Self::Reflection,\n        }\n    }\n}\n\nimpl From<GridSampleOptionsIr> for GridSampleOptions {\n    fn from(val: GridSampleOptionsIr) -> Self {\n        Self {\n            mode: val.mode.into(),\n            padding_mode: val.padding_mode.into(),\n            align_corners: val.align_corners,\n        }\n    }\n}\n\nimpl From<GridSampleOptions> for GridSampleOptionsIr {\n    fn from(val: GridSampleOptions) -> Self {\n        Self {\n            mode: val.mode.into(),\n            padding_mode: val.padding_mode.into(),\n            align_corners: val.align_corners,\n        }\n    }\n}\n\nimpl OperationIr {\n    /// Get all input [tensors](TensorIr) involved with the current operation.\n    pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {\n        match self {\n            OperationIr::BaseFloat(repr) => repr.inputs(),\n            OperationIr::BaseInt(repr) => repr.inputs(),\n            OperationIr::BaseBool(repr) => repr.inputs(),\n            OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),\n            OperationIr::NumericInt(_dtype, repr) => repr.inputs(),\n            OperationIr::Bool(repr) => repr.inputs(),\n            OperationIr::Int(repr) => repr.inputs(),\n            OperationIr::Float(_dtype, repr) => repr.inputs(),\n            OperationIr::Module(repr) => repr.inputs(),\n            OperationIr::Init(repr) => repr.inputs(),\n            OperationIr::Custom(repr) => repr.inputs(),\n            OperationIr::Drop(repr) => Box::new([repr].into_iter()),\n        }\n    }\n\n    /// Get all output [tensors](TensorIr) involved with the current operation.\n    pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {\n        match self {\n            OperationIr::BaseFloat(repr) => repr.outputs(),\n            OperationIr::BaseInt(repr) => repr.outputs(),\n            OperationIr::BaseBool(repr) => repr.outputs(),\n            OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),\n            OperationIr::NumericInt(_dtype, repr) => repr.outputs(),\n            OperationIr::Bool(repr) => repr.outputs(),\n            OperationIr::Int(repr) => repr.outputs(),\n            OperationIr::Float(_dtype, repr) => repr.outputs(),\n            OperationIr::Module(repr) => repr.outputs(),\n            OperationIr::Init(repr) => repr.outputs(),\n            OperationIr::Custom(repr) => repr.outputs(),\n            OperationIr::Drop(_repr) => Box::new([].into_iter()),\n        }\n    }\n\n    /// Get all [tensor](TensorIr) involved with the current operation.\n    pub fn nodes(&self) -> Vec<&TensorIr> {\n        self.inputs().chain(self.outputs()).collect()\n    }\n\n    /// Set the given nodes that are [read write](super::TensorStatus::ReadWrite) to\n    /// [read only](super::TensorStatus::ReadOnly) in the current operation.\n    ///\n    /// Returns the tensor that were updated with their original representation.\n    pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        match self {\n            OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),\n            OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),\n            OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),\n            OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),\n            OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),\n            OperationIr::Bool(repr) => repr.mark_read_only(nodes),\n            OperationIr::Int(repr) => repr.mark_read_only(nodes),\n            OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),\n            OperationIr::Module(repr) => repr.mark_read_only(nodes),\n            OperationIr::Init(_) => Vec::new(),\n            OperationIr::Drop(repr) => {\n                let mut output = Vec::new();\n                repr.mark_read_only(nodes, &mut output);\n                output\n            }\n            OperationIr::Custom(repr) => {\n                let mut output = Vec::new();\n\n                for input in repr.inputs.iter_mut() {\n                    input.mark_read_only(nodes, &mut output);\n                }\n\n                output\n            }\n        }\n    }\n}\n\nimpl BaseOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),\n            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),\n            BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),\n            BaseOperationIr::Scatter(repr) => {\n                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())\n            }\n            BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),\n            BaseOperationIr::SelectAssign(repr) => {\n                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())\n            }\n            BaseOperationIr::MaskWhere(repr) => {\n                Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())\n            }\n            BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),\n            BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),\n            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),\n            BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),\n            BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),\n            BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),\n            BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),\n            BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),\n        }\n    }\n\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),\n            BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            BaseOperationIr::Reshape(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::SwapDims(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Permute(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n\n            BaseOperationIr::Expand(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n\n            BaseOperationIr::Flip(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Slice(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::SliceAssign(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.value.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Gather(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Scatter(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n                repr.value.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Select(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::SelectAssign(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n                repr.value.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::MaskWhere(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.mask.mark_read_only(nodes, &mut output);\n                repr.value.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::MaskFill(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.mask.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Equal(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::EqualElem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::RepeatDim(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Cat(repr) => {\n                for t in repr.tensors.iter_mut() {\n                    t.mark_read_only(nodes, &mut output);\n                }\n            }\n            BaseOperationIr::Cast(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Unfold(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BaseOperationIr::Empty(_) => {}\n            BaseOperationIr::Zeros(_) => {}\n            BaseOperationIr::Ones(_) => {}\n        };\n\n        output\n    }\n}\n\nimpl NumericOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),\n            NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),\n            NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Full(_repr) => Box::new([].into_iter()),\n            NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),\n            NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),\n            NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),\n            NumericOperationIr::Powi(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            NumericOperationIr::CumMin(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::CumMax(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::CumProd(repr) => Box::new([&repr.input].into_iter()),\n            NumericOperationIr::CumSum(repr) => Box::new([&repr.input].into_iter()),\n        }\n    }\n\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MaxDimWithIndices(repr) => {\n                Box::new([&repr.out, &repr.out_indices].into_iter())\n            }\n            NumericOperationIr::MinDimWithIndices(repr) => {\n                Box::new([&repr.out, &repr.out_indices].into_iter())\n            }\n            NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::Powi(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),\n            NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            NumericOperationIr::Add(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::AddScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Sub(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::SubScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Mul(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MulScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Div(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::DivScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Rem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::RemScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::GreaterElem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::GreaterEqualElem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::LowerElem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::LowerEqualElem(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Greater(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::GreaterEqual(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Lower(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::LowerEqual(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::ArgMax(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::ArgMin(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Clamp(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Abs(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Full(_) => {}\n            NumericOperationIr::MeanDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Mean(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Sum(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::SumDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Prod(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::ProdDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Max(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MaxDimWithIndices(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MinDimWithIndices(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::Min(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MaxDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MinDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MaxAbs(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::MaxAbsDim(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::IntRandom(_) => {}\n            NumericOperationIr::Powi(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::CumSum(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::CumProd(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::CumMin(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            NumericOperationIr::CumMax(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n        };\n\n        output\n    }\n}\n\nimpl FloatOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            FloatOperationIr::Random(_repr) => Box::new([].into_iter()),\n            FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Quantize(repr) => {\n                Box::new([&repr.tensor, &repr.qparams.scales].into_iter())\n            }\n            FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::GridSample2d(repr) => {\n                Box::new([&repr.tensor, &repr.grid].into_iter())\n            }\n            FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),\n            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            FloatOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n        }\n    }\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),\n            FloatOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            FloatOperationIr::Matmul(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Cross(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Random(_) => {}\n            FloatOperationIr::Exp(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Log(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Log1p(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Erf(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Recip(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::PowfScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Sqrt(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Cos(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Sin(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Tanh(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Round(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Floor(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Ceil(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Trunc(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Quantize(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.qparams.scales.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Dequantize(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::IntoInt(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::IsNan(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::IsInf(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::GridSample2d(repr) => {\n                repr.tensor.mark_read_only(nodes, &mut output);\n                repr.grid.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),\n            FloatOperationIr::ArcTan2(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            FloatOperationIr::Powf(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n        };\n\n        output\n    }\n}\n\nimpl IntOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),\n            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),\n            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),\n            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),\n        }\n    }\n\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),\n            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            IntOperationIr::Matmul(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::IntoFloat(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseAnd(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseAndScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseOr(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseOrScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseXor(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseXorScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseNot(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseLeftShift(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseLeftShiftScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseRightShift(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            IntOperationIr::BitwiseRightShiftScalar(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n            }\n        };\n\n        output\n    }\n}\n\nimpl BoolOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),\n            BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),\n            BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),\n            BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n            BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),\n        }\n    }\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),\n            BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),\n            BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),\n            BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),\n            BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            BoolOperationIr::IntoFloat(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BoolOperationIr::IntoInt(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BoolOperationIr::Not(repr) => {\n                repr.input.mark_read_only(nodes, &mut output);\n            }\n            BoolOperationIr::And(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n            BoolOperationIr::Or(repr) => {\n                repr.lhs.mark_read_only(nodes, &mut output);\n                repr.rhs.mark_read_only(nodes, &mut output);\n            }\n        };\n\n        output\n    }\n}\n\nimpl ModuleOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            ModuleOperationIr::Embedding(repr) => {\n                Box::new([&repr.weights, &repr.indices].into_iter())\n            }\n            ModuleOperationIr::EmbeddingBackward(repr) => {\n                Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())\n            }\n            ModuleOperationIr::Conv1d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::Conv1dXBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv1dWeightBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv1dBiasBackward(repr) => {\n                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv2d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::Conv2dXBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv2dWeightBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv2dBiasBackward(repr) => {\n                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv3d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::Conv3dXBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv3dWeightBackward(repr) => {\n                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::Conv3dBiasBackward(repr) => {\n                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())\n            }\n            ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {\n                (Some(mask), Some(bias)) => {\n                    Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())\n                }\n                (Some(mask), None) => {\n                    Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())\n                }\n                (None, Some(bias)) => {\n                    Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())\n                }\n                (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),\n            },\n            ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {\n                (Some(mask), Some(bias)) => Box::new(\n                    [\n                        &repr.x,\n                        &repr.offset,\n                        &repr.weight,\n                        &repr.out_grad,\n                        mask,\n                        bias,\n                    ]\n                    .into_iter(),\n                ),\n                (Some(mask), None) => Box::new(\n                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),\n                ),\n                (None, Some(bias)) => Box::new(\n                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),\n                ),\n                (None, None) => {\n                    Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())\n                }\n            },\n            ModuleOperationIr::ConvTranspose1d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::ConvTranspose2d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::ConvTranspose3d(repr) => {\n                if let Some(bias) = &repr.bias {\n                    Box::new([&repr.x, &repr.weight, bias].into_iter())\n                } else {\n                    Box::new([&repr.x, &repr.weight].into_iter())\n                }\n            }\n            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::AvgPool1dBackward(repr) => {\n                Box::new([&repr.x, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::AvgPool2dBackward(repr) => {\n                Box::new([&repr.x, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {\n                Box::new([&repr.x, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {\n                Box::new([&repr.x, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {\n                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {\n                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),\n            ModuleOperationIr::InterpolateBackward(repr) => {\n                Box::new([&repr.x, &repr.grad].into_iter())\n            }\n            ModuleOperationIr::Attention(repr) => {\n                if let Some(mask) = &repr.mask {\n                    if let Some(attn_bias) = &repr.attn_bias {\n                        Box::new([&repr.query, &repr.key, &repr.value, mask, attn_bias].into_iter())\n                    } else {\n                        Box::new([&repr.query, &repr.key, &repr.value, mask].into_iter())\n                    }\n                } else if let Some(attn_bias) = &repr.attn_bias {\n                    Box::new([&repr.query, &repr.key, &repr.value, attn_bias].into_iter())\n                } else {\n                    Box::new([&repr.query, &repr.key, &repr.value].into_iter())\n                }\n            }\n        }\n    }\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        match self {\n            ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv1dXBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv1dWeightBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv1dBiasBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv2dXBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv2dWeightBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv2dBiasBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv3dXBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv3dWeightBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Conv3dBiasBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::DeformableConv2dBackward(repr) => {\n                match (&repr.mask_grad, &repr.bias_grad) {\n                    (Some(mask_grad), Some(bias_grad)) => Box::new(\n                        [\n                            &repr.input_grad,\n                            &repr.offset_grad,\n                            &repr.weight_grad,\n                            mask_grad,\n                            bias_grad,\n                        ]\n                        .into_iter(),\n                    ),\n                    (Some(mask_grad), None) => Box::new(\n                        [\n                            &repr.input_grad,\n                            &repr.offset_grad,\n                            &repr.weight_grad,\n                            mask_grad,\n                        ]\n                        .into_iter(),\n                    ),\n                    (None, Some(bias_grad)) => Box::new(\n                        [\n                            &repr.input_grad,\n                            &repr.offset_grad,\n                            &repr.weight_grad,\n                            bias_grad,\n                        ]\n                        .into_iter(),\n                    ),\n                    (None, None) => Box::new(\n                        [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),\n                    ),\n                }\n            }\n            ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::MaxPool1dWithIndices(repr) => {\n                Box::new([&repr.out, &repr.out_indices].into_iter())\n            }\n            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {\n                Box::new([&repr.out].into_iter())\n            }\n            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::MaxPool2dWithIndices(repr) => {\n                Box::new([&repr.out, &repr.out_indices].into_iter())\n            }\n            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {\n                Box::new([&repr.out].into_iter())\n            }\n            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),\n            ModuleOperationIr::Attention(repr) => Box::new([&repr.out].into_iter()),\n        }\n    }\n\n    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {\n        let mut output = Vec::new();\n\n        match self {\n            ModuleOperationIr::Embedding(repr) => {\n                repr.weights.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::EmbeddingBackward(repr) => {\n                repr.weights.mark_read_only(nodes, &mut output);\n                repr.out_grad.mark_read_only(nodes, &mut output);\n                repr.indices.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv1d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::Conv1dXBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv1dWeightBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv1dBiasBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.bias.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::Conv2dXBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv2dWeightBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv2dBiasBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.bias.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv3d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::Conv3dXBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv3dWeightBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Conv3dBiasBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.bias.mark_read_only(nodes, &mut output);\n                repr.output_grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::DeformableConv2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.offset.mark_read_only(nodes, &mut output);\n\n                match (&mut repr.mask, &mut repr.bias) {\n                    (Some(mask), Some(bias)) => {\n                        mask.mark_read_only(nodes, &mut output);\n                        bias.mark_read_only(nodes, &mut output);\n                    }\n                    (Some(mask), None) => {\n                        mask.mark_read_only(nodes, &mut output);\n                    }\n                    (None, Some(bias)) => {\n                        bias.mark_read_only(nodes, &mut output);\n                    }\n                    (None, None) => {}\n                };\n            }\n            ModuleOperationIr::DeformableConv2dBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n                repr.offset.mark_read_only(nodes, &mut output);\n                repr.out_grad.mark_read_only(nodes, &mut output);\n\n                if let Some(mask) = repr.mask.as_mut() {\n                    mask.mark_read_only(nodes, &mut output);\n                }\n                if let Some(bias) = repr.bias.as_mut() {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::ConvTranspose1d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::ConvTranspose2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::ConvTranspose3d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.weight.mark_read_only(nodes, &mut output);\n\n                if let Some(bias) = &mut repr.bias {\n                    bias.mark_read_only(nodes, &mut output);\n                }\n            }\n            ModuleOperationIr::AvgPool1d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AvgPool2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AvgPool1dBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AvgPool2dBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AdaptiveAvgPool1d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AdaptiveAvgPool2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool1d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool1dWithIndices(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool2d(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool2dWithIndices(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Interpolate(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::InterpolateBackward(repr) => {\n                repr.x.mark_read_only(nodes, &mut output);\n                repr.grad.mark_read_only(nodes, &mut output);\n            }\n            ModuleOperationIr::Attention(repr) => {\n                repr.query.mark_read_only(nodes, &mut output);\n                repr.key.mark_read_only(nodes, &mut output);\n                repr.value.mark_read_only(nodes, &mut output);\n                if let Some(mask) = &mut repr.mask {\n                    mask.mark_read_only(nodes, &mut output);\n                }\n                if let Some(attn_bias) = &mut repr.attn_bias {\n                    attn_bias.mark_read_only(nodes, &mut output);\n                }\n            }\n        };\n\n        output\n    }\n}\n\nimpl InitOperationIr {\n    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        Box::new([].into_iter())\n    }\n    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {\n        Box::new([&self.out].into_iter())\n    }\n}\n\nimpl TensorIr {\n    fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {\n        if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {\n            output.push(self.clone());\n            self.status = TensorStatus::ReadOnly;\n        }\n    }\n}\n\nimpl core::hash::Hash for RandomOpIr {\n    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {\n        self.out.hash(state);\n\n        match self.distribution {\n            Distribution::Default => 1u8.hash(state),\n            Distribution::Bernoulli(_) => 2u8.hash(state),\n            Distribution::Uniform(_, _) => 3u8.hash(state),\n            Distribution::Normal(_, _) => 4u8.hash(state),\n        }\n    }\n}\n\n/// Extension trait to extract outputs when registering an operation.\npub trait OperationOutput<O> {\n    /// Extract a single output.\n    fn output(self) -> O;\n\n    /// Extract a fixed number of outputs.\n    fn outputs<const N: usize>(self) -> [O; N];\n}\n\nimpl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {\n    fn output(self) -> O {\n        let [tensor] = self.outputs();\n        tensor\n    }\n\n    fn outputs<const N: usize>(self) -> [O; N] {\n        self.try_into().unwrap()\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/src/scalar.rs",
    "content": "use burn_backend::{DType, Scalar};\nuse burn_backend::{Element, ElementConversion};\nuse core::hash::Hash;\nuse serde::{Deserialize, Serialize};\n\n/// A scalar representation.\n#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]\n#[allow(missing_docs)]\npub enum ScalarIr {\n    Float(f64),\n    Int(i64),\n    UInt(u64),\n    Bool(bool),\n}\n\nimpl Hash for ScalarIr {\n    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {\n        match self {\n            ScalarIr::Float(x) => x.to_bits().hash(state),\n            ScalarIr::Int(x) => x.hash(state),\n            ScalarIr::UInt(x) => x.hash(state),\n            ScalarIr::Bool(x) => x.hash(state),\n        }\n    }\n}\n\nimpl ScalarIr {\n    /// Creates a scalar with the specified data type.\n    pub fn new<E: ElementConversion>(value: E, dtype: &DType) -> Self {\n        if dtype.is_float() {\n            Self::Float(value.elem())\n        } else if dtype.is_int() {\n            Self::Int(value.elem())\n        } else if dtype.is_uint() {\n            Self::UInt(value.elem())\n        } else if dtype.is_bool() {\n            Self::Bool(value.elem())\n        } else {\n            unimplemented!(\"Scalar not supported for {dtype:?}\")\n        }\n    }\n\n    /// Converts and returns the converted element.\n    pub fn elem<E: Element>(self) -> E {\n        match self {\n            ScalarIr::Float(x) => x.elem(),\n            ScalarIr::Int(x) => x.elem(),\n            ScalarIr::UInt(x) => x.elem(),\n            ScalarIr::Bool(x) => x.elem(),\n        }\n    }\n}\n\n// The enums are similar, but both types have different roles:\n// - `Scalar`: runtime literal value\n// - `ScalarIr`: serializable literal representation (used for IR)\nimpl From<Scalar> for ScalarIr {\n    fn from(value: Scalar) -> Self {\n        match value {\n            Scalar::Float(x) => Self::Float(x),\n            Scalar::Int(x) => Self::Int(x),\n            Scalar::UInt(x) => Self::UInt(x),\n            Scalar::Bool(x) => Self::Bool(x),\n        }\n    }\n}\n\nimpl From<ScalarIr> for Scalar {\n    fn from(value: ScalarIr) -> Self {\n        match value {\n            ScalarIr::Float(x) => Self::Float(x),\n            ScalarIr::Int(x) => Self::Int(x),\n            ScalarIr::UInt(x) => Self::UInt(x),\n            ScalarIr::Bool(x) => Self::Bool(x),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ir/src/tensor.rs",
    "content": "use serde::{Deserialize, Serialize};\n\nuse burn_backend::{DType, Shape};\n\n/// The tensor unique identifier.\n#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]\npub struct TensorId {\n    value: u64,\n}\n\nimpl core::fmt::Display for TensorId {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_fmt(format_args!(\"TensorId({:?})\", self.value))\n    }\n}\n\n/// The status of the current tensor.\n#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]\npub enum TensorStatus {\n    /// The tensor can be read, but not written.\n    ReadOnly,\n    /// The tensor can be mutated inplace.\n    ReadWrite,\n    /// No handle exists for that tensor.\n    NotInit,\n}\n\n/// A tensor definition represents a snapshot of a tensor when it was used.\n///\n/// # Example\n///\n/// A tensor that is used multiple times has its status updated for each operation.\n///\n///   1. Status::NotInit\n///   2. Status::ReadOnly\n///   3. Status::ReadOnly\n///   4. Status::ReadWrite\n#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]\npub struct TensorIr {\n    /// The [tensor id](TensorId).\n    pub id: TensorId,\n    /// The shape of the tensor.\n    pub shape: Shape,\n    /// The [status](TensorStatus) of the tensor when it was used.\n    pub status: TensorStatus,\n    /// The [type](DType) of the tensor.\n    pub dtype: DType,\n}\n\nimpl TensorId {\n    /// Create a new tensor id.\n    pub fn new(value: u64) -> Self {\n        Self { value }\n    }\n}\n\nimpl TensorIr {\n    /// Create a new tensor that is not already initialized.\n    pub fn uninit(id: TensorId, shape: Shape, dtype: DType) -> Self {\n        Self {\n            id,\n            status: TensorStatus::NotInit,\n            shape,\n            dtype,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Ndarray backend for the Burn framework\"\ndocumentation = \"https://docs.rs/burn-ndarray\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-ndarray\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-ndarray\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\nblas-accelerate = [\n    \"blas-src/accelerate\", # Accelerate framework (macOS only)\n    \"ndarray/blas\",\n]\nblas-netlib = [\"blas-src/netlib\", \"ndarray/blas\"]\nblas-openblas = [\"blas-src/openblas\", \"ndarray/blas\", \"openblas-src\"]\nblas-openblas-system = [\n    \"blas-src/openblas\",\n    \"ndarray/blas\",\n    \"openblas-src/system\",\n]\ndefault = [\"std\", \"simd\", \"multi-threads\"]\ndoc = [\"default\"]\nmulti-threads = [\n    \"rayon\",\n    \"ndarray/rayon\",\n    \"matrixmultiply/threading\",\n]\nsimd = [\"macerator\", \"bytemuck\", \"seq-macro\", \"itertools\"]\nstd = [\n    \"burn-autodiff\",\n    \"burn-std/std\",\n    \"burn-backend/std\",\n    \"burn-ir/std\",\n    \"ndarray/std\",\n    \"matrixmultiply/std\",\n    \"rand/std\",\n    \"rand/std_rng\",\n    \"num-traits/std\",\n    \"macerator/std\",\n]\ntracing = [\n    \"burn-autodiff?/tracing\",\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n    \"burn-ir/tracing\",\n]\n\n# Serves as a ref impl for some burn-cubecl kernels\nexport_tests = []\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std when std is disabled **\n\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n\natomic_float = { workspace = true }\nblas-src = { workspace = true, default-features = false, optional = true }      # no-std compatible\nconst-random = { workspace = true }\nlibm = { workspace = true }\nmatrixmultiply = { workspace = true, default-features = false }\nndarray = { workspace = true }\nnum-traits = { workspace = true }\nopenblas-src = { workspace = true, optional = true }\npaste = { workspace = true }\nrand = { workspace = true, default-features = false }\n\n# SIMD\nbytemuck = { workspace = true, optional = true }\nitertools = { version = \"0.14\", optional = true }\nmacerator = { workspace = true, optional = true }\nseq-macro = { version = \"0.3\", optional = true }\n\n# Parallel\nrayon = { workspace = true, optional = true }\n\n[target.'cfg(not(target_has_atomic = \"ptr\"))'.dependencies]\nportable-atomic = { workspace = true }\nportable-atomic-util = { workspace = true }\n\n[dev-dependencies]\nbytes = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-ndarray/README.md",
    "content": "# Burn NdArray\n\n> [Burn](https://github.com/tracel-ai/burn) ndarray backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-ndarray.svg)](https://crates.io/crates/burn-ndarray)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-ndarray/blob/master/README.md)\n\n## Feature Flags\n\nThis crate can be used without the standard library (`#![no_std]`) with `alloc` by disabling the\ndefault `std` feature.\n\nThe following flags support various BLAS options:\n\n- `blas-accelerate` - Accelerate framework (macOS only)\n- `blas-netlib` - Netlib\n- `blas-openblas` - OpenBLAS static linked\n- `blas-openblas-system` - OpenBLAS from the system\n\nNote: under the `no_std` mode, the seed is fixed if the seed is not\ninitialized by `Backend::seed` method.\n\n### Platform Support\n\n| Option     | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |\n| :--------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |\n| Pure Rust  | Yes | No  |  Yes  |  Yes  |   Yes   |   Yes   | Yes | Yes  |\n| Accelerate | Yes | No  |  No   |  Yes  |   No    |   No    | Yes |  No  |\n| Netlib     | Yes | No  |  Yes  |  Yes  |   Yes   |   No    | No  |  No  |\n| Openblas   | Yes | No  |  Yes  |  Yes  |   Yes   |   Yes   | Yes |  No  |\n"
  },
  {
    "path": "crates/burn-ndarray/build.rs",
    "content": "fn main() {\n    // https://github.com/rust-ndarray/ndarray/issues/1197\n    if cfg!(feature = \"blas-accelerate\") {\n        println!(\"cargo:rustc-link-lib=framework=Accelerate\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/backend.rs",
    "content": "use crate::rand::NdArrayRng;\nuse crate::{NdArrayQTensor, NdArrayTensor};\nuse crate::{\n    SharedArray,\n    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},\n};\nuse alloc::string::String;\nuse burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue};\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};\nuse burn_backend::{Backend, DType, DeviceId, DeviceOps};\nuse burn_ir::{BackendIr, HandleKind, TensorHandle};\nuse burn_std::BoolStore;\nuse burn_std::stub::Mutex;\nuse core::marker::PhantomData;\nuse rand::SeedableRng;\n\npub(crate) static SEED: Mutex<Option<NdArrayRng>> = Mutex::new(None);\n\n/// The device type for the ndarray backend.\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]\npub enum NdArrayDevice {\n    /// The CPU device.\n    #[default]\n    Cpu,\n}\n\nimpl DeviceOps for NdArrayDevice {}\n\nimpl burn_backend::Device for NdArrayDevice {\n    fn from_id(_device_id: DeviceId) -> Self {\n        Self::Cpu\n    }\n\n    fn to_id(&self) -> DeviceId {\n        DeviceId {\n            type_id: 0,\n            index_id: 0,\n        }\n    }\n\n    fn device_count(_type_id: u16) -> usize {\n        1\n    }\n}\n\n/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations.\n///\n/// This backend is compatible with CPUs and can be compiled for almost any platform, including\n/// `wasm`, `arm`, and `x86`.\n#[derive(Clone, Copy, Default, Debug)]\npub struct NdArray<E = f32, I = i64, Q = i8>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    _e: PhantomData<E>,\n    _i: PhantomData<I>,\n    _q: PhantomData<Q>,\n}\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> Backend for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    type Device = NdArrayDevice;\n\n    type FloatTensorPrimitive = NdArrayTensor;\n    type FloatElem = E;\n\n    type IntTensorPrimitive = NdArrayTensor;\n    type IntElem = I;\n\n    type BoolTensorPrimitive = NdArrayTensor;\n    type BoolElem = bool;\n\n    type QuantizedTensorPrimitive = NdArrayQTensor;\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    fn name(_device: &Self::Device) -> String {\n        String::from(\"ndarray\")\n    }\n\n    fn seed(_device: &Self::Device, seed: u64) {\n        let rng = NdArrayRng::seed_from_u64(seed);\n        let mut seed = SEED.lock().unwrap();\n        *seed = Some(rng);\n    }\n\n    fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {\n        match dtype {\n            DType::F64\n            | DType::F32\n            | DType::Flex32\n            | DType::I64\n            | DType::I32\n            | DType::I16\n            | DType::I8\n            | DType::U64\n            | DType::U32\n            | DType::U16\n            | DType::U8\n            | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(),\n            DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(),\n            DType::QFloat(scheme) => {\n                match scheme {\n                    QuantScheme {\n                        level: QuantLevel::Tensor | QuantLevel::Block(_),\n                        mode: QuantMode::Symmetric,\n                        #[cfg(not(feature = \"export_tests\"))]\n                            value: QuantValue::Q8F | QuantValue::Q8S,\n                        // For tests, \"native\" sub-byte quant serves as a reference for value equality.\n                        // Values are stored as i8 regardless.\n                        #[cfg(feature = \"export_tests\")]\n                            value:\n                            QuantValue::Q8F\n                            | QuantValue::Q8S\n                            | QuantValue::Q4F\n                            | QuantValue::Q4S\n                            | QuantValue::Q2F\n                            | QuantValue::Q2S,\n                        store: QuantStore::Native,\n                        ..\n                    } => burn_backend::DTypeUsage::general(),\n                    _scheme => burn_backend::DTypeUsageSet::empty(),\n                }\n            }\n        }\n    }\n}\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BackendIr for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    type Handle = HandleKind<Self>;\n\n    fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {\n        match handle.handle {\n            HandleKind::Float(handle) => handle,\n            _ => panic!(\"Expected float handle, got {}\", handle.handle.name()),\n        }\n    }\n\n    fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {\n        match handle.handle {\n            HandleKind::Int(handle) => handle,\n            _ => panic!(\"Expected int handle, got {}\", handle.handle.name()),\n        }\n    }\n\n    fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {\n        match handle.handle {\n            HandleKind::Bool(handle) => handle,\n            _ => panic!(\"Expected bool handle, got {}\", handle.handle.name()),\n        }\n    }\n\n    fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {\n        match handle.handle {\n            HandleKind::Quantized(handle) => handle,\n            _ => panic!(\"Expected quantized handle, got {}\", handle.handle.name()),\n        }\n    }\n\n    fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {\n        HandleKind::Float(tensor)\n    }\n\n    fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {\n        HandleKind::Int(tensor)\n    }\n\n    fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {\n        HandleKind::Bool(tensor)\n    }\n\n    fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {\n        HandleKind::Quantized(tensor)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_backend::QTensorPrimitive;\n\n    #[test]\n    fn should_support_dtypes() {\n        type B = NdArray<f32>;\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F64));\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::Flex32));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::I16));\n        assert!(B::supports_dtype(&device, DType::I8));\n        assert!(B::supports_dtype(&device, DType::U64));\n        assert!(B::supports_dtype(&device, DType::U32));\n        assert!(B::supports_dtype(&device, DType::U16));\n        assert!(B::supports_dtype(&device, DType::U8));\n        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n        assert!(B::supports_dtype(\n            &device,\n            DType::QFloat(NdArrayQTensor::default_scheme())\n        ));\n\n        assert!(!B::supports_dtype(&device, DType::F16));\n        assert!(!B::supports_dtype(&device, DType::BF16));\n        // QuantStore::U32 not supported\n        assert!(!B::supports_dtype(\n            &device,\n            DType::QFloat(QuantScheme::default())\n        ));\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/element.rs",
    "content": "use burn_backend::Element;\nuse num_traits::Signed;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse num_traits::Pow;\n\nuse libm::{log1p, log1pf};\n\n/// A float element for ndarray backend.\npub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd<Self>\nwhere\n    Self: Sized,\n{\n}\n\n/// An int element for ndarray backend.\npub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd<Self> {}\n\n/// A general element for ndarray backend.\npub trait NdArrayElement:\n    Element\n    + ndarray::LinalgScalar\n    + ndarray::ScalarOperand\n    + ExpElement\n    + AddAssignElement\n    + num_traits::FromPrimitive\n    + core::ops::AddAssign\n    + core::cmp::PartialEq\n    + core::ops::Rem<Output = Self>\n{\n}\n\n/// A element for ndarray backend that supports exp ops.\npub trait ExpElement {\n    /// Exponent\n    fn exp_elem(self) -> Self;\n    /// Log\n    fn log_elem(self) -> Self;\n    /// Log1p\n    fn log1p_elem(self) -> Self;\n    /// Powf\n    fn powf_elem(self, value: f32) -> Self;\n    /// Powi\n    fn powi_elem(self, value: i32) -> Self;\n    /// Sqrt\n    fn sqrt_elem(self) -> Self;\n    /// Abs\n    fn abs_elem(self) -> Self;\n}\n\n/// The addition assignment operator implemented for ndarray elements.\npub trait AddAssignElement<Rhs = Self> {\n    /// Performs the addition assignment operation.\n    ///\n    /// For `bool`, this corresponds to logical OR assignment.\n    fn add_assign(&mut self, rhs: Rhs);\n}\n\nimpl<E: NdArrayElement> AddAssignElement for E {\n    fn add_assign(&mut self, rhs: Self) {\n        *self += rhs;\n    }\n}\n\nimpl AddAssignElement for bool {\n    fn add_assign(&mut self, rhs: Self) {\n        *self = *self || rhs; // logical OR for bool\n    }\n}\n\n/// A quantized element for the ndarray backend.\npub trait QuantElement: NdArrayElement {}\n\nimpl QuantElement for i8 {}\n\nimpl FloatNdArrayElement for f64 {}\nimpl FloatNdArrayElement for f32 {}\n\nimpl IntNdArrayElement for i64 {}\nimpl IntNdArrayElement for i32 {}\nimpl IntNdArrayElement for i16 {}\nimpl IntNdArrayElement for i8 {}\n\nimpl IntNdArrayElement for u64 {}\nimpl IntNdArrayElement for u32 {}\nimpl IntNdArrayElement for u16 {}\nimpl IntNdArrayElement for u8 {}\n\nmacro_rules! make_float {\n    (\n        $ty:ty,\n        $log1p:expr\n    ) => {\n        impl NdArrayElement for $ty {}\n\n        #[allow(clippy::cast_abs_to_unsigned)]\n        impl ExpElement for $ty {\n            #[inline(always)]\n            fn exp_elem(self) -> Self {\n                self.exp()\n            }\n\n            #[inline(always)]\n            fn log_elem(self) -> Self {\n                self.ln()\n            }\n\n            #[inline(always)]\n            fn log1p_elem(self) -> Self {\n                $log1p(self)\n            }\n\n            #[inline(always)]\n            fn powf_elem(self, value: f32) -> Self {\n                self.pow(value)\n            }\n\n            #[inline(always)]\n            fn powi_elem(self, value: i32) -> Self {\n                #[cfg(feature = \"std\")]\n                let val = self.powi(value);\n\n                #[cfg(not(feature = \"std\"))]\n                let val = Self::powf_elem(self, value as f32);\n\n                val\n            }\n\n            #[inline(always)]\n            fn sqrt_elem(self) -> Self {\n                self.sqrt()\n            }\n\n            #[inline(always)]\n            fn abs_elem(self) -> Self {\n                self.abs()\n            }\n        }\n    };\n}\nmacro_rules! make_int {\n    (\n        $ty:ty,\n        $abs:expr\n    ) => {\n        impl NdArrayElement for $ty {}\n\n        #[allow(clippy::cast_abs_to_unsigned)]\n        impl ExpElement for $ty {\n            #[inline(always)]\n            fn exp_elem(self) -> Self {\n                (self as f32).exp() as $ty\n            }\n\n            #[inline(always)]\n            fn log_elem(self) -> Self {\n                (self as f32).ln() as $ty\n            }\n\n            #[inline(always)]\n            fn log1p_elem(self) -> Self {\n                log1pf(self as f32) as $ty\n            }\n\n            #[inline(always)]\n            fn powf_elem(self, value: f32) -> Self {\n                (self as f32).pow(value) as $ty\n            }\n\n            #[inline(always)]\n            fn powi_elem(self, value: i32) -> Self {\n                #[cfg(feature = \"std\")]\n                let val = f32::powi(self as f32, value) as $ty;\n\n                #[cfg(not(feature = \"std\"))]\n                let val = Self::powf_elem(self, value as f32);\n\n                val\n            }\n\n            #[inline(always)]\n            fn sqrt_elem(self) -> Self {\n                (self as f32).sqrt() as $ty\n            }\n\n            #[inline(always)]\n            fn abs_elem(self) -> Self {\n                $abs(self)\n            }\n        }\n    };\n}\n\nmake_float!(f64, log1p);\nmake_float!(f32, log1pf);\n\nmake_int!(i64, i64::wrapping_abs);\nmake_int!(i32, i32::wrapping_abs);\nmake_int!(i16, i16::wrapping_abs);\nmake_int!(i8, i8::wrapping_abs);\nmake_int!(u64, |x| x);\nmake_int!(u32, |x| x);\nmake_int!(u16, |x| x);\nmake_int!(u8, |x| x);\n"
  },
  {
    "path": "crates/burn-ndarray/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! Burn ndarray backend.\n\n#[cfg(any(\n    feature = \"blas-netlib\",\n    feature = \"blas-openblas\",\n    feature = \"blas-openblas-system\",\n))]\nextern crate blas_src;\n\nmod backend;\nmod element;\nmod ops;\nmod parallel;\nmod rand;\nmod sharing;\nmod storage;\nmod tensor;\n\npub use backend::*;\npub use element::*;\npub(crate) use sharing::*;\npub(crate) use storage::*;\npub use tensor::*;\n\nextern crate alloc;\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/activation.rs",
    "content": "use crate::{\n    NdArray, NdArrayTensor, SharedArray,\n    element::{FloatNdArrayElement, IntNdArrayElement, QuantElement},\n    execute_with_numeric_dtype,\n    ops::NdArrayMathOps,\n};\nuse burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor};\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ActivationOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn relu(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem()))\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/adaptive_avgpool.rs",
    "content": "use crate::{\n    SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,\n};\nuse burn_backend::ElementConversion;\nuse ndarray::Array4;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\npub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    output_size: [usize; 2],\n) -> SharedArray<E> {\n    let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap();\n\n    let mut output = Array4::from_elem(\n        (batch_size, channels, output_size[0], output_size[1]),\n        0.elem(),\n    );\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output = unsafe_shared_out.get();\n            for h in 0..output_size[0] {\n                for w in 0..output_size[1] {\n                    let ih_start = start_index(h, output_size[0], input_height);\n                    let ih_end = end_index(h, output_size[0], input_height);\n                    let iw_start = start_index(w, output_size[1], input_width);\n                    let iw_end = end_index(w, output_size[1], input_width);\n\n                    let mut sum_val: E = 0.elem();\n\n                    for ih in ih_start..ih_end {\n                        for iw in iw_start..iw_end {\n                            sum_val += x[[b, c, ih, iw]];\n                        }\n                    }\n\n                    let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();\n                    output[[b, c, h, w]] = sum_val / count.elem();\n                }\n            }\n        })\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    grad: SharedArray<E>,\n) -> SharedArray<E> {\n    let [_, _, input_height, input_width] = x.shape().try_into().unwrap();\n    let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap();\n\n    let mut output_grad =\n        Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output_grad = unsafe_shared_out.get();\n            for oh in 0..output_height {\n                for ow in 0..output_width {\n                    let ih_start = start_index(oh, output_height, input_height);\n                    let ih_end = end_index(oh, output_height, input_height);\n\n                    let iw_start = start_index(ow, output_width, input_width);\n                    let iw_end = end_index(ow, output_width, input_width);\n\n                    let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();\n\n                    for ih in ih_start..ih_end {\n                        for iw in iw_start..iw_end {\n                            output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem();\n                        }\n                    }\n                }\n            }\n        })\n    });\n\n    output_grad.into_dyn().into_shared()\n}\n\nfn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize\n}\n\nfn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    let index =\n        (((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize;\n\n    usize::min(index, input_size)\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/avgpool.rs",
    "content": "use crate::{\n    SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,\n};\n\nuse burn_backend::ElementConversion;\nuse burn_backend::ops::conv::calculate_pool_output_size;\nuse ndarray::Array4;\n\npub(crate) fn avg_pool2d<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> SharedArray<E> {\n    let [kernel_height, kernel_width] = kernel_size;\n    let [padding_height, padding_width] = padding;\n    let [stride_height, stride_width] = stride;\n    let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();\n\n    let out_height = calculate_pool_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        1,\n        x_height,\n        ceil_mode,\n    );\n    let out_width = calculate_pool_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        1,\n        x_width,\n        ceil_mode,\n    );\n\n    // Padded input bounds (for count_include_pad calculation)\n    let padded_height = x_height + 2 * padding_height;\n    let padded_width = x_width + 2 * padding_width;\n\n    let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem());\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output = unsafe_shared_out.get();\n\n            for oh in 0..out_height {\n                for ow in 0..out_width {\n                    let mut sum_val: E = 0.elem();\n                    let mut valid_count = 0usize;\n                    let mut padded_count = 0usize;\n\n                    for kh in 0..kernel_height {\n                        let ih = oh * stride_height + kh;\n\n                        for kw in 0..kernel_width {\n                            let iw = ow * stride_width + kw;\n\n                            // Check if within padded bounds (excludes ceil_mode extensions)\n                            if ih < padded_height && iw < padded_width {\n                                padded_count += 1;\n\n                                // Check if within valid (non-padding) input bounds\n                                if ih >= padding_height\n                                    && ih < x_height + padding_height\n                                    && iw >= padding_width\n                                    && iw < x_width + padding_width\n                                {\n                                    let ih_valid = ih - padding_height;\n                                    let iw_valid = iw - padding_width;\n                                    sum_val += x[[b, c, ih_valid, iw_valid]];\n                                    valid_count += 1;\n                                }\n                            }\n                        }\n                    }\n\n                    // count_include_pad: count positions within padded bounds (not ceil_mode extensions)\n                    // !count_include_pad: count only valid (non-padding) positions\n                    let count: E = if count_include_pad {\n                        (padded_count as i32).elem()\n                    } else {\n                        (valid_count as i32).elem()\n                    };\n\n                    output[[b, c, oh, ow]] = sum_val / count;\n                }\n            }\n        })\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    grad: SharedArray<E>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    count_include_pad: bool,\n    _ceil_mode: bool,\n) -> SharedArray<E> {\n    let [kernel_height, kernel_width] = kernel_size;\n    let [stride_height, stride_width] = stride;\n    let [padding_height, padding_width] = padding;\n    let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();\n    let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap();\n\n    // Padded input bounds (for count_include_pad calculation)\n    let padded_height = x_height + 2 * padding_height;\n    let padded_width = x_width + 2 * padding_width;\n\n    let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem());\n    let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output_grad = unsafe_shared_grad.get();\n\n            for oh in 0..out_height {\n                for ow in 0..out_width {\n                    let ih_start_kernel = oh * stride_height;\n                    let iw_start_kernel = ow * stride_width;\n\n                    let ih_end_kernel = ih_start_kernel + kernel_height;\n                    let iw_end_kernel = iw_start_kernel + kernel_width;\n\n                    // Clip to valid input bounds (for gradient distribution)\n                    let ih_start = usize::max(ih_start_kernel, padding_height);\n                    let iw_start = usize::max(iw_start_kernel, padding_width);\n                    let ih_end = usize::min(ih_end_kernel, x_height + padding_height);\n                    let iw_end = usize::min(iw_end_kernel, x_width + padding_width);\n\n                    // Calculate count based on count_include_pad\n                    let count = if count_include_pad {\n                        // Count positions within padded bounds (not ceil_mode extensions)\n                        let ih_start_padded = ih_start_kernel;\n                        let iw_start_padded = iw_start_kernel;\n                        let ih_end_padded = usize::min(ih_end_kernel, padded_height);\n                        let iw_end_padded = usize::min(iw_end_kernel, padded_width);\n                        (ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded)\n                    } else {\n                        // Count only valid (non-padding) positions\n                        (ih_end - ih_start) * (iw_end - iw_start)\n                    };\n\n                    for ih in ih_start..ih_end {\n                        for iw in iw_start..iw_end {\n                            let ih = ih - padding_height;\n                            let iw = iw - padding_width;\n\n                            output_grad[[b, c, ih, iw]] +=\n                                grad[[b, c, oh, ow]] / (count as i32).elem();\n                        }\n                    }\n                }\n            }\n        })\n    });\n\n    output_grad.into_dyn().into_shared()\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/base.rs",
    "content": "use alloc::{vec, vec::Vec};\nuse burn_backend::element::{Element, ElementConversion};\n#[cfg(feature = \"simd\")]\nuse burn_backend::{DType, quantization::QuantValue};\nuse core::fmt::Debug;\nuse core::marker::PhantomData;\nuse ndarray::IntoDimension;\nuse ndarray::SliceInfo;\nuse ndarray::Zip;\nuse ndarray::s;\nuse ndarray::{Array2, ArrayD};\nuse num_traits::Signed;\n#[cfg(feature = \"simd\")]\nuse paste::paste;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\n#[cfg(feature = \"simd\")]\nuse crate::ops::simd::{\n    binary::try_binary_simd,\n    binary_elemwise::{\n        VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub,\n        try_binary_scalar_simd,\n    },\n    cmp::{\n        VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd,\n        try_cmp_simd,\n    },\n    unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd},\n};\nuse crate::reshape;\nuse crate::{\n    IntNdArrayElement, ShapeOps,\n    ops::macros::{\n        cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim,\n    },\n};\nuse crate::{SharedArray, element::NdArrayElement};\nuse burn_backend::ops::unfold::calculate_unfold_shape;\nuse burn_backend::{Shape, Slice};\nuse ndarray::ArrayView;\nuse ndarray::Axis;\nuse ndarray::Dim;\nuse ndarray::IxDyn;\nuse ndarray::SliceInfoElem;\n\npub struct NdArrayOps<E> {\n    e: PhantomData<E>,\n}\n\npub(crate) struct NdArrayMathOps<E> {\n    e: PhantomData<E>,\n}\n\nimpl<E> NdArrayOps<E>\nwhere\n    E: Copy + Debug + Element + crate::AddAssignElement,\n{\n    pub fn slice(tensor: ArrayView<E, IxDyn>, slices: &[Slice]) -> SharedArray<E> {\n        let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims());\n        tensor.slice_move(slices.as_slice()).to_shared()\n    }\n\n    pub fn slice_assign(\n        tensor: SharedArray<E>,\n        slices: &[Slice],\n        value: SharedArray<E>,\n    ) -> SharedArray<E> {\n        let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims());\n        let mut array = tensor.into_owned();\n        array.slice_mut(slices.as_slice()).assign(&value);\n        array.into_shared()\n    }\n\n    pub fn mask_where(\n        tensor: SharedArray<E>,\n        mask: SharedArray<bool>,\n        source: SharedArray<E>,\n    ) -> SharedArray<E> {\n        let tensor = tensor.broadcast(mask.dim()).unwrap();\n        let source = source.broadcast(mask.dim()).unwrap();\n        Zip::from(&tensor)\n            .and(&mask)\n            .and(&source)\n            .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })\n            .into_shared()\n    }\n\n    pub fn mask_fill(tensor: SharedArray<E>, mask: SharedArray<bool>, value: E) -> SharedArray<E> {\n        // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique\n        let mut output = tensor.into_owned();\n        let broadcast_mask = mask.broadcast(output.dim()).unwrap();\n        Zip::from(&mut output)\n            .and(&broadcast_mask)\n            .for_each(|out, &mask_val| {\n                if mask_val {\n                    *out = value;\n                }\n            });\n        output.into_shared()\n    }\n\n    pub fn gather<I: NdArrayElement>(\n        dim: usize,\n        mut tensor: SharedArray<E>,\n        mut indices: SharedArray<I>,\n    ) -> SharedArray<E> {\n        let ndims = tensor.shape().num_dims();\n        if dim != ndims - 1 {\n            tensor.swap_axes(ndims - 1, dim);\n            indices.swap_axes(ndims - 1, dim);\n        }\n        let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape());\n        let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices[ndims - 1]);\n        let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices);\n\n        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index]));\n        let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor]));\n        let mut output = Array2::from_elem((batch_size, size_index), 0.elem::<E>());\n\n        for b in 0..batch_size {\n            let indices = indices.slice(s!(b, ..));\n            for (i, index) in indices.iter().enumerate() {\n                output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];\n            }\n        }\n\n        let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices);\n\n        if dim != ndims - 1 {\n            output.swap_axes(ndims - 1, dim);\n        }\n\n        output\n    }\n\n    pub fn scatter<I: NdArrayElement>(\n        dim: usize,\n        mut tensor: SharedArray<E>,\n        mut indices: SharedArray<I>,\n        mut value: SharedArray<E>,\n    ) -> SharedArray<E> {\n        let ndims = tensor.shape().num_dims();\n        if dim != ndims - 1 {\n            tensor.swap_axes(ndims - 1, dim);\n            indices.swap_axes(ndims - 1, dim);\n            value.swap_axes(ndims - 1, dim);\n        }\n\n        let (shape_tensor, shape_indices, shape_value) =\n            (tensor.shape().into_shape(), indices.shape(), value.shape());\n        let (size_tensor, size_index, size_value) = (\n            shape_tensor[ndims - 1],\n            shape_indices[ndims - 1],\n            shape_value[ndims - 1],\n        );\n        let batch_size = Self::gather_batch_size(&shape_tensor, shape_indices);\n\n        if shape_value != shape_indices {\n            panic!(\n                \"Invalid dimension: the shape of the index tensor should be the same as the value \\\n                 tensor: Index {:?} value {:?}\",\n                shape_indices, shape_value\n            );\n        }\n\n        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index]));\n        let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value]));\n        let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor]));\n\n        for b in 0..batch_size {\n            let indices = indices.slice(s!(b, ..));\n\n            for (i, index) in indices.iter().enumerate() {\n                let index = index.elem::<i64>() as usize;\n                tensor[[b, index]].add_assign(value[[b, i]]);\n            }\n        }\n\n        let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor);\n        if dim != ndims - 1 {\n            output.swap_axes(ndims - 1, dim);\n        }\n        output\n    }\n\n    fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize {\n        let ndims = shape_tensor.num_dims();\n        let mut batch_size = 1;\n\n        for i in 0..ndims - 1 {\n            if shape_tensor[i] != shape_indices[i] {\n                panic!(\n                    \"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \\\n                     {:?}\",\n                    shape_tensor, shape_indices\n                );\n            }\n            batch_size *= shape_indices[i];\n        }\n\n        batch_size\n    }\n\n    pub fn reshape(tensor: SharedArray<E>, shape: Shape) -> SharedArray<E> {\n        reshape!(\n            ty E,\n            shape shape,\n            array tensor,\n            d shape.num_dims()\n        )\n    }\n\n    pub(crate) fn concatenate(\n        arrays: &[ndarray::ArrayView<E, IxDyn>],\n        dim: usize,\n    ) -> SharedArray<E> {\n        let array = ndarray::concatenate(Axis(dim), arrays)\n            .unwrap()\n            .into_shared();\n\n        // Transform column-major layout into row-major (standard) layout. (fix #1053)\n        // Get shape first (via reference), then pass ownership to avoid clone\n        let shape = array.shape().into_shape();\n        Self::reshape(array, shape)\n    }\n\n    pub fn cat(tensors: Vec<SharedArray<E>>, dim: usize) -> SharedArray<E> {\n        let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect();\n        Self::concatenate(&arrays, dim)\n    }\n\n    #[allow(clippy::wrong_self_convention)]\n    fn to_slice_args_with_steps(\n        burn_slices: &[burn_backend::Slice],\n        ndims: usize,\n    ) -> Vec<SliceInfoElem> {\n        let mut slices = vec![SliceInfoElem::NewAxis; ndims];\n\n        for i in 0..ndims {\n            slices[i] = if i < burn_slices.len() {\n                let slice = &burn_slices[i];\n\n                // Check for empty range (would result in no elements)\n                if let Some(end) = slice.end\n                    && slice.start == end\n                {\n                    SliceInfoElem::Slice {\n                        start: 0,\n                        end: Some(0),\n                        step: 1,\n                    }\n                } else {\n                    // Pass slice parameters directly to ndarray\n                    // ndarray handles both positive and negative steps correctly:\n                    // - Positive step: iterates forward from start\n                    // - Negative step: iterates backward from the last element in range\n                    SliceInfoElem::Slice {\n                        start: slice.start,\n                        end: slice.end,\n                        step: slice.step,\n                    }\n                }\n            } else {\n                // Dimension not specified in slices - use full range\n                SliceInfoElem::Slice {\n                    start: 0,\n                    end: None,\n                    step: 1,\n                }\n            }\n        }\n\n        slices\n    }\n\n    pub fn swap_dims(mut tensor: SharedArray<E>, dim1: usize, dim2: usize) -> SharedArray<E> {\n        tensor.swap_axes(dim1, dim2);\n\n        tensor\n    }\n\n    pub fn permute(tensor: SharedArray<E>, axes: &[usize]) -> SharedArray<E> {\n        tensor.permuted_axes(axes.into_dimension())\n    }\n\n    /// Broadcasts the tensor to the given shape\n    pub(crate) fn expand(tensor: SharedArray<E>, shape: Shape) -> SharedArray<E> {\n        tensor\n            .broadcast(shape.into_dimension())\n            .expect(\"The shapes should be broadcastable\")\n            // need to convert view to owned array because NdArrayTensor expects owned array\n            // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides)\n            .into_owned()\n            .into_shared()\n    }\n\n    pub fn flip(tensor: SharedArray<E>, axes: &[usize]) -> SharedArray<E> {\n        let slice_items: Vec<_> = (0..tensor.shape().num_dims())\n            .map(|i| {\n                if axes.contains(&i) {\n                    SliceInfoElem::Slice {\n                        start: 0,\n                        end: None,\n                        step: -1,\n                    }\n                } else {\n                    SliceInfoElem::Slice {\n                        start: 0,\n                        end: None,\n                        step: 1,\n                    }\n                }\n            })\n            .collect();\n        let slice_info =\n            SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();\n        tensor.slice(slice_info).into_owned().into_shared()\n    }\n\n    /// Unfold windows along a dimension.\n    ///\n    /// # Warning\n    ///\n    /// This is a copy impl; `ndarray` doesn't expose the layout machinery\n    /// necessary to build the stride view.\n    ///\n    /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``\n    /// * `dim` - the dimension to unfold.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with shape ``[pre=..., windows, post=..., size]``.\n    #[allow(unused)]\n    pub(crate) fn unfold(\n        tensor: SharedArray<E>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> SharedArray<E> {\n        let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);\n        let windows = result_shape[dim];\n\n        let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()];\n        let new_axis = slices.len();\n\n        let mut stack = Vec::with_capacity(windows);\n        for widx in 0..windows {\n            let start = widx * step;\n            let end = start + size;\n            slices[dim] = Slice::new(start as isize, Some(end as isize), 1);\n\n            let mut window_slice =\n                tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice());\n            window_slice.insert_axis_inplace(Axis(new_axis));\n            window_slice.swap_axes(dim, new_axis);\n\n            stack.push(window_slice);\n        }\n        Self::concatenate(&stack, dim)\n    }\n}\n\n#[cfg(feature = \"simd\")]\nmacro_rules! dispatch_binary_simd {\n    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*\n                _ => Err(($lhs, $rhs)),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*\n                DType::QFloat(strategy) => match strategy.value {\n                    QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),\n                    _ => Err(($lhs, $rhs)),\n                },\n                _ => Err(($lhs, $rhs)),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n}\n\n#[cfg(not(feature = \"simd\"))]\nmacro_rules! dispatch_binary_simd {\n    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};\n}\n\n#[cfg(feature = \"simd\")]\nmacro_rules! dispatch_binary_scalar_simd {\n    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*\n                _ => Err($lhs),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*\n                DType::QFloat(strategy) => match strategy.value {\n                    QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),\n                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs)\n                },\n                _ => Err($lhs),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n}\n\n#[cfg(not(feature = \"simd\"))]\nmacro_rules! dispatch_binary_scalar_simd {\n    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};\n}\n\n#[cfg(feature = \"simd\")]\nmacro_rules! dispatch_cmp_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)*\n                DType::QFloat(strategy) => match strategy.value {\n                    QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs),\n                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs))\n                },\n                _ => Err(($lhs, $rhs)),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n}\n\n#[cfg(not(feature = \"simd\"))]\nmacro_rules! dispatch_cmp_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};\n}\n\n#[cfg(feature = \"simd\")]\nmacro_rules! dispatch_cmp_scalar_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)*\n                DType::QFloat(strategy) => match strategy.value {\n                    QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs),\n                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs)\n                },\n                _ => Err($lhs),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n}\n\n#[cfg(not(feature = \"simd\"))]\nmacro_rules! dispatch_cmp_scalar_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};\n}\n\n#[cfg(feature = \"simd\")]\nmacro_rules! dispatch_unary_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{\n        paste! {\n            let simd = match $elem::dtype() {\n                $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)*\n                _ => Err($lhs),\n            };\n            match simd {\n                Ok(out) => return out,\n                Err(args) => args,\n            }\n        }\n    }};\n}\n\n#[cfg(not(feature = \"simd\"))]\nmacro_rules! dispatch_unary_simd {\n    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }};\n}\n\n// Helper function to broadcast two tensors to a common shape for comparison operations\n// Returns broadcasted views that can be safely zipped\nfn broadcast_for_comparison<'a, E: Copy, S1, S2>(\n    lhs: &'a ndarray::ArrayBase<S1, ndarray::IxDyn>,\n    rhs: &'a ndarray::ArrayBase<S2, ndarray::IxDyn>,\n) -> (\n    ndarray::ArrayView<'a, E, ndarray::IxDyn>,\n    ndarray::ArrayView<'a, E, ndarray::IxDyn>,\n)\nwhere\n    S1: ndarray::Data<Elem = E>,\n    S2: ndarray::Data<Elem = E>,\n{\n    // Get shapes\n    let lhs_shape = lhs.shape();\n    let rhs_shape = rhs.shape();\n\n    // Compute broadcast shape using ndarray's broadcast compatibility rules\n    let ndims = lhs_shape.len().max(rhs_shape.len());\n    let mut broadcast_shape = vec![1; ndims];\n\n    for i in 0..ndims {\n        let lhs_dim = if i < lhs_shape.len() {\n            lhs_shape[lhs_shape.len() - 1 - i]\n        } else {\n            1\n        };\n        let rhs_dim = if i < rhs_shape.len() {\n            rhs_shape[rhs_shape.len() - 1 - i]\n        } else {\n            1\n        };\n\n        if lhs_dim == rhs_dim {\n            broadcast_shape[ndims - 1 - i] = lhs_dim;\n        } else if lhs_dim == 1 {\n            broadcast_shape[ndims - 1 - i] = rhs_dim;\n        } else if rhs_dim == 1 {\n            broadcast_shape[ndims - 1 - i] = lhs_dim;\n        } else {\n            panic!(\n                \"Incompatible shapes for broadcasting: {:?} and {:?}\",\n                lhs_shape, rhs_shape\n            );\n        }\n    }\n\n    // Create IxDyn from broadcast shape\n    let broadcast_dim = ndarray::IxDyn(&broadcast_shape);\n\n    // Broadcast both arrays\n    let lhs_broadcast = lhs\n        .broadcast(broadcast_dim.clone())\n        .expect(\"Failed to broadcast lhs\");\n    let rhs_broadcast = rhs\n        .broadcast(broadcast_dim)\n        .expect(\"Failed to broadcast rhs\");\n\n    (lhs_broadcast, rhs_broadcast)\n}\n\nimpl<E> NdArrayMathOps<E>\nwhere\n    E: Copy + NdArrayElement,\n{\n    pub fn add(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {\n        let (lhs, rhs) = dispatch_binary_simd!(\n            E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64\n        );\n\n        let array = &lhs + &rhs;\n        array.into_shared()\n    }\n\n    pub fn add_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {\n        let lhs = dispatch_binary_scalar_simd!(\n            E,\n            VecAdd,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            u64,\n            i64,\n            f64\n        );\n\n        let array = lhs + rhs;\n        array.into_shared()\n    }\n\n    pub fn sub(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {\n        let (lhs, rhs) = dispatch_binary_simd!(\n            E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64\n        );\n\n        let array = lhs - rhs;\n        array.into_shared()\n    }\n\n    pub fn sub_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {\n        let lhs = dispatch_binary_scalar_simd!(\n            E,\n            VecSub,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            u64,\n            i64,\n            f64\n        );\n\n        let array = lhs - rhs;\n        array.into_shared()\n    }\n\n    pub fn mul(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {\n        let (lhs, rhs) =\n            dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64);\n\n        let array = lhs * rhs;\n        array.into_shared()\n    }\n\n    pub fn mul_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {\n        let lhs = dispatch_binary_scalar_simd!(\n            noq,\n            E,\n            VecMul,\n            lhs,\n            rhs.elem(),\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            f64\n        );\n\n        let array = lhs * rhs;\n        array.into_shared()\n    }\n\n    pub fn div(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {\n        let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64);\n\n        let array = lhs / rhs;\n        array.into_shared()\n    }\n\n    pub fn div_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {\n        let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64);\n\n        let array = lhs / rhs;\n        array.into_shared()\n    }\n\n    pub fn remainder(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {\n        // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique\n        let mut out = lhs.into_owned();\n        Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| {\n            // out_elem holds lhs value; read it before overwriting with remainder\n            let a_f = (*out_elem).to_f64();\n            let b_f = b.to_f64();\n            let r = a_f - b_f * (a_f / b_f).floor();\n            *out_elem = r.elem();\n        });\n        out.into_shared()\n    }\n\n    pub fn remainder_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E>\n    where\n        E: core::ops::Rem<Output = E>,\n    {\n        let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs);\n        array.into_shared()\n    }\n\n    pub fn recip(tensor: SharedArray<E>) -> SharedArray<E> {\n        let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32);\n\n        let array = tensor.map(|x| 1.elem::<E>() / *x);\n        array.into_shared()\n    }\n\n    /// Sum all elements - zero-copy for borrowed storage.\n    pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {\n        let sum = view.sum();\n        ArrayD::from_elem(IxDyn(&[1]), sum).into_shared()\n    }\n\n    /// Mean of all elements - zero-copy for borrowed storage.\n    pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {\n        let mean = view.mean().unwrap();\n        ArrayD::from_elem(IxDyn(&[1]), mean).into_shared()\n    }\n\n    /// Product of all elements - zero-copy for borrowed storage.\n    pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {\n        let prod = view.iter().fold(E::one(), |acc, &x| acc * x);\n        ArrayD::from_elem(IxDyn(&[1]), prod).into_shared()\n    }\n\n    pub fn mean_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        let ndims = tensor.shape().num_dims();\n        match ndims {\n            d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean),\n            _ => panic!(\"Dim not supported {ndims}\"),\n        }\n    }\n\n    pub fn sum_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        let ndims = tensor.shape().num_dims();\n        match ndims {\n            d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum),\n            _ => panic!(\"Dim not supported {ndims}\"),\n        }\n    }\n\n    pub fn prod_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        let ndims = tensor.shape().num_dims();\n        match ndims {\n            d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod),\n            _ => panic!(\"Dim not supported {ndims}\"),\n        }\n    }\n\n    pub fn cumsum(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        cumsum_dim(tensor, dim)\n    }\n\n    pub fn cumprod(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        cumprod_dim(tensor, dim)\n    }\n\n    pub fn select<I: NdArrayElement>(\n        tensor: SharedArray<E>,\n        dim: usize,\n        indices: SharedArray<I>,\n    ) -> SharedArray<E> {\n        let array = tensor.select(\n            Axis(dim),\n            &indices\n                .into_iter()\n                .map(|i| i.elem::<i64>() as usize)\n                .collect::<Vec<_>>(),\n        );\n\n        array.into_shared()\n    }\n\n    pub fn select_assign<I: NdArrayElement>(\n        tensor: SharedArray<E>,\n        dim: usize,\n        indices: SharedArray<I>,\n        value: SharedArray<E>,\n    ) -> SharedArray<E> {\n        let mut output_array = tensor.into_owned();\n\n        for (index_value, index) in indices.into_iter().enumerate() {\n            let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);\n            let value = value.index_axis(Axis(dim), index_value);\n\n            view.zip_mut_with(&value, |a, b| *a += *b);\n        }\n\n        output_array.into_shared()\n    }\n\n    pub(crate) fn elementwise_op<OtherE>(\n        lhs: SharedArray<E>,\n        rhs: SharedArray<OtherE>,\n        var_name: impl FnMut(&E, &OtherE) -> E,\n    ) -> SharedArray<E> {\n        let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view());\n        let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view());\n\n        Zip::from(lhs).and(rhs).map_collect(var_name).into_shared()\n    }\n\n    pub(crate) fn elementwise_op_scalar(\n        lhs: SharedArray<E>,\n        var_name: impl FnMut(E) -> E,\n    ) -> SharedArray<E> {\n        lhs.mapv(var_name).into_shared()\n    }\n\n    pub(crate) fn abs(tensor: SharedArray<E>) -> SharedArray<E> {\n        let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64);\n\n        tensor.mapv_into(|a| a.abs_elem()).into_shared()\n    }\n\n    pub(crate) fn equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {\n        let (lhs, rhs) = dispatch_cmp_simd!(\n            E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64\n        );\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs == rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {\n        let lhs = dispatch_cmp_scalar_simd!(\n            E,\n            VecEquals,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        lhs.mapv(|a| a == rhs).into_shared()\n    }\n\n    pub(crate) fn sign_op(tensor: SharedArray<E>) -> SharedArray<E>\n    where\n        E: Signed,\n    {\n        let zero = 0.elem();\n        let one = 1.elem::<E>();\n\n        tensor\n            .mapv(|x| {\n                if x == zero {\n                    zero\n                } else {\n                    match x.is_positive() {\n                        true => one,\n                        false => -one,\n                    }\n                }\n            })\n            .into_shared()\n    }\n}\n\nimpl<E> NdArrayMathOps<E>\nwhere\n    E: Copy + NdArrayElement + PartialOrd,\n{\n    /// Max of all elements - zero-copy for borrowed storage.\n    pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {\n        let max = view\n            .iter()\n            .copied()\n            .reduce(|a, b| if a > b { a } else { b })\n            .expect(\"Cannot compute max of empty tensor\");\n        ArrayD::from_elem(IxDyn(&[1]), max).into_shared()\n    }\n\n    /// Min of all elements - zero-copy for borrowed storage.\n    pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {\n        let min = view\n            .iter()\n            .copied()\n            .reduce(|a, b| if a < b { a } else { b })\n            .expect(\"Cannot compute min of empty tensor\");\n        ArrayD::from_elem(IxDyn(&[1]), min).into_shared()\n    }\n\n    /// Argmax along dimension - zero-copy for borrowed storage.\n    pub fn argmax_view<I: NdArrayElement + PartialOrd>(\n        view: ArrayView<'_, E, IxDyn>,\n        dim: usize,\n    ) -> SharedArray<I> {\n        arg_view(view, dim, CmpType::Max)\n    }\n\n    /// Argmin along dimension - zero-copy for borrowed storage.\n    pub fn argmin_view<I: NdArrayElement + PartialOrd>(\n        view: ArrayView<'_, E, IxDyn>,\n        dim: usize,\n    ) -> SharedArray<I> {\n        arg_view(view, dim, CmpType::Min)\n    }\n\n    pub fn cummin(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        cummin_dim(tensor, dim)\n    }\n\n    pub fn cummax(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n        cummax_dim(tensor, dim)\n    }\n\n    pub fn argmax<I: NdArrayElement + PartialOrd>(\n        tensor: SharedArray<E>,\n        dim: usize,\n    ) -> SharedArray<I> {\n        arg(tensor, dim, CmpType::Max)\n    }\n\n    pub fn argmin<I: NdArrayElement + PartialOrd>(\n        tensor: SharedArray<E>,\n        dim: usize,\n    ) -> SharedArray<I> {\n        arg(tensor, dim, CmpType::Min)\n    }\n\n    pub fn clamp_min(tensor: SharedArray<E>, min: E) -> SharedArray<E> {\n        let mut tensor = dispatch_binary_scalar_simd!(\n            E,\n            VecMax,\n            tensor,\n            min.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            u64,\n            i64,\n            f64\n        );\n\n        tensor.mapv_inplace(|x| match x < min {\n            true => min,\n            false => x,\n        });\n\n        tensor\n    }\n\n    pub fn clamp_max(tensor: SharedArray<E>, max: E) -> SharedArray<E> {\n        let mut tensor = dispatch_binary_scalar_simd!(\n            E,\n            VecMin,\n            tensor,\n            max.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            u64,\n            i64,\n            f64\n        );\n\n        tensor.mapv_inplace(|x| match x > max {\n            true => max,\n            false => x,\n        });\n\n        tensor\n    }\n\n    pub fn clamp(tensor: SharedArray<E>, min: E, max: E) -> SharedArray<E> {\n        let mut tensor = dispatch_binary_scalar_simd!(\n            E,\n            VecClamp,\n            tensor,\n            (min.elem(), max.elem()),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            i32,\n            f32,\n            u64,\n            i64,\n            f64\n        );\n\n        tensor.mapv_inplace(|x| match x < min {\n            true => min,\n            false => match x > max {\n                true => max,\n                false => x,\n            },\n        });\n\n        tensor\n    }\n\n    pub(crate) fn greater(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {\n        let (lhs, rhs) = dispatch_cmp_simd!(\n            E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64\n        );\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs > rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn greater_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {\n        let lhs = dispatch_cmp_scalar_simd!(\n            E,\n            VecGreater,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        lhs.mapv(|a| a > rhs).into_shared()\n    }\n\n    pub(crate) fn greater_equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {\n        let (lhs, rhs) = dispatch_cmp_simd!(\n            E,\n            VecGreaterEq,\n            lhs,\n            rhs,\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs >= rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn greater_equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {\n        let lhs = dispatch_cmp_scalar_simd!(\n            E,\n            VecGreaterEq,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        lhs.mapv(|a| a >= rhs).into_shared()\n    }\n\n    pub(crate) fn lower_equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {\n        let (lhs, rhs) = dispatch_cmp_simd!(\n            E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64\n        );\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs <= rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn lower_equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {\n        let lhs = dispatch_cmp_scalar_simd!(\n            E,\n            VecLowerEq,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        lhs.mapv(|a| a <= rhs).into_shared()\n    }\n\n    pub(crate) fn lower(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {\n        let (lhs, rhs) = dispatch_cmp_simd!(\n            E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64\n        );\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs < rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn lower_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {\n        let lhs = dispatch_cmp_scalar_simd!(\n            E,\n            VecLower,\n            lhs,\n            rhs.elem(),\n            u8,\n            i8,\n            u16,\n            i16,\n            u32,\n            f32,\n            i32,\n            u64,\n            i64,\n            f64\n        );\n\n        lhs.mapv(|a| a < rhs).into_shared()\n    }\n}\n\npub struct NdArrayBitOps<I: IntNdArrayElement>(PhantomData<I>);\n\nimpl<I: IntNdArrayElement> NdArrayBitOps<I> {\n    pub(crate) fn bitand(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {\n        let (lhs, rhs) =\n            dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);\n\n        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {\n            (a.elem::<i64>() & (b.elem::<i64>())).elem()\n        })\n    }\n\n    pub(crate) fn bitand_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {\n        let lhs = dispatch_binary_scalar_simd!(\n            I,\n            VecBitAnd,\n            lhs,\n            rhs.elem(),\n            i8,\n            u8,\n            i16,\n            u16,\n            i32,\n            u32,\n            i64,\n            u64\n        );\n\n        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {\n            (a.elem::<i64>() & rhs.elem::<i64>()).elem()\n        })\n    }\n\n    pub(crate) fn bitor(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {\n        let (lhs, rhs) =\n            dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);\n\n        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {\n            (a.elem::<i64>() | (b.elem::<i64>())).elem()\n        })\n    }\n\n    pub(crate) fn bitor_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {\n        let lhs = dispatch_binary_scalar_simd!(\n            I,\n            VecBitOr,\n            lhs,\n            rhs.elem(),\n            i8,\n            u8,\n            i16,\n            u16,\n            i32,\n            u32,\n            i64,\n            u64\n        );\n\n        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {\n            (a.elem::<i64>() | rhs.elem::<i64>()).elem()\n        })\n    }\n\n    pub(crate) fn bitxor(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {\n        let (lhs, rhs) =\n            dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);\n\n        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {\n            (a.elem::<i64>() ^ (b.elem::<i64>())).elem()\n        })\n    }\n\n    pub(crate) fn bitxor_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {\n        let lhs = dispatch_binary_scalar_simd!(\n            I,\n            VecBitXor,\n            lhs,\n            rhs.elem(),\n            i8,\n            u8,\n            i16,\n            u16,\n            i32,\n            u32,\n            i64,\n            u64\n        );\n\n        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {\n            (a.elem::<i64>() ^ rhs.elem::<i64>()).elem()\n        })\n    }\n\n    pub(crate) fn bitnot(tensor: SharedArray<I>) -> SharedArray<I> {\n        let tensor =\n            dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64);\n\n        NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::<i64>()).elem())\n    }\n}\n\npub struct NdArrayBoolOps;\n\n// Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would\n// produce invalid values.\nimpl NdArrayBoolOps {\n    pub(crate) fn equal(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {\n        #[cfg(feature = \"simd\")]\n        let (lhs, rhs) = match try_cmp_simd::<bool, u8, VecEquals>(lhs, rhs) {\n            Ok(out) => return out,\n            Err(args) => args,\n        };\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs == rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn equal_elem(lhs: SharedArray<bool>, rhs: bool) -> SharedArray<bool> {\n        #[cfg(feature = \"simd\")]\n        let lhs = match try_cmp_scalar_simd::<bool, u8, VecEquals>(lhs, rhs.elem()) {\n            Ok(out) => return out,\n            Err(args) => args,\n        };\n\n        lhs.mapv(|a| a == rhs).into_shared()\n    }\n\n    pub(crate) fn and(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {\n        #[cfg(feature = \"simd\")]\n        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitAnd>(lhs, rhs) {\n            Ok(out) => return out,\n            Err(args) => args,\n        };\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs && rhs)\n            .into_shared()\n    }\n\n    pub(crate) fn or(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {\n        #[cfg(feature = \"simd\")]\n        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitOr>(lhs, rhs) {\n            Ok(out) => return out,\n            Err(args) => args,\n        };\n\n        // Use the helper to broadcast both arrays to a common shape\n        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);\n        // Now we can safely zip and compare\n        Zip::from(&lhs_broadcast)\n            .and(&rhs_broadcast)\n            .map_collect(|&lhs, &rhs| lhs || rhs)\n            .into_shared()\n    }\n\n    /// Any element is true - zero-copy for borrowed storage.\n    pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool {\n        view.iter().any(|&x| x)\n    }\n\n    /// All elements are true - zero-copy for borrowed storage.\n    pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool {\n        view.iter().all(|&x| x)\n    }\n}\n\nenum CmpType {\n    Min,\n    Max,\n}\n\nfn arg<E: NdArrayElement + PartialOrd, I: NdArrayElement + PartialOrd>(\n    tensor: SharedArray<E>,\n    dim: usize,\n    cmp: CmpType,\n) -> SharedArray<I> {\n    arg_view(tensor.view(), dim, cmp)\n}\n\n/// View-based argmax/argmin - zero-copy for borrowed storage.\nfn arg_view<E: NdArrayElement + PartialOrd, I: NdArrayElement + PartialOrd>(\n    view: ArrayView<'_, E, IxDyn>,\n    dim: usize,\n    cmp: CmpType,\n) -> SharedArray<I> {\n    let mut reshape = view.shape().to_vec();\n    reshape[dim] = 1;\n\n    let output = view.map_axis(Axis(dim), |arr| {\n        // Find the min/max value in the array, and return its index.\n        let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| {\n            let cmp = match cmp {\n                CmpType::Min => e < &acc.0,\n                CmpType::Max => e > &acc.0,\n            };\n\n            if cmp { (*e, idx) } else { acc }\n        });\n\n        (idx as i64).elem()\n    });\n\n    let output = output.to_shape(Dim(reshape.as_slice())).unwrap();\n\n    output.into_shared()\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_backend::TensorData;\n\n    use crate::NdArrayTensor;\n\n    use super::*;\n\n    #[test]\n    fn should_generate_row_major_layout_for_cat() {\n        let expected_shape: &[usize] = &[4, 6, 2];\n        let expected_strides: &[isize] = &[12, 2, 1];\n        let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([\n            [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],\n            [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],\n            [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],\n            [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],\n        ])) else {\n            panic!()\n        };\n        let expected_array = expected_storage.into_shared();\n\n        let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([\n            [1, 2, 3, 4, 5, 6],\n            [7, 8, 9, 10, 11, 12],\n            [13, 14, 15, 16, 17, 18],\n            [19, 20, 21, 22, 23, 24],\n        ])) else {\n            panic!()\n        };\n        let tensor = tensor_storage.into_shared();\n\n        // unsqueeze dim on the outermost axis\n        let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1]));\n        let NdArrayTensor::I32(zeros_storage) =\n            NdArrayTensor::from_data(TensorData::zeros::<i32, _>([4, 6, 1]))\n        else {\n            panic!()\n        };\n        let zeros = zeros_storage.into_shared();\n        // make `ndarray` concatenates array on the outermost axis\n        let array = NdArrayOps::cat([array, zeros].to_vec(), 2);\n\n        assert!(array.is_standard_layout());\n        assert_eq!(array.shape(), expected_shape);\n        assert_eq!(array.strides(), expected_strides);\n        assert_eq!(\n            array.into_iter().collect::<Vec<_>>(),\n            expected_array.into_iter().collect::<Vec<_>>(),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/bool_tensor.rs",
    "content": "// Language\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_backend::Scalar;\nuse burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor};\nuse burn_backend::{\n    backend::ExecutionError,\n    ops::BoolTensorOps,\n    tensor::{BoolTensor, IntTensor},\n};\nuse ndarray::IntoDimension;\n\n// Current crate\nuse crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};\nuse crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};\nuse crate::{NdArrayDevice, SharedArray, slice};\n\n// Workspace crates\nuse burn_backend::{Shape, TensorData, backend::Backend};\n\nuse super::{NdArrayBoolOps, NdArrayOps};\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolTensorOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {\n        if !data.dtype.is_bool() {\n            unimplemented!(\"Unsupported dtype for `bool_from_data`\")\n        }\n        NdArrayTensor::from_data(data)\n    }\n\n    async fn bool_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {\n        Ok(tensor.into_data())\n    }\n\n    fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {\n        tensor\n    }\n\n    fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {\n        NdArrayOps::reshape(tensor.bool(), shape).into()\n    }\n\n    fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {\n        slice!(tensor, slices)\n    }\n\n    fn bool_into_int(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use mapv directly instead of collecting to Vec and going through TensorData\n        let int_array: SharedArray<I> = tensor.bool().mapv(|b| b.elem()).into_shared();\n        int_array.into()\n    }\n\n    fn bool_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {\n        NdArrayDevice::Cpu\n    }\n\n    fn bool_empty(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {\n        Self::bool_zeros(shape, _device)\n    }\n\n    fn bool_zeros(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {\n        let values = vec![false; shape.num_elements()];\n        NdArrayTensor::from_data(TensorData::new(values, shape))\n    }\n\n    fn bool_ones(shape: Shape, _device: &<NdArray<E> as Backend>::Device) -> NdArrayTensor {\n        let values = vec![true; shape.num_elements()];\n        NdArrayTensor::from_data(TensorData::new(values, shape))\n    }\n\n    fn bool_slice_assign(\n        tensor: NdArrayTensor,\n        slices: &[burn_backend::Slice],\n        value: NdArrayTensor,\n    ) -> NdArrayTensor {\n        NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into()\n    }\n\n    fn bool_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {\n        NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into()\n    }\n\n    fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into()\n    }\n\n    fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor {\n        tensor.bool().mapv(|a| !a).into_shared().into()\n    }\n\n    fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into()\n    }\n\n    fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into()\n    }\n\n    fn bool_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {\n        let arr: SharedArray<E> = tensor.bool().mapv(|b| b.elem()).into_shared();\n        arr.into()\n    }\n\n    fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {\n        NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into()\n    }\n\n    fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {\n        tensor.bool().permuted_axes(axes.into_dimension()).into()\n    }\n\n    fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {\n        NdArrayOps::expand(tensor.bool(), shape).into()\n    }\n\n    fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {\n            let tensor_bool = tensor.bool();\n            let indices_vec: Vec<usize> = indices\n                .into_iter()\n                .map(|i| i.elem::<i64>() as usize)\n                .collect();\n\n            let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);\n            selected.into_shared().into()\n        })\n    }\n\n    fn bool_select_or(\n        tensor: NdArrayTensor,\n        dim: usize,\n        indices: NdArrayTensor,\n        value: NdArrayTensor,\n    ) -> NdArrayTensor {\n        execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {\n            let mut output_array = tensor.bool().into_owned();\n            let value_bool = value.bool();\n\n            for (index_value, index) in indices.into_iter().enumerate() {\n                let index_usize = index.elem::<i64>() as usize;\n                let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);\n                let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);\n                // For boolean tensors, select_assign should use logical OR operation\n                view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);\n            }\n            output_array.into_shared().into()\n        })\n    }\n\n    fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {\n        NdArrayOps::flip(tensor.bool(), axes).into()\n    }\n\n    fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor {\n        NdArrayOps::unfold(tensor.bool(), dim, size, step).into()\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into()\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into()\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        execute_with_int_dtype!(indices, |indices| NdArrayOps::gather(\n            dim,\n            tensor.bool(),\n            indices\n        ))\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter(\n            dim,\n            tensor.bool(),\n            indices,\n            value.bool()\n        ))\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into()\n    }\n\n    fn bool_any(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        // Use view() for zero-copy on borrowed storage with short-circuit evaluation\n        let result = NdArrayBoolOps::any_view(tensor.bool().view());\n        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))\n    }\n\n    fn bool_all(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        // Use view() for zero-copy on borrowed storage with short-circuit evaluation\n        let result = NdArrayBoolOps::all_view(tensor.bool().view());\n        NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1])))\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/conv.rs",
    "content": "use burn_backend::{\n    ElementConversion,\n    ops::{\n        ConvOptions, ConvTransposeOptions,\n        conv::{calculate_conv_output_size, calculate_conv_transpose_output_size},\n    },\n};\nuse ndarray::{\n    Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s,\n};\n\nuse crate::{\n    NdArrayElement, SharedArray, iter_par, iter_range_par,\n    ops::padding::{apply_padding_4d, apply_padding_5d},\n    run_par,\n    sharing::UnsafeSharedRef,\n    tensor::NdArrayTensor,\n};\n\n#[inline(always)]\nfn conv2d_mad_inner<E: NdArrayElement>(\n    mut output: ArrayViewMut2<E>,\n    x: ArrayView2<E>,\n    k: E,\n    k_xy: (usize, usize),\n    out_xy: (usize, usize),\n    stride: (usize, usize),\n    dilation: (usize, usize),\n) {\n    let (kh, kw) = k_xy;\n    let (out_width, out_height) = out_xy;\n    let (stride_width, stride_height) = stride;\n    let (dilation_width, dilation_height) = dilation;\n\n    for oh in 0..out_height {\n        // Construct a sub-slice view of the input row.\n        // This is done upfront so that rustc does not have to emit bounds checks\n        // in the hot loop below.\n        let ir = x\n            .row(oh * stride_height + kh * dilation_height)\n            .to_slice()\n            .unwrap();\n\n        // Ditto. Construct a sub-slice view of the output row, and explicitly specify\n        // the bounds upfront as 0..out_width so that rustc can make the assumption\n        // that all accesses are in-bounds in the below loop.\n        let mut or = output.row_mut(oh);\n        let or = &mut or.as_slice_mut().unwrap()[0..out_width];\n\n        #[allow(clippy::needless_range_loop)]\n        for ow in 0..out_width {\n            let iw = ow * stride_width + kw * dilation_width;\n            or[ow] += ir[iw] * k;\n        }\n    }\n}\n\n#[inline(always)]\nfn conv3d_mad_inner<E: NdArrayElement>(\n    mut output: ArrayViewMut3<E>,\n    x: ArrayView3<E>,\n    k: E,\n    k_xyz: (usize, usize, usize),\n    out_xyz: (usize, usize, usize),\n    stride: (usize, usize, usize),\n    dilation: (usize, usize, usize),\n) {\n    let (kd, kh, kw) = k_xyz;\n    let (out_width, out_height, out_depth) = out_xyz;\n    let (stride_width, stride_height, stride_depth) = stride;\n    let (dilation_width, dilation_height, dilation_depth) = dilation;\n\n    for od in 0..out_depth {\n        let id = od * stride_depth + kd * dilation_depth;\n\n        for oh in 0..out_height {\n            let ih = oh * stride_height + kh * dilation_height;\n\n            // Construct a sub-slice view of the input row.\n            // This is done upfront so that rustc does not have to emit bounds checks\n            // in the hot loop below.\n            let ir = x.slice(s![id, ih, ..]).to_slice().unwrap();\n\n            // Ditto. Construct a sub-slice view of the output row, and explicitly specify\n            // the bounds upfront as 0..out_width so that rustc can make the assumption\n            // that all accesses are in-bounds in the below loop.\n            let or = &mut output\n                .slice_mut(s![od, oh, 0..out_width])\n                .into_slice()\n                .unwrap()[0..out_width];\n\n            #[allow(clippy::needless_range_loop)]\n            for ow in 0..out_width {\n                let iw = ow * stride_width + kw * dilation_width;\n                or[ow] += ir[iw] * k;\n            }\n        }\n    }\n}\n\npub(crate) fn conv2d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    weight: SharedArray<E>,\n    bias: Option<SharedArray<E>>,\n    options: ConvOptions<2>,\n) -> SharedArray<E>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n{\n    let [dilation_height, dilation_width] = options.dilation;\n    let [padding_height, padding_width] = options.padding;\n    let [stride_height, stride_width] = options.stride;\n    let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();\n    let [out_channels, in_channels, kernel_height, kernel_width] =\n        weight.shape().try_into().unwrap();\n    let channels_per_group = out_channels / options.groups;\n\n    let out_height = calculate_conv_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        dilation_height,\n        in_height,\n    );\n    let out_width = calculate_conv_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        dilation_width,\n        in_width,\n    );\n\n    let x = apply_padding_4d::<E>(x, options.padding, 0i32.elem());\n\n    // Convert inputs from dynamic indexes to static to improve perf.\n    let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();\n    let weights = weight.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width]));\n\n    run_par!(|| {\n        iter_par!(output.axis_iter_mut(Axis(0)))\n            .enumerate()\n            .for_each(\n                #[inline(never)]\n                |(k, mut output)| {\n                    let b = k / out_channels;\n                    let oc = k % out_channels;\n                    let g = oc / channels_per_group;\n\n                    for ic in (in_channels * g)..(in_channels * (g + 1)) {\n                        let weight_ic = ic - (g * in_channels);\n\n                        let x = x.slice(s![b, ic, .., ..]);\n                        let k = weights.slice(s![oc, weight_ic, .., ..]);\n\n                        for kh in 0..kernel_height {\n                            for kw in 0..kernel_width {\n                                let k = k[[kh, kw]];\n\n                                // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization\n                                // in the case that the stride/dilation is 1.\n                                #[allow(clippy::if_same_then_else)]\n                                if (1, 1, 1, 1)\n                                    == (\n                                        stride_width,\n                                        stride_height,\n                                        dilation_width,\n                                        dilation_height,\n                                    )\n                                {\n                                    conv2d_mad_inner(\n                                        output.view_mut(),\n                                        x.view(),\n                                        k,\n                                        (kh, kw),\n                                        (out_width, out_height),\n                                        (stride_width, stride_height),\n                                        (dilation_width, dilation_height),\n                                    );\n                                } else {\n                                    conv2d_mad_inner(\n                                        output.view_mut(),\n                                        x.view(),\n                                        k,\n                                        (kh, kw),\n                                        (out_width, out_height),\n                                        (stride_width, stride_height),\n                                        (dilation_width, dilation_height),\n                                    );\n                                }\n                            }\n                        }\n                    }\n\n                    if let Some(bias) = &bias {\n                        let bias = bias[oc];\n\n                        for oh in 0..out_height {\n                            // Get a mutable slice reference to the row we're looping over.\n                            // We explicitly define the bounds to 0..out_width so that rustc can make\n                            // the assumption that all accesses are in-bounds.\n                            let mut or = output.row_mut(oh);\n                            let or = &mut or.as_slice_mut().unwrap()[0..out_width];\n\n                            #[allow(clippy::needless_range_loop)]\n                            for ow in 0..out_width {\n                                or[ow] += bias;\n                            }\n                        }\n                    }\n                },\n            );\n    });\n\n    output\n        .to_shape([batch_size, out_channels, out_height, out_width])\n        .unwrap()\n        .into_dyn()\n        .into_shared()\n}\n\npub(crate) fn conv_transpose2d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    weight: SharedArray<E>,\n    bias: Option<SharedArray<E>>,\n    options: ConvTransposeOptions<2>,\n) -> SharedArray<E> {\n    let [dilation_height, dilation_width] = options.dilation;\n    let [padding_height, padding_width] = options.padding;\n    let [stride_height, stride_width] = options.stride;\n    let [out_padding_height, out_padding_width] = options.padding_out;\n    let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();\n    let [in_channels, out_channels, kernel_height, kernel_width] =\n        weight.shape().try_into().unwrap();\n\n    let out_height = calculate_conv_transpose_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        out_padding_height,\n        dilation_height,\n        in_height,\n    );\n    let out_width = calculate_conv_transpose_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        out_padding_width,\n        dilation_width,\n        in_width,\n    );\n\n    let x = x;\n    let mut output = Array4::zeros(Dim([\n        batch_size,\n        out_channels * options.groups,\n        out_height,\n        out_width,\n    ]));\n\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {\n            let b = k / (out_channels * options.groups);\n            let oc = k % out_channels;\n            let g = (k / out_channels) % options.groups;\n\n            let output = unsafe_shared_out.get();\n\n            let oc_out = oc + (out_channels * g);\n            let ic_start = g * (in_channels / options.groups);\n            let ic_end = ic_start + in_channels / options.groups;\n\n            for ic in ic_start..ic_end {\n                for ih in 0..in_height {\n                    for iw in 0..in_width {\n                        for kh in 0..kernel_height {\n                            for kw in 0..kernel_width {\n                                let oh = ih * stride_height + kh * dilation_height;\n                                let ow = iw * stride_width + kw * dilation_width;\n\n                                if oh >= out_height + padding_height\n                                    || ow >= out_width + padding_width\n                                    || oh < padding_height\n                                    || ow < padding_width\n                                {\n                                    continue;\n                                }\n\n                                let oh = oh - padding_height;\n                                let ow = ow - padding_width;\n\n                                output[[b, oc_out, oh, ow]] +=\n                                    x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]];\n                            }\n                        }\n                    }\n                }\n            }\n\n            if let Some(bias) = &bias {\n                for oh in 0..out_height {\n                    for ow in 0..out_width {\n                        output[[b, oc_out, oh, ow]] += bias[oc_out];\n                    }\n                }\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn conv3d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    weight: SharedArray<E>,\n    bias: Option<SharedArray<E>>,\n    options: ConvOptions<3>,\n) -> SharedArray<E>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n{\n    let [dilation_depth, dilation_height, dilation_width] = options.dilation;\n    let [padding_depth, padding_height, padding_width] = options.padding;\n    let [stride_depth, stride_height, stride_width] = options.stride;\n    let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();\n    let [\n        out_channels,\n        in_channels,\n        kernel_depth,\n        kernel_height,\n        kernel_width,\n    ] = weight.shape().try_into().unwrap();\n    let out_c_per_group = out_channels / options.groups;\n\n    let out_depth = calculate_conv_output_size(\n        kernel_depth,\n        stride_depth,\n        padding_depth,\n        dilation_depth,\n        in_depth,\n    );\n    let out_height = calculate_conv_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        dilation_height,\n        in_height,\n    );\n    let out_width = calculate_conv_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        dilation_width,\n        in_width,\n    );\n\n    let x = apply_padding_5d::<E>(x, options.padding, 0i32.elem());\n\n    // Convert inputs from dynamic indexes to static to improve perf.\n    let x = x.into_dimensionality::<ndarray::Ix5>().unwrap();\n    let weights = weight.into_dimensionality::<ndarray::Ix5>().unwrap();\n\n    let mut output = Array4::zeros(Dim([\n        batch_size * out_channels,\n        out_depth,\n        out_height,\n        out_width,\n    ]));\n\n    run_par!(|| {\n        iter_par!(output.axis_iter_mut(Axis(0)))\n            .enumerate()\n            .for_each(\n                #[inline(never)]\n                |(k, mut output)| {\n                    let b = k / out_channels;\n                    let oc = k % out_channels;\n                    let g = oc / out_c_per_group;\n\n                    for ic in (in_channels * g)..(in_channels * (g + 1)) {\n                        let weight_ic = ic - (g * in_channels);\n\n                        let x = x.slice(s![b, ic, .., .., ..]);\n                        let k = weights.slice(s![oc, weight_ic, .., .., ..]);\n\n                        for kd in 0..kernel_depth {\n                            for kh in 0..kernel_height {\n                                for kw in 0..kernel_width {\n                                    let k = k[[kd, kh, kw]];\n\n                                    // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization\n                                    // in the case that the stride/dilation is 1.\n                                    #[allow(clippy::if_same_then_else)]\n                                    if (1, 1, 1, 1, 1, 1)\n                                        == (\n                                            stride_width,\n                                            stride_height,\n                                            stride_depth,\n                                            dilation_width,\n                                            dilation_height,\n                                            dilation_depth,\n                                        )\n                                    {\n                                        conv3d_mad_inner(\n                                            output.view_mut(),\n                                            x.view(),\n                                            k,\n                                            (kd, kh, kw),\n                                            (out_width, out_height, out_depth),\n                                            (stride_width, stride_height, stride_depth),\n                                            (dilation_width, dilation_height, dilation_depth),\n                                        );\n                                    } else {\n                                        conv3d_mad_inner(\n                                            output.view_mut(),\n                                            x.view(),\n                                            k,\n                                            (kd, kh, kw),\n                                            (out_width, out_height, out_depth),\n                                            (stride_width, stride_height, stride_depth),\n                                            (dilation_width, dilation_height, dilation_depth),\n                                        );\n                                    }\n                                }\n                            }\n                        }\n                    }\n\n                    if let Some(bias) = &bias {\n                        let bias = bias[oc];\n\n                        // Get a mutable iterator to the row we're looping over.\n                        let orows = output.rows_mut();\n                        for mut or in orows {\n                            // We explicitly define the bounds to 0..out_width so that rustc can make\n                            // the assumption that all accesses are in-bounds.\n                            let or = &mut or.as_slice_mut().unwrap()[0..out_width];\n\n                            #[allow(clippy::needless_range_loop)]\n                            for ow in 0..out_width {\n                                or[ow] += bias;\n                            }\n                        }\n                    }\n                },\n            );\n    });\n\n    output\n        .to_shape([batch_size, out_channels, out_depth, out_height, out_width])\n        .unwrap()\n        .into_dyn()\n        .into_shared()\n}\n\npub(crate) fn conv_transpose3d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    weight: SharedArray<E>,\n    bias: Option<SharedArray<E>>,\n    options: ConvTransposeOptions<3>,\n) -> SharedArray<E> {\n    let [dilation_depth, dilation_height, dilation_width] = options.dilation;\n    let [padding_depth, padding_height, padding_width] = options.padding;\n    let [stride_depth, stride_height, stride_width] = options.stride;\n    let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out;\n    let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap();\n    let [\n        in_channels,\n        out_channels,\n        kernel_depth,\n        kernel_height,\n        kernel_width,\n    ] = weight.shape().try_into().unwrap();\n\n    let out_depth = calculate_conv_transpose_output_size(\n        kernel_depth,\n        stride_depth,\n        padding_depth,\n        out_padding_depth,\n        dilation_depth,\n        in_depth,\n    );\n    let out_height = calculate_conv_transpose_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        out_padding_height,\n        dilation_height,\n        in_height,\n    );\n    let out_width = calculate_conv_transpose_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        out_padding_width,\n        dilation_width,\n        in_width,\n    );\n\n    let x = x;\n    let mut output = Array5::zeros(Dim([\n        batch_size,\n        out_channels * options.groups,\n        out_depth,\n        out_height,\n        out_width,\n    ]));\n\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe {\n            let b = k / (out_channels * options.groups);\n            let oc = k % out_channels;\n            let g = (k / out_channels) % options.groups;\n\n            let output = unsafe_shared_out.get();\n\n            let oc_out = oc + (out_channels * g);\n            let ic_start = g * (in_channels / options.groups);\n            let ic_end = ic_start + in_channels / options.groups;\n\n            for ic in ic_start..ic_end {\n                for id in 0..in_depth {\n                    for ih in 0..in_height {\n                        for iw in 0..in_width {\n                            for kd in 0..kernel_depth {\n                                for kh in 0..kernel_height {\n                                    for kw in 0..kernel_width {\n                                        let od = id * stride_depth + kd * dilation_depth;\n                                        let oh = ih * stride_height + kh * dilation_height;\n                                        let ow = iw * stride_width + kw * dilation_width;\n\n                                        if od >= out_depth + padding_depth\n                                            || oh >= out_height + padding_height\n                                            || ow >= out_width + padding_width\n                                            || od < padding_depth\n                                            || oh < padding_height\n                                            || ow < padding_width\n                                        {\n                                            continue;\n                                        }\n\n                                        let od = od - padding_depth;\n                                        let oh = oh - padding_height;\n                                        let ow = ow - padding_width;\n\n                                        output[[b, oc_out, od, oh, ow]] +=\n                                            x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]];\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n\n            if let Some(bias) = &bias {\n                for od in 0..out_depth {\n                    for oh in 0..out_height {\n                        for ow in 0..out_width {\n                            output[[b, oc_out, od, oh, ow]] += bias[oc_out];\n                        }\n                    }\n                }\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/deform_conv.rs",
    "content": "use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size};\nuse core::ops::AddAssign;\nuse ndarray::{\n    Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4,\n    Zip, s,\n};\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par};\n\nuse super::matmul::matmul;\n\n#[inline(always)]\n#[allow(clippy::too_many_arguments)]\nfn deform_im2col_kernel<F: FloatNdArrayElement>(\n    out_y: usize,\n    out_x: usize,\n    input: ArrayView2<F>,\n    offset: ArrayView3<F>,\n    mask: Option<ArrayView2<F>>,\n    mut columns: ArrayViewMut2<F>,\n    args: DeformConvOptions<2>,\n    (kernel_h, kernel_w): (usize, usize),\n) {\n    // position shape: [in_channels, batch_size, out_h, out_w]\n    // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]\n\n    let (height, width) = input.dim();\n\n    for kernel_y in 0..kernel_h {\n        for kernel_x in 0..kernel_w {\n            let mask_value = mask\n                .map(|it| it[[kernel_y, kernel_x]])\n                .unwrap_or_else(|| F::from_elem(1.0));\n\n            let offset = offset.slice(s![kernel_y, kernel_x, ..]);\n            let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])\n                - F::from_elem(args.padding[0])\n                + offset[0];\n            let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])\n                - F::from_elem(args.padding[1])\n                + offset[1];\n\n            let interpolated = bilinear_interpolate(input, height, width, y, x);\n\n            columns[[kernel_y, kernel_x]] = mask_value * interpolated;\n        }\n    }\n}\n\nfn bilinear_interpolate<F: FloatNdArrayElement>(\n    input: ArrayView2<F>,\n    height: usize,\n    width: usize,\n    y: F,\n    x: F,\n) -> F {\n    // To simplify code\n    let y = y.to_f32();\n    let x = x.to_f32();\n\n    let mut result = F::from_elem(0.0);\n    if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {\n        let y_low = f32::floor(y);\n        let x_low = f32::floor(x);\n        let y_high = (y_low + 1.) as usize;\n        let x_high = (x_low + 1.) as usize;\n\n        let zero = F::from_elem(0.0);\n        let v1: F = if y_low >= 0. && x_low >= 0. {\n            input[[y_low as usize, x_low as usize]]\n        } else {\n            zero\n        };\n        let v2: F = if y_low >= 0. && x_high < width {\n            input[[y_low as usize, x_high]]\n        } else {\n            zero\n        };\n        let v3: F = if y_high < height && x_low >= 0. {\n            input[[y_high, x_low as usize]]\n        } else {\n            zero\n        };\n        let v4: F = if y_high < height && x_high < width {\n            input[[y_high, x_high]]\n        } else {\n            zero\n        };\n\n        let l_y = y - y_low;\n        let l_x = x - x_low;\n        let h_y = 1.0 - l_y;\n        let h_x = 1.0 - l_x;\n\n        let w1 = F::from_elem(h_y * h_x);\n        let w2 = F::from_elem(h_y * l_x);\n        let w3 = F::from_elem(l_y * h_x);\n        let w4 = F::from_elem(l_y * l_x);\n\n        result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;\n    }\n    result\n}\n\npub(crate) fn deform_conv2d<F: FloatNdArrayElement>(\n    input: SharedArray<F>,\n    offset: SharedArray<F>,\n    weight: SharedArray<F>,\n    mask: Option<SharedArray<F>>,\n    bias: Option<SharedArray<F>>,\n    args: DeformConvOptions<2>,\n) -> SharedArray<F>\nwhere\n    NdArrayTensor: From<SharedArray<F>>,\n{\n    let [batch_size, _, in_height, in_width] = input.shape().dims();\n    let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims();\n    let groups = args.weight_groups;\n\n    let weight = weight.as_standard_layout();\n\n    let out_h = calculate_conv_output_size(\n        kernel_h,\n        args.stride[0],\n        args.padding[0],\n        args.dilation[0],\n        in_height,\n    );\n    let out_w = calculate_conv_output_size(\n        kernel_w,\n        args.stride[1],\n        args.padding[1],\n        args.dilation[1],\n        in_width,\n    );\n    let out_dims = (out_h, out_w);\n\n    let input = input.into_dimensionality::<Ix4>().unwrap();\n    let offset = offset.into_dimensionality::<Ix4>().unwrap();\n    let mask = mask.as_ref().map(|it| {\n        it.to_shape((\n            batch_size,\n            args.offset_groups,\n            kernel_h,\n            kernel_w,\n            out_h,\n            out_w,\n        ))\n        .unwrap()\n    });\n\n    let columns = deform_im2col(\n        input.view(),\n        offset.view(),\n        mask.as_ref().map(|it| it.view()),\n        args,\n        out_dims,\n        (kernel_h, kernel_w),\n    );\n\n    let (col_size_0, col_size_1) = columns.dim();\n    let col_size_0 = col_size_0 / groups;\n    let out_c_per_group = out_channels / groups;\n\n    let weight = weight\n        .to_shape((groups, out_c_per_group, col_size_0))\n        .unwrap();\n    let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();\n    let out = matmul(\n        weight.to_owned().into_dyn().into_shared(),\n        columns.to_owned().into_dyn().into_shared(),\n    );\n\n    let mut out = out\n        .into_shape_with_order((out_channels, batch_size, out_h, out_w))\n        .unwrap();\n    out.swap_axes(0, 1);\n\n    if let Some(bias) = bias {\n        let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap();\n        out.add_assign(&bias);\n    }\n\n    out.into_dyn().into_shared()\n}\n\npub(crate) fn deform_im2col<F: FloatNdArrayElement>(\n    input: ArrayView4<F>,\n    offset: ArrayView4<F>,\n    mask: Option<ArrayView6<F>>,\n    args: DeformConvOptions<2>,\n    out_dims: (usize, usize),\n    kernel_dims: (usize, usize),\n) -> Array2<F> {\n    let (batch_size, in_channels, _, _) = input.dim();\n    let (kernel_h, kernel_w) = kernel_dims;\n    let (out_h, out_w) = out_dims;\n    let channels_per_offset_group = in_channels / args.offset_groups;\n\n    let mut columns = Array4::zeros(Dim([\n        in_channels,\n        kernel_h,\n        kernel_w,\n        batch_size * out_h * out_w,\n    ]));\n\n    let groups = args.offset_groups;\n\n    run_par!(|| {\n        iter_par!(columns.axis_iter_mut(Axis(3)))\n            .enumerate()\n            .for_each(|(index, mut columns)| {\n                let out_x = index % out_w;\n                let out_y = (index / out_w) % out_h;\n                let batch = (index / (out_w * out_h)) % batch_size;\n                let offset = offset.slice(s![batch, .., out_y, out_x]);\n                let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap();\n                let mask = mask\n                    .as_ref()\n                    .map(|it| it.slice(s![batch, .., .., .., out_y, out_x]));\n                columns\n                    .axis_iter_mut(Axis(0))\n                    .enumerate()\n                    .for_each(|(in_channel, mut columns)| {\n                        let group_index = in_channel / channels_per_offset_group;\n                        deform_im2col_kernel(\n                            out_y,\n                            out_x,\n                            input.slice(s![batch, in_channel, .., ..]),\n                            offset.slice(s![group_index, .., .., ..]),\n                            mask.as_ref().map(|it| it.slice(s![group_index, .., ..])),\n                            columns.view_mut(),\n                            args.clone(),\n                            kernel_dims,\n                        );\n                    });\n            });\n    });\n\n    columns\n        // Columns is created here, so we know it's contiguous\n        .into_shape_with_order((\n            in_channels * kernel_h * kernel_w,\n            batch_size * out_h * out_w,\n        ))\n        .unwrap()\n}\n\npub mod backward {\n    #[cfg(target_has_atomic = \"32\")]\n    use core::sync::atomic::Ordering;\n\n    use atomic_float::AtomicF32;\n    use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};\n\n    use super::*;\n\n    pub(crate) type DeformConv2dBackward<F> = (\n        SharedArray<F>,\n        SharedArray<F>,\n        SharedArray<F>,\n        Option<SharedArray<F>>,\n        Option<SharedArray<F>>,\n    );\n\n    /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.\n    pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement>(\n        input: SharedArray<F>,\n        offset: SharedArray<F>,\n        weight: SharedArray<F>,\n        mask: Option<SharedArray<F>>,\n        bias: Option<SharedArray<F>>,\n        out_grad: SharedArray<F>,\n        args: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<F> {\n        let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims();\n        let [_, _, kernel_h, kernel_w] = weight.shape().dims();\n        let groups = args.weight_groups;\n        let out_c_per_group = out_channels / groups;\n        let col_shape_1 = batch_size * out_h * out_w;\n        let mut out_grad = out_grad.into_dimensionality::<Ix4>().unwrap();\n\n        let gradient_bias = bias.map(|_| {\n            let out_grad = out_grad\n                .clone()\n                .sum_axis(Axis(0))\n                .sum_axis(Axis(1))\n                .sum_axis(Axis(1));\n\n            out_grad.into_dyn().into_shared()\n        });\n\n        out_grad.swap_axes(0, 1);\n        let out_grad = out_grad\n            .to_shape((groups, out_c_per_group, col_shape_1))\n            .unwrap();\n\n        let input = input.into_dimensionality::<Ix4>().unwrap();\n        let offset = offset.into_dimensionality::<Ix4>().unwrap();\n        let mask = mask.map(|it| {\n            it.into_shape_with_order((\n                batch_size,\n                args.offset_groups,\n                kernel_h,\n                kernel_w,\n                out_h,\n                out_w,\n            ))\n            .unwrap()\n        });\n\n        let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(\n            input.view(),\n            weight,\n            offset.view(),\n            mask.as_ref().map(|it| it.view()),\n            out_grad.view(),\n            &args,\n            (kernel_h, kernel_w),\n        );\n\n        let weight_grad = compute_weight_grad(\n            input.view(),\n            offset.view(),\n            mask.as_ref().map(|it| it.view()),\n            out_grad.view(),\n            args,\n            (kernel_h, kernel_w),\n            (out_h, out_w),\n        );\n\n        (\n            input_gradient,\n            offset_gradient,\n            weight_grad,\n            mask_gradient,\n            gradient_bias,\n        )\n    }\n\n    fn compute_weight_grad<F: FloatNdArrayElement>(\n        input: ArrayView4<F>,\n        offset: ArrayView4<F>,\n        mask: Option<ArrayView6<F>>,\n        out_grad: ArrayView3<F>,\n        options: DeformConvOptions<2>,\n        kernel_dims: (usize, usize),\n        out_dims: (usize, usize),\n    ) -> SharedArray<F> {\n        let in_channels = input.dim().1;\n        let (groups, out_c_per_group, _) = out_grad.dim();\n        let (kernel_h, kernel_w) = kernel_dims;\n\n        let in_c_per_group = in_channels / groups;\n\n        let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);\n        let (col_size_0, col_size_1) = columns.dim();\n        let col_size_0 = col_size_0 / groups;\n\n        let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();\n        columns.swap_axes(1, 2);\n\n        let grad_weight = matmul(\n            out_grad.to_owned().into_dyn().into_shared(),\n            columns.to_owned().into_dyn().into_shared(),\n        );\n\n        let grad_weight = grad_weight\n            .into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w))\n            .unwrap();\n        grad_weight.into_dyn().into_shared()\n    }\n\n    type InputGradients<F> = (SharedArray<F>, SharedArray<F>, Option<SharedArray<F>>);\n\n    fn backward_gradient_inputs<F: FloatNdArrayElement>(\n        image: ArrayView4<F>,\n        weight: SharedArray<F>,\n        offset: ArrayView4<F>,\n        mask: Option<ArrayView6<F>>,\n        out_grad: ArrayView3<F>,\n        args: &DeformConvOptions<2>,\n        kernel_dims: (usize, usize),\n    ) -> InputGradients<F> {\n        let input_shape = image.dim();\n        let in_channels = input_shape.1;\n        let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims();\n        let (batch_size, _, out_h, out_w) = offset.dim();\n\n        let groups = args.weight_groups;\n        let out_c_per_group = out_channels / groups;\n\n        let col_shape_0 = in_c_per_group * kernel_h * kernel_w;\n\n        let mut weight = weight\n            .to_shape((groups, out_c_per_group, col_shape_0))\n            .unwrap();\n        weight.swap_axes(1, 2);\n        let columns = matmul(\n            weight.to_owned().into_dyn().into_shared(),\n            out_grad.to_owned().into_dyn().into_shared(),\n        );\n\n        let columns = columns\n            .to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w))\n            .unwrap();\n\n        let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(\n            columns.view(),\n            image.view(),\n            offset,\n            mask,\n            args,\n            kernel_dims,\n        );\n\n        let input_gradient =\n            compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape);\n\n        (input_gradient, offset_gradient, mask_gradient)\n    }\n\n    fn compute_offset_and_mask_gradient<F: FloatNdArrayElement>(\n        columns: ArrayView6<F>,\n        image: ArrayView4<F>,\n        offset: ArrayView4<F>,\n        mask: Option<ArrayView6<F>>,\n        args: &DeformConvOptions<2>,\n        kernel_dims: (usize, usize),\n    ) -> (SharedArray<F>, Option<SharedArray<F>>) {\n        let (kernel_h, kernel_w) = kernel_dims;\n        let (_, in_channels, height, width) = image.dim();\n        let (batch_size, offset_channels, out_h, out_w) = offset.dim();\n        let offs_groups = args.offset_groups;\n        let channels_per_offset_group = in_channels / args.offset_groups;\n\n        let mut grad_offset = Array5::zeros((\n            offs_groups,\n            kernel_h,\n            kernel_w,\n            2,\n            batch_size * out_h * out_w,\n        ));\n        let mut grad_mask =\n            Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w));\n\n        grad_mask\n            .axis_iter_mut(Axis(3))\n            .zip(grad_offset.axis_iter_mut(Axis(4)))\n            .enumerate()\n            .for_each(|(index, (mut grad_mask, mut grad_offset))| {\n                let out_x = index % out_w;\n                let out_y = (index / out_w) % out_h;\n                let batch = index / (out_w * out_h);\n                let offset = offset.slice(s![batch, .., out_y, out_x]);\n                let offset = offset\n                    .to_shape((offs_groups, kernel_h, kernel_w, 2))\n                    .unwrap();\n                let mask: Option<ArrayView3<F>> = mask\n                    .as_ref()\n                    .map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x]));\n                let columns = columns.slice(s![.., .., .., batch, out_y, out_x]);\n                let image = image.slice(s![batch, .., .., ..]);\n\n                for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() {\n                    let grad_mask: &mut F = grad_mask;\n                    let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]);\n                    let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);\n                    let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]);\n                    let columns = columns.slice(s![.., kernel_y, kernel_x]);\n                    let group_offset = group * channels_per_offset_group;\n                    let image = image.slice(s![group_offset.., .., ..]);\n                    let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])\n                        - F::from_elem(args.padding[0])\n                        + offset[0];\n                    let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])\n                        - F::from_elem(args.padding[1])\n                        + offset[1];\n                    for (i, grad_offset) in grad_offset.iter_mut().enumerate() {\n                        let is_y_direction = i % 2 == 0;\n                        let use_mask = mask.is_some();\n\n                        for channel in 0..channels_per_offset_group {\n                            let mask = mask.unwrap_or_else(|| F::one());\n                            let image = image.index_axis(Axis(0), channel);\n                            let weight =\n                                get_coordinate_weight(image, height, width, y, x, is_y_direction);\n                            *grad_offset += mask * weight * columns[channel];\n                            if use_mask && is_y_direction {\n                                *grad_mask += columns[channel]\n                                    * bilinear_interpolate(image, height, width, y, x);\n                            }\n                        }\n                    }\n                }\n            });\n\n        let mask_gradient = mask.map(|_| {\n            let mut grad_mask = grad_mask\n                .into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w))\n                .unwrap();\n            grad_mask.swap_axes(0, 1);\n            grad_mask.into_dyn().into_shared()\n        });\n        let mut grad_offset = grad_offset\n            .into_shape_with_order((offset_channels, batch_size, out_h, out_w))\n            .unwrap();\n        grad_offset.swap_axes(0, 1);\n        let offset_gradient = grad_offset.into_dyn().into_shared();\n        (offset_gradient, mask_gradient)\n    }\n\n    fn get_coordinate_weight<F: FloatNdArrayElement>(\n        input: ArrayView2<F>,\n        height: usize,\n        width: usize,\n        y: F,\n        x: F,\n        is_y_direction: bool,\n    ) -> F {\n        let y = y.to_f32();\n        let x = x.to_f32();\n\n        let y_low = f32::floor(y);\n        let x_low = f32::floor(x);\n        let y_high = y_low + 1.;\n        let x_high = x_low + 1.;\n\n        let valid_y_low = y_low >= 0. && y_low < height as f32;\n        let valid_y_high = y_high >= 0. && y_high < height as f32;\n        let valid_x_low = x_low >= 0. && x_low < width as f32;\n        let valid_x_high = x_high >= 0. && x_high < width as f32;\n\n        let bottom_left = if valid_y_low && valid_x_low {\n            input[[y_low as usize, x_low as usize]]\n        } else {\n            F::zero()\n        };\n        let bottom_right = if valid_y_low && valid_x_high {\n            input[[y_low as usize, x_high as usize]]\n        } else {\n            F::zero()\n        };\n        let top_left = if valid_y_high && valid_x_low {\n            input[[y_high as usize, x_low as usize]]\n        } else {\n            F::zero()\n        };\n        let top_right = if valid_y_high && valid_x_high {\n            input[[y_high as usize, x_high as usize]]\n        } else {\n            F::zero()\n        };\n\n        if is_y_direction {\n            let delta_x = F::from_elem(x - x_low);\n            delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left)\n        } else {\n            let delta_y = F::from_elem(y - y_low);\n            delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left)\n        }\n    }\n\n    fn compute_input_grad<F: FloatNdArrayElement>(\n        columns: ArrayView6<F>,\n        offset: ArrayView4<F>,\n        mask: Option<ArrayView6<F>>,\n        args: &DeformConvOptions<2>,\n        kernel_dims: (usize, usize),\n        input_shape: (usize, usize, usize, usize),\n    ) -> SharedArray<F> {\n        let (batch_size, in_channels, height, width) = input_shape;\n        let (kernel_h, kernel_w) = kernel_dims;\n        let offs_groups = args.offset_groups;\n        let channels_per_offset_group = in_channels / offs_groups;\n\n        let grad_in =\n            Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || {\n                AtomicF32::new(0.0)\n            });\n\n        let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| {\n            let group = in_channel / channels_per_offset_group;\n            let offset = offset.slice(s![batch, .., out_y, out_x]);\n            let offset = offset\n                .to_shape((offs_groups, kernel_h, kernel_w, 2))\n                .unwrap();\n            let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);\n            let offset = [offset[0], offset[1]];\n            let mask = mask\n                .as_ref()\n                .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());\n            let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])\n                - F::from_elem(args.padding[0])\n                + offset[0];\n            let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])\n                - F::from_elem(args.padding[1])\n                + offset[1];\n            let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);\n            deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);\n        };\n\n        // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise\n        #[cfg(feature = \"multi-threads\")]\n        run_par!(|| {\n            iter_par!(Zip::indexed(columns))\n                .for_each(|(args0, args1)| compute_for_each(args0, args1))\n        });\n\n        #[cfg(not(feature = \"multi-threads\"))]\n        run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) });\n\n        let grad_in: Array1<F> = grad_in\n            .into_iter()\n            .map(|it| F::from_elem(it.into_inner()))\n            .collect();\n        let grad_in = grad_in\n            .into_shape_with_order((batch_size, in_channels, height, width))\n            .unwrap();\n        grad_in.into_dyn().into_shared()\n    }\n\n    fn deform_col2img_kernel(\n        y: f32,\n        x: f32,\n        mask: Option<f32>,\n        col: f32,\n        grad_input: ArrayView2<AtomicF32>,\n    ) {\n        let (height, width) = grad_input.dim();\n        let mask_value = mask.unwrap_or(1.0);\n\n        for dy in -1..=1 {\n            for dx in -1..=1 {\n                let yp = f32::floor(y) + dy as f32;\n                let xp = f32::floor(x) + dx as f32;\n\n                if yp >= 0.0\n                    && yp < height as f32\n                    && xp >= 0.0\n                    && xp < width as f32\n                    && f32::abs(y - yp) < 1.0\n                    && f32::abs(x - xp) < 1.0\n                {\n                    let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));\n\n                    #[cfg_attr(not(target_has_atomic = \"32\"), allow(unused))]\n                    let value = mask_value * weight * col;\n\n                    #[cfg(target_has_atomic = \"32\")]\n                    grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel);\n                    #[cfg(not(target_has_atomic = \"32\"))]\n                    panic!(\"Can't use deformable convolution backwards pass without atomics\");\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/grid_sample.rs",
    "content": "use burn_backend::ElementConversion;\nuse burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse ndarray::Array4;\n\nuse crate::SharedArray;\nuse crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par};\n\n/// Sample a tensor using grid-based sampling.\n///\n/// # Arguments\n///\n/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)\n/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].\n///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right\n/// * `options` - Grid sampling options (mode, padding_mode, align_corners)\n///\n/// # Returns\n///\n/// A tensor with shape (N, C, H_out, W_out)\npub(crate) fn grid_sample_2d<E: FloatNdArrayElement>(\n    tensor: SharedArray<E>,\n    grid: SharedArray<E>,\n    options: GridSampleOptions,\n) -> SharedArray<E> {\n    match options.mode {\n        InterpolateMode::Bilinear => (),\n        _ => todo!(\n            \"grid_sample_2d with {:?} mode is not implemented\",\n            options.mode\n        ),\n    }\n\n    let tensor = tensor.into_dimensionality::<ndarray::Ix4>().unwrap();\n    let grid = grid.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let (batch_size, channels, height_in, width_in) = tensor.dim();\n    let (b, height_out, width_out, d) = grid.dim();\n    assert!(batch_size == b);\n    assert!(2 == d);\n\n    let mut output = Array4::zeros((batch_size, channels, height_out, width_out));\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    let sample_count = batch_size * channels * height_out * width_out;\n    let strides = (\n        channels * height_out * width_out,\n        height_out * width_out,\n        width_out,\n    );\n\n    let align = options.align_corners;\n    let pad_mode = options.padding_mode;\n\n    run_par!(|| {\n        iter_range_par!(0, sample_count).for_each(|id| {\n            let (b, c, y, x) = (\n                id / strides.0,\n                id % strides.0 / strides.1,\n                id % strides.1 / strides.2,\n                id % strides.2,\n            );\n\n            let sample_x = grid[(b, y, x, 0)].elem::<f64>();\n            let sample_y = grid[(b, y, x, 1)].elem::<f64>();\n\n            // Convert normalized grid coordinates [-1, 1] to pixel coordinates\n            let (px, py) = if align {\n                // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2\n                // Maps -1 to 0 and 1 to width - 1\n                let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0;\n                let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0;\n                (px, py)\n            } else {\n                // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5\n                // Maps -1 to -0.5 and 1 to width - 0.5\n                let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5;\n                let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5;\n                (px, py)\n            };\n\n            // Bilinear interpolation with the specified padding mode\n            let val =\n                bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align);\n\n            unsafe {\n                let output = unsafe_shared_out.get();\n                output[(b, c, y, x)] = val.elem();\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\n/// Bilinear interpolation at a point with configurable padding mode.\n#[allow(clippy::too_many_arguments)]\nfn bilinear_interpolate<E, S>(\n    source: &ndarray::ArrayBase<S, ndarray::Dim<[usize; 4]>>,\n    b: usize,\n    c: usize,\n    x: f64,\n    y: f64,\n    width: usize,\n    height: usize,\n    padding_mode: GridSamplePaddingMode,\n    align_corners: bool,\n) -> f64\nwhere\n    E: FloatNdArrayElement,\n    S: ndarray::Data<Elem = E>,\n{\n    // Handle inf/nan coordinates\n    if !x.is_finite() || !y.is_finite() {\n        return match padding_mode {\n            GridSamplePaddingMode::Zeros => 0.0,\n            GridSamplePaddingMode::Border => {\n                // Clamp to center of image for inf/nan\n                let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64);\n                let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64);\n                source[(b, c, cy as usize, cx as usize)].elem::<f64>()\n            }\n            GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan\n        };\n    }\n\n    // Apply padding mode to get actual sampling coordinates\n    let (x, y) = match padding_mode {\n        GridSamplePaddingMode::Border => {\n            // Clamp coordinates to valid range [0, size-1]\n            let x = x.clamp(0.0, (width - 1) as f64);\n            let y = y.clamp(0.0, (height - 1) as f64);\n            (x, y)\n        }\n        GridSamplePaddingMode::Reflection => {\n            // Reflect coordinates at boundaries\n            let x = reflect_coordinate(x, width, align_corners);\n            let y = reflect_coordinate(y, height, align_corners);\n            (x, y)\n        }\n        GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read\n    };\n\n    // Get the four corner indices\n    let x0 = x.floor() as i64;\n    let y0 = y.floor() as i64;\n    let x1 = x0.saturating_add(1);\n    let y1 = y0.saturating_add(1);\n\n    // Compute interpolation weights (fractional part)\n    let x_frac = x - x.floor();\n    let y_frac = y - y.floor();\n\n    // Helper to read a value based on padding mode\n    let read_value = |xi: i64, yi: i64| -> f64 {\n        match padding_mode {\n            GridSamplePaddingMode::Zeros => {\n                // Return 0 for out-of-bounds\n                if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 {\n                    source[(b, c, yi as usize, xi as usize)].elem::<f64>()\n                } else {\n                    0.0\n                }\n            }\n            GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {\n                // Coordinates should already be in valid range after clamping/reflection\n                let xi = xi.clamp(0, (width - 1) as i64) as usize;\n                let yi = yi.clamp(0, (height - 1) as i64) as usize;\n                source[(b, c, yi, xi)].elem::<f64>()\n            }\n        }\n    };\n\n    // Read the four corners\n    let v00 = read_value(x0, y0);\n    let v01 = read_value(x0, y1);\n    let v10 = read_value(x1, y0);\n    let v11 = read_value(x1, y1);\n\n    // Bilinear interpolation weights\n    let w00 = (1.0 - x_frac) * (1.0 - y_frac);\n    let w01 = (1.0 - x_frac) * y_frac;\n    let w10 = x_frac * (1.0 - y_frac);\n    let w11 = x_frac * y_frac;\n\n    v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11\n}\n\n/// Reflect a coordinate at the boundaries using a triangle wave pattern.\n///\n/// For align_corners=true: reflects within [0, size-1]\n/// For align_corners=false: reflects within [-0.5, size-0.5]\nfn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {\n    let size_f = size as f64;\n    let (min_val, max_val) = if align_corners {\n        (0.0, size_f - 1.0)\n    } else {\n        (-0.5, size_f - 0.5)\n    };\n\n    let span = max_val - min_val;\n    if span <= 0.0 {\n        return min_val;\n    }\n\n    // Triangle wave formula: span - |((x mod 2*span) - span)|\n    let period = 2.0 * span;\n    let x = (coord - min_val).abs();\n    let x_mod = x - (x / period).floor() * period;\n    span - (x_mod - span).abs() + min_val\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/int_tensor.rs",
    "content": "// Language\nuse crate::rand::get_seeded_rng;\nuse alloc::vec::Vec;\nuse burn_backend::backend::ExecutionError;\nuse burn_backend::ops::IntTensorOps;\nuse burn_backend::tensor::{FloatTensor, IntTensor};\nuse burn_backend::{Distribution, IntDType, Scalar, TensorMetadata};\n\nuse burn_backend::ElementConversion;\n\n// Current crate\nuse crate::cat_with_dtype;\nuse crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor};\nuse crate::{NdArrayDevice, SEED, slice};\nuse crate::{SharedArray, element::QuantElement};\nuse crate::{element::FloatNdArrayElement, ops::matmul::matmul};\nuse crate::{element::IntNdArrayElement, execute_with_int_dtype};\n\n// Workspace crates\nuse super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps};\nuse burn_backend::{DType, Shape, TensorData, backend::Backend};\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor {\n        if data.dtype.is_int() || data.dtype.is_uint() {\n            NdArrayTensor::from_data(data)\n        } else {\n            unimplemented!(\"Unsupported dtype for `int_from_data`: {:?}\", data.dtype)\n        }\n    }\n\n    async fn int_into_data(tensor: NdArrayTensor) -> Result<TensorData, ExecutionError> {\n        Ok(tensor.into_data())\n    }\n\n    fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor {\n        tensor\n    }\n\n    fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape))\n    }\n\n    fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor {\n        slice!(tensor, slices)\n    }\n\n    fn int_device(_tensor: &NdArrayTensor) -> <NdArray<E> as Backend>::Device {\n        NdArrayDevice::Cpu\n    }\n\n    fn int_empty(\n        shape: Shape,\n        device: &<NdArray<E> as Backend>::Device,\n        dtype: IntDType,\n    ) -> NdArrayTensor {\n        Self::int_zeros(shape, device, dtype)\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        execute_with_int_dtype!((lhs, rhs), matmul)\n    }\n\n    fn int_mask_where(\n        tensor: NdArrayTensor,\n        mask: NdArrayTensor,\n        source: NdArrayTensor,\n    ) -> NdArrayTensor {\n        execute_with_int_dtype!((tensor, source), |tensor, source| {\n            NdArrayOps::mask_where(tensor, mask.bool(), source)\n        })\n    }\n\n    fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill(\n            array,\n            mask.bool(),\n            value.elem()\n        ))\n    }\n\n    fn int_slice_assign(\n        tensor: NdArrayTensor,\n        slices: &[burn_backend::Slice],\n        value: NdArrayTensor,\n    ) -> NdArrayTensor {\n        execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign(\n            tensor, slices, value\n        ))\n    }\n\n    fn int_cat(tensors: Vec<NdArrayTensor>, dim: usize) -> NdArrayTensor {\n        cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8])\n    }\n\n    fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal)\n    }\n\n    fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem()))\n    }\n\n    fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater)\n    }\n\n    fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem()))\n    }\n\n    fn int_greater_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal)\n    }\n\n    fn int_greater_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem(\n            array,\n            rhs.elem()\n        ))\n    }\n\n    fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower)\n    }\n\n    fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem()))\n    }\n\n    fn int_lower_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal)\n    }\n\n    fn int_lower_equal_elem(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem(\n            array,\n            rhs.elem()\n        ))\n    }\n\n    fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add)\n    }\n\n    fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem()))\n    }\n\n    fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub)\n    }\n\n    fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem()))\n    }\n\n    fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul)\n    }\n\n    fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem()))\n    }\n\n    fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div)\n    }\n\n    fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem()))\n    }\n\n    fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder)\n    }\n\n    fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar(\n            array,\n            rhs.elem()\n        ))\n    }\n\n    fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::sum_view(\n            array.view()\n        ))\n    }\n\n    fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim))\n    }\n\n    fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(\n            tensor,\n            E,\n            |array: SharedArray<E>| NdArrayMathOps::prod_view(array.view())\n        )\n    }\n\n    fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim))\n    }\n\n    fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(\n            tensor,\n            E,\n            |array: SharedArray<E>| NdArrayMathOps::mean_view(array.view())\n        )\n    }\n\n    fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim))\n    }\n\n    fn int_max(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::max_view(\n            array.view()\n        ))\n    }\n\n    fn int_min(tensor: NdArrayTensor) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| NdArrayMathOps::min_view(\n            array.view()\n        ))\n    }\n\n    fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim))\n    }\n\n    fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim))\n    }\n\n    fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim))\n    }\n\n    fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim))\n    }\n\n    fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {\n            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather(\n                dim, array, idx_array\n            ))\n        })\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: NdArrayTensor,\n        indices: NdArrayTensor,\n        value: NdArrayTensor,\n    ) -> NdArrayTensor {\n        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {\n            execute_with_int_dtype!(indices, |idx_array| NdArrayOps::<I>::scatter(\n                dim, tensor, idx_array, value\n            ))\n        })\n    }\n\n    fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor {\n            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select(\n                array, dim, idx_array\n            ))\n        })\n    }\n\n    fn int_select_add(\n        tensor: NdArrayTensor,\n        dim: usize,\n        indices: NdArrayTensor,\n        value: NdArrayTensor,\n    ) -> NdArrayTensor {\n        execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor {\n            execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::<I>::select_assign(\n                tensor, dim, idx_array, value\n            ))\n        })\n    }\n    fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {\n            NdArrayMathOps::argmax_view::<I>(array.view(), dim)\n        })\n    }\n\n    fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {\n            NdArrayMathOps::argmin_view::<I>(array.view(), dim)\n        })\n    }\n\n    fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem()))\n    }\n\n    fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem()))\n    }\n\n    fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp(\n            array,\n            min.elem(),\n            max.elem()\n        ))\n    }\n\n    fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor {\n        match tensor.dtype() {\n            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {\n                execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [\n                    I64 => i64, I32 => i32, I16 => i16, I8 => i8\n                ])\n            }\n            // Already unsigned\n            DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor,\n            other => panic!(\"Unsupported dtype: {other:?}\"),\n        }\n    }\n\n    fn int_into_float(tensor: NdArrayTensor) -> FloatTensor<Self> {\n        execute_with_int_dtype!(tensor, IntElem, |array: SharedArray<IntElem>| array\n            .mapv(|a: IntElem| a.elem::<E>())\n            .into_shared())\n    }\n\n    fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2))\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &NdArrayDevice,\n    ) -> NdArrayTensor {\n        let mut seed = SEED.lock().unwrap();\n        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);\n\n        let effective_distribution = if distribution == Distribution::Default {\n            Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant\n        } else {\n            distribution\n        };\n\n        let tensor = Self::int_from_data(\n            TensorData::random::<I, _, _>(shape, effective_distribution, &mut rng),\n            device,\n        );\n        *seed = Some(rng);\n        tensor\n    }\n\n    fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op(\n            lhs,\n            rhs,\n            |a: &I, b: &I| { (a.elem::<i64>().pow(b.elem::<u32>())).elem() }\n        ))\n    }\n\n    fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes))\n    }\n\n    fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes))\n    }\n\n    fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor {\n        match tensor.dtype() {\n            DType::I64 | DType::I32 | DType::I16 | DType::I8 => {\n                execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [\n                    I64 => i64, I32 => i32, I16 => i16, I8 => i8\n                ])\n            }\n            DType::U64 | DType::U32 | DType::U16 | DType::U8 => {\n                Self::int_greater_elem(tensor, 0.into())\n            }\n            other => panic!(\"Unsupported dtype: {other:?}\"),\n        }\n    }\n\n    fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape))\n    }\n\n    fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand)\n    }\n\n    fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem()))\n    }\n\n    fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor)\n    }\n\n    fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem()))\n    }\n\n    fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor)\n    }\n\n    fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem()))\n    }\n\n    fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot)\n    }\n\n    fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {\n            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {\n                (a.elem::<i64>() << (b.elem::<u32>())).elem()\n            })\n        })\n    }\n\n    fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, I, |array| {\n            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {\n                (a.elem::<i64>() << rhs.elem::<u32>()).elem()\n            })\n        })\n    }\n\n    fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor {\n        execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| {\n            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {\n                (a.elem::<i64>() >> (b.elem::<u32>())).elem()\n            })\n        })\n    }\n\n    fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor {\n        execute_with_int_dtype!(lhs, I, |array| {\n            NdArrayMathOps::elementwise_op_scalar(array, |a: I| {\n                (a.elem::<i64>() >> rhs.elem::<u32>()).elem()\n            })\n        })\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into()))\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step))\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/interpolate.rs",
    "content": "use burn_backend::ElementConversion;\nuse ndarray::{Array4, ArrayBase, DataOwned};\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par};\n\npub(crate) fn nearest_interpolate<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    output_size: [usize; 2],\n) -> SharedArray<E> {\n    let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let (batch_size, channels, in_height, in_width) = x.dim();\n    let [out_height, out_width] = output_size;\n\n    let y_ratio = (in_height as f64) / (out_height as f64);\n    let x_ratio = (in_width as f64) / (out_width as f64);\n\n    let out_element_num = batch_size * channels * out_height * out_width;\n    let strides = (\n        channels * out_height * out_width,\n        out_height * out_width,\n        out_width,\n    );\n\n    let mut output = Array4::zeros((batch_size, channels, out_height, out_width));\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, out_element_num).for_each(|id| {\n            let (b, c, h, w) = (\n                id / strides.0,\n                id % strides.0 / strides.1,\n                id % strides.1 / strides.2,\n                id % strides.2,\n            );\n\n            let y_in = (y_ratio * h as f64).floor() as usize;\n            let x_in = (x_ratio * w as f64).floor() as usize;\n\n            unsafe {\n                let output = unsafe_shared_out.get();\n                output[(b, c, h, w)] = x[(b, c, y_in, x_in)];\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn nearest_interpolate_backward<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    grad: SharedArray<E>,\n    output_size: [usize; 2],\n) -> SharedArray<E> {\n    let [batch_size, channels, input_height, input_width] = x.shape().dims();\n    let [output_height, output_width] = output_size;\n\n    let mut output_grad =\n        Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output_grad = unsafe_shared_out.get();\n\n            for oh in 0..output_height {\n                for ow in 0..output_width {\n                    let ih = start_index(oh, output_height, input_height);\n                    let iw = start_index(ow, output_width, input_width);\n\n                    output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]]\n                }\n            }\n        })\n    });\n\n    output_grad.into_dyn().into_shared()\n}\n\nfn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {\n    ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize\n}\n\n// clamp ceil(frac) to stay within bounds in case of floating-point imprecision\npub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 {\n    frac.ceil().min(max as f64)\n}\n\npub(crate) fn bilinear_interpolate<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    output_size: [usize; 2],\n    align_corners: bool,\n) -> SharedArray<E> {\n    let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let (batch_size, channels, in_height, in_width) = x.dim();\n    let [out_height, out_width] = output_size;\n\n    let out_element_num = batch_size * channels * out_height * out_width;\n    let strides = (\n        channels * out_height * out_width,\n        out_height * out_width,\n        out_width,\n    );\n\n    let mut output = Array4::zeros((batch_size, channels, out_height, out_width));\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, out_element_num).for_each(|id| {\n            let (b, c, h, w) = (\n                id / strides.0,\n                id % strides.0 / strides.1,\n                id % strides.1 / strides.2,\n                id % strides.2,\n            );\n\n            let (y_frac, x_frac) = if align_corners {\n                let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);\n                let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);\n                (y_ratio * h as f64, x_ratio * w as f64)\n            } else {\n                let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;\n                let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;\n                (\n                    y_frac.clamp(0.0, (in_height - 1) as f64),\n                    x_frac.clamp(0.0, (in_width - 1) as f64),\n                )\n            };\n            let val =\n                bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1);\n\n            unsafe {\n                let output = unsafe_shared_out.get();\n                output[(b, c, h, w)] = val.elem();\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn bicubic_interpolate<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    output_size: [usize; 2],\n    align_corners: bool,\n) -> SharedArray<E> {\n    fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 {\n        fn cubic_convolution1(x: f64, a: f64) -> f64 {\n            ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0\n        }\n\n        fn cubic_convolution2(x: f64, a: f64) -> f64 {\n            ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a\n        }\n\n        let coeffs = [\n            cubic_convolution2(t + 1.0, -0.75),\n            cubic_convolution1(t, -0.75),\n            cubic_convolution1(1.0 - t, -0.75),\n            cubic_convolution2(2.0 - t, -0.75),\n        ];\n\n        x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]\n    }\n\n    let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let (batch_size, channels, in_height, in_width) = x.dim();\n    let [out_height, out_width] = output_size;\n\n    let out_element_num = batch_size * channels * out_height * out_width;\n    let strides = (\n        channels * out_height * out_width,\n        out_height * out_width,\n        out_width,\n    );\n\n    let mut output = Array4::zeros((batch_size, channels, out_height, out_width));\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, out_element_num).for_each(|id| {\n            let (b, c, h, w) = (\n                id / strides.0,\n                id % strides.0 / strides.1,\n                id % strides.1 / strides.2,\n                id % strides.2,\n            );\n\n            let (y_frac, x_frac) = if align_corners {\n                let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);\n                let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);\n                (y_ratio * h as f64, x_ratio * w as f64)\n            } else {\n                let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;\n                let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;\n                (y_frac, x_frac)\n            };\n            let y0 = y_frac.floor();\n            let yw = y_frac - y0;\n            let y_in = y0 as isize;\n\n            let x0 = x_frac.floor();\n            let xw = x_frac - x0;\n            let x_in = x0 as isize;\n\n            let max_h = (in_height - 1) as isize;\n            let max_w = (in_width - 1) as isize;\n\n            let ys_in = [\n                (y_in - 1).clamp(0, max_h) as usize,\n                y_in.clamp(0, max_h) as usize,\n                (y_in + 1).clamp(0, max_h) as usize,\n                (y_in + 2).clamp(0, max_h) as usize,\n            ];\n\n            let xs_in = [\n                (x_in - 1).clamp(0, max_w) as usize,\n                x_in.clamp(0, max_w) as usize,\n                (x_in + 1).clamp(0, max_w) as usize,\n                (x_in + 2).clamp(0, max_w) as usize,\n            ];\n\n            let coefficients = ys_in.map(|y| {\n                cubic_interp1d(\n                    x[(b, c, y, xs_in[0])].elem(),\n                    x[(b, c, y, xs_in[1])].elem(),\n                    x[(b, c, y, xs_in[2])].elem(),\n                    x[(b, c, y, xs_in[3])].elem(),\n                    xw,\n                )\n            });\n\n            let result = cubic_interp1d(\n                coefficients[0],\n                coefficients[1],\n                coefficients[2],\n                coefficients[3],\n                yw,\n            )\n            .elem();\n\n            unsafe {\n                let output = unsafe_shared_out.get();\n                output[(b, c, h, w)] = result;\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn lanczos3_interpolate<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    output_size: [usize; 2],\n    align_corners: bool,\n) -> SharedArray<E> {\n    fn lanczos3_weight(x: f64) -> f64 {\n        if x == 0.0 {\n            return 1.0;\n        }\n        let abs_x = x.abs();\n        if abs_x >= 3.0 {\n            return 0.0;\n        }\n        let pi = core::f64::consts::PI;\n        let pi_x = pi * x;\n        let pi_x_over_3 = pi_x / 3.0;\n        (pi_x.sin() * pi_x_over_3.sin()) / (pi_x * pi_x_over_3)\n    }\n\n    let x = x.into_dimensionality::<ndarray::Ix4>().unwrap();\n\n    let (batch_size, channels, in_height, in_width) = x.dim();\n    let [out_height, out_width] = output_size;\n\n    let out_element_num = batch_size * channels * out_height * out_width;\n    let strides = (\n        channels * out_height * out_width,\n        out_height * out_width,\n        out_width,\n    );\n\n    let mut output = Array4::zeros((batch_size, channels, out_height, out_width));\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, out_element_num).for_each(|id| {\n            let (b, c, h, w) = (\n                id / strides.0,\n                id % strides.0 / strides.1,\n                id % strides.1 / strides.2,\n                id % strides.2,\n            );\n\n            let (y_frac, x_frac) = if align_corners {\n                let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64);\n                let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64);\n                (y_ratio * h as f64, x_ratio * w as f64)\n            } else {\n                let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5;\n                let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5;\n                (y_frac, x_frac)\n            };\n\n            let y0 = y_frac.floor();\n            let x0 = x_frac.floor();\n            let max_h = (in_height - 1) as isize;\n            let max_w = (in_width - 1) as isize;\n\n            // 6x6 separable Lanczos3 filter (skip out-of-bounds positions)\n            let mut result = 0.0;\n            let mut weight_sum = 0.0;\n            for ky in -2..=3 {\n                let yi = y0 as isize + ky;\n                if yi < 0 || yi > max_h {\n                    continue;\n                }\n                let y_idx = yi as usize;\n                let wy = lanczos3_weight(y_frac - (y0 + ky as f64));\n                for kx in -2..=3 {\n                    let xi = x0 as isize + kx;\n                    if xi < 0 || xi > max_w {\n                        continue;\n                    }\n                    let x_idx = xi as usize;\n                    let wx = lanczos3_weight(x_frac - (x0 + kx as f64));\n                    let w = wy * wx;\n                    let pixel: f64 = x[(b, c, y_idx, x_idx)].elem();\n                    result += pixel * w;\n                    weight_sum += w;\n                }\n            }\n            if weight_sum != 0.0 {\n                result /= weight_sum;\n            }\n\n            unsafe {\n                let output = unsafe_shared_out.get();\n                output[(b, c, h, w)] = result.elem();\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n\n/// Sample an element of the source array with bilinear interpolation\n///\n/// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width)\n/// * `b` - The batch to read from\n/// * `c` - The channel to read from\n/// * `x` - The x position to read in the array\n/// * `y` - The y position to read in the array\n/// * `x_max` - The max x position (inclusive)\n/// * `y_max` - The max y position (inclusive)\n///\n/// # Returns\n///\n/// The interpolated value read from the array\npub(crate) fn bilinear_interpolate_single<E, S>(\n    source: &ArrayBase<S, ndarray::Dim<[usize; 4]>>,\n    b: usize,\n    c: usize,\n    x: f64,\n    y: f64,\n    x_max: usize,\n    y_max: usize,\n) -> f64\nwhere\n    E: FloatNdArrayElement,\n    S: DataOwned<Elem = E>,\n{\n    let y0 = y.floor();\n    let y1 = ceil_clamp(y, y_max);\n    let yw = y - y0;\n\n    let x0 = x.floor();\n    let x1 = ceil_clamp(x, x_max);\n    let xw = x - x0;\n\n    let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize);\n\n    let p_a = source[(b, c, y0, x0)].elem::<f64>() * (1.0 - xw) * (1.0 - yw);\n    let p_b = source[(b, c, y0, x1)].elem::<f64>() * xw * (1.0 - yw);\n    let p_c = source[(b, c, y1, x0)].elem::<f64>() * (1.0 - xw) * yw;\n    let p_d = source[(b, c, y1, x1)].elem::<f64>() * xw * yw;\n\n    p_a + p_b + p_c + p_d\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/macros.rs",
    "content": "macro_rules! keepdim {\n    (\n        $dim:expr,\n        $self:expr,\n        mean\n    ) => {{\n        // Get shape first (via reference), then pass ownership to avoid clone\n        let mut shape = $self.shape().into_shape();\n        shape[$dim] = 1;\n        let tensor: SharedArray<E> = mean_dim($self, $dim);\n        NdArrayOps::reshape(tensor, shape)\n    }};\n    (\n        $dim:expr,\n        $self:expr,\n        sum\n    ) => {{\n        // Get shape first (via reference), then pass ownership to avoid clone\n        let mut shape = $self.shape().into_shape();\n        shape[$dim] = 1;\n        let tensor: SharedArray<E> = sum_dim($self, $dim);\n        NdArrayOps::reshape(tensor, shape)\n    }};\n    (\n        $dim:expr,\n        $self:expr,\n        prod\n    ) => {{\n        // Get shape first (via reference), then pass ownership to avoid clone\n        let mut shape = $self.shape().into_shape();\n        shape[$dim] = 1;\n        let tensor: SharedArray<E> = prod_dim($self, $dim);\n        NdArrayOps::reshape(tensor, shape)\n    }};\n}\n\nuse burn_backend::ElementConversion;\npub(crate) use keepdim;\nuse ndarray::{Axis, Zip};\n\nuse crate::{SharedArray, element::NdArrayElement};\n\npub(crate) fn mean_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n    tensor.mean_axis(Axis(dim)).unwrap().into_shared()\n}\n\npub(crate) fn sum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n    tensor.sum_axis(Axis(dim)).into_shared()\n}\n\npub(crate) fn prod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n    tensor\n        .fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))\n        .into_shared()\n}\n\n/// Generic cumulative operation function with closure-based operation.\npub(crate) fn cumulative_with_op<E, F>(tensor: SharedArray<E>, dim: usize, op: F) -> SharedArray<E>\nwhere\n    E: NdArrayElement,\n    F: Fn(&mut E, &E),\n{\n    let axis = Axis(dim);\n    let shape = tensor.shape().to_vec();\n    // Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique\n    let mut result = tensor.into_owned();\n    let dim_size = shape[dim];\n\n    for i in 1..dim_size {\n        let prev = result.index_axis(axis, i - 1).to_owned();\n        let mut current = result.index_axis_mut(axis, i);\n        Zip::from(&mut current).and(&prev).for_each(&op);\n    }\n\n    result.into_shared()\n}\n\n// Define all cumulative operation functions using the generic function\npub(crate) fn cumsum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n    cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem()))\n}\n\npub(crate) fn cumprod_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {\n    cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem()))\n}\n\npub(crate) fn cummin_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(\n    tensor: SharedArray<E>,\n    dim: usize,\n) -> SharedArray<E> {\n    cumulative_with_op(tensor, dim, |c, &p| {\n        if p < *c {\n            *c = p;\n        }\n    })\n}\n\npub(crate) fn cummax_dim<E: NdArrayElement + core::cmp::PartialOrd<E>>(\n    tensor: SharedArray<E>,\n    dim: usize,\n) -> SharedArray<E> {\n    cumulative_with_op(tensor, dim, |c, &p| {\n        if p > *c {\n            *c = p;\n        }\n    })\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/matmul.rs",
    "content": "use crate::UnsafeSharedRef;\nuse crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par};\n\nuse alloc::{vec, vec::Vec};\nuse burn_backend::ElementConversion;\nuse burn_backend::Shape;\nuse ndarray::{IxDyn, s};\n\npub(crate) fn matmul<E: NdArrayElement>(\n    lhs: SharedArray<E>,\n    rhs: SharedArray<E>,\n) -> SharedArray<E> {\n    let shape_lhs = lhs.shape();\n    let shape_rhs = rhs.shape();\n    let ndims = shape_lhs.num_dims();\n    let m = shape_lhs[ndims - 2]; // # of left rows\n    let k = shape_rhs[ndims - 2]; // # of left cols and right rows\n    let n = shape_rhs[ndims - 1]; // # of right cols\n\n    let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs);\n    let l_mat_size = m * k; // size of matrix component of left array\n    let r_mat_size = k * n; // size of matrix component of right array\n    let out_mat_size = m * n; // size of matrix component of output array\n\n    let num_l_batches = shape_lhs.num_elements() / l_mat_size;\n    let num_r_batches = shape_rhs.num_elements() / r_mat_size;\n    let num_out_batches = out_shape.num_elements() / out_mat_size;\n\n    let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k]));\n    let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n]));\n\n    let alpha: E = 1.0.elem();\n    let beta: E = 0.0.elem();\n\n    let out = run_par!(|| {\n        let mut out_array = ndarray::Array3::<E>::zeros((num_out_batches, m, n));\n        let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);\n\n        iter_range_par!(0, num_out_batches).for_each(|out_batch| {\n            // Here, we:\n            //   1. Un-flatten the output batch into a component-based batch index.\n            //   2. Use the strides for left and right batch indices to convert it to a flattened\n            //      batch for left and right.\n            let out_index = strides_out.unflatten(out_batch);\n            let l_batch = strides_lhs.flatten(&out_index);\n            let r_batch = strides_rhs.flatten(&out_index);\n\n            let lhs_slice = lhs_array.slice(s!(l_batch, .., ..));\n            let rhs_slice = rhs_array.slice(s!(r_batch, .., ..));\n\n            unsafe {\n                let mut out_slice = unsafe_shared_out_array\n                    .get()\n                    .slice_mut(s!(out_batch, .., ..));\n\n                ndarray::linalg::general_mat_mul(\n                    alpha,\n                    &lhs_slice,\n                    &rhs_slice,\n                    beta,\n                    &mut out_slice,\n                )\n            }\n        });\n\n        out_array.into_shared().into_dyn()\n    });\n\n    NdArrayOps::reshape(out, out_shape)\n}\n\n#[derive(Debug, PartialEq)]\nstruct Strides {\n    strides: Vec<usize>,\n}\nimpl Strides {\n    fn new(strides: Vec<usize>) -> Self {\n        Strides { strides }\n    }\n\n    fn unflatten(&self, linear_index: usize) -> Vec<usize> {\n        let mut coord = Vec::with_capacity(self.strides.len());\n        let mut rem = linear_index;\n        for stride in self.strides.iter() {\n            coord.push(rem / stride);\n            rem %= stride;\n        }\n        coord\n    }\n\n    fn flatten(&self, index: &Vec<usize>) -> usize {\n        assert_eq!(self.strides.len(), index.len());\n        self.strides\n            .iter()\n            .zip(index)\n            .map(|(stride, index)| stride * index)\n            .sum()\n    }\n}\n\n/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for\n/// the non-matrix dimensions of all arrays.\n///\n/// # Arguments\n/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument.\n/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument.\n///\n/// # Panics\n/// * If `D` is not at least 2.\n/// * If the matrix multiplication dimensions (last 2) are incompatible.\n/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where\n///   one dim is equal to 1 is broadcast.)\nfn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) {\n    let ndims = lsh.num_dims();\n    if ndims < 2 {\n        panic!(\"Matrix multiplication requires an array with at least 2 dimensions.\");\n    }\n\n    // Fetch matrix dimensions and check compatibility.\n    let l_rows = lsh[ndims - 2];\n    let l_cols = lsh[ndims - 1];\n    let r_rows = rsh[ndims - 2];\n    let r_cols = rsh[ndims - 1];\n    if l_cols != r_rows {\n        panic!(\"Dimensions are incompatible for matrix multiplication.\");\n    }\n    // Set matrix dimensions of the output shape.\n    let mut osh = vec![0; ndims];\n    osh[ndims - 2] = l_rows;\n    osh[ndims - 1] = r_cols;\n\n    // Set other array dimensions, broadcasting as necessary.\n    // Compute the strides inline.\n    let mut cur_l_stride: usize = 1;\n    let mut cur_r_stride: usize = 1;\n    let mut cur_o_stride: usize = 1;\n    let mut l_strides = Vec::with_capacity(ndims - 2);\n    let mut r_strides = Vec::with_capacity(ndims - 2);\n    let mut o_strides = Vec::with_capacity(ndims - 2);\n    for i in (0..ndims - 2).rev() {\n        let l_dim = lsh[i];\n        let r_dim = rsh[i];\n\n        // Compatible dimensions are:\n        //   1. Both dimensions are equal.\n        //   2. One of the dimensions is equal to 1.\n        let o_dim: usize;\n        if l_dim == r_dim {\n            o_dim = l_dim; // both dimensions are equal\n            l_strides.push(cur_l_stride);\n            r_strides.push(cur_r_stride);\n        } else if l_dim == 1 {\n            o_dim = r_dim; // broadcast the left\n            l_strides.push(0);\n            r_strides.push(cur_r_stride);\n        } else if r_dim == 1 {\n            o_dim = l_dim; // broadcast the right\n            l_strides.push(cur_l_stride);\n            r_strides.push(0);\n        } else {\n            panic!(\"Dimensions differ and cannot be broadcasted.\");\n        }\n        osh[i] = o_dim;\n        o_strides.push(cur_o_stride);\n        cur_o_stride *= o_dim;\n\n        cur_l_stride *= l_dim;\n        cur_r_stride *= r_dim;\n    }\n    l_strides.reverse();\n    r_strides.reverse();\n    o_strides.reverse();\n\n    (\n        Shape::from(osh),\n        Strides::new(l_strides),\n        Strides::new(r_strides),\n        Strides::new(o_strides),\n    )\n}\n\npub(crate) fn cross<E: NdArrayElement>(\n    lhs: SharedArray<E>,\n    rhs: SharedArray<E>,\n    dim: usize,\n) -> SharedArray<E> {\n    let shape_lhs = lhs.shape();\n    let shape_rhs = rhs.shape();\n    let ndims = shape_lhs.num_dims();\n\n    // Broadcast the shapes except along dim\n    let mut broadcast_shape = vec![0; ndims];\n    for i in 0..ndims {\n        if i == dim {\n            broadcast_shape[i] = shape_lhs[i]; // already checked to be 3\n        } else {\n            let l = shape_lhs[i];\n            let r = shape_rhs[i];\n            if l == r {\n                broadcast_shape[i] = l;\n            } else if l == 1 {\n                broadcast_shape[i] = r;\n            } else if r == 1 {\n                broadcast_shape[i] = l;\n            } else {\n                panic!(\"Tensors are not broadcastable along dimension {}\", i);\n            }\n        }\n    }\n\n    // Broadcast lhs and rhs\n    let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() {\n        lhs\n    } else {\n        NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone()))\n    };\n    let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() {\n        rhs\n    } else {\n        NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone()))\n    };\n\n    // Now, move dim to the last dimension\n    let mut perm = (0..ndims).collect::<Vec<_>>();\n    perm.remove(dim);\n    perm.push(dim);\n\n    let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm);\n    let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm);\n\n    // Reshape to (*, 3)\n    let total_elements = lhs_permuted.shape().num_elements();\n    let batch_size = total_elements / 3;\n    let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3]));\n    let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3]));\n\n    // Compute cross product\n    let mut result = ndarray::ArrayD::<E>::zeros(IxDyn(&[batch_size, 3]));\n    for i in 0..batch_size {\n        let a1 = lhs_reshaped[IxDyn(&[i, 0])];\n        let a2 = lhs_reshaped[IxDyn(&[i, 1])];\n        let a3 = lhs_reshaped[IxDyn(&[i, 2])];\n        let b1 = rhs_reshaped[IxDyn(&[i, 0])];\n        let b2 = rhs_reshaped[IxDyn(&[i, 1])];\n        let b3 = rhs_reshaped[IxDyn(&[i, 2])];\n        result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2));\n        result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3));\n        result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1));\n    }\n\n    let result_shared = result.into_shared();\n\n    // Reshape back to the broadcast shape with dim at the end\n    let mut result_shape = broadcast_shape;\n    result_shape.remove(dim);\n    result_shape.push(3);\n    let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape));\n\n    // Permute back\n    let mut inv_perm = vec![0; ndims];\n    for (i, &p) in perm.iter().enumerate() {\n        inv_perm[p] = i;\n    }\n    NdArrayOps::permute(result_reshaped, &inv_perm)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    impl Strides {\n        fn empty() -> Self {\n            Strides {\n                strides: Vec::with_capacity(0),\n            }\n        }\n    }\n\n    #[test]\n    fn test_output_shape() {\n        // plain matrix multiply\n        assert_eq!(\n            output_shape(&[5, 3], &[3, 7]),\n            (\n                Shape::from([5, 7]),\n                Strides::empty(),\n                Strides::empty(),\n                Strides::empty()\n            )\n        );\n        // matrix multiply with one extra stack dimension\n        assert_eq!(\n            output_shape(&[4, 5, 3], &[4, 3, 7]),\n            (\n                Shape::from([4, 5, 7]),\n                Strides::new(vec![1]),\n                Strides::new(vec![1]),\n                Strides::new(vec![1])\n            )\n        );\n        // rank 3, broadcast left\n        assert_eq!(\n            output_shape(&[1, 5, 3], &[4, 3, 7]),\n            (\n                Shape::from([4, 5, 7]),\n                Strides::new(vec![0]),\n                Strides::new(vec![1]),\n                Strides::new(vec![1])\n            )\n        );\n        // rank 3, broadcast right\n        assert_eq!(\n            output_shape(&[4, 5, 3], &[1, 3, 7]),\n            (\n                Shape::from([4, 5, 7]),\n                Strides::new(vec![1]),\n                Strides::new(vec![0]),\n                Strides::new(vec![1])\n            )\n        );\n        // rank 4, multi broadcast\n        assert_eq!(\n            output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]),\n            (\n                Shape::from([8, 4, 5, 7]),\n                Strides::new(vec![0, 1]),\n                Strides::new(vec![1, 0]),\n                Strides::new(vec![4, 1])\n            )\n        );\n        // rank 5, multi-broadcast\n        assert_eq!(\n            output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]),\n            (\n                Shape::from([8, 3, 4, 5, 7]),\n                Strides::new(vec![0, 4, 1]),\n                Strides::new(vec![3, 1, 0]),\n                Strides::new(vec![12, 4, 1])\n            )\n        )\n    }\n\n    #[test]\n    #[should_panic(\n        expected = \"Matrix multiplication requires an array with at least 2 dimensions.\"\n    )]\n    fn test_output_shape_too_small() {\n        output_shape(&[4], &[4]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Dimensions are incompatible for matrix multiplication.\")]\n    fn test_output_shape_bad_matrix_dims() {\n        output_shape(&[5, 3], &[4, 7]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Dimensions differ and cannot be broadcasted.\")]\n    fn test_output_shape_non_broadcast() {\n        output_shape(&[4, 5, 3], &[2, 3, 7]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/maxpool.rs",
    "content": "use crate::{\n    ShapeOps, SharedArray,\n    element::{FloatNdArrayElement, IntNdArrayElement},\n    iter_range_par,\n    ops::padding::apply_padding_4d,\n    run_par,\n    sharing::UnsafeSharedRef,\n};\n\nuse burn_backend::ElementConversion;\nuse burn_backend::ops::conv::calculate_pool_output_size;\nuse ndarray::Array4;\n\npub(crate) fn max_pool2d<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n) -> SharedArray<E> {\n    let [kernel_height, kernel_width] = kernel_size;\n    let [padding_height, padding_width] = padding;\n    let [stride_height, stride_width] = stride;\n    let [dilation_height, dilation_width] = dilation;\n    let [batch_size, channels, x_height, x_width] = x.shape().dims();\n    let inf = (-f32::INFINITY).elem::<E>();\n\n    let out_height = calculate_pool_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        dilation_height,\n        x_height,\n        ceil_mode,\n    );\n    let out_width = calculate_pool_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        dilation_width,\n        x_width,\n        ceil_mode,\n    );\n\n    // Calculate extra padding needed for ceil_mode\n    // The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation\n    // This must be < input_size + 2 * total_padding\n    let max_ih =\n        (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;\n    let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;\n    let padded_height = x_height + 2 * padding_height;\n    let padded_width = x_width + 2 * padding_width;\n    let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));\n    let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));\n    let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];\n\n    let x = apply_padding_4d::<E>(x, total_padding, inf);\n\n    // Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d)\n    let offset_h = extra_pad_h;\n    let offset_w = extra_pad_w;\n\n    let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output = unsafe_shared_out.get();\n\n            for oh in 0..out_height {\n                for ow in 0..out_width {\n                    let mut max_val = inf;\n\n                    for kh in 0..kernel_height {\n                        let ih = offset_h + oh * stride_height + kh * dilation_height;\n\n                        for kw in 0..kernel_width {\n                            let iw = offset_w + ow * stride_width + kw * dilation_width;\n\n                            let val = x[[b, c, ih, iw]];\n\n                            if val > max_val {\n                                max_val = val;\n                            }\n                        }\n                    }\n\n                    output[[b, c, oh, ow]] = max_val;\n                }\n            }\n        })\n    });\n\n    output.into_dyn().into_shared()\n}\n\npub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, I: IntNdArrayElement>(\n    x: SharedArray<E>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n) -> (SharedArray<E>, SharedArray<I>) {\n    let [kernel_height, kernel_width] = kernel_size;\n    let [padding_height, padding_width] = padding;\n    let [stride_height, stride_width] = stride;\n    let [dilation_height, dilation_width] = dilation;\n    let [batch_size, channels, x_height, x_width] = x.shape().dims();\n    let inf = (-f32::INFINITY).elem::<E>();\n\n    let out_height = calculate_pool_output_size(\n        kernel_height,\n        stride_height,\n        padding_height,\n        dilation_height,\n        x_height,\n        ceil_mode,\n    );\n    let out_width = calculate_pool_output_size(\n        kernel_width,\n        stride_width,\n        padding_width,\n        dilation_width,\n        x_width,\n        ceil_mode,\n    );\n\n    // Calculate extra padding needed for ceil_mode\n    let max_ih =\n        (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;\n    let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;\n    let padded_height = x_height + 2 * padding_height;\n    let padded_width = x_width + 2 * padding_width;\n    let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));\n    let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));\n    let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];\n\n    let x = apply_padding_4d::<E>(x, total_padding, inf);\n\n    // Offset to account for extra padding\n    let offset_h = extra_pad_h;\n    let offset_w = extra_pad_w;\n\n    let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);\n    let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));\n\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n    let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output = unsafe_shared_out.get();\n            let indices = unsafe_shared_indices.get();\n\n            for oh in 0..out_height {\n                for ow in 0..out_width {\n                    let mut max_val = inf;\n                    let mut index = 0;\n\n                    for kh in 0..kernel_height {\n                        let ih = offset_h + oh * stride_height + kh * dilation_height;\n\n                        for kw in 0..kernel_width {\n                            let iw = offset_w + ow * stride_width + kw * dilation_width;\n                            let val = x[[b, c, ih, iw]];\n\n                            if val > max_val {\n                                max_val = val;\n\n                                // Calculate index in original (unpadded) input\n                                let ih_orig = ih as i64 - (total_padding[0]) as i64;\n                                let iw_orig = iw as i64 - (total_padding[1]) as i64;\n\n                                // Clamp to valid range for index calculation\n                                let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1);\n                                let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1);\n\n                                index = ih_clamped * x_width as i64 + iw_clamped;\n                            }\n                        }\n                    }\n\n                    output[[b, c, oh, ow]] = max_val;\n                    indices[[b, c, oh, ow]] = index.elem();\n                }\n            }\n        })\n    });\n\n    let output = output.into_dyn().into_shared();\n    let indices = indices.into_dyn().into_shared();\n\n    (output, indices)\n}\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(\n    x: SharedArray<E>,\n    _kernel_size: [usize; 2],\n    _stride: [usize; 2],\n    _padding: [usize; 2],\n    _dilation: [usize; 2],\n    _ceil_mode: bool,\n    output_grad: SharedArray<E>,\n    indices: SharedArray<I>,\n) -> SharedArray<E> {\n    let [_batch_size, _channels, height, width] = output_grad.shape().dims();\n    let [batch_size, channels, height_x, width_x] = x.shape().dims();\n\n    let output_grad = output_grad;\n    let indices = indices;\n\n    let mut output = Array4::zeros((batch_size, channels, height_x, width_x));\n\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n    run_par!(|| {\n        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {\n            let b = k / channels;\n            let c = k % channels;\n\n            let output = unsafe_shared_out.get();\n\n            for h in 0..height {\n                for w in 0..width {\n                    let index = indices[[b, c, h, w]].elem::<i64>();\n                    let grad = output_grad[[b, c, h, w]];\n\n                    let index_h = index as usize / width_x;\n                    let index_w = index as usize % width_x;\n\n                    output[[b, c, index_h, index_w]] += grad;\n                }\n            }\n        });\n    });\n\n    output.into_dyn().into_shared()\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/mod.rs",
    "content": "mod activation;\nmod base;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\n#[cfg(feature = \"simd\")]\nmod simd;\nmod tensor;\nmod transaction;\n\npub(crate) mod adaptive_avgpool;\npub(crate) mod avgpool;\npub(crate) mod conv;\npub(crate) mod deform_conv;\npub(crate) mod grid_sample;\npub(crate) mod interpolate;\npub(crate) mod macros;\npub(crate) mod matmul;\npub(crate) mod maxpool;\npub(crate) mod padding;\npub(crate) mod quantization;\n\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/module.rs",
    "content": "use super::{\n    adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},\n    avgpool::{avg_pool2d, avg_pool2d_backward},\n    conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d},\n    deform_conv::{backward::deform_conv2d_backward, deform_conv2d},\n    interpolate::{\n        bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate,\n    },\n    maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},\n};\n#[cfg(feature = \"simd\")]\nuse crate::ops::simd::{\n    avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd,\n};\nuse crate::{\n    NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype,\n    tensor::NdArrayTensor,\n};\nuse crate::{\n    element::{IntNdArrayElement, QuantElement},\n    ops::interpolate::nearest_interpolate_backward,\n};\nuse burn_backend::{\n    ElementConversion, TensorMetadata,\n    ops::{attention::attention_fallback, *},\n    tensor::FloatTensor,\n};\n\nmacro_rules! module_op {\n    // Module op with inputs (inp), optional (opt) and arguments (args).\n    // Converts NdArrayStorage to SharedArray for compatibility with existing operations.\n    (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{\n        #[allow(unused_parens, unreachable_patterns)]\n        match ($($x),+) {\n            ($(NdArrayTensor::F32($x)),+) => {\n                type $element = f32;\n                $op(\n                    $($x.into_shared()),+\n                    $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!(\"Optional argument type mismatch\") }))*\n                )\n            }\n            ($(NdArrayTensor::F64($x)),+) => {\n                type $element = f64;\n                $op(\n                    $($x.into_shared()),+\n                    $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!(\"Optional argument type mismatch\") }))*\n                )\n            }\n            _ => panic!(\"Data type mismatch\"),\n        }\n    }};\n}\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ModuleOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn conv2d(\n        x: NdArrayTensor,\n        weight: NdArrayTensor,\n        bias: Option<NdArrayTensor>,\n        options: ConvOptions<2>,\n    ) -> NdArrayTensor {\n        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {\n            #[cfg(feature = \"simd\")]\n            let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) {\n                Ok(out) => return out.into(),\n                Err(args) => args,\n            };\n            conv2d::<E>(x, weight, bias, options).into()\n        })\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        module_op!(\n            inp(x, offset, weight),\n            opt(mask, bias),\n            E,\n            |x, offset, weight, mask, bias| deform_conv2d::<E>(\n                x, offset, weight, mask, bias, options\n            )\n            .into()\n        )\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        module_op!(\n            inp(x, offset, weight, output_grad),\n            opt(mask, bias),\n            E,\n            |x, offset, weight, output_grad, mask, bias| {\n                let (x, offset, weight, mask, bias) = deform_conv2d_backward::<E>(\n                    x,\n                    offset,\n                    weight,\n                    mask,\n                    bias,\n                    output_grad,\n                    options,\n                );\n                DeformConv2dBackward::new(\n                    x.into(),\n                    offset.into(),\n                    weight.into(),\n                    mask.map(|m| m.into()),\n                    bias.map(|b| b.into()),\n                )\n            }\n        )\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {\n            conv_transpose2d::<E>(x, weight, bias, options).into()\n        })\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x), opt(), E, |x| {\n            #[cfg(feature = \"simd\")]\n            let x = match if ceil_mode {\n                // SIMD path doesn't support ceil_mode yet, skip it\n                Err(x)\n            } else {\n                try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad)\n            } {\n                Ok(out) => return out.into(),\n                Err(x) => x,\n            };\n            avg_pool2d::<E>(\n                x,\n                kernel_size,\n                stride,\n                padding,\n                count_include_pad,\n                ceil_mode,\n            )\n            .into()\n        })\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::<E>(\n            x,\n            grad,\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode\n        )\n        .into())\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x), opt(), E, |x| {\n            #[cfg(feature = \"simd\")]\n            let x = match if ceil_mode {\n                // SIMD path doesn't support ceil_mode yet, skip it\n                Err(x)\n            } else {\n                try_max_pool2d_simd(x, kernel_size, stride, padding, dilation)\n            } {\n                Ok(out) => return out.into(),\n                Err(x) => x,\n            };\n            max_pool2d::<E>(x, kernel_size, stride, padding, dilation, ceil_mode).into()\n        })\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<NdArray<E, I, Q>> {\n        module_op!(inp(x), opt(), E, |x| {\n            let (output, indices) = max_pool2d_with_indices::<E, I>(\n                x,\n                kernel_size,\n                stride,\n                padding,\n                dilation,\n                ceil_mode,\n            );\n            MaxPool2dWithIndices::new(output.into(), indices.into())\n        })\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: NdArrayTensor,\n    ) -> MaxPool2dBackward<NdArray<E, I, Q>> {\n        execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray<IntElem>| {\n            // Convert indices from runtime dtype to the expected I type\n            // (pool indices are bounded by tensor dimensions, so conversion is safe)\n            let indices: SharedArray<I> = idx_s.mapv(|x| x.elem()).into_shared();\n            module_op!(inp(x, output_grad), opt(), E, |x, output_grad| {\n                let output = max_pool2d_backward::<E, I>(\n                    x,\n                    kernel_size,\n                    stride,\n                    padding,\n                    dilation,\n                    ceil_mode,\n                    output_grad,\n                    indices,\n                );\n                MaxPool2dBackward::new(output.into())\n            })\n        })\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::<E>(\n            x,\n            output_size\n        )\n        .into())\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x, grad), opt(), E, |x, grad| {\n            adaptive_avg_pool2d_backward::<E>(x, grad).into()\n        })\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        match options.mode {\n            InterpolateMode::Nearest => {\n                module_op!(inp(x), opt(), E, |x| nearest_interpolate::<E>(\n                    x,\n                    output_size\n                )\n                .into())\n            }\n            InterpolateMode::Bilinear => {\n                let align_corners = options.align_corners;\n                module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(\n                    x,\n                    output_size,\n                    align_corners\n                )\n                .into())\n            }\n            InterpolateMode::Bicubic => {\n                let align_corners = options.align_corners;\n                module_op!(inp(x), opt(), E, |x| bicubic_interpolate::<E>(\n                    x,\n                    output_size,\n                    align_corners\n                )\n                .into())\n            }\n            InterpolateMode::Lanczos3 => {\n                let align_corners = options.align_corners;\n                module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::<E>(\n                    x,\n                    output_size,\n                    align_corners\n                )\n                .into())\n            }\n        }\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        match options.mode {\n            InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {\n                nearest_interpolate_backward::<E>(x, grad, output_size).into()\n            }),\n            InterpolateMode::Bilinear => {\n                panic!(\"bilinear interpolation backward is not supported for ndarray backend\")\n            }\n            InterpolateMode::Bicubic => {\n                panic!(\"bicubic interpolation backward is not supported for ndarray backend\")\n            }\n            InterpolateMode::Lanczos3 => {\n                panic!(\"lanczos3 interpolation backward is not supported for ndarray backend\")\n            }\n        }\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::<E>(\n            x, weight, bias, options\n        )\n        .into())\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| {\n            conv_transpose3d::<E>(x, weight, bias, options).into()\n        })\n    }\n\n    fn attention(\n        query: FloatTensor<Self>,\n        key: FloatTensor<Self>,\n        value: FloatTensor<Self>,\n        mask: Option<burn_backend::tensor::BoolTensor<Self>>,\n        attn_bias: Option<FloatTensor<Self>>,\n        options: AttentionModuleOptions,\n    ) -> FloatTensor<Self> {\n        attention_fallback::<Self>(query, key, value, mask, attn_bias, options)\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/padding.rs",
    "content": "use crate::{NdArrayElement, SharedArray};\nuse ndarray::{Array4, Array5};\n\nuse super::NdArrayOps;\n\npub(crate) fn apply_padding_4d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    padding: [usize; 2],\n    elem: E,\n) -> SharedArray<E> {\n    let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap();\n    let [padding_height, padding_width] = padding;\n    let padded_height = height + 2 * padding_height;\n    let padded_width = width + 2 * padding_width;\n\n    let x_new = Array4::from_elem(\n        (batch_size, input_channels, padded_height, padded_width),\n        elem,\n    );\n    let mut x_new = x_new.into_shared().into_dyn();\n\n    x_new = NdArrayOps::slice_assign(\n        x_new,\n        &[\n            burn_backend::Slice::from(0..batch_size),\n            burn_backend::Slice::from(0..input_channels),\n            burn_backend::Slice::from(padding_height..height + padding_height),\n            burn_backend::Slice::from(padding_width..width + padding_width),\n        ],\n        x,\n    );\n\n    x_new\n}\n\npub(crate) fn apply_padding_5d<E: NdArrayElement>(\n    x: SharedArray<E>,\n    padding: [usize; 3],\n    elem: E,\n) -> SharedArray<E> {\n    let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap();\n    let [padding_depth, padding_height, padding_width] = padding;\n    let padded_depth = depth + 2 * padding_depth;\n    let padded_height = height + 2 * padding_height;\n    let padded_width = width + 2 * padding_width;\n\n    let x_new = Array5::from_elem(\n        (\n            batch_size,\n            input_channels,\n            padded_depth,\n            padded_height,\n            padded_width,\n        ),\n        elem,\n    );\n    let mut x_new = x_new.into_shared().into_dyn();\n\n    x_new = NdArrayOps::slice_assign(\n        x_new,\n        &[\n            burn_backend::Slice::from(0..batch_size),\n            burn_backend::Slice::from(0..input_channels),\n            burn_backend::Slice::from(padding_depth..depth + padding_depth),\n            burn_backend::Slice::from(padding_height..height + padding_height),\n            burn_backend::Slice::from(padding_width..width + padding_width),\n        ],\n        x,\n    );\n\n    x_new\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/qtensor.rs",
    "content": "use alloc::{vec, vec::Vec};\n\nuse burn_backend::{\n    DType, ExecutionError, Shape, TensorData, TensorMetadata,\n    ops::QTensorOps,\n    quantization::{\n        QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue,\n        QuantizationParametersPrimitive, QuantizedBytes,\n    },\n    tensor::{FloatTensor, IntTensor, QuantizedTensor},\n};\n\nuse crate::{\n    FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray,\n    element::{IntNdArrayElement, QuantElement},\n    execute_with_dtype, execute_with_int_dtype, execute_with_numeric_dtype, slice,\n};\n\nuse super::quantization::{QuantizationStrategy, SymmetricQuantization};\nuse super::{NdArrayMathOps, NdArrayOps};\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QTensorOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor<Self> {\n        match data.dtype {\n            DType::QFloat(scheme) => {\n                let shape = data.shape.clone();\n                let num_elements = data.num_elements();\n                let q_bytes = QuantizedBytes {\n                    bytes: data.into_bytes(),\n                    scheme,\n                    num_elements,\n                };\n\n                match scheme {\n                    QuantScheme {\n                        level: QuantLevel::Tensor | QuantLevel::Block(_),\n                        mode: QuantMode::Symmetric,\n                        value: QuantValue::Q8F | QuantValue::Q8S,\n                        ..\n                    } => {\n                        // We can load QuantStore::U32 w/ QuantizedBytes impl\n                        let (values, qparams) = q_bytes.into_vec_i8();\n                        let data = TensorData::new(values, shape);\n                        // Overwrite storage\n                        let scheme = scheme.with_store(QuantStore::Native);\n\n                        let qparams = qparams\n                            .scales\n                            .into_iter()\n                            .map(|scales| QParams { scales })\n                            .collect();\n\n                        NdArrayQTensor {\n                            qtensor: NdArrayTensor::from_data(data),\n                            scheme,\n                            qparams,\n                        }\n                    }\n                    QuantScheme {\n                        value:\n                            QuantValue::Q4F\n                            | QuantValue::Q4S\n                            | QuantValue::Q2F\n                            | QuantValue::Q2S\n                            | QuantValue::E2M1\n                            | QuantValue::E4M3\n                            | QuantValue::E5M2,\n                        ..\n                    } => unimplemented!(\"from_data not supported for scheme {scheme:?}\"),\n                }\n            }\n            _ => panic!(\n                \"Invalid dtype (expected DType::QFloat, got {:?})\",\n                data.dtype\n            ),\n        }\n    }\n\n    fn quantize(\n        tensor: FloatTensor<Self>,\n        scheme: &QuantScheme,\n        qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        let shape = tensor.shape();\n        let data_f = tensor.into_data();\n        let scales = qparams.scales.into_data().convert::<f32>();\n\n        // Implement with ndarray instead of QuantizationStrategy?\n        let (data, qparams) = match scheme {\n            QuantScheme {\n                level: QuantLevel::Tensor,\n                mode: QuantMode::Symmetric,\n                #[cfg(not(feature = \"export_tests\"))]\n                    value: QuantValue::Q8F | QuantValue::Q8S,\n                // For tests, \"native\" sub-byte quant serves as a reference for value equality.\n                // Values are stored as i8 regardless.\n                #[cfg(feature = \"export_tests\")]\n                    value:\n                    QuantValue::Q8F\n                    | QuantValue::Q8S\n                    | QuantValue::Q4F\n                    | QuantValue::Q4S\n                    | QuantValue::Q2F\n                    | QuantValue::Q2S,\n                store: QuantStore::Native,\n                ..\n            } => {\n                let scales = scales.iter().next().unwrap();\n                let strategy = QuantizationStrategy::PerTensorSymmetric(\n                    SymmetricQuantization::init(scales, scheme.value),\n                );\n                let values = strategy.quantize(data_f.as_slice().unwrap());\n                (\n                    TensorData::quantized(values, shape.clone(), *scheme, &[scales]),\n                    vec![QParams { scales }],\n                )\n            }\n            QuantScheme {\n                level: QuantLevel::Block(block_size),\n                mode: QuantMode::Symmetric,\n                #[cfg(not(feature = \"export_tests\"))]\n                    value: QuantValue::Q8F | QuantValue::Q8S,\n                #[cfg(feature = \"export_tests\")]\n                    value:\n                    QuantValue::Q8F\n                    | QuantValue::Q8S\n                    | QuantValue::Q4F\n                    | QuantValue::Q4S\n                    | QuantValue::Q2F\n                    | QuantValue::Q2S,\n                store: QuantStore::Native,\n                ..\n            } => {\n                let scales = scales.as_slice().unwrap();\n                let (strategy, qparams) = scales\n                    .iter()\n                    .map(|&s| {\n                        (\n                            SymmetricQuantization::init(s, scheme.value),\n                            QParams { scales: s },\n                        )\n                    })\n                    .unzip();\n                let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size);\n                let values = strategy.quantize(data_f.as_slice().unwrap());\n                (\n                    TensorData::quantized(values, shape.clone(), *scheme, scales),\n                    qparams,\n                )\n            }\n            scheme => unimplemented!(\"Quantization not supported for scheme {scheme:?}\"),\n        };\n\n        let num_elements = data.num_elements();\n        let q_bytes = QuantizedBytes {\n            bytes: data.into_bytes(),\n            scheme: *scheme,\n            num_elements,\n        };\n        let (values, _) = q_bytes.into_vec_i8();\n        let data = TensorData::new(values, shape).convert::<Q>();\n\n        NdArrayQTensor {\n            qtensor: NdArrayTensor::from_data(data),\n            scheme: *scheme,\n            qparams,\n        }\n    }\n\n    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        let strategy = tensor.strategy();\n        let scheme = tensor.scheme;\n        let shape = tensor.shape();\n        let data = match tensor.qtensor {\n            NdArrayTensor::I8(storage) => {\n                let data = storage.into_shared().into_iter().collect();\n                dequantize(data, shape, scheme, &strategy)\n            }\n            _ => unreachable!(),\n        };\n        NdArrayTensor::from_data(data)\n    }\n\n    fn q_device(_tensor: &QuantizedTensor<Self>) -> NdArrayDevice {\n        NdArrayDevice::Cpu\n    }\n\n    fn q_to_device(\n        tensor: QuantizedTensor<Self>,\n        _device: &NdArrayDevice,\n    ) -> QuantizedTensor<Self> {\n        tensor\n    }\n\n    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::reshape(array, shape)\n            }),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        let shape = tensor.qtensor.shape();\n        let scales = tensor.qparams.iter().map(|q| q.scales).collect::<Vec<_>>();\n        Ok(execute_with_numeric_dtype!(\n            tensor.qtensor,\n            E,\n            |array: SharedArray<E>| {\n                let values = array.into_iter().collect();\n                TensorData::quantized(values, shape, tensor.scheme, &scales)\n            }\n        ))\n    }\n\n    fn q_swap_dims(\n        tensor: QuantizedTensor<Self>,\n        dim1: usize,\n        dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::swap_dims(array, dim1, dim2)\n            }),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::permute(array, axes)\n            }),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_flip(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::flip(array, axes)\n            }),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_gather(\n        dim: usize,\n        tensor: QuantizedTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<\n            IntElem,\n        >|\n         -> NdArrayTensor {\n            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::gather(dim, array, idx_array)\n            })\n        });\n        NdArrayQTensor {\n            qtensor,\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_select(\n        tensor: QuantizedTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray<\n            IntElem,\n        >|\n         -> NdArrayTensor {\n            execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayMathOps::select(array, dim, idx_array)\n            })\n        });\n        NdArrayQTensor {\n            qtensor,\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_slice(\n        tensor: QuantizedTensor<Self>,\n        slices: &[burn_backend::Slice],\n    ) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: slice!(tensor.qtensor, slices),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n\n    fn q_argmax(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {\n        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n            NdArrayMathOps::argmax::<I>(array, dim)\n        })\n    }\n\n    fn q_argmin(tensor: QuantizedTensor<Self>, dim: usize) -> IntTensor<Self> {\n        execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n            NdArrayMathOps::argmin::<I>(array, dim)\n        })\n    }\n\n    fn q_expand(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {\n        NdArrayQTensor {\n            qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray<E>| {\n                NdArrayOps::expand(array, shape)\n            }),\n            scheme: tensor.scheme,\n            qparams: tensor.qparams,\n        }\n    }\n}\n\nfn dequantize<Q: QuantElement>(\n    data: Vec<Q>,\n    shape: Shape,\n    scheme: QuantScheme,\n    strategy: &QuantizationStrategy,\n) -> TensorData {\n    let qparams = match strategy {\n        QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale],\n        QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => {\n            quant.iter().map(|q| q.scale).collect()\n        }\n    };\n    let q_bytes = QuantizedBytes::new(data, scheme, &qparams);\n    let (values, _qparams) = q_bytes.into_vec_i8();\n    TensorData::new(strategy.dequantize(&values), shape)\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/quantization.rs",
    "content": "use alloc::vec::Vec;\nuse num_traits::{Float, PrimInt};\n\nuse burn_backend::quantization::{BlockSize, QuantValue};\n\n// NOTE: this mainly serves as a simple reference implementation.\n// The de/quantization ops should be refactored to use ndarray.\n\n/// Quantization strategy.\n#[derive(Debug, Clone, PartialEq, Eq)]\npub enum QuantizationStrategy {\n    /// Per-tensor symmetric quantization.\n    PerTensorSymmetric(SymmetricQuantization<f32>),\n    /// Per-block symmetric quantization.\n    PerBlockSymmetric(Vec<SymmetricQuantization<f32>>, BlockSize),\n}\n\nimpl QuantizationStrategy {\n    /// Quantize the values to a lower precision data type.\n    pub fn quantize(&self, values: &[f32]) -> Vec<i8> {\n        match self {\n            QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values),\n            QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {\n                let block_elems = block_size.num_elements();\n                let num_blocks = strategy.len();\n                let numel = values.len();\n                assert_eq!(\n                    numel / block_elems,\n                    num_blocks,\n                    \"Invalid per-block quantization with num blocks {num_blocks} and {numel} values\"\n                );\n                values\n                    .chunks(block_elems)\n                    .enumerate()\n                    .flat_map(|(block_id, block)| strategy[block_id].quantize(block))\n                    .collect()\n            }\n        }\n    }\n\n    /// Dequantize the values to a higher precision data type.\n    pub fn dequantize(&self, values: &[i8]) -> Vec<f32> {\n        match self {\n            QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values),\n            QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => {\n                let block_elems = block_size.num_elements();\n                let num_blocks = strategy.len();\n                let numel = values.len();\n                assert_eq!(\n                    numel / block_elems,\n                    num_blocks,\n                    \"Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values\"\n                );\n                values\n                    .chunks(block_elems)\n                    .enumerate()\n                    .flat_map(|(block_id, block)| strategy[block_id].dequantize(block))\n                    .collect()\n            }\n        }\n    }\n}\n\n/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision\n/// data type `Q` and vice-versa.\npub trait Quantization<E: Float + Send + Sync> {\n    /// Returns the quantization range `[a, b]`.\n    fn range(&self) -> (E, E);\n    /// Convert the values to a lower precision data type.\n    fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q>;\n    /// Convert a single value to a lower precision data type.\n    fn quantize_one<Q: PrimInt>(&self, value: E) -> Q;\n    /// Convert the values back to a higher precision data type.\n    fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E>;\n    /// Convert a single value back to a higher precision data type.\n    fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E;\n}\n\nfn valid_scale<E: Float>(mut scale: E) -> E {\n    // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the\n    // scale to 0.1 to avoid division by zero.\n    if scale.eq(&E::zero()) {\n        scale = E::from(0.1).unwrap();\n    }\n    scale\n}\n\n/// Symmetric quantization scheme.\n#[derive(Debug, Clone, Copy)]\npub struct SymmetricQuantization<E: Float + Send + Sync> {\n    /// The scaling factor.\n    pub scale: E,\n    // The quantization value data type.\n    value: QuantValue,\n}\n\nimpl<E: Float + Send + Sync> SymmetricQuantization<E> {\n    /// Initialize a symmetric quantization scheme with the given parameters.\n    pub fn init(scale: E, value: QuantValue) -> Self {\n        Self {\n            scale: valid_scale(scale),\n            value,\n        }\n    }\n\n    #[allow(dead_code)]\n    /// Create a new quantization scheme for an input range `[alpha, beta]`.\n    fn new(alpha: E, beta: E, value: QuantValue) -> Self {\n        let (a, b) = value.range();\n        let a = E::from(a).unwrap();\n        let b = E::from(b).unwrap();\n\n        // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range\n        let alpha = alpha.abs().max(beta.abs());\n        let scale = valid_scale((alpha + alpha) / (b - a));\n        Self { scale, value }\n    }\n}\n\nimpl<E: Float + Send + Sync> Quantization<E> for SymmetricQuantization<E> {\n    fn quantize<Q: PrimInt>(&self, values: &[E]) -> Vec<Q> {\n        values.iter().map(|x| self.quantize_one(*x)).collect()\n    }\n\n    fn dequantize<Q: PrimInt>(&self, values: &[Q]) -> Vec<E> {\n        values.iter().map(|x_q| self.dequantize_one(*x_q)).collect()\n    }\n\n    fn quantize_one<Q: PrimInt>(&self, value: E) -> Q {\n        let (a, b) = self.range();\n\n        // x_q = clamp(round(x / scale), a, b)\n        Q::from(value.div(self.scale).round().clamp(a, b)).unwrap()\n    }\n\n    fn dequantize_one<Q: PrimInt>(&self, value: Q) -> E {\n        // x = scale * x_q\n        self.scale * E::from(value).unwrap()\n    }\n\n    fn range(&self) -> (E, E) {\n        let (a, b) = self.value.range();\n        let a = E::from(a).unwrap();\n        let b = E::from(b).unwrap();\n        (a, b)\n    }\n}\n\nimpl<E: Float + Send + Sync> PartialEq for SymmetricQuantization<E> {\n    fn eq(&self, other: &Self) -> bool {\n        self.scale == other.scale\n    }\n}\n\nimpl<E: Float + Send + Sync> Eq for SymmetricQuantization<E> {}\n\n#[cfg(test)]\nmod tests {\n    use burn_backend::TensorData;\n\n    use super::*;\n    use alloc::vec;\n\n    #[test]\n    fn test_int8_symmetric_quantization() {\n        let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5];\n        let expected_q = vec![-127, -71, 0, 35];\n        let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063];\n\n        let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);\n\n        let q: Vec<i8> = symmetric.quantize(&x);\n        assert_eq!(q, expected_q);\n\n        let d = symmetric.dequantize(&expected_q);\n\n        assert_eq!(d, expected_d);\n    }\n\n    #[test]\n    fn test_int8_symmetric_quantization_per_block() {\n        let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5];\n        let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35];\n        let expected_d = vec![\n            -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063,\n        ];\n\n        let symmetric = SymmetricQuantization::<f32>::new(-1.8, 0.5, QuantValue::Q8S);\n        let strategy = QuantizationStrategy::PerBlockSymmetric(\n            vec![symmetric, symmetric],\n            BlockSize::new([4]),\n        );\n\n        let q: Vec<i8> = strategy.quantize(&x);\n        assert_eq!(q, expected_q);\n\n        let d = symmetric.dequantize(&expected_q);\n\n        assert_eq!(d, expected_d);\n    }\n\n    #[test]\n    fn should_support_dequantize() {\n        let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization {\n            scale: 0.1,\n            value: QuantValue::Q8S,\n        });\n\n        let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]);\n\n        let output = TensorData::new(output, [2, 3]);\n\n        output.assert_approx_eq::<f32>(\n            &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]),\n            Default::default(),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/avgpool.rs",
    "content": "use core::{marker::PhantomData, mem::transmute};\n\nuse crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};\n\nuse burn_backend::DType;\nuse burn_backend::{Element, ElementConversion};\nuse bytemuck::Zeroable;\nuse macerator::{Simd, VAdd, VDiv};\nuse ndarray::{Array4, s};\nuse nhwc::avg_pool_nhwc;\n\nuse super::should_use_simd;\n\n#[macerator::with_simd]\nfn is_accelerated<S: Simd, T: VAdd + VDiv>(_x: PhantomData<T>) -> bool {\n    <T as VAdd>::is_accelerated::<S>() && <T as VDiv>::is_accelerated::<S>()\n}\n\npub(crate) fn try_avg_pool2d_simd<E: Element>(\n    x: SharedArray<E>,\n    ksize: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    with_pad: bool,\n) -> Result<SharedArray<E>, SharedArray<E>> {\n    // Strides must be unit, dilation isn't supported, rows must be contiguous\n    if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) {\n        return Err(x);\n    }\n\n    match E::dtype() {\n        DType::F64 if is_accelerated::<f64>(PhantomData) => Ok(cast(avg_pool_nhwc::<f64>(\n            cast(x),\n            ksize,\n            stride,\n            padding,\n            with_pad,\n        ))),\n        DType::F32 if is_accelerated::<f32>(PhantomData) => Ok(cast(avg_pool_nhwc::<f32>(\n            cast(x),\n            ksize,\n            stride,\n            padding,\n            with_pad,\n        ))),\n        _ => Err(x),\n    }\n}\n\nfn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {\n    unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }\n}\n\nmod nhwc {\n    use itertools::Itertools;\n    use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned};\n    use ndarray::{ArrayView3, ArrayViewMut3};\n    use seq_macro::seq;\n\n    use crate::ops::simd::lanes;\n\n    use super::*;\n\n    // Until you can use associated constants as array size, we need to hardcode this.\n    // The most common config (x86-v3) has 16 registers, so use half of them for accumulators.\n    const BLOCK_REGISTERS: usize = 8;\n\n    pub(crate) fn avg_pool_nhwc<E: Element + VAdd + VDiv>(\n        x: SharedArray<E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        with_pad: bool,\n    ) -> SharedArray<E> {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n        let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();\n        let lanes = lanes::<E>();\n\n        let ch_block = lanes * BLOCK_REGISTERS;\n\n        let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1;\n        let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1;\n\n        let mut output = unsafe {\n            Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()\n        };\n        let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n        let x = x.view();\n        let x = x.permuted_axes(vec![0, 2, 3, 1]);\n\n        // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.\n        // An exclusive loop will always have `lanes * blocking factor` elements in bounds.\n        let blocks = channels / ch_block;\n        let blocks_end = blocks * ch_block;\n        // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An\n        // exclusive loop will always have `lanes` elements in bounds.\n        let simd_end = channels / lanes * lanes;\n        let num_simd_unblocked = (simd_end - blocks_end) / lanes;\n        let remainder = channels - simd_end;\n\n        run_par!(|| {\n            // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.\n            iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {\n                let block = k % blocks;\n                let b = k / blocks;\n\n                let output = unsafe_shared_out.get();\n\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n\n                loop_blocked(x, out, kernel_size, stride, padding, with_pad, block);\n            });\n            // SAFETY: See `loop_unblocked`\n            iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe {\n                let ch = (k % num_simd_unblocked) * lanes + blocks_end;\n                let b = k / num_simd_unblocked;\n\n                let output = unsafe_shared_out.get();\n\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n\n                loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch);\n            });\n            // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.\n            iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {\n                let ch = (k % remainder) + simd_end;\n                let b = k / remainder;\n\n                let output = unsafe_shared_out.get();\n\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n\n                loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch);\n            });\n        });\n\n        output = output.permuted_axes([0, 3, 1, 2]);\n\n        output.into_dyn().into_shared()\n    }\n\n    /// Execute the blocked (unrolled) portion of the pool.\n    #[allow(\n        clippy::too_many_arguments,\n        clippy::erasing_op,\n        clippy::identity_op,\n        unused_mut\n    )]\n    #[macerator::with_simd]\n    fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>(\n        x: ArrayView3<'a, E>,\n        mut out: ArrayViewMut3<'a, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        with_pad: bool,\n        block: usize,\n    ) where\n        'a: 'a,\n    {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n        let lanes = E::lanes::<S>();\n\n        let ch_block = lanes * BLOCK_REGISTERS;\n\n        // If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds\n        for oh in pad_h..out_height.saturating_sub(pad_h) {\n            for ow in pad_w..out_width.saturating_sub(pad_w) {\n                seq!(N in 0..8 {\n                    let mut sum~N: Vector<S, E> = Zeroable::zeroed();\n                });\n                let ch = block * ch_block;\n                let ch_end = ch + ch_block;\n                let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw - pad_w;\n                        let x = x.slice(s![ih, iw, ch..ch_end]);\n\n                        seq!(N in 0..8 {\n                            // SAFETY:\n                            // Load a full vector from x[N * lanes]. This is bounds checked by the\n                            // slice above.\n                            sum~N += unsafe { vload_unaligned(&x[N * lanes]) };\n                        });\n                    }\n                }\n\n                let count = kernel_height * kernel_width;\n                let count = (count as u64).elem::<E>();\n                let count_v = count.splat();\n                seq!(N in 0..8 {\n                    let s~N = sum~N / count_v;\n                    // SAFETY:\n                    // Store a full vector to out[N * lanes]. This is bounds checked by the\n                    // slice above.\n                    unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };\n                });\n            }\n        }\n\n        // Border pixels need bounds checks\n        if (pad_h, pad_w) != (0, 0) {\n            let v_borders = (0..pad_h)\n                .chain(out_height.saturating_sub(pad_h)..out_height)\n                .cartesian_product(0..out_width);\n            let h_borders = (0..out_height)\n                .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));\n\n            for (oh, ow) in v_borders.chain(h_borders) {\n                seq!(N in 0..8 {\n                    let mut sum~N: Vector<S, E> = Zeroable::zeroed();\n                });\n                let mut count: usize = 0;\n                let ch = block * ch_block;\n                let ch_end = ch + ch_block;\n                let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n                        count += 1;\n\n                        let x = x.slice(s![ih, iw, ch..ch_end]);\n\n                        seq!(N in 0..8 {\n                            // SAFETY:\n                            // Load a full vector from x[N * lanes]. This is bounds checked by the\n                            // slice above.\n                            sum~N += unsafe { vload_unaligned(&x[N * lanes]) };\n                        });\n                    }\n                }\n\n                if with_pad {\n                    count = kernel_height * kernel_width;\n                }\n\n                let count = (count as u64).elem::<E>();\n                let count_v = count.splat();\n                seq!(N in 0..8 {\n                    let s~N = sum~N / count_v;\n                    // SAFETY:\n                    // Store a full vector to out[N * lanes]. This is bounds checked by the\n                    // slice above.\n                    unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };\n                });\n            }\n        }\n    }\n\n    /// Execute the unblocked (not unrolled) portion of the pool.\n    ///\n    /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.\n    #[allow(clippy::too_many_arguments, unused_mut)]\n    #[macerator::with_simd]\n    unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>(\n        x: ArrayView3<'a, E>,\n        mut out: ArrayViewMut3<'a, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        with_pad: bool,\n        ch: usize,\n    ) where\n        'a: 'a,\n    {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n\n        // If pixels are not within padding range, bounds checks are always true\n        for oh in pad_h..out_height - pad_h {\n            for ow in pad_w..out_width - pad_w {\n                let mut sum: Vector<S, E> = Zeroable::zeroed();\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw - pad_w;\n                        // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`\n                        let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) };\n                        sum += s0;\n                    }\n                }\n\n                let count = kernel_height * kernel_width;\n                let count: E = (count as u64).elem();\n                let count_v = count.splat();\n                let s0 = sum / count_v;\n                // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.\n                unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };\n            }\n        }\n\n        // Border pixels need bounds checks\n        if (pad_h, pad_w) != (0, 0) {\n            let v_borders = (0..pad_h)\n                .chain(out_height.saturating_sub(pad_h)..out_height)\n                .cartesian_product(0..out_width);\n            let h_borders = (0..out_height)\n                .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));\n\n            for (oh, ow) in v_borders.chain(h_borders) {\n                let mut sum: Vector<S, E> = Zeroable::zeroed();\n                let mut count: usize = 0;\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n                        count += 1;\n\n                        // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`\n                        sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) };\n                    }\n                }\n\n                if with_pad {\n                    count = kernel_height * kernel_width;\n                }\n\n                let count = (count as u64).elem::<E>();\n                let count_v = count.splat();\n                let s0 = sum / count_v;\n                // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.\n                unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) };\n            }\n        }\n    }\n\n    /// Execute scalar portion of the pooling\n    #[allow(clippy::too_many_arguments)]\n    fn loop_scalar<E: Element + VAdd + VDiv>(\n        x: ArrayView3<'_, E>,\n        mut out: ArrayViewMut3<'_, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        with_pad: bool,\n        ch: usize,\n    ) {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n\n        // If pixels are not within padding range, bounds checks are always true\n        for oh in pad_h..out_height.saturating_sub(pad_h) {\n            for ow in pad_w..out_width.saturating_sub(pad_w) {\n                let mut sum: E = Zeroable::zeroed();\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw - pad_w;\n                        sum = sum + x[[ih, iw, ch]];\n                    }\n                }\n\n                let count = (kernel_height * kernel_width) as u64;\n                out[[oh, ow, ch]] = sum / count.elem();\n            }\n        }\n\n        // Border pixels need bounds checks\n        if (pad_h, pad_w) != (0, 0) {\n            let v_borders = (0..pad_h)\n                .chain(out_height.saturating_sub(pad_h)..out_height)\n                .cartesian_product(0..out_width);\n            let h_borders = (0..out_height)\n                .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));\n\n            for (oh, ow) in v_borders.chain(h_borders) {\n                let mut sum: E = Zeroable::zeroed();\n                let mut count: usize = 0;\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n                        count += 1;\n                        sum = sum + x[[ih, iw, ch]];\n                    }\n                }\n\n                if with_pad {\n                    count = kernel_height * kernel_width;\n                }\n\n                out[[oh, ow, ch]] = sum / (count as u64).elem();\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/base.rs",
    "content": "use core::{marker::PhantomData, mem::MaybeUninit};\n\nuse macerator::{Arch, Scalar, Simd};\nuse ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder};\n\n/// Whether SIMD instructions are worth using\n#[cfg(all(\n    any(\n        target_arch = \"x86\",\n        target_arch = \"x86_64\",\n        target_arch = \"aarch64\",\n        target_arch = \"wasm32\",\n        target_arch = \"loongarch64\"\n    ),\n    not(test)\n))]\npub fn should_use_simd(len: usize) -> bool {\n    len >= 32\n}\n\n/// Whether SIMD instructions are worth using\n#[cfg(all(\n    not(any(\n        target_arch = \"x86\",\n        target_arch = \"x86_64\",\n        target_arch = \"aarch64\",\n        target_arch = \"wasm32\",\n        target_arch = \"loongarch64\"\n    )),\n    not(test)\n))]\npub fn should_use_simd(_len: usize) -> bool {\n    false\n}\n\n#[cfg(test)]\npub fn should_use_simd(_len: usize) -> bool {\n    true\n}\n\npub(crate) fn lanes<E: Scalar>() -> usize {\n    #[allow(non_camel_case_types)]\n    struct lanes<__T0>(__T0);\n\n    impl<E: Scalar> ::macerator::WithSimd for lanes<PhantomData<E>> {\n        type Output = usize;\n        #[inline(always)]\n        fn with_simd<__S: ::macerator::Simd>(self) -> <Self as ::macerator::WithSimd>::Output {\n            let Self(__ty) = self;\n            #[allow(unused_unsafe)]\n            unsafe {\n                lanes_simd::<__S, E>(__ty)\n            }\n        }\n    }\n    (Arch::new()).dispatch(lanes(PhantomData::<E>))\n}\n\nfn lanes_simd<S: Simd, E: Scalar>(_ty: PhantomData<E>) -> usize {\n    E::lanes::<S>()\n}\n\npub(crate) fn uninit_array_like<In, Out>(reference: &ArcArray<In, IxDyn>) -> ArrayD<Out> {\n    let shape = reference.raw_dim();\n    let strides = reference.strides();\n    let strides = strides.iter().map(|it| *it as usize).collect::<Vec<_>>();\n    let shape_strides = shape.strides(IxDyn(&strides));\n    let size = reference.len();\n    let mut out_data: Vec<MaybeUninit<Out>> = Vec::with_capacity(size);\n    unsafe { out_data.set_len(size) };\n    unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() }\n}\n\npub trait MinMax {\n    fn min(self, other: Self) -> Self;\n    fn max(self, other: Self) -> Self;\n}\n\nmacro_rules! impl_minmax {\n    ($ty: ty) => {\n        impl MinMax for $ty {\n            fn min(self, other: Self) -> Self {\n                Ord::min(self, other)\n            }\n            fn max(self, other: Self) -> Self {\n                Ord::max(self, other)\n            }\n        }\n    };\n    ($($ty: ty),*) => {\n        $(impl_minmax!($ty);)*\n    }\n}\n\nimpl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);\n\nimpl MinMax for f32 {\n    fn min(self, other: Self) -> Self {\n        self.min(other)\n    }\n\n    fn max(self, other: Self) -> Self {\n        self.max(other)\n    }\n}\n\nimpl MinMax for f64 {\n    fn min(self, other: Self) -> Self {\n        self.min(other)\n    }\n\n    fn max(self, other: Self) -> Self {\n        self.max(other)\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/binary.rs",
    "content": "use core::{marker::PhantomData, slice};\n\nuse burn_backend::Element;\nuse macerator::{\n    Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned,\n    vstore_unaligned,\n};\nuse ndarray::ArrayD;\nuse seq_macro::seq;\n\nuse crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};\n\nuse super::{\n    MinMax,\n    binary_elemwise::{\n        VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub,\n    },\n    should_use_simd,\n};\n\npub trait SimdBinop<T: Scalar, Out: Scalar> {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, Out>;\n    fn apply(lhs: T, rhs: T) -> Out;\n    fn is_accelerated<S: Simd>() -> bool;\n}\n\nimpl<T: VAdd> SimdBinop<T, T> for VecAdd {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs + rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs + rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VAdd>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VDiv> SimdBinop<T, T> for VecDiv {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs / rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs / rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VDiv>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VMul> SimdBinop<T, T> for VecMul {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs * rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs * rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VMul>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VSub> SimdBinop<T, T> for VecSub {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs - rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs - rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VSub>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VOrd + MinMax> SimdBinop<T, T> for VecMin {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs.min(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        MinMax::min(lhs, rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_min_max_accelerated::<S>()\n    }\n}\n\nimpl<T: VOrd + MinMax> SimdBinop<T, T> for VecMax {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs.max(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        MinMax::max(lhs, rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_min_max_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitAnd> SimdBinop<T, T> for VecBitAnd {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs & rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs.bitand(rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitAnd>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitOr> SimdBinop<T, T> for VecBitOr {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs | rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs.bitor(rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitOr>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitXor> SimdBinop<T, T> for VecBitXor {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Vector<S, T> {\n        lhs ^ rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs.bitxor(rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitXor>::is_accelerated::<S>()\n    }\n}\n\n#[macerator::with_simd]\nfn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdBinop<T, Out>>(\n    _x: PhantomData<(T, Out, Op)>,\n) -> bool {\n    Op::is_accelerated::<S>()\n}\n\n#[allow(clippy::result_large_err)]\npub fn try_binary_simd<\n    E: Element,\n    EOut: Element,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdBinop<T, Out>,\n>(\n    lhs: SharedArray<E>,\n    rhs: SharedArray<E>,\n) -> Result<SharedArray<EOut>, (SharedArray<E>, SharedArray<E>)> {\n    let lhs_len = lhs.len();\n    let rhs_len = rhs.len();\n    if !should_use_simd(lhs_len.max(rhs_len))\n        || !lhs.is_standard_layout()\n        || !rhs.is_standard_layout()\n        || lhs.shape() != rhs.shape()\n        || !is_accelerated::<T, Out, Op>(PhantomData)\n    {\n        return Err((lhs, rhs));\n    }\n    // Used to assert traits based on the dynamic `DType`.\n    let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };\n    let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };\n    let out = binary_simd_same::<T, Out, Op>(lhs, rhs);\n\n    // Used to assert traits based on the dynamic `DType`.\n    let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };\n    Ok(out)\n}\n\nfn binary_simd_same<\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdBinop<T, Out>,\n>(\n    lhs: SharedArray<T>,\n    rhs: SharedArray<T>,\n) -> SharedArray<Out> {\n    let out = if lhs.is_unique() {\n        let mut buf = lhs.into_owned();\n        let lhs = buf.as_slice_mut().unwrap();\n        let rhs = rhs.as_slice().unwrap();\n        let out =\n            unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) };\n        binary(lhs, rhs, out, PhantomData::<Op>);\n        unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }\n    } else if rhs.is_unique() {\n        let mut buf = rhs.into_owned();\n        let lhs = lhs.as_slice().unwrap();\n        let rhs = buf.as_slice_mut().unwrap();\n        let out =\n            unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) };\n        binary(lhs, rhs, out, PhantomData::<Op>);\n        unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buf) }\n    } else {\n        let mut out = uninit_array_like(&lhs);\n        let lhs = lhs.as_slice().unwrap();\n        let rhs = rhs.as_slice().unwrap();\n        let out_slice = out.as_slice_mut().unwrap();\n        binary(lhs, rhs, out_slice, PhantomData::<Op>);\n        out\n    };\n    out.into_shared()\n}\n\n#[allow(clippy::erasing_op, clippy::identity_op)]\n#[macerator::with_simd]\nfn binary<\n    'a,\n    S: Simd,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdBinop<T, Out>,\n>(\n    lhs: &'a [T],\n    rhs: &'a [T],\n    out: &'a mut [Out],\n    _op: PhantomData<Op>,\n) where\n    'a: 'a,\n{\n    let lanes = T::lanes::<S>();\n    let mut chunks_lhs = lhs.chunks_exact(8 * lanes);\n    let mut chunks_rhs = rhs.chunks_exact(8 * lanes);\n    let mut chunks_out = out.chunks_exact_mut(8 * lanes);\n    while let Some(((lhs, rhs), out)) = chunks_lhs\n        .next()\n        .zip(chunks_rhs.next())\n        .zip(chunks_out.next())\n    {\n        seq!(N in 0..8 {\n            // Load one full vector from `lhs`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };\n            // Load one full vector from `rhs`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };\n            let s~N = Op::apply_vec(lhs~N, rhs~N);\n            // Store one full vector to `out`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };\n        });\n    }\n    let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);\n    let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);\n    let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);\n    while let Some(((lhs, rhs), out)) = chunks_lhs\n        .next()\n        .zip(chunks_rhs.next())\n        .zip(chunks_out.next())\n    {\n        // Load one full vector from `lhs`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };\n        // Load one full vector from `rhs`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };\n        let s0 = Op::apply_vec(lhs0, rhs0);\n        // Store one full vector to `out`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };\n    }\n\n    for ((lhs, rhs), out) in chunks_lhs\n        .remainder()\n        .iter()\n        .zip(chunks_rhs.remainder())\n        .zip(chunks_out.into_remainder())\n    {\n        *out = Op::apply(*lhs, *rhs)\n    }\n}\n\n/// Unsafely alias a slice to use as an inline argument\nfn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {\n    let ptr = slice.as_mut_ptr();\n    let len = slice.len();\n    unsafe { slice::from_raw_parts_mut(ptr, len) }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/binary_elemwise.rs",
    "content": "use core::marker::PhantomData;\n\nuse bytemuck::cast;\nuse macerator::{\n    Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload,\n    vload_unaligned, vstore, vstore_unaligned,\n};\nuse ndarray::ArrayD;\nuse seq_macro::seq;\n\nuse crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};\n\nuse super::{MinMax, should_use_simd};\n\npub trait ScalarSimdBinop<T: Scalar, Out: Scalar> {\n    type Rhs: Copy;\n    type RhsVec<S: Simd>: Copy;\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S>;\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, Out>;\n    fn apply(lhs: T, rhs: Self::Rhs) -> Out;\n    fn is_accelerated<S: Simd>() -> bool;\n}\n\npub struct VecAdd;\npub struct VecDiv;\npub struct VecMul;\npub struct VecSub;\npub struct VecMin;\npub struct VecMax;\npub struct VecClamp;\npub struct VecBitAnd;\npub struct VecBitOr;\npub struct VecBitXor;\n\nimpl<T: VAdd> ScalarSimdBinop<T, T> for VecAdd {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs + rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs + rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VAdd>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VDiv> ScalarSimdBinop<T, T> for VecDiv {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs / rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs / rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VDiv>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VMul> ScalarSimdBinop<T, T> for VecMul {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs * rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs * rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VMul>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VSub> ScalarSimdBinop<T, T> for VecSub {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs - rhs\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs - rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VSub>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMin {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs.min(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs.min(rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_min_max_accelerated::<S>()\n    }\n}\n\nimpl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecMax {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs.max(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> T {\n        lhs.max(rhs)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_min_max_accelerated::<S>()\n    }\n}\n\nimpl<T: VOrd + MinMax> ScalarSimdBinop<T, T> for VecClamp {\n    type Rhs = (T, T);\n    type RhsVec<S: Simd> = (Vector<S, T>, Vector<S, T>);\n\n    fn splat<S: Simd>((min, max): Self::Rhs) -> Self::RhsVec<S> {\n        (min.splat(), max.splat())\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, (min, max): Self::RhsVec<S>) -> Vector<S, T> {\n        lhs.min(max).max(min)\n    }\n\n    fn apply(lhs: T, (min, max): Self::Rhs) -> T {\n        lhs.min(max).max(min)\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_min_max_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitAnd> ScalarSimdBinop<T, T> for VecBitAnd {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs & rhs\n    }\n\n    fn apply(lhs: T, rhs: Self::Rhs) -> T {\n        lhs & rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitAnd>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitOr> ScalarSimdBinop<T, T> for VecBitOr {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs | rhs\n    }\n\n    fn apply(lhs: T, rhs: Self::Rhs) -> T {\n        lhs | rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitOr>::is_accelerated::<S>()\n    }\n}\n\nimpl<T: VBitXor> ScalarSimdBinop<T, T> for VecBitXor {\n    type Rhs = T;\n    type RhsVec<S: Simd> = Vector<S, T>;\n\n    fn splat<S: Simd>(rhs: Self::Rhs) -> Self::RhsVec<S> {\n        rhs.splat()\n    }\n\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Self::RhsVec<S>) -> Vector<S, T> {\n        lhs ^ rhs\n    }\n\n    fn apply(lhs: T, rhs: Self::Rhs) -> T {\n        lhs ^ rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitXor>::is_accelerated::<S>()\n    }\n}\n\n#[macerator::with_simd]\nfn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: ScalarSimdBinop<T, Out>>(\n    _x: PhantomData<(T, Out, Op)>,\n) -> bool {\n    Op::is_accelerated::<S>()\n}\n\npub fn try_binary_scalar_simd<\n    E: NdArrayElement,\n    EOut: NdArrayElement,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: ScalarSimdBinop<T, Out>,\n>(\n    input: SharedArray<E>,\n    elem: Op::Rhs,\n) -> Result<SharedArray<EOut>, SharedArray<E>> {\n    if !should_use_simd(input.len())\n        || input.as_slice_memory_order().is_none()\n        || !is_accelerated::<T, Out, Op>(PhantomData)\n    {\n        return Err(input);\n    }\n    // Used to assert traits based on the dynamic `DType`.\n    let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };\n    let out = if size_of::<T>() == size_of::<Out>()\n        && align_of::<T>() >= align_of::<Out>()\n        && input.is_unique()\n    {\n        unsafe { binary_scalar_simd_inplace::<T, Out, Op>(input, elem) }\n    } else {\n        binary_scalar_simd_owned::<T, Out, Op>(input, elem)\n    };\n    // Used to assert traits based on the dynamic `DType`.\n    let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };\n    Ok(out)\n}\n\n/// Execute operation in place on an owned tensor\n/// SAFETY:\n/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\nunsafe fn binary_scalar_simd_inplace<\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: ScalarSimdBinop<T, Out>,\n>(\n    input: SharedArray<T>,\n    elem: Op::Rhs,\n) -> SharedArray<Out> {\n    let mut buffer = input.into_owned();\n    let slice = buffer.as_slice_memory_order_mut().unwrap();\n    unsafe { binary_scalar_slice_inplace::<T, Out, Op>(slice, elem, PhantomData) };\n    // Buffer has the same elem size and is filled with the operation output, so this is safe\n    let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };\n    out.into_shared()\n}\n\n/// Create a new copy of the tensor as the output\nfn binary_scalar_simd_owned<\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: ScalarSimdBinop<T, Out>,\n>(\n    input: SharedArray<T>,\n    elem: Op::Rhs,\n) -> SharedArray<Out> {\n    let mut out = uninit_array_like(&input);\n    let input = input.as_slice_memory_order().unwrap();\n    let out_slice = out.as_slice_memory_order_mut().unwrap();\n    binary_scalar_slice::<T, Out, Op>(input, out_slice, elem, PhantomData);\n    out.into_shared()\n}\n\n#[inline(always)]\n#[allow(clippy::erasing_op, clippy::identity_op)]\n#[macerator::with_simd]\nfn binary_scalar_slice<\n    'a,\n    S: Simd,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: ScalarSimdBinop<T, Out>,\n>(\n    input: &'a [T],\n    out: &'a mut [Out],\n    rhs: Op::Rhs,\n    _op: PhantomData<Op>,\n) where\n    'a: 'a,\n{\n    let lanes = T::lanes::<S>();\n    let mut chunks_input = input.chunks_exact(8 * lanes);\n    let mut chunks_out = out.chunks_exact_mut(8 * lanes);\n    let rhs_vec = Op::splat::<S>(rhs);\n    while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n        seq!(N in 0..8 {\n            // Load one full vector from `input`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let s~N = unsafe { vload_unaligned(&input[N * lanes]) };\n            let s~N = Op::apply_vec(s~N, rhs_vec);\n            // Store one full vector to `out`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };\n        });\n    }\n    let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);\n    let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);\n    while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n        // Load one full vector from `input`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let s0 = unsafe { vload_unaligned(input.as_ptr()) };\n        let s0 = Op::apply_vec(s0, rhs_vec);\n        // Store one full vector to `out`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };\n    }\n\n    for (input, out) in chunks_input\n        .remainder()\n        .iter()\n        .zip(chunks_out.into_remainder())\n    {\n        *out = Op::apply(*input, rhs)\n    }\n}\n\n/// Execute operation in line.\n/// SAFETY:\n/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\n#[inline(always)]\n#[macerator::with_simd]\nunsafe fn binary_scalar_slice_inplace<\n    'a,\n    S: Simd,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: ScalarSimdBinop<T, Out>,\n>(\n    buf: &'a mut [T],\n    rhs: Op::Rhs,\n    _op: PhantomData<(Out, Op)>,\n) where\n    'a: 'a,\n{\n    let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };\n    for elem in head.iter_mut().chain(tail) {\n        *elem = cast(Op::apply(*elem, rhs));\n    }\n    let mut chunks = main.chunks_exact_mut(8);\n    let rhs = Op::splat::<S>(rhs);\n    for elem in chunks.by_ref() {\n        seq!(N in 0..8 {\n            // Load a full vector from the aligned portion of the buffer.\n            // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n            // always a full vector in bounds.\n            let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };\n            let s~N = Op::apply_vec(s~N, rhs);\n            // Store a full vector at the same position as the input. Cast is safe because `Out` is\n            // size and align compatible\n            unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) };\n        });\n    }\n\n    for elem in chunks.into_remainder() {\n        // Load a full vector from the aligned portion of the buffer.\n        // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n        // always a full vector in bounds.\n        let s0 = unsafe { vload(elem as *const _ as *const T) };\n\n        let s0 = Op::apply_vec(s0, rhs);\n        // Store a full vector at the same position as the input. Cast is safe because `Out` is\n        // size and align compatible\n        unsafe { vstore(elem as *mut _ as *mut Out, s0) };\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/cmp.rs",
    "content": "use core::{marker::PhantomData, slice};\n\nuse burn_backend::Element;\nuse macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned};\nuse ndarray::ArrayD;\nuse seq_macro::seq;\n\nuse crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like};\n\nuse super::should_use_simd;\n\npub trait SimdCmpOp<T: Scalar> {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T>;\n    fn apply(lhs: T, rhs: T) -> bool;\n    fn is_accelerated<S: Simd>() -> bool;\n}\n\npub struct VecEquals;\n\nimpl<T: VEq> SimdCmpOp<T> for VecEquals {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {\n        lhs.eq(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> bool {\n        lhs == rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VEq>::is_accelerated::<S>()\n    }\n}\n\npub struct VecGreater;\n\nimpl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreater {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {\n        lhs.gt(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> bool {\n        lhs > rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_cmp_accelerated::<S>()\n    }\n}\n\npub struct VecGreaterEq;\n\nimpl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecGreaterEq {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {\n        lhs.ge(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> bool {\n        lhs >= rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_cmp_accelerated::<S>()\n    }\n}\n\npub struct VecLowerEq;\n\nimpl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLowerEq {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {\n        lhs.le(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> bool {\n        lhs <= rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_cmp_accelerated::<S>()\n    }\n}\n\npub struct VecLower;\n\nimpl<T: VOrd + PartialOrd> SimdCmpOp<T> for VecLower {\n    fn apply_vec<S: Simd>(lhs: Vector<S, T>, rhs: Vector<S, T>) -> Mask<S, T> {\n        lhs.lt(rhs)\n    }\n\n    fn apply(lhs: T, rhs: T) -> bool {\n        lhs < rhs\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VOrd>::is_cmp_accelerated::<S>()\n    }\n}\n\n#[macerator::with_simd]\nfn is_accelerated<S: Simd, T: Scalar, Op: SimdCmpOp<T>>(_x: PhantomData<(T, Op)>) -> bool {\n    Op::is_accelerated::<S>()\n}\n\n#[allow(clippy::result_large_err)]\npub fn try_cmp_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n    lhs: SharedArray<E>,\n    rhs: SharedArray<E>,\n) -> Result<SharedArray<bool>, (SharedArray<E>, SharedArray<E>)> {\n    let lhs_len = lhs.len();\n    let rhs_len = rhs.len();\n    if !should_use_simd(lhs_len.max(rhs_len))\n        || !lhs.is_standard_layout()\n        || !rhs.is_standard_layout()\n        || lhs.shape() != rhs.shape()\n        || !is_accelerated::<T, Op>(PhantomData)\n    {\n        return Err((lhs, rhs));\n    }\n    // Used to assert traits based on the dynamic `DType`.\n    let lhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(lhs) };\n    let rhs = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(rhs) };\n    let out = cmp_simd_same::<T, Op>(lhs, rhs);\n\n    Ok(out)\n}\n\nfn cmp_simd_same<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n    lhs: SharedArray<T>,\n    rhs: SharedArray<T>,\n) -> SharedArray<bool> {\n    let out = if lhs.is_unique() && size_of::<T>() == size_of::<bool>() {\n        let mut buf = lhs.into_owned();\n        let lhs = buf.as_slice_mut().unwrap();\n        let rhs = rhs.as_slice().unwrap();\n        let out =\n            unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) };\n        cmp(lhs, rhs, out, PhantomData::<Op>);\n        unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }\n    } else if rhs.is_unique() && size_of::<T>() == size_of::<bool>() {\n        let mut buf = rhs.into_owned();\n        let lhs = lhs.as_slice().unwrap();\n        let rhs = buf.as_slice_mut().unwrap();\n        let out =\n            unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) };\n        cmp(lhs, rhs, out, PhantomData::<Op>);\n        unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buf) }\n    } else {\n        let mut out = uninit_array_like(&lhs);\n        let lhs = lhs.as_slice().unwrap();\n        let rhs = rhs.as_slice().unwrap();\n        let out_slice = out.as_slice_mut().unwrap();\n        cmp(lhs, rhs, out_slice, PhantomData::<Op>);\n        out\n    };\n    out.into_shared()\n}\n\n#[allow(clippy::erasing_op, clippy::identity_op)]\n#[macerator::with_simd]\nfn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n    lhs: &'a [T],\n    rhs: &'a [T],\n    out: &'a mut [bool],\n    _op: PhantomData<Op>,\n) where\n    'a: 'a,\n{\n    let lanes = T::lanes::<S>();\n    let mut chunks_lhs = lhs.chunks_exact(8 * lanes);\n    let mut chunks_rhs = rhs.chunks_exact(8 * lanes);\n    let mut chunks_out = out.chunks_exact_mut(8 * lanes);\n    while let Some(((lhs, rhs), out)) = chunks_lhs\n        .next()\n        .zip(chunks_rhs.next())\n        .zip(chunks_out.next())\n    {\n        seq!(N in 0..8 {\n            // Load one full vector from `lhs`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let lhs~N = unsafe { vload_unaligned::<S, _>(&lhs[N * lanes]) };\n            // Load one full vector from `rhs`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) };\n            let s~N = Op::apply_vec(lhs~N, rhs~N);\n            // Store one full vector to `out`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };\n        });\n    }\n    let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes);\n    let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes);\n    let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);\n    while let Some(((lhs, rhs), out)) = chunks_lhs\n        .next()\n        .zip(chunks_rhs.next())\n        .zip(chunks_out.next())\n    {\n        // Load one full vector from `lhs`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let lhs0 = unsafe { vload_unaligned::<S, _>(lhs.as_ptr()) };\n        // Load one full vector from `rhs`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) };\n        let s0 = Op::apply_vec(lhs0, rhs0);\n        // Store one full vector to `out`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };\n    }\n\n    for ((lhs, rhs), out) in chunks_lhs\n        .remainder()\n        .iter()\n        .zip(chunks_rhs.remainder())\n        .zip(chunks_out.into_remainder())\n    {\n        *out = Op::apply(*lhs, *rhs)\n    }\n}\n\n/// Unsafely alias a slice to use as an inline argument\nfn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] {\n    let ptr = slice.as_mut_ptr();\n    let len = slice.len();\n    unsafe { slice::from_raw_parts_mut(ptr, len) }\n}\n\npub use elemwise::try_cmp_scalar_simd;\n\nmod elemwise {\n    use bytemuck::cast;\n    use macerator::vload;\n\n    use super::*;\n\n    pub fn try_cmp_scalar_simd<E: Element, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n        input: SharedArray<E>,\n        elem: T,\n    ) -> Result<SharedArray<bool>, SharedArray<E>> {\n        if !should_use_simd(input.len())\n            || input.as_slice_memory_order().is_none()\n            || !is_accelerated::<T, Op>(PhantomData)\n        {\n            return Err(input);\n        }\n        // Used to assert traits based on the dynamic `DType`.\n        let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };\n        let out = if size_of::<T>() == size_of::<bool>()\n            && align_of::<T>() >= align_of::<bool>()\n            && input.is_unique()\n        {\n            unsafe { cmp_scalar_simd_inplace::<T, Op>(input, elem) }\n        } else {\n            cmp_scalar_simd_owned::<T, Op>(input, elem)\n        };\n        Ok(out)\n    }\n\n    /// Execute operation in place on an owned tensor\n    /// SAFETY:\n    /// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\n    unsafe fn cmp_scalar_simd_inplace<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n        input: SharedArray<T>,\n        elem: T,\n    ) -> SharedArray<bool> {\n        let mut buffer = input.into_owned();\n        let slice = buffer.as_slice_memory_order_mut().unwrap();\n        unsafe { cmp_scalar_slice_inplace::<T, Op>(slice, elem, PhantomData) };\n        // Buffer has the same elem size and is filled with the operation output, so this is safe\n        let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<bool>>(buffer) };\n        out.into_shared()\n    }\n\n    /// Create a new copy of the tensor as the output\n    fn cmp_scalar_simd_owned<T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n        input: SharedArray<T>,\n        elem: T,\n    ) -> SharedArray<bool> {\n        let mut out = uninit_array_like(&input);\n        let input = input.as_slice_memory_order().unwrap();\n        let out_slice = out.as_slice_memory_order_mut().unwrap();\n        cmp_scalar_slice::<T, Op>(input, out_slice, elem, PhantomData);\n        out.into_shared()\n    }\n\n    #[inline(always)]\n    #[allow(clippy::erasing_op, clippy::identity_op)]\n    #[macerator::with_simd]\n    fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n        input: &'a [T],\n        out: &'a mut [bool],\n        rhs: T,\n        _op: PhantomData<Op>,\n    ) where\n        'a: 'a,\n    {\n        let lanes = T::lanes::<S>();\n        let mut chunks_input = input.chunks_exact(8 * lanes);\n        let mut chunks_out = out.chunks_exact_mut(8 * lanes);\n        let rhs_vec = rhs.splat::<S>();\n        while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n            seq!(N in 0..8 {\n                // Load one full vector from `input`.\n                // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n                let s~N = unsafe { vload_unaligned(&input[N * lanes]) };\n                let s~N = Op::apply_vec(s~N, rhs_vec);\n                // Store one full vector to `out`.\n                // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n                unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) };\n            });\n        }\n        let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);\n        let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);\n        while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n            // Load one full vector from `input`.\n            // SAFETY: Guaranteed to be in bounds because `len == lanes`\n            let s0 = unsafe { vload_unaligned(input.as_ptr()) };\n            let s0 = Op::apply_vec(s0, rhs_vec);\n            // Store one full vector to `out`.\n            // SAFETY: Guaranteed to be in bounds because `len == lanes`\n            unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) };\n        }\n\n        for (input, out) in chunks_input\n            .remainder()\n            .iter()\n            .zip(chunks_out.into_remainder())\n        {\n            *out = Op::apply(*input, rhs)\n        }\n    }\n\n    /// Execute operation in line.\n    /// SAFETY:\n    /// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\n    #[inline(always)]\n    #[macerator::with_simd]\n    unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp<T>>(\n        buf: &'a mut [T],\n        rhs: T,\n        _op: PhantomData<Op>,\n    ) where\n        'a: 'a,\n    {\n        let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };\n        for elem in head.iter_mut().chain(tail) {\n            *elem = cast(Op::apply(*elem, rhs));\n        }\n        let mut chunks = main.chunks_exact_mut(8);\n        let rhs = rhs.splat::<S>();\n        for elem in chunks.by_ref() {\n            seq!(N in 0..8 {\n                // Load a full vector from the aligned portion of the buffer.\n                // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n                // always a full vector in bounds.\n                let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };\n                let s~N = Op::apply_vec(s~N, rhs);\n                // Store a full vector at the same position as the input. Cast is safe because `Out` is\n                // size and align compatible\n                unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) };\n            });\n        }\n\n        for elem in chunks.into_remainder() {\n            // Load a full vector from the aligned portion of the buffer.\n            // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n            // always a full vector in bounds.\n            let s0 = unsafe { vload(elem as *const _ as *const T) };\n\n            let s0 = Op::apply_vec(s0, rhs);\n            // Store a full vector at the same position as the input. Cast is safe because `Out` is\n            // size and align compatible\n            unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) };\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/conv.rs",
    "content": "use core::{marker::PhantomData, mem::transmute};\n\nuse burn_backend::{\n    DType, Element,\n    ops::{ConvOptions, conv::calculate_conv_output_size},\n};\nuse bytemuck::Zeroable;\nuse macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned};\nuse ndarray::{\n    ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s,\n};\nuse seq_macro::seq;\n\nuse crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par};\n\ntype Args<E> = (SharedArray<E>, SharedArray<E>, Option<SharedArray<E>>);\n\n#[allow(clippy::result_large_err)]\npub fn try_conv2d_simd<E: FloatNdArrayElement>(\n    x: SharedArray<E>,\n    weight: SharedArray<E>,\n    bias: Option<SharedArray<E>>,\n    options: ConvOptions<2>,\n) -> Result<SharedArray<E>, Args<E>> {\n    match E::dtype() {\n        DType::F64 => conv2d::<f64, _>(x, weight, bias, options, PhantomData),\n        DType::F32 => conv2d::<f32, _>(x, weight, bias, options, PhantomData),\n        DType::I64 => conv2d::<i64, _>(x, weight, bias, options, PhantomData),\n        DType::I32 => conv2d::<i32, _>(x, weight, bias, options, PhantomData),\n        DType::I16 => conv2d::<i16, _>(x, weight, bias, options, PhantomData),\n        DType::U64 => conv2d::<u64, _>(x, weight, bias, options, PhantomData),\n        DType::U32 => conv2d::<u32, _>(x, weight, bias, options, PhantomData),\n        DType::U16 => conv2d::<u16, _>(x, weight, bias, options, PhantomData),\n        _ => Err((x, weight, bias)),\n    }\n}\n\nfn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {\n    unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }\n}\n\n/// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on\n/// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018).\n/// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures.\n/// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. <https://arxiv.org/abs/1808.05567>.\n#[allow(clippy::result_large_err)]\nfn conv2d<E: VMulAdd + Element, T: Element>(\n    x: SharedArray<T>,\n    weight: SharedArray<T>,\n    bias: Option<SharedArray<T>>,\n    options: ConvOptions<2>,\n    _ty: PhantomData<E>,\n) -> Result<SharedArray<T>, Args<T>> {\n    let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap();\n    let channels_per_group = out_channels / options.groups;\n\n    #[macerator::with_simd]\n    fn precheck<S: Simd, E: VMulAdd>(_ty: PhantomData<E>) -> (usize, bool) {\n        (E::lanes::<S>(), E::is_accelerated::<S>())\n    }\n\n    let (lanes, accelerated) = precheck::<E>(PhantomData);\n\n    if !accelerated || !channels_per_group.is_multiple_of(lanes) {\n        return Err((x, weight, bias));\n    }\n\n    let x = cast::<_, E>(x);\n    let weight = cast::<_, E>(weight);\n    let bias = bias.map(|bias| cast::<_, E>(bias));\n\n    let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap();\n    let [dilate_h, dilate_w] = options.dilation;\n    let [stride_h, stride_w] = options.stride;\n    let [pad_h, pad_w] = options.padding;\n    let padded = options.padding != [0, 0];\n    let strided = options.stride != [1, 1] || options.dilation != [1, 1];\n    let grouped = options.groups != 1;\n\n    let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height);\n    let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width);\n\n    let x = x.into_dimensionality::<Ix4>().unwrap();\n    let weights = weight.into_dimensionality::<Ix4>().unwrap();\n    let weights = weights.permuted_axes([1, 2, 3, 0]);\n    let weights = weights.as_standard_layout();\n    let bias = bias.map(|bias| bias.into_dimensionality::<Ix1>().unwrap());\n    // floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`.\n    let oc_blocks = out_channels / lanes;\n\n    let mut out = unsafe {\n        Array4::<E>::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init()\n    };\n    let unsafe_shared_out = UnsafeSharedRef::new(&mut out);\n\n    run_par!(|| {\n        // SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference\n        // is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function.\n        iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe {\n            let b = k / oc_blocks;\n            let ob = k % oc_blocks;\n            let x = x.slice(s![b, .., .., ..]);\n            let out = unsafe_shared_out.get();\n            let mut out = out.slice_mut(s![b, .., .., ..]);\n            let w = weights.view();\n\n            match (padded, strided, grouped) {\n                (true, true, true) => {\n                    conv2d_launch::<E, true, true, true>(x, w, &bias, &mut out, &options, ob)\n                }\n                (true, false, true) => {\n                    conv2d_launch::<E, true, false, true>(x, w, &bias, &mut out, &options, ob)\n                }\n                (false, true, true) => {\n                    conv2d_launch::<E, false, true, true>(x, w, &bias, &mut out, &options, ob)\n                }\n                (false, false, true) => {\n                    conv2d_launch::<E, false, false, true>(x, w, &bias, &mut out, &options, ob)\n                }\n                (true, true, false) => {\n                    conv2d_launch::<E, true, true, false>(x, w, &bias, &mut out, &options, ob)\n                }\n                (true, false, false) => {\n                    conv2d_launch::<E, true, false, false>(x, w, &bias, &mut out, &options, ob)\n                }\n                (false, true, false) => {\n                    conv2d_launch::<E, false, true, false>(x, w, &bias, &mut out, &options, ob)\n                }\n                (false, false, false) => {\n                    conv2d_launch::<E, false, false, false>(x, w, &bias, &mut out, &options, ob)\n                }\n            }\n        });\n    });\n\n    let output = out.permuted_axes([0, 3, 1, 2]);\n    Ok(cast(output.into_dyn().into_shared()))\n}\n\n/// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support\n/// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might\n/// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16).\n/// This should always be conservative, since oversizing it will cause register spills and that's\n/// **much** worse than the performance lost with lower values.\nconst REGISTER_BLOCK: usize = 8;\ninner_with_register_blocking_size!(8);\n\n/// Run a loop of conv2d.\n/// # SAFETY\n/// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`.\n/// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and\n/// `out` must have unit stride for the out channels.\n#[inline(always)]\n#[macerator::with_simd]\nunsafe fn conv2d_launch<\n    'a,\n    S: Simd,\n    E: VMulAdd,\n    const PAD: bool,\n    const STRIDE: bool,\n    const GROUPS: bool,\n>(\n    x: ArrayView3<'a, E>,\n    weights: ArrayView4<'a, E>,\n    bias: &'a Option<ArcArray1<E>>,\n    out: &'a mut ArrayViewMut3<'a, E>,\n    options: &'a ConvOptions<2>,\n    ob: usize,\n) where\n    'a: 'a,\n{\n    let (in_channels, k_height, k_width, out_channels) = weights.dim();\n    let (out_height, out_width, _) = out.dim();\n    let channels_per_group = out_channels / options.groups;\n    let lanes = E::lanes::<S>();\n\n    let [mut pad_h, mut pad_w] = options.padding;\n    let [stride_h, stride_w] = options.stride;\n    let [dilate_h, dilate_w] = options.dilation;\n\n    // Trick compiler into inlining 0 to padding\n    if !PAD {\n        pad_h = 0;\n        pad_w = 0;\n    }\n\n    let oc_b = channels_per_group.min(lanes);\n    let ow_b = REGISTER_BLOCK;\n\n    let ow_start = pad_w;\n    let ow_width = out_width.saturating_sub(2 * pad_w);\n    let oh_start = pad_h;\n    let oh_end = out_height.saturating_sub(pad_h);\n\n    let ow_blocks = ow_width / ow_b;\n    let oc = ob * oc_b;\n    let group = oc / channels_per_group;\n    let mut ic_off = group * in_channels;\n    if !GROUPS {\n        ic_off = 0;\n    }\n\n    unsafe {\n        let bias = if let Some(bias) = &bias {\n            vload_unaligned::<S, _>(&bias[oc])\n        } else {\n            Zeroable::zeroed()\n        };\n\n        for oh in oh_start..oh_end {\n            let mut out = out.slice_mut(s![oh, .., ..]);\n            for ow_block in 0..ow_blocks {\n                let ow = ow_block * ow_b + ow_start;\n\n                #[allow(clippy::if_same_then_else)]\n                if STRIDE {\n                    conv2d_inner_nopad(\n                        &x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w,\n                        dilate_h, dilate_w, k_height, k_width, pad_h, pad_w,\n                    );\n                } else {\n                    conv2d_inner_nopad_nostride(\n                        &x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h,\n                        pad_w,\n                    );\n                }\n            }\n        }\n        conv2d_remainder(\n            x,\n            weights,\n            out,\n            bias,\n            oc,\n            ic_off,\n            ow_blocks * ow_b,\n            stride_h,\n            stride_w,\n            dilate_h,\n            dilate_w,\n            pad_h,\n            pad_w,\n            k_height,\n            k_width,\n        );\n    }\n}\n\n/// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is\n/// much slower, so we want to minimize the amount of pixels that need to be processed by this\n///\n/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector\n/// is in bounds. Weights and `out` must be channels last (with `stride == 1`).\n#[allow(clippy::too_many_arguments)]\n#[inline(always)]\nunsafe fn conv2d_remainder<S: Simd, E: VMulAdd>(\n    x: ArrayView3<E>,\n    weights: ArrayView4<E>,\n    out: &mut ArrayViewMut3<E>,\n    bias: Vector<S, E>,\n    oc: usize,\n    ic_off: usize,\n    owb_end: usize,\n    stride_h: usize,\n    stride_w: usize,\n    dilate_h: usize,\n    dilate_w: usize,\n    pad_h: usize,\n    pad_w: usize,\n    k_height: usize,\n    k_width: usize,\n) {\n    let in_channels = weights.shape()[0];\n    let (_, in_height, in_width) = x.dim();\n    let (out_height, out_width, _) = out.dim();\n    let oh_start = pad_h;\n    let oh_end = out_height.saturating_sub(pad_h);\n    let ow_start = pad_w;\n\n    let height1 = in_height + pad_h;\n    let width1 = in_width + pad_w;\n\n    for oh in (0..oh_start).chain(oh_end..out_height) {\n        for ow in 0..out_width {\n            let mut acc = bias;\n\n            for ic in 0..in_channels {\n                for kh in 0..k_height {\n                    let ih = oh * stride_h + kh * dilate_h;\n                    if (ih < pad_h) | (ih >= height1) {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..k_width {\n                        let iw = ow * stride_w + kw * dilate_w;\n                        if (iw < pad_w) | (iw >= width1) {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n\n                        // Load a full vector from the weights. This is guaranteed to be in bounds\n                        // as long as `oc <= out_channels - simd_lanes` and out channels are last.\n                        // We need to ensure the weights are reshaped appropriately.\n                        let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };\n\n                        // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the\n                        // compiler can't prove this. We can't use `as_slice` with fixed bounds\n                        // because we want to support arbitrary input layouts. So an unchecked load\n                        // is used.\n                        let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::<S>();\n                        acc = i0.mul_add(f0, acc);\n                    }\n                }\n            }\n\n            // Store a full vector from the output. This is guaranteed to be in bounds\n            // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with\n            // channels last, so this always holds.\n            unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };\n        }\n    }\n    for ow in (0..ow_start).chain(owb_end..out_width) {\n        for oh in 0..out_height {\n            let mut acc = bias;\n\n            for ic in 0..in_channels {\n                for kh in 0..k_height {\n                    let ih = oh * stride_h + kh * dilate_h;\n                    if (ih < pad_h) | (ih >= height1) {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..k_width {\n                        let iw = ow * stride_w + kw * dilate_w;\n                        if (iw < pad_w) | (iw >= width1) {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n\n                        // Load a full vector from the weights. This is guaranteed to be in bounds\n                        // as long as `oc <= out_channels - simd_lanes` and out channels are last.\n                        // We need to ensure the weights are reshaped appropriately.\n                        let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };\n\n                        // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the\n                        // compiler can't prove this. We can't use `as_slice` with fixed bounds\n                        // because we want to support arbitrary input layouts. So an unchecked load\n                        // is used.\n                        let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::<S>();\n                        acc = i0.mul_add(f0, acc);\n                    }\n                }\n            }\n\n            // Store a full vector from the output. This is guaranteed to be in bounds\n            // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with\n            // channels last, so this always holds.\n            unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) };\n        }\n    }\n}\n\nmacro_rules! inner_with_register_blocking_size {\n    ($rb: literal) => {\n        /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than\n        /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is\n        /// guaranteed to always be in bounds (because of the way out size is calculated).\n        ///\n        /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector\n        /// is in bounds. Weights and `out` must be channels last (with `stride == 1`).\n        #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]\n        #[inline(always)]\n        unsafe fn conv2d_inner_nopad<S: Simd, E: VMulAdd>(\n            x: &ArrayView3<E>,\n            weights: &ArrayView4<E>,\n            out: &mut ArrayViewMut2<E>,\n            bias: Vector<S, E>,\n            oh: usize,\n            ow: usize,\n            oc: usize,\n            ic_off: usize,\n            stride_h: usize,\n            stride_w: usize,\n            dilate_h: usize,\n            dilate_w: usize,\n            k_height: usize,\n            k_width: usize,\n            pad_h: usize,\n            pad_w: usize,\n        ) {\n            let in_channels = weights.shape()[0];\n\n            seq!(N in 0..$rb {\n                let mut acc~N = bias;\n            });\n\n            for ic in 0..in_channels {\n                for kh in 0..k_height {\n                    let ih = oh * stride_h + kh * dilate_h - pad_h;\n\n                    for kw in 0..k_width {\n                        // Load a full vector from the weights. This is guaranteed to be in bounds\n                        // as long as `oc <= out_channels - simd_lanes` and out channels are last.\n                        // We need to ensure the weights are reshaped appropriately.\n                        let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };\n                        let iw = ow * stride_w + kw * dilate_w - pad_w;\n\n                        seq!(N in 0..$rb {\n                            // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the\n                            // compiler can't prove this. We can't use `as_slice` with fixed bounds\n                            // because we want to support arbitrary input layouts. So an unchecked load\n                            // is used.\n                            let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::<S>();\n                        });\n                        seq!(N in 0..$rb {\n                            acc~N = i~N.mul_add(f0, acc~N);\n                        });\n                    }\n                }\n            }\n\n            seq!(N in 0..$rb {\n                // Store a full vector from the output. This is guaranteed to be in bounds\n                // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with\n                // channels last, so this always holds.\n                unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };\n            });\n        }\n\n        /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than\n        /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is\n        /// guaranteed to always be in bounds (because of the way out size is calculated).\n        ///\n        /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector\n        /// is in bounds. Weights and `out` must be channels last (with `stride == 1`).\n        #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)]\n        #[inline(always)]\n        unsafe fn conv2d_inner_nopad_nostride<S: Simd, E: VMulAdd>(\n            x: &ArrayView3<E>,\n            weights: &ArrayView4<E>,\n            out: &mut ArrayViewMut2<E>,\n            bias: Vector<S, E>,\n            oh: usize,\n            ow: usize,\n            oc: usize,\n            ic_off: usize,\n            k_height: usize,\n            k_width: usize,\n            pad_h: usize,\n            pad_w: usize,\n        ) {\n            let in_channels = weights.shape()[0];\n\n            seq!(N in 0..$rb {\n                let mut acc~N = bias;\n            });\n\n            for ic in 0..in_channels {\n                for kh in 0..k_height {\n                    let ih = oh + kh - pad_h;\n\n                    for kw in 0..k_width {\n                        // Load a full vector from the weights. This is guaranteed to be in bounds\n                        // as long as `oc <= out_channels - simd_lanes` and out channels are last.\n                        // We need to ensure the weights are reshaped appropriately.\n                        let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) };\n                        let iw = ow + kw - pad_w;\n\n                        seq!(N in 0..$rb {\n                            // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the\n                            // compiler can't prove this. We can't use `as_slice` with fixed bounds\n                            // because we want to support arbitrary input layouts. So an unchecked load\n                            // is used.\n                            let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::<S>();\n                        });\n                        seq!(N in 0..$rb {\n                            acc~N = i~N.mul_add(f0, acc~N);\n                        });\n                    }\n                }\n            }\n\n            seq!(N in 0..$rb {\n                // Store a full vector from the output. This is guaranteed to be in bounds\n                // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with\n                // channels last, so this always holds.\n                unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) };\n            });\n        }\n    };\n}\npub(crate) use inner_with_register_blocking_size;\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/maxpool.rs",
    "content": "use core::{marker::PhantomData, mem::transmute};\n\nuse crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};\n\nuse burn_backend::{BoolStore, DType, Element, quantization::QuantValue};\nuse macerator::{Simd, VOrd};\nuse ndarray::{Array4, s};\nuse nhwc::max_pool2d_nhwc;\n\nuse super::{MinMax, should_use_simd};\n\n#[macerator::with_simd]\nfn is_accelerated_impl<S: Simd, T: VOrd>(_x: PhantomData<T>) -> bool {\n    <T as VOrd>::is_min_max_accelerated::<S>()\n}\n\nfn is_accelerated<T: VOrd>() -> bool {\n    is_accelerated_impl::<T>(PhantomData)\n}\n\nmacro_rules! launch_kernel {\n    ($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => {\n        match <$ty as Element>::dtype() {\n            DType::F64 if is_accelerated::<f64>() => Ok(cast($func::<f64>(cast($x), $($arg),*))),\n            DType::F32 if is_accelerated::<f32>() => Ok(cast($func::<f32>(cast($x), $($arg),*))),\n            DType::I64 if is_accelerated::<i64>() => Ok(cast($func::<i64>(cast($x), $($arg),*))),\n            DType::I32 if is_accelerated::<i32>() => Ok(cast($func::<i32>(cast($x), $($arg),*))),\n            DType::I16 if is_accelerated::<i16>() => Ok(cast($func::<i16>(cast($x), $($arg),*))),\n            DType::I8 if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),\n            DType::U64 if is_accelerated::<u64>() => Ok(cast($func::<u64>(cast($x), $($arg),*))),\n            DType::U32 if is_accelerated::<u32>() => Ok(cast($func::<u32>(cast($x), $($arg),*))),\n            DType::U16 if is_accelerated::<u16>() => Ok(cast($func::<u16>(cast($x), $($arg),*))),\n            DType::U8 if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),\n            DType::Bool(BoolStore::Native) if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),\n            DType::QFloat(scheme) => match scheme.value {\n                QuantValue::Q8F | QuantValue::Q8S if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),\n                _ => Err($x)\n            },\n            _ => Err($x),\n        }\n    };\n}\n\npub(crate) fn try_max_pool2d_simd<E: Element>(\n    x: SharedArray<E>,\n    ksize: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n) -> Result<SharedArray<E>, SharedArray<E>> {\n    let [_, c, _, _] = x.shape().try_into().unwrap();\n    if !should_use_simd(c) || x.strides()[1] != 1 {\n        return Err(x);\n    }\n\n    launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation)\n}\n\nfn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {\n    unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }\n}\n\nmod nhwc {\n    use itertools::Itertools;\n    use macerator::{Simd, vload_unaligned, vstore_unaligned};\n    use ndarray::{ArrayView3, ArrayViewMut3, Ix4};\n    use seq_macro::seq;\n\n    use crate::ops::simd::lanes;\n\n    use super::*;\n\n    // Until you can use associated constants as array size, we need to hardcode this.\n    // The most common config (x86-v3) has 16 registers, so use half of them for accumulators.\n    const BLOCK_REGISTERS: usize = 8;\n\n    pub(crate) fn max_pool2d_nhwc<E: Element + VOrd + MinMax>(\n        x: SharedArray<E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n    ) -> SharedArray<E> {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n        let [dilation_height, dilation_width] = dilation;\n        let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();\n        let lanes = lanes::<E>();\n\n        let ch_block = lanes * BLOCK_REGISTERS;\n\n        let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1)\n            / stride_height)\n            + 1;\n        let out_width =\n            ((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1;\n\n        let mut output = unsafe {\n            Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()\n        };\n        let unsafe_shared_out = UnsafeSharedRef::new(&mut output);\n\n        let x = x.into_dimensionality::<Ix4>().unwrap();\n        let x = x.view();\n        let x = x.permuted_axes([0, 2, 3, 1]);\n\n        // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`.\n        // An exclusive loop will always have `lanes * blocking factor` elements in bounds.\n        let blocks = channels / ch_block;\n        let blocks_end = blocks * ch_block;\n        // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An\n        // exclusive loop will always have `lanes` elements in bounds.\n        let simd_end = channels / lanes * lanes;\n        let simd_unblocked = (simd_end - blocks_end) / lanes;\n        let remainder = channels - simd_end;\n\n        run_par!(|| {\n            // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.\n            iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {\n                let block = k % blocks;\n                let b = k / blocks;\n\n                let output = unsafe_shared_out.get();\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n                loop_blocked(x, out, kernel_size, stride, padding, dilation, block);\n            });\n            // SAFETY: See `loop_unblocked`\n            iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe {\n                let ch = (k % simd_unblocked) * lanes + blocks_end;\n                let b = k / simd_unblocked;\n\n                let output = unsafe_shared_out.get();\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n                loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch);\n            });\n            // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe.\n            iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {\n                let ch = (k % remainder) + simd_end;\n                let b = k / remainder;\n\n                let output = unsafe_shared_out.get();\n                let x = x.slice(s![b, .., .., ..]);\n                let out = output.slice_mut(s![b, .., .., ..]);\n                loop_scalar(x, out, kernel_size, stride, padding, dilation, ch);\n            });\n        });\n\n        output = output.permuted_axes([0, 3, 1, 2]);\n\n        output.into_dyn().into_shared()\n    }\n\n    /// Execute the blocked (unrolled) portion of the pool.\n    #[allow(\n        clippy::too_many_arguments,\n        clippy::erasing_op,\n        clippy::identity_op,\n        unused_mut\n    )]\n    #[inline(always)]\n    #[macerator::with_simd]\n    fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>(\n        x: ArrayView3<'a, E>,\n        mut out: ArrayViewMut3<'a, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        block: usize,\n    ) where\n        'a: 'a,\n    {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n        let [dilation_height, dilation_width] = dilation;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n        let lanes = E::lanes::<S>();\n        let ch_block = lanes * BLOCK_REGISTERS;\n\n        let min = E::MIN.splat::<S>();\n        // If outside padding area, kernels are guaranteed to be in bounds\n        for oh in pad_h..out_height.saturating_sub(pad_h) {\n            for ow in pad_w..out_width.saturating_sub(pad_w) {\n                seq!(N in 0..8 {\n                    let mut acc~N = min;\n                });\n                let ch = block * ch_block;\n                let ch_end = ch + ch_block;\n                let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh * dilation_height - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw * dilation_width - pad_w;\n                        let x = x.slice(s![ih, iw, ch..ch_end]);\n\n                        seq!(N in 0..8 {\n                            // SAFETY:\n                            // Load a full vector from x[N * lanes]. This is bounds checked by the\n                            // slice above.\n                            acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });\n                        });\n                    }\n                }\n\n                seq!(N in 0..8 {\n                    // SAFETY:\n                    // Store a full vector to out[N * lanes]. This is bounds checked by the\n                    // slice above.\n                    unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };\n                });\n            }\n        }\n\n        // Border pixels need bounds checks\n        if (pad_h, pad_w) != (0, 0) {\n            let v_borders = (0..pad_h)\n                .chain(out_height.saturating_sub(pad_h)..out_height)\n                .cartesian_product(0..out_width);\n            let h_borders = (0..out_height)\n                .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));\n\n            for (oh, ow) in v_borders.chain(h_borders) {\n                seq!(N in 0..8 {\n                    let mut acc~N = min;\n                });\n                let ch = block * ch_block;\n                let ch_end = ch + ch_block;\n                let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh * dilation_height;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw * dilation_width;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n\n                        let x = x.slice(s![ih, iw, ch..ch_end]);\n\n                        seq!(N in 0..8 {\n                            // SAFETY:\n                            // Load a full vector from x[N * lanes]. This is bounds checked by the\n                            // slice above.\n                            acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });\n                        });\n                    }\n                }\n\n                seq!(N in 0..8 {\n                    // SAFETY:\n                    // Store a full vector to out[N * lanes]. This is bounds checked by the\n                    // slice above.\n                    unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };\n                });\n            }\n        }\n    }\n\n    /// Execute the unblocked (not unrolled) portion of the pool.\n    ///\n    /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`.\n    #[allow(clippy::too_many_arguments, unused_mut)]\n    #[inline(always)]\n    #[macerator::with_simd]\n    unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>(\n        x: ArrayView3<'a, E>,\n        mut out: ArrayViewMut3<'a, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ch: usize,\n    ) where\n        'a: 'a,\n    {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n        let [dilation_height, dilation_width] = dilation;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n\n        for oh in pad_h..out_height.saturating_sub(pad_h) {\n            for ow in pad_w..out_width.saturating_sub(pad_w) {\n                let mut acc = E::MIN.splat::<S>();\n                let out = &mut out[[oh, ow, ch]];\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh * dilation_height - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw * dilation_width - pad_w;\n                        // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`\n                        acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });\n                    }\n                }\n                // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.\n                unsafe { vstore_unaligned(out, acc) };\n            }\n        }\n\n        // Border pixels need bounds checks\n        if (pad_h, pad_w) != (0, 0) {\n            let v_borders = (0..pad_h)\n                .chain(out_height.saturating_sub(pad_h)..out_height)\n                .cartesian_product(0..out_width);\n            let h_borders = (0..out_height)\n                .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));\n\n            for (oh, ow) in v_borders.chain(h_borders) {\n                let mut acc = E::MIN.splat::<S>();\n                let out = &mut out[[oh, ow, ch]];\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh * dilation_height;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw * dilation_width;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n                        // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes`\n                        acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });\n                    }\n                }\n                // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`.\n                unsafe { vstore_unaligned(out, acc) };\n            }\n        }\n    }\n\n    fn loop_scalar<E: Element + MinMax>(\n        x: ArrayView3<'_, E>,\n        mut out: ArrayViewMut3<'_, E>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ch: usize,\n    ) {\n        let [kernel_height, kernel_width] = kernel_size;\n        let [pad_h, pad_w] = padding;\n        let [stride_height, stride_width] = stride;\n        let [dilation_height, dilation_width] = dilation;\n\n        let (x_height, x_width, _) = x.dim();\n        let (out_height, out_width, _) = out.dim();\n\n        for oh in 0..out_height {\n            for ow in 0..out_width {\n                let mut acc = E::MIN;\n\n                for kh in 0..kernel_height {\n                    let ih = oh * stride_height + kh * dilation_height;\n                    if ih < pad_h || ih >= x_height + pad_h {\n                        continue;\n                    }\n                    let ih = ih - pad_h;\n\n                    for kw in 0..kernel_width {\n                        let iw = ow * stride_width + kw * dilation_width;\n                        if iw < pad_w || iw >= x_width + pad_w {\n                            continue;\n                        }\n                        let iw = iw - pad_w;\n                        acc = acc.max(x[[ih, iw, ch]]);\n                    }\n                }\n\n                out[[oh, ow, ch]] = acc;\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/mod.rs",
    "content": "pub(crate) mod avgpool;\nmod base;\npub(crate) mod binary;\npub(crate) mod binary_elemwise;\npub(crate) mod cmp;\npub(crate) mod conv;\npub(crate) mod maxpool;\npub(crate) mod unary;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/simd/unary.rs",
    "content": "use core::marker::PhantomData;\n\nuse bytemuck::cast;\nuse macerator::{\n    Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned,\n};\nuse ndarray::ArrayD;\nuse num_traits::Signed;\nuse seq_macro::seq;\n\nuse crate::{NdArrayElement, SharedArray};\n\nuse super::should_use_simd;\n\npub trait SimdUnop<T: Scalar, Out: Scalar> {\n    fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, Out>;\n    fn apply(input: T) -> Out;\n    fn is_accelerated<S: Simd>() -> bool;\n}\n\npub struct RecipVec;\n\nimpl SimdUnop<f32, f32> for RecipVec {\n    fn apply_vec<S: Simd>(input: Vector<S, f32>) -> Vector<S, f32> {\n        input.recip()\n    }\n\n    fn apply(input: f32) -> f32 {\n        input.recip()\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <f32 as VRecip>::is_accelerated::<S>()\n    }\n}\n\npub struct VecAbs;\n\nimpl<T: VAbs + Signed> SimdUnop<T, T> for VecAbs {\n    fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {\n        input.abs()\n    }\n\n    fn apply(input: T) -> T {\n        input.abs()\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VAbs>::is_accelerated::<S>()\n    }\n}\n\npub struct VecBitNot;\n\nimpl<T: VBitNot> SimdUnop<T, T> for VecBitNot {\n    fn apply_vec<S: Simd>(input: Vector<S, T>) -> Vector<S, T> {\n        !input\n    }\n\n    fn apply(input: T) -> T {\n        input.not()\n    }\n\n    fn is_accelerated<S: Simd>() -> bool {\n        <T as VBitNot>::is_accelerated::<S>()\n    }\n}\n\n#[macerator::with_simd]\nfn is_accelerated<S: Simd, T: Scalar, Out: Scalar, Op: SimdUnop<T, Out>>(\n    _x: PhantomData<(T, Out, Op)>,\n) -> bool {\n    Op::is_accelerated::<S>()\n}\n\npub fn try_unary_simd<\n    E: NdArrayElement,\n    EOut: NdArrayElement,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdUnop<T, Out>,\n>(\n    input: SharedArray<E>,\n) -> Result<SharedArray<EOut>, SharedArray<E>> {\n    if !should_use_simd(input.len())\n        || input.as_slice_memory_order().is_none()\n        || !is_accelerated::<T, Out, Op>(PhantomData)\n    {\n        return Err(input);\n    }\n    // Used to assert traits based on the dynamic `DType`.\n    let input = unsafe { core::mem::transmute::<SharedArray<E>, SharedArray<T>>(input) };\n    let out = if size_of::<T>() == size_of::<Out>()\n        && align_of::<T>() >= align_of::<Out>()\n        && input.is_unique()\n    {\n        unsafe { unary_scalar_simd_inplace::<T, Out, Op>(input) }\n    } else {\n        unary_scalar_simd_owned::<T, Out, Op>(input)\n    };\n    // Used to assert traits based on the dynamic `DType`.\n    let out = unsafe { core::mem::transmute::<SharedArray<Out>, SharedArray<EOut>>(out) };\n    Ok(out)\n}\n\n/// Execute operation in line.\n/// SAFETY:\n/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\nunsafe fn unary_scalar_simd_inplace<\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdUnop<T, Out>,\n>(\n    input: SharedArray<T>,\n) -> SharedArray<Out> {\n    let mut buffer = input.into_owned();\n    let slice = buffer.as_slice_memory_order_mut().unwrap();\n    // This is only called when in and out have the same size, so it's safe\n    unsafe { unary_slice_inplace::<T, Out, Op>(slice, PhantomData) };\n    // Buffer has the same elem size and is filled with the operation output, so this is safe\n    let out = unsafe { core::mem::transmute::<ArrayD<T>, ArrayD<Out>>(buffer) };\n    out.into_shared()\n}\n\nfn unary_scalar_simd_owned<\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdUnop<T, Out>,\n>(\n    input: SharedArray<T>,\n) -> SharedArray<Out> {\n    let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() };\n    let input = input.as_slice_memory_order().unwrap();\n    let out_slice = out.as_slice_memory_order_mut().unwrap();\n    unary_slice::<T, Out, Op>(input, out_slice, PhantomData);\n    out.into_shared()\n}\n\n#[allow(clippy::erasing_op, clippy::identity_op)]\n#[macerator::with_simd]\nfn unary_slice<\n    'a,\n    S: Simd,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdUnop<T, Out>,\n>(\n    input: &'a [T],\n    out: &'a mut [Out],\n    _op: PhantomData<Op>,\n) where\n    'a: 'a,\n{\n    let lanes = T::lanes::<S>();\n    let mut chunks_input = input.chunks_exact(8 * lanes);\n    let mut chunks_out = out.chunks_exact_mut(8 * lanes);\n    while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n        seq!(N in 0..8 {\n            // Load one full vector from `input`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            let s~N = unsafe { vload_unaligned(&input[N * lanes]) };\n            let s~N = Op::apply_vec::<S>(s~N);\n            // Store one full vector to `out`.\n            // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes`\n            unsafe { vstore_unaligned(&mut out[N * lanes], s~N) };\n        });\n    }\n    let mut chunks_input = chunks_input.remainder().chunks_exact(lanes);\n    let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes);\n    while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) {\n        // Load one full vector from `input`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        let s0 = unsafe { vload_unaligned(input.as_ptr()) };\n        let s0 = Op::apply_vec::<S>(s0);\n        // Store one full vector to `out`.\n        // SAFETY: Guaranteed to be in bounds because `len == lanes`\n        unsafe { vstore_unaligned(out.as_mut_ptr(), s0) };\n    }\n\n    for (input, out) in chunks_input\n        .remainder()\n        .iter()\n        .zip(chunks_out.into_remainder())\n    {\n        *out = Op::apply(*input)\n    }\n}\n\n/// Execute operation in line.\n/// SAFETY:\n/// Must ensure `size_of::<T> == size_of::<Out>` and `align_of::<T> >= align_of::<Out>`.\n#[macerator::with_simd]\nunsafe fn unary_slice_inplace<\n    'a,\n    S: Simd,\n    T: NdArrayElement + Scalar,\n    Out: NdArrayElement + Scalar,\n    Op: SimdUnop<T, Out>,\n>(\n    buf: &'a mut [T],\n    _op: PhantomData<(Out, Op)>,\n) where\n    'a: 'a,\n{\n    let (head, main, tail) = unsafe { buf.align_to_mut::<Vector<S, T>>() };\n    for elem in head.iter_mut().chain(tail) {\n        *elem = cast(Op::apply(*elem));\n    }\n    let mut chunks = main.chunks_exact_mut(8);\n    for elem in chunks.by_ref() {\n        seq!(N in 0..8 {\n            // Load a full vector from the aligned portion of the buffer.\n            // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n            // always a full vector in bounds.\n            let s~N = unsafe { vload(&elem[N] as *const _ as *const T) };\n            let s~N = Op::apply_vec::<S>(s~N);\n            // Store a full vector at the same position as the input. Cast is safe because `Out` is\n            // size and align compatible\n            unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) };\n        });\n    }\n\n    for elem in chunks.into_remainder() {\n        // Load a full vector from the aligned portion of the buffer.\n        // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is\n        // always a full vector in bounds.\n        let s0 = unsafe { vload(elem as *const _ as *const T) };\n\n        let s0 = Op::apply_vec::<S>(s0);\n        // Store a full vector at the same position as the input. Cast is safe because `Out` is\n        // size and align compatible\n        unsafe { vstore(elem as *mut _ as *mut Out, s0) };\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/tensor.rs",
    "content": "// Language\nuse alloc::vec::Vec;\nuse burn_backend::backend::ExecutionError;\nuse burn_backend::ops::GridSampleOptions;\nuse burn_backend::tensor::FloatTensor;\nuse burn_backend::{TensorMetadata, element::cast::ToElement};\n\n// Current crate\nuse super::{\n    NdArrayMathOps, NdArrayOps,\n    matmul::{cross, matmul},\n};\nuse crate::{\n    NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor,\n};\nuse crate::{NdArrayDevice, SEED, slice};\nuse crate::{\n    SharedArray,\n    element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement},\n};\nuse crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d};\n\n// Workspace crates\nuse crate::rand::get_seeded_rng;\nuse burn_backend::{Distribution, FloatDType, Scalar};\nuse burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps};\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float;\n\nuse libm::erf;\n\n#[cfg(feature = \"std\")]\n#[allow(dead_code)]\nfn round_ties_even_wrapper(x: f64) -> f64 {\n    x.round_ties_even()\n}\n\n#[cfg(not(feature = \"std\"))]\n#[allow(dead_code)]\nfn round_ties_even_wrapper(x: f64) -> f64 {\n    if (x - x.floor()) == 0.5 {\n        (x * 0.5).round() * 2.0\n    } else {\n        x.round()\n    }\n}\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n    fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor<Self> {\n        NdArrayTensor::from_data(data)\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &NdArrayDevice,\n    ) -> FloatTensor<Self> {\n        let mut seed = SEED.lock().unwrap();\n        let mut rng = seed.take().unwrap_or_else(get_seeded_rng);\n        let tensor = Self::float_from_data(\n            TensorData::random::<E, _, _>(shape, distribution, &mut rng),\n            device,\n        );\n        *seed = Some(rng);\n        tensor\n    }\n\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        Ok(tensor.into_data())\n    }\n\n    fn float_device(_tensor: &FloatTensor<Self>) -> NdArrayDevice {\n        NdArrayDevice::Cpu\n    }\n\n    fn float_to_device(tensor: FloatTensor<Self>, _device: &NdArrayDevice) -> FloatTensor<Self> {\n        tensor\n    }\n\n    fn float_empty(\n        shape: Shape,\n        device: &<NdArray<E> as Backend>::Device,\n        dtype: FloatDType,\n    ) -> FloatTensor<Self> {\n        Self::float_zeros(shape, device, dtype)\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add)\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::add_scalar(array, rhs.elem())\n        })\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub)\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::sub_scalar(array, rhs.elem())\n        })\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul)\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::mul_scalar(array, rhs.elem())\n        })\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div)\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::div_scalar(array, rhs.elem())\n        })\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder)\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::remainder_scalar(array, rhs.elem())\n        })\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), matmul)\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim))\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::recip(array)\n        })\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::swap_dims(array, dim1, dim2)\n        })\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::reshape(array, shape)\n        })\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: NdArrayTensor,\n    ) -> FloatTensor<Self> {\n        execute_with_int_dtype!(\n            indices,\n            IntElem,\n            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {\n                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n                    NdArrayOps::gather(dim, array, idx_array)\n                })\n            }\n        )\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: NdArrayTensor,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        execute_with_int_dtype!(\n            indices,\n            IntElem,\n            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {\n                execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter(\n                    dim, tensor, idx_array, value\n                ))\n            }\n        )\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: NdArrayTensor,\n    ) -> FloatTensor<Self> {\n        execute_with_int_dtype!(\n            indices,\n            IntElem,\n            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {\n                execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n                    NdArrayMathOps::select(array, dim, idx_array)\n                })\n            }\n        )\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: NdArrayTensor,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        execute_with_int_dtype!(\n            indices,\n            IntElem,\n            |idx_array: SharedArray<IntElem>| -> NdArrayTensor {\n                execute_with_float_dtype!((tensor, value), |tensor, value| {\n                    NdArrayMathOps::select_assign(tensor, dim, idx_array, value)\n                })\n            }\n        )\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[burn_backend::Slice]) -> FloatTensor<Self> {\n        slice!(tensor, slices)\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!((tensor, value), |tensor, value| {\n            NdArrayOps::slice_assign(tensor, slices, value)\n        })\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: NdArrayTensor,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!((tensor, value), |tensor, value| {\n            NdArrayOps::mask_where(tensor, mask.bool(), value)\n        })\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: NdArrayTensor,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::mask_fill(array, mask.bool(), value.elem())\n        })\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) })\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::equal_elem(array, rhs.elem())\n        })\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) })\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::greater_elem(array, rhs.elem())\n        })\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {\n            NdArrayMathOps::greater_equal(lhs, rhs)\n        })\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::greater_equal_elem(array, rhs.elem())\n        })\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) })\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::lower_elem(array, rhs.elem())\n        })\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!((lhs, rhs), |lhs, rhs| {\n            NdArrayMathOps::lower_equal(lhs, rhs)\n        })\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> NdArrayTensor {\n        execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::lower_equal_elem(array, rhs.elem())\n        })\n    }\n\n    fn float_detach(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor\n    }\n\n    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::mean_view(array.view())\n        })\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::sum_view(array.view())\n        })\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::mean_dim(array, dim)\n        })\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::cumsum(array, dim)\n        })\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::cumprod(array, dim)\n        })\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::cummin(array, dim)\n        })\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::cummax(array, dim)\n        })\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::sum_dim(array, dim)\n        })\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::argmax_view::<I>(array.view(), dim)\n        })\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> NdArrayTensor {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::argmin_view::<I>(array.view(), dim)\n        })\n    }\n\n    fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared()\n        })\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array.mapv_into(|a: FloatElem| a.log_elem()).into_shared()\n        })\n    }\n\n    fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::prod_view(array.view())\n        })\n    }\n\n    fn float_prod_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::prod_dim(array, dim)\n        })\n    }\n\n    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::max_view(array.view())\n        })\n    }\n\n    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        // Use view() for zero-copy on borrowed storage\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::min_view(array.view())\n        })\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared()\n        })\n    }\n\n    fn float_powf_scalar_impl(tensor: FloatTensor<Self>, value: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| a.powf_elem(value.elem()))\n                .into_shared()\n        })\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared()\n        })\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::abs(array)\n        })\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {\n            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b))\n        })\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem())\n                .into_shared()\n        })\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem())\n                .into_shared()\n        })\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array\n                .mapv_into(|a: FloatElem| erf(a.to_f64()).elem())\n                .into_shared()\n        })\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        cat_with_dtype!(tensors, dim, [F64, F32])\n    }\n\n    fn float_clamp_min(tensor: FloatTensor<Self>, min: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::clamp_min(array, min.elem())\n        })\n    }\n\n    fn float_clamp_max(tensor: FloatTensor<Self>, max: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::clamp_max(array, max.elem())\n        })\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::clamp(array, min.elem(), max.elem())\n        })\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> NdArrayTensor {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            array.mapv(|a: FloatElem| a.elem::<I>()).into_shared()\n        })\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| {\n            NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b))\n        })\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::permute(array, axes)\n        })\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::flip(array, axes)\n        })\n    }\n\n    fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayMathOps::sign_op(array)\n        })\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::expand(array, shape)\n        })\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            cast_to_dtype(array, dtype.into())\n        })\n    }\n\n    fn float_grid_sample_2d(\n        tensor: FloatTensor<Self>,\n        grid: FloatTensor<Self>,\n        options: GridSampleOptions,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d(\n            tensor, grid, options\n        ))\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray<FloatElem>| {\n            NdArrayOps::unfold(array, dim, size, step)\n        })\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/ops/transaction.rs",
    "content": "use crate::{\n    FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray,\n    element::{IntNdArrayElement, QuantElement},\n};\nuse burn_backend::ops::TransactionOps;\n\nimpl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> TransactionOps<Self>\n    for NdArray<E, I, Q>\nwhere\n    NdArrayTensor: From<SharedArray<E>>,\n    NdArrayTensor: From<SharedArray<I>>,\n{\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/parallel.rs",
    "content": "/// Macro for running a function in parallel.\n#[cfg(feature = \"multi-threads\")]\n#[macro_export(local_inner_macros)]\nmacro_rules! run_par {\n    (\n        $func:expr\n    ) => {{\n        use rayon::prelude::*;\n\n        #[allow(clippy::redundant_closure_call)]\n        rayon::scope(|_| $func())\n    }};\n}\n\n/// Macro for running a function in parallel.\n#[cfg(not(feature = \"multi-threads\"))]\n#[macro_export(local_inner_macros)]\nmacro_rules! run_par {\n    (\n        $func:expr\n    ) => {{ $func() }};\n}\n\n/// Macro for iterating in parallel.\n#[cfg(not(feature = \"multi-threads\"))]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_par {\n    (\n        $iter:expr\n    ) => {{ $iter }};\n}\n\n/// Macro for iterating in parallel.\n#[cfg(feature = \"multi-threads\")]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_par {\n    (\n        $iter:expr\n    ) => {{ $iter.into_par_iter() }};\n}\n\n/// Macro for iterating in parallel.\n#[cfg(feature = \"multi-threads\")]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_slice_par {\n    (\n        $slice:expr\n    ) => {{ $slice.into_par_iter() }};\n}\n\n/// Macro for iterating in parallel.\n#[cfg(not(feature = \"multi-threads\"))]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_slice_par {\n    (\n        $slice:expr\n    ) => {{ $slice.iter() }};\n}\n\n/// Macro for iterating over a range in parallel.\n#[cfg(feature = \"multi-threads\")]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_range_par {\n    (\n        $start:expr, $end:expr\n    ) => {{ ($start..$end).into_par_iter() }};\n}\n\n/// Macro for iterating over a range in parallel.\n#[cfg(not(feature = \"multi-threads\"))]\n#[macro_export(local_inner_macros)]\nmacro_rules! iter_range_par {\n    (\n        $start:expr, $end:expr\n    ) => {{ ($start..$end) }};\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/rand.rs",
    "content": "//! Random number generation utilities for burn-ndarray\n\n#[cfg(not(feature = \"std\"))]\nuse rand::rngs::SmallRng;\n#[cfg(feature = \"std\")]\nuse rand::rngs::StdRng;\n\n/// Type alias for the RNG used by burn-ndarray\n#[cfg(feature = \"std\")]\npub type NdArrayRng = StdRng;\n#[cfg(not(feature = \"std\"))]\npub type NdArrayRng = SmallRng;\n\n#[cfg(not(feature = \"std\"))]\nuse rand::SeedableRng;\n\n/// Get a seeded random number generator\n///\n/// For std builds, uses OS entropy.\n/// For no_std builds, uses a compile-time random seed.\n#[cfg(feature = \"std\")]\npub fn get_seeded_rng() -> NdArrayRng {\n    // Use the standard implementation from burn-std\n    burn_std::rand::get_seeded_rng()\n}\n\n/// Get a seeded random number generator\n///\n/// For std builds, uses OS entropy.\n/// For no_std builds, uses a compile-time random seed.\n#[cfg(not(feature = \"std\"))]\npub fn get_seeded_rng() -> NdArrayRng {\n    // Use compile-time random seed for no_std\n    const SEED: u64 = const_random::const_random!(u64);\n    SmallRng::seed_from_u64(SEED)\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/sharing.rs",
    "content": "use core::cell::UnsafeCell;\n\n/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439).\npub(crate) struct UnsafeSharedRef<'a, T> {\n    cell: UnsafeCell<&'a mut T>,\n}\n\nunsafe impl<T> Sync for UnsafeSharedRef<'_, T> {}\n\nimpl<'a, T> UnsafeSharedRef<'a, T> {\n    pub fn new(data: &'a mut T) -> Self {\n        Self {\n            cell: UnsafeCell::new(data),\n        }\n    }\n    pub unsafe fn get(&self) -> &'a mut T {\n        unsafe { core::ptr::read(self.cell.get()) }\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/storage.rs",
    "content": "//! Copy-on-write storage for zero-copy tensor loading.\n//!\n//! This module provides `NdArrayStorage<E>`, which enables true zero-copy loading\n//! from burnpack files. When data is borrowed from external memory (like mmap'd files\n//! or static data), it remains zero-copy until a mutating operation is performed,\n//! at which point it's copied (copy-on-write semantics).\n//!\n//! This integrates with ndarray's existing COW patterns - operations that check\n//! `is_unique()` will see borrowed data as non-unique, triggering the allocation path.\n\nuse burn_backend::Element;\nuse burn_std::{Bytes, Shape};\nuse core::mem;\nuse ndarray::{ArcArray, ArrayView, IxDyn};\n\n/// Storage that supports both owned data and borrowed (zero-copy) data.\n///\n/// # Copy-on-Write Semantics\n///\n/// - **Borrowed**: Data from external source (burnpack, mmap, static).\n///   Reports `is_unique() == false` to trigger copy on mutation.\n/// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount.\n///\n/// # Example\n///\n/// ```ignore\n/// // Zero-copy load\n/// let storage = NdArrayStorage::from_borrowed(bytes, shape);\n/// storage.is_unique();  // false - will copy on mutation\n///\n/// // Read operations use view() - zero-copy\n/// let view = storage.view();\n///\n/// // Mutation converts to owned\n/// let owned = storage.into_owned();  // Copies here\n/// ```\n#[derive(Debug)]\npub enum NdArrayStorage<E: Element> {\n    /// Borrowed from external source (e.g., burnpack zero-copy load).\n    /// Keeps `Bytes` alive to ensure the referenced memory is valid.\n    Borrowed {\n        /// Source bytes - keeps external memory alive via reference counting\n        bytes: Bytes,\n        /// Shape of the tensor\n        shape: Shape,\n    },\n\n    /// Standard owned storage with ArcArray COW semantics.\n    Owned(ArcArray<E, IxDyn>),\n}\n\nimpl<E: Element> Clone for NdArrayStorage<E> {\n    fn clone(&self) -> Self {\n        match self {\n            // For borrowed data, clone the Bytes (cheap Arc clone) and shape\n            Self::Borrowed { bytes, shape } => Self::Borrowed {\n                bytes: bytes.clone(),\n                shape: shape.clone(),\n            },\n            // For owned data, clone the ArcArray (cheap Arc clone)\n            Self::Owned(arr) => Self::Owned(arr.clone()),\n        }\n    }\n}\n\nimpl<E: Element> NdArrayStorage<E> {\n    /// Create borrowed storage from external bytes.\n    ///\n    /// Returns the bytes and shape back on failure (misaligned or too small),\n    /// enabling zero-copy even for native allocations by avoiding defensive cloning.\n    ///\n    /// # Requirements\n    ///\n    /// The caller must ensure that:\n    /// - The `Bytes` contain valid data for the element type `E`\n    /// - The data is contiguous in row-major (C) order matching the provided shape\n    ///\n    /// These requirements are upheld when loading from `TensorData` (burnpack, etc.)\n    /// which always stores data contiguously in row-major order.\n    pub fn from_borrowed(bytes: Bytes, shape: impl Into<Shape>) -> Result<Self, (Bytes, Shape)> {\n        let shape = shape.into();\n        // Validate alignment\n        let ptr = bytes.as_ptr();\n        if !(ptr as usize).is_multiple_of(mem::align_of::<E>()) {\n            return Err((bytes, shape));\n        }\n\n        // Validate size (using checked arithmetic to prevent overflow)\n        let num_elements = match shape\n            .iter()\n            .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))\n        {\n            Some(n) => n,\n            None => return Err((bytes, shape)),\n        };\n        let expected_size = match num_elements.checked_mul(mem::size_of::<E>()) {\n            Some(s) => s,\n            None => return Err((bytes, shape)),\n        };\n        if bytes.len() < expected_size {\n            return Err((bytes, shape));\n        }\n\n        Ok(Self::Borrowed { bytes, shape })\n    }\n\n    /// Create owned storage from an ArcArray.\n    #[inline]\n    pub fn from_owned(array: ArcArray<E, IxDyn>) -> Self {\n        Self::Owned(array)\n    }\n\n    /// Returns whether this storage is uniquely owned and can be mutated in-place.\n    ///\n    /// - **Borrowed**: Always returns `false` to trigger copy-on-write.\n    /// - **Owned**: Delegates to `ArcArray::is_unique()`.\n    ///\n    /// This integrates with existing SIMD code patterns like:\n    /// ```ignore\n    /// if tensor.is_unique() {\n    ///     // mutate in place\n    /// } else {\n    ///     // allocate new\n    /// }\n    /// ```\n    #[inline]\n    pub fn is_unique(&self) -> bool {\n        match self {\n            Self::Borrowed { .. } => false, // Force copy path\n            Self::Owned(arr) => arr.is_unique(),\n        }\n    }\n\n    /// Get a read-only view of the data.\n    ///\n    /// This is zero-copy for both borrowed and owned variants.\n    #[inline]\n    pub fn view(&self) -> ArrayView<'_, E, IxDyn> {\n        match self {\n            Self::Borrowed { bytes, shape } => {\n                let ptr = bytes.as_ptr() as *const E;\n                let dim = IxDyn(shape);\n                // SAFETY:\n                // - `bytes` is kept alive for the lifetime of `self`\n                // - Alignment was validated in `from_borrowed`\n                // - Size was validated in `from_borrowed`\n                unsafe { ArrayView::from_shape_ptr(dim, ptr) }\n            }\n            Self::Owned(arr) => arr.view(),\n        }\n    }\n\n    /// Convert to owned ArcArray.\n    ///\n    /// - **Borrowed**: Copies the data into a new ArcArray.\n    /// - **Owned + unique**: Returns the array without copying.\n    /// - **Owned + shared**: Clones the data.\n    pub fn into_owned(self) -> ArcArray<E, IxDyn> {\n        match self {\n            Self::Borrowed { bytes, shape } => {\n                let ptr = bytes.as_ptr() as *const E;\n                let dim = IxDyn(&shape);\n                // SAFETY: Same as view() - bytes is valid for this scope\n                let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };\n                view.to_owned().into_shared()\n            }\n            Self::Owned(arr) => arr,\n        }\n    }\n\n    /// Convert to shared ArcArray, suitable for returning from operations.\n    ///\n    /// This is equivalent to `into_owned()` but named for clarity.\n    #[inline]\n    pub fn into_shared(self) -> ArcArray<E, IxDyn> {\n        self.into_owned()\n    }\n\n    /// Get the shape of the tensor.\n    pub fn shape(&self) -> &[usize] {\n        match self {\n            Self::Borrowed { shape, .. } => shape,\n            Self::Owned(arr) => arr.shape(),\n        }\n    }\n\n    /// Get the number of dimensions.\n    #[inline]\n    pub fn ndim(&self) -> usize {\n        self.shape().len()\n    }\n\n    /// Get the total number of elements.\n    #[inline]\n    pub fn len(&self) -> usize {\n        self.shape().iter().product()\n    }\n\n    /// Check if the tensor is empty.\n    #[inline]\n    pub fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n\n    /// Returns `true` if this is borrowed (zero-copy) storage.\n    #[inline]\n    pub fn is_borrowed(&self) -> bool {\n        matches!(self, Self::Borrowed { .. })\n    }\n\n    /// Returns `true` if this is owned storage.\n    #[inline]\n    pub fn is_owned(&self) -> bool {\n        matches!(self, Self::Owned(_))\n    }\n\n    /// Ensure owned and return mutable reference to the ArcArray.\n    ///\n    /// Converts borrowed to owned if necessary.\n    pub fn ensure_owned(&mut self) -> &mut ArcArray<E, IxDyn> {\n        if let Self::Borrowed { bytes, shape } = self {\n            let ptr = bytes.as_ptr() as *const E;\n            let dim = IxDyn(shape);\n            // SAFETY: Same as view()\n            let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) };\n            *self = Self::Owned(view.to_owned().into_shared());\n        }\n        match self {\n            Self::Owned(arr) => arr,\n            Self::Borrowed { .. } => unreachable!(),\n        }\n    }\n}\n\n/// Convert from ArcArray to NdArrayStorage.\nimpl<E: Element> From<ArcArray<E, IxDyn>> for NdArrayStorage<E> {\n    fn from(array: ArcArray<E, IxDyn>) -> Self {\n        Self::Owned(array)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use alloc::{vec, vec::Vec};\n    use burn_std::Bytes;\n\n    #[test]\n    fn test_borrowed_is_not_unique() {\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        assert!(!storage.is_unique());\n        assert!(storage.is_borrowed());\n    }\n\n    #[test]\n    fn test_owned_unique_when_single_ref() {\n        let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();\n        let storage = NdArrayStorage::from_owned(array);\n\n        assert!(storage.is_unique());\n        assert!(storage.is_owned());\n    }\n\n    #[test]\n    fn test_owned_not_unique_when_cloned() {\n        let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared();\n        let storage = NdArrayStorage::from_owned(array);\n        let _clone = storage.clone();\n\n        assert!(!storage.is_unique());\n    }\n\n    #[test]\n    fn test_view_zero_copy() {\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        let view = storage.view();\n        assert_eq!(view[[0, 0]], 1.0);\n        assert_eq!(view[[1, 1]], 4.0);\n    }\n\n    #[test]\n    fn test_into_owned_copies_borrowed() {\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        let owned = storage.into_owned();\n        assert_eq!(owned[[0, 0]], 1.0);\n        assert_eq!(owned[[1, 1]], 4.0);\n    }\n\n    #[test]\n    fn test_from_borrowed_validates_alignment() {\n        use burn_std::AllocationProperty;\n\n        // Test 1: Properly aligned data should succeed\n        let aligned_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let aligned_bytes = Bytes::from_elems(aligned_data);\n\n        // Verify test setup - should be 4-byte aligned for f32\n        assert_eq!(\n            (aligned_bytes.as_ptr() as usize) % core::mem::align_of::<f32>(),\n            0,\n            \"Test setup: f32 data should be properly aligned\"\n        );\n\n        let result = NdArrayStorage::<f32>::from_borrowed(aligned_bytes, [2, 2]);\n        assert!(\n            result.is_ok(),\n            \"from_borrowed should succeed for properly aligned data\"\n        );\n\n        // Test 2: Misaligned data should fail\n        // Create a buffer large enough to find a misaligned offset\n        // (static data placement varies by platform, so we find an offset dynamically)\n        let buffer: &[u8] = &[0u8; 32];\n        let shared = bytes::Bytes::from_static(buffer);\n        let base = shared.as_ptr() as usize;\n        let align = core::mem::align_of::<f32>();\n\n        // Find an offset in 1..align that produces misalignment (at least one must exist)\n        let misalign_offset = (1..align)\n            .find(|&off| !(base + off).is_multiple_of(align))\n            .expect(\"Should find a misaligned offset\");\n\n        let sliced = shared.slice(misalign_offset..(misalign_offset + 16));\n        let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other);\n\n        // Verify test setup - should NOT be 4-byte aligned\n        assert_ne!(\n            (misaligned_bytes.as_ptr() as usize) % align,\n            0,\n            \"Test setup: sliced data should be misaligned for f32\"\n        );\n\n        let result = NdArrayStorage::<f32>::from_borrowed(misaligned_bytes, [4]);\n        assert!(\n            result.is_err(),\n            \"from_borrowed should return Err for misaligned data\"\n        );\n    }\n\n    #[test]\n    fn test_insufficient_size_returns_err() {\n        // Create bytes that are too small for the requested shape\n        let data: Vec<f32> = vec![1.0, 2.0]; // 8 bytes\n        let bytes = Bytes::from_elems(data);\n\n        // Try to create storage for 4 elements (needs 16 bytes)\n        let result = NdArrayStorage::<f32>::from_borrowed(bytes, [4]);\n        assert!(\n            result.is_err(),\n            \"from_borrowed should return Err when bytes are too small\"\n        );\n    }\n\n    // ==========================================================================\n    // Zero-copy hardening tests\n    // These tests verify the zero-copy guarantee is maintained. If any of these\n    // fail, it indicates a regression in zero-copy functionality.\n    // ==========================================================================\n\n    #[test]\n    fn test_zero_copy_native_allocation() {\n        // CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy\n        // on initial load. The view() must return a pointer to the SAME memory.\n        //\n        // Note: Native allocations copy on clone (this is expected), but the initial\n        // load is still zero-copy, avoiding an extra copy in the common case where\n        // the tensor is used without cloning.\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let original_ptr = bytes.as_ptr();\n\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        // Initial load must be zero-copy\n        let view = storage.view();\n        let view_ptr = view.as_ptr() as *const u8;\n\n        assert_eq!(\n            original_ptr, view_ptr,\n            \"ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes\"\n        );\n\n        // Verify data integrity\n        assert_eq!(view[[0, 0]], 1.0);\n        assert_eq!(view[[0, 1]], 2.0);\n        assert_eq!(view[[1, 0]], 3.0);\n        assert_eq!(view[[1, 1]], 4.0);\n    }\n\n    #[test]\n    fn test_zero_copy_shared_bytes_pointer_identity() {\n        // CRITICAL: Test with SharedBytesAllocationController for true zero-copy.\n        // This simulates the actual burnpack/mmap loading path.\n        use burn_std::AllocationProperty;\n\n        // Create static-like data using bytes::Bytes\n        let data: &[u8] = &[\n            0, 0, 128, 63, // 1.0f32 in little-endian\n            0, 0, 0, 64, // 2.0f32\n            0, 0, 64, 64, // 3.0f32\n            0, 0, 128, 64, // 4.0f32\n        ];\n        let shared = bytes::Bytes::from_static(data);\n        let original_ptr = shared.as_ptr();\n\n        // Create Bytes with SharedBytesAllocationController\n        let bytes = Bytes::from_shared(shared, AllocationProperty::Other);\n\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        // Verify pointer identity\n        let view_ptr = storage.view().as_ptr() as *const u8;\n        assert_eq!(\n            original_ptr, view_ptr,\n            \"ZERO-COPY REGRESSION: SharedBytes view must point to original static data\"\n        );\n\n        // Clone should also share the same memory\n        let cloned = storage.clone();\n        let cloned_ptr = cloned.view().as_ptr() as *const u8;\n        assert_eq!(\n            original_ptr, cloned_ptr,\n            \"ZERO-COPY REGRESSION: SharedBytes clone must share memory\"\n        );\n    }\n\n    #[test]\n    fn test_clone_borrowed_stays_borrowed() {\n        // Verify that cloning borrowed storage produces another borrowed storage.\n        // Note: The underlying Bytes may or may not share memory depending on\n        // the allocation controller (native allocations copy, file-backed may share).\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n        let cloned = storage.clone();\n\n        // Both should still be borrowed (the storage type is preserved)\n        assert!(\n            storage.is_borrowed(),\n            \"ZERO-COPY REGRESSION: original should remain borrowed after clone\"\n        );\n        assert!(\n            cloned.is_borrowed(),\n            \"ZERO-COPY REGRESSION: clone should be borrowed type\"\n        );\n\n        // Both should report not unique (important for COW behavior)\n        assert!(\n            !storage.is_unique(),\n            \"ZERO-COPY REGRESSION: original should not be unique after clone\"\n        );\n        assert!(\n            !cloned.is_unique(),\n            \"ZERO-COPY REGRESSION: clone should not be unique\"\n        );\n\n        // Data should be identical\n        assert_eq!(storage.view(), cloned.view(), \"Clone should have same data\");\n    }\n\n    #[test]\n    fn test_zero_copy_triggers_copy_on_mutation() {\n        // Verify that into_owned() on borrowed data creates a NEW allocation\n        // (this is the \"copy\" in copy-on-write)\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let original_ptr = bytes.as_ptr();\n\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        assert!(storage.is_borrowed(), \"should start as borrowed\");\n\n        let owned = storage.into_owned();\n        let owned_ptr = owned.as_ptr() as *const u8;\n\n        assert_ne!(\n            original_ptr, owned_ptr,\n            \"into_owned() on borrowed data MUST allocate new memory (copy-on-write)\"\n        );\n    }\n\n    #[test]\n    fn test_borrowed_reports_not_unique() {\n        // CRITICAL: Borrowed storage must report is_unique() == false\n        // This is what triggers copy-on-write in mutation operations\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let storage = NdArrayStorage::<f32>::from_borrowed(bytes, [2, 2]).expect(\"should create\");\n\n        assert!(\n            !storage.is_unique(),\n            \"ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \\\n             to trigger copy-on-write. If this is true, mutations will corrupt shared data!\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-ndarray/src/tensor.rs",
    "content": "use burn_backend::{\n    DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata,\n    quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue},\n};\nuse burn_std::BoolStore;\n\nuse crate::NdArrayStorage;\nuse crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization};\nuse alloc::vec::Vec;\nuse ndarray::{ArcArray, ArrayD, IxDyn};\n\n/// Concrete storage type for ndarray (owned with COW semantics via Arc)\npub type SharedArray<E> = ArcArray<E, IxDyn>;\n\n/// Tensor primitive used by the [ndarray backend](crate::NdArray).\n///\n/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`.\n/// When data is borrowed from external sources (like burnpack files),\n/// it remains zero-copy until a mutating operation is performed.\n#[derive(Debug, Clone)]\n#[allow(missing_docs)]\npub enum NdArrayTensor {\n    F64(NdArrayStorage<f64>),\n    F32(NdArrayStorage<f32>),\n    I64(NdArrayStorage<i64>),\n    I32(NdArrayStorage<i32>),\n    I16(NdArrayStorage<i16>),\n    I8(NdArrayStorage<i8>),\n    U64(NdArrayStorage<u64>),\n    U32(NdArrayStorage<u32>),\n    U16(NdArrayStorage<u16>),\n    U8(NdArrayStorage<u8>),\n    Bool(NdArrayStorage<bool>),\n}\n\nimpl NdArrayTensor {\n    /// Extract bool array, converting to owned if necessary.\n    pub(crate) fn bool(self) -> SharedArray<bool> {\n        match self {\n            NdArrayTensor::Bool(storage) => storage.into_shared(),\n            _ => unimplemented!(\"Expected bool tensor, got {:?}\", self.dtype()),\n        }\n    }\n\n    /// Returns true if this tensor uses borrowed (zero-copy) storage.\n    #[inline]\n    pub fn is_borrowed(&self) -> bool {\n        macro_rules! check {\n            ($($variant:ident),*) => {\n                match self {\n                    $(NdArrayTensor::$variant(s) => s.is_borrowed(),)*\n                }\n            };\n        }\n        check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)\n    }\n}\n\npub(crate) fn cast_to_dtype<E1: Element>(array: SharedArray<E1>, dtype: DType) -> NdArrayTensor\nwhere\n    NdArrayTensor: From<SharedArray<E1>>,\n{\n    fn cast<E1: Element, E2: Element>(array: SharedArray<E1>) -> SharedArray<E2> {\n        array.mapv(|a| a.elem()).into_shared()\n    }\n\n    if E1::dtype() == dtype {\n        return array.into();\n    }\n\n    match dtype {\n        DType::F64 => cast::<E1, f64>(array).into(),\n        DType::F32 => cast::<E1, f32>(array).into(),\n        DType::Flex32 => cast::<E1, f32>(array).into(),\n        DType::I64 => cast::<E1, i64>(array).into(),\n        DType::I32 => cast::<E1, i32>(array).into(),\n        DType::I16 => cast::<E1, i16>(array).into(),\n        DType::I8 => cast::<E1, i8>(array).into(),\n        DType::U64 => cast::<E1, u64>(array).into(),\n        DType::U32 => cast::<E1, u32>(array).into(),\n        DType::U16 => cast::<E1, u16>(array).into(),\n        DType::U8 => cast::<E1, u8>(array).into(),\n        DType::Bool(BoolStore::Native) => cast::<E1, bool>(array).into(),\n        dtype => panic!(\"Unsupported dtype: {dtype:?}\"),\n    }\n}\n\nmacro_rules! impl_from {\n    ($($ty: ty => $dtype: ident),*) => {\n        // From SharedArray (owned) -> NdArrayTensor\n        $(impl From<SharedArray<$ty>> for NdArrayTensor {\n           fn from(value: SharedArray<$ty>) -> NdArrayTensor {\n                NdArrayTensor::$dtype(NdArrayStorage::from_owned(value))\n           }\n        })*\n\n        // From NdArrayStorage -> NdArrayTensor\n        $(impl From<NdArrayStorage<$ty>> for NdArrayTensor {\n           fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor {\n                NdArrayTensor::$dtype(value)\n           }\n        })*\n    };\n}\n\nimpl_from!(\n    f64 => F64, f32 => F32,\n    i64 => I64, i32 => I32, i16 => I16, i8 => I8,\n    u64 => U64, u32 => U32, u16 => U16, u8 => U8,\n    bool => Bool\n);\n\n/// Macro to execute an operation on a given element type.\n///\n/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation.\n///\n/// # Panics\n/// Since there is no automatic type cast at this time, binary operations for different\n/// floating point precision data types will panic with a data type mismatch.\n#[macro_export]\nmacro_rules! execute_with_dtype {\n    (($lhs:expr, $rhs:expr),$element:ident,  $op:expr, [$($dtype: ident => $ty: ty),*]) => {{\n        let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs);\n        let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs);\n        match ($lhs, $rhs) {\n            $(\n                ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => {\n                    #[allow(unused)]\n                    type $element = $ty;\n                    // Convert storage to SharedArray for compatibility with existing operations\n                    $op(lhs.into_shared(), rhs.into_shared()).into()\n                }\n            )*\n            _ => panic!(\n                \"Data type mismatch (lhs: {:?}, rhs: {:?})\",\n                lhs_dtype, rhs_dtype\n            ),\n        }\n    }};\n    // Binary op: type automatically inferred by the compiler\n    (($lhs:expr, $rhs:expr), $op:expr) => {{\n        $crate::execute_with_dtype!(($lhs, $rhs), E, $op)\n    }};\n\n    // Binary op: generic type cannot be inferred for an operation\n    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [\n            F64 => f64, F32 => f32,\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8,\n            Bool => bool\n        ])\n    }};\n\n    ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{\n        match $tensor {\n            $(\n                $crate::NdArrayTensor::$dtype(storage) => {\n                    #[allow(unused)]\n                    type $element = $ty;\n                    // Convert to SharedArray for compatibility with most operations\n                    $op(storage.into_shared()).into()\n                }\n            )*\n            #[allow(unreachable_patterns)]\n            other => unimplemented!(\"unsupported dtype: {:?}\", other.dtype())\n        }\n    }};\n    // Unary op: type automatically inferred by the compiler\n    ($tensor:expr, $op:expr) => {{\n        $crate::execute_with_dtype!($tensor, E, $op)\n    }};\n\n    // Unary op: generic type cannot be inferred for an operation\n    ($tensor:expr, $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!($tensor, $element, $op, [\n            F64 => f64, F32 => f32,\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8,\n            Bool => bool\n        ])\n    }};\n}\n\n/// Macro to execute an operation a given element type.\n/// Only handles float types.\n///\n/// # Panics\n/// Since there is no automatic type cast at this time, binary operations for different\n/// floating point precision data types will panic with a data type mismatch.\n#[macro_export]\nmacro_rules! execute_with_float_dtype {\n    // Binary op: type automatically inferred by the compiler\n    (($lhs:expr, $rhs:expr), $op:expr) => {{\n        $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op)\n    }};\n\n    // Binary op: generic type cannot be inferred for an operation\n    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [\n            F64 => f64, F32 => f32\n        ])\n    }};\n\n    // Unary op: type automatically inferred by the compiler\n    ($tensor:expr, $op:expr) => {{\n        $crate::execute_with_float_dtype!($tensor, E, $op)\n    }};\n\n    // Unary op: generic type cannot be inferred for an operation\n    ($tensor:expr, $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!($tensor, $element, $op, [\n            F64 => f64, F32 => f32\n        ])\n    }};\n}\n\n/// Macro to execute an operation a given element type.\n/// Only handles int types.\n///\n/// # Panics\n/// Since there is no automatic type cast at this time, binary operations for different\n/// floating point precision data types will panic with a data type mismatch.\n#[macro_export]\nmacro_rules! execute_with_int_dtype {\n    // Binary op: type automatically inferred by the compiler\n    (($lhs:expr, $rhs:expr), $op:expr) => {{\n        $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op)\n    }};\n\n    // Binary op: generic type cannot be inferred for an operation\n    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8\n        ])\n    }};\n\n    // Unary op: type automatically inferred by the compiler\n    ($tensor:expr, $op:expr) => {{\n        $crate::execute_with_int_dtype!($tensor, E, $op)\n    }};\n\n    // Unary op: generic type cannot be inferred for an operation\n    ($tensor:expr, $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!($tensor, $element, $op, [\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8\n        ])\n    }};\n}\n\n/// Macro to execute an operation a given element type.\n/// Only handles numeric types\n///\n/// # Panics\n/// Since there is no automatic type cast at this time, binary operations for different\n/// floating point precision data types will panic with a data type mismatch.\n#[macro_export]\nmacro_rules! execute_with_numeric_dtype {\n    // Binary op: type automatically inferred by the compiler\n    (($lhs:expr, $rhs:expr), $op:expr) => {{\n        $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op)\n    }};\n\n    // Binary op: generic type cannot be inferred for an operation\n    (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [\n            F64 => f64, F32 => f32,\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8\n        ])\n    }};\n\n    // Unary op: type automatically inferred by the compiler\n    ($tensor:expr, $op:expr) => {{\n        $crate::execute_with_numeric_dtype!($tensor, E, $op)\n    }};\n\n    // Unary op: generic type cannot be inferred for an operation\n    ($tensor:expr, $element:ident, $op:expr) => {{\n        $crate::execute_with_dtype!($tensor, $element, $op, [\n            F64 => f64, F32 => f32,\n            I64 => i64, I32 => i32, I16 => i16, I8 => i8,\n            U64 => u64, U32 => u32, U16 => u16, U8 => u8\n        ])\n    }};\n}\n\n/// Macro to execute a cat operation on a given set of element types.\n///\n/// Uses zero-copy views from storage for concatenation.\n///\n/// # Panics\n/// Since there is no automatic type cast at this time, binary operations for different\n/// floating point precision data types will panic with a data type mismatch.\n#[macro_export]\nmacro_rules! cat_with_dtype {\n    ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => {\n        match &$tensors[0] {\n            $(NdArrayTensor::$dtype(_) => {\n                let tensors = $tensors\n                    .iter()\n                    .map(|t| {\n                        if let NdArrayTensor::$dtype(storage) = t {\n                            // Use storage.view() for zero-copy access\n                            storage.view()\n                        } else {\n                            panic!(\"Concatenate data type mismatch (expected {:?}, got {:?})\", $tensors[0].dtype(), t.dtype())\n                        }\n                    })\n                    .collect::<Vec<_>>();\n                NdArrayOps::concatenate(&tensors, $dim).into()\n            })*\n            _ => panic!(\"Unsupported dtype: {:?}\", $tensors[0].dtype())\n        }\n    };\n}\n\nimpl TensorMetadata for NdArrayTensor {\n    fn dtype(&self) -> DType {\n        match self {\n            NdArrayTensor::F64(_) => DType::F64,\n            NdArrayTensor::F32(_) => DType::F32,\n            NdArrayTensor::I64(_) => DType::I64,\n            NdArrayTensor::I32(_) => DType::I32,\n            NdArrayTensor::I16(_) => DType::I16,\n            NdArrayTensor::I8(_) => DType::I8,\n            NdArrayTensor::U64(_) => DType::U64,\n            NdArrayTensor::U32(_) => DType::U32,\n            NdArrayTensor::U16(_) => DType::U16,\n            NdArrayTensor::U8(_) => DType::U8,\n            NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native),\n        }\n    }\n\n    fn shape(&self) -> Shape {\n        // Use storage's shape method (works for both borrowed and owned)\n        macro_rules! get_shape {\n            ($($variant:ident),*) => {\n                match self {\n                    $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)*\n                }\n            };\n        }\n        get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)\n    }\n\n    fn rank(&self) -> usize {\n        self.shape().num_dims()\n    }\n}\n\npub(crate) trait ShapeOps {\n    fn num_dims(self) -> usize;\n    fn num_elements(self) -> usize;\n    fn dims<const N: usize>(self) -> [usize; N];\n    fn into_shape(self) -> Shape;\n}\n\nimpl ShapeOps for &[usize] {\n    fn num_dims(self) -> usize {\n        self.len()\n    }\n\n    fn num_elements(self) -> usize {\n        self.iter().product()\n    }\n\n    fn dims<const N: usize>(self) -> [usize; N] {\n        self.try_into().unwrap()\n    }\n\n    fn into_shape(self) -> Shape {\n        Shape::from(self)\n    }\n}\n\nmod utils {\n    use burn_std::tensor::is_contiguous;\n\n    use super::*;\n\n    impl NdArrayTensor {\n        pub(crate) fn into_data(self) -> TensorData {\n            let shape = self.shape();\n            let contiguous = self.is_contiguous();\n\n            fn inner<E: Element>(\n                shape: Shape,\n                is_contiguous: bool,\n                array: ArcArray<E, IxDyn>,\n            ) -> TensorData {\n                let vec = if is_contiguous {\n                    match array.try_into_owned_nocopy() {\n                        Ok(owned) => {\n                            let (mut vec, offset) = owned.into_raw_vec_and_offset();\n                            if let Some(offset) = offset {\n                                vec.drain(..offset);\n                            }\n                            if vec.len() > shape.num_elements() {\n                                vec.drain(shape.num_elements()..vec.len());\n                            }\n                            vec\n                        }\n                        Err(array) => array.into_iter().collect(),\n                    }\n                } else {\n                    array.into_iter().collect()\n                };\n\n                TensorData::new(vec, shape)\n            }\n\n            // Convert storage to owned array before extracting data\n            execute_with_dtype!(self, |arr| inner(shape, contiguous, arr))\n        }\n\n        pub(crate) fn is_contiguous(&self) -> bool {\n            // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous)\n            // For owned data, we check the strides\n            macro_rules! check_contiguous {\n                ($($variant:ident),*) => {\n                    match self {\n                        $(NdArrayTensor::$variant(storage) => {\n                            match storage {\n                                NdArrayStorage::Borrowed { .. } => {\n                                    // Borrowed storage requires contiguous row-major data\n                                    // (see NdArrayStorage::from_borrowed documentation)\n                                    true\n                                }\n                                NdArrayStorage::Owned(array) => {\n                                    let shape = array.shape();\n                                    let mut strides = Vec::with_capacity(array.strides().len());\n                                    for &stride in array.strides() {\n                                        if stride <= 0 {\n                                            return false;\n                                        }\n                                        strides.push(stride as usize);\n                                    }\n                                    is_contiguous(shape, &strides)\n                                }\n                            }\n                        })*\n                    }\n                };\n            }\n            check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)\n        }\n    }\n}\n\n/// Converts a slice of usize to a typed dimension.\n#[macro_export(local_inner_macros)]\nmacro_rules! to_typed_dims {\n    (\n        $n:expr,\n        $dims:expr,\n        justdim\n    ) => {{\n        let mut dims = [0; $n];\n        for i in 0..$n {\n            dims[i] = $dims[i];\n        }\n        let dim: Dim<[usize; $n]> = Dim(dims);\n        dim\n    }};\n}\n\n/// Reshapes an array into a tensor.\n#[macro_export(local_inner_macros)]\nmacro_rules! reshape {\n    (\n        ty $ty:ty,\n        n $n:expr,\n        shape $shape:expr,\n        array $array:expr\n    ) => {{\n        let dim = $crate::to_typed_dims!($n, $shape, justdim);\n        let array = match $array.is_standard_layout() {\n            true => {\n                match $array.to_shape(dim) {\n                    Ok(val) => val.into_shared(),\n                    Err(err) => {\n                        core::panic!(\"Shape should be compatible shape={dim:?}: {err:?}\");\n                    }\n                }\n            },\n            false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),\n        };\n        array.into_dyn()\n    }};\n    (\n        ty $ty:ty,\n        shape $shape:expr,\n        array $array:expr,\n        d $D:expr\n    ) => {{\n        match $D {\n            1 => reshape!(ty $ty, n 1, shape $shape, array $array),\n            2 => reshape!(ty $ty, n 2, shape $shape, array $array),\n            3 => reshape!(ty $ty, n 3, shape $shape, array $array),\n            4 => reshape!(ty $ty, n 4, shape $shape, array $array),\n            5 => reshape!(ty $ty, n 5, shape $shape, array $array),\n            6 => reshape!(ty $ty, n 6, shape $shape, array $array),\n            _ => core::panic!(\"NdArray supports arrays up to 6 dimensions, received: {}\", $D),\n        }\n    }};\n}\n\n/// Slice a tensor\n#[macro_export]\nmacro_rules! slice {\n    ($tensor:expr, $slices:expr) => {\n        slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool)\n    };\n    ($tensor:expr, $slices:expr, $($variant:ident),*) => {\n        match $tensor {\n            $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })*\n        }\n    };\n}\n\nimpl NdArrayTensor {\n    /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).\n    ///\n    /// This method attempts zero-copy loading when possible. If the data has properly\n    /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise,\n    /// it falls back to copying the data.\n    ///\n    /// Zero-copy loading works when:\n    /// - The data's bytes are properly aligned for the element type\n    /// - The bytes can be borrowed (e.g., from mmap'd file or static data)\n    pub fn from_data(data: TensorData) -> NdArrayTensor {\n        // Only use Borrowed storage for non-native allocations (e.g., burnpack mmap/file).\n        // For native Rust heap allocations (the common case), go directly to owned storage:\n        // `from_data_owned` reclaims the Vec zero-copy via `into_vec`, while\n        // Borrowed storage would trigger a full memcopy on every single operation\n        // (because `is_unique()` always returns false for Borrowed).\n        use burn_backend::AllocationProperty;\n        if data.bytes.property() != AllocationProperty::Native {\n            match Self::try_from_data_borrowed(data) {\n                Ok(tensor) => return tensor,\n                Err(data) => return Self::from_data_owned(data),\n            }\n        }\n        Self::from_data_owned(data)\n    }\n\n    /// Try to create a tensor with borrowed storage (zero-copy).\n    ///\n    /// Takes ownership of TensorData and returns it back on failure.\n    /// No cloning occurs - bytes are moved into storage or returned on failure.\n    ///\n    /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data).\n    fn try_from_data_borrowed(data: TensorData) -> Result<NdArrayTensor, TensorData> {\n        let TensorData {\n            bytes,\n            shape,\n            dtype,\n        } = data;\n\n        macro_rules! try_borrow {\n            ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => {\n                match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) {\n                    Ok(storage) => return Ok(NdArrayTensor::$variant(storage)),\n                    Err((bytes, shape)) => (bytes, shape),\n                }\n            };\n        }\n\n        // Try to create borrowed storage; get bytes back on failure\n        let (bytes, shape) = match dtype {\n            DType::F64 => try_borrow!(f64, F64, bytes, shape),\n            DType::F32 => try_borrow!(f32, F32, bytes, shape),\n            DType::I64 => try_borrow!(i64, I64, bytes, shape),\n            DType::I32 => try_borrow!(i32, I32, bytes, shape),\n            DType::I16 => try_borrow!(i16, I16, bytes, shape),\n            DType::I8 => try_borrow!(i8, I8, bytes, shape),\n            DType::U64 => try_borrow!(u64, U64, bytes, shape),\n            DType::U32 => try_borrow!(u32, U32, bytes, shape),\n            DType::U16 => try_borrow!(u16, U16, bytes, shape),\n            DType::U8 => try_borrow!(u8, U8, bytes, shape),\n            DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape),\n            _ => (bytes, shape), // QFloat not supported for zero-copy\n        };\n\n        Err(TensorData {\n            bytes,\n            shape,\n            dtype,\n        })\n    }\n\n    /// Create a tensor with owned storage.\n    ///\n    /// This may or may not copy data depending on whether the underlying bytes\n    /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned,\n    /// no copy occurs; otherwise data is copied to a new allocation.\n    fn from_data_owned(data: TensorData) -> NdArrayTensor {\n        let shape = data.shape.to_vec(); // TODO: into_vec\n\n        macro_rules! execute {\n            ($data: expr, [$($dtype: pat => $ty: ty),*]) => {\n                match $data.dtype {\n                    $( $dtype => {\n                        match data.into_vec::<$ty>() {\n                            Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(),\n                            Err(err) => panic!(\"Data should have the same element type as the tensor {err:?}\"),\n                        }.into()\n                    }, )*\n                    other => unimplemented!(\"Unsupported dtype {other:?}\"),\n                }\n            };\n        }\n\n        execute!(data, [\n            DType::F64 => f64, DType::F32 => f32,\n            DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8,\n            DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8,\n            DType::Bool(BoolStore::Native) => bool\n        ])\n    }\n}\n\n/// A quantized tensor for the ndarray backend.\n#[derive(Clone, Debug)]\npub struct NdArrayQTensor {\n    /// The quantized tensor.\n    pub qtensor: NdArrayTensor,\n    /// The quantization scheme.\n    pub scheme: QuantScheme,\n    /// The quantization parameters.\n    pub qparams: Vec<QParams<f32>>,\n}\n\nimpl NdArrayQTensor {\n    /// Returns the quantization strategy, including quantization parameters, for the given tensor.\n    pub fn strategy(&self) -> QuantizationStrategy {\n        match self.scheme {\n            QuantScheme {\n                level: QuantLevel::Tensor,\n                mode: QuantMode::Symmetric,\n                value:\n                    QuantValue::Q8F\n                    | QuantValue::Q8S\n                    | QuantValue::E4M3\n                    | QuantValue::E5M2\n                    | QuantValue::Q4F\n                    | QuantValue::Q4S\n                    | QuantValue::E2M1\n                    | QuantValue::Q2F\n                    | QuantValue::Q2S,\n                ..\n            } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(\n                self.qparams[0].scales,\n                self.scheme.value,\n            )),\n            QuantScheme {\n                level: QuantLevel::Block(block_size),\n                mode: QuantMode::Symmetric,\n                value:\n                    QuantValue::Q8F\n                    | QuantValue::Q8S\n                    | QuantValue::E4M3\n                    | QuantValue::E5M2\n                    | QuantValue::Q4F\n                    | QuantValue::Q4S\n                    | QuantValue::E2M1\n                    | QuantValue::Q2F\n                    | QuantValue::Q2S,\n                ..\n            } => QuantizationStrategy::PerBlockSymmetric(\n                self.qparams\n                    .iter()\n                    .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value))\n                    .collect(),\n                block_size,\n            ),\n        }\n    }\n}\n\nimpl QTensorPrimitive for NdArrayQTensor {\n    fn scheme(&self) -> &QuantScheme {\n        &self.scheme\n    }\n\n    fn default_scheme() -> QuantScheme {\n        QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native)\n    }\n}\n\nimpl TensorMetadata for NdArrayQTensor {\n    fn dtype(&self) -> DType {\n        DType::QFloat(self.scheme)\n    }\n\n    fn shape(&self) -> Shape {\n        self.qtensor.shape()\n    }\n\n    fn rank(&self) -> usize {\n        self.shape().num_dims()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::NdArray;\n    use alloc::vec;\n\n    use super::*;\n    use burn_backend::{\n        Distribution,\n        ops::{FloatTensorOps, QTensorOps},\n        quantization::{QuantStore, QuantizationParametersPrimitive},\n    };\n    use burn_std::rand::get_seeded_rng;\n\n    #[test]\n    fn should_support_into_and_from_data_1d() {\n        let data_expected = TensorData::random::<f32, _, _>(\n            Shape::new([3]),\n            Distribution::Default,\n            &mut get_seeded_rng(),\n        );\n        let tensor = NdArrayTensor::from_data(data_expected.clone());\n\n        let data_actual = tensor.into_data();\n\n        assert_eq!(data_expected, data_actual);\n    }\n\n    #[test]\n    fn should_support_into_and_from_data_2d() {\n        let data_expected = TensorData::random::<f32, _, _>(\n            Shape::new([2, 3]),\n            Distribution::Default,\n            &mut get_seeded_rng(),\n        );\n        let tensor = NdArrayTensor::from_data(data_expected.clone());\n\n        let data_actual = tensor.into_data();\n\n        assert_eq!(data_expected, data_actual);\n    }\n\n    #[test]\n    fn should_support_into_and_from_data_3d() {\n        let data_expected = TensorData::random::<f32, _, _>(\n            Shape::new([2, 3, 4]),\n            Distribution::Default,\n            &mut get_seeded_rng(),\n        );\n        let tensor = NdArrayTensor::from_data(data_expected.clone());\n\n        let data_actual = tensor.into_data();\n\n        assert_eq!(data_expected, data_actual);\n    }\n\n    #[test]\n    fn should_support_into_and_from_data_4d() {\n        let data_expected = TensorData::random::<f32, _, _>(\n            Shape::new([2, 3, 4, 2]),\n            Distribution::Default,\n            &mut get_seeded_rng(),\n        );\n        let tensor = NdArrayTensor::from_data(data_expected.clone());\n\n        let data_actual = tensor.into_data();\n\n        assert_eq!(data_expected, data_actual);\n    }\n\n    #[test]\n    fn should_support_qtensor_strategy() {\n        type B = NdArray<f32, i64, i8>;\n        let scale: f32 = 0.009_019_608;\n        let device = Default::default();\n\n        let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device);\n        let scheme = QuantScheme::default()\n            .with_value(QuantValue::Q8S)\n            .with_store(QuantStore::Native);\n        let qparams = QuantizationParametersPrimitive {\n            scales: B::float_from_data(TensorData::from([scale]), &device),\n        };\n        let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams);\n\n        assert_eq!(qtensor.scheme(), &scheme);\n        assert_eq!(\n            qtensor.strategy(),\n            QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init(\n                scale,\n                QuantValue::Q8S\n            ))\n        );\n    }\n\n    // ==========================================================================\n    // Zero-copy integration tests\n    // These tests verify end-to-end zero-copy behavior through NdArrayTensor.\n    // ==========================================================================\n\n    #[test]\n    fn zero_copy_creates_borrowed_storage_for_non_native() {\n        // Verify that from_data creates borrowed storage for non-native allocations\n        // (e.g. burnpack mmap/file data tagged with AllocationProperty::Other or File).\n        // Native heap allocations intentionally use Owned storage for performance.\n        use burn_backend::AllocationProperty;\n        use burn_std::Bytes;\n\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        // Tag as Other to simulate burnpack / mmap data (non-native backing storage)\n        let non_native_bytes = Bytes::from_shared(\n            bytes::Bytes::copy_from_slice(&*bytes),\n            AllocationProperty::Other,\n        );\n        let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32);\n\n        let tensor = NdArrayTensor::from_data(tensor_data);\n\n        match &tensor {\n            NdArrayTensor::F32(storage) => {\n                assert!(\n                    storage.is_borrowed(),\n                    \"ZERO-COPY REGRESSION: from_data should create borrowed storage \\\n                     for non-native (e.g. burnpack) TensorData\"\n                );\n                assert!(\n                    !storage.is_unique(),\n                    \"ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false\"\n                );\n            }\n            _ => panic!(\"Expected F32 tensor\"),\n        }\n    }\n\n    #[test]\n    fn native_alloc_creates_owned_storage() {\n        // Native heap allocations must use Owned storage so that is_unique()\n        // returns true and ndarray can perform in-place mutations without copying.\n        use burn_std::Bytes;\n\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data); // AllocationProperty::Native\n        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);\n\n        let tensor = NdArrayTensor::from_data(tensor_data);\n\n        match &tensor {\n            NdArrayTensor::F32(storage) => {\n                assert!(\n                    !storage.is_borrowed(),\n                    \"PERF REGRESSION: from_data must NOT create borrowed storage \\\n                     for native heap allocations (is_unique() would always be false)\"\n                );\n            }\n            _ => panic!(\"Expected F32 tensor\"),\n        }\n    }\n\n    #[test]\n    fn zero_copy_data_integrity() {\n        // Verify data is correctly accessible through borrowed storage\n        use burn_std::Bytes;\n\n        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];\n        let bytes = Bytes::from_elems(data);\n        let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32);\n\n        let tensor = NdArrayTensor::from_data(tensor_data);\n\n        match &tensor {\n            NdArrayTensor::F32(storage) => {\n                let view = storage.view();\n                assert_eq!(view[[0, 0]], 1.0);\n                assert_eq!(view[[0, 1]], 2.0);\n                assert_eq!(view[[1, 0]], 3.0);\n                assert_eq!(view[[1, 1]], 4.0);\n            }\n            _ => panic!(\"Expected F32 tensor\"),\n        }\n    }\n\n    #[test]\n    fn zero_copy_fallback_when_bytes_owned() {\n        // When TensorData owns bytes exclusively, it may use the copy path\n        // This is expected behavior - verify it still works correctly\n        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);\n        let tensor = NdArrayTensor::from_data(data.clone());\n        let result = tensor.into_data();\n\n        assert_eq!(data, result, \"Data should round-trip correctly\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Neural network building blocks for the Burn deep learning framework\"\ndocumentation = \"https://docs.rs/burn-nn\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-nn\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-nn\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"std\",\n    \"burn-core/default\",\n]\ndoc = [\n    \"std\",\n    # Doc features\n    \"burn-core/doc\",\n    \"pretrained\",\n]\npretrained = [\"std\", \"burn-store/pytorch\", \"burn-std/network\", \"dirs\"]\n# Added for some test cases that should only be run locally \n# (e.g., test cases with pretrained weights for gram matrix loss)\ntest-local = []  \nstd = [\n    \"burn-core/std\",\n    \"num-traits/std\",\n    \"burn-store?/std\",\n    \"burn-std?/std\",\n]\ntracing = [\n    \"burn-core/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-rocm?/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-wgpu?/tracing\",\n    \"burn-fusion?/tracing\",\n]\n\ntest-cuda = [\n    \"burn-cuda/default\",\n] # To use cuda during testing, default uses ndarray.\ntest-rocm = [\n    \"burn-rocm/default\",\n] # To use hip during testing, default uses ndarray.\ntest-tch = [\n    \"burn-tch/default\",\n] # To use tch during testing, default uses ndarray.\ntest-wgpu = [\n    \"burn-wgpu/default\",\n] # To use wgpu during testing, default uses ndarray.\ntest-vulkan = [\n    \"test-wgpu\",\n    \"burn-wgpu/vulkan\",\n] # To use wgpu-spirv during testing, default uses ndarray.\ntest-metal = [\n    \"test-wgpu\",\n    \"burn-wgpu/metal\",\n] # To use wgpu-spirv during testing, default uses ndarray.\n\n# Memory checks are disabled by default\ntest-memory-checks = [\"burn-fusion/memory-checks\"]\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std when std is disabled **\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", default-features = false }\n\nnum-traits = { workspace = true }\n\n# FOR TESTING\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-remote = { path = \"../burn-remote\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\n\n# For loss functions requiring pretrained models (e.g., Gram Matrix Loss)\nburn-store = { path = \"../burn-store\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\ndirs = { workspace = true, optional = true }\n\n[dev-dependencies]\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\" }\nrstest = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-nn/README.md",
    "content": "# Burn Neural Networks\n\nCore building blocks for Burn neural networks."
  },
  {
    "path": "crates/burn-nn/src/activation/activation_wrapper.rs",
    "content": "use burn_core as burn;\n\nuse crate::activation::{\n    Celu, CeluConfig, Elu, EluConfig, Gelu, HardShrink, HardShrinkConfig, HardSigmoid,\n    HardSigmoidConfig, HardSwish, LeakyRelu, LeakyReluConfig, PRelu, PReluConfig, Relu, Selu,\n    Shrink, ShrinkConfig, Sigmoid, SoftShrink, SoftShrinkConfig, Softplus, SoftplusConfig,\n    Softsign, SwiGlu, SwiGluConfig, Tanh, ThresholdedRelu, ThresholdedReluConfig,\n};\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// [`Activation`] Configuration.\n#[derive(Config, Debug)]\n#[non_exhaustive]\npub enum ActivationConfig {\n    /// [`Gelu`] activation layer.\n    Gelu,\n\n    /// [`Gelu`] activation layer with tanh approximation.\n    GeluApproximate,\n\n    /// [`PRelu`] activation layer.\n    PRelu(PReluConfig),\n\n    /// [`Relu`] activation layer.\n    Relu,\n\n    /// [`LeakyRelu`] activation layer.\n    LeakyRelu(LeakyReluConfig),\n\n    /// [`SwiGlu`] activation layer.\n    SwiGlu(SwiGluConfig),\n\n    /// [`Selu`] activation layer.\n    Selu,\n\n    /// [`Sigmoid`] activation layer.\n    Sigmoid,\n\n    /// [`Tanh`] activation layer.\n    Tanh,\n\n    /// [`HardSigmoid`] activation layer.\n    HardSigmoid(HardSigmoidConfig),\n\n    /// [`HardSwish`] activation layer.\n    HardSwish,\n\n    /// [`Softplus`] activation layer.\n    Softplus(SoftplusConfig),\n\n    /// [`Softsign`] activation layer.\n    Softsign,\n\n    /// [`Elu`] activation layer.\n    Elu(EluConfig),\n\n    /// [`Celu`] activation layer.\n    Celu(CeluConfig),\n\n    /// [`ThresholdedRelu`] activation layer.\n    ThresholdedRelu(ThresholdedReluConfig),\n\n    /// [`HardShrink`] activation layer.\n    HardShrink(HardShrinkConfig),\n\n    /// [`SoftShrink`] activation layer.\n    SoftShrink(SoftShrinkConfig),\n\n    /// [`Shrink`] activation layer.\n    Shrink(ShrinkConfig),\n}\n\nimpl From<PReluConfig> for ActivationConfig {\n    fn from(config: PReluConfig) -> Self {\n        Self::PRelu(config)\n    }\n}\n\nimpl From<LeakyReluConfig> for ActivationConfig {\n    fn from(config: LeakyReluConfig) -> Self {\n        Self::LeakyRelu(config)\n    }\n}\n\nimpl From<SwiGluConfig> for ActivationConfig {\n    fn from(config: SwiGluConfig) -> Self {\n        Self::SwiGlu(config)\n    }\n}\n\nimpl From<HardSigmoidConfig> for ActivationConfig {\n    fn from(config: HardSigmoidConfig) -> Self {\n        Self::HardSigmoid(config)\n    }\n}\n\nimpl From<SoftplusConfig> for ActivationConfig {\n    fn from(config: SoftplusConfig) -> Self {\n        Self::Softplus(config)\n    }\n}\n\nimpl From<EluConfig> for ActivationConfig {\n    fn from(config: EluConfig) -> Self {\n        Self::Elu(config)\n    }\n}\n\nimpl From<CeluConfig> for ActivationConfig {\n    fn from(config: CeluConfig) -> Self {\n        Self::Celu(config)\n    }\n}\n\nimpl From<ThresholdedReluConfig> for ActivationConfig {\n    fn from(config: ThresholdedReluConfig) -> Self {\n        Self::ThresholdedRelu(config)\n    }\n}\n\nimpl From<HardShrinkConfig> for ActivationConfig {\n    fn from(config: HardShrinkConfig) -> Self {\n        Self::HardShrink(config)\n    }\n}\n\nimpl From<SoftShrinkConfig> for ActivationConfig {\n    fn from(config: SoftShrinkConfig) -> Self {\n        Self::SoftShrink(config)\n    }\n}\n\nimpl From<ShrinkConfig> for ActivationConfig {\n    fn from(config: ShrinkConfig) -> Self {\n        Self::Shrink(config)\n    }\n}\n\nimpl ActivationConfig {\n    /// Initialize a wrapped activation layer.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Activation<B> {\n        match self {\n            ActivationConfig::Relu => Relu.into(),\n            ActivationConfig::LeakyRelu(conf) => conf.init().into(),\n            ActivationConfig::Gelu => Gelu::new().into(),\n            ActivationConfig::GeluApproximate => Gelu::new_approximate().into(),\n            ActivationConfig::PRelu(conf) => conf.init(device).into(),\n            ActivationConfig::SwiGlu(conf) => conf.init(device).into(),\n            ActivationConfig::HardSigmoid(conf) => conf.init().into(),\n            ActivationConfig::HardSwish => HardSwish.into(),\n            ActivationConfig::Softplus(conf) => conf.init().into(),\n            ActivationConfig::Selu => Selu.into(),\n            ActivationConfig::Sigmoid => Sigmoid.into(),\n            ActivationConfig::Tanh => Tanh.into(),\n            ActivationConfig::Softsign => Softsign.into(),\n            ActivationConfig::Elu(conf) => conf.init().into(),\n            ActivationConfig::Celu(conf) => conf.init().into(),\n            ActivationConfig::HardShrink(conf) => conf.init().into(),\n            ActivationConfig::SoftShrink(conf) => conf.init().into(),\n            ActivationConfig::Shrink(conf) => conf.init().into(),\n            ActivationConfig::ThresholdedRelu(conf) => conf.init().into(),\n        }\n    }\n}\n\n/// Activation Layer Wrapper.\n///\n/// Provides support for many in-built `burn::nn` activations.\n#[derive(Module, Debug)]\n#[non_exhaustive]\n#[allow(clippy::large_enum_variant)]\npub enum Activation<B: Backend> {\n    /// [`Gelu`] activation layer.\n    Gelu(Gelu),\n\n    /// [`PRelu`] activation layer.\n    PRelu(PRelu<B>),\n\n    /// [`Relu`] activation layer.\n    Relu(Relu),\n\n    /// [`LeakyRelu`] activation layer.\n    LeakyRelu(LeakyRelu),\n\n    /// [`SwiGlu`] activation layer.\n    SwiGlu(SwiGlu<B>),\n\n    /// [`Selu`] activation layer.\n    Selu(Selu),\n\n    /// [`Sigmoid`] activation layer.\n    Sigmoid(Sigmoid),\n\n    /// [`Tanh`] activation layer.\n    Tanh(Tanh),\n\n    /// [`HardSigmoid`] activation layer.\n    HardSigmoid(HardSigmoid),\n\n    /// [`HardSwish`] activation layer.\n    HardSwish(HardSwish),\n\n    /// [`Softplus`] activation layer.\n    Softplus(Softplus),\n\n    /// [`Softsign`] activation layer.\n    Softsign(Softsign),\n\n    /// [`Elu`] activation layer.\n    Elu(Elu),\n\n    /// [`Celu`] activation layer.\n    Celu(Celu),\n\n    /// [`ThresholdedRelu`] activation layer.\n    ThresholdedRelu(ThresholdedRelu),\n\n    /// [`HardShrink`] activation layer.\n    HardShrink(HardShrink),\n\n    /// [`SoftShrink`] activation layer.\n    SoftShrink(SoftShrink),\n\n    /// [`Shrink`] activation layer.\n    Shrink(Shrink),\n}\n\nimpl<B: Backend> From<Gelu> for Activation<B> {\n    fn from(layer: Gelu) -> Self {\n        Self::Gelu(layer)\n    }\n}\n\nimpl<B: Backend> From<PRelu<B>> for Activation<B> {\n    fn from(layer: PRelu<B>) -> Self {\n        Self::PRelu(layer)\n    }\n}\n\nimpl<B: Backend> From<Relu> for Activation<B> {\n    fn from(layer: Relu) -> Self {\n        Self::Relu(layer)\n    }\n}\n\nimpl<B: Backend> From<LeakyRelu> for Activation<B> {\n    fn from(layer: LeakyRelu) -> Self {\n        Self::LeakyRelu(layer)\n    }\n}\n\nimpl<B: Backend> From<SwiGlu<B>> for Activation<B> {\n    fn from(layer: SwiGlu<B>) -> Self {\n        Self::SwiGlu(layer)\n    }\n}\n\nimpl<B: Backend> From<Selu> for Activation<B> {\n    fn from(layer: Selu) -> Self {\n        Self::Selu(layer)\n    }\n}\n\nimpl<B: Backend> From<Sigmoid> for Activation<B> {\n    fn from(layer: Sigmoid) -> Self {\n        Self::Sigmoid(layer)\n    }\n}\n\nimpl<B: Backend> From<Tanh> for Activation<B> {\n    fn from(layer: Tanh) -> Self {\n        Self::Tanh(layer)\n    }\n}\n\nimpl<B: Backend> From<HardSigmoid> for Activation<B> {\n    fn from(layer: HardSigmoid) -> Self {\n        Self::HardSigmoid(layer)\n    }\n}\n\nimpl<B: Backend> From<HardSwish> for Activation<B> {\n    fn from(layer: HardSwish) -> Self {\n        Self::HardSwish(layer)\n    }\n}\n\nimpl<B: Backend> From<Softplus> for Activation<B> {\n    fn from(layer: Softplus) -> Self {\n        Self::Softplus(layer)\n    }\n}\n\nimpl<B: Backend> From<Softsign> for Activation<B> {\n    fn from(layer: Softsign) -> Self {\n        Self::Softsign(layer)\n    }\n}\n\nimpl<B: Backend> From<Elu> for Activation<B> {\n    fn from(layer: Elu) -> Self {\n        Self::Elu(layer)\n    }\n}\n\nimpl<B: Backend> From<Celu> for Activation<B> {\n    fn from(layer: Celu) -> Self {\n        Self::Celu(layer)\n    }\n}\n\nimpl<B: Backend> From<ThresholdedRelu> for Activation<B> {\n    fn from(layer: ThresholdedRelu) -> Self {\n        Self::ThresholdedRelu(layer)\n    }\n}\n\nimpl<B: Backend> From<HardShrink> for Activation<B> {\n    fn from(layer: HardShrink) -> Self {\n        Self::HardShrink(layer)\n    }\n}\n\nimpl<B: Backend> From<SoftShrink> for Activation<B> {\n    fn from(layer: SoftShrink) -> Self {\n        Self::SoftShrink(layer)\n    }\n}\n\nimpl<B: Backend> From<Shrink> for Activation<B> {\n    fn from(layer: Shrink) -> Self {\n        Self::Shrink(layer)\n    }\n}\n\nimpl<B: Backend> Activation<B> {\n    /// Forward pass.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        match self {\n            Activation::Relu(layer) => layer.forward(input),\n            Activation::LeakyRelu(layer) => layer.forward(input),\n            Activation::Gelu(layer) => layer.forward(input),\n            Activation::PRelu(layer) => layer.forward(input),\n            Activation::SwiGlu(layer) => layer.forward(input),\n            Activation::HardSigmoid(layer) => layer.forward(input),\n            Activation::HardSwish(layer) => layer.forward(input),\n            Activation::Softplus(layer) => layer.forward(input),\n            Activation::Selu(layer) => layer.forward(input),\n            Activation::Sigmoid(layer) => layer.forward(input),\n            Activation::Tanh(layer) => layer.forward(input),\n            Activation::Softsign(layer) => layer.forward(input),\n            Activation::Elu(layer) => layer.forward(input),\n            Activation::Celu(layer) => layer.forward(input),\n            Activation::ThresholdedRelu(layer) => layer.forward(input),\n            Activation::HardShrink(layer) => layer.forward(input),\n            Activation::SoftShrink(layer) => layer.forward(input),\n            Activation::Shrink(layer) => layer.forward(input),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::module::Module;\n\n    fn make_input<B: Backend>(device: &B::Device) -> Tensor<B, 2> {\n        Tensor::from_data([[-1.0, -0.5, 0.0], [1.0, 0.5, 0.0]], device)\n    }\n\n    fn expect_tensor<B: Backend, const D: usize>(actual: Tensor<B, D>, expected: Tensor<B, D>) {\n        actual.to_data().assert_eq(&expected.to_data(), true);\n    }\n\n    fn check_stateless_config_output<B: Backend, const D: usize>(\n        config: ActivationConfig,\n        input: Tensor<B, D>,\n        expected: Tensor<B, D>,\n        device: &B::Device,\n    ) {\n        let act = config.init(device);\n        let output = act.forward(input);\n        expect_tensor(output, expected);\n    }\n\n    #[test]\n    fn test_gelu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Gelu::new().forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Gelu, input, expected, &device)\n    }\n\n    #[test]\n    fn test_gelu_approximate() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Gelu::new_approximate().forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::GeluApproximate, input, expected, &device)\n    }\n\n    #[test]\n    fn test_prelu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = PReluConfig::new();\n        let expected = inner_config.init(&device).forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_relu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Relu.forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Relu, input, expected, &device)\n    }\n\n    #[test]\n    fn test_leaky_relu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = LeakyReluConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_swi_glu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let d_input = input.shape()[1];\n        let d_output = 2 * d_input;\n\n        let inner_config = SwiGluConfig::new(d_input, d_output);\n        let mut reference: SwiGlu<TestBackend> = inner_config.init(&device);\n\n        let config: ActivationConfig = inner_config.into();\n        let layer = config.init(&device);\n\n        match &layer {\n            Activation::SwiGlu(inner) => {\n                // Clone the initialized weights.\n                let state = inner.clone().into_record();\n                reference = reference.load_record(state);\n            }\n            _ => unreachable!(),\n        };\n\n        expect_tensor(\n            layer.forward(input.clone()),\n            reference.forward(input.clone()),\n        )\n    }\n\n    #[test]\n    fn test_selu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Selu.forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Selu, input, expected, &device)\n    }\n\n    #[test]\n    fn test_sigmoid() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Sigmoid.forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Sigmoid, input, expected, &device)\n    }\n\n    #[test]\n    fn test_tanh() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Tanh.forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Tanh, input, expected, &device)\n    }\n\n    #[test]\n    fn test_hard_sigmoid() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = HardSigmoidConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_softsign() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let expected = Softsign.forward(input.clone());\n\n        check_stateless_config_output(ActivationConfig::Softsign, input, expected, &device)\n    }\n\n    #[test]\n    fn test_elu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = EluConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_softplus() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = SoftplusConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_celu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = CeluConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_thresholded_relu() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = ThresholdedReluConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_hard_shrink() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = HardShrinkConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_soft_shrink() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = SoftShrinkConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n\n    #[test]\n    fn test_shrink() {\n        let device = Default::default();\n        let input = make_input::<TestBackend>(&device);\n\n        let inner_config = ShrinkConfig::new();\n        let expected = inner_config.init().forward(input.clone());\n\n        check_stateless_config_output(inner_config.into(), input, expected, &device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/celu.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::celu;\nuse burn::tensor::backend::Backend;\n\n/// CELU (Continuously Differentiable Exponential Linear Unit) layer.\n///\n/// Applies the CELU function element-wise:\n/// `celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`\n///\n/// Should be created with [CeluConfig](CeluConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Celu {\n    /// The alpha value for the CELU formulation.\n    pub alpha: f64,\n}\n\n/// Configuration to create a [Celu](Celu) layer using the [init function](CeluConfig::init).\n#[derive(Config, Debug)]\npub struct CeluConfig {\n    /// The alpha value for the CELU formulation. Default is 1.0\n    #[config(default = \"1.0\")]\n    pub alpha: f64,\n}\n\nimpl CeluConfig {\n    /// Initialize a new [Celu](Celu) Layer\n    pub fn init(&self) -> Celu {\n        Celu { alpha: self.alpha }\n    }\n}\n\nimpl ModuleDisplay for Celu {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"alpha\", &self.alpha).optional()\n    }\n}\n\nimpl Celu {\n    /// Forward pass for the Celu layer.\n    ///\n    /// See [celu](burn::tensor::activation::celu) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        celu(input, self.alpha)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_celu_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Celu = CeluConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, -0.5, -1.0]]), &device);\n        let out = model.forward(input);\n        // celu(0.5, 1) = 0.5\n        // celu(-0.5, 1) = 1 * (exp(-0.5) - 1) = -0.393469\n        // celu(-1.0, 1) = 1 * (exp(-1) - 1) = -0.632121\n        let expected = TensorData::from([[0.5, -0.393469, -0.632121]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_celu_with_alpha() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Celu = CeluConfig::new().with_alpha(2.0).init();\n        let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, -2.0]]), &device);\n        let out = model.forward(input);\n        // celu(0, 2) = 0\n        // celu(-2, 2) = 2 * (exp(-1) - 1) = -1.264241\n        let expected = TensorData::from([[0.0, -1.264241]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = CeluConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"Celu {alpha: 1}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/elu.rs",
    "content": "use burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_core as burn;\n\nuse burn::tensor::activation::elu;\n\n/// ELU (Exponential Linear Unit) layer.\n///\n/// Should be created with [EluConfig](EluConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Elu {\n    /// The alpha value.\n    pub alpha: f64,\n}\n/// Configuration to create an [Elu](Elu) layer using the [init function](EluConfig::init).\n#[derive(Config, Debug)]\npub struct EluConfig {\n    /// The alpha value. Default is 1.0\n    #[config(default = \"1.0\")]\n    pub alpha: f64,\n}\nimpl EluConfig {\n    /// Initialize a new [Elu](Elu) Layer\n    pub fn init(&self) -> Elu {\n        Elu { alpha: self.alpha }\n    }\n}\n\nimpl ModuleDisplay for Elu {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"alpha\", &self.alpha).optional()\n    }\n}\n\nimpl Elu {\n    /// Forward pass for the ELU layer.\n    ///\n    /// See [elu](burn::tensor::activation::elu) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        elu(input, self.alpha)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_elu_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Elu = EluConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);\n        let out = model.forward(input);\n        // elu(0.4410, 1.0) = 0.4410\n        // elu(-0.2507, 1.0) = 1.0 * (exp(-0.2507) - 1) = -0.22186\n        let expected = TensorData::from([[0.4410, -0.22186]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = EluConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"Elu {alpha: 1}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/gelu.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the Gaussian Error Linear Units function element-wise.\n///\n/// See also [gelu](burn::tensor::activation::gelu)\n///\n/// When `approximate` is true, uses the tanh approximation:\n/// `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`\n#[derive(Module, Clone, Debug, Default)]\npub struct Gelu {\n    /// Whether to use tanh approximation.\n    pub approximate: bool,\n}\n\nimpl Gelu {\n    /// Create the module with exact GELU.\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Create the module with tanh approximation.\n    pub fn new_approximate() -> Self {\n        Self { approximate: true }\n    }\n\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        if self.approximate {\n            burn::tensor::activation::gelu_approximate(input)\n        } else {\n            burn::tensor::activation::gelu(input)\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::Tolerance;\n    use burn::tensor::ops::FloatElem;\n\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn display() {\n        let layer = Gelu::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Gelu {\\n  approximate: false\\n}\");\n    }\n\n    #[test]\n    fn forward_approximate() {\n        let device = Default::default();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[-1.0, 0.0, 1.0], [0.5, -0.5, 2.0]], &device);\n\n        let output = Gelu::new_approximate().forward(input);\n\n        // PyTorch: torch.nn.functional.gelu(x, approximate=\"tanh\")\n        let expected = Tensor::<TestBackend, 2>::from_data(\n            [\n                [-0.1588079929, 0.0000000000, 0.8411920071],\n                [0.3457140028, -0.1542859972, 1.9545977116],\n            ],\n            &device,\n        );\n\n        output\n            .into_data()\n            .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::rel_abs(1e-5, 1e-5));\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/glu.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the gated linear unit function.\n///\n/// See also [glu](burn::tensor::activation::glu)\n#[derive(Module, Clone, Debug, Default)]\npub struct GLU {\n    dim: usize,\n}\n\nimpl GLU {\n    /// Create the module.\n    ///\n    /// # Arguments\n    /// * `dim` - The dimension on which to split the input.\n    pub fn new(dim: usize) -> Self {\n        Self { dim }\n    }\n\n    /// Applies the gated linear unit function.\n    ///\n    /// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.\n    ///\n    /// **Note**:\n    /// * The size of the input tensor along `dim` must be divisible by 2.\n    ///\n    /// ### Arguments\n    /// * `tensor` - The input tensor.\n    ///\n    /// ### Returns\n    /// * A tensor with the same shape as the input, except the size along `dim` is halved.\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::glu(input, self.dim)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = GLU::new(1);\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"GLU {\\n  dim: 1\\n}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/hard_shrink.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::hard_shrink;\nuse burn::tensor::backend::Backend;\n\n/// Hard Shrink layer.\n///\n/// Applies the Hard Shrink function element-wise:\n/// `hard_shrink(x) = x if |x| > lambda else 0`\n///\n/// Should be created with [HardShrinkConfig](HardShrinkConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct HardShrink {\n    /// The lambda value for the Hard Shrink formulation.\n    pub lambda: f64,\n}\n\n/// Configuration to create a [HardShrink](HardShrink) layer using the [init function](HardShrinkConfig::init).\n#[derive(Config, Debug)]\npub struct HardShrinkConfig {\n    /// The lambda value for the Hard Shrink formulation. Default is 0.5\n    #[config(default = \"0.5\")]\n    pub lambda: f64,\n}\n\nimpl HardShrinkConfig {\n    /// Initialize a new [HardShrink](HardShrink) Layer\n    pub fn init(&self) -> HardShrink {\n        HardShrink {\n            lambda: self.lambda,\n        }\n    }\n}\n\nimpl ModuleDisplay for HardShrink {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"lambda\", &self.lambda).optional()\n    }\n}\n\nimpl HardShrink {\n    /// Forward pass for the Hard Shrink layer.\n    ///\n    /// See [hard_shrink](burn::tensor::activation::hard_shrink) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        hard_shrink(input, self.lambda)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_hard_shrink_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: HardShrink = HardShrinkConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -1.0], [8.0, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn test_hard_shrink_with_lambda() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: HardShrink = HardShrinkConfig::new().with_lambda(0.2).init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.1, -0.1, -0.3], [0.5, 0.1, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -0.3], [0.5, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn display() {\n        let config = HardShrinkConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"HardShrink {lambda: 0.5}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/hard_sigmoid.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::hard_sigmoid;\nuse burn::tensor::backend::Backend;\n\n/// Hard Sigmoid layer.\n///\n/// Should be created with [HardSigmoidConfig](HardSigmoidConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct HardSigmoid {\n    /// The alpha value.\n    pub alpha: f64,\n    /// The beta value.\n    pub beta: f64,\n}\n/// Configuration to create a [Hard Sigmoid](HardSigmoid) layer using the [init function](HardSigmoidConfig::init).\n#[derive(Config, Debug)]\npub struct HardSigmoidConfig {\n    /// The alpha value. Default is 0.2\n    #[config(default = \"0.2\")]\n    pub alpha: f64,\n    /// The beta value. Default is 0.5\n    #[config(default = \"0.5\")]\n    pub beta: f64,\n}\nimpl HardSigmoidConfig {\n    /// Initialize a new [Hard Sigmoid](HardSigmoid) Layer\n    pub fn init(&self) -> HardSigmoid {\n        HardSigmoid {\n            alpha: self.alpha,\n            beta: self.beta,\n        }\n    }\n}\n\nimpl ModuleDisplay for HardSigmoid {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"alpha\", &self.alpha)\n            .add(\"beta\", &self.beta)\n            .optional()\n    }\n}\n\nimpl HardSigmoid {\n    /// Forward pass for the Hard Sigmoid layer.\n    ///\n    /// See [hard_sigmoid](burn::tensor::activation::hard_sigmoid) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        hard_sigmoid(input, self.alpha, self.beta)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_hard_sigmoid_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: HardSigmoid = HardSigmoidConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.5882, 0.44986]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = HardSigmoidConfig::new().init();\n        assert_eq!(\n            alloc::format!(\"{config}\"),\n            \"HardSigmoid {alpha: 0.2, beta: 0.5}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/hard_swish.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::hard_swish;\nuse burn::tensor::backend::Backend;\n\n/// Hard Swish layer.\n#[derive(Module, Clone, Debug, Default)]\npub struct HardSwish;\n\nimpl HardSwish {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self\n    }\n\n    /// Forward pass for the Hard Swish layer.\n    ///\n    /// See [hard_swish](burn::tensor::activation::hard_swish) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        hard_swish(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_hard_swish_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model = HardSwish::new();\n\n        let input = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[3.0f32, -3.0], [0.0, 1.0]]),\n            &device,\n        );\n        let out = model.forward(input);\n        let expected = TensorData::from([[3.0f32, 0.0], [0.0, 0.6666667]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let layer = HardSwish::new();\n        assert_eq!(alloc::format!(\"{layer}\"), \"HardSwish\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/leaky_relu.rs",
    "content": "use burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_core as burn;\n\nuse burn::tensor::activation::leaky_relu;\n\n/// Leaky ReLu layer.\n///\n/// Should be created with [LeakyReluConfig](LeakyReluConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct LeakyRelu {\n    /// The negative slope.\n    pub negative_slope: f64,\n}\n/// Configuration to create a [Leaky Relu](LeakyRelu) layer using the [init function](LeakyReluConfig::init).\n#[derive(Config, Debug)]\npub struct LeakyReluConfig {\n    /// The negative slope. Default is 0.01\n    #[config(default = \"0.01\")]\n    pub negative_slope: f64,\n}\nimpl LeakyReluConfig {\n    /// Initialize a new [Leaky Relu](LeakyRelu) Layer\n    pub fn init(&self) -> LeakyRelu {\n        LeakyRelu {\n            negative_slope: self.negative_slope,\n        }\n    }\n}\n\nimpl ModuleDisplay for LeakyRelu {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"negative_slope\", &self.negative_slope)\n            .optional()\n    }\n}\n\nimpl LeakyRelu {\n    /// Forward pass for the Leaky ReLu layer.\n    ///\n    /// See [leaky_relu](burn::tensor::activation::leaky_relu) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        leaky_relu(input, self.negative_slope)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_leaky_relu_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: LeakyRelu = LeakyReluConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.4410, -0.2507]]), &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.4410, -0.002507]]);\n        out.to_data().assert_eq(&expected, false);\n    }\n    #[test]\n    fn test_leaky_relu_forward_multi_dim() {\n        let input = [\n            [\n                [-1.0222, 1.5810, 0.3457, -1.3530],\n                [0.0231, 0.8681, 0.2473, -0.0377],\n                [0.3520, -1.1199, 1.2219, 0.2804],\n            ],\n            [\n                [1.0002, 0.7259, 0.8779, 0.2084],\n                [1.5615, -0.1057, -0.4886, -1.5184],\n                [-0.5523, -0.2741, -0.0210, -1.1352],\n            ],\n        ];\n        let expected = TensorData::from([\n            [\n                [-1.0222e-02, 1.5810e+00, 3.457e-01, -1.3530e-02],\n                [2.31e-02, 8.681e-01, 2.473e-01, -3.77e-04],\n                [3.52e-01, -1.1199e-02, 1.2219e+00, 2.804e-01],\n            ],\n            [\n                [1.0002e+00, 7.259e-01, 8.779e-01, 2.084e-01],\n                [1.5615e+00, -1.057e-03, -4.886e-03, -1.5184e-02],\n                [-5.523e-03, -2.741e-03, -2.1e-04, -1.1352e-02],\n            ],\n        ]);\n\n        let device = <TestBackend as Backend>::Device::default();\n        let model: LeakyRelu = LeakyReluConfig::new().init();\n        let input_data = Tensor::<TestBackend, 3>::from_data(TensorData::from(input), &device);\n        let actual_output = model.forward(input_data);\n        actual_output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default())\n    }\n\n    #[test]\n    fn display() {\n        let config = LeakyReluConfig::new().init();\n        assert_eq!(\n            alloc::format!(\"{config}\"),\n            \"LeakyRelu {negative_slope: 0.01}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/mod.rs",
    "content": "//! # Activation Layers\n//!\n//! Users who desire a selectable activation function should\n//! consider [`Activation`], which provides an abstraction over:\n//! * [`Relu`] - the default,\n//! * ['PRelu']\n//! * [`Gelu`]\n//! * [`LeakyRelu`]\n//! * [`SwiGlu`]\n//! * [`Selu`]\n//! * [`Sigmoid`]\n//! * [`HardSigmoid`]\n//! * [`HardSwish`]\n//! * [`Softplus`]\n//! * [`Softsign`]\n//! * [`Tanh`]\n//! * [`Elu`]\n//! * [`Celu`]\n//! * [`ThresholdedRelu`]\n//!\n//! The activation layer [`GLU`] has shape-changing behaviors\n//! not compatible with the common API, and is not included\n//! in the abstraction wrappers.\n\nmod activation_wrapper;\n\n// These are pub(crate) for dual-export in `nn` without re-exporting\n// all of `nn.activation`, or manually listing each symbol.\npub(crate) mod celu;\npub(crate) mod elu;\npub(crate) mod gelu;\npub(crate) mod glu;\npub(crate) mod hard_shrink;\npub(crate) mod hard_sigmoid;\npub(crate) mod hard_swish;\npub(crate) mod leaky_relu;\npub(crate) mod prelu;\npub(crate) mod relu;\npub(crate) mod selu;\npub(crate) mod shrink;\npub(crate) mod sigmoid;\npub(crate) mod soft_shrink;\npub(crate) mod softplus;\npub(crate) mod softsign;\npub(crate) mod swiglu;\npub(crate) mod tanh;\npub(crate) mod thresholded_relu;\n\npub use activation_wrapper::*;\npub use celu::*;\npub use elu::*;\npub use gelu::*;\npub use glu::*;\npub use hard_shrink::*;\npub use hard_sigmoid::*;\npub use hard_swish::*;\npub use leaky_relu::*;\npub use prelu::*;\npub use relu::*;\npub use selu::*;\npub use shrink::*;\npub use sigmoid::*;\npub use soft_shrink::*;\npub use softplus::*;\npub use softsign::*;\npub use swiglu::*;\npub use tanh::*;\npub use thresholded_relu::*;\n"
  },
  {
    "path": "crates/burn-nn/src/activation/prelu.rs",
    "content": "use burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_core as burn;\n/// Parametric Relu layer.\n///\n/// Should be created using [PReluConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct PRelu<B: Backend> {\n    /// the weights learnt for PReLu. can be of shape \\[1\\] or \\[num_parameters\\] in which case it must\n    /// be the same as number of channels in the input tensor\n    pub alpha: Param<Tensor<B, 1>>,\n\n    /// Alpha value for the PRelu layer\n    pub alpha_value: f64,\n}\n\nimpl<B: Backend> ModuleDisplay for PRelu<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [num_parameters] = self.alpha.shape().dims();\n\n        content\n            .add(\"num_parameters\", &num_parameters)\n            .add(\"alpha_value\", &self.alpha_value)\n            .optional()\n    }\n}\n\n/// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init).\n#[derive(Config, Debug)]\npub struct PReluConfig {\n    /// The number of parameters.\n    #[config(default = \"1\")]\n    pub num_parameters: usize,\n    /// The learnable weight alpha. Default is 0.25\n    #[config(default = \"0.25\")]\n    pub alpha: f64,\n}\n\nimpl PReluConfig {\n    /// Initialize a new [Parametric Relu](PRelu) Layer\n    pub fn init<B: Backend>(&self, device: &B::Device) -> PRelu<B> {\n        PRelu {\n            // alpha is a tensor of length num_parameters\n            alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device),\n            alpha_value: self.alpha,\n        }\n    }\n}\n\nimpl<B: Backend> PRelu<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    ///\n    /// See also [prelu](burn::tensor::activation::prelu) for more information.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::prelu(input, self.alpha.val())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn display() {\n        let layer = PReluConfig::new().init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/relu.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the rectified linear unit function element-wise\n/// See also [relu](burn::tensor::activation::relu)\n///\n#[derive(Module, Clone, Debug, Default)]\npub struct Relu;\n\nimpl Relu {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self {}\n    }\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::relu(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = Relu::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Relu\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/selu.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the Scaled Exponential Linear Unit function element-wise.\n/// See also [selu](burn::tensor::activation::selu)\n#[derive(Module, Clone, Debug, Default)]\npub struct Selu;\n\nimpl Selu {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self {}\n    }\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::selu(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = Selu::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Selu\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/shrink.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::shrink;\nuse burn::tensor::backend::Backend;\n\n/// Shrink layer.\n///\n/// Applies the Shrink function element-wise:\n/// `shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`\n///\n/// Should be created with [ShrinkConfig](ShrinkConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Shrink {\n    /// The lambda value for the Shrink formulation.\n    pub lambda: f64,\n    /// The bias value for the Shrink formulation.\n    // Usually bias = lambda, but need this to handle onnx spec https://onnx.ai/onnx/operators/onnx__Shrink.html\n    pub bias: f64,\n}\n\n/// Configuration to create a [Shrink](Shrink) layer using the [init function](ShrinkConfig::init).\n#[derive(Config, Debug)]\npub struct ShrinkConfig {\n    /// The lambda value for the Shrink formulation. Default is 0.5\n    #[config(default = \"0.5\")]\n    pub lambda: f64,\n    /// The bias value for the Shrink formulation. Default is 0.5.\n    #[config(default = \"0.5\")]\n    pub bias: f64,\n}\n\nimpl ShrinkConfig {\n    /// Initialize a new [Shrink](Shrink) Layer\n    pub fn init(&self) -> Shrink {\n        Shrink {\n            lambda: self.lambda,\n            bias: self.bias,\n        }\n    }\n}\n\nimpl ModuleDisplay for Shrink {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"lambda\", &self.lambda)\n            .add(\"bias\", &self.bias)\n            .optional()\n    }\n}\n\nimpl Shrink {\n    /// Forward pass for the Shrink layer.\n    ///\n    /// See [shrink](burn::tensor::activation::shrink) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        shrink(input, self.lambda, self.bias)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_shrink_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Shrink = ShrinkConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn test_shrink_with_lambda_and_bias() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Shrink = ShrinkConfig::new()\n            .with_lambda(0.25)\n            .with_bias(0.125)\n            .init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -0.375], [0.625, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn display() {\n        let config = ShrinkConfig::new().init();\n        assert_eq!(\n            alloc::format!(\"{config}\"),\n            \"Shrink {lambda: 0.5, bias: 0.5}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/sigmoid.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the sigmoid function element-wise\n/// See also [sigmoid](burn::tensor::activation::sigmoid)\n#[derive(Module, Clone, Debug, Default)]\npub struct Sigmoid;\n\nimpl Sigmoid {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self {}\n    }\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::sigmoid(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = Sigmoid::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Sigmoid\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/soft_shrink.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::soft_shrink;\nuse burn::tensor::backend::Backend;\n\n/// Soft Shrink layer.\n///\n/// Applies the Soft Shrink function element-wise:\n/// `soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`\n///\n/// Should be created with [SoftShrinkConfig](SoftShrinkConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct SoftShrink {\n    /// The lambda value for the Soft Shrink formulation.\n    pub lambda: f64,\n}\n\n/// Configuration to create a [SoftShrink](SoftShrink) layer using the [init function](SoftShrinkConfig::init).\n#[derive(Config, Debug)]\npub struct SoftShrinkConfig {\n    /// The lambda value for the Soft Shrink formulation. Default is 0.5\n    #[config(default = \"0.5\")]\n    pub lambda: f64,\n}\n\nimpl SoftShrinkConfig {\n    /// Initialize a new [SoftShrink](SoftShrink) Layer\n    pub fn init(&self) -> SoftShrink {\n        SoftShrink {\n            lambda: self.lambda,\n        }\n    }\n}\n\nimpl ModuleDisplay for SoftShrink {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"lambda\", &self.lambda).optional()\n    }\n}\n\nimpl SoftShrink {\n    /// Forward pass for the Soft Shrink layer.\n    ///\n    /// See [soft_shrink](burn::tensor::activation::soft_shrink) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        soft_shrink(input, self.lambda)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_soft_shrink_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: SoftShrink = SoftShrinkConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.5, -0.5, -1.0], [8.0, 0.3, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -0.5], [7.5, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn test_soft_shrink_with_lambda() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: SoftShrink = SoftShrinkConfig::new().with_lambda(0.25).init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[0.125, -0.125, -0.5], [0.75, 0.1, 0.0]], &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0_f32, 0.0, -0.25], [0.5, 0.0, 0.0]]);\n        assert_eq!(out.into_data(), expected);\n    }\n\n    #[test]\n    fn display() {\n        let config = SoftShrinkConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"SoftShrink {lambda: 0.5}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/softplus.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::softplus;\nuse burn::tensor::backend::Backend;\n\n/// Softplus layer.\n///\n/// Applies the softplus function element-wise:\n/// `softplus(x) = (1/beta) * log(1 + exp(beta * x))`\n///\n/// Should be created with [SoftplusConfig](SoftplusConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Softplus {\n    /// The beta value.\n    pub beta: f64,\n}\n\n/// Configuration to create a [Softplus](Softplus) layer using the [init function](SoftplusConfig::init).\n#[derive(Config, Debug)]\npub struct SoftplusConfig {\n    /// The beta value. Default is 1.0\n    #[config(default = \"1.0\")]\n    pub beta: f64,\n}\n\nimpl SoftplusConfig {\n    /// Initialize a new [Softplus](Softplus) Layer\n    pub fn init(&self) -> Softplus {\n        Softplus { beta: self.beta }\n    }\n}\n\nimpl ModuleDisplay for Softplus {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"beta\", &self.beta).optional()\n    }\n}\n\nimpl Softplus {\n    /// Forward pass for the Softplus layer.\n    ///\n    /// See [softplus](burn::tensor::activation::softplus) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        softplus(input, self.beta)\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::approx_constant)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_softplus_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Softplus = SoftplusConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0, -1.0]]), &device);\n        let out = model.forward(input);\n        // softplus(0) = log(2) ≈ 0.6931\n        // softplus(1) = log(1 + e) ≈ 1.3133\n        // softplus(-1) = log(1 + e^-1) ≈ 0.3133\n        let expected = TensorData::from([[0.6931, 1.3133, 0.3133]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_softplus_with_beta() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: Softplus = SoftplusConfig::new().with_beta(2.0).init();\n        let input = Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0, 1.0]]), &device);\n        let out = model.forward(input);\n        // softplus(0, beta=2) = (1/2) * log(1 + exp(0)) = 0.5 * log(2) ≈ 0.3466\n        // softplus(1, beta=2) = (1/2) * log(1 + exp(2)) = 0.5 * log(8.389) ≈ 1.0635\n        let expected = TensorData::from([[0.3466, 1.0635]]);\n        out.to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = SoftplusConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"Softplus {beta: 1}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/softsign.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the softsign function element-wise\n/// See also [softsign](burn::tensor::activation::softsign)\n#[derive(Module, Clone, Debug, Default)]\npub struct Softsign;\n\nimpl Softsign {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self {}\n    }\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::softsign(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = Softsign::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Softsign\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/swiglu.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::activation::silu;\nuse burn::tensor::{Tensor, backend::Backend};\n\nuse crate::{Linear, LinearConfig, LinearLayout};\n\n/// Configuration to create a [SwiGlu](SwiGlu) activation layer using the [init function](SwiGluConfig::init).\n#[derive(Config, Debug)]\npub struct SwiGluConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the output features.\n    pub d_output: usize,\n    /// If a bias should be applied during the linear transformation. Default behaviour is False\n    /// for SwiGLU activation implementations.\n    #[config(default = false)]\n    pub bias: bool,\n    /// The type of function used to initialize the linear layer parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n    /// The layout in which the linear parameters are stored.\n    #[config(default = \"LinearLayout::Row\")]\n    pub layout: LinearLayout,\n}\n\n/// Applies the SwiGLU or Swish Gated Linear Unit to the input tensor.\n/// The SwiGLU activation function is defined as:\n/// `SwiGLU(x) = Swish(W_inner * x + b_inner) * (W_outer * x + b_outer)`\n///\n/// Should be created with [SwiGluConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct SwiGlu<B: Backend> {\n    /// The inner linear layer for Swish activation function\n    /// with `d_input` input features and `d_output` output features.\n    pub linear_inner: Linear<B>,\n    /// The outer linear layer for element wise multiplication\n    /// with `d_input` input features and `d_output` output features.\n    pub linear_outer: Linear<B>,\n}\n\nimpl<B: Backend> ModuleDisplay for SwiGlu<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, d_output] = self.linear_inner.weight.shape().dims();\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_output\", &d_output)\n            .add(\"bias\", &self.linear_inner.bias.is_some())\n            .optional()\n    }\n}\n\nimpl SwiGluConfig {\n    /// Initialize a new [SwiGLU](SwiGlu) activation layer.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> SwiGlu<B> {\n        SwiGlu {\n            linear_inner: LinearConfig::new(self.d_input, self.d_output)\n                .with_bias(self.bias)\n                .with_initializer(self.initializer.clone())\n                .with_layout(self.layout)\n                .init(device),\n            linear_outer: LinearConfig::new(self.d_input, self.d_output)\n                .with_bias(self.bias)\n                .with_initializer(self.initializer.clone())\n                .with_layout(self.layout)\n                .init(device),\n        }\n    }\n}\n\nimpl<B: Backend> SwiGlu<B> {\n    /// Applies the Swish Gated Linear Unit to the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, seq_length, d_input]`\n    /// - output: `[batch_size, seq_length, d_output]`\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let x = self.linear_inner.forward(input.clone());\n        let x = silu(x);\n        x.mul(self.linear_outer.forward(input))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_swiglu_forward_no_bias() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = SwiGluConfig::new(3, 3).with_initializer(Initializer::Constant { value: 0.5 });\n        let swiglu = config.init(&device);\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n        let output = swiglu.forward(input);\n        let expected_output = Tensor::<TestBackend, 2>::from_data(\n            [[8.5732, 8.5732, 8.5732], [56.2189, 56.2189, 56.2189]],\n            &device,\n        );\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_swiglu_forward_with_bias() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = SwiGluConfig::new(3, 3)\n            .with_bias(true)\n            .with_initializer(Initializer::Constant { value: 0.5 });\n        let swiglu = config.init(&device);\n        let input =\n            Tensor::<TestBackend, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n        let output = swiglu.forward(input);\n        let expected_output = Tensor::<TestBackend, 2>::from_data(\n            [[11.8909, 11.8909, 11.8909], [63.9785, 63.9785, 63.9785]],\n            &device,\n        );\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = SwiGluConfig::new(3, 5);\n        let swiglu = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{swiglu}\"),\n            \"SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/tanh.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Applies the tanh activation function element-wise\n/// See also [tanh](burn::tensor::activation::tanh)\n#[derive(Module, Clone, Debug, Default)]\npub struct Tanh;\n\nimpl Tanh {\n    /// Create the module.\n    pub fn new() -> Self {\n        Self {}\n    }\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        burn::tensor::activation::tanh(input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let layer = Tanh::new();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Tanh\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/activation/thresholded_relu.rs",
    "content": "use burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_core as burn;\n\nuse burn::tensor::activation::thresholded_relu;\n\n/// Thresholded ReLU layer.\n///\n/// Should be created with [ThresholdedReluConfig](ThresholdedReluConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct ThresholdedRelu {\n    /// The alpha threshold.\n    pub alpha: f64,\n}\n\n/// Configuration to create a [ThresholdedRelu](ThresholdedRelu) layer using the [init function](ThresholdedReluConfig::init).\n#[derive(Config, Debug)]\npub struct ThresholdedReluConfig {\n    /// The alpha threshold. Default is 1.0\n    #[config(default = \"1.0\")]\n    pub alpha: f64,\n}\n\nimpl ThresholdedReluConfig {\n    /// Initialize a new [ThresholdedRelu](ThresholdedRelu) layer.\n    pub fn init(&self) -> ThresholdedRelu {\n        ThresholdedRelu { alpha: self.alpha }\n    }\n}\n\nimpl ModuleDisplay for ThresholdedRelu {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"alpha\", &self.alpha).optional()\n    }\n}\n\nimpl ThresholdedRelu {\n    /// Forward pass for the Thresholded ReLU layer.\n    ///\n    /// See [thresholded_relu](burn::tensor::activation::thresholded_relu) for more information.\n    ///\n    /// # Shapes\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        thresholded_relu(input, self.alpha)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_thresholded_relu_forward() {\n        let device = <TestBackend as Backend>::Device::default();\n        let model: ThresholdedRelu = ThresholdedReluConfig::new().init();\n        let input =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5, 1.5, -0.2]]), &device);\n        let out = model.forward(input);\n        let expected = TensorData::from([[0.0, 1.5, 0.0]]);\n        out.to_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn display() {\n        let config = ThresholdedReluConfig::new().init();\n        assert_eq!(alloc::format!(\"{config}\"), \"ThresholdedRelu {alpha: 1}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![recursion_limit = \"256\"]\n\n//! Burn neural network module.\n\n/// Loss module\npub mod loss;\n\n/// Neural network modules implementations.\npub mod modules;\npub use modules::*;\n\npub mod activation;\npub use activation::{\n    celu::*, elu::*, gelu::*, glu::*, hard_shrink::*, hard_sigmoid::*, leaky_relu::*, prelu::*,\n    relu::*, selu::*, shrink::*, sigmoid::*, soft_shrink::*, softplus::*, softsign::*, swiglu::*,\n    tanh::*, thresholded_relu::*,\n};\n\nmod padding;\npub use padding::*;\n\n// For backward compat, `burn::nn::Initializer`\npub use burn_core::module::Initializer;\n\nextern crate alloc;\n\n/// Backend for test cases\n#[cfg(all(\n    test,\n    not(feature = \"test-tch\"),\n    not(feature = \"test-wgpu\"),\n    not(feature = \"test-cuda\"),\n    not(feature = \"test-rocm\")\n))]\npub type TestBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(all(test, feature = \"test-tch\"))]\n/// Backend for test cases\npub type TestBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(all(test, feature = \"test-wgpu\"))]\n/// Backend for test cases\npub type TestBackend = burn_wgpu::Wgpu;\n\n#[cfg(all(test, feature = \"test-cuda\"))]\n/// Backend for test cases\npub type TestBackend = burn_cuda::Cuda;\n\n#[cfg(all(test, feature = \"test-rocm\"))]\n/// Backend for test cases\npub type TestBackend = burn_rocm::Rocm;\n\n/// Backend for autodiff test cases\n#[cfg(test)]\npub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;\n\n#[cfg(all(test, feature = \"test-memory-checks\"))]\nmod tests {\n    burn_fusion::memory_checks!();\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/binary_cross_entropy.rs",
    "content": "use burn_core as burn;\n\nuse alloc::vec::Vec;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::activation::log_sigmoid;\nuse burn::tensor::{Int, Tensor, backend::Backend};\nuse burn::{config::Config, module::Module};\n\n/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).\n#[derive(Config, Debug)]\npub struct BinaryCrossEntropyLossConfig {\n    /// Create weighted binary cross-entropy with a weight for each class.\n    ///\n    /// The loss of a specific sample will simply be multiplied by its label weight.\n    pub weights: Option<Vec<f32>>,\n\n    /// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).\n    ///\n    /// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.\n    /// Alpha = 0 would be the same as default.\n    pub smoothing: Option<f32>,\n\n    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.\n    #[config(default = false)]\n    pub logits: bool,\n}\n\nimpl BinaryCrossEntropyLossConfig {\n    /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).\n    pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {\n        self.assertions();\n        BinaryCrossEntropyLoss {\n            weights: self\n                .weights\n                .as_ref()\n                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),\n            smoothing: self.smoothing,\n            logits: self.logits,\n        }\n    }\n\n    fn assertions(&self) {\n        if let Some(alpha) = self.smoothing {\n            assert!(\n                (0.0..=1.).contains(&alpha),\n                \"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}\"\n            );\n        };\n        if let Some(weights) = self.weights.as_ref() {\n            assert!(\n                weights.iter().all(|e| e > &0.),\n                \"Weights of cross-entropy have to be positive.\"\n            );\n        }\n    }\n}\n\n/// Calculate the binary cross entropy loss from the input logits and the targets.\n///\n/// Should be created using [BinaryCrossEntropyLossConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct BinaryCrossEntropyLoss<B: Backend> {\n    /// Weights for cross-entropy.\n    pub weights: Option<Tensor<B, 1>>,\n    /// Label smoothing alpha.\n    pub smoothing: Option<f32>,\n    /// Treat the inputs as logits\n    pub logits: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"weights\", &self.weights)\n            .add(\"smoothing\", &self.smoothing)\n            .add(\"logits\", &self.logits)\n            .optional()\n    }\n}\n\nimpl<B: Backend> BinaryCrossEntropyLoss<B> {\n    /// Compute the criterion on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// Binary:\n    /// - logits: `[batch_size]`\n    /// - targets: `[batch_size]`\n    ///\n    /// Multi-label:\n    /// - logits: `[batch_size, num_classes]`\n    /// - targets: `[batch_size, num_classes]`\n    pub fn forward<const D: usize>(\n        &self,\n        logits: Tensor<B, D>,\n        targets: Tensor<B, D, Int>,\n    ) -> Tensor<B, 1> {\n        self.assertions(&logits, &targets);\n\n        let mut targets_float = targets.clone().float();\n        let shape = targets.dims();\n\n        if let Some(alpha) = self.smoothing {\n            let num_classes = if D > 1 { shape[D - 1] } else { 2 };\n            targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;\n        }\n\n        let mut loss = if self.logits {\n            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`\n            (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)\n        } else {\n            // - (target * log(input) + (1 - target) * log(1 - input))\n            // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values\n            (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)\n                - targets_float * logits.log().clamp_min(-100.0)\n        };\n\n        if let Some(weights) = &self.weights {\n            let weights = if D > 1 {\n                weights.clone().expand(shape)\n            } else {\n                // Flatten targets and expand resulting weights to make it compatible with\n                // Tensor<B, D> for binary 1-D case\n                weights\n                    .clone()\n                    .gather(0, targets.flatten(0, 0))\n                    .expand(shape)\n            };\n            loss = loss * weights;\n        }\n\n        loss.mean()\n    }\n\n    fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {\n        let logits_dims = logits.dims();\n        let targets_dims = targets.dims();\n        assert!(\n            logits_dims == targets_dims,\n            \"Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?}).\"\n        );\n\n        if let Some(weights) = &self.weights\n            && D > 1\n        {\n            let targets_classes = targets_dims[D - 1];\n            let weights_classes = weights.dims()[0];\n            assert!(\n                weights_classes == targets_classes,\n                \"The number of classes ({weights_classes}) does not match the weights provided ({targets_classes}).\"\n            );\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{TensorData, activation::sigmoid};\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_binary_cross_entropy_preds_all_correct() {\n        let device = Default::default();\n        let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .init(&device)\n            .forward(preds, targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.000]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_preds_all_incorrect() {\n        let device = Default::default();\n        let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .init(&device)\n            .forward(preds, targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([100.000]); // clamped value\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_binary_cross_entropy() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])\n        // target = torch.tensor([0., 1., 0., 1.])\n        // loss = nn.BCELoss()\n        // sigmoid = nn.Sigmoid()\n        // out = loss(sigmoid(input), target) # tensor(0.7491)\n\n        let device = Default::default();\n        let logits =\n            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .init(&device)\n            .forward(sigmoid(logits), targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.7491]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_with_logits() {\n        let device = Default::default();\n        let logits =\n            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_logits(true)\n            .init(&device)\n            .forward(logits, targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.7491]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_with_weights() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])\n        // target = torch.tensor([0, 1, 0, 1])\n        // weights = torch.tensor([3., 7.]).gather(0, target)\n        // loss = nn.BCELoss(weights)\n        // sigmoid = nn.Sigmoid()\n        // out = loss(sigmoid(input), target.float()) # tensor(3.1531)\n\n        let device = Default::default();\n        let logits =\n            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);\n        let weights = [3., 7.];\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_weights(Some(weights.to_vec()))\n            .init(&device)\n            .forward(sigmoid(logits), targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([3.1531]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_with_smoothing() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])\n        // target = torch.tensor([0., 1., 0., 1.])\n        // target_smooth = target * (1 - 0.1) + (0.1 / 2)\n        // loss = nn.BCELoss()\n        // sigmoid = nn.Sigmoid()\n        // out = loss(sigmoid(input), target_smooth) # tensor(0.7490)\n\n        let device = Default::default();\n        let logits =\n            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);\n        let targets =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_smoothing(Some(0.1))\n            .init(&device)\n            .forward(sigmoid(logits), targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.7490]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_multilabel() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])\n        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])\n        // weights = torch.tensor([3., 7., 0.9])\n        // loss = nn.BCEWithLogitsLoss()\n        // out = loss(input, target) # tensor(0.7112)\n\n        let device = Default::default();\n        let logits = Tensor::<TestBackend, 2>::from_floats(\n            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::from([[1, 0, 1], [1, 0, 0]]),\n            &device,\n        );\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_logits(true)\n            .init(&device)\n            .forward(logits, targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.7112]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_multilabel_with_weights() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])\n        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])\n        // loss = nn.BCEWithLogitsLoss()\n        // out = loss(input, target) # tensor(3.1708)\n\n        let device = Default::default();\n        let logits = Tensor::<TestBackend, 2>::from_floats(\n            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::from([[1, 0, 1], [1, 0, 0]]),\n            &device,\n        );\n        let weights = [3., 7., 0.9];\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_logits(true)\n            .with_weights(Some(weights.to_vec()))\n            .init(&device)\n            .forward(logits, targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([3.1708]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_binary_cross_entropy_multilabel_with_smoothing() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])\n        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])\n        // target_smooth = target * (1 - 0.1) + (0.1 / 3)\n        // loss = nn.BCELoss()\n        // sigmoid = nn.Sigmoid()\n        // out = loss(sigmoid(input), target_smooth) # tensor(0.7228)\n\n        let device = Default::default();\n        let logits = Tensor::<TestBackend, 2>::from_floats(\n            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::from([[1, 0, 1], [1, 0, 0]]),\n            &device,\n        );\n\n        let loss_actual = BinaryCrossEntropyLossConfig::new()\n            .with_smoothing(Some(0.1))\n            .init(&device)\n            .forward(sigmoid(logits), targets)\n            .into_data();\n\n        let loss_expected = TensorData::from([0.7228]);\n        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());\n    }\n\n    #[test]\n    #[should_panic = \"The number of classes\"]\n    fn multilabel_weights_should_match_target() {\n        // import torch\n        // from torch import nn\n        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])\n        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])\n        // loss = nn.BCEWithLogitsLoss()\n        // out = loss(input, target) # tensor(3.1708)\n\n        let device = Default::default();\n        let logits = Tensor::<TestBackend, 2>::from_floats(\n            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::from([[1, 0, 1], [1, 0, 0]]),\n            &device,\n        );\n        let weights = [3., 7.];\n\n        let _loss = BinaryCrossEntropyLossConfig::new()\n            .with_logits(true)\n            .with_weights(Some(weights.to_vec()))\n            .init(&device)\n            .forward(logits, targets);\n    }\n\n    #[test]\n    fn display() {\n        let config =\n            BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));\n        let loss = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{loss}\"),\n            \"BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/cosine_embedding.rs",
    "content": "use alloc::format;\n\nuse burn::tensor::linalg::cosine_similarity;\n\nuse burn_core as burn;\n\nuse crate::loss::reduction::Reduction;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::{Int, Tensor, activation::relu, backend::Backend};\n\n/// Configuration for CosineEmbeddingLoss.\n#[derive(Config, Debug)]\npub struct CosineEmbeddingLossConfig {\n    /// Margin for negative samples.\n    #[config(default = 0.0)]\n    pub margin: f32,\n\n    /// Specifies the reduction to apply to the output.\n    #[config(default = \"Reduction::Mean\")]\n    pub reduction: Reduction,\n}\n\nimpl CosineEmbeddingLossConfig {\n    /// Initialize CosineEmbeddingLoss.\n    pub fn init(&self) -> CosineEmbeddingLoss {\n        CosineEmbeddingLoss {\n            margin: self.margin,\n            reduction: self.reduction.clone(),\n        }\n    }\n}\n\n/// Cosine embedding loss between two tensors.\n///\n/// Measures cosine distance between tensors.\n/// Used for learning embeddings or similarity.\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct CosineEmbeddingLoss {\n    /// Margin value. Default: 0.0\n    pub margin: f32,\n\n    /// Reduction method\n    pub reduction: Reduction,\n}\n\nimpl Default for CosineEmbeddingLoss {\n    fn default() -> Self {\n        CosineEmbeddingLossConfig::new().init()\n    }\n}\n\nimpl ModuleDisplay for CosineEmbeddingLoss {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"margin\", &self.margin)\n            .add(\"reduction\", format!(\"{:?}\", &self.reduction).as_str())\n            .optional()\n    }\n}\n\nimpl CosineEmbeddingLoss {\n    /// Creates a new instance\n    pub fn new() -> Self {\n        CosineEmbeddingLossConfig::new().init()\n    }\n\n    /// Compute loss with reduction.\n    ///\n    /// # Shapes\n    ///\n    /// - input1: ``[batch_size, embedding_dim]``\n    /// - input2: ``[batch_size, embedding_dim]``\n    /// - target: ``[batch_size]`` with values 1 or -1\n    ///\n    /// # Returns\n    ///\n    /// Loss tensor of shape ``[1]``\n    pub fn forward<B: Backend>(\n        &self,\n        input1: Tensor<B, 2>,\n        input2: Tensor<B, 2>,\n        target: Tensor<B, 1, Int>,\n    ) -> Tensor<B, 1> {\n        let tensor = self.forward_no_reduction(input1, input2, target);\n        match &self.reduction {\n            Reduction::Mean | Reduction::Auto => tensor.mean(),\n            Reduction::Sum => tensor.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Compute loss without applying reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `input1` - First input tensor of shape ``[batch_size, embedding_dim]``\n    /// * `input2` - Second input tensor of shape ``[batch_size, embedding_dim]``\n    /// * `target` - Target tensor of shape ``[batch_size]`` with values 1 or -1\n    ///\n    /// # Returns\n    ///\n    /// Tensor of per-element losses with shape ``[batch_size]``\n    pub fn forward_no_reduction<B: Backend>(\n        &self,\n        input1: Tensor<B, 2>,\n        input2: Tensor<B, 2>,\n        target: Tensor<B, 1, Int>,\n    ) -> Tensor<B, 1> {\n        self.assertions(&input1, &input2, &target);\n\n        // cos_sim shape: [batch_size, 1]\n        let cos_sim = cosine_similarity(input1, input2, 1, None);\n        // cos_sim shape: [batch_size]\n        let cos_sim: Tensor<B, 1> = cos_sim.squeeze_dim(1);\n\n        let mut loss = cos_sim.zeros_like();\n\n        // Similar pairs (target == 1) - Formula: L = 1 - cos_sim\n        let similar_mask = target.clone().equal_elem(1);\n        let similar_loss = cos_sim.clone().neg().add_scalar(1);\n        loss = loss.mask_where(similar_mask, similar_loss);\n\n        // Dissimilar pairs (target == -1) - Formula: L = max(0, cos_sim - margin)\n        let dissimilar_mask = target.equal_elem(-1);\n        let dissimilar_loss = relu(cos_sim.clone().sub_scalar(self.margin));\n        loss = loss.mask_where(dissimilar_mask, dissimilar_loss);\n\n        // return loss shape: [batch_size]\n        loss\n    }\n\n    fn assertions<B: Backend>(\n        &self,\n        input1: &Tensor<B, 2>,\n        input2: &Tensor<B, 2>,\n        target: &Tensor<B, 1, Int>,\n    ) {\n        let [batch_size1, dim1] = input1.dims();\n        let [batch_size2, dim2] = input2.dims();\n        let [batch_size_target] = target.dims();\n\n        assert_eq!(\n            batch_size1, batch_size2,\n            \"Batch size of input1 ({batch_size1}) must match batch size of input2 ({batch_size2})\"\n        );\n\n        assert_eq!(\n            dim1, dim2,\n            \"Embedding dimension of input1 ({dim1}) must match embedding dimension of input2 ({dim2})\"\n        );\n\n        assert_eq!(\n            batch_size1, batch_size_target,\n            \"Batch size of inputs ({batch_size1}) must match batch size of target ({batch_size_target})\"\n        );\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn cosine_embedding_loss_positive_target() {\n        let device = Default::default();\n\n        // Two identical vectors should have cosine similarity of 1\n        let input1 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        let input2 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        // Target 1 means that inputs should be similar\n        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 1]), &device);\n\n        let loss = CosineEmbeddingLossConfig::new().init();\n        let loss_no_reduction =\n            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());\n        let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());\n\n        let loss_sum = loss.forward(input1, input2, target);\n\n        // For identical vectors, 1 - cos_sim = 1 - 1 = 0\n        let expected_no_reduction = TensorData::from([0.0, 0.0]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());\n\n        let expected_mean = TensorData::from([0.0]);\n        loss_mean\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());\n\n        let expected_sum = TensorData::from([0.0]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());\n    }\n\n    #[test]\n    fn cosine_embedding_loss_negative_target() {\n        let device = Default::default();\n\n        // Two identical vectors should have cosine similarity of 1\n        let input1 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        let input2 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        // Target -1 means that inputs should be dissimilar\n        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([-1, -1]), &device);\n\n        // With margin 0.0, max(0, cos_sim - margin) = max(0, 1 - 0) = 1\n        let loss = CosineEmbeddingLossConfig::new().init();\n        let loss_no_reduction =\n            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());\n        let loss_mean = loss.forward(input1.clone(), input2.clone(), target.clone());\n\n        // Create a loss with Sum reduction for testing\n        let loss_sum_config = CosineEmbeddingLossConfig::new().with_reduction(Reduction::Sum);\n        let loss_sum =\n            loss_sum_config\n                .init()\n                .forward(input1.clone(), input2.clone(), target.clone());\n\n        let expected_no_reduction = TensorData::from([1.0, 1.0]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());\n\n        let expected_mean = TensorData::from([1.0]);\n        loss_mean\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());\n\n        let expected_sum = TensorData::from([2.0]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_sum, Tolerance::default());\n\n        // With margin 0.5, max(0, cos_sim - margin) = max(0, 1 - 0.5) = 0.5\n        let loss_with_margin = CosineEmbeddingLossConfig::new().with_margin(0.5).init();\n        let loss_with_margin = loss_with_margin.forward(input1, input2, target);\n\n        let expected = TensorData::from([0.5]);\n        loss_with_margin\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn cosine_embedding_loss_mixed_targets() {\n        let device = Default::default();\n\n        let input1 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        let input2 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 0.0], [0.0, 1.0]]),\n            &device,\n        );\n\n        // Mixed targets\n        let target = Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, -1]), &device);\n\n        let loss = CosineEmbeddingLossConfig::new().init();\n        let loss_no_reduction =\n            loss.forward_no_reduction(input1.clone(), input2.clone(), target.clone());\n        let loss_mean = loss.forward(input1, input2, target);\n\n        let expected_no_reduction = TensorData::from([0.0, 1.0]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::default());\n\n        let expected_mean = TensorData::from([0.5]);\n        loss_mean\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_mean, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = CosineEmbeddingLossConfig::new().with_margin(0.5);\n        let loss = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{loss}\"),\n            \"CosineEmbeddingLoss {margin: 0.5, reduction: Mean}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/cross_entropy.rs",
    "content": "use burn_core as burn;\nuse burn_core::tensor::IndexingUpdateOp;\n\nuse alloc::string::ToString;\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::activation::log_softmax;\nuse burn::tensor::{Bool, Int, Tensor, backend::Backend};\nuse burn::{config::Config, module::Module};\n\n/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).\n#[derive(Config, Debug)]\npub struct CrossEntropyLossConfig {\n    /// Create padded cross entropy.\n    ///\n    /// Prevents pad tokens from impacting loss calculation.\n    pub pad_tokens: Option<Vec<usize>>,\n\n    /// Create weighted cross-entropy.\n    ///\n    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,\n    ///\n    /// # Pre-conditions\n    ///   - The order of the weight vector should correspond to the label integer assignment.\n    ///   - Targets assigned negative Int's will not be allowed.\n    pub weights: Option<Vec<f32>>,\n\n    /// Create cross-entropy with label smoothing.\n    ///\n    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.\n    /// Alpha = 0 would be the same as default.\n    pub smoothing: Option<f32>,\n\n    /// Create cross-entropy with probabilities as input instead of logits.\n    ///\n    #[config(default = true)]\n    pub logits: bool,\n}\n\nimpl CrossEntropyLossConfig {\n    /// Initialize [Cross-entropy loss](CrossEntropyLoss).\n    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {\n        self.assertions();\n        CrossEntropyLoss {\n            pad_tokens: self.pad_tokens.clone(),\n            weights: self\n                .weights\n                .as_ref()\n                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),\n            smoothing: self.smoothing,\n            logits: self.logits,\n        }\n    }\n\n    fn assertions(&self) {\n        if let Some(alpha) = self.smoothing {\n            assert!(\n                (0.0..=1.).contains(&alpha),\n                \"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}\"\n            );\n        };\n        if let Some(weights) = self.weights.as_ref() {\n            assert!(\n                weights.iter().all(|e| e > &0.),\n                \"Weights of cross-entropy have to be positive.\"\n            );\n        }\n    }\n}\n\n/// Calculate the cross entropy loss from the input logits and the targets.\n///\n/// Should be created using [CrossEntropyLossConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct CrossEntropyLoss<B: Backend> {\n    /// Pad tokens to ignore in the loss calculation.\n    pub pad_tokens: Option<Vec<usize>>,\n    /// Weights for cross-entropy.\n    pub weights: Option<Tensor<B, 1>>,\n    /// Label smoothing factor.\n    pub smoothing: Option<f32>,\n    /// Use logits as input.\n    pub logits: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {\n            alloc::format!(\"Vec<0..{}>\", pad_tokens.len())\n        } else {\n            \"None\".to_string()\n        };\n\n        content\n            .add(\"pad_tokens\", &pad_tokens)\n            .add(\"weights\", &self.weights)\n            .add(\"smoothing\", &self.smoothing)\n            .add(\"logits\", &self.logits)\n            .optional()\n    }\n}\n\nimpl<B: Backend> CrossEntropyLoss<B> {\n    /// For backward compatibility.\n    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {\n        CrossEntropyLossConfig::new()\n            .with_pad_tokens(pad_index.map(|e| vec![e]))\n            .init(device)\n    }\n\n    /// Compute the criterion on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - logits: `[batch_size, num_targets]`\n    /// - targets: `[batch_size]`\n    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {\n        Self::assertions(logits.clone(), targets.clone());\n        match self.smoothing {\n            Some(alpha) => self.forward_smoothed(logits, targets, alpha),\n            _ => self.forward_default(logits, targets),\n        }\n    }\n\n    fn forward_smoothed(\n        &self,\n        logits: Tensor<B, 2>,\n        targets: Tensor<B, 1, Int>,\n        alpha: f32,\n    ) -> Tensor<B, 1> {\n        let mask = self.padding_mask(&targets);\n        let tensor = if self.logits {\n            log_softmax(logits, 1)\n        } else {\n            logits.log()\n        };\n        let [batch_size, nr_classes] = tensor.dims();\n        let tensor = tensor\n            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);\n\n        match &self.weights {\n            Some(weights) => {\n                let tensor = tensor\n                    * weights\n                        .clone()\n                        .reshape([1, nr_classes])\n                        .repeat_dim(0, batch_size);\n                let weights = weights.clone().gather(0, targets);\n                let tensor = Self::apply_mask_2d(tensor, mask);\n                tensor.sum().neg() / weights.sum()\n            }\n            None => {\n                let tensor = Self::apply_mask_2d(tensor, mask);\n                tensor.sum_dim(1).mean().neg()\n            }\n        }\n    }\n\n    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {\n        let [batch_size] = targets.dims();\n\n        let mask = self.padding_mask(&targets);\n        let tensor = log_softmax(logits, 1);\n        let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));\n\n        match &self.weights {\n            Some(weights) => {\n                let weights = weights.clone().gather(0, targets);\n                let tensor = tensor.reshape([batch_size]) * weights.clone();\n                let tensor = Self::apply_mask_1d(tensor, mask);\n                tensor.sum().neg() / weights.sum()\n            }\n            None => {\n                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);\n                tensor.mean().neg()\n            }\n        }\n    }\n\n    fn compute_smoothed_targets(\n        shape: [usize; 2],\n        targets: Tensor<B, 1, Int>,\n        alpha: f32,\n    ) -> Tensor<B, 2> {\n        let [batch_size, nr_classes] = shape;\n        let device = &targets.device();\n        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(\n            1,\n            targets.reshape([batch_size, 1]),\n            Tensor::ones([batch_size, 1], device),\n            IndexingUpdateOp::Add,\n        );\n        targets_matrix * (1. - alpha) + alpha / nr_classes as f32\n    }\n\n    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {\n        let mut mask = None;\n        if let Some(pad_tokens) = &self.pad_tokens {\n            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();\n            for x in pad_tokens {\n                res = res + targets.clone().equal_elem(*x as i64).int();\n            }\n            mask = Some(res.greater_elem(0));\n        }\n\n        mask\n    }\n\n    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {\n        if let Some(mask) = mask {\n            tensor = tensor.mask_fill(mask, 0);\n        }\n\n        tensor\n    }\n\n    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {\n        if let Some(mask) = mask {\n            let [batch_size, nr_classes] = tensor.dims();\n            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);\n        }\n\n        tensor\n    }\n\n    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {\n        let [logits_height, _] = logits.dims();\n        let [targets_height] = targets.dims();\n        assert!(\n            logits_height == targets_height,\n            \"Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height}).\"\n        );\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    macro_rules! setup {\n        () => {{\n            let [batch_size, num_targets] = [4, 5];\n            let device = Default::default();\n            let logits = Tensor::<TestBackend, 2>::random(\n                [batch_size, num_targets],\n                Distribution::Normal(0., 1.0),\n                &device,\n            );\n            let targets =\n                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);\n            let targets_logits = Tensor::<TestBackend, 2>::from_data(\n                TensorData::from([\n                    [0.0, 0.0, 1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 1.0],\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                ]),\n                &device,\n            );\n            (logits, targets, targets_logits)\n        }};\n    }\n\n    macro_rules! setup_padded {\n        () => {{\n            let [batch_size, num_targets, pad_index] = [4, 5, 1];\n            let device = Default::default();\n            let logits = Tensor::<TestBackend, 2>::random(\n                [batch_size, num_targets],\n                Distribution::Normal(0., 1.0),\n                &device,\n            );\n            let targets = Tensor::<TestBackend, 1, Int>::from_data(\n                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),\n                &device,\n            );\n            let targets_logits = Tensor::<TestBackend, 2>::from_data(\n                TensorData::from([\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 1.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ]),\n                &device,\n            );\n            (logits, targets, targets_logits)\n        }};\n    }\n\n    #[test]\n    fn test_cross_entropy_loss_with_weights() {\n        let (logits, targets, targets_logits) = setup!();\n        let weights = vec![1.0, 2., 3., 4., 5.];\n        let device = Default::default();\n        let loss_1 = CrossEntropyLossConfig::new()\n            .with_weights(Some(weights.clone()))\n            .init(&device)\n            .forward(logits.clone(), targets);\n        let tensor = log_softmax(logits, 1);\n        let loss_2 = tensor\n            * targets_logits\n            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)\n                .unsqueeze()\n                .repeat_dim(0, 4);\n        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_label_smoothing_with_weights_and_alpha_zero() {\n        let (logits, targets, _) = setup!();\n        let device = Default::default();\n        let weights = vec![1.0, 2., 3., 4., 5.];\n        let loss_1 = CrossEntropyLossConfig::new()\n            .with_weights(Some(weights.clone()))\n            .init(&device)\n            .forward(logits.clone(), targets.clone());\n        let loss_2 = CrossEntropyLossConfig::new()\n            .with_weights(Some(weights.clone()))\n            .with_smoothing(Some(0.))\n            .init(&device)\n            .forward(logits.clone(), targets);\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_cross_entropy_loss() {\n        let (logits, targets, targets_logits) = setup!();\n        let device = Default::default();\n        let loss_1 = CrossEntropyLossConfig::new()\n            .init(&device)\n            .forward(logits.clone(), targets);\n        let loss_2 = cross_entropy_with_logits(logits, targets_logits);\n\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_label_smoothing_alpha_equal_zero() {\n        let (logits, targets, _) = setup!();\n        let device = Default::default();\n        let loss_1 = CrossEntropyLossConfig::new()\n            .init(&device)\n            .forward(logits.clone(), targets.clone());\n        let loss_2 = CrossEntropyLossConfig::new()\n            .with_smoothing(Some(0.))\n            .init(&device)\n            .forward(logits, targets);\n\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_cross_entropy_loss_with_pad_token() {\n        let (logits, targets, targets_logits) = setup_padded!();\n        let pad_index = 1;\n\n        let loss_1 = CrossEntropyLossConfig::new()\n            .with_pad_tokens(Some(vec![pad_index, 2]))\n            .init(&logits.device())\n            .forward(logits.clone(), targets);\n        let loss_2 = cross_entropy_with_logits(logits, targets_logits);\n\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_label_smoothing_with_zero_alpha_and_pad_token() {\n        let (logits, targets, _) = setup_padded!();\n        let pad_index = 1;\n\n        let loss_1 = CrossEntropyLossConfig::new()\n            .with_pad_tokens(Some(vec![pad_index, 2]))\n            .init(&logits.device())\n            .forward(logits.clone(), targets.clone());\n        let loss_2 = CrossEntropyLossConfig::new()\n            .with_pad_tokens(Some(vec![pad_index, 2]))\n            .with_smoothing(Some(0.))\n            .init(&logits.device())\n            .forward(logits.clone(), targets);\n\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_label_smoothing_target_conversion() {\n        let (logits, targets, _) = setup!();\n        let smoothed_targets =\n            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);\n        let targets_logits = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([\n                [0.01, 0.01, 0.96, 0.01, 0.01],\n                [0.96, 0.01, 0.01, 0.01, 0.01],\n                [0.01, 0.01, 0.01, 0.01, 0.96],\n                [0.01, 0.96, 0.01, 0.01, 0.01],\n            ]),\n            &Default::default(),\n        );\n        smoothed_targets\n            .into_data()\n            .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_label_smoothing() {\n        let (logits, targets, _) = setup!();\n        let device = Default::default();\n        let loss_1 = CrossEntropyLossConfig::new()\n            .with_smoothing(Some(0.05))\n            .init(&device)\n            .forward(logits.clone(), targets);\n        let targets_logits = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([\n                [0.01, 0.01, 0.96, 0.01, 0.01],\n                [0.96, 0.01, 0.01, 0.01, 0.01],\n                [0.01, 0.01, 0.01, 0.01, 0.96],\n                [0.01, 0.96, 0.01, 0.01, 0.01],\n            ]),\n            &device,\n        );\n\n        let x = log_softmax(logits, 1);\n        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();\n\n        loss_1\n            .into_data()\n            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = CrossEntropyLossConfig::new()\n            .with_weights(Some(alloc::vec![3., 7., 0.9]))\n            .with_smoothing(Some(0.5));\n        let loss = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{loss}\"),\n            \"CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/ctc.rs",
    "content": "#![allow(clippy::excessive_precision)]\n\nuse super::Reduction;\nuse alloc::vec;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::{Bool, Element, Int, Tensor, backend::Backend, s};\nuse burn_core as burn;\nuse burn_core::tensor::Numeric;\nuse core::f32;\n\n/// Configuration for the [CTC Loss](CTCLoss) module.\n#[derive(Config, Debug)]\npub struct CTCLossConfig {\n    /// The index number used to represent the blank label. Default value is `0`.\n    #[config(default = 0)]\n    pub blank: usize,\n    /// Whether to zero infinite losses and the associated gradients. Default value is `false`.\n    #[config(default = false)]\n    pub zero_infinity: bool,\n}\n\nimpl CTCLossConfig {\n    /// Initialize a new [CTC Loss](CTCLoss) module\n    pub fn init(&self) -> CTCLoss {\n        CTCLoss {\n            blank: self.blank,\n            zero_infinity: self.zero_infinity,\n        }\n    }\n}\n\n/// Computes the Connectionist Temporal Classification (CTC) loss.\n///\n/// Calculates the loss between a continuous (unsegmented) time series and a target sequence.\n/// CTC sums over the probability of all possible alignments of the input to the target,\n/// producing a loss value that is differentiable with respect to each input node.\n///\n/// The input to this loss is expected to be **log-probabilities** (e.g,, via `log_softmax`),\n/// not raw logits.\n///\n/// # References\n///\n/// - [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](https://www.cs.toronto.edu/~graves/icml_2006.pdf)\n///\n/// # Example\n///\n/// ```rust,ignore\n/// use burn::tensor::{Tensor, Int};\n/// use burn::tensor::activation::log_softmax;\n/// use burn::nn::loss::{CTCLossConfig, CTCLoss};\n///\n/// let device = Default::default();\n///\n/// // Initialize CTC Loss with default configuration\n/// let ctc_loss = CTCLossConfig::new().init();\n///\n/// // Initialize CTC Loss with custom configuration\n/// let ctc_loss = CTCLossConfig::new()\n///     .with_blank(1)\n///     .with_zero_infinity(true)\n///     .init();\n///\n/// // Prepare inputs (Logits shape: [Time, Batch, Class])\n/// // In your actual code, the logits would be the output of your model\n/// let logits = Tensor::<B, 3>::ones([10, 2, 5], &device);\n/// let log_probs = log_softmax(logits, 2);\n///\n/// // Targets shape: [Batch, Max_Target_Len]\n/// // Note: Targets should not contain the blank index (1).\n/// let targets = Tensor::<B, 2, Int>::from_data([[0, 2], [3, 4]], &device);\n///\n/// // Lengths shape: [Batch]\n/// let input_lengths = Tensor::<B, 1, Int>::from_data([10, 8], &device);\n/// let target_lengths = Tensor::<B, 1, Int>::from_data([2, 2], &device);\n///\n/// // Compute loss\n/// let loss = ctc_loss.forward(log_probs, targets, input_lengths, target_lengths);\n/// ```\n#[derive(Module, Clone, Debug)]\npub struct CTCLoss {\n    blank: usize,\n    zero_infinity: bool,\n}\n\nimpl CTCLoss {\n    /// Computes the CTC loss for the input log-probabilities and targets with no reduction applied.\n    ///\n    /// # Arguments\n    ///\n    /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`).\n    /// - `targets`: A 2D tensor containing the target class indices. These indices should not\n    ///   include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.\n    /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This\n    ///   allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains\n    ///   sequences of varying lengths.\n    /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target\n    ///   sequence in `targets`.\n    ///\n    /// # Returns\n    ///\n    /// - A 1D tensor of shape `[batch_size]` containing the loss for each sample.\n    ///\n    /// # Shapes\n    ///\n    /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank.\n    /// - `targets`: `[batch_size, max_target_length]`\n    /// - `input_lengths`: `[batch_size]`\n    /// - `target_lengths`: `[batch_size]`\n    pub fn forward<B: Backend>(\n        &self,\n        log_probs: Tensor<B, 3>,\n        targets: Tensor<B, 2, Int>,\n        input_lengths: Tensor<B, 1, Int>,\n        target_lengths: Tensor<B, 1, Int>,\n    ) -> Tensor<B, 1> {\n        let device = log_probs.device();\n        let [max_input_length, batch_size, num_classes] = log_probs.dims(); // [T, N, C]\n        let max_target_len = targets.dims()[1];\n        let input_lengths_len = input_lengths.dims()[0];\n        let target_lengths_len = target_lengths.dims()[0];\n        self.assertions(\n            batch_size,\n            num_classes,\n            targets.clone(),\n            input_lengths_len,\n            target_lengths_len,\n        );\n\n        // Build the modified label sequence l' by inserting blanks around every label\n        let blank_inserted_targets =\n            self.insert_blanks::<B>(&targets, batch_size, max_target_len, &device);\n\n        // Initialize the forward variable alpha\n        let max_l_prime_len = 2 * max_target_len + 1;\n        let mut log_alpha_t_s =\n            Tensor::<B, 2>::full([batch_size, max_l_prime_len], f32::NEG_INFINITY, &device);\n        log_alpha_t_s = self.initialize_log_alpha(\n            log_probs.clone(),\n            blank_inserted_targets.clone(),\n            log_alpha_t_s,\n        );\n\n        let l_prime_combined_mask = self.create_l_prime_mask(\n            blank_inserted_targets.clone(),\n            batch_size,\n            max_l_prime_len,\n            &device,\n        );\n        let s_mask =\n            self.create_s_mask(max_l_prime_len, batch_size, target_lengths.clone(), &device);\n\n        // Loop over time steps since an arbitrary time step t depends on t - 1\n        for t in 1..max_input_length {\n            let combined_s_t_mask = self.create_combined_s_t_mask(\n                input_lengths.clone(),\n                t,\n                batch_size,\n                max_l_prime_len,\n                s_mask.clone(),\n            );\n            log_alpha_t_s = self.compute_log_alpha_t_s(\n                t,\n                combined_s_t_mask,\n                log_alpha_t_s,\n                l_prime_combined_mask.clone(),\n                log_probs.clone(),\n                blank_inserted_targets.clone(),\n            );\n        }\n\n        let last_blank_indices = target_lengths.mul_scalar(2).reshape([batch_size, 1]);\n        let last_label_indices = last_blank_indices.clone().sub_scalar(1);\n        let log_alpha_last_blank = log_alpha_t_s\n            .clone()\n            .gather(1, last_blank_indices)\n            .squeeze_dim::<1>(1);\n        let log_alpha_last_label = log_alpha_t_s\n            .clone()\n            .gather(1, last_label_indices)\n            .squeeze_dim::<1>(1);\n        let log_likelihood = self.log_sum_exp(log_alpha_last_blank, log_alpha_last_label, &device);\n        let mut ctc_loss_tensor = log_likelihood.neg();\n\n        if self.zero_infinity {\n            let inf_mask = ctc_loss_tensor.clone().is_inf();\n            ctc_loss_tensor = ctc_loss_tensor\n                .clone()\n                .mask_where(inf_mask, ctc_loss_tensor.clone().zeros_like());\n        }\n\n        ctc_loss_tensor\n    }\n\n    /// Computes the CTC loss for the input log-probabilities and targets with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// - `log_probs`: The log-probabilities of the outputs (e.g., from `log_softmax`).\n    /// - `targets`: A 2D tensor containing the target class indices. These indices should not\n    ///   include the blank index used in CTC loss. The targets are padded to the length of the longest sequence.\n    /// - `input_lengths`: A 1D tensor containing the actual length of the input sequence for each batch. This\n    ///   allows retrieving the actual sequence of log-probabilities from `log_probs` if the batch contains\n    ///   sequences of varying lengths.\n    /// - `target_lengths`: A 1D tensor containing the actual length of the target sequence for each target\n    ///   sequence in `targets`.\n    /// - `reduction`: The reduction stratey to apply to the loss tensor containing the CTC loss values for\n    ///   each sample (e.g., mean, sum). For the mean reduction strategy, the output losses will be divided\n    ///   by the target lengths and then the mean over the batch is taken. This follows PyTorch's behavior.\n    ///\n    /// # Returns\n    ///\n    /// - A 1D tensor of shape `[1]` containing the reduced loss value.\n    ///\n    /// # Shapes\n    ///\n    /// - `log_probs`: `[time_steps, batch_size, num_classes]` where `num_classes` includes blank.\n    /// - `targets`: `[batch_size, max_target_length]`\n    /// - `input_lengths`: `[batch_size]`\n    /// - `target_lengths`: `[batch_size]`\n    ///\n    /// # Panics\n    /// - If `reduction` is not one of `Reduction::Auto`, `Reduction::Mean`, and `Reduction::Sum`.\n    /// - If `blank` index is greater than or equal to `num_classes`.\n    /// - If the batch dimension of `log_probs`, `targets`, `input_lengths`, and `target_lengths` do not match.\n    pub fn forward_with_reduction<B: Backend>(\n        &self,\n        log_probs: Tensor<B, 3>,\n        targets: Tensor<B, 2, Int>,\n        input_lengths: Tensor<B, 1, Int>,\n        target_lengths: Tensor<B, 1, Int>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let ctc_loss_tensor =\n            self.forward(log_probs, targets, input_lengths, target_lengths.clone());\n\n        match reduction {\n            Reduction::Auto | Reduction::Mean => {\n                // Following PyTorch's behavior where the output losses are divided\n                // by the target lengths and then the mean over the batch is taken\n                let target_lengths_float = target_lengths.float();\n                ctc_loss_tensor.div(target_lengths_float).mean()\n            }\n            Reduction::Sum => ctc_loss_tensor.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    fn assertions<B: Backend>(\n        &self,\n        batch_size: usize,\n        num_classes: usize,\n        targets: Tensor<B, 2, Int>,\n        input_lengths_len: usize,\n        target_lengths_len: usize,\n    ) {\n        assert!(\n            self.blank < num_classes,\n            \"blank index {} must be less than num_classes {}\",\n            self.blank,\n            num_classes\n        );\n        assert_eq!(\n            targets.dims()[0],\n            batch_size,\n            \"targets batch dimension {} must equal batch_size {}\",\n            targets.dims()[0],\n            batch_size\n        );\n        assert_eq!(\n            input_lengths_len, batch_size,\n            \"input_lengths length {} must equal batch_size {}\",\n            input_lengths_len, batch_size\n        );\n        assert_eq!(\n            target_lengths_len, batch_size,\n            \"target_lengths length {} must equal batch_size {}\",\n            target_lengths_len, batch_size\n        );\n    }\n\n    fn insert_blanks<B: Backend>(\n        &self,\n        targets: &Tensor<B, 2, Int>,\n        batch_size: usize,\n        max_target_len: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 2, Int> {\n        // The modified label sequences have (max_target_len + 1) blank labels\n        let blank_tensor = Tensor::<B, 2, Int>::full(\n            [batch_size, 2 * max_target_len + 1],\n            self.blank as i64,\n            device,\n        );\n\n        blank_tensor.slice_assign(s![.., 1..;2], targets.clone())\n    }\n\n    fn initialize_log_alpha<B: Backend>(\n        &self,\n        log_probs: Tensor<B, 3>,\n        blank_inserted_targets: Tensor<B, 2, Int>,\n        log_alpha_t_s: Tensor<B, 2>,\n    ) -> Tensor<B, 2> {\n        // Given alpha_t(s), we have:\n        // alpha_1(1) = (y_blank)^1  => log_alpha_1(1) = ln(y_blank)^1\n        // alpha_1(2) = (y_l1)^1  => log_alpha_1(2) = ln(y_l1)^1\n        // alpha_1(s) = 0 (for every s > 2)  => log_alpha_1(s) = -neg_inf\n        let log_probs_t0 = log_probs\n            .clone()\n            .slice(s![0..1, .., ..])\n            .squeeze_dim::<2>(0); // shape: [N, C]\n\n        // log_alpha shape: [N, 2*S+1]\n        // log_probs shape: [T, N, C]\n        // log_alpha[:, 0] = log_probs[0, :, blank]\n        let first_blank = blank_inserted_targets.clone().slice(s![.., 0..1]); // [N, 1]\n        // log_probs_t0 have C columns where each represents a unique class (includes blank)\n        let log_prob_blank = log_probs_t0.clone().gather(1, first_blank); // [N, 1]\n        let temp_log_alpha_t_s = log_alpha_t_s.slice_assign(s![.., 0..1], log_prob_blank);\n\n        // log_alpha[:, 1] = log_probs[0, :, targets[:, 0]]\n        let first_label = blank_inserted_targets.clone().slice(s![.., 1..2]); // [N, 1]\n        let log_prob_first_label = log_probs_t0.gather(1, first_label); // [N, 1]\n        temp_log_alpha_t_s.slice_assign(s![.., 1..2], log_prob_first_label)\n    }\n\n    fn right_shift_2d_tensor<B: Backend, K>(\n        &self,\n        org_2d_tensor: Tensor<B, 2, K>,\n        shift_by: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 2, K>\n    where\n        K: Numeric<B>,\n        K::Elem: Element,\n    {\n        assert!(\n            shift_by == 1 || shift_by == 2,\n            \"The parameter shift_by must 1 or 2\"\n        );\n\n        let [rows, cols] = org_2d_tensor.dims();\n        let padding_shape = [rows, shift_by];\n        let padding_tensor = if org_2d_tensor.dtype().is_float() {\n            Tensor::<B, 2, K>::full(padding_shape, f32::NEG_INFINITY, device)\n        } else {\n            Tensor::<B, 2, K>::full(padding_shape, 0, device)\n        };\n        let org_tensor_shortened = org_2d_tensor.slice(s![.., ..cols - shift_by]);\n\n        Tensor::cat(vec![padding_tensor, org_tensor_shortened], 1)\n    }\n\n    fn create_l_prime_mask<B: Backend>(\n        &self,\n        blank_inserted_targets: Tensor<B, 2, Int>,\n        batch_size: usize,\n        max_l_prime_len: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 2, Bool> {\n        let l_prime_s = blank_inserted_targets.clone();\n        let l_prime_s_minus_2 = self.right_shift_2d_tensor(blank_inserted_targets, 2, device);\n\n        // Create a single mask that is true for entries where alpha_{t-1}(s - 2) should also\n        // be added to compute alpha_{t}(s)\n        let s_is_not_blank_mask = l_prime_s.clone().not_equal_elem(self.blank as i64);\n        let s_not_equal_s_minus_2_mask = l_prime_s.not_equal(l_prime_s_minus_2);\n\n        // The 2 leftmost columns of the returned mask should only contain false.\n        // These are invalid positions since s - 2 is a valid index only when s >= 2.\n        let col_indices = Tensor::<B, 1, Int>::arange(0..(max_l_prime_len as i64), device)\n            .reshape([1, max_l_prime_len])\n            .expand([batch_size, max_l_prime_len]);\n        let s_greater_than_1_mask = col_indices.greater_equal_elem(2);\n\n        s_is_not_blank_mask\n            .bool_and(s_not_equal_s_minus_2_mask)\n            .bool_and(s_greater_than_1_mask)\n    }\n\n    fn create_s_mask<B: Backend>(\n        &self,\n        max_l_prime_len: usize,\n        batch_size: usize,\n        target_lengths: Tensor<B, 1, Int>,\n        device: &B::Device,\n    ) -> Tensor<B, 2, Bool> {\n        let col_indices = Tensor::<B, 1, Int>::arange(0..max_l_prime_len as i64, device)\n            .reshape([1, max_l_prime_len]);\n        let col_indices_expanded = col_indices.expand([batch_size, max_l_prime_len]);\n        let blank_inserted_target_lengths = target_lengths\n            .mul_scalar(2)\n            .add_scalar(1)\n            .reshape([batch_size, 1]);\n        let target_lengths_expanded =\n            blank_inserted_target_lengths.expand([batch_size, max_l_prime_len]);\n\n        col_indices_expanded.lower(target_lengths_expanded)\n    }\n\n    fn log_sum_exp<const D: usize, B: Backend>(\n        &self,\n        log_tensor1: Tensor<B, D>,\n        log_tensor2: Tensor<B, D>,\n        device: &B::Device,\n    ) -> Tensor<B, D> {\n        let shape = log_tensor1.dims();\n        let ones_tensor = Tensor::<B, D>::ones(shape, device);\n\n        // Let A and B represent parameters tensor1 and tensor2 respectively.\n        // Let C be the tensor this method returns.\n        // If an entry in both A and B are neg_inf, then the same entry\n        // in C should also contain neg_inf.\n        // If an entry in only one of A or B is neg_inf, then the same entry in\n        // C should contain the value of the other tensor entry which is not neg_inf.\n        let tensor1_is_neg_inf = log_tensor1.clone().equal_elem(f32::NEG_INFINITY);\n        let tensor2_is_neg_inf = log_tensor2.clone().equal_elem(f32::NEG_INFINITY);\n        let temp_tensor1 = ones_tensor\n            .clone()\n            .mask_where(tensor1_is_neg_inf.clone(), log_tensor2.clone());\n        let neg_inf_lse_tensor =\n            temp_tensor1.mask_where(tensor2_is_neg_inf.clone(), log_tensor1.clone());\n\n        // Create sanitized tensors for math operations to prevent NaN. Replace neg_inf\n        // with 0.0. The tensor neg_inf_lse_tensor contains correct values for entries\n        // where at least one of the corresponding entries in log_tensor1 or log_tensor2\n        // is neg_inf. Hence, the math operations below is computing the values for entries\n        // that are not already filled with their actual/correct values. Thus, result for\n        // these positions (where we sanitize) are not used anyway since the\n        // unfilled_entries_mask is applied at the end.\n        let tensor1_safe = log_tensor1\n            .clone()\n            .mask_fill(tensor1_is_neg_inf.clone(), 0.0);\n        let tensor2_safe = log_tensor2\n            .clone()\n            .mask_fill(tensor2_is_neg_inf.clone(), 0.0);\n\n        // Create a mask which contains true for entries whose values were not\n        // set by operations above\n        let filled_entries_mask = tensor1_is_neg_inf.bool_or(tensor2_is_neg_inf);\n        let unfilled_entries_mask = filled_entries_mask.bool_not();\n\n        let max_tensor = tensor1_safe.clone().max_pair(tensor2_safe.clone());\n        let diff_tensor = tensor1_safe.sub(tensor2_safe);\n        let exp_tensor = diff_tensor.abs().neg().exp();\n        let ln_tensor = ones_tensor.add(exp_tensor).log();\n        let lse_tensor = max_tensor.add(ln_tensor);\n        neg_inf_lse_tensor.mask_where(unfilled_entries_mask, lse_tensor)\n    }\n\n    fn create_combined_s_t_mask<B: Backend>(\n        &self,\n        input_lengths: Tensor<B, 1, Int>,\n        t: usize,\n        batch_size: usize,\n        max_l_prime_len: usize,\n        s_mask: Tensor<B, 2, Bool>,\n    ) -> Tensor<B, 2, Bool> {\n        // Create masks for valid t and s\n        let t_mask_1d = input_lengths\n            .clone()\n            .greater_elem(t as i64)\n            .reshape([batch_size, 1]);\n        let t_mask = t_mask_1d.expand([batch_size, max_l_prime_len]);\n\n        t_mask.bool_and(s_mask.clone())\n    }\n\n    fn compute_log_alpha_t_s<B: Backend>(\n        &self,\n        t: usize,\n        combined_s_t_mask: Tensor<B, 2, Bool>,\n        log_alpha_t_s: Tensor<B, 2>,\n        l_prime_combined_mask: Tensor<B, 2, Bool>,\n        log_probs: Tensor<B, 3>,\n        blank_inserted_targets: Tensor<B, 2, Int>,\n    ) -> Tensor<B, 2> {\n        let device = log_probs.device();\n        let log_alpha_t_minus_1 = log_alpha_t_s.clone();\n\n        // No move from last time step: alpha_{t-1}(s)\n        let log_alpha_s = log_alpha_t_minus_1.clone();\n\n        // Single move from last time step: alpha_{t-1}(s - 1)\n        let log_alpha_s_minus_1 =\n            self.right_shift_2d_tensor(log_alpha_t_minus_1.clone(), 1, &device);\n\n        // A skip move (moving 2 positions) from last time step: alpha_{t-1}(s - 2)\n        let log_alpha_s_minus_2 =\n            self.right_shift_2d_tensor(log_alpha_t_minus_1.clone(), 2, &device);\n\n        // Compute alpha_{t}(s) using recursion, corresponding to equation 6 of the paper.\n        let log_alpha_bar = self.log_sum_exp(log_alpha_s, log_alpha_s_minus_1, &device);\n        let log_alpha_bar_plus_log_alpha_s_minus_2 =\n            self.log_sum_exp(log_alpha_bar.clone(), log_alpha_s_minus_2, &device);\n        let log_alpha_s_to_s_minus_2 = log_alpha_bar.mask_where(\n            l_prime_combined_mask.clone(),\n            log_alpha_bar_plus_log_alpha_s_minus_2,\n        ); // [N, 2 * U + 1]\n        let log_probs_t = log_probs.clone().slice(s![t, .., ..]).squeeze_dim::<2>(0); // [N, C]\n        let log_probs_l_prime_s = log_probs_t.gather(1, blank_inserted_targets.clone());\n        let temp_log_alpha_t_s = log_alpha_s_to_s_minus_2.add(log_probs_l_prime_s);\n        log_alpha_t_s.mask_where(combined_s_t_mask, temp_log_alpha_t_s)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_ndarray::{NdArray, NdArrayDevice};\n\n    type TestBackend = NdArray<f32>;\n\n    fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {\n        assert_eq!(\n            actual.len(),\n            expected.len(),\n            \"Length mismatch: actual {} vs expected {}\",\n            actual.len(),\n            expected.len()\n        );\n        for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {\n            assert!(\n                (a - e).abs() < tol,\n                \"Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})\",\n                i,\n                e,\n                a,\n                (a - e).abs()\n            );\n        }\n    }\n\n    // ---------------------------------------------------------------\n    // insert_blanks tests\n    // ---------------------------------------------------------------\n\n    #[test]\n    fn test_insert_blanks_single_sample() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64, 2, 3]], &device);\n        let result = ctc.insert_blanks::<TestBackend>(&targets, 1, 3, &device);\n        let result_data = result.into_data().to_vec::<i64>().unwrap();\n        assert_eq!(result_data, vec![0, 1, 0, 2, 0, 3, 0]);\n    }\n\n    #[test]\n    fn test_insert_blanks_batch() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64, 2], [3, 4]], &device);\n        let result = ctc.insert_blanks::<TestBackend>(&targets, 2, 2, &device);\n        let result_data = result.into_data().to_vec::<i64>().unwrap();\n        assert_eq!(result_data, vec![0, 1, 0, 2, 0, 0, 3, 0, 4, 0]);\n    }\n\n    #[test]\n    fn test_insert_blanks_custom_blank() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_blank(2).init();\n\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[0_i64, 1]], &device);\n        let result = ctc.insert_blanks::<TestBackend>(&targets, 1, 2, &device);\n        let result_data = result.into_data().to_vec::<i64>().unwrap();\n        // l' = [blank=2, 0, blank=2, 1, blank=2]\n        assert_eq!(result_data, vec![2, 0, 2, 1, 2]);\n    }\n\n    // ---------------------------------------------------------------\n    // Assertions\n    // ---------------------------------------------------------------\n\n    #[test]\n    #[should_panic(expected = \"blank index\")]\n    fn test_ctc_loss_panics_invalid_blank_index() {\n        let device = NdArrayDevice::Cpu;\n        // blank=5 is out of bounds for num_classes=3\n        let ctc = CTCLossConfig::new().with_blank(5).init();\n\n        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 1, 3], &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);\n\n        ctc.forward(log_probs, targets, input_lengths, target_lengths);\n    }\n\n    #[test]\n    #[should_panic(expected = \"must equal batch_size\")]\n    fn test_ctc_loss_panics_mismatched_batch_size() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        // Logits batch size = 2\n        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);\n        // Targets batch size = 1 (Mismatch)\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);\n\n        ctc.forward(log_probs, targets, input_lengths, target_lengths);\n    }\n\n    #[test]\n    #[should_panic(expected = \"input_lengths length\")]\n    fn test_ctc_loss_panics_input_lengths_mismatch() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        // Logits batch size = 2\n        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);\n\n        // Input lengths size = 1 (Mismatch)\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1, 1], &device);\n\n        ctc.forward(log_probs, targets, input_lengths, target_lengths);\n    }\n\n    #[test]\n    #[should_panic(expected = \"target_lengths length\")]\n    fn test_ctc_loss_panics_target_lengths_mismatch() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        // Logits batch size = 2\n        let log_probs = Tensor::<TestBackend, 3>::zeros([2, 2, 3], &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1], [2]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2, 2], &device);\n\n        // Target lengths size = 1 (Mismatch)\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1], &device);\n\n        ctc.forward(log_probs, targets, input_lengths, target_lengths);\n    }\n\n    // ---------------------------------------------------------------\n    // Edge Case & Config Tests\n    // ---------------------------------------------------------------\n\n    #[test]\n    fn test_ctc_loss_repeated_labels_minimum_input_length() {\n        // T=3, N=1, C=2, blank=0, target=[1, 1], uniform P = 1/2.\n        //\n        // The minimum T for target [1, 1] is 3: the only valid path is (1, 0, 1).\n        // prob = (1/2)^3 = 1/8\n        // Loss = -ln(1/8) = 3 * ln(2)\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 2], 0.5_f32.ln(), &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64, 1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i64], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.into_data().to_vec::<f32>().unwrap();\n        let expected = 3.0 * 2.0_f32.ln();\n        assert_approx_equal(&loss_data, &[expected], 1e-3);\n    }\n\n    #[test]\n    fn test_ctc_loss_custom_blank_uniform() {\n        // T=3, N=1, C=3, blank=2, target=[0, 1], uniform P = 1/3.\n        //\n        // Two distinct labels, 3 classes, 3 time steps, just with\n        // blank=2 instead of 0.\n        // 5 valid paths → total = 5/27\n        // Loss = -ln(5/27)\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_blank(2).init();\n\n        let log_probs = Tensor::<TestBackend, 3>::full([3, 1, 3], (1.0_f32 / 3.0).ln(), &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[0_i64, 1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([3_i64], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.into_data().to_vec::<f32>().unwrap();\n        let expected = -(5.0_f32 / 27.0).ln();\n        assert_approx_equal(&loss_data, &[expected], 1e-3);\n    }\n\n    // ---------------------------------------------------------------\n    // zero_infinity tests\n    // ---------------------------------------------------------------\n\n    #[test]\n    fn test_ctc_loss_zero_infinity_produces_inf_when_disabled() {\n        // T=2, N=1, C=3, blank=0, target=[1, 1], input_length=2\n        // Target [1, 1] requires at least 3 time steps → no valid paths → loss = +inf\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_zero_infinity(false).init();\n\n        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64, 1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.into_data().to_vec::<f32>().unwrap();\n        assert!(\n            loss_data[0].is_infinite() && loss_data[0] > 0.0,\n            \"Expected +inf, got {}\",\n            loss_data[0]\n        );\n    }\n\n    #[test]\n    fn test_ctc_loss_zero_infinity_masks_inf_when_enabled() {\n        // Same inputs as above, but zero_infinity=true → loss should be 0.0\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_zero_infinity(true).init();\n\n        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 3], (1.0_f32 / 3.0).ln(), &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64, 1]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.into_data().to_vec::<f32>().unwrap();\n        assert_approx_equal(&loss_data, &[0.0], 1e-6);\n    }\n\n    #[test]\n    fn test_ctc_loss_zero_infinity_does_not_affect_finite_loss() {\n        // Verify that zero_infinity=true does not change a finite loss value.\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_zero_infinity(true).init();\n\n        let log_probs = Tensor::<TestBackend, 3>::full([2, 1, 2], 0.5_f32.ln(), &device);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data([[1_i64]], &device);\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data([2_i64], &device);\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data([1_i64], &device);\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.into_data().to_vec::<f32>().unwrap();\n        let expected = -(0.75_f32).ln();\n        assert_approx_equal(&loss_data, &[expected], 1e-3);\n    }\n}\n\n#[cfg(test)]\nmod pytorch_comparison_tests {\n    use super::*;\n    use burn::tensor::activation::log_softmax;\n    use burn_autodiff::Autodiff;\n    use burn_core::tensor::TensorData;\n    use burn_ndarray::{NdArray, NdArrayDevice};\n\n    type InnerBackend = NdArray<f32>;\n    type TestBackend = Autodiff<InnerBackend>;\n\n    fn assert_approx_equal(actual: &[f32], expected: &[f32], tol: f32) {\n        assert_eq!(\n            actual.len(),\n            expected.len(),\n            \"Length mismatch: actual {} vs expected {}\",\n            actual.len(),\n            expected.len()\n        );\n        for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {\n            assert!(\n                (a - e).abs() < tol,\n                \"Mismatch at index {}: expected {:.6}, got {:.6} (diff: {:.6})\",\n                i,\n                e,\n                a,\n                (a - e).abs()\n            );\n        }\n    }\n\n    /// Deterministic logits: sin((t*7 + n*13 + c*3) * 0.1).\n    fn generate_logits(\n        t_size: usize,\n        n_size: usize,\n        c_size: usize,\n        device: &NdArrayDevice,\n    ) -> Tensor<TestBackend, 3> {\n        let mut data = Vec::with_capacity(t_size * n_size * c_size);\n        for t in 0..t_size {\n            for n in 0..n_size {\n                for c in 0..c_size {\n                    data.push(((t * 7 + n * 13 + c * 3) as f32 * 0.1).sin());\n                }\n            }\n        }\n        Tensor::<TestBackend, 3>::from_data(TensorData::new(data, [t_size, n_size, c_size]), device)\n    }\n\n    /// Runs a CTC forward + backward test and asserts against expected values from PyTorch.\n    ///\n    /// This helper performs the following steps:\n    /// 1. Generates deterministic logits using a sine-wave formula.\n    /// 2. Computes the CTC loss (forward pass).\n    /// 3. Asserts the computed loss matches `expected_losses`.\n    /// 4. Backpropagates the sum of the loss.\n    /// 5. Asserts the resulting gradients w.r.t. logits match `expected_grad_flat`.\n    ///\n    /// # Arguments\n    ///\n    /// - `expected_losses`: per-sample loss values from PyTorch (reduction='none').\n    /// - `expected_grad_flat`: flattened gradient of sum(loss) w.r.t. logits.\n    #[allow(clippy::too_many_arguments)]\n    fn run_comparison(\n        label: &str,\n        t_size: usize,\n        n_size: usize,\n        c_size: usize,\n        targets_flat: Vec<i64>,\n        target_shape: [usize; 2],\n        input_lengths: Vec<i64>,\n        target_lengths: Vec<i64>,\n        blank: usize,\n        expected_losses: &[f32],\n        expected_grad_flat: &[f32],\n        loss_tol: f32,\n        grad_tol: f32,\n    ) {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().with_blank(blank).init();\n\n        let logits = generate_logits(t_size, n_size, c_size, &device).require_grad();\n        let log_probs = log_softmax(logits.clone(), 2);\n\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::new(targets_flat, target_shape),\n            &device,\n        );\n        let input_lengths = Tensor::<TestBackend, 1, Int>::from_data(\n            TensorData::new(input_lengths, [n_size]),\n            &device,\n        );\n        let target_lengths = Tensor::<TestBackend, 1, Int>::from_data(\n            TensorData::new(target_lengths, [n_size]),\n            &device,\n        );\n\n        let loss = ctc.forward(log_probs, targets, input_lengths, target_lengths);\n        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();\n\n        println!(\"=== {} ===\", label);\n        println!(\"  Loss: {:?}\", loss_data);\n        assert_approx_equal(&loss_data, expected_losses, loss_tol);\n\n        let loss_sum = loss.sum();\n        let grads = loss_sum.backward();\n        let logits_grad = logits.grad(&grads).unwrap();\n        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();\n        assert_approx_equal(&grad_data, expected_grad_flat, grad_tol);\n    }\n\n    #[test]\n    fn test_ctc_loss_uniform_input_lengths() {\n        // T=5, N=3, C=4, all input_lengths = 5\n        // Expected losses and gradient from PyTorch\n        let expected_losses = [3.5236570835113525_f32, 3.495313882827759, 4.262677192687988];\n        let expected_grad_flat = [\n            -0.1679008007_f32,\n            -0.4595540464,\n            0.2795598209,\n            0.3478950262,\n            -0.3913056254,\n            -0.0832268298,\n            0.2535884976,\n            0.2209439576,\n            -0.0502742566,\n            0.2766197622,\n            0.2054125518,\n            -0.4317580462,\n            -0.0544800088,\n            -0.3144550920,\n            0.0847885981,\n            0.2841464877,\n            -0.1844545156,\n            -0.2063435912,\n            0.2222184092,\n            0.1685796976,\n            0.0278018005,\n            0.2657383382,\n            -0.0336986706,\n            -0.2598414719,\n            -0.0482986756,\n            -0.0098767160,\n            -0.1533526182,\n            0.2115280181,\n            -0.1380317956,\n            -0.2198686600,\n            0.2042596638,\n            0.1536407918,\n            0.0534787849,\n            0.1819230020,\n            -0.2805589139,\n            0.0451571345,\n            -0.0895631388,\n            0.1996460557,\n            -0.2741115987,\n            0.1640286744,\n            -0.2200077325,\n            -0.1693530381,\n            0.2101601064,\n            0.1792006642,\n            0.0398471877,\n            -0.1131042913,\n            -0.2363226712,\n            0.3095797896,\n            -0.2163617164,\n            0.2740726173,\n            -0.2124865055,\n            0.1547756046,\n            -0.4312027395,\n            -0.0446923785,\n            0.2330704331,\n            0.2428246588,\n            -0.0050083841,\n            -0.6256869435,\n            0.2689785957,\n            0.3617166877,\n        ];\n        run_comparison(\n            \"T=5, N=3, C=4 (uniform input lengths)\",\n            5,\n            3,\n            4,\n            vec![1, 2, 0, 1, 0, 0, 3, 2, 1],\n            [3, 3],\n            vec![5, 5, 5],\n            vec![2, 1, 3],\n            0,\n            &expected_losses,\n            &expected_grad_flat,\n            1e-3,\n            1e-3,\n        );\n    }\n\n    #[test]\n    fn test_ctc_loss_repeated_labels() {\n        // T=8, N=4, C=6, includes consecutive repeated label [1,1,2]\n        // Expected losses and gradient from PyTorch\n        let expected_losses = [\n            8.84203052520752_f32,\n            9.023029327392578,\n            9.398024559020996,\n            9.008068084716797,\n        ];\n        let expected_grad_flat = [\n            -0.2766432464,\n            -0.5202965736,\n            0.1523768753,\n            0.1896236390,\n            0.2200277001,\n            0.2349116206,\n            -0.1854365915,\n            0.2031330466,\n            -0.4260218740,\n            0.1678018719,\n            0.1360142529,\n            0.1045092493,\n            -0.6603536606,\n            0.2278252542,\n            0.1691786796,\n            0.1262856424,\n            0.0972681716,\n            0.0397959016,\n            -0.0894432291,\n            -0.5457318425,\n            0.1490373611,\n            0.1462858170,\n            0.1569476575,\n            0.1829041988,\n            -0.2842915654,\n            -0.4220107496,\n            0.1822281033,\n            0.1889107376,\n            0.1791101843,\n            0.1560532600,\n            -0.1155678406,\n            0.2295538932,\n            -0.2645366490,\n            -0.0288553704,\n            0.1027252972,\n            0.0766806602,\n            -0.5448347330,\n            0.2031028718,\n            0.1589304954,\n            0.1322451383,\n            0.1189499870,\n            -0.0683937520,\n            -0.0873993114,\n            -0.3051757514,\n            -0.2355299890,\n            0.1586059481,\n            0.2018169016,\n            0.2676822543,\n            -0.3225219846,\n            -0.2611543834,\n            0.1922984123,\n            0.1632783115,\n            0.1297036558,\n            0.0983960181,\n            -0.1507159024,\n            0.2256962359,\n            -0.1040333956,\n            -0.1514528394,\n            0.0985243544,\n            0.0819815546,\n            -0.2940836251,\n            0.1586865336,\n            0.1468491107,\n            0.1485087872,\n            0.1639631987,\n            -0.3239239752,\n            -0.0767390430,\n            -0.0434846729,\n            -0.4023587406,\n            -0.0052628326,\n            0.2273432612,\n            0.3005020022,\n            -0.2598774135,\n            -0.2188862711,\n            0.1678501070,\n            0.1352078766,\n            0.1002781317,\n            0.0754275694,\n            -0.1502914876,\n            0.1930875033,\n            -0.0709601715,\n            -0.2219523191,\n            0.1243555173,\n            0.1257609427,\n            -0.0574148744,\n            0.1152269915,\n            0.1307857931,\n            0.1599020809,\n            0.2068412602,\n            -0.5553412437,\n            -0.0536844917,\n            0.0758557543,\n            -0.2106334567,\n            -0.2509877980,\n            0.1757438034,\n            0.2637061775,\n            -0.1759711355,\n            -0.2431350052,\n            0.1071053818,\n            0.1259848624,\n            0.1004033238,\n            0.0856125653,\n            -0.1173698306,\n            0.1213828772,\n            -0.1768893301,\n            -0.2070008069,\n            0.1709136516,\n            0.2089634240,\n            0.0153109450,\n            0.0967332721,\n            0.1268781722,\n            0.1706230640,\n            0.2291058898,\n            -0.6386513710,\n            -0.0536664203,\n            0.1378114969,\n            0.0360041447,\n            -0.2989685237,\n            -0.0084722806,\n            0.1872915775,\n            -0.1523490399,\n            -0.2111770809,\n            -0.0390694551,\n            0.1366800815,\n            0.1302325875,\n            0.1356829405,\n            -0.0982905105,\n            -0.0127884001,\n            -0.3586881459,\n            -0.0259541404,\n            0.2114149332,\n            0.2843062580,\n            -0.0324133746,\n            0.1084750593,\n            0.1447229236,\n            0.1862253845,\n            0.2259712219,\n            -0.6329812407,\n            -0.1173689738,\n            0.1914442331,\n            0.1654772907,\n            -0.1376858056,\n            -0.2194855511,\n            0.1176188141,\n            -0.1529908478,\n            -0.0606661662,\n            -0.3384291232,\n            0.1524862647,\n            0.1777049750,\n            0.2218948901,\n            -0.0923086405,\n            -0.2855934799,\n            -0.3215619624,\n            0.1726681292,\n            0.2303666323,\n            0.2964293361,\n            -0.2508065701,\n            0.1479703039,\n            0.1753441393,\n            0.1917535067,\n            0.1919818372,\n            -0.4562432170,\n            -0.2350299209,\n            0.2257601619,\n            0.1863904297,\n            0.0388212129,\n            -0.2966264784,\n            0.0806845874,\n            -0.1992894858,\n            0.1068909168,\n            -0.5761897564,\n            0.1624972969,\n            0.2155302167,\n            0.2905607820,\n            -0.1168124676,\n            -0.6870660186,\n            0.1488010883,\n            0.1881926507,\n            0.2230074406,\n            0.2438773215,\n            -0.5771554708,\n            0.1980127096,\n            0.1924194694,\n            0.1714663208,\n            0.1415647417,\n            -0.1263078004,\n            -0.3408652246,\n            0.2292248607,\n            0.1707807332,\n            0.1269564927,\n            -0.2634142637,\n            0.0773174241,\n        ];\n        run_comparison(\n            \"T=8, N=4, C=6 (repeated labels)\",\n            8,\n            4,\n            6,\n            vec![1, 1, 2, 0, 2, 3, 2, 1, 5, 0, 0, 0, 1, 2, 3, 4],\n            [4, 4],\n            vec![8, 8, 8, 8],\n            vec![3, 4, 1, 4],\n            0,\n            &expected_losses,\n            &expected_grad_flat,\n            1e-3,\n            1e-3,\n        );\n    }\n\n    #[test]\n    fn test_ctc_loss_long_sequence() {\n        // T=10, N=2, C=8\n        // Expected losses and gradient from PyTorch\n        let expected_losses = [12.629399299621582, 12.298524856567383];\n        let expected_grad_flat = [\n            -0.2570972741,\n            -0.6013792753,\n            0.1061997041,\n            0.1321590245,\n            0.1533492655,\n            0.1637226790,\n            0.1598964781,\n            0.1431493312,\n            -0.2540431321,\n            0.1788398325,\n            -0.4038805366,\n            0.1477340311,\n            0.1197479516,\n            0.0920107216,\n            0.0686140805,\n            0.0509770736,\n            -0.1364373565,\n            -0.3724762201,\n            0.1489177048,\n            -0.0966964588,\n            0.1463697106,\n            0.1275274903,\n            0.1033692732,\n            0.0794258416,\n            -0.1771971881,\n            0.2073454857,\n            -0.3109439015,\n            0.1249521226,\n            -0.0101635465,\n            0.0692621097,\n            0.0533472970,\n            0.0433975980,\n            -0.1398337185,\n            -0.0874802172,\n            0.1705365479,\n            -0.2174201906,\n            0.1150254831,\n            0.0460043959,\n            0.0647982135,\n            0.0483694859,\n            -0.2332949787,\n            0.1969220787,\n            -0.1270586401,\n            0.1098557115,\n            -0.1364655048,\n            0.0715296715,\n            0.0553609394,\n            0.0631506816,\n            -0.2169117928,\n            0.0929956511,\n            0.1624538749,\n            -0.2009791434,\n            0.0904926360,\n            -0.0248185843,\n            0.0532633252,\n            0.0435040221,\n            -0.2313277274,\n            0.1497355998,\n            -0.0024202778,\n            0.1029939279,\n            -0.2776987851,\n            0.0963881761,\n            0.0351882279,\n            0.1271408647,\n            -0.2590557337,\n            0.1577988416,\n            0.1429322213,\n            -0.1401246637,\n            0.0866033062,\n            -0.1151762009,\n            0.0683368817,\n            0.0586853735,\n            -0.1322475076,\n            0.0806737095,\n            0.0528722852,\n            0.0920089707,\n            -0.3037962914,\n            0.1280544847,\n            -0.1391123086,\n            0.2215466499,\n            -0.1918463260,\n            0.1376975775,\n            0.1160097718,\n            -0.0549413785,\n            0.0970225409,\n            -0.2708687484,\n            0.1147320047,\n            0.0521945432,\n            -0.0504456684,\n            -0.0012221609,\n            0.0644332916,\n            0.0818370953,\n            -0.1036835983,\n            0.1512031406,\n            -0.4072600305,\n            0.2651379406,\n            -0.0681083873,\n            0.0860663429,\n            0.0810486302,\n            0.0434282124,\n            0.1056238264,\n            -0.2994530201,\n            0.1729898751,\n            -0.1215954795,\n            -0.0481944978,\n            -0.1697723418,\n            0.0725984722,\n            0.0692019314,\n            0.0859903544,\n            0.1680216491,\n            -0.4071443677,\n            0.2292988002,\n            -0.0205532499,\n            0.0566616580,\n            0.0326749459,\n            0.0861379728,\n            0.1142501161,\n            -0.0448331088,\n            0.2054910213,\n            -0.4298293889,\n            -0.0647637174,\n            -0.4240962267,\n            0.1013666242,\n            -0.0110451467,\n            0.1519176364,\n            0.1661346704,\n            -0.0719586164,\n            0.1524447650,\n            -0.0496110357,\n            0.0562372655,\n            -0.1889088154,\n            0.1013496071,\n            0.1339637935,\n            0.1694275290,\n            0.2007708699,\n            -0.4232292175,\n            -0.0401752405,\n            -0.2951072752,\n            0.1443216652,\n            -0.2857291698,\n            0.1489982456,\n            0.1327733696,\n            0.1096193567,\n            0.0852990299,\n            -0.0413062274,\n            0.0820900649,\n            -0.7903561592,\n            0.1329460591,\n            0.1535883099,\n            0.1631743014,\n            0.1585651338,\n            0.1412984729,\n            -0.1033771932,\n            0.1799504310,\n            0.1697744429,\n            -0.5749052763,\n            0.1189445183,\n            0.0911802500,\n            0.0679325759,\n            0.0505003072,\n        ];\n        run_comparison(\n            \"T=10, N=2, C=8\",\n            10,\n            2,\n            8,\n            vec![1, 3, 5, 7, 2, 2, 4, 6, 1, 3],\n            [2, 5],\n            vec![10, 10],\n            vec![5, 5],\n            0,\n            &expected_losses,\n            &expected_grad_flat,\n            1e-3,\n            1e-3,\n        );\n    }\n\n    #[test]\n    fn test_ctc_loss_mixed_input_lengths() {\n        // T=12, N=3, C=5, input_lengths=[12, 7, 10]\n        // Expected losses and gradient from PyTorch\n        let expected_losses = [10.595505714416504, 6.8078508377075195, 7.705057144165039];\n        let expected_grad_flat = [\n            -0.4790987670,\n            -0.2554937005,\n            0.1991624236,\n            0.2478453964,\n            0.2875846624,\n            -0.3495813310,\n            0.2268397957,\n            0.2150714993,\n            -0.2442178279,\n            0.1518878639,\n            -0.2764556706,\n            0.2474014312,\n            -0.2137086987,\n            0.1371368915,\n            0.1056260392,\n            -0.2729502618,\n            -0.3609606028,\n            0.2159237266,\n            0.2238420397,\n            0.1941450834,\n            -0.2953839302,\n            0.1920599341,\n            0.1974952668,\n            -0.2054278404,\n            0.1112565696,\n            -0.1719199270,\n            0.2299505472,\n            -0.2864859998,\n            0.1497263014,\n            0.0787290633,\n            -0.2035763413,\n            -0.3042884767,\n            0.2126964629,\n            0.1810975969,\n            0.1140707731,\n            -0.2759391963,\n            0.0975771844,\n            0.1823379993,\n            -0.1112988219,\n            0.1073228419,\n            -0.1336459517,\n            0.1869296581,\n            -0.1996247321,\n            0.1846873760,\n            -0.0383463502,\n            -0.2254105806,\n            -0.1834360659,\n            0.1925925612,\n            0.1462381780,\n            0.0700158924,\n            -0.2259973884,\n            -0.0393539183,\n            0.1802661419,\n            -0.0571591072,\n            0.1422442794,\n            -0.0609069727,\n            0.1089282706,\n            -0.0313654318,\n            0.2186669111,\n            -0.2353227735,\n            -0.2840364873,\n            -0.0632198900,\n            0.1755636632,\n            0.1377806067,\n            0.0339120962,\n            -0.1904856712,\n            -0.2139032930,\n            0.1827126741,\n            0.0056131603,\n            0.2160631120,\n            -0.0243270602,\n            -0.0070458520,\n            0.1070247591,\n            0.2239368409,\n            -0.2995886803,\n            -0.2955487072,\n            0.0309870224,\n            0.1654911339,\n            0.1581364125,\n            -0.0590658709,\n            -0.2191396207,\n            -0.3791662455,\n            0.1803640425,\n            0.1225430891,\n            0.2953987718,\n            -0.0436352938,\n            -0.1575258970,\n            0.1785279512,\n            0.1756918877,\n            -0.1530586481,\n            -0.1834939867,\n            0.0909025446,\n            0.1423641294,\n            0.1959712654,\n            -0.2457439601,\n            -0.3619639874,\n            -0.3929221630,\n            0.1820438206,\n            0.2454170734,\n            0.3274252713,\n            -0.0628800318,\n            -0.2567180395,\n            0.2112283260,\n            0.0507859327,\n            0.0575838275,\n            -0.0587697029,\n            0.1174769849,\n            0.0783569664,\n            0.2290501744,\n            -0.3661144078,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            -0.0725664943,\n            -0.1532069892,\n            0.2162397504,\n            -0.1248963475,\n            0.1344300956,\n            -0.0362483934,\n            0.1295878887,\n            -0.0502482466,\n            0.2470482886,\n            -0.2901395261,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            -0.1349253207,\n            0.0867646411,\n            0.1998746395,\n            -0.2658679783,\n            0.1141540110,\n            -0.0705668628,\n            0.1519546807,\n            -0.2509805560,\n            0.2475892603,\n            -0.0779965296,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            -0.2338010073,\n            0.2471641302,\n            0.1834627241,\n            -0.3026831448,\n            0.1058573127,\n            -0.1155209392,\n            0.1921830922,\n            -0.4129956067,\n            0.2229512781,\n            0.1133821756,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            -0.2636392713,\n            0.2323469073,\n            -0.2913427949,\n            0.1800564528,\n            0.1425786912,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n            0.0000000000,\n        ];\n        run_comparison(\n            \"T=12, N=3, C=5 (mixed input lengths)\",\n            12,\n            3,\n            5,\n            vec![1, 4, 2, 0, 3, 1, 0, 0, 2, 4, 1, 3],\n            [3, 4],\n            vec![12, 7, 10],\n            vec![3, 2, 4],\n            0,\n            &expected_losses,\n            &expected_grad_flat,\n            1e-3,\n            1e-3,\n        );\n    }\n\n    #[test]\n    fn test_ctc_loss_sum_reduction() {\n        // Same inputs as comparison_uniform_input_lengths, sum reduction\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        let logits = generate_logits(5, 3, 4, &device).require_grad();\n        let log_probs = log_softmax(logits.clone(), 2);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::new(vec![1_i64, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),\n            &device,\n        );\n        let il = Tensor::<TestBackend, 1, Int>::from_data([5_i64, 5, 5], &device);\n        let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i64, 1, 3], &device);\n\n        let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Sum);\n        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();\n\n        let expected_sum = 11.2816486359_f32; // Expected value from PyTorch\n        assert_approx_equal(&loss_data, &[expected_sum], 1e-3);\n\n        let grads = loss.backward();\n        let logits_grad = logits.grad(&grads).unwrap();\n        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();\n        // Expected gradient from PyTorch\n        let expected_grad = [\n            -0.1679008007_f32,\n            -0.4595540464,\n            0.2795598209,\n            0.3478950262,\n            -0.3913056254,\n            -0.0832268298,\n            0.2535884976,\n            0.2209439576,\n            -0.0502742566,\n            0.2766197622,\n            0.2054125518,\n            -0.4317580462,\n            -0.0544800088,\n            -0.3144550920,\n            0.0847885981,\n            0.2841464877,\n            -0.1844545156,\n            -0.2063435912,\n            0.2222184092,\n            0.1685796976,\n            0.0278018005,\n            0.2657383382,\n            -0.0336986706,\n            -0.2598414719,\n            -0.0482986756,\n            -0.0098767160,\n            -0.1533526182,\n            0.2115280181,\n            -0.1380317956,\n            -0.2198686600,\n            0.2042596638,\n            0.1536407918,\n            0.0534787849,\n            0.1819230020,\n            -0.2805589139,\n            0.0451571345,\n            -0.0895631388,\n            0.1996460557,\n            -0.2741115987,\n            0.1640286744,\n            -0.2200077325,\n            -0.1693530381,\n            0.2101601064,\n            0.1792006642,\n            0.0398471877,\n            -0.1131042913,\n            -0.2363226712,\n            0.3095797896,\n            -0.2163617164,\n            0.2740726173,\n            -0.2124865055,\n            0.1547756046,\n            -0.4312027395,\n            -0.0446923785,\n            0.2330704331,\n            0.2428246588,\n            -0.0050083841,\n            -0.6256869435,\n            0.2689785957,\n            0.3617166877,\n        ];\n        assert_approx_equal(&grad_data, &expected_grad, 1e-3);\n    }\n\n    #[test]\n    fn test_ctc_loss_mean_reduction() {\n        let device = NdArrayDevice::Cpu;\n        let ctc = CTCLossConfig::new().init();\n\n        let logits = generate_logits(5, 3, 4, &device).require_grad();\n        let log_probs = log_softmax(logits.clone(), 2);\n        let targets = Tensor::<TestBackend, 2, Int>::from_data(\n            TensorData::new(vec![1_i64, 2, 0, 1, 0, 0, 3, 2, 1], [3, 3]),\n            &device,\n        );\n        let il = Tensor::<TestBackend, 1, Int>::from_data([5_i64, 5, 5], &device);\n        let tl = Tensor::<TestBackend, 1, Int>::from_data([2_i64, 1, 3], &device);\n\n        let loss = ctc.forward_with_reduction(log_probs, targets, il, tl, Reduction::Mean);\n        let loss_data = loss.clone().into_data().to_vec::<f32>().unwrap();\n\n        let expected_mean = 2.2260115147_f32; // Expected value from PyTorch\n        assert_approx_equal(&loss_data, &[expected_mean], 1e-3);\n\n        let grads = loss.backward();\n        let logits_grad = logits.grad(&grads).unwrap();\n        let grad_data = logits_grad.into_data().to_vec::<f32>().unwrap();\n        // Expected gradient from PyTorch\n        let expected_grad = [\n            -0.0279834662_f32,\n            -0.0765923411,\n            0.0465933047,\n            0.0579825081,\n            -0.1304352134,\n            -0.0277422778,\n            0.0845294967,\n            0.0736479908,\n            -0.0055860290,\n            0.0307355281,\n            0.0228236169,\n            -0.0479731150,\n            -0.0090800021,\n            -0.0524091832,\n            0.0141314333,\n            0.0473577492,\n            -0.0614848398,\n            -0.0687812045,\n            0.0740728080,\n            0.0561932363,\n            0.0030890885,\n            0.0295264814,\n            -0.0037442972,\n            -0.0288712755,\n            -0.0080497796,\n            -0.0016461194,\n            -0.0255587716,\n            0.0352546684,\n            -0.0460105985,\n            -0.0732895583,\n            0.0680865571,\n            0.0512135960,\n            0.0059420872,\n            0.0202136654,\n            -0.0311732125,\n            0.0050174589,\n            -0.0149271907,\n            0.0332743451,\n            -0.0456852652,\n            0.0273381118,\n            -0.0733359158,\n            -0.0564510152,\n            0.0700533763,\n            0.0597335547,\n            0.0044274656,\n            -0.0125671430,\n            -0.0262580756,\n            0.0343977548,\n            -0.0360602848,\n            0.0456787720,\n            -0.0354144201,\n            0.0257959347,\n            -0.1437342465,\n            -0.0148974592,\n            0.0776901469,\n            0.0809415579,\n            -0.0005564869,\n            -0.0695207715,\n            0.0298865121,\n            0.0401907414,\n        ];\n        assert_approx_equal(&grad_data, &expected_grad, 1e-3);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/huber.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::{config::Config, module::Module};\n\nuse super::Reduction;\n\n/// Configuration to create a [Huber loss](HuberLoss).\n#[derive(Config, Debug)]\npub struct HuberLossConfig {\n    /// The bound where the Huber loss function changes from quadratic to linear behaviour.\n    pub delta: f32,\n}\n\nimpl HuberLossConfig {\n    /// Initialize [Huber loss](HuberLoss).\n    pub fn init(&self) -> HuberLoss {\n        self.assertions();\n        HuberLoss {\n            delta: self.delta,\n            lin_bias: self.delta * self.delta * 0.5,\n        }\n    }\n\n    fn assertions(&self) {\n        assert!(\n            self.delta >= 0., // This also tests for normality\n            \"Delta for Huber loss must be a non-negative number.\"\n        );\n    }\n}\n\n/// Calculate the Huber loss between the inputs and the target.\n///\n/// The loss for each element of the residuals `r = targets - predictions` is given by\n///\n/// ```text\n/// L(r) = 0.5 * r^2                  if |r| <= d\n/// L(r) = 0.5 * d^2 + d * (|r| - d)  if |r| >  d\n/// ```\n///\n/// where `d` is the configured `delta`. In particular, this is equal to the\n/// [L2 Loss](super::MseLoss) for residuals with magnitude smaller than `delta`,\n/// but behaves linearly instead of quadratically for large residuals.\n///\n/// This loss function is less sensitive to outliers than the mean squared error loss.\n///\n/// See also: <https://en.wikipedia.org/wiki/Huber_loss>\n#[derive(Module, Debug, Clone)]\n#[module(custom_display)]\npub struct HuberLoss {\n    /// The bound where the Huber loss function changes from quadratic to linear behaviour.\n    pub delta: f32,\n    /// Precomputed value for the linear bias.\n    pub lin_bias: f32, // delta * delta * 0.5 precomputed\n}\n\nimpl ModuleDisplay for HuberLoss {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"delta\", &self.delta)\n            .add(\"lin_bias\", &self.lin_bias)\n            .optional()\n    }\n}\n\nimpl HuberLoss {\n    /// Compute the loss element-wise for the predictions and targets, then reduce\n    /// to a single loss value.\n    ///\n    /// `Reduction::Auto` behaves as `Reduction::Mean`.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: \\[...dims\\]\n    /// - targets: \\[...dims\\]\n    /// - output: \\[1\\]\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let loss = self.forward_no_reduction(predictions, targets);\n        match reduction {\n            Reduction::Mean | Reduction::Auto => loss.mean(),\n            Reduction::Sum => loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n    /// Compute the loss element-wise for the predictions and targets.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: [...dims]\n    /// - targets: [...dims]\n    /// - output: [...dims]\n    pub fn forward_no_reduction<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let residuals = targets - predictions;\n        self.forward_residuals(residuals)\n    }\n    /// Compute the loss element-wise for the given residuals.\n    ///\n    /// # Shapes\n    ///\n    /// - residuals: [...dims]\n    /// - output: [...dims]\n    pub fn forward_residuals<const D: usize, B: Backend>(\n        &self,\n        residuals: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let is_large = residuals.clone().abs().greater_elem(self.delta);\n        // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the\n        // `sign()` function, in general, suffers from a jump at 0.\n        // Instead the following tensor implements `delta * sign(r)` for values outside\n        // the bound:\n        let softsign = residuals.clone().clamp(-self.delta, self.delta);\n\n        // 0.5 * d^2 + d * (|r| - d) =\n        // d * |r| - 0.5 * d^2\n        // Moreover |r| = sign(r) * r\n        let outside = softsign.mul(residuals.clone()).sub_scalar(self.lin_bias);\n\n        let inside = residuals.square().mul_scalar(0.5);\n        inside.mask_where(is_large, outside)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_huber_loss() {\n        let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);\n        let targets = TensorData::from([0., 0., 0., 0., 0.]);\n\n        let device = Default::default();\n\n        let predict = TestTensor::<1>::from_data(predict, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let huber = HuberLossConfig::new(0.5).init();\n\n        let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum);\n        let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto);\n        let loss_no_reduction = huber.forward_no_reduction(predict, targets);\n\n        let expected = TensorData::from([0.875, 0.125, 0., 0.045, 0.375]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([0.284]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([1.42]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn test_huber_ad_loss() {\n        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;\n\n        let predict = TensorData::from([-2., -0.5, 0., 0.3, 1.]);\n        let targets = TensorData::from([0., 0., 0., 0., 0.]);\n\n        let device = Default::default();\n        let predict = TestAutodiffTensor::from_data(predict, &device).require_grad();\n        let targets = TestAutodiffTensor::from_data(targets, &device);\n\n        let loss = HuberLossConfig::new(0.5).init();\n        let loss = loss.forward_no_reduction(predict.clone(), targets);\n\n        let grads = loss.backward();\n        let grads_predict = predict.grad(&grads).unwrap();\n\n        let expected = TensorData::from([-0.5, -0.5, 0., 0.3, 0.5]);\n        grads_predict\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = HuberLossConfig::new(0.5);\n        let loss = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{loss}\"),\n            \"HuberLoss {delta: 0.5, lin_bias: 0.125}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/kldiv.rs",
    "content": "use burn_core as burn;\n\nuse super::Reduction;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::{config::Config, module::Module};\n\n/// Configuration to create a [KLDiv loss](KLDivLoss).\n#[derive(Config, Debug)]\npub struct KLDivLossConfig {\n    /// Specifies whether target is the log space. Default: False.\n    #[config(default = false)]\n    pub log_target: bool,\n}\n\nimpl KLDivLossConfig {\n    /// Initialize [KLDiv Loss](KLDivLoss).\n    pub fn init(&self) -> KLDivLoss {\n        KLDivLoss {\n            log_target: self.log_target,\n        }\n    }\n}\n\n/// Kullback-Leibler Divergence Loss\n///\n/// KL Divergence shows the difference between two probability distributions by measuring information loss\n///\n/// KLDivLoss =\n/// ```tex\n/// y_{true} \\cdot (\\log{y_{true}} - \\log{y_{pred}})\n///     ```\n/// By default, the loss expects the input in the log-space.\n/// The targets may also be provided in the log-space if `log_target` is true.\n///\n/// See\n/// - [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence)\n#[derive(Module, Debug, Clone)]\n#[module(custom_display)]\npub struct KLDivLoss {\n    /// Specifies whether target is the log space. Default: False.\n    pub log_target: bool,\n}\n\nimpl ModuleDisplay for KLDivLoss {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"log_target\", &self.log_target).optional()\n    }\n}\n\nimpl KLDivLoss {\n    /// Compute the criterion on the input tensor.\n    ///\n    /// `Reduction::Auto` behaves as `Reduction::BatchMean`,`Reduction::Mean` dose not align with the math definition.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: \\[batch_size,num_targets\\]\n    /// - targets: \\[batch_size,num_targets\\]\n    /// - output: \\[1\\]\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let loss = self.forward_no_reduction(predictions, targets);\n        match reduction {\n            Reduction::BatchMean | Reduction::Auto => {\n                let batch_size = loss.dims()[0] as f32;\n                loss.sum().div_scalar(batch_size)\n            }\n            Reduction::Mean => loss.mean(),\n            Reduction::Sum => loss.sum(),\n        }\n    }\n    /// Compute the criterion on the input tensor without reducing.\n    pub fn forward_no_reduction<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        match self.log_target {\n            true => targets.clone().exp().mul(targets.sub(predictions)),\n            false => {\n                let epsilon = 1e-8;\n                let log_target = targets.clone().clamp(epsilon, 1.0).log();\n                targets.mul(log_target.sub(predictions))\n            }\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_kl_div_loss() {\n        let predict = TensorData::from([[-1.0, -0.5], [-2.0, -0.2]]);\n        let targets = TensorData::from([[0.4, 0.6], [0.1, 0.9]]);\n\n        let device = Default::default();\n        let predict = TestTensor::<2>::from_data(predict, &device);\n        let targets = TestTensor::<2>::from_data(targets, &device);\n\n        let kl_loss = KLDivLossConfig { log_target: false }.init();\n\n        let loss_sum = kl_loss.forward(predict.clone(), targets.clone(), Reduction::Sum);\n        let loss_batch_mean =\n            kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);\n        let loss_no_reduction = kl_loss.forward_no_reduction(predict, targets);\n\n        let expected_no_reduction =\n            TensorData::from([[0.0334837139, -0.0064953566], [-0.0302585065, 0.0851755068]]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_no_reduction, Tolerance::absolute(1e-5));\n\n        let expected_sum = TensorData::from([0.08191]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));\n\n        let expected_batch_mean = TensorData::from([0.04095]);\n        loss_batch_mean\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_batch_mean, Tolerance::absolute(1e-5));\n    }\n\n    #[test]\n    fn test_kl_div_loss_log_target() {\n        let device = Default::default();\n        let predict = TestTensor::<1>::from_data([-1.0, -2.0], &device);\n        let targets = TestTensor::<1>::from_data([-0.5, -1.5], &device);\n\n        let kl_loss = KLDivLossConfig { log_target: true }.init();\n\n        let loss_no_reduction = kl_loss.forward_no_reduction(predict.clone(), targets.clone());\n        let expected_none = TensorData::from([0.3032653299, 0.1115650801]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_none, Tolerance::absolute(1e-5));\n\n        let loss_batch_mean =\n            kl_loss.forward(predict.clone(), targets.clone(), Reduction::BatchMean);\n        let expected_bm = TensorData::from([0.207415204965]);\n        loss_batch_mean\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_bm, Tolerance::absolute(1e-5));\n\n        let loss_sum = kl_loss.forward(predict, targets, Reduction::Sum);\n        let expected_sum = TensorData::from([0.414830409931]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_sum, Tolerance::absolute(1e-5));\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn test_kl_div_ad_loss() {\n        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 2>;\n\n        let device = Default::default();\n        let predict = TestAutodiffTensor::from_data([[-1.0, -0.5]], &device).require_grad();\n        let targets = TestAutodiffTensor::from_data([[0.4, 0.6]], &device);\n\n        let kl_loss = KLDivLossConfig { log_target: false }.init();\n        let loss = kl_loss.forward(predict.clone(), targets, Reduction::Sum);\n\n        let grads = loss.backward();\n        let grads_predict = predict.grad(&grads).unwrap();\n\n        // d/d_pred [target * (log_target - pred)] = -target\n        let expected = TensorData::from([[-0.4, -0.6]]);\n        grads_predict\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = KLDivLossConfig { log_target: true };\n        let loss = config.init();\n\n        assert_eq!(alloc::format!(\"{loss}\"), \"KLDivLoss {log_target: true}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/lp_loss.rs",
    "content": "use super::Reduction;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::{Tensor, backend::Backend};\nuse burn_core as burn;\n\n/// Configuration for the [Lp Loss](LpLoss) module.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_nn::loss::{LpLossConfig, Reduction};\n///\n/// // Create L1 loss (MAE when using mean reduction)\n/// let l1_loss = LpLossConfig::l1();\n///\n/// // Create L2 loss (MSE when using mean reduction)\n/// let l2_loss = LpLossConfig::l2();\n///\n/// // Create custom Lp loss with p=3\n/// let l3_loss = LpLossConfig::new(3.0).init();\n/// ```\n#[derive(Config, Debug)]\npub struct LpLossConfig {\n    /// The exponent `p` determining the type of error measurement.\n    ///\n    /// Common values:\n    /// - `p = 1.0`: L1 loss (MAE with mean reduction) - robust to outliers\n    /// - `p = 2.0`: L2 loss (MSE with mean reduction) - standard choice, differentiable everywhere\n    /// - `p > 2.0`: Increasingly sensitive to large errors (outliers)\n    /// - `0 < p < 1`: More robust to outliers than L1 (quasi-norm)\n    pub p: f64,\n}\n\nimpl LpLossConfig {\n    /// Initializes a [Lp Loss](LpLoss) module.\n    ///\n    /// # Panics\n    ///\n    /// Panics if `p <= 0`.\n    pub fn init(&self) -> LpLoss {\n        self.assertions();\n        LpLoss { p: self.p }\n    }\n\n    /// Creates L1 loss (p=1).\n    ///\n    /// When used with `Reduction::Mean`, this computes Mean Absolute Error (MAE).\n    /// When used with `Reduction::Sum`, this computes Sum of Absolute Errors (SAE).\n    pub fn l1() -> LpLoss {\n        LpLoss { p: 1.0 }\n    }\n\n    /// Creates L2 loss (p=2).\n    ///\n    /// When used with `Reduction::Mean`, this computes Mean Squared Error (MSE).\n    /// When used with `Reduction::Sum`, this computes Sum of Squared Errors (SSE).\n    pub fn l2() -> LpLoss {\n        LpLoss { p: 2.0 }\n    }\n\n    fn assertions(&self) {\n        assert!(self.p > 0.0, \"The order of the norm p must be positive.\")\n    }\n}\n\n/// Computes the Lp Loss between predictions and targets.\n///\n/// This loss function computes the element-wise p-th power of absolute errors,\n/// then reduces them via mean or sum.\n///\n/// # Mathematical Definition\n///\n/// For predictions `ŷ` and targets `y`, the element-wise loss is:\n///\n/// ```text\n/// Lᵢ = |ŷᵢ - yᵢ|ᵖ\n/// ```\n///\n/// With mean reduction (default), the final loss is:\n///\n/// ```text\n/// L = (1/n) × Σᵢ |ŷᵢ - yᵢ|ᵖ\n/// ```\n///\n/// # Notes\n///\n/// - This implementation computes `|error|^p`, **not** the Lp norm `(Σ|error|^p)^(1/p)`.\n/// - The `p = 1` case uses an optimized `abs()` operation.\n/// - The `p = 2` case uses an optimized computation `error * error` instead of `powf`.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_nn::loss::{LpLossConfig, Reduction};\n/// use burn::tensor::Tensor;\n///\n/// // Create L2 loss\n/// let l2_loss = LpLossConfig::l2();\n///\n/// let predictions: Tensor<Backend, 2> = /* model output */;\n/// let targets: Tensor<Backend, 2> = /* ground truth */;\n///\n/// // Compute loss with mean reduction (MSE)\n/// let mse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Mean);\n///\n/// // Compute loss with sum reduction (SSE)\n/// let sse = l2_loss.forward(predictions.clone(), targets.clone(), Reduction::Sum);\n///\n/// // Compute loss with no reduction\n/// let unreduced_l2_loss = l2_loss.forward_no_reduction(predictions, targets);\n/// ```\n#[derive(Module, Clone, Debug)]\npub struct LpLoss {\n    /// The order of the norm (e.g., 1 for L1, 2 for L2).\n    /// Equivalently, the exponent `p` for computing `|error|^p`.\n    pub p: f64,\n}\n\nimpl LpLoss {\n    /// Computes the element-wise loss `|error|^p` with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `predictions` - The model's predicted values.\n    /// * `targets` - The ground truth target values.\n    /// * `reduction` - Specifies how to reduce the element-wise losses:\n    ///   - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.\n    ///   - `Reduction::Sum`: Returns the sum of all element-wise losses.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor containing the reduced loss value.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[...dims]` - Any shape\n    /// - targets: `[...dims]` - Must match predictions shape\n    /// - output: `[1]` - Scalar loss value\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let unreduced_loss = self.forward_no_reduction(predictions, targets);\n\n        match reduction {\n            Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),\n            Reduction::Sum => unreduced_loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Computes the element-wise loss `|error|^p` without reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `predictions` - The model's predicted values.\n    /// * `targets` - The ground truth target values.\n    ///\n    /// # Returns\n    ///\n    /// A tensor of the same shape as the inputs, containing `|prediction - target|^p`\n    /// for each element.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[...dims]` - Any shape\n    /// - targets: `[...dims]` - Must match predictions shape\n    /// - output: `[...dims]` - Same shape as inputs\n    pub fn forward_no_reduction<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let error = predictions.sub(targets);\n\n        // Use simplified/optimized expressions for common cases (p = 1, p = 2)\n        if self.p == 1.0 {\n            // L1 loss\n            error.abs()\n        } else if self.p == 2.0 {\n            // L2 loss\n            error.clone().mul(error)\n        } else {\n            error.abs().powf_scalar(self.p)\n        }\n    }\n\n    /// Computes the element-wise loss `|error|^p` with reduction over specified dimensions.\n    ///\n    /// Calculates element-wise `|predictions - targets|^p`, then takes the mean\n    /// over the specified dimensions. Useful for per-sample or per-channel losses (e.g., when\n    /// working with images).\n    ///\n    /// Dimensions can be provided in any order. They are sorted internally and\n    /// reduced from highest to lowest to ensure indices remain valid.\n    ///\n    /// # Arguments\n    ///\n    /// * `predictions` - The model's predicted values.\n    /// * `targets` - The ground truth target values.\n    /// * `dims` - Dimensions to reduce over.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the specified dimensions reduced to size 1.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// // Image tensor: [batch, C, H, W]\n    /// let l2_loss = LpLossConfig::l2();\n    ///\n    /// // Per-image MSE for PSNR: reduce over C, H, W → [batch, 1, 1, 1]\n    /// let mse_per_image = l2_loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);\n    /// ```\n    pub fn forward_reduce_dims<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        dims: &[usize],\n    ) -> Tensor<B, D> {\n        let error = self.forward_no_reduction(predictions, targets);\n\n        // Sort the dimensions to ascending order\n        let mut sorted_dims = dims.to_vec();\n        sorted_dims.sort();\n\n        // Reduce over specified dimensions\n        error.mean_dims(sorted_dims.as_slice())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_lp_loss_l1_constructor() {\n        let loss_func_l1 = LpLossConfig::l1();\n        let loss_func_p1 = LpLossConfig::new(1.0).init();\n        assert_eq!(loss_func_l1.p, 1.0);\n        assert_eq!(loss_func_l1.p, loss_func_p1.p);\n    }\n\n    #[test]\n    fn test_lp_loss_l2_constructor() {\n        let loss_func_l2 = LpLossConfig::l2();\n        let loss_func_p2 = LpLossConfig::new(2.0).init();\n        assert_eq!(loss_func_l2.p, 2.0);\n        assert_eq!(loss_func_l2.p, loss_func_p2.p);\n    }\n\n    #[test]\n    fn test_lp_loss_l1() {\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),\n            &device,\n        );\n\n        let loss_func = LpLossConfig::l1();\n        let loss_no_reduction =\n            loss_func.forward_no_reduction(predictions.clone(), targets.clone());\n        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);\n\n        let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([1.0]);\n        loss_auto.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([4.0]);\n        loss_sum.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_l2() {\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),\n            &device,\n        );\n\n        let loss_func = LpLossConfig::l2();\n        let loss_no_reduction =\n            loss_func.forward_no_reduction(predictions.clone(), targets.clone());\n        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);\n\n        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([1.5]);\n        loss_auto.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([6.0]);\n        loss_sum.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_p_half() {\n        // L0.5 quasi-norm: more robust to outliers than L1\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 0.0]]),\n            &device,\n        );\n\n        let loss_func = LpLossConfig::new(0.5).init();\n        let loss_no_reduction =\n            loss_func.forward_no_reduction(predictions.clone(), targets.clone());\n        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);\n\n        // |1-2|^0.5 = 1, |2-1|^0.5 = 1, |3-3|^0.5 = 0, |4-0|^0.5 = 2\n        let expected = TensorData::from([[1.0, 1.0], [0.0, 2.0]]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([1.0]);\n        loss_auto.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([4.0]);\n        loss_sum.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_p3() {\n        // L3 norm: more sensitive to outliers than L2\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),\n            &device,\n        );\n\n        let loss_func = LpLossConfig::new(3.0).init();\n        let loss_no_reduction =\n            loss_func.forward_no_reduction(predictions.clone(), targets.clone());\n        let loss_auto = loss_func.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_sum = loss_func.forward(predictions, targets, Reduction::Sum);\n\n        // |1-2|^3 = 1, |2-1|^3 = 1, |3-3|^3 = 0, |4-2|^3 = 8\n        let expected = TensorData::from([[1.0, 1.0], [0.0, 8.0]]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([2.5]);\n        loss_auto.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([10.0]);\n        loss_sum.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_zero_error() {\n        // Test when predictions exactly match targets\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = predictions.clone();\n\n        let loss_func_l1 = LpLossConfig::l1();\n        let loss_func_l2 = LpLossConfig::l2();\n\n        let l1_loss = loss_func_l1.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let l2_loss = loss_func_l2.forward(predictions, targets, Reduction::Auto);\n\n        let expected = TensorData::from([0.0]);\n        l1_loss.into_data().assert_eq(&expected, false);\n        l2_loss.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_negative_errors() {\n        // Test that negative errors are handled correctly (absolute value)\n        let device = Default::default();\n        let predictions =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 2.0, 3.0]), &device);\n        let targets =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([3.0, 4.0, 5.0]), &device);\n        let loss_func_l1 = LpLossConfig::l1();\n        let loss_func_p1 = LpLossConfig::new(1.0).init();\n\n        let loss_no_reduction_l1 =\n            loss_func_l1.forward_no_reduction(predictions.clone(), targets.clone());\n        let loss_no_reduction_p1 = loss_func_p1.forward_no_reduction(predictions, targets);\n\n        // All errors are negative: 1-3=-2, 2-4=-2, 3-5=-2, but |error| = 2\n        let expected = TensorData::from([2.0, 2.0, 2.0]);\n        loss_no_reduction_l1.into_data().assert_eq(&expected, false);\n        loss_no_reduction_p1.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_lp_loss_3d_tensor() {\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.0, 2.0], [3.0, 5.0]], [[4.0, 6.0], [7.0, 10.0]]]),\n            &device,\n        );\n        let loss_func_l2 = LpLossConfig::l2();\n        let loss_func_p2 = LpLossConfig::new(2.0).init();\n\n        let loss_l2 = loss_func_l2.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_p2 = loss_func_p2.forward(predictions, targets, Reduction::Auto);\n\n        // Errors: 1, 0, 0, -1, 1, 0, 0, -2\n        // Squared: 1, 0, 0, 1, 1, 0, 0, 4\n        // Mean: 7/8 = 0.875\n        let expected = TensorData::from([0.875]);\n        loss_l2.into_data().assert_eq(&expected, false);\n        loss_p2.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    #[should_panic(expected = \"The order of the norm p must be positive.\")]\n    fn test_lp_loss_negative_p_panics() {\n        let _ = LpLossConfig::new(-1.0).init();\n    }\n\n    #[test]\n    #[should_panic(expected = \"The order of the norm p must be positive.\")]\n    fn test_lp_loss_zero_p_panics() {\n        let _ = LpLossConfig::new(0.0).init();\n    }\n\n    #[test]\n    fn test_lp_loss_fractional_p() {\n        // Test p = 1.5\n        let device = Default::default();\n        let predictions =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([0.0, 4.0]), &device);\n\n        let targets = Tensor::<TestBackend, 1>::from_data(TensorData::from([1.0, 0.0]), &device);\n\n        let loss_func = LpLossConfig::new(1.5).init();\n        let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);\n\n        // |0-1|^1.5 = 1, |4-0|^1.5 = 8\n        let expected = TensorData::from([1.0, 8.0]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_single_dim() {\n        let device = Default::default();\n        // Shape: [2, 3]\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),\n            &device,\n        );\n        let loss_func_l2 = LpLossConfig::l2();\n        let loss_func_p2 = LpLossConfig::new(2.0).init();\n\n        // Reduce over dim 1 -> should give [2, 1] shape\n        let loss_l2 = loss_func_l2.forward_reduce_dims(predictions.clone(), targets.clone(), &[1]);\n        let loss_p2 = loss_func_p2.forward_reduce_dims(predictions, targets, &[1]);\n\n        // Errors row 0: [1, 0, -3] -> squared: [1, 0, 9] -> mean: 10/3\n        // Errors row 1: [3, 0, 0] -> squared: [9, 0, 0] -> mean: 3\n        let expected = TensorData::from([[10.0 / 3.0], [3.0]]);\n        loss_l2\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n        loss_p2\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_first_dim() {\n        let device = Default::default();\n        // Shape: [2, 3]\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0, 2.0, 6.0], [1.0, 5.0, 6.0]]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l2();\n\n        // Reduce over dim 0 -> should give [1, 3] shape\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[0]);\n\n        // Squared errors: [[1, 0, 9], [9, 0, 0]]\n        // Mean over dim 0: [5, 0, 4.5]\n        let expected = TensorData::from([[5.0, 0.0, 4.5]]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_multiple_dims() {\n        let device = Default::default();\n        // Shape: [2, 2, 2]\n        let predictions = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.0, 2.0], [3.0, 6.0]], [[4.0, 6.0], [7.0, 10.0]]]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l2();\n\n        // Reduce over dims 1 and 2 -> should give [2, 1, 1] shape\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);\n\n        // Batch 0 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25\n        // Batch 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 5/4 = 1.25\n        let expected = TensorData::from([[[1.25]], [[1.25]]]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_all_dims() {\n        let device = Default::default();\n        // Shape: [2, 2]\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l2();\n\n        // Reduce over all dims -> should give [1, 1] shape\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[0, 1]);\n\n        // Errors: [[-1, 1], [0, 2]] -> squared: [[1, 1], [0, 4]] -> mean: 1.5\n        let expected = TensorData::from([[1.5]]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_image_batch() {\n        // Simulate per-image loss for [batch, C, H, W] tensor (common use case for PSNR)\n        let device = Default::default();\n        // Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)\n        let predictions = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[[1.0, 2.0], [3.0, 4.0]]], // Image 1\n                [[[5.0, 6.0], [7.0, 8.0]]], // Image 2\n            ]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[[0.0, 2.0], [3.0, 6.0]]], // Target 1\n                [[[5.0, 5.0], [7.0, 7.0]]], // Target 2\n            ]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l2();\n\n        // Reduce over C, H, W (dims 1, 2, 3) to get per-image MSE\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2, 3]);\n\n        // Image 1 errors: [[1, 0], [0, -2]] -> squared: [[1, 0], [0, 4]] -> mean: 1.25\n        // Image 2 errors: [[0, 1], [0, 1]] -> squared: [[0, 1], [0, 1]] -> mean: 0.5\n        let expected = TensorData::from([[[[1.25]]], [[[0.5]]]]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_with_p1() {\n        let device = Default::default();\n        // Shape: [2, 3]\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0, 5.0, 3.0], [1.0, 5.0, 9.0]]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l1();\n\n        // Reduce over dim 1 -> should give [2, 1] shape\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1]);\n\n        // Abs errors row 0: [1, 3, 0] -> mean: 4/3\n        // Abs errors row 1: [3, 0, 3] -> mean: 2\n        let expected = TensorData::from([[4.0 / 3.0], [2.0]]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_empty_dims() {\n        // Reducing over no dimensions should return the unreduced loss\n        let device = Default::default();\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0, 2.0], [3.0, 6.0]]),\n            &device,\n        );\n        let loss_func = LpLossConfig::l2();\n        let loss_reduce_dims =\n            loss_func.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);\n        let loss_no_reduction = loss_func.forward_no_reduction(predictions, targets);\n\n        // Should be equivalent\n        loss_reduce_dims\n            .into_data()\n            .assert_eq(&loss_no_reduction.into_data(), true);\n    }\n\n    #[test]\n    fn test_forward_reduce_dims_zero_error() {\n        let device = Default::default();\n        // Shape: [2, 2, 2]\n        let predictions = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]),\n            &device,\n        );\n        let targets = predictions.clone();\n        let loss_func = LpLossConfig::l2();\n        let loss = loss_func.forward_reduce_dims(predictions, targets, &[1, 2]);\n\n        // All zeros, reduced to shape: [2, 1, 1]\n        let expected = TensorData::from([[[0.0]], [[0.0]]]);\n        loss.into_data().assert_eq(&expected, false);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/mod.rs",
    "content": "#[cfg(feature = \"pretrained\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"pretrained\")))]\nmod pretrained;\n#[cfg(feature = \"pretrained\")]\n#[cfg_attr(docsrs, doc(cfg(feature = \"pretrained\")))]\npub use pretrained::*;\n\nmod binary_cross_entropy;\nmod cosine_embedding;\nmod cross_entropy;\nmod ctc;\nmod huber;\nmod kldiv;\nmod lp_loss;\nmod mse;\nmod poisson;\nmod reduction;\nmod rnnt;\nmod smooth_l1;\n\npub use binary_cross_entropy::*;\npub use cosine_embedding::*;\npub use cross_entropy::*;\npub use ctc::*;\npub use huber::*;\npub use kldiv::*;\npub use lp_loss::*;\npub use mse::*;\npub use poisson::*;\npub use reduction::*;\npub use rnnt::*;\npub use smooth_l1::*;\n"
  },
  {
    "path": "crates/burn-nn/src/loss/mse.rs",
    "content": "use burn_core as burn;\n\nuse crate::loss::reduction::Reduction;\n\nuse burn::module::Module;\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Calculate the mean squared error loss from the input logits and the targets.\n#[derive(Module, Clone, Debug)]\npub struct MseLoss;\n\nimpl Default for MseLoss {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl MseLoss {\n    /// Create the criterion.\n    pub fn new() -> Self {\n        Self\n    }\n\n    /// Compute the criterion on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - logits: [batch_size, num_targets]\n    /// - targets: [batch_size, num_targets]\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        logits: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let tensor = self.forward_no_reduction(logits, targets);\n        match reduction {\n            Reduction::Mean | Reduction::Auto => tensor.mean(),\n            Reduction::Sum => tensor.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Compute the criterion on the input tensor without reducing.\n    pub fn forward_no_reduction<const D: usize, B: Backend>(\n        &self,\n        logits: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        logits.sub(targets).square()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_mse_loss() {\n        let device = Default::default();\n        let logits = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),\n            &device,\n        );\n\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),\n            &device,\n        );\n\n        let mse = MseLoss::new();\n        let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());\n        let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);\n        let loss_sum = mse.forward(logits, targets, Reduction::Sum);\n\n        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);\n        loss_no_reduction.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([1.5]);\n        loss.into_data().assert_eq(&expected, false);\n\n        let expected = TensorData::from([6.0]);\n        loss_sum.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn display() {\n        let loss = MseLoss::new();\n        assert_eq!(alloc::format!(\"{loss}\"), \"MseLoss\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/poisson.rs",
    "content": "use burn_core as burn;\nuse core::f32::consts::PI;\n\nuse burn::tensor::cast::ToElement;\n\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::{config::Config, module::Module};\n\nuse super::Reduction;\n\n/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.\n///\n/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss\n/// behavior, such as whether the input is in log-space, whether to include the Stirling\n/// approximation term, and a small epsilon value to avoid numerical instability.\n#[derive(Config, Debug)]\npub struct PoissonNllLossConfig {\n    /// If `true`, the predictions are expected to be in log-space.\n    ///\n    /// When `log_input` is `true`, the loss is computed as:\n    /// ```text\n    /// L(predictions, target) = exp(predictions) - target * predictions\n    /// ```\n    /// When `log_input` is `false`, the loss is computed as:\n    /// ```text\n    /// L(predictions, target) = predictions - target * log(predictions + eps)\n    /// ```\n    #[config(default = true)]\n    pub log_input: bool,\n    /// Whether to compute the full loss, including the Stirling approximation term.\n    ///\n    /// When `full` is `true`, the Stirling approximation term is added to the loss:\n    /// ```text\n    /// target * log(target) - target + 0.5 * log(2 * PI * target)\n    /// ```\n    #[config(default = false)]\n    pub full: bool,\n    /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.\n    ///\n    /// This epsilon value is added to the predictions to ensure numerical stability\n    /// when computing the logarithm.\n    #[config(default = 1e-8)]\n    pub eps: f64,\n}\n\nimpl PoissonNllLossConfig {\n    /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.\n    ///\n    /// # Panics\n    /// - Panics if `eps` is not a positive number.\n    pub fn init(&self) -> PoissonNllLoss {\n        self.assertions();\n        PoissonNllLoss {\n            log_input: self.log_input,\n            full: self.full,\n            eps: self.eps,\n        }\n    }\n\n    /// Validates the configuration parameters.\n    ///\n    /// # Panics\n    /// - Panics if `eps` is not a positive number.\n    fn assertions(&self) {\n        assert!(\n            self.eps > 0.,\n            \"eps for PoissonNllLoss must be a positive number.\"\n        );\n    }\n}\n\n/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target.\n///\n/// This loss function is used when the target values are assumed to follow a Poisson distribution.\n/// The loss is defined as:\n/// ```text\n/// target ~ Poisson(input)\n/// L(predictions, target) = predictions - target * log(predictions) + log(target!)\n/// ```\n/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula.\n/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss.\n///\n/// For more details, see:\n/// <https://en.wikipedia.org/wiki/Poisson_regression#Maximum_likelihood-based_parameter_estimation>\n#[derive(Module, Debug, Clone)]\n#[module(custom_display)]\npub struct PoissonNllLoss {\n    /// If `true`, the predictions are expected to be in log-space.\n    pub log_input: bool,\n    /// Whether to compute the full loss, including the Stirling approximation term.\n    pub full: bool,\n    /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`.\n    pub eps: f64,\n}\n\nimpl ModuleDisplay for PoissonNllLoss {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"log_input\", &self.log_input)\n            .add(\"full\", &self.full)\n            .add(\"eps\", &self.eps)\n            .optional()\n    }\n}\n\nimpl PoissonNllLoss {\n    /// Computes the loss element-wise for the given predictions and targets, then reduces\n    /// the result to a single loss value.\n    ///\n    /// # Arguments\n    /// - `predictions`: The predicted values.\n    /// - `targets`: The target values.\n    /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`.\n    ///\n    /// # Shapes\n    /// - `predictions`: `[...dims]`\n    /// - `targets`: `[...dims]`\n    /// - `output`: `[1]`\n    ///\n    /// # Panics\n    /// - Panics if the shapes of `predictions` and `targets` do not match.\n    /// - Panics if any target value is negative.\n    /// - Panics if `log_input` is `false` and any prediction value is negative.\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let loss = self.forward_no_reduction(predictions, targets);\n        match reduction {\n            Reduction::Mean | Reduction::Auto => loss.mean(),\n            Reduction::Sum => loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Computes the loss element-wise for the given predictions and targets without reduction.\n    ///\n    /// # Arguments\n    /// - `predictions`: The predicted values.\n    /// - `targets`: The target values.\n    ///\n    /// # Shapes\n    /// - `predictions`: `[...dims]`\n    /// - `targets`: `[...dims]`\n    /// - `output`: `[...dims]`\n    ///\n    /// # Panics\n    /// - Panics if the shapes of `predictions` and `targets` do not match.\n    /// - Panics if any target value is negative.\n    /// - Panics if `log_input` is `false` and any prediction value is negative.\n    pub fn forward_no_reduction<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        self.assertions(&predictions, &targets);\n        let mut loss;\n        if self.log_input {\n            loss = predictions.clone().exp() - targets.clone() * predictions;\n        } else {\n            loss = predictions.clone() - targets.clone() * (predictions + self.eps).log();\n        }\n        if self.full {\n            let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone()\n                + (targets.clone() * 2. * PI).log() * 0.5;\n            loss = loss\n                + log_stirling_term\n                    .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like());\n        }\n        loss\n    }\n\n    /// Validates the input tensors for the loss computation.\n    ///\n    /// # Panics\n    /// - Panics if the shapes of `predictions` and `targets` do not match.\n    /// - Panics if any target value is negative.\n    /// - Panics if `log_input` is `false` and any prediction value is negative.\n    fn assertions<const D: usize, B: Backend>(\n        &self,\n        predictions: &Tensor<B, D>,\n        targets: &Tensor<B, D>,\n    ) {\n        let predictions_dims = predictions.dims();\n        let targets_dims = targets.dims();\n        assert!(\n            predictions_dims == targets_dims,\n            \"Shape of targets ({targets_dims:?}) should correspond to outer shape of predictions ({predictions_dims:?}).\"\n        );\n        assert!(\n            targets\n                .clone()\n                .greater_equal_elem(0.)\n                .all()\n                .into_scalar()\n                .to_bool(),\n            \"All the values of `targets` must be non-negative.\"\n        );\n        if !self.log_input {\n            assert!(\n                predictions\n                    .clone()\n                    .greater_equal_elem(0.)\n                    .all()\n                    .into_scalar()\n                    .to_bool(),\n                \"When `log_input` is `false`, all the values of `predictions` must be non-negative.\"\n            );\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    #![allow(clippy::approx_constant)]\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_poisson_nll_loss() {\n        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);\n        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().init();\n\n        let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);\n        let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);\n\n        let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([21.0321]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([126.1929]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_poisson_nll_loss_no_log_input() {\n        let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]);\n        let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().with_log_input(false).init();\n\n        let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());\n\n        let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_poisson_nll_loss_full() {\n        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);\n        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().with_full(true).init();\n\n        let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);\n        let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n        let loss_no_reduction = poisson.forward_no_reduction(predictions, targets);\n\n        let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]);\n        loss_no_reduction\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([21.9920]);\n        loss.into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([131.9518]);\n        loss_sum\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn test_poisson_nll_loss_gradients() {\n        type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;\n\n        let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]);\n        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]);\n\n        let device = Default::default();\n\n        let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad();\n        let predictions2 = predictions1.clone();\n        let targets = TestAutodiffTensor::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().with_full(false).init();\n        let poisson_full = PoissonNllLossConfig::new().with_full(true).init();\n\n        let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);\n        let loss_full_sum =\n            poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum);\n\n        let grads = loss_sum.backward();\n        let grads_full = loss_full_sum.backward();\n\n        let grads_predictions1 = predictions1.grad(&grads).unwrap();\n        let grads_predictions2 = predictions2.grad(&grads_full).unwrap();\n\n        let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]);\n\n        grads_predictions1\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n        grads_predictions2\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    #[should_panic = \"eps for PoissonNllLoss must be a positive number.\"]\n    fn test_negative_eps() {\n        let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();\n    }\n\n    #[test]\n    #[should_panic = \"All the values of `targets` must be non-negative.\"]\n    fn test_targets_with_negative_values() {\n        let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]);\n        let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().init();\n\n        let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);\n    }\n\n    #[test]\n    #[should_panic = \"Shape of targets\"]\n    fn test_shape_tensors() {\n        let predictions = TensorData::from([0., 1., 2.]);\n        let targets = TensorData::from([0., 1.]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().init();\n\n        let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());\n    }\n\n    #[test]\n    #[should_panic = \"When `log_input` is `false`, all the values of `predictions` must be non-negative.\"]\n    fn test_exp_predictions_non_negative() {\n        let predictions = TensorData::from([0.3, -0.1, 0.4]);\n        let targets = TensorData::from([0., 1., 0.]);\n\n        let device = Default::default();\n\n        let predictions = TestTensor::<1>::from_data(predictions, &device);\n        let targets = TestTensor::<1>::from_data(targets, &device);\n\n        let poisson = PoissonNllLossConfig::new().with_log_input(false).init();\n\n        let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());\n    }\n\n    #[test]\n    fn display() {\n        let config = PoissonNllLossConfig::new();\n        let loss = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{loss}\"),\n            \"PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/pretrained/gram_matrix/gram_matrix_loss.rs",
    "content": "use burn_core as burn;\n\nuse super::vgg19::Vgg19;\nuse super::weights::load_vgg19_weights;\nuse crate::loss::Reduction;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Configuration for the [Gram Matrix Loss](GramMatrixLoss) module.\n///\n/// Gram Matrix Loss (often used in Neural Style Transfer) measures the difference in\n/// texture or style between two images. It does this by comparing the spatial correlations\n/// of their feature maps extracted from a pretrained VGG19 network.\n///\n/// # Example\n///\n/// ```rust,ignore\n/// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig;\n///\n/// // Create Gram Matrix Loss with equal weights for all 5 layers\n/// let device = Default::default();\n/// let gram_loss = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0])\n///     .with_use_avg_pool(true)\n///     .init::<B>(&device);\n/// ```\n///\n/// # Reference\n/// [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)\n#[cfg_attr(docsrs, doc(cfg(feature = \"pretrained\")))]\n#[derive(Config, Debug)]\npub struct GramMatrixLossConfig {\n    /// The weights of the layer contributing to the total loss.\n    /// Should have a length of 5 since Gram Matrix Loss uses 5 specific VGG19 layers.\n    pub layer_weights: Vec<f32>,\n\n    /// If true, uses average pooling in the VGG19 feature extractor.\n    /// If false, uses the max pooling.\n    #[config(default = \"false\")]\n    pub use_avg_pool: bool,\n}\n\nimpl GramMatrixLossConfig {\n    /// Initializes a [Gram Matrix Loss](GramMatrixLoss) module.\n    ///\n    /// This will automatically download and load the pretrained VGG19 weights\n    /// if they are not already cached locally.\n    ///\n    /// # Panics\n    ///\n    /// - If `layer_weights` does not contain exactly 5 elements.\n    /// - If any of the weights in `layer_weights` is not non-negative.\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig;\n    ///\n    /// // Create Gram Matrix Loss with equal weights for all 5 layers\n    /// let device = Default::default();\n    /// let gram_loss = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0])\n    ///     .init::<B>(&device);\n    /// ```\n    pub fn init<B: Backend>(&self, device: &B::Device) -> GramMatrixLoss<B> {\n        self.assertions();\n\n        let vgg19 = Vgg19::new(self.use_avg_pool, device);\n        let pretrained_vgg19 = load_vgg19_weights(vgg19).no_grad();\n\n        GramMatrixLoss {\n            layer_weights: self.layer_weights.clone(),\n            feat_extractor: pretrained_vgg19,\n        }\n    }\n\n    fn assertions(&self) {\n        assert!(\n            self.layer_weights.len() == 5,\n            \"The layer_weights vector must contain exactly 5 elements\"\n        );\n        assert!(\n            self.layer_weights.iter().all(|&w| w >= 0.0),\n            \"All layer weights must be non-negative\"\n        );\n    }\n}\n\n/// Computes the Gram Matrix Loss between predictions and targets.\n///\n/// This loss function extracts features from 5 specific layers of a pretrained VGG19 network\n/// (`conv1_1`, `conv2_1`, `conv3_1`, `conv4_1`, `conv5_1`). It computes the Gram matrix for each\n/// layer's feature map, which captures the style/texture information, and calculates the\n/// Mean Squared Error between the Gram matrices of the predictions and targets.\n///\n/// # Note\n///\n/// The Gram Matrix Loss assumes the input tensors are already in the \\[0.0, 1.0\\] range.\n///\n/// # Example\n///\n/// ```rust,ignore\n/// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig;\n///\n/// // Initialize the loss function via its config\n/// let device = Default::default();\n/// // Uses max pool by default\n/// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<B>(&device);\n/// ```\n///\n/// # Reference\n/// [Image Style Transfer Using Convolutional Neural Networks](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf)\n#[cfg_attr(docsrs, doc(cfg(feature = \"pretrained\")))]\n#[derive(Module, Debug)]\npub struct GramMatrixLoss<B: Backend> {\n    /// The weights of the layer contributing to the total loss.\n    /// Should have a length of 5 since Gram Matrix Loss uses 5 layers.\n    pub layer_weights: Vec<f32>,\n    /// Pretrained VGG19 feature extractor\n    pub feat_extractor: Vgg19<B>,\n}\n\nimpl<B: Backend> GramMatrixLoss<B> {\n    /// Computes the Gram Matrix Loss with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// - `predictions` - The model's predicted images. The pixels should be in the \\[0.0, 1.0\\] range.\n    /// - `targets` - The ground truth target images. The pixels should be in the \\[0.0, 1.0\\] range.\n    /// - `reduction` - Specifies how to reduce the batch losses.\n    ///   - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of batch losses.\n    ///   - `Reduction::Sum`: Returns the sum of batch losses.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor containing the reduced loss value.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[batch_size, 3, height, width]`\n    /// - targets: `[batch_size, 3, height, width]`\n    /// - output: `[1]`\n    ///\n    /// # Panics\n    ///\n    /// - If the `reduction` type is not supported.\n    /// - If the input tensors do not have exactly 3 channels.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig;\n    /// use burn::loss::Reduction;\n    ///\n    /// let device = Default::default();\n    /// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<B>(&device);\n    ///\n    /// let predictions = /* [N, 3, H, W] */;\n    /// let targets = /* [N, 3, H, W] */;\n    ///\n    /// # Returns a tensor with shape [1] containing a single loss value\n    /// let loss = loss_fn.forward(predictions, targets, Reduction::Mean);\n    /// ```\n    pub fn forward(\n        &self,\n        predictions: Tensor<B, 4>,\n        targets: Tensor<B, 4>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let unreduced_loss = self.forward_no_reduction(predictions, targets);\n\n        match reduction {\n            Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),\n            Reduction::Sum => unreduced_loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Computes the unreduced Gram Matrix Loss per sample in the batch.\n    ///\n    /// # Arguments\n    ///\n    /// - `predictions` - The model's predicted images. The pixels should be in the \\[0.0, 1.0\\] range.\n    /// - `targets` - The ground truth target images. The pixels should be in the \\[0.0, 1.0\\] range.\n    ///\n    /// # Returns\n    ///\n    /// A 1D tensor containing the total weighted loss for each sample in the batch.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[batch_size, 3, height, width]`\n    /// - targets: `[batch_size, 3, height, width]`\n    /// - output: `[batch_size]`\n    ///\n    /// # Panics\n    ///\n    /// - If the input tensors do not have exactly 3 channels.\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// use burn_nn::loss::pretrained::gram_matrix::GramMatrixLossConfig;\n    ///\n    /// let device = Default::default();\n    /// let loss_fn = GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<B>(&device);\n    ///\n    /// let predictions = /* [N, 3, H, W] */;\n    /// let targets = /* [N, 3, H, W] */;\n    ///\n    /// // Returns a tensor of shape [N] containing the loss for each sample\n    /// let unreduced_loss = loss_fn.forward_no_reduction(predictions, targets);\n    /// ```\n    pub fn forward_no_reduction(\n        &self,\n        predictions: Tensor<B, 4>,\n        targets: Tensor<B, 4>,\n    ) -> Tensor<B, 1> {\n        let pred_processed = self.preprocess_input(predictions);\n        let target_processed = self.preprocess_input(targets);\n\n        // Both vectors contain 5 entries since there are 5 layers\n        // Both feature map tensors already have the shape [N, C, H * W]\n        let pred_features = self.feat_extractor.forward(pred_processed);\n        let mut pred_normalization_factors = Vec::with_capacity(5);\n        for feature_tensor in &pred_features {\n            let [_, c, h_times_w] = feature_tensor.dims();\n            let (c_f, hw_f) = (c as f32, h_times_w as f32);\n            pred_normalization_factors.push(4.0 * c_f * c_f * hw_f * hw_f);\n        }\n\n        let target_features = self.feat_extractor.forward(target_processed);\n\n        // Create vector which will hold loss tensors for each layer\n        let mut loss_tensors = Vec::with_capacity(pred_features.len());\n\n        // Compute and add the weighted loss for each layer to the final loss tensor.\n        // Note that the loss tensor for each layer and the final loss tensors\n        // contains a loss value for each sample in the batch.\n        for (pred_f, target_f) in pred_features.into_iter().zip(target_features) {\n            // Compute Gram matrix as G = F(F^T)\n            // [N, C, H*W] times [N, H*W, C] equals [N, C, C]\n            let pred_gram_matrices = pred_f.clone().matmul(pred_f.clone().transpose());\n            let target_gram_matrices = target_f.clone().matmul(target_f.clone().transpose());\n\n            let gram_matrices_diff = pred_gram_matrices - target_gram_matrices;\n            let gram_matrices_diff_squared = gram_matrices_diff.powi_scalar(2);\n\n            // For each sample, sum over all the entries of the gram matrix.\n            // Equivalently, sum over the last two dimensions (the two C dimensions).\n            let loss = gram_matrices_diff_squared\n                .sum_dims(&[1, 2])\n                .squeeze_dims::<1>(&[1, 2]);\n            loss_tensors.push(loss);\n        }\n\n        // Sum each layer's loss in the vector of loss tensors\n        let scaled_loss_tensors: Vec<Tensor<B, 1>> = loss_tensors\n            .into_iter()\n            .zip(pred_normalization_factors)\n            .zip(self.layer_weights.clone())\n            .map(|((loss_tensor, norm_factor), weight)| {\n                loss_tensor.div_scalar(norm_factor).mul_scalar(weight)\n            })\n            .collect();\n        let stacked_loss_tensors = Tensor::stack::<2>(scaled_loss_tensors, 1);\n        stacked_loss_tensors.sum_dim(1).squeeze_dim(1)\n    }\n\n    /// Applies standard ImageNet normalization to the input tensor for the VGG19 network.\n    ///\n    /// # Note\n    ///\n    /// This method assumes the input tensor is already in the \\[0.0, 1.0\\] range.\n    ///\n    /// # Panics\n    ///\n    /// - If the input tensor does not have exactly 3 channels.\n    fn preprocess_input(&self, tensor: Tensor<B, 4>) -> Tensor<B, 4> {\n        let device = &tensor.device();\n        let channels = tensor.dims()[1];\n        assert!(\n            channels == 3,\n            \"Expected input tensor to have exactly 3 channels, but got {}\",\n            channels\n        );\n\n        // ImageNet normalization constants\n        let mean = Tensor::<B, 1>::from_floats([0.485, 0.456, 0.406], device).reshape([1, 3, 1, 1]);\n        let std = Tensor::<B, 1>::from_floats([0.229, 0.224, 0.225], device).reshape([1, 3, 1, 1]);\n\n        (tensor - mean) / std\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::Distribution;\n\n    #[test]\n    #[should_panic(expected = \"The layer_weights vector must contain exactly 5 elements\")]\n    fn test_gram_matrix_loss_config_invalid_length() {\n        let device = Default::default();\n        GramMatrixLossConfig::new(vec![1.0, 1.0]).init::<TestBackend>(&device);\n    }\n\n    #[test]\n    #[should_panic(expected = \"All layer weights must be non-negative\")]\n    fn test_gram_matrix_loss_config_negative_weights() {\n        let device = Default::default();\n        GramMatrixLossConfig::new(vec![1.0, -1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_config_valid_weights() {\n        let device = Default::default();\n        let layer_weights = vec![0.0, 0.2, 0.2, 0.25, 0.4];\n        let loss_fn = GramMatrixLossConfig::new(layer_weights.clone()).init::<TestBackend>(&device);\n        assert_eq!(\n            loss_fn.layer_weights, layer_weights,\n            \"Expected layer weights vector {:?}, got {:?}\",\n            loss_fn.layer_weights, layer_weights\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"Expected input tensor to have exactly 3 channels, but got 1\")]\n    fn test_gram_matrix_loss_1_channel_panic() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        // 1 channel (Grayscale) should panic\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([2, 1, 16, 16], Distribution::Default, &device);\n        let tensor2 = tensor1.clone();\n\n        let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Expected input tensor to have exactly 3 channels, but got 4\")]\n    fn test_gram_matrix_loss_4_channel_panic() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        // 4 channels (e.g., RGBA) should panic\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([2, 4, 16, 16], Distribution::Default, &device);\n        let tensor2 = tensor1.clone();\n\n        let _ = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_zero_for_identical_inputs() {\n        let device = Default::default();\n\n        // Instantiate using Vgg19::new() to use random weights\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([2, 3, 16, 16], Distribution::Default, &device);\n        let tensor2 = tensor1.clone();\n\n        let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n        let loss_val: f32 = loss.into_scalar();\n\n        // Loss should be exactly 0 (or extremely close due to floating point) when inputs are identical\n        assert!(\n            loss_val.abs() < 1e-4,\n            \"Loss should be zero for identical inputs\"\n        );\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_greater_than_zero_for_different_inputs() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        let tensor1: Tensor<TestBackend, 4> = Tensor::ones([2, 3, 16, 16], &device);\n        let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);\n\n        let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n        let loss_val: f32 = loss.into_scalar();\n\n        assert!(\n            loss_val > 0.0,\n            \"Loss should be positive for different inputs\"\n        );\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_forward_no_reduction_shape() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        let batch_size = 4;\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);\n        let tensor2: Tensor<TestBackend, 4> =\n            Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);\n\n        let unreduced_loss = loss_fn.forward_no_reduction(tensor1, tensor2);\n\n        // Unreduced loss should return a 1D tensor with shape [batch_size]\n        assert_eq!(unreduced_loss.dims(), [batch_size]);\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_reduction_sum_vs_mean() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::new(false, &device),\n        };\n\n        let batch_size = 4;\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);\n        let tensor2: Tensor<TestBackend, 4> =\n            Tensor::random([batch_size, 3, 16, 16], Distribution::Default, &device);\n\n        let loss_mean: f32 = loss_fn\n            .forward(tensor1.clone(), tensor2.clone(), Reduction::Mean)\n            .into_scalar();\n        let loss_sum: f32 = loss_fn\n            .forward(tensor1, tensor2, Reduction::Sum)\n            .into_scalar();\n\n        let expected_sum = loss_mean * (batch_size as f32);\n        let diff = (loss_sum - expected_sum).abs();\n\n        // The sum reduction should be equal to the mean reduction multiplied by the batch size\n        assert!(\n            diff < 1e-4,\n            \"Sum reduction should equal batch_size * Mean reduction\"\n        );\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_with_avg_pool() {\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            // Initialize with use_avg_pool = true\n            feat_extractor: Vgg19::new(true, &device),\n        };\n\n        let batch_size = 4;\n        let tensor1: Tensor<TestBackend, 4> = Tensor::ones([batch_size, 3, 16, 16], &device);\n        let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([batch_size, 3, 16, 16], &device);\n\n        let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n        let loss_val: f32 = loss.into_scalar();\n\n        assert!(\n            loss_val > 0.0,\n            \"Loss should be positive for different inputs using avg pooling\"\n        );\n    }\n\n    #[test]\n    fn test_gram_matrix_loss_autodiff() {\n        use crate::TestAutodiffBackend;\n\n        let device = Default::default();\n        let loss_fn = GramMatrixLoss {\n            layer_weights: vec![1.0, 1.0, 1.0, 1.0, 1.0],\n            feat_extractor: Vgg19::<TestAutodiffBackend>::new(false, &device).no_grad(),\n        };\n\n        // The prediction tensor requires gradients\n        let predictions: Tensor<TestAutodiffBackend, 4> =\n            Tensor::ones([2, 3, 16, 16], &device).require_grad();\n\n        // The target tensor does not require gradients\n        let targets: Tensor<TestAutodiffBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);\n\n        let loss = loss_fn.forward(predictions.clone(), targets, Reduction::Mean);\n        let grads = loss.backward();\n\n        // Verify that gradients were successfully computed for the predictions tensor\n        let pred_grad = predictions.grad(&grads);\n        assert!(\n            pred_grad.is_some(),\n            \"Gradients should be computed for the predictions tensor\"\n        );\n\n        // Verify that VGG19 parameters do not have gradients\n        let conv1_1_weight_grad = loss_fn.feat_extractor.conv1_1.weight.val().grad(&grads);\n        assert!(\n            conv1_1_weight_grad.is_none(),\n            \"Gradients should not be computed for VGG19 parameters\"\n        );\n    }\n\n    #[test]\n    #[cfg(feature = \"test-local\")]\n    fn test_gram_matrix_loss_pretrained_weights_identical_inputs() {\n        let device = Default::default();\n        let loss_fn =\n            GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);\n\n        let tensor1: Tensor<TestBackend, 4> =\n            Tensor::random([2, 3, 16, 16], Distribution::Default, &device);\n        let tensor2 = tensor1.clone();\n\n        let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n        let loss_val: f32 = loss.into_scalar();\n\n        // Loss should be exactly 0 (or extremely close due to floating point) when inputs are identical\n        assert!(\n            loss_val.abs() < 1e-4,\n            \"Loss should be zero for identical inputs\"\n        );\n    }\n\n    #[test]\n    #[cfg(feature = \"test-local\")]\n    fn test_gram_matrix_loss_pretrained_weights_different_inputs() {\n        let device = Default::default();\n        let loss_fn =\n            GramMatrixLossConfig::new(vec![1.0, 1.0, 1.0, 1.0, 1.0]).init::<TestBackend>(&device);\n\n        let tensor1: Tensor<TestBackend, 4> = Tensor::ones([2, 3, 16, 16], &device);\n        let tensor2: Tensor<TestBackend, 4> = Tensor::zeros([2, 3, 16, 16], &device);\n\n        let loss = loss_fn.forward(tensor1, tensor2, Reduction::Mean);\n        let loss_val: f32 = loss.into_scalar();\n\n        assert!(\n            loss_val > 0.0,\n            \"Loss should be positive for different inputs\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/pretrained/gram_matrix/mod.rs",
    "content": "mod gram_matrix_loss;\nmod vgg19;\nmod weights;\n\npub use gram_matrix_loss::*;\n"
  },
  {
    "path": "crates/burn-nn/src/loss/pretrained/gram_matrix/vgg19.rs",
    "content": "use burn_core as burn;\n\nuse crate::PaddingConfig2d;\nuse crate::conv::{Conv2d, Conv2dConfig};\nuse burn::module::Module;\nuse burn::tensor::{\n    Tensor,\n    activation::relu,\n    backend::Backend,\n    module::{avg_pool2d, max_pool2d},\n};\n\n/// VGG19 feature extractor for the Gram Matrix Loss.\n///\n/// This module implements the VGG19 architecture up to the 5th convolutional block.\n/// It is specifically tailored for Neural Style Transfer and Gram Matrix Loss,\n/// extracting and flattening features from the following 5 layers:\n/// - `conv1_1`\n/// - `conv2_1`\n/// - `conv3_1`\n/// - `conv4_1`\n/// - `conv5_1`\n#[derive(Module, Debug)]\npub struct Vgg19<B: Backend> {\n    use_avg_pool: bool,\n\n    // Block 1\n    // Field is made public for testing whether the weights are frozen or not\n    pub conv1_1: Conv2d<B>,\n    conv1_2: Conv2d<B>,\n\n    // Block 2\n    conv2_1: Conv2d<B>,\n    conv2_2: Conv2d<B>,\n\n    // Block 3\n    conv3_1: Conv2d<B>,\n    conv3_2: Conv2d<B>,\n    conv3_3: Conv2d<B>,\n    conv3_4: Conv2d<B>,\n\n    // Block 4\n    conv4_1: Conv2d<B>,\n    conv4_2: Conv2d<B>,\n    conv4_3: Conv2d<B>,\n    conv4_4: Conv2d<B>,\n\n    // Block 5\n    conv5_1: Conv2d<B>,\n}\n\nimpl<B: Backend> Vgg19<B> {\n    /// Creates a new VGG19 feature extractor.\n    ///\n    /// The network is initialized with standard VGG19 configurations (3x3 kernels,\n    /// stride 1, padding 1). Note that the weights are randomly initialized here so\n    /// they should be overwritten by `load_vgg19_weights` before use.\n    pub fn new(use_avg_pool: bool, device: &B::Device) -> Self {\n        // All convolutions use a kernel size of 3 by 3, stride of 1, and\n        // padding of 1.\n        // This combination of kernel size and padding preserves input\n        // dimensions. Thus, `PaddingConfig2d::Same` can be used instead.\n        let conv_config = |in_ch, out_ch| {\n            Conv2dConfig::new([in_ch, out_ch], [3, 3])\n                .with_stride([1, 1])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device)\n        };\n\n        Self {\n            use_avg_pool,\n            // Block 1\n            conv1_1: conv_config(3, 64),\n            conv1_2: conv_config(64, 64),\n            // Block 2\n            conv2_1: conv_config(64, 128),\n            conv2_2: conv_config(128, 128),\n            // Block 3\n            conv3_1: conv_config(128, 256),\n            conv3_2: conv_config(256, 256),\n            conv3_3: conv_config(256, 256),\n            conv3_4: conv_config(256, 256),\n            // Block 4\n            conv4_1: conv_config(256, 512),\n            conv4_2: conv_config(512, 512),\n            conv4_3: conv_config(512, 512),\n            conv4_4: conv_config(512, 512),\n            // Block 5\n            conv5_1: conv_config(512, 512),\n        }\n    }\n\n    /// Performs a forward pass to extract features for the Gram Matrix Loss.\n    ///\n    /// # Arguments\n    ///\n    /// - `x` - Input image tensor of shape `[batch_size, 3, height, width]`.\n    ///\n    /// # Returns\n    ///\n    /// A tuple containing:\n    /// - `features`: A `Vec` of 5 tensors, each representing the flattened feature map\n    ///    from one of the target layers. Shape of each tensor: `[batch_size, channels, height * width]`.\n    /// - `normalization_factors`: A `Vec` of 5 `f32` values, representing the normalization\n    ///    factor `4 * N^2 * M^2` for each layer, used to scale the Gram matrix loss.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 3>> {\n        let pool_2d = |x| {\n            if self.use_avg_pool {\n                avg_pool2d(x, [2, 2], [2, 2], [0, 0], false, false)\n            } else {\n                max_pool2d(x, [2, 2], [2, 2], [0, 0], [1, 1], false)\n            }\n        };\n\n        let mut features = Vec::with_capacity(5);\n\n        // Block 1\n        let x1_1 = relu(self.conv1_1.forward(x));\n        let flattened_x1_1 = x1_1.clone().flatten(2, 3);\n        features.push(flattened_x1_1);\n        let x1_2 = relu(self.conv1_2.forward(x1_1));\n        let x1 = pool_2d(x1_2);\n\n        // Block 2\n        let x2_1 = relu(self.conv2_1.forward(x1));\n        let flattened_x2_1 = x2_1.clone().flatten(2, 3);\n        features.push(flattened_x2_1);\n        let x2_2 = relu(self.conv2_2.forward(x2_1));\n        let x2 = pool_2d(x2_2);\n\n        // Block 3\n        let x3_1 = relu(self.conv3_1.forward(x2));\n        let flattened_x3_1 = x3_1.clone().flatten(2, 3);\n        features.push(flattened_x3_1);\n        let x3_2 = relu(self.conv3_2.forward(x3_1));\n        let x3_3 = relu(self.conv3_3.forward(x3_2));\n        let x3_4 = relu(self.conv3_4.forward(x3_3));\n        let x3 = pool_2d(x3_4);\n\n        // Block 4\n        let x4_1 = relu(self.conv4_1.forward(x3));\n        let flattened_x4_1 = x4_1.clone().flatten(2, 3);\n        features.push(flattened_x4_1);\n        let x4_2 = relu(self.conv4_2.forward(x4_1));\n        let x4_3 = relu(self.conv4_3.forward(x4_2));\n        let x4_4 = relu(self.conv4_4.forward(x4_3));\n        let x4 = pool_2d(x4_4);\n\n        // Block 5\n        let x5_1 = relu(self.conv5_1.forward(x4));\n        let flattened_x5_1 = x5_1.flatten(2, 3);\n        features.push(flattened_x5_1);\n\n        features\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/pretrained/gram_matrix/weights.rs",
    "content": "use burn_core as burn;\n\nuse super::vgg19::Vgg19;\nuse burn::tensor::backend::Backend;\nuse burn_std::network::downloader::download_file_as_bytes;\nuse burn_store::{ModuleSnapshot, PytorchStore};\nuse std::fs::{File, create_dir_all, rename};\nuse std::io::Write;\nuse std::path::PathBuf;\n\nconst VGG19_URL: &str = \"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth\";\n\n/// Resolves and returns the local cache directory for the VGG19 weights.\n///\n/// Creates the directory `~/.cache/burn-pretrained-models/loss/vgg19/`\n/// (or OS equivalent) if it does not already exist.\nfn get_cache_dir() -> PathBuf {\n    let cache_dir = dirs::cache_dir()\n        .expect(\"Failed to get cache directory for Gram Matrix Loss\")\n        .join(\"burn-pretrained-models\")\n        .join(\"loss\")\n        .join(\"vgg19\");\n\n    if !cache_dir.exists() {\n        create_dir_all(&cache_dir).expect(\"Failed to create cache directory for Gram Matrix Loss\");\n    }\n\n    cache_dir\n}\n\n/// Downloads the pretrained weights to the `cache_path` if they don't exist already.\n///\n/// Requires an active internet connection on the first run. Subsequent runs will\n/// use the locally cached `.pth` file.\nfn download_weights_if_not_saved(cache_path: &PathBuf) {\n    if !cache_path.exists() {\n        let bytes = download_file_as_bytes(\n            VGG19_URL,\n            \"Downloading VGG19 ImageNet weights for Gram Matrix Loss...\",\n        );\n\n        // Write to a temporary file. If writing gets completed, then rename to the actual/correct name.\n        // If writing is not completed, the file with the correct name (i.e. `cache_path`) will not exist\n        // so this code block can run again which is the desired behavior.\n        let temp_path = cache_path.with_extension(\"pth.tmp\");\n        let mut file = File::create(&temp_path)\n            .expect(\"Failed to create VGG19 cache file for Gram Matrix Loss\");\n        file.write_all(&bytes)\n            .expect(\"Failed to write VGG19 weights to the cache file for Gram Matrix Loss\");\n\n        rename(temp_path, cache_path)\n            .expect(\"Failed to rename temporary file to the actual VGG19 cache file name for Gram Matrix Loss\");\n    }\n}\n\n/// Loads ImageNet pretrained weights into the provided VGG19 feature extractor.\n///\n/// This function downloads the official PyTorch VGG19 weights, remaps the keys\n/// from PyTorch's `features.X` format to Burn's `convX_Y` format, and loads\n/// them into the module.\n///\n/// # Arguments\n///\n/// - `vgg19` - An initialized VGG19 module with random weights.\n///\n/// # Returns\n///\n/// The VGG19 module with pretrained ImageNet weights loaded.\npub fn load_vgg19_weights<B: Backend>(mut vgg19: Vgg19<B>) -> Vgg19<B> {\n    let cache_dir = get_cache_dir();\n    let cache_path = cache_dir.join(\"vgg19.pth\");\n    download_weights_if_not_saved(&cache_path);\n\n    // Download the pretrained weights from PyTorch\n    let mut store = PytorchStore::from_file(cache_path)\n        .allow_partial(true)\n        // Block 1\n        .with_key_remapping(r\"^features\\.0\\.\", \"conv1_1.\")\n        .with_key_remapping(r\"^features\\.2\\.\", \"conv1_2.\")\n        // Block 2\n        .with_key_remapping(r\"^features\\.5\\.\", \"conv2_1.\")\n        .with_key_remapping(r\"^features\\.7\\.\", \"conv2_2.\")\n        // Block 3\n        .with_key_remapping(r\"^features\\.10\\.\", \"conv3_1.\")\n        .with_key_remapping(r\"^features\\.12\\.\", \"conv3_2.\")\n        .with_key_remapping(r\"^features\\.14\\.\", \"conv3_3.\")\n        .with_key_remapping(r\"^features\\.16\\.\", \"conv3_4.\")\n        // Block 4\n        .with_key_remapping(r\"^features\\.19\\.\", \"conv4_1.\")\n        .with_key_remapping(r\"^features\\.21\\.\", \"conv4_2.\")\n        .with_key_remapping(r\"^features\\.23\\.\", \"conv4_3.\")\n        .with_key_remapping(r\"^features\\.25\\.\", \"conv4_4.\")\n        // Block 5\n        .with_key_remapping(r\"^features\\.28\\.\", \"conv5_1.\");\n\n    let result = vgg19.load_from(&mut store);\n    if let Err(e) = result {\n        eprintln!(\"Warning: Some VGG19 weights could not be loaded: {:?}\", e);\n    }\n\n    vgg19\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/pretrained/mod.rs",
    "content": "mod gram_matrix;\n\npub use gram_matrix::*;\n"
  },
  {
    "path": "crates/burn-nn/src/loss/reduction.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\n\n/// The reduction type for the loss.\n#[derive(Config, Debug)]\npub enum Reduction {\n    /// The mean of the losses will be returned.\n    Mean,\n\n    /// The sum of the losses will be returned.\n    Sum,\n\n    /// The sum of the losses divided by the batch_size will be returned.\n    BatchMean,\n\n    /// The mean of the losses will be returned.\n    Auto,\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/rnnt.rs",
    "content": "use super::Reduction;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::{Bool, Int, Tensor, backend::Backend, s};\nuse burn_core as burn;\nuse core::f32;\n\n/// Configuration for [RNNTLoss](RNNTLoss).\n#[derive(Config, Debug)]\npub struct RNNTLossConfig {\n    /// Index of the blank label in the vocabulary. Default: `0`.\n    #[config(default = 0)]\n    pub blank: usize,\n    /// Treat the inputs as logits, applying a log-softmax on the last dimension internally.\n    /// If `false`, the input must already be log-probabilities. Default: `true`.\n    #[config(default = true)]\n    pub logits: bool,\n}\n\nimpl RNNTLossConfig {\n    /// Initializes a [RNNTLoss](RNNTLoss) module.\n    pub fn init(&self) -> RNNTLoss {\n        RNNTLoss {\n            blank: self.blank,\n            logits: self.logits,\n        }\n    }\n}\n\n/// RNN Transducer (RNNT) loss, as described in\n/// [Sequence Transduction with Recurrent Neural Networks](https://arxiv.org/abs/1211.3711).\n///\n/// Computes the negative log-likelihood over a 2D lattice of encoder time steps (T)\n/// and output labels (U), marginalizing over all valid alignments.\n///\n/// # Example\n///\n/// ```rust,ignore\n/// let rnnt = RNNTLossConfig::new().init();\n///\n/// // logits: [B, T, U+1, V] from the joiner network\n/// let loss = rnnt.forward(logits, targets, logit_lengths, target_lengths);\n/// ```\n#[derive(Module, Clone, Debug)]\npub struct RNNTLoss {\n    blank: usize,\n    logits: bool,\n}\n\nimpl RNNTLoss {\n    /// Computes per-sample RNNT loss (no reduction). Returns shape `[B]`.\n    ///\n    /// - `logits`: `[B, T, U+1, V]` — joiner output (raw logits or log-probs)\n    /// - `targets`: `[B, U]` — target label indices (must not contain blank)\n    /// - `logit_lengths`: `[B]` — actual encoder lengths per sample\n    /// - `target_lengths`: `[B]` — actual target lengths per sample\n    pub fn forward<B: Backend>(\n        &self,\n        logits: Tensor<B, 4>,\n        targets: Tensor<B, 2, Int>,\n        logit_lengths: Tensor<B, 1, Int>,\n        target_lengths: Tensor<B, 1, Int>,\n    ) -> Tensor<B, 1> {\n        let device = logits.device();\n        let [b, max_t, max_up1, v] = logits.dims();\n        let max_u = max_up1 - 1;\n\n        self.check_inputs(b, v, &targets, &logit_lengths, &target_lengths, max_u);\n\n        let log_probs = if self.logits {\n            let vocab_dim = 3; // last dim of [B, T, U+1, V]\n            burn::tensor::activation::log_softmax(logits, vocab_dim)\n        } else {\n            logits\n        };\n\n        let (lpb, lpl) = self.extract_log_probs(log_probs, targets);\n        let u_mask = self.create_u_mask(&target_lengths, b, max_up1, &device);\n        let neg_inf = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);\n\n        // Forward pass: compute log_alpha across the (T, U) lattice\n        let mut alpha = self.init_alpha(&lpl, b, max_up1, &device);\n        alpha = neg_inf.clone().mask_where(u_mask.clone(), alpha);\n\n        let logit_lengths_exp = logit_lengths.clone().reshape([b, 1]).expand([b, max_up1]);\n\n        for t in 1..max_t {\n            let new = self.step_alpha(&alpha, &lpb, &lpl, t);\n            let new = neg_inf.clone().mask_where(u_mask.clone(), new);\n\n            // Only update alpha for samples where t < logit_lengths[b]\n            let valid = logit_lengths_exp.clone().greater_elem(t as i64);\n            alpha = alpha.mask_where(valid, new);\n        }\n\n        self.gather_loss(alpha, &lpb, logit_lengths, target_lengths, b, max_up1)\n    }\n\n    /// Computes RNNT loss with the given reduction. Returns shape `[1]`.\n    pub fn forward_with_reduction<B: Backend>(\n        &self,\n        logits: Tensor<B, 4>,\n        targets: Tensor<B, 2, Int>,\n        logit_lengths: Tensor<B, 1, Int>,\n        target_lengths: Tensor<B, 1, Int>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let loss = self.forward(logits, targets, logit_lengths, target_lengths);\n        match reduction {\n            Reduction::Auto | Reduction::Mean => loss.mean(),\n            Reduction::Sum => loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Gathers `log_prob_blank[B, T, U+1]` and `log_prob_label[B, T, U]` from the full\n    /// log-probability tensor by indexing into the vocab dimension.\n    fn extract_log_probs<B: Backend>(\n        &self,\n        log_probs: Tensor<B, 4>,\n        targets: Tensor<B, 2, Int>,\n    ) -> (Tensor<B, 3>, Tensor<B, 3>) {\n        let [b, max_t, max_up1, v] = log_probs.dims();\n        let max_u = max_up1 - 1;\n        let device = log_probs.device();\n        let vocab_dim = 3;\n\n        // Blank probabilities: gather blank index across vocab dim\n        let blank_idx =\n            Tensor::<B, 4, Int>::full([b, max_t, max_up1, 1], self.blank as i64, &device);\n        let lpb = log_probs\n            .clone()\n            .gather(vocab_dim, blank_idx)\n            .squeeze_dim::<3>(vocab_dim);\n\n        // Label probabilities: gather target labels across vocab dim (only first U positions)\n        let tgt = targets\n            .reshape([b, 1, max_u, 1])\n            .expand([b, max_t, max_u, 1]);\n        let lpl = log_probs\n            .slice(s![.., .., 0..max_u, 0..v])\n            .gather(vocab_dim, tgt)\n            .squeeze_dim::<3>(vocab_dim);\n\n        (lpb, lpl)\n    }\n\n    /// Sets up log_alpha at t=0: `alpha(0,0) = 0`, then cumsum of label probs along u.\n    fn init_alpha<B: Backend>(\n        &self,\n        lpl: &Tensor<B, 3>,\n        b: usize,\n        max_up1: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 2> {\n        let mut alpha = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, device);\n        alpha = alpha.slice_assign(s![.., 0..1], Tensor::zeros([b, 1], device));\n\n        // Label probs at t=0\n        let lpl_0 = lpl.clone().slice(s![.., 0..1, ..]).squeeze_dim::<2>(1);\n        for u in 1..max_up1 {\n            let prev = alpha.clone().slice(s![.., (u - 1)..u]);\n            let lp = lpl_0.clone().slice(s![.., (u - 1)..u]);\n            alpha = alpha.slice_assign(s![.., u..(u + 1)], prev.add(lp));\n        }\n        alpha\n    }\n\n    /// Boolean mask `[B, U+1]` that is true where `u <= target_lengths[b]`.\n    fn create_u_mask<B: Backend>(\n        &self,\n        target_lengths: &Tensor<B, 1, Int>,\n        b: usize,\n        max_up1: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 2, Bool> {\n        let indices = Tensor::<B, 1, Int>::arange(0..max_up1 as i64, device)\n            .reshape([1, max_up1])\n            .expand([b, max_up1]);\n        let lengths = target_lengths.clone().reshape([b, 1]).expand([b, max_up1]);\n        indices.lower_equal(lengths)\n    }\n\n    /// One time step of the forward recurrence:\n    ///\n    ///   alpha(t, u) = logaddexp(\n    ///       alpha(t-1, u) + blank(t-1, u),\n    ///       alpha(t, u-1) + label(t, u-1),\n    ///   )\n    fn step_alpha<B: Backend>(\n        &self,\n        alpha: &Tensor<B, 2>,\n        lpb: &Tensor<B, 3>,\n        lpl: &Tensor<B, 3>,\n        t: usize,\n    ) -> Tensor<B, 2> {\n        let [b, max_up1] = alpha.dims();\n        let device = alpha.device();\n\n        // Blank transition: alpha(t-1, :) + blank_prob(t-1, :)\n        let blank_prob = lpb\n            .clone()\n            .slice(s![.., (t - 1)..t, ..])\n            .squeeze_dim::<2>(1);\n        let from_blank = alpha.clone().add(blank_prob);\n\n        let mut new = Tensor::<B, 2>::full([b, max_up1], f32::NEG_INFINITY, &device);\n        new = new.slice_assign(s![.., 0..1], from_blank.clone().slice(s![.., 0..1]));\n\n        // Label probs at time t\n        let label_prob = lpl\n            .clone()\n            .slice(s![.., t..(t + 1), ..])\n            .squeeze_dim::<2>(1);\n\n        for u in 1..max_up1 {\n            let via_blank = from_blank.clone().slice(s![.., u..(u + 1)]);\n            let via_label = new\n                .clone()\n                .slice(s![.., (u - 1)..u])\n                .add(label_prob.clone().slice(s![.., (u - 1)..u]));\n            new = new.slice_assign(s![.., u..(u + 1)], self.log_sum_exp(via_blank, via_label));\n        }\n        new\n    }\n\n    /// Extracts `-(alpha(T_b, U_b) + blank(T_b, U_b))` for each sample in the batch.\n    fn gather_loss<B: Backend>(\n        &self,\n        alpha: Tensor<B, 2>,\n        lpb: &Tensor<B, 3>,\n        logit_lengths: Tensor<B, 1, Int>,\n        target_lengths: Tensor<B, 1, Int>,\n        b: usize,\n        max_up1: usize,\n    ) -> Tensor<B, 1> {\n        let t_idx = logit_lengths.sub_scalar(1);\n        let u_idx = target_lengths;\n\n        let alpha_tu = alpha\n            .gather(1, u_idx.clone().reshape([b, 1]))\n            .squeeze_dim::<1>(1);\n\n        // Gather blank prob at (T_b, U_b)\n        let t_exp = t_idx.reshape([b, 1, 1]).expand([b, 1, max_up1]);\n        let lpb_t = lpb.clone().gather(1, t_exp).squeeze_dim::<2>(1);\n        let lpb_tu = lpb_t.gather(1, u_idx.reshape([b, 1])).squeeze_dim::<1>(1);\n\n        alpha_tu.add(lpb_tu).neg()\n    }\n\n    fn check_inputs<B: Backend>(\n        &self,\n        b: usize,\n        v: usize,\n        targets: &Tensor<B, 2, Int>,\n        logit_lengths: &Tensor<B, 1, Int>,\n        target_lengths: &Tensor<B, 1, Int>,\n        max_u: usize,\n    ) {\n        assert!(\n            self.blank < v,\n            \"blank index {} must be less than vocab_size {}\",\n            self.blank,\n            v\n        );\n        assert_eq!(\n            targets.dims()[0],\n            b,\n            \"targets batch dimension {} must equal batch_size {}\",\n            targets.dims()[0],\n            b\n        );\n        assert_eq!(\n            targets.dims()[1],\n            max_u,\n            \"targets length dimension {} must equal max_target_len (max_u) {}\",\n            targets.dims()[1],\n            max_u\n        );\n        assert_eq!(\n            logit_lengths.dims()[0],\n            b,\n            \"logit_lengths length {} must equal batch_size {}\",\n            logit_lengths.dims()[0],\n            b\n        );\n        assert_eq!(\n            target_lengths.dims()[0],\n            b,\n            \"target_lengths length {} must equal batch_size {}\",\n            target_lengths.dims()[0],\n            b\n        );\n    }\n\n    /// Numerically stable `log(exp(a) + exp(b))`, handling `-inf` inputs.\n    fn log_sum_exp<const D: usize, B: Backend>(\n        &self,\n        a: Tensor<B, D>,\n        b: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let a_inf = a.clone().equal_elem(f32::NEG_INFINITY);\n        let b_inf = b.clone().equal_elem(f32::NEG_INFINITY);\n\n        // Replace -inf with 0 to prevent NaN in the subtraction (masked out below)\n        let a_safe = a.clone().mask_fill(a_inf.clone(), 0.0);\n        let b_safe = b.clone().mask_fill(b_inf.clone(), 0.0);\n\n        // log(exp(a) + exp(b)) = max(a,b) + log(1 + exp(-|a-b|))\n        let max = a_safe.clone().max_pair(b_safe.clone());\n        let result = max.add(a_safe.sub(b_safe).abs().neg().exp().add_scalar(1.0).log());\n\n        // If a=-inf, result is b; if b=-inf, result is a; if both -inf, stays -inf\n        let result = result.mask_where(a_inf, b);\n        result.mask_where(b_inf, a)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn::tensor::{TensorData, Tolerance};\n    use burn_ndarray::{NdArray, NdArrayDevice};\n\n    type B = NdArray<f32>;\n    const NUM_LABELS: usize = 2; // vocab size for simple unit tests\n\n    #[test]\n    fn config_defaults() {\n        let cfg = RNNTLossConfig::new();\n        assert_eq!(cfg.blank, 0);\n        assert!(cfg.logits);\n    }\n\n    #[test]\n    #[should_panic(expected = \"blank index\")]\n    fn panics_on_invalid_blank() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().with_blank(5).init();\n        rnnt.forward(\n            Tensor::<B, 4>::zeros([1, 2, 2, 3], &dev),\n            Tensor::<B, 2, Int>::from_data([[1_i64]], &dev),\n            Tensor::<B, 1, Int>::from_data([2], &dev),\n            Tensor::<B, 1, Int>::from_data([1], &dev),\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"must equal batch_size\")]\n    fn panics_on_batch_mismatch() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        rnnt.forward(\n            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),\n            Tensor::<B, 2, Int>::from_data([[1_i64]], &dev),\n            Tensor::<B, 1, Int>::from_data([3, 3], &dev),\n            Tensor::<B, 1, Int>::from_data([1, 1], &dev),\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"logit_lengths length\")]\n    fn panics_on_logit_lengths_mismatch() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        rnnt.forward(\n            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),\n            Tensor::<B, 2, Int>::from_data([[1_i64], [2]], &dev),\n            Tensor::<B, 1, Int>::from_data([3], &dev),\n            Tensor::<B, 1, Int>::from_data([1, 1], &dev),\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"target_lengths length\")]\n    fn panics_on_target_lengths_mismatch() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        rnnt.forward(\n            Tensor::<B, 4>::zeros([2, 3, 2, 3], &dev),\n            Tensor::<B, 2, Int>::from_data([[1_i64], [2]], &dev),\n            Tensor::<B, 1, Int>::from_data([3, 3], &dev),\n            Tensor::<B, 1, Int>::from_data([1], &dev),\n        );\n    }\n\n    #[test]\n    fn single_token_uniform_probs() {\n        // B=1, T=2, U=1, V=2, uniform probs: P(blank) = P(label) = 1/V\n        //\n        // Two alignment paths (label emitted at t=0 or t=1), each with T+U emissions:\n        //   total_prob = T * (1/V)^(T+1) = 2 * (1/2)^3 = 1/4\n        //   loss = -ln(1/4) = 2*ln(2)\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().with_logits(false).init();\n        let time_steps = 2;\n        let target_len = 1;\n        let v = NUM_LABELS as f32;\n        let log_uniform = (1.0 / v).ln();\n\n        let loss = rnnt.forward(\n            Tensor::<B, 4>::full(\n                [1, time_steps, target_len + 1, NUM_LABELS],\n                log_uniform,\n                &dev,\n            ),\n            Tensor::<B, 2, Int>::from_data([[1_i64]], &dev),\n            Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),\n            Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),\n        );\n        // Each path: T-1 blanks + U labels + 1 final blank = T + U emissions\n        let num_paths = time_steps as f32;\n        let emissions_per_path = (time_steps + target_len) as f32;\n        let total_prob = num_paths * v.powf(-emissions_per_path);\n        let expected_loss = -total_prob.ln();\n        loss.into_data().assert_approx_eq::<f32>(\n            &TensorData::from([expected_loss]),\n            Tolerance::absolute(1e-4),\n        );\n    }\n\n    #[test]\n    fn empty_target() {\n        // B=1, T=3, U=0, V=2, uniform probs: only the all-blanks path exists.\n        //\n        // Single path with T emissions (T-1 blanks + 1 final blank, all at u=0):\n        //   total_prob = (1/V)^T = (1/2)^3 = 1/8\n        //   loss = T*ln(V) = 3*ln(2)\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().with_logits(false).init();\n        let time_steps = 3;\n        let target_len = 0;\n        let v = NUM_LABELS as f32;\n        let log_uniform = (1.0 / v).ln();\n\n        let loss = rnnt.forward(\n            Tensor::<B, 4>::full([1, time_steps, 2, NUM_LABELS], log_uniform, &dev),\n            Tensor::<B, 2, Int>::from_data([[1_i64]], &dev),\n            Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev),\n            Tensor::<B, 1, Int>::from_data([target_len as i64], &dev),\n        );\n        // T + U = T emissions total for U=0\n        let expected_loss = -v.powf(-((time_steps + target_len) as f32)).ln();\n        loss.into_data().assert_approx_eq::<f32>(\n            &TensorData::from([expected_loss]),\n            Tolerance::absolute(1e-4),\n        );\n    }\n\n    #[test]\n    fn logits_equivalence() {\n        // Verify that logits=true (internal log_softmax on raw logits)\n        // gives the same loss as logits=false with external log_softmax.\n        let dev = NdArrayDevice::Cpu;\n        let [bs, time_steps, up1, vocab] = [1, 2, 3, 4];\n        let num_elements = bs * time_steps * up1 * vocab;\n        let target_len = up1 - 1;\n\n        let data: Vec<f32> = (0..num_elements).map(|i| (i as f32 * 0.3).sin()).collect();\n        let logits = Tensor::<B, 4>::from_data(\n            burn_core::tensor::TensorData::new(data, [bs, time_steps, up1, vocab]),\n            &dev,\n        );\n        let targets = Tensor::<B, 2, Int>::from_data([[1_i64, 2]], &dev);\n        let logit_lengths = Tensor::<B, 1, Int>::from_data([time_steps as i64], &dev);\n        let target_lengths = Tensor::<B, 1, Int>::from_data([target_len as i64], &dev);\n\n        let vocab_dim = 3;\n        let fused = RNNTLossConfig::new().with_logits(true).init().forward(\n            logits.clone(),\n            targets.clone(),\n            logit_lengths.clone(),\n            target_lengths.clone(),\n        );\n\n        let log_probs = burn::tensor::activation::log_softmax(logits, vocab_dim);\n        let manual = RNNTLossConfig::new().with_logits(false).init().forward(\n            log_probs,\n            targets,\n            logit_lengths,\n            target_lengths,\n        );\n\n        fused\n            .into_data()\n            .assert_approx_eq::<f32>(&manual.into_data(), Tolerance::absolute(1e-4));\n    }\n}\n\n/// Tests comparing forward loss and backward gradients against torchaudio.functional.rnnt_loss.\n///\n/// Logits are generated deterministically via sin((b*11+t*7+u*13+v*3)*0.1) so the same\n/// values can be reproduced in a Python script for cross-checking.\n#[cfg(test)]\n#[allow(clippy::identity_op, clippy::too_many_arguments)]\nmod pytorch_comparison_tests {\n    use super::*;\n    use burn::tensor::{TensorData, Tolerance};\n    use burn_autodiff::Autodiff;\n    use burn_ndarray::{NdArray, NdArrayDevice};\n\n    type B = Autodiff<NdArray<f32>>;\n    fn tol() -> Tolerance<f32> {\n        Tolerance::absolute(1e-3)\n    }\n\n    /// Deterministic logits matching the Python reference generator.\n    /// Uses coprime coefficients to avoid repeating patterns across dimensions.\n    fn make_logits(bs: usize, t: usize, u: usize, v: usize, dev: &NdArrayDevice) -> Tensor<B, 4> {\n        let mut data = Vec::with_capacity(bs * t * u * v);\n        for bi in 0..bs {\n            for ti in 0..t {\n                for ui in 0..u {\n                    for vi in 0..v {\n                        let idx = bi * 11 + ti * 7 + ui * 13 + vi * 3;\n                        data.push((idx as f32 * 0.1).sin());\n                    }\n                }\n            }\n        }\n        Tensor::from_data(TensorData::new(data, [bs, t, u, v]), dev)\n    }\n\n    /// Checks that gradients along the vocab dim sum to ~0 at every (b, t, u) position.\n    /// This must hold because log_softmax is applied on the last dim,\n    /// and the Jacobian of log_softmax has the property that each row sums to zero.\n    fn check_vocab_grad_sums(grad: &[f32], bs: usize, t: usize, up1: usize, v: usize) {\n        for bi in 0..bs {\n            for ti in 0..t {\n                for ui in 0..up1 {\n                    let base = ((bi * t + ti) * up1 + ui) * v;\n                    let sum: f32 = (0..v).map(|vi| grad[base + vi]).sum();\n                    TensorData::from([sum])\n                        .assert_approx_eq::<f32>(&TensorData::from([0.0f32]), tol());\n                }\n            }\n        }\n    }\n\n    /// Returns the V-sized gradient slice at position (b, t, u) in a flattened [B, T, U+1, V] grad.\n    fn grad_at(\n        grad: &[f32],\n        b: usize,\n        t: usize,\n        u: usize,\n        max_t: usize,\n        up1: usize,\n        v: usize,\n    ) -> &[f32] {\n        let base = ((b * max_t + t) * up1 + u) * v;\n        &grad[base..base + v]\n    }\n\n    /// Asserts that a gradient slice at position (b, t, u) matches expected values.\n    fn assert_grad(\n        grad: &[f32],\n        b: usize,\n        t: usize,\n        u: usize,\n        max_t: usize,\n        up1: usize,\n        v: usize,\n        expected: &[f32],\n    ) {\n        TensorData::from(grad_at(grad, b, t, u, max_t, up1, v))\n            .assert_approx_eq::<f32>(&TensorData::from(expected), tol());\n    }\n\n    #[test]\n    fn basic_b1() {\n        // B=1, T=4, U+1=3, V=3, targets=[1,2]\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        let logits = make_logits(1, 4, 3, 3, &dev).require_grad();\n\n        let loss = rnnt.forward(\n            logits.clone(),\n            Tensor::<B, 2, Int>::from_data([[1_i64, 2]], &dev),\n            Tensor::<B, 1, Int>::from_data([4_i64], &dev),\n            Tensor::<B, 1, Int>::from_data([2_i64], &dev),\n        );\n        loss.clone()\n            .into_data()\n            .assert_approx_eq::<f32>(&TensorData::from([4.4491f32]), tol());\n\n        let grads = loss.sum().backward();\n        let grad = logits\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .to_vec::<f32>()\n            .unwrap();\n\n        // Spot-check first, middle, and last (t, u) positions against torchaudio\n        assert_grad(&grad, 0, 0, 0, 4, 3, 3, &[-0.2041, -0.2246, 0.4287]);\n        assert_grad(&grad, 0, 2, 0, 4, 3, 3, &[0.0079, -0.0640, 0.0561]);\n        assert_grad(&grad, 0, 3, 2, 4, 3, 3, &[-0.6899, 0.3231, 0.3667]);\n        check_vocab_grad_sums(&grad, 1, 4, 3, 3);\n    }\n\n    #[test]\n    fn batched_b2() {\n        // B=2, T=5, U+1=4, V=4, targets=[[1,2,3],[2,1,3]]\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();\n\n        let loss = rnnt.forward(\n            logits.clone(),\n            Tensor::<B, 2, Int>::from_data(\n                TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]),\n                &dev,\n            ),\n            Tensor::<B, 1, Int>::from_data([5_i64, 5], &dev),\n            Tensor::<B, 1, Int>::from_data([3_i64, 3], &dev),\n        );\n        loss.clone()\n            .into_data()\n            .assert_approx_eq::<f32>(&TensorData::from([7.9356f32, 7.2033]), tol());\n\n        let grads = loss.sum().backward();\n        let grad = logits\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .to_vec::<f32>()\n            .unwrap();\n\n        // Spot-check: first position of each sample, and last position\n        assert_grad(&grad, 0, 0, 0, 5, 4, 4, &[-0.3161, -0.3113, 0.2796, 0.3479]);\n        assert_grad(&grad, 1, 0, 0, 5, 4, 4, &[-0.2766, 0.2602, -0.2248, 0.2411]);\n        assert_grad(&grad, 0, 4, 3, 5, 4, 4, &[-0.8216, 0.2296, 0.2786, 0.3133]);\n        assert_grad(&grad, 1, 4, 3, 5, 4, 4, &[-0.7185, 0.2735, 0.2437, 0.2012]);\n        check_vocab_grad_sums(&grad, 2, 5, 4, 4);\n    }\n\n    #[test]\n    fn variable_lengths_b3() {\n        // B=3, T=6, U+1=4, V=5\n        // logit_lengths=[6,4,5], target_lengths=[3,2,1]\n        // Tests that masking works correctly for variable-length sequences.\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        let logits = make_logits(3, 6, 4, 5, &dev).require_grad();\n\n        let loss = rnnt.forward(\n            logits.clone(),\n            Tensor::<B, 2, Int>::from_data(\n                TensorData::new(vec![1_i64, 2, 3, 4, 1, 0, 2, 0, 0], [3, 3]),\n                &dev,\n            ),\n            Tensor::<B, 1, Int>::from_data([6_i64, 4, 5], &dev),\n            Tensor::<B, 1, Int>::from_data([3_i64, 2, 1], &dev),\n        );\n        loss.clone()\n            .into_data()\n            .assert_approx_eq::<f32>(&TensorData::from([10.7458f32, 8.0196, 8.3316]), tol());\n\n        let grads = loss.sum().backward();\n        let grad = logits\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .to_vec::<f32>()\n            .unwrap();\n        let stride = 4 * 5; // U+1 * V per time step\n        let zeros = vec![0.0f32; 5];\n\n        // Sample 0 (full length=6): spot-check first and last active positions\n        assert_grad(\n            &grad,\n            0,\n            0,\n            0,\n            6,\n            4,\n            5,\n            &[-0.4232, -0.3114, 0.1992, 0.2478, 0.2876],\n        );\n        assert_grad(\n            &grad,\n            0,\n            5,\n            3,\n            6,\n            4,\n            5,\n            &[-0.8016, 0.2170, 0.2172, 0.1991, 0.1683],\n        );\n\n        // Sample 1 (logit_length=4): gradients beyond t=3 should be zero\n        assert_grad(\n            &grad,\n            1,\n            0,\n            0,\n            6,\n            4,\n            5,\n            &[-0.2502, 0.2160, 0.2173, 0.2002, -0.3833],\n        );\n        let sample1_t4_start = 1 * 6 * stride + 4 * stride;\n        for i in 0..(2 * stride) {\n            // t=4 and t=5 should all be zero\n            assert!(\n                grad[sample1_t4_start + i].abs() < 1e-3,\n                \"sample 1, t>=4: grad[{}] = {} (expected 0)\",\n                i,\n                grad[sample1_t4_start + i]\n            );\n        }\n\n        // Sample 1 (target_length=2): u=3 positions should be zero within active time steps\n        for ti in 0..4 {\n            assert_grad(&grad, 1, ti, 3, 6, 4, 5, &zeros);\n        }\n\n        // Sample 2 (logit_length=5): t=5 should be zero\n        let sample2_t5_start = 2 * 6 * stride + 5 * stride;\n        for i in 0..stride {\n            assert!(\n                grad[sample2_t5_start + i].abs() < 1e-3,\n                \"sample 2, t=5: grad[{}] = {} (expected 0)\",\n                i,\n                grad[sample2_t5_start + i]\n            );\n        }\n\n        check_vocab_grad_sums(&grad, 3, 6, 4, 5);\n    }\n\n    #[test]\n    fn sum_reduction() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();\n        let tgt = Tensor::<B, 2, Int>::from_data(\n            TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]),\n            &dev,\n        );\n        let il = Tensor::<B, 1, Int>::from_data([5_i64, 5], &dev);\n        let tl = Tensor::<B, 1, Int>::from_data([3_i64, 3], &dev);\n\n        let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Sum);\n        // 7.9356 + 7.2033 = 15.1389\n        loss.clone()\n            .into_data()\n            .assert_approx_eq::<f32>(&TensorData::from([15.1389f32]), tol());\n\n        let grads = loss.backward();\n        let g = logits\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .to_vec::<f32>()\n            .unwrap();\n        TensorData::from(&g[..4]).assert_approx_eq::<f32>(\n            &TensorData::from([-0.3161f32, -0.3113, 0.2796, 0.3479]),\n            tol(),\n        );\n    }\n\n    #[test]\n    fn mean_reduction() {\n        let dev = NdArrayDevice::Cpu;\n        let rnnt = RNNTLossConfig::new().init();\n        let logits = make_logits(2, 5, 4, 4, &dev).require_grad();\n        let tgt = Tensor::<B, 2, Int>::from_data(\n            TensorData::new(vec![1_i64, 2, 3, 2, 1, 3], [2, 3]),\n            &dev,\n        );\n        let il = Tensor::<B, 1, Int>::from_data([5_i64, 5], &dev);\n        let tl = Tensor::<B, 1, Int>::from_data([3_i64, 3], &dev);\n\n        let loss = rnnt.forward_with_reduction(logits.clone(), tgt, il, tl, Reduction::Mean);\n        // 15.1389 / 2 = 7.5694\n        loss.clone()\n            .into_data()\n            .assert_approx_eq::<f32>(&TensorData::from([7.5694f32]), tol());\n\n        // Gradients should be half the sum-reduction gradients (mean over batch of 2)\n        let grads = loss.backward();\n        let g = logits\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .to_vec::<f32>()\n            .unwrap();\n        TensorData::from(&g[..4]).assert_approx_eq::<f32>(\n            &TensorData::from([-0.1581f32, -0.1557, 0.1398, 0.1739]),\n            tol(),\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/loss/smooth_l1.rs",
    "content": "use super::Reduction;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::tensor::{Tensor, backend::Backend};\nuse burn_core as burn;\n\n/// Configuration for the [SmoothL1Loss](SmoothL1Loss) module.\n///\n/// Smooth L1 loss combines L1 and L2 loss, using L2 loss for small errors (below beta)\n/// and L1 loss for large errors (above beta). This makes it less sensitive to outliers\n/// than MSE while maintaining smooth gradients near zero.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};\n///\n/// // Create Smooth L1 loss with default beta=1.0\n/// let smooth_l1 = SmoothL1LossConfig::new().init();\n///\n/// // Create with custom beta\n/// let smooth_l1_custom = SmoothL1LossConfig::new().with_beta(0.5).init();\n/// ```\n#[derive(Config, Debug)]\npub struct SmoothL1LossConfig {\n    /// Specifies the threshold at which to change between L1 and L2 loss.\n    /// The value must be positive. Default: 1.0\n    #[config(default = 1.0)]\n    pub beta: f32,\n}\n\nimpl SmoothL1LossConfig {\n    /// Initializes a [Smooth L1 Loss](SmoothL1Loss) module.\n    ///\n    /// # Panics\n    ///\n    /// Panics if `beta <= 0`.\n    pub fn init(&self) -> SmoothL1Loss {\n        self.assertions();\n        SmoothL1Loss { beta: self.beta }\n    }\n\n    fn assertions(&self) {\n        assert!(self.beta > 0.0, \"The parameter beta must be positive.\")\n    }\n}\n\n/// Computes the Smooth L1 Loss between predictions and targets.\n///\n/// This loss function uses L2 loss for small errors (below beta) and L1 loss for\n/// large errors (above beta), providing robustness to outliers while maintaining\n/// smooth gradients near |x - y| = 0.\n///\n/// # Mathematical Definition\n///\n/// For predictions `x` and targets `y`, the element-wise loss is:\n///\n/// - L_i = 0.5 * (x_i - y_i)² / beta   , if |x_i - y_i| < beta\n/// - L_i = |x_i - y_i| - 0.5 * beta    , otherwise\n///\n/// # Notes\n///\n/// Smooth L1 loss is closely related to HuberLoss since it is equivalent to HuberLoss\n/// scaled by `1/beta`:\n/// `SmoothL1(x, y, beta) = Huber(x, y, beta) / beta`\n///\n/// This leads to the following differences:\n///\n/// - As beta approaches 0, Smooth L1 loss converges to L1Loss, while HuberLoss converges to 0.\n///   When beta = 0, Smooth L1 loss is equivalent to L1 loss. Thus, the `beta`\n///   parameter in Burn must be positive. L1Loss should be used for beta = 0.\n/// - As beta approaches positive infinity, Smooth L1 loss converges to a constant 0 loss, while\n///   HuberLoss converges to L2Loss.\n///\n/// # Example\n///\n/// ```rust,ignore\n/// use burn_nn::loss::{SmoothL1LossConfig, Reduction};\n/// use burn::tensor::Tensor;\n///\n/// // Create Smooth L1 loss with the default beta=1.0\n/// let smooth_l1 = SmoothL1LossConfig::new().init();\n///\n/// let predictions: Tensor<Backend, 2> = /* model output */;\n/// let targets: Tensor<Backend, 2> = /* ground truth */;\n///\n/// // Compute element-wise loss without reduction\n/// let element_wise = smooth_l1.forward(predictions.clone(), targets.clone());\n///\n/// // Compute loss with mean reduction\n/// let loss = smooth_l1.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);\n///\n/// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]\n/// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);\n/// ```\n#[derive(Module, Clone, Debug)]\npub struct SmoothL1Loss {\n    /// Specifies the threshold at which to change between L1 and L2 loss.\n    /// The value must be positive. Default: 1.0\n    pub beta: f32,\n}\n\nimpl SmoothL1Loss {\n    /// Computes the element-wise smooth L1 loss without reduction.\n    ///\n    /// # Arguments\n    ///\n    /// - `predictions` - The model's predicted values.\n    /// - `targets` - The ground truth target values.\n    ///\n    /// # Returns\n    ///\n    /// A tensor of the same shape as the inputs, containing the smooth L1 loss\n    /// for each element.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[...dims]` - Any shape\n    /// - targets: `[...dims]` - Must match predictions shape\n    /// - output: `[...dims]` - Same shape as inputs\n    pub fn forward<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let error = predictions.sub(targets);\n        let abs_error = error.clone().abs();\n\n        // The L1 case: |error| - 0.5 * beta (when |error| >= beta)\n        let l1_loss = abs_error.clone().sub_scalar(0.5 * self.beta);\n\n        // The L2 case: 0.5 * (error)^2 / beta (when |error| < beta)\n        let l2_loss = error.square().mul_scalar(0.5).div_scalar(self.beta);\n\n        let l2_mask = abs_error.lower_elem(self.beta);\n        l1_loss.mask_where(l2_mask, l2_loss)\n    }\n\n    /// Computes the smooth L1 loss with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// - `predictions` - The model's predicted values.\n    /// - `targets` - The ground truth target values.\n    /// - `reduction` - Specifies how to reduce the element-wise losses:\n    ///   - `Reduction::Mean` or `Reduction::Auto`: Returns the mean of all element-wise losses.\n    ///   - `Reduction::Sum`: Returns the sum of all element-wise losses.\n    ///\n    /// # Returns\n    ///\n    /// A scalar tensor containing the reduced loss value.\n    ///\n    /// # Shapes\n    ///\n    /// - predictions: `[...dims]` - Any shape\n    /// - targets: `[...dims]` - Must match predictions shape\n    /// - output: `[1]` - Scalar loss value\n    pub fn forward_with_reduction<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let unreduced_loss = self.forward(predictions, targets);\n\n        match reduction {\n            Reduction::Mean | Reduction::Auto => unreduced_loss.mean(),\n            Reduction::Sum => unreduced_loss.sum(),\n            other => panic!(\"{other:?} reduction is not supported\"),\n        }\n    }\n\n    /// Computes the smooth L1 loss with reduction over specified dimensions.\n    ///\n    /// Calculates element-wise smooth L1 loss, then takes the mean\n    /// over the specified dimensions. Useful for per-sample or per-channel losses.\n    ///\n    /// Dimensions can be provided in any order. They are sorted internally and\n    /// reduced from highest to lowest to ensure indices remain valid.\n    ///\n    /// # Arguments\n    ///\n    /// - `predictions` - The model's predicted values.\n    /// - `targets` - The ground truth target values.\n    /// - `dims` - Dimensions to reduce over.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the specified dimensions reduced to size 1.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// // Consider image tensor with shape [batch, C, H, W]\n    /// let smooth_l1 = SmoothL1LossConfig::new().init();\n    ///\n    /// // Per-image loss: reduce over C, H, W → [batch, 1, 1, 1]\n    /// let loss_per_image = smooth_l1.forward_reduce_dims(predictions, targets, &[1, 2, 3]);\n    /// ```\n    pub fn forward_reduce_dims<const D: usize, B: Backend>(\n        &self,\n        predictions: Tensor<B, D>,\n        targets: Tensor<B, D>,\n        dims: &[usize],\n    ) -> Tensor<B, D> {\n        let error = self.forward(predictions, targets);\n\n        // Sort the dimensions to ascending order\n        let mut sorted_dims = dims.to_vec();\n        sorted_dims.sort();\n\n        // Reduce over specified dimensions\n        error.mean_dims(sorted_dims.as_slice())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n\n    type FT = FloatElem<TestBackend>;\n\n    // =========================================================================\n    // Configuration Tests\n    // =========================================================================\n\n    #[test]\n    fn test_smooth_l1_config_default_beta() {\n        let loss = SmoothL1LossConfig::new().init();\n        assert_eq!(loss.beta, 1.0);\n    }\n\n    #[test]\n    fn test_smooth_l1_config_custom_beta() {\n        let loss = SmoothL1LossConfig::new().with_beta(2.5).init();\n        assert_eq!(loss.beta, 2.5);\n    }\n\n    #[test]\n    #[should_panic(expected = \"The parameter beta must be positive\")]\n    fn test_smooth_l1_config_beta_zero_panics() {\n        SmoothL1LossConfig::new().with_beta(0.0).init();\n    }\n\n    #[test]\n    #[should_panic(expected = \"The parameter beta must be positive\")]\n    fn test_smooth_l1_config_beta_negative_panics() {\n        SmoothL1LossConfig::new().with_beta(-1.0).init();\n    }\n\n    // =========================================================================\n    // Forward Pass (Element-wise) Tests\n    // =========================================================================\n\n    #[test]\n    fn test_smooth_l1_forward_l2_region() {\n        // Beta = 1.0, errors = 0.0 and 0.5 (both < beta, use L2 formula)\n        // L2 formula: 0.5 * error^2 / beta\n        // error = 0.0  ->  loss = 0.5 * 0.0 / 1.0 = 0.0\n        // error = 0.5  ->  loss = 0.5 * 0.25 / 1.0 = 0.125\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.5]]), &device);\n        let targets =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([[0.0_f32, 0.125]]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_l1_region() {\n        // Beta = 1.0, errors = 0.0 and 2.0 (2.0 >= beta, use L1 formula)\n        // L1 formula: |error| - 0.5 * beta\n        // L2 formula: 0.5 * (error)^2 / beta\n        // error = 0.0  ->  loss = 0.0\n        // error = 2.0  ->  loss = 2.0 - 0.5 = 1.5\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 2.0]]), &device);\n        let targets =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([[0.0_f32, 1.5]]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_zero_error() {\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[1.0_f32, 2.0, 3.0]]), &device);\n        let targets = predictions.clone();\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([[0.0_f32, 0.0, 0.0]]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_negative_errors() {\n        // Ensure absolute value is used correctly\n        // L1 formula: |error| - 0.5 * beta\n        // L2 formula: 0.5 * (error)^2 / beta\n        // Beta = 1.0, error = -3.0 (L1: 3.0 - 0.5 = 2.5)\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([-3.0_f32]), &device);\n        let targets = Tensor::<TestBackend, 1>::zeros([1], &device);\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([2.5_f32]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_mixed_regions() {\n        // Test with errors in both L1 and L2 regions\n        // Beta = 1.0\n        // L1 formula: |error| - 0.5 * beta\n        // L2 formula: 0.5 * (error)^2 / beta\n        // error = 0.5 -> L2: 0.5 * 0.25 / 1 = 0.125\n        // error = 1.5 -> L1: 1.5 - 0.5 = 1.0\n        // error = 3.0 -> L1: 3.0 - 0.5 = 2.5\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([0.5_f32, 1.5, 3.0]), &device);\n        let targets = Tensor::<TestBackend, 1>::zeros([3], &device);\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([0.125_f32, 1.0, 2.5]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_custom_beta_values() {\n        // Test with beta = 0.5\n        // error = 0.25 (< beta): L2 = 0.5 * 0.0625 / 0.5 = 0.0625\n        // error = 1.0 (>= beta): L1 = 1.0 - 0.25 = 0.75\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().with_beta(0.5).init();\n\n        let predictions =\n            Tensor::<TestBackend, 1>::from_data(TensorData::from([0.25_f32, 1.0]), &device);\n        let targets = Tensor::<TestBackend, 1>::zeros([2], &device);\n\n        let output = loss.forward(predictions, targets);\n        let expected = TensorData::from([0.0625_f32, 0.75]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    // =========================================================================\n    // forward_with_reduction Tests\n    // =========================================================================\n\n    #[test]\n    fn test_smooth_l1_reduction_mean() {\n        // Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)\n        // Mean: (0.125 + 1.5) / 2 = 0.8125\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);\n        let targets =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);\n\n        let output = loss.forward_with_reduction(predictions, targets, Reduction::Mean);\n        let expected = TensorData::from([0.8125_f32]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_reduction_sum() {\n        // Errors: 0.5 (L2: 0.125), 2.0 (L1: 1.5)\n        // Sum: 1.625\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.5_f32, 2.0]]), &device);\n        let targets =\n            Tensor::<TestBackend, 2>::from_data(TensorData::from([[0.0_f32, 0.0]]), &device);\n\n        let output = loss.forward_with_reduction(predictions, targets, Reduction::Sum);\n        let expected = TensorData::from([1.625_f32]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_reduction_auto_equals_mean() {\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions = Tensor::<TestBackend, 1>::from_data(TensorData::from([2.0_f32]), &device);\n        let targets = Tensor::<TestBackend, 1>::zeros([1], &device);\n\n        let mean_out =\n            loss.forward_with_reduction(predictions.clone(), targets.clone(), Reduction::Mean);\n        let auto_out = loss.forward_with_reduction(predictions, targets, Reduction::Auto);\n\n        mean_out.into_data().assert_eq(&auto_out.into_data(), false);\n    }\n\n    // =========================================================================\n    // Dimension Reduction Tests\n    // =========================================================================\n\n    #[test]\n    fn test_smooth_l1_forward_reduce_dims_single_dim() {\n        // Beta = 2.0\n        // L1 formula: |error| - 0.5 * beta\n        // L2 formula: 0.5 * (error)^2 / beta\n        // Row 0: errors [0.0, 1.0, 4.0]\n        //   error = 0.0 -> L2: 0.0\n        //   error = 1.0 -> L2: 0.5 * 1.0 / 2.0 = 0.25\n        //   error = 4.0 -> L1: 4.0 - 1.0 = 3.0\n        //   Mean = 3.25 / 3 = 1.083333...\n        // Row 1: errors [0.0, 0.0, 0.0] -> Mean = 0.0\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().with_beta(2.0).init();\n\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0_f32, 1.0, 4.0], [5.0_f32, 5.0, 5.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.0_f32, 0.0, 0.0], [5.0_f32, 5.0, 5.0]]),\n            &device,\n        );\n\n        let output = loss.forward_reduce_dims(predictions, targets, &[1]);\n        let expected = TensorData::from([[3.25_f32 / 3.0], [0.0]]); // 3.25/3 = 1.0833...\n        output\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_reduce_dims_image_batch() {\n        // Simulate per-image Smooth L1 loss for [batch, C, H, W] tensor\n        // (common in object detection like Fast R-CNN)\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init(); // beta = 1.0\n\n        // Shape: [2, 1, 2, 2] (batch=2, C=1, H=2, W=2)\n        let predictions = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[[0.5_f32, 2.0], [0.0, 3.0]]], // Image 1\n                [[[1.0_f32, 0.0], [0.5, 1.5]]], // Image 2\n            ]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::zeros([2, 1, 2, 2], &device);\n\n        // Reduce over C, H, W (dims 1, 2, 3) to get per-image loss\n        let output = loss.forward_reduce_dims(predictions, targets, &[1, 2, 3]);\n\n        // Image 1: losses [[0.125, 1.5], [0.0, 2.5]] -> mean: 4.125 / 4 = 1.03125\n        // Image 2: losses [[0.5, 0.0], [0.125, 1.0]] -> mean: 1.625 / 4 = 0.40625\n        let expected = TensorData::from([[[[1.03125_f32]]], [[[0.40625_f32]]]]);\n        output.into_data().assert_eq(&expected, false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_reduce_dims_unsorted() {\n        // Test that unsorted dimensions are handled correctly (sorted internally)\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[1.0_f32, 2.0], [3.0, 4.0]], [[5.0_f32, 6.0], [7.0, 8.0]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 3>::zeros([2, 2, 2], &device);\n\n        // Pass dims in reverse order\n        let output = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[2, 1]);\n        let expected_output = loss.forward_reduce_dims(predictions, targets, &[1, 2]);\n\n        output\n            .into_data()\n            .assert_eq(&expected_output.into_data(), false);\n    }\n\n    #[test]\n    fn test_smooth_l1_forward_reduce_dims_empty_dims() {\n        // Reducing over no dimensions should return the unreduced loss\n        let device = Default::default();\n        let loss = SmoothL1LossConfig::new().init();\n\n        let predictions = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.5_f32, 2.0], [0.0, 3.0]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 2>::zeros([2, 2], &device);\n\n        let loss_reduce_dims = loss.forward_reduce_dims(predictions.clone(), targets.clone(), &[]);\n        let loss_no_reduction = loss.forward(predictions, targets);\n\n        loss_reduce_dims\n            .into_data()\n            .assert_eq(&loss_no_reduction.into_data(), false);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/attention/cross_attention.rs",
    "content": "//! Cross-Attention Module for Burn\n//!\n//! Features:\n//! - Asymmetric Input Shapes (Query vs Context)\n//! - Grouped Query Attention (GQA) & Multi-Query Attention (MQA) support\n//! - Quantization-Safe Masking (min_float)\n//! - Sparse-Ready (quiet_softmax)\n//! - KV Caching for Streaming Inference\n\nuse crate::cache::TensorCache;\nuse crate::modules::{Linear, LinearConfig};\nuse crate::{Dropout, DropoutConfig};\nuse burn_core as burn;\n\nuse burn::{\n    config::Config,\n    module::{Initializer, Module},\n    tensor::{\n        Bool, Tensor,\n        activation::{quiet_softmax, softmax},\n        backend::Backend,\n    },\n};\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n#[derive(Config, Debug)]\n/// Configuration to create a [CrossAttention](CrossAttention) layer using the [init function](CrossAttentionConfig::init).\npub struct CrossAttentionConfig {\n    /// Dimension of the Query (e.g., Decoder state).\n    pub d_model: usize,\n    /// Dimension of the Context (e.g., Encoder audio embeddings).\n    pub d_context: usize,\n    /// Number of heads for the Query.\n    pub n_heads: usize,\n    /// Number of heads for Key/Value (Set to 1 for MQA, set to n_heads for MHA).\n    pub n_heads_kv: usize,\n    /// Dimension of a single head.\n    pub d_head: usize,\n    /// Dropout rate.\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    /// Masking value. Use -1.0e4 for f16/bf16 safety.\n    #[config(default = -1.0e4)]\n    pub min_float: f64,\n    /// Use quiet_softmax to allow zero-attention (good for sparse/quantized models).\n    #[config(default = false)]\n    pub quiet_softmax: bool,\n}\n\n#[derive(Module, Debug)]\n/// The Cross attention module\n///\n/// # Params\n///\n/// - `query`: [`Linear`] layer with `d_model` input and output features.\n/// - `key`: [`Linear`] layer with `d_model` input and output features.\n/// - `value`: [`Linear`] layer with `d_model` input and output features.\n/// - `output`: [`Linear`] layer with `d_model` input and output features.\n///\n/// Should be created with [CrossAttentionConfig].\npub struct CrossAttention<B: Backend> {\n    query: Linear<B>,\n    key: Linear<B>,\n    value: Linear<B>,\n    output: Linear<B>,\n    dropout: Dropout,\n\n    n_heads: usize,\n    n_heads_kv: usize,\n    d_head: usize,\n    scale: f64,\n    min_float: f64,\n    quiet_softmax: bool,\n}\n\n/// Cache for the [Cross Attention](CrossAttention) layer.\n///\n/// To be used during inference when context is constant.\npub struct CrossAttentionCache<B: Backend> {\n    /// Cached key tensor.\n    pub k: TensorCache<B, 4>,\n    /// Cached value tensor.\n    pub v: TensorCache<B, 4>,\n}\n\nimpl<B: Backend> CrossAttentionCache<B> {\n    /// Create a new empty cache.\n    pub fn new() -> Self {\n        Self {\n            k: TensorCache::empty(),\n            v: TensorCache::empty(),\n        }\n    }\n}\n\nimpl<B: Backend> Default for CrossAttentionCache<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl CrossAttentionConfig {\n    /// Initializes a new cross-attention module.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - The device on which to initialize the module.\n    ///\n    /// # Returns\n    ///\n    /// A new [CrossAttention] module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {\n        // Safety Rail for GQA\n        assert_eq!(\n            self.n_heads % self.n_heads_kv,\n            0,\n            \"Query heads must be divisible by KV heads\"\n        );\n\n        let init_linear = |in_dim, out_dim| {\n            LinearConfig::new(in_dim, out_dim)\n                .with_initializer(Initializer::KaimingUniform {\n                    gain: 1.0 / (self.d_head as f64).sqrt(),\n                    fan_out_only: false,\n                })\n                .init(device)\n        };\n\n        CrossAttention {\n            // ADVICE: Asymmetric Projections\n            query: init_linear(self.d_model, self.n_heads * self.d_head),\n            key: init_linear(self.d_context, self.n_heads_kv * self.d_head),\n            value: init_linear(self.d_context, self.n_heads_kv * self.d_head),\n            output: init_linear(self.n_heads * self.d_head, self.d_model),\n\n            dropout: DropoutConfig::new(self.dropout).init(),\n            n_heads: self.n_heads,\n            n_heads_kv: self.n_heads_kv,\n            d_head: self.d_head,\n            scale: (self.d_head as f64).sqrt().recip(),\n            min_float: self.min_float,\n            quiet_softmax: self.quiet_softmax,\n        }\n    }\n}\n\nimpl<B: Backend> CrossAttention<B> {\n    /// Applies cross-attention to query using context as key and value.\n    ///\n    /// # Arguments\n    ///\n    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.\n    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.\n    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.\n    ///\n    /// # Returns\n    ///\n    /// Output tensor of shape `[batch, seq_len_query, d_model]`.\n    pub fn forward(\n        &self,\n        query: Tensor<B, 3>,\n        context: Tensor<B, 3>,\n        mask: Option<Tensor<B, 2, Bool>>,\n    ) -> Tensor<B, 3> {\n        let [batch, l_q, _] = query.dims();\n        let [_, l_k, _] = context.dims();\n\n        // 1. Projections\n        let q = self.query.forward(query);\n        let k = self.key.forward(context.clone());\n        let v = self.value.forward(context);\n\n        // 2. Reshape Heads\n        // Q: [Batch, Heads, L_q, D_head]\n        let q = q\n            .reshape([batch, l_q, self.n_heads, self.d_head])\n            .swap_dims(1, 2);\n\n        // K, V: [Batch, Heads_KV, L_k, D_head]\n        let k = k\n            .reshape([batch, l_k, self.n_heads_kv, self.d_head])\n            .swap_dims(1, 2);\n        let v = v\n            .reshape([batch, l_k, self.n_heads_kv, self.d_head])\n            .swap_dims(1, 2);\n\n        // 3. GQA Expansion\n        // ADVICE: Handle GQA by repeating KV heads to match Query heads\n        let (k, v) = if self.n_heads != self.n_heads_kv {\n            let n_rep = self.n_heads / self.n_heads_kv;\n            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))\n        } else {\n            (k, v)\n        };\n\n        // 4. Score Calculation\n        let scores = q.matmul(k.transpose()) * self.scale;\n\n        // 5. Masking\n        // ADVICE: Use min_float for F16/FP8 safety\n        let scores = if let Some(mask) = mask {\n            let mask = mask.reshape([batch, 1, 1, l_k]);\n            scores.mask_fill(mask, self.min_float)\n        } else {\n            scores\n        };\n\n        // 6. Softmax\n        // ADVICE: Optional Quiet Softmax for sparse networks\n        let weights = if self.quiet_softmax {\n            quiet_softmax(scores, 3)\n        } else {\n            softmax(scores, 3)\n        };\n\n        let weights = self.dropout.forward(weights);\n\n        // 7. Aggregate & Output\n        let output = weights.matmul(v);\n        let output = output\n            .swap_dims(1, 2)\n            .reshape([batch, l_q, self.n_heads * self.d_head]);\n\n        self.output.forward(output)\n    }\n\n    /// Applies cross-attention to query using context as key and value.\n    ///\n    /// This method uses a cache to avoid recomputing key and value tensors when the context is the same.\n    ///\n    /// # Arguments\n    ///\n    /// * `query` - Query tensor of shape `[batch, seq_len_query, d_model]`.\n    /// * `context` - Context tensor of shape `[batch, seq_len_context, d_context]`.\n    /// * `mask` - Optional attention mask of shape `[batch, seq_len_context]` where `true` indicates positions to mask.\n    /// * `cache` - The cache to use.\n    ///\n    /// # Returns\n    ///\n    /// Output tensor of shape `[batch, seq_len_query, d_model]`.\n    pub fn forward_cache(\n        &self,\n        query: Tensor<B, 3>,\n        context: Tensor<B, 3>,\n        mask: Option<Tensor<B, 2, Bool>>,\n        cache: &mut CrossAttentionCache<B>,\n    ) -> Tensor<B, 3> {\n        let [batch, l_q, _] = query.dims();\n\n        // 1. Projections\n        let q = self.query.forward(query);\n\n        let k_compute = |context: Tensor<B, 3>| {\n            let [batch, l_k, _] = context.dims();\n            self.key\n                .forward(context)\n                .reshape([batch, l_k, self.n_heads_kv, self.d_head])\n                .swap_dims(1, 2)\n        };\n        let v_compute = |context: Tensor<B, 3>| {\n            let [batch, l_k, _] = context.dims();\n            self.value\n                .forward(context)\n                .reshape([batch, l_k, self.n_heads_kv, self.d_head])\n                .swap_dims(1, 2)\n        };\n\n        let k = cache.k.forward_full(context.clone(), k_compute);\n        let v = cache.v.forward_full(context, v_compute);\n\n        let [_, _, l_k, _] = k.dims();\n\n        // 2. Reshape Heads\n        // Q: [Batch, Heads, L_q, D_head]\n        let q = q\n            .reshape([batch, l_q, self.n_heads, self.d_head])\n            .swap_dims(1, 2);\n\n        // K, V are already in their correct shape from k_compute and v_compute\n\n        // 3. GQA Expansion\n        // ADVICE: Handle GQA by repeating KV heads to match Query heads\n        let (k, v) = if self.n_heads != self.n_heads_kv {\n            let n_rep = self.n_heads / self.n_heads_kv;\n            (self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))\n        } else {\n            (k, v)\n        };\n\n        // 4. Score Calculation\n        let scores = q.matmul(k.transpose()) * self.scale;\n\n        // 5. Masking\n        // ADVICE: Use min_float for F16/FP8 safety\n        let scores = if let Some(mask) = mask {\n            let mask = mask.reshape([batch, 1, 1, l_k]);\n            scores.mask_fill(mask, self.min_float)\n        } else {\n            scores\n        };\n\n        // 6. Softmax\n        // ADVICE: Optional Quiet Softmax for sparse networks\n        let weights = if self.quiet_softmax {\n            quiet_softmax(scores, 3)\n        } else {\n            softmax(scores, 3)\n        };\n\n        let weights = self.dropout.forward(weights);\n\n        // 7. Aggregate & Output\n        let output = weights.matmul(v);\n        let output = output\n            .swap_dims(1, 2)\n            .reshape([batch, l_q, self.n_heads * self.d_head]);\n\n        self.output.forward(output)\n    }\n\n    /// Helper for Grouped Query Attention\n    fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {\n        let [b, h, l, d] = x.dims();\n        x.reshape([b, h, 1, l, d])\n            .expand([b, h, n_rep, l, d])\n            .reshape([b, h * n_rep, l, d])\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};\n\n    #[test]\n    fn test_cross_attention_mha_shapes() {\n        let [\n            batch_size,\n            seq_len_query,\n            seq_len_context,\n            d_model,\n            d_context,\n            n_heads,\n            d_head,\n        ] = [7, 13, 15, 32, 40, 4, 8];\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv: n_heads, // MHA case\n            d_head,\n            dropout: 0.1,\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        let cross_attn = config.init::<TestBackend>(&device);\n\n        let query = Tensor::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let context = Tensor::random(\n            [batch_size, seq_len_context, d_context],\n            Distribution::Default,\n            &device,\n        );\n\n        let output = cross_attn.forward(query, context, None);\n\n        assert_eq!(\n            output.shape(),\n            Shape::new([batch_size, seq_len_query, d_model]),\n            \"Output should have the correct shape\",\n        );\n    }\n\n    #[test]\n    fn test_cross_attention_gqa_shapes() {\n        let [\n            batch_size,\n            seq_len_query,\n            seq_len_context,\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv,\n            d_head,\n        ] = [7, 13, 15, 32, 40, 4, 2, 8];\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv, // GQA case\n            d_head,\n            dropout: 0.1,\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        let cross_attn = config.init::<TestBackend>(&device);\n\n        let query = Tensor::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let context = Tensor::random(\n            [batch_size, seq_len_context, d_context],\n            Distribution::Default,\n            &device,\n        );\n\n        let output = cross_attn.forward(query, context, None);\n\n        assert_eq!(\n            output.shape(),\n            Shape::new([batch_size, seq_len_query, d_model]),\n            \"Output should have the correct shape\",\n        );\n    }\n\n    #[test]\n    fn test_cross_attention_mqa_shapes() {\n        let [\n            batch_size,\n            seq_len_query,\n            seq_len_context,\n            d_model,\n            d_context,\n            n_heads,\n            d_head,\n        ] = [7, 13, 15, 32, 40, 4, 8];\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv: 1, // MQA case\n            d_head,\n            dropout: 0.1,\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        let cross_attn = config.init::<TestBackend>(&device);\n\n        let query = Tensor::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let context = Tensor::random(\n            [batch_size, seq_len_context, d_context],\n            Distribution::Default,\n            &device,\n        );\n\n        let output = cross_attn.forward(query, context, None);\n\n        assert_eq!(\n            output.shape(),\n            Shape::new([batch_size, seq_len_query, d_model]),\n            \"Output should have the correct shape\",\n        );\n    }\n\n    #[test]\n    fn test_cross_attention_mask() {\n        let [\n            batch_size,\n            seq_len_query,\n            seq_len_context,\n            d_model,\n            d_context,\n            n_heads,\n            d_head,\n        ] = [3, 6, 8, 12, 16, 4, 3];\n        let num_padded = 2;\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv: n_heads,\n            d_head,\n            dropout: 0.0, // No dropout for deterministic test\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        let cross_attn = config.init::<TestBackend>(&device);\n\n        // Create a padding mask for the context\n        let mut mask: Tensor<TestBackend, 2, Int> =\n            Tensor::zeros([batch_size, seq_len_context], &device);\n        mask = mask.slice_assign(\n            [0..batch_size, seq_len_context - num_padded..seq_len_context],\n            Tensor::ones([batch_size, num_padded], &device),\n        );\n        let mask_bool = mask.equal_elem(1);\n\n        let query = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n\n        let context_1 = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_len_context, d_context],\n            Distribution::Default,\n            &device,\n        );\n\n        // Change the padded part of the context tensor\n        let context_2 = context_1.clone().slice_assign(\n            [\n                0..batch_size,\n                seq_len_context - num_padded..seq_len_context,\n                0..d_context,\n            ],\n            Tensor::random(\n                [batch_size, num_padded, d_context],\n                Distribution::Default,\n                &device,\n            ),\n        );\n\n        // The outputs should be the same since the changed part is masked.\n        let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));\n        let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));\n\n        output_1\n            .into_data()\n            .assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());\n    }\n\n    #[test]\n    #[should_panic]\n    fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model: 32,\n            d_context: 32,\n            n_heads: 5,\n            n_heads_kv: 2,\n            d_head: 8,\n            dropout: 0.1,\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        config.init::<TestBackend>(&device);\n    }\n\n    #[test]\n    fn test_cross_attention_cache() {\n        let [\n            batch_size,\n            seq_len_query,\n            seq_len_context,\n            d_model,\n            d_context,\n            n_heads,\n            d_head,\n        ] = [3, 6, 8, 12, 16, 4, 3];\n        let device = Default::default();\n        let config = CrossAttentionConfig {\n            d_model,\n            d_context,\n            n_heads,\n            n_heads_kv: n_heads,\n            d_head,\n            dropout: 0.0, // No dropout for deterministic test\n            min_float: -1.0e4,\n            quiet_softmax: false,\n        };\n        let cross_attn = config.init::<TestBackend>(&device);\n\n        let query1 = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let context = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_len_context, d_context],\n            Distribution::Default,\n            &device,\n        );\n\n        // First forward pass, no cache\n        let output1 = cross_attn.forward(query1.clone(), context.clone(), None);\n\n        // Second forward pass with cache\n        let mut cache = CrossAttentionCache::new();\n        let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);\n\n        // The two outputs should be identical\n        output1\n            .into_data()\n            .assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());\n\n        // Third forward pass with different query, but same context and cache\n        let query2 = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_len_query, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);\n\n        // For control, do a forward pass without cache with query2\n        let output4 = cross_attn.forward(query2.clone(), context.clone(), None);\n\n        // output3 and output4 should be identical\n        output3\n            .into_data()\n            .assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/attention/mask.rs",
    "content": "use burn_core as burn;\nuse burn_core::config::Config;\n\nuse alloc::vec::Vec;\nuse burn::tensor::ops::IntElem;\n\nuse burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};\n\n/// Generate an autoregressive attention mask.\n///\n/// The mask can be used in Transformer modules to train models to generate tensors sequentially.\npub fn generate_autoregressive_mask<B: Backend>(\n    batch_size: usize,\n    seq_length: usize,\n    device: &B::Device,\n) -> Tensor<B, 3, Bool> {\n    let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);\n    mask.expand([batch_size, seq_length, seq_length])\n}\n\n/// Generate a padding attention mask.\npub struct GeneratePaddingMask<B: Backend> {\n    /// The generated tensor.\n    pub tensor: Tensor<B, 2, Int>,\n\n    /// The generated mask.\n    pub mask: Tensor<B, 2, Bool>,\n}\n\n/// Defines an enumeration to specify sequence length options for padding\n#[derive(Config, Debug, Copy)]\npub enum SeqLengthOption {\n    /// No maximum length; use the longest sequence\n    NoMax,\n    /// Maximum length specified, truncate if necessary\n    Max(usize),\n    /// Fixed length, pad or truncate to this exact length\n    Fixed(usize),\n}\n\nimpl From<Option<usize>> for SeqLengthOption {\n    fn from(val: Option<usize>) -> Self {\n        match val {\n            Some(max) => SeqLengthOption::Max(max),\n            None => SeqLengthOption::NoMax,\n        }\n    }\n}\n\n/// Generates a padding attention mask for a batch of token sequences.\n///\n/// # Arguments\n///\n/// * `pad_token` - The token ID used for padding\n/// * `tokens_list` - Vector of token sequences (each sequence is a vector of token IDs)\n/// * `seq_length` - Sequence length option (NoMax, Max, or Fixed)\n/// * `device` - The device for tensor operations\n///\n/// # Returns\n///\n/// A `GeneratePaddingMask` containing the padded tensor and corresponding mask\npub fn generate_padding_mask<B: Backend>(\n    pad_token: usize,\n    tokens_list: Vec<Vec<usize>>,\n    seq_length: impl Into<SeqLengthOption>,\n    device: &B::Device,\n) -> GeneratePaddingMask<B> {\n    let tokens_max = || {\n        tokens_list\n            .iter()\n            .map(|tokens| tokens.len())\n            .max()\n            .unwrap_or(1)\n    };\n\n    let size = match seq_length.into() {\n        SeqLengthOption::NoMax => tokens_max(),\n        SeqLengthOption::Max(max) => usize::min(tokens_max(), max),\n        SeqLengthOption::Fixed(limit) => limit,\n    };\n    let batch_size = tokens_list.len();\n\n    let mut tensor = Tensor::zeros([batch_size, size], device);\n    tensor = tensor.add_scalar(pad_token as i64);\n\n    for (index, tokens) in tokens_list.into_iter().enumerate() {\n        let seq_length = tokens.len().min(size);\n        tensor = tensor.slice_assign(\n            [index..index + 1, 0..seq_length],\n            Tensor::from_data(\n                TensorData::new(\n                    tokens\n                        .into_iter()\n                        .take(size)\n                        .map(|e| (e as i64).elem::<IntElem<B>>())\n                        .collect(),\n                    Shape::new([1, seq_length]),\n                ),\n                device,\n            ),\n        );\n    }\n\n    let mask = tensor.clone().equal_elem(pad_token as i64);\n\n    GeneratePaddingMask { tensor, mask }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use alloc::vec;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn test_generate_autoregressive_mask() {\n        let device = <TestBackend as Backend>::Device::default();\n\n        let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);\n\n        mask.into_data().assert_eq(\n            &TensorData::from([\n                [\n                    [false, true, true],\n                    [false, false, true],\n                    [false, false, false],\n                ],\n                [\n                    [false, true, true],\n                    [false, false, true],\n                    [false, false, false],\n                ],\n            ]),\n            false,\n        );\n    }\n\n    #[test]\n    fn test_generate_padding_mask() {\n        let device = <TestBackend as Backend>::Device::default();\n        let tokens = vec![\n            vec![3, 3, 3],\n            vec![3, 3, 3],\n            vec![3, 3, 3, 4],\n            vec![3, 3, 3, 4, 10, 15],\n        ];\n\n        let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);\n\n        mask.mask.into_data().assert_eq(\n            &TensorData::from([\n                [false, false, false, true, true, true],\n                [false, false, false, true, true, true],\n                [false, false, false, false, true, true],\n                [false, false, false, false, false, false],\n            ]),\n            false,\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/attention/mha.rs",
    "content": "use burn_core as burn;\n\nuse crate::activation::Gelu;\nuse crate::cache::TensorCache;\nuse crate::{Dropout, DropoutConfig, Linear, LinearConfig};\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::{Bool, Tensor, backend::Backend};\n\nuse burn::tensor::activation::{quiet_softmax, softmax};\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer using the [init function](MultiHeadAttentionConfig::init).\n#[derive(Config, Debug)]\npub struct MultiHeadAttentionConfig {\n    /// The size of each linear layer.\n    pub d_model: usize,\n    /// The number of heads.\n    pub n_heads: usize,\n    /// The dropout rate. Default: 0.1\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    /// The minimum value a float can take. Default: -1.0e4\n    /// This is used to mask attention scores before calculating attention weights.\n    /// A value too low might result in NaN.\n    #[config(default = -1.0e4)]\n    pub min_float: f64,\n    /// Use \"quiet softmax\" instead of regular softmax.\n    ///\n    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).\n    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.\n    ///\n    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>\n    #[config(default = false)]\n    pub quiet_softmax: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// The multihead attention module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).\n///\n/// # Params\n///\n/// - `query`: [`Linear`] layer with `d_model` input and output features.\n/// - `key`: [`Linear`] layer with `d_model` input and output features.\n/// - `value`: [`Linear`] layer with `d_model` input and output features.\n/// - `output`: [`Linear`] layer with `d_model` input and output features.\n///\n/// Should be created with [MultiHeadAttentionConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct MultiHeadAttention<B: Backend> {\n    /// Linear layer to transform the input features into the query space.\n    pub query: Linear<B>,\n    /// Linear layer to transform the input features into the key space.\n    pub key: Linear<B>,\n    /// Linear layer to transform the input features into the value space.\n    pub value: Linear<B>,\n    /// Linear layer to transform the output features back to the original space.\n    pub output: Linear<B>,\n    /// Dropout layer.\n    pub dropout: Dropout,\n    /// Activation function.\n    pub activation: Gelu,\n    /// The size of each linear layer.\n    pub d_model: usize,\n    /// The number of heads.\n    pub n_heads: usize,\n    /// Size of the key and query vectors.\n    pub d_k: usize,\n    /// Minimum value a float can take.\n    pub min_float: f64,\n    /// Use \"quiet softmax\" instead of regular softmax.\n    pub quiet_softmax: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"d_model\", &self.d_model)\n            .add(\"n_heads\", &self.n_heads)\n            .add(\"d_k\", &self.d_k)\n            .add(\"dropout\", &self.dropout.prob)\n            .add(\"min_float\", &self.min_float)\n            .add(\"quiet_softmax\", &self.quiet_softmax)\n            .optional()\n    }\n}\n\n/// [Multihead attention](MultiHeadAttention) forward pass input argument.\n#[derive(Debug, Clone)]\npub struct MhaInput<B: Backend> {\n    /// Shape `[batch_size, seq_length_1, d_model]`\n    query: Tensor<B, 3>,\n    /// Shape `[batch_size, seq_length_2, d_model]`\n    key: Tensor<B, 3>,\n    /// Shape `[batch_size, seq_length_2, d_model]`\n    value: Tensor<B, 3>,\n    mask_pad: Option<Tensor<B, 2, Bool>>,\n    mask_attn: Option<Tensor<B, 3, Bool>>,\n}\n\nimpl MultiHeadAttentionConfig {\n    /// Initialize a new [multihead attention](MultiHeadAttention) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {\n        let linear = |config: &Self| {\n            LinearConfig::new(config.d_model, config.d_model)\n                .with_initializer(self.initializer.clone())\n                .init(device)\n        };\n\n        MultiHeadAttention {\n            query: linear(self),\n            key: linear(self),\n            value: linear(self),\n            output: linear(self),\n            dropout: DropoutConfig::new(self.dropout).init(),\n            activation: Gelu::new(),\n            n_heads: self.n_heads,\n            d_k: self.d_model / self.n_heads,\n            min_float: self.min_float,\n            quiet_softmax: self.quiet_softmax,\n            d_model: self.d_model,\n        }\n    }\n}\n\nimpl<B: Backend> MhaInput<B> {\n    /// Create a [multihead attention](MultiHeadAttention) input argument\n    /// by setting the query, key and value to the given tensor.\n    ///\n    /// # Shape\n    /// - tensor: `[batch_size, seq_length, d_model]`\n    pub fn self_attn(tensor: Tensor<B, 3>) -> Self {\n        Self {\n            query: tensor.clone(),\n            key: tensor.clone(),\n            value: tensor,\n            mask_pad: None,\n            mask_attn: None,\n        }\n    }\n\n    /// Create a [multihead attention](MultiHeadAttention) input argument.\n    pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {\n        Self {\n            query,\n            key,\n            value,\n            mask_pad: None,\n            mask_attn: None,\n        }\n    }\n\n    /// Register the padding mask.\n    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {\n        self.mask_pad = Some(mask_pad);\n        self\n    }\n\n    /// Register the attention mask.\n    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {\n        self.mask_attn = Some(mask_attn);\n        self\n    }\n}\n\n/// [Multihead attention](MultiHeadAttention) outputs.\n#[derive(Debug, Clone)]\npub struct MhaOutput<B: Backend> {\n    /// The attention weights `[batch_size, n_heads, seq_length_1, seq_length_2]`.\n    pub weights: Tensor<B, 4>,\n    /// The context tensor `[batch_size, seq_length_1, d_model]`.\n    pub context: Tensor<B, 3>,\n}\n\nimpl<B: Backend> MultiHeadAttention<B> {\n    /// Applies the forward pass on the input tensors.\n    ///\n    /// See [MultiHeadAttention](MultiHeadAttention) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - query: `[batch_size, seq_length_1, d_model]`\n    /// - key: `[batch_size, seq_length_2, d_model]`\n    /// - value: `[batch_size, seq_length_2, d_model]`\n    /// - output: `[batch_size, seq_length_1, d_model]`\n    pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {\n        let [batch_size, seq_length_1, d_model] = input.query.dims();\n\n        let query = self.attention_linear(input.query, &self.query);\n        let key = self.attention_linear(input.key, &self.key);\n        let value = self.attention_linear(input.value, &self.value);\n\n        let attn_scores = self.attn_scores(query, key);\n        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);\n\n        let context = weights.clone().matmul(value);\n        let context = context\n            .swap_dims(1, 2)\n            .reshape([batch_size, seq_length_1, d_model]);\n        let context = self.output.forward(context);\n\n        MhaOutput { weights, context }\n    }\n\n    /// Applies the forward pass using a cache.\n    ///\n    /// # Shapes\n    ///\n    /// - query: `[batch_size, seq_length_1, d_model]`\n    /// - key: `[batch_size, seq_length_2, d_model]`\n    /// - value: `[batch_size, seq_length_2, d_model]`\n    /// - output: `[batch_size, seq_length_1, d_model]`\n    pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {\n        let [batch_size, seq_length_1, d_model] = input.query.dims();\n\n        let query = cache\n            .query\n            .forward(input.query, |t| self.attention_linear(t, &self.query));\n        let key = cache\n            .key\n            .forward(input.key, |t| self.attention_linear(t, &self.key));\n        let value = cache\n            .value\n            .forward(input.value, |t| self.attention_linear(t, &self.value));\n\n        let attn_scores = self.attn_scores(query, key);\n        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);\n\n        let context = weights.clone().matmul(value);\n        let context = context\n            .swap_dims(1, 2)\n            .reshape([batch_size, seq_length_1, d_model]);\n\n        let context = cache.output.forward(context, |t| self.output.forward(t));\n\n        MhaOutput { weights, context }\n    }\n\n    fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {\n        let attn_scores = query\n            .matmul(key.transpose())\n            .div_scalar((self.d_k as f32).sqrt());\n\n        self.dropout.forward(attn_scores)\n    }\n\n    fn attn_weights(\n        &self,\n        mut attn_scores: Tensor<B, 4>,\n        mask_pad: Option<Tensor<B, 2, Bool>>,\n        mask_attn: Option<Tensor<B, 3, Bool>>,\n    ) -> Tensor<B, 4> {\n        if let Some(mask_pad) = mask_pad {\n            let [batch_size, seq_length] = mask_pad.dims();\n\n            attn_scores = attn_scores.mask_fill(\n                mask_pad.reshape([batch_size, 1, 1, seq_length]),\n                self.min_float,\n            );\n        }\n\n        if let Some(mask_attn) = mask_attn {\n            let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();\n\n            attn_scores = attn_scores.mask_fill(\n                mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),\n                self.min_float,\n            );\n        }\n\n        if self.quiet_softmax {\n            quiet_softmax(attn_scores, 3)\n        } else {\n            softmax(attn_scores, 3)\n        }\n    }\n\n    fn attention_linear(&self, x: Tensor<B, 3>, linear: &Linear<B>) -> Tensor<B, 4> {\n        let [batch_size, seq_length, _d_model] = x.dims();\n        linear\n            .forward(x)\n            .reshape([batch_size, seq_length, self.n_heads, self.d_k])\n            .swap_dims(1, 2)\n    }\n}\n\n/// Cache for the [Multi Head Attention](MultiHeadAttention) layer.\n///\n/// To be used during inference when decoding tokens.\npub struct MhaCache<B: Backend> {\n    query: MhaLinearCache<B, 4>,\n    key: MhaLinearCache<B, 4>,\n    value: MhaLinearCache<B, 4>,\n    output: MhaLinearCache<B, 3>,\n}\n\nenum MhaLinearCache<B: Backend, const D: usize> {\n    Autoregressive(TensorCache<B, D>, usize),\n    Full(TensorCache<B, D>),\n}\n\nimpl<B: Backend> MhaCache<B> {\n    /// Initialize a cache for autoregressive inference.\n    pub fn autoregressive() -> Self {\n        Self {\n            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),\n            key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),\n            value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),\n            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),\n        }\n    }\n\n    /// Initialize a cache for autoregressive inference, but with a fixed memory used for keys and\n    /// values (cross-attention).\n    pub fn autoregressive_cross_attention() -> Self {\n        Self {\n            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),\n            key: MhaLinearCache::Full(TensorCache::empty()),\n            value: MhaLinearCache::Full(TensorCache::empty()),\n            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),\n        }\n    }\n}\n\nimpl<B: Backend, const D: usize> MhaLinearCache<B, D> {\n    pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(\n        &mut self,\n        tensor: Tensor<B, 3>,\n        func: F,\n    ) -> Tensor<B, D> {\n        match self {\n            MhaLinearCache::Autoregressive(cache, dim) => {\n                cache.forward_autoregressive(tensor, *dim, func)\n            }\n            MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, attention::generate_autoregressive_mask};\n    use alloc::vec::Vec;\n    use burn::tensor::Int;\n    use burn::tensor::Tolerance;\n    use burn::tensor::ops::FloatElem;\n    use burn::tensor::{Distribution, Shape};\n\n    #[test]\n    fn test_self_attention_shapes() {\n        let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];\n        let device = Default::default();\n        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);\n        let input = MhaInput::self_attn(Tensor::random(\n            [batch_size, seq_length, d_model],\n            Distribution::Default,\n            &device,\n        ));\n\n        let output = mha.forward(input);\n\n        assert_eq!(\n            output.context.shape(),\n            Shape::new([batch_size, seq_length, d_model]),\n            \"Context should have the correct shape\",\n        );\n        assert_eq!(\n            output.weights.shape(),\n            Shape::new([batch_size, n_heads, seq_length, seq_length]),\n            \"Weights should have the correct shape\",\n        );\n    }\n\n    #[test]\n    fn test_generic_mha_shapes() {\n        let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];\n        let mha = MultiHeadAttentionConfig::new(d_model, n_heads)\n            .init::<TestBackend>(&Default::default());\n        let device = Default::default();\n        let input = MhaInput::new(\n            Tensor::random(\n                [batch_size, seq_length_1, d_model],\n                Distribution::Default,\n                &device,\n            ),\n            Tensor::random(\n                [batch_size, seq_length_2, d_model],\n                Distribution::Default,\n                &device,\n            ),\n            Tensor::random(\n                [batch_size, seq_length_2, d_model],\n                Distribution::Default,\n                &device,\n            ),\n        );\n\n        let output = mha.forward(input);\n\n        assert_eq!(\n            output.context.shape(),\n            Shape::new([batch_size, seq_length_1, d_model]),\n            \"Context should have the correct shape\",\n        );\n        assert_eq!(\n            output.weights.shape(),\n            Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),\n            \"Weights should have the correct shape\",\n        );\n    }\n\n    #[test]\n    fn test_self_attention_mask_pad() {\n        let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];\n        let device = Default::default();\n        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);\n\n        // Create a padding mask\n        let mask_pad: Tensor<TestBackend, 2, Int> =\n            Tensor::zeros([batch_size, seq_length], &device);\n        let mask_pad = mask_pad.slice_assign(\n            [0..batch_size, seq_length - num_padded..seq_length],\n            Tensor::ones([batch_size, num_padded], &device),\n        );\n        let mask_pad = mask_pad.equal_elem(1).to_device(&device);\n\n        let tensor_1 = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_length, d_model],\n            Distribution::Default,\n            &device,\n        );\n        // Change the end of the tensor\n        let tensor_2 = tensor_1.clone().slice_assign(\n            [\n                0..batch_size,\n                seq_length - num_padded..seq_length,\n                0..d_model,\n            ],\n            Tensor::random(\n                [batch_size, num_padded, d_model],\n                Distribution::Default,\n                &device,\n            ),\n        );\n\n        let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());\n        let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);\n\n        let output_1 = mha.forward(input_1);\n        let output_2 = mha.forward(input_2);\n\n        // Check that the beginning of each tensor is the same\n        output_1\n            .context\n            .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])\n            .into_data()\n            .assert_approx_eq(\n                &output_2\n                    .context\n                    .slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])\n                    .into_data(),\n                Tolerance::<f32>::default(),\n            );\n    }\n\n    #[test]\n    fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {\n        let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];\n        let device = Default::default();\n        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);\n\n        let tensor = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_length, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());\n        let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);\n\n        let output_1 = mha.forward(input);\n        let mut output_2 = Vec::new();\n        let mut cache = MhaCache::autoregressive();\n\n        for i in 1..seq_length + 1 {\n            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);\n            let input = MhaInput::self_attn(tensor);\n            let next_tok = mha.forward_cache(input, &mut cache).context.slice([\n                0..batch_size,\n                i - 1..i,\n                0..d_model,\n            ]);\n            output_2.push(next_tok);\n        }\n\n        let output_2 = Tensor::cat(output_2, 1);\n\n        output_1\n            .context\n            .into_data()\n            .assert_approx_eq::<FloatElem<TestBackend>>(\n                &output_2.into_data(),\n                Tolerance::default(),\n            );\n    }\n\n    #[test]\n    fn display() {\n        let config = MultiHeadAttentionConfig::new(2, 4);\n        let mha = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{mha}\"),\n            \"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \\\n            dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/attention/mod.rs",
    "content": "mod cross_attention;\nmod mask;\nmod mha;\n\npub use cross_attention::*;\npub use mask::*;\npub use mha::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/cache/autoregressive.rs",
    "content": "use alloc::vec;\nuse burn_core as burn;\n\nuse super::{CacheState, TensorCache};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\nimpl<B: Backend, const D: usize> TensorCache<B, D> {\n    pub(crate) fn forward_autoregressive<F>(\n        &mut self,\n        tensor: Tensor<B, 3>,\n        dim_cat: usize,\n        func: F,\n    ) -> Tensor<B, D>\n    where\n        F: Fn(Tensor<B, 3>) -> Tensor<B, D>,\n    {\n        let mut tensor_old = CacheState::Empty;\n        core::mem::swap(&mut self.state, &mut tensor_old);\n\n        let tensor_new = match tensor_old {\n            CacheState::Value(tensor_old) => {\n                let [batch_size, seq_length, d_model] = tensor.dims();\n                let next_seq_token =\n                    tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);\n                let next_seq_token = func(next_seq_token);\n\n                Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)\n            }\n            _ => func(tensor),\n        };\n\n        self.state = CacheState::Value(tensor_new.clone());\n        tensor_new\n    }\n\n    pub(crate) fn forward_full<F>(&mut self, tensor: Tensor<B, 3>, func: F) -> Tensor<B, D>\n    where\n        F: Fn(Tensor<B, 3>) -> Tensor<B, D>,\n    {\n        let mut tensor_old = CacheState::Empty;\n        core::mem::swap(&mut self.state, &mut tensor_old);\n\n        let tensor_new = match tensor_old {\n            CacheState::Value(tensor_old) => tensor_old,\n            _ => func(tensor),\n        };\n\n        self.state = CacheState::Value(tensor_new.clone());\n        tensor_new\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/cache/base.rs",
    "content": "use burn_core as burn;\n\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\npub(crate) enum CacheState<T> {\n    Value(T),\n    Empty,\n}\n\n/// A cache for a tensor.\npub struct TensorCache<B: Backend, const D: usize> {\n    pub(crate) state: CacheState<Tensor<B, D>>,\n}\n\nimpl<B: Backend, const D: usize> TensorCache<B, D> {\n    /// Creates a new empty cache.\n    ///\n    /// # Returns\n    ///\n    /// The empty cache.\n    pub fn empty() -> Self {\n        Self {\n            state: CacheState::Empty,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/cache/mod.rs",
    "content": "mod autoregressive;\nmod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/checks.rs",
    "content": "pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize, groups: usize) {\n    let channels_in_div_by_group = channels_in.is_multiple_of(groups);\n    let channels_out_div_by_group = channels_out.is_multiple_of(groups);\n\n    if !channels_in_div_by_group || !channels_out_div_by_group {\n        panic!(\n            \"Both channels must be divisible by the number of groups. Got \\\n             channels_in={channels_in}, channels_out={channels_out}, groups={groups}\"\n        );\n    }\n}\n\n// https://github.com/tracel-ai/burn/issues/2676\n/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel\n/// size is not supported as it will not produce the same output size.\npub(crate) fn check_same_padding_support(kernel_size: &[usize]) {\n    for k in kernel_size.iter() {\n        if k % 2 == 0 {\n            unimplemented!(\"Same padding with an even kernel size is not supported\");\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv1d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::{PaddingConfig1d, conv::checks};\nuse burn::tensor::{Tensor, backend::Backend, module::conv1d, ops::PaddedConvOptions};\nuse burn::{\n    config::Config,\n    module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay, Param},\n};\n\n/// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init).\n#[derive(Config, Debug)]\npub struct Conv1dConfig {\n    /// The number of input channels.\n    pub channels_in: usize,\n    /// The number of output channels.\n    pub channels_out: usize,\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The stride of the convolution.\n    #[config(default = \"1\")]\n    pub stride: usize,\n    /// Spacing between kernel elements.\n    #[config(default = \"1\")]\n    pub dilation: usize,\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig1d::Valid\")]\n    pub padding: PaddingConfig1d,\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 1D convolution over input tensors.\n///\n/// Should be created with [Conv1dConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Conv1d<B: Backend> {\n    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size]`\n    pub weight: Param<Tensor<B, 3>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: usize,\n    /// Size of the kernel.\n    pub kernel_size: usize,\n    /// Spacing between kernel elements.\n    pub dilation: usize,\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// Padding configuration.\n    pub padding: PaddingConfig1d,\n}\n\nimpl<B: Backend> ModuleDisplay for Conv1d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        // Format stride/dilation as strings\n        let stride = format!(\"{:?}\", self.stride);\n        let kernel_size = format!(\"{:?}\", self.kernel_size);\n        let dilation = format!(\"{:?}\", self.dilation);\n\n        // Extract channels in/out from weight dims\n        let [channels_out, group_channels_in, _] = self.weight.dims();\n        let channels_in = group_channels_in * self.groups;\n        let ch_out = format!(\"{:?}\", channels_out);\n        let ch_in = format!(\"{:?}\", channels_in);\n\n        content\n            .add(\"ch_in\", &ch_in)\n            .add(\"ch_out\", &ch_out)\n            .add(\"stride\", &stride)\n            .add(\"kernel_size\", &kernel_size)\n            .add(\"dilation\", &dilation)\n            .add(\"groups\", &self.groups)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .optional()\n    }\n}\nimpl Conv1dConfig {\n    /// Initialize a new [conv1d](Conv1d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv1d<B> {\n        checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups);\n\n        let shape = [\n            self.channels_out,\n            self.channels_in / self.groups,\n            self.kernel_size,\n        ];\n\n        let fan_in: usize = self.channels_in / self.groups * self.kernel_size;\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), None, device);\n        let mut bias = None;\n\n        if self.bias {\n            bias =\n                Some(\n                    self.initializer\n                        .init_with([self.channels_out], Some(fan_in), None, device),\n                );\n        }\n\n        Conv1d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            padding: self.padding.clone(),\n            dilation: self.dilation,\n            groups: self.groups,\n        }\n    }\n}\n\nimpl<B: Backend> Conv1d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [conv1d](burn::tensor::module::conv1d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, length_in]`\n    /// - output: `[batch_size, channels_out, length_out]`\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let length = input.dims()[2];\n\n        // Calculate padding as pair - handles Same, Valid, and Explicit uniformly\n        let (left, right) =\n            self.padding\n                .calculate_padding_1d_pair(length, self.kernel_size, self.stride);\n\n        let options = PaddedConvOptions::asymmetric(\n            [self.stride],\n            [left],\n            [right],\n            [self.dilation],\n            self.groups,\n        );\n\n        conv1d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            options,\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::{ElementConversion, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv1dConfig::new(5, 5, 5);\n        let k = (config.channels_in * config.kernel_size) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&device);\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv1dConfig::new(5, 5, 5).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight\n            .to_data()\n            .assert_eq(&TensorData::zeros::<FT, _>(conv.weight.shape()), false);\n    }\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = Conv1dConfig::new(4, 4, 2)\n            .with_padding(PaddingConfig1d::Same)\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=4, length=5]\n        let input = Tensor::<TestBackend, 3>::ones([1, 4, 5], &device);\n        let output = conv.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 4, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = Conv1dConfig::new(5, 5, 5);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{conv}\"),\n            \"Conv1d {ch_in: 5, ch_out: 5, stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = Conv1dConfig::new(5, 3, 3);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create conv with asymmetric padding: left=1, right=2\n        let config = Conv1dConfig::new(2, 3, 3)\n            .with_padding(PaddingConfig1d::Explicit(1, 2))\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = conv.forward(input);\n\n        // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7\n        // Output length = (7 - 3) / 1 + 1 = 5\n        assert_eq!(output.dims(), [1, 3, 5]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create conv with symmetric explicit padding: left=2, right=2\n        let config = Conv1dConfig::new(2, 3, 3)\n            .with_padding(PaddingConfig1d::Explicit(2, 2))\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = conv.forward(input);\n\n        // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8\n        // Output length = (8 - 3) / 1 + 1 = 6\n        assert_eq!(output.dims(), [1, 3, 6]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv2d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::PaddingConfig2d;\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::conv2d;\nuse burn::tensor::ops::PaddedConvOptions;\n\nuse crate::conv::checks;\n\n/// Configuration to create a [2D convolution](Conv2d) layer, using the [init function](Conv2dConfig::init).\n#[derive(Config, Debug)]\npub struct Conv2dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1]\")]\n    pub stride: [usize; 2],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1]\")]\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig2d::Valid\")]\n    pub padding: PaddingConfig2d,\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 2D convolution over input tensors.\n///\n/// Should be created with [Conv2dConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Conv2d<B: Backend> {\n    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`\n    pub weight: Param<Tensor<B, 4>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: [usize; 2],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// The padding configuration.\n    pub padding: PaddingConfig2d,\n}\n\nimpl Conv2dConfig {\n    /// Initialize a new [conv2d](Conv2d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);\n\n        let shape = [\n            self.channels[1],\n            self.channels[0] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n        ];\n\n        let k = self.kernel_size.iter().product::<usize>();\n        let fan_in = self.channels[0] / self.groups * k;\n        let fan_out = self.channels[1] / self.groups * k;\n\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), Some(fan_out), device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(self.initializer.init_with(\n                [self.channels[1]],\n                Some(fan_in),\n                Some(fan_out),\n                device,\n            ));\n        }\n\n        Conv2d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            padding: self.padding.clone(),\n            groups: self.groups,\n        }\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for Conv2d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.\n        let stride = format!(\"{:?}\", self.stride);\n        let kernel_size = format!(\"{:?}\", self.kernel_size);\n        let dilation = format!(\"{:?}\", self.dilation);\n        let [channels_out, group_channels_in, _, _] = self.weight.dims();\n        let channels_in = group_channels_in * self.groups;\n        let ch_out = format!(\"{:?}\", channels_out);\n        let ch_in = format!(\"{:?}\", channels_in);\n        content\n            .add(\"ch_in\", &ch_in)\n            .add(\"ch_out\", &ch_out)\n            .add(\"stride\", &stride)\n            .add(\"kernel_size\", &kernel_size)\n            .add(\"dilation\", &dilation)\n            .add(\"groups\", &self.groups)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .optional()\n    }\n}\n\nimpl<B: Backend> Conv2d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [conv2d](burn::tensor::module::conv2d) for more information.\n    ///\n    /// # Shapes\n    /// - `input`: `[batch_size, channels_in, height_in, width_in]`\n    /// - `output`: `[batch_size, channels_out, height_out, width_out]`\n    ///\n    /// # Example\n    /// ```rust,ignore\n    /// use burn::nn::conv::Conv2dConfig;\n    /// use burn::tensor::Tensor;\n    ///\n    /// // Assuming backend type alias `B`\n    /// let device = Default::default();\n    /// let conv = Conv2dConfig::new([3, 8], [3, 3]).init::<B>(&device);\n    ///\n    /// let x = Tensor::<B, 4>::zeros([1, 3, 28, 28], &device);\n    /// let y = conv.forward(x);\n    ///\n    /// println!(\"{:?}\", y.dims()); // [1, 8, 26, 26]\n    /// ```\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let [_batch_size, _channels_in, height_in, width_in] = input.dims();\n\n        // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly\n        let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(\n            height_in,\n            width_in,\n            &self.kernel_size,\n            &self.stride,\n        );\n\n        let options = PaddedConvOptions::asymmetric(\n            self.stride,\n            [top, left],\n            [bottom, right],\n            self.dilation,\n            self.groups,\n        );\n\n        conv2d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            options,\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::ops::FloatElem;\n    use burn::tensor::{ElementConversion, Tolerance};\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    type FT = FloatElem<TestBackend>; // Float test\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv2dConfig::new([5, 1], [5, 5]);\n        let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&device);\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<FT, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn initializer_fan_out() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true, // test that fan_out is passed to `init_with()`\n        };\n\n        let config = Conv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    fn initializer_fan_with_groups_is_valid() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true,\n        };\n\n        let config = Conv2dConfig::new([4, 4], [1, 1])\n            .with_initializer(init.clone())\n            .with_groups(4);\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    #[should_panic = \"Both channels must be divisible by the number of groups.\"]\n    fn channels_with_groups_is_invalid() {\n        let device = Default::default();\n        let config = Conv2dConfig::new([1, 4], [1, 1]).with_groups(4);\n        let _ = config.init::<TestBackend>(&device);\n    }\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = Conv2dConfig::new([4, 4], [2, 2])\n            .with_padding(PaddingConfig2d::Same)\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=4, height=5, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 4, 5, 5], &device);\n        let output = conv.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 4, 5, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = Conv2dConfig::new([5, 1], [5, 5]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{conv}\"),\n            \"Conv2d {ch_in: 5, ch_out: 1, stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = Conv2dConfig::new([5, 3], [3, 3]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create conv with asymmetric padding: top=1, left=2, bottom=3, right=4\n        let config = Conv2dConfig::new([2, 3], [3, 3])\n            .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4))\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = conv.forward(input);\n\n        // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9\n        assert_eq!(output.dims(), [1, 3, 6, 9]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create conv with symmetric explicit padding: top=2, left=2, bottom=2, right=2\n        let config = Conv2dConfig::new([2, 3], [3, 3])\n            .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))\n            .with_initializer(Initializer::Constant { value: 1.0 })\n            .with_bias(false);\n        let conv = config.init::<TestBackend>(&device);\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = conv.forward(input);\n\n        // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7\n        assert_eq!(output.dims(), [1, 3, 6, 7]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv3d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::PaddingConfig3d;\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::conv3d;\nuse burn::tensor::ops::ConvOptions;\n\nuse crate::conv::checks;\n\n/// Configuration to create a [3D convolution](Conv3d) layer, using the [init function](Conv3dConfig::init).\n#[derive(Config, Debug)]\npub struct Conv3dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 3],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1, 1]\")]\n    pub stride: [usize; 3],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1, 1]\")]\n    pub dilation: [usize; 3],\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    #[config(default = \"PaddingConfig3d::Valid\")]\n    pub padding: PaddingConfig3d,\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 3D convolution over input tensors.\n///\n/// Should be created with [Conv3dConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Conv3d<B: Backend> {\n    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2, kernel_size_3]`\n    pub weight: Param<Tensor<B, 5>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: [usize; 3],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 3],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 3],\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// The padding configuration.\n    pub padding: PaddingConfig3d,\n}\n\nimpl Conv3dConfig {\n    /// Initialize a new [conv3d](Conv3d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Conv3d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);\n        if self.padding == PaddingConfig3d::Same {\n            checks::check_same_padding_support(&self.kernel_size);\n        }\n\n        let shape = [\n            self.channels[1],\n            self.channels[0] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n            self.kernel_size[2],\n        ];\n\n        let k = self.kernel_size.iter().product::<usize>();\n        let fan_in = self.channels[0] / self.groups * k;\n        let fan_out = self.channels[1] / self.groups * k;\n\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), Some(fan_out), device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(self.initializer.init_with(\n                [self.channels[1]],\n                Some(fan_in),\n                Some(fan_out),\n                device,\n            ));\n        }\n\n        Conv3d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            padding: self.padding.clone(),\n            groups: self.groups,\n        }\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for Conv3d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        // Format arrays as strings (consistent with Conv2d/Conv1d).\n        let stride = format!(\"{:?}\", self.stride);\n        let kernel_size = format!(\"{:?}\", self.kernel_size);\n        let dilation = format!(\"{:?}\", self.dilation);\n\n        // Weight dims: [channels_out, channels_in/groups, k1, k2, k3]\n        let [channels_out, group_channels_in, _, _, _] = self.weight.dims();\n        let channels_in = group_channels_in * self.groups;\n        let ch_out = format!(\"{:?}\", channels_out);\n        let ch_in = format!(\"{:?}\", channels_in);\n\n        content\n            .add(\"ch_in\", &ch_in)\n            .add(\"ch_out\", &ch_out)\n            .add(\"stride\", &stride)\n            .add(\"kernel_size\", &kernel_size)\n            .add(\"dilation\", &dilation)\n            .add(\"groups\", &self.groups)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .optional()\n    }\n}\n\nimpl<B: Backend> Conv3d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [conv3d](burn::tensor::module::conv3d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`\n    /// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`\n    pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {\n        let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims();\n        let padding = self.padding.calculate_padding_3d(\n            depth_in,\n            height_in,\n            width_in,\n            &self.kernel_size,\n            &self.stride,\n        );\n        conv3d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            ConvOptions::new(self.stride, padding, self.dilation, self.groups),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv3dConfig::new([5, 1], [5, 5, 5]);\n        let k = (config.channels[0]\n            * config.kernel_size[0]\n            * config.kernel_size[1]\n            * config.kernel_size[2]) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&device);\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = Conv3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);\n        let device = Default::default();\n        let conv = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<f32, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn initializer_fan_out() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true, // test that fan_out is passed to `init_with()`\n        };\n        let config = Conv3dConfig::new([5, 1], [5, 5, 5]).with_initializer(init.clone());\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    fn initializer_fan_with_groups_is_valid() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true,\n        };\n\n        let config = Conv3dConfig::new([4, 4], [1, 1, 1])\n            .with_initializer(init.clone())\n            .with_groups(4);\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    #[should_panic = \"Same padding with an even kernel size is not supported\"]\n    fn same_with_even_kernel_is_invalid() {\n        let device = Default::default();\n        let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same);\n        let _ = config.init::<TestBackend>(&device);\n    }\n\n    #[test]\n    fn display() {\n        let config = Conv3dConfig::new([5, 1], [5, 5, 5]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{conv}\"),\n            \"Conv3d {ch_in: 5, ch_out: 1, stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: Valid, params: 626}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = Conv3dConfig::new([5, 3], [3, 3, 3]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv_transpose1d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::conv::checks;\nuse burn::config::Config;\nuse burn::module::Content;\nuse burn::module::DisplaySettings;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::ModuleDisplay;\nuse burn::module::Param;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::conv_transpose1d;\nuse burn::tensor::ops::ConvTransposeOptions;\n\n/// Configuration to create an [1D transposed convolution](ConvTranspose1d) layer\n/// using the [init function](ConvTranspose1dConfig::init).\n#[derive(Config, Debug)]\npub struct ConvTranspose1dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The stride of the convolution.\n    #[config(default = \"1\")]\n    pub stride: usize,\n    /// Spacing between kernel elements.\n    #[config(default = \"1\")]\n    pub dilation: usize,\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    #[config(default = \"0\")]\n    pub padding: usize,\n    /// The padding output configuration.\n    #[config(default = \"0\")]\n    pub padding_out: usize,\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 1D transposed convolution over input tensors.\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct ConvTranspose1d<B: Backend> {\n    /// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`\n    pub weight: Param<Tensor<B, 3>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: usize,\n    /// Size of the kernel.\n    pub kernel_size: usize,\n    /// Spacing between kernel elements.\n    pub dilation: usize,\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// The padding configuration.\n    pub padding: usize,\n    /// The padding output configuration.\n    pub padding_out: usize,\n    /// The number of channels.\n    pub channels: [usize; 2],\n}\n\nimpl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"channels\", &format!(\"{:?}\", &self.channels))\n            .add(\"stride\", &self.stride)\n            .add(\"kernel_size\", &self.kernel_size)\n            .add(\"dilation\", &self.dilation)\n            .add(\"groups\", &self.groups)\n            .add(\"padding\", &self.padding)\n            .add(\"padding_out\", &self.padding_out)\n            .optional()\n    }\n}\n\nimpl ConvTranspose1dConfig {\n    /// Initialize a new [conv transpose 1d](ConvTranspose1d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose1d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);\n\n        let shape = [\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size,\n        ];\n\n        let fan_in = self.channels[1] / self.groups * self.kernel_size;\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), None, device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(\n                self.initializer\n                    .init_with([self.channels[1]], Some(fan_in), None, device),\n            );\n        }\n\n        ConvTranspose1d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            groups: self.groups,\n            padding: self.padding,\n            padding_out: self.padding_out,\n            channels: self.channels,\n        }\n    }\n}\n\nimpl<B: Backend> ConvTranspose1d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See also [conv_transpose1d](burn::tensor::module::conv_transpose1d).\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, length_in]`\n    /// - output: `[batch_size, channels_out, length_out]`\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        conv_transpose1d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            ConvTransposeOptions::new(\n                [self.stride],\n                [self.padding],\n                [self.padding_out],\n                [self.dilation],\n                self.groups,\n            ),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::ops::FloatElem;\n    use burn::tensor::{ElementConversion, Tolerance};\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = ConvTranspose1dConfig::new([5, 1], 5);\n        let k = (config.channels[1] * config.kernel_size) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = ConvTranspose1dConfig::new([5, 2], 5).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<f32>(\n            &TensorData::zeros::<f32, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn display() {\n        let config = ConvTranspose1dConfig::new([5, 2], 5);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{conv}\"),\n            \"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = ConvTranspose1dConfig::new([5, 3], 3);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 3>::zeros([1, 4, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv_transpose2d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::conv::checks;\nuse burn::config::Config;\nuse burn::module::Content;\nuse burn::module::DisplaySettings;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::ModuleDisplay;\nuse burn::module::Param;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::conv_transpose2d;\nuse burn::tensor::ops::ConvTransposeOptions;\n\n/// Configuration to create an [2D transposed convolution](ConvTranspose2d) layer\n/// using the [init function](ConvTranspose2dConfig::init).\n#[derive(Config, Debug)]\npub struct ConvTranspose2dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1]\")]\n    pub stride: [usize; 2],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1]\")]\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    #[config(default = \"[0, 0]\")]\n    pub padding: [usize; 2],\n    /// The padding output configuration.\n    #[config(default = \"[0, 0]\")]\n    pub padding_out: [usize; 2],\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 2D transposed convolution over input tensors.\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct ConvTranspose2d<B: Backend> {\n    /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`\n    pub weight: Param<Tensor<B, 4>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: [usize; 2],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// Padding configuration.\n    pub padding: [usize; 2],\n    /// Padding output configuration.\n    pub padding_out: [usize; 2],\n    /// Number of channels.\n    pub channels: [usize; 2],\n}\n\nimpl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"channels\", &format!(\"{:?}\", &self.channels))\n            .add(\"stride\", &format!(\"{:?}\", &self.stride))\n            .add(\"kernel_size\", &format!(\"{:?}\", &self.kernel_size))\n            .add(\"dilation\", &format!(\"{:?}\", &self.dilation))\n            .add(\"groups\", &self.groups)\n            .add(\"padding\", &format!(\"{:?}\", &self.padding))\n            .add(\"padding_out\", &format!(\"{:?}\", &self.padding_out))\n            .optional()\n    }\n}\n\nimpl ConvTranspose2dConfig {\n    /// Initialize a new [conv transpose 2d](ConvTranspose2d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose2d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);\n\n        let shape = [\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n        ];\n\n        let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), None, device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(\n                self.initializer\n                    .init_with([self.channels[1]], Some(fan_in), None, device),\n            );\n        }\n\n        ConvTranspose2d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            groups: self.groups,\n            padding: self.padding,\n            padding_out: self.padding_out,\n            channels: self.channels,\n        }\n    }\n}\n\nimpl<B: Backend> ConvTranspose2d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See also [conv_transpose2d](burn::tensor::module::conv_transpose2d).\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, height_in, width_in]`\n    /// - output: `[batch_size, channels_out, height_out, width_out]`\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        conv_transpose2d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            ConvTransposeOptions::new(\n                self.stride,\n                self.padding,\n                self.padding_out,\n                self.dilation,\n                self.groups,\n            ),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = ConvTranspose2dConfig::new([5, 1], [5, 5]);\n        let k = (config.channels[1] * config.kernel_size[0] * config.kernel_size[1]) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config =\n            ConvTranspose2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<f32, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn display() {\n        let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{conv}\"),\n            \"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = ConvTranspose2dConfig::new([5, 3], [3, 3]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/conv_transpose3d.rs",
    "content": "use alloc::format;\n\nuse burn_core as burn;\n\nuse crate::conv::checks;\nuse burn::config::Config;\nuse burn::module::Content;\nuse burn::module::DisplaySettings;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::ModuleDisplay;\nuse burn::module::Param;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::conv_transpose3d;\nuse burn::tensor::ops::ConvTransposeOptions;\n\n/// Configuration to create an [3D transposed convolution](ConvTranspose3d) layer\n/// using the [init function](ConvTranspose3dConfig::init).\n#[derive(Config, Debug)]\npub struct ConvTranspose3dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 3],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1, 1]\")]\n    pub stride: [usize; 3],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1, 1]\")]\n    pub dilation: [usize; 3],\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub groups: usize,\n    /// The padding configuration.\n    #[config(default = \"[0, 0, 0]\")]\n    pub padding: [usize; 3],\n    /// The padding output configuration.\n    #[config(default = \"[0, 0, 0]\")]\n    pub padding_out: [usize; 3],\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a 3D transposed convolution over input tensors.\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct ConvTranspose3d<B: Backend> {\n    /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2, kernel_size_3]`\n    pub weight: Param<Tensor<B, 5>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: [usize; 3],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 3],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 3],\n    /// Controls the connections between input and output channels.\n    pub groups: usize,\n    /// Padding configuration.\n    pub padding: [usize; 3],\n    /// Padding output configuration.\n    pub padding_out: [usize; 3],\n    /// Number of channels.\n    pub channels: [usize; 2],\n}\n\nimpl<B: Backend> ModuleDisplay for ConvTranspose3d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"channels\", &format!(\"{:?}\", &self.channels))\n            .add(\"stride\", &format!(\"{:?}\", &self.stride))\n            .add(\"kernel_size\", &format!(\"{:?}\", &self.kernel_size))\n            .add(\"dilation\", &format!(\"{:?}\", &self.dilation))\n            .add(\"groups\", &self.groups)\n            .add(\"padding\", &format!(\"{:?}\", &self.padding))\n            .add(\"padding_out\", &format!(\"{:?}\", &self.padding_out))\n            .optional()\n    }\n}\n\nimpl ConvTranspose3dConfig {\n    /// Initialize a new [conv transpose 2d](ConvTranspose3d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> ConvTranspose3d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups);\n\n        let shape = [\n            self.channels[0],\n            self.channels[1] / self.groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n            self.kernel_size[2],\n        ];\n\n        let fan_in = self.channels[1] / self.groups * self.kernel_size.iter().product::<usize>();\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), None, device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(\n                self.initializer\n                    .init_with([self.channels[1]], Some(fan_in), None, device),\n            );\n        }\n\n        ConvTranspose3d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            groups: self.groups,\n            padding: self.padding,\n            padding_out: self.padding_out,\n            channels: self.channels,\n        }\n    }\n}\n\nimpl<B: Backend> ConvTranspose3d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See also [conv_transpose3d](burn::tensor::module::conv_transpose3d).\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, depth_in, height_in, width_in]`\n    /// - output: `[batch_size, channels_out, depth_out, height_out, width_out]`\n    pub fn forward(&self, input: Tensor<B, 5>) -> Tensor<B, 5> {\n        conv_transpose3d(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|bias| bias.val()),\n            ConvTransposeOptions::new(\n                self.stride,\n                self.padding,\n                self.padding_out,\n                self.dilation,\n                self.groups,\n            ),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = ConvTranspose3dConfig::new([5, 1], [5, 5, 5]);\n        let k = (config.channels[1]\n            * config.kernel_size[0]\n            * config.kernel_size[1]\n            * config.kernel_size[2]) as f64;\n        let k = (config.groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config =\n            ConvTranspose3dConfig::new([5, 2], [5, 5, 5]).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<f32>(\n            &TensorData::zeros::<f32, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn display() {\n        let config = ConvTranspose3dConfig::new([5, 2], [5, 5, 5]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{conv}\"),\n            \"ConvTranspose3d {channels: [5, 2], stride: [1, 1, 1], kernel_size: [5, 5, 5], dilation: [1, 1, 1], groups: 1, padding: [0, 0, 0], padding_out: [0, 0, 0], params: 1252}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = ConvTranspose3dConfig::new([5, 3], [3, 3, 3]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 5>::zeros([1, 4, 10, 10, 10], &Default::default());\n        let _ = conv.forward(input);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/deform_conv2d.rs",
    "content": "use alloc::format;\nuse burn::tensor::ops::DeformConvOptions;\n\nuse burn_core as burn;\n\nuse crate::PaddingConfig2d;\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::deform_conv2d;\n\nuse crate::conv::checks;\n\n/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init).\n#[derive(Config, Debug)]\npub struct DeformConv2dConfig {\n    /// The number of channels.\n    pub channels: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1]\")]\n    pub stride: [usize; 2],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1]\")]\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    #[config(default = \"1\")]\n    pub weight_groups: usize,\n    /// Offset groups.\n    #[config(default = \"1\")]\n    pub offset_groups: usize,\n    /// The padding configuration.\n    ///\n    /// ### Warning\n    /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel\n    /// size is not supported as it will not produce the same output size.\n    #[config(default = \"PaddingConfig2d::Valid\")]\n    pub padding: PaddingConfig2d,\n    /// If bias should be added to the output.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n}\n\n/// Applies a deformable 2D convolution over input tensors.\n///\n/// Should be created with [DeformConv2dConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct DeformConv2d<B: Backend> {\n    /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`\n    pub weight: Param<Tensor<B, 4>>,\n    /// Tensor of shape `[channels_out]`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n    /// Stride of the convolution.\n    pub stride: [usize; 2],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 2],\n    /// Controls the connections between input and output channels.\n    pub weight_groups: usize,\n    /// Offset groups.\n    pub offset_groups: usize,\n    /// The padding configuration.\n    pub padding: PaddingConfig2d,\n}\n\nimpl DeformConv2dConfig {\n    /// Initialize a new [DeformConv2d](DeformConv2d) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {\n        checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);\n        if self.padding == PaddingConfig2d::Same {\n            checks::check_same_padding_support(&self.kernel_size);\n        }\n\n        let shape = [\n            self.channels[1],\n            self.channels[0] / self.weight_groups,\n            self.kernel_size[0],\n            self.kernel_size[1],\n        ];\n\n        let k = self.kernel_size.iter().product::<usize>();\n        let fan_in = self.channels[0] / self.weight_groups * k;\n        let fan_out = self.channels[1] / self.weight_groups * k;\n\n        let weight = self\n            .initializer\n            .init_with(shape, Some(fan_in), Some(fan_out), device);\n        let mut bias = None;\n\n        if self.bias {\n            bias = Some(self.initializer.init_with(\n                [self.channels[1]],\n                Some(fan_in),\n                Some(fan_out),\n                device,\n            ));\n        }\n\n        DeformConv2d {\n            weight,\n            bias,\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            dilation: self.dilation,\n            padding: self.padding.clone(),\n            weight_groups: self.weight_groups,\n            offset_groups: self.weight_groups,\n        }\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for DeformConv2d<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.\n        let stride = format!(\"{:?}\", self.stride);\n        let kernel_size = format!(\"{:?}\", self.kernel_size);\n        let dilation = format!(\"{:?}\", self.dilation);\n\n        content\n            .add(\"stride\", &stride)\n            .add(\"kernel_size\", &kernel_size)\n            .add(\"dilation\", &dilation)\n            .add(\"weight_groups\", &self.weight_groups)\n            .add(\"offset_groups\", &self.offset_groups)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .optional()\n    }\n}\n\nimpl<B: Backend> DeformConv2d<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [deform_conv2d](burn::tensor::module::deform_conv2d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels_in, height_in, width_in]`\n    /// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]`\n    /// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]`\n    /// - output: `[batch_size, channels_out, height_out, width_out]`\n    pub fn forward(\n        &self,\n        input: Tensor<B, 4>,\n        offset: Tensor<B, 4>,\n        mask: Option<Tensor<B, 4>>,\n    ) -> Tensor<B, 4> {\n        let [_batch_size, _channels_in, height_in, width_in] = input.dims();\n        let padding =\n            self.padding\n                .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);\n        deform_conv2d(\n            input,\n            offset,\n            self.weight.val(),\n            mask,\n            self.bias.as_ref().map(|bias| bias.val()),\n            DeformConvOptions::new(\n                self.stride,\n                padding,\n                self.dilation,\n                self.weight_groups,\n                self.offset_groups,\n            ),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = DeformConv2dConfig::new([5, 1], [5, 5]);\n        let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;\n        let k = (config.offset_groups as f64 / k).sqrt().elem::<FT>();\n        let conv = config.init::<TestBackend>(&device);\n\n        conv.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);\n        let conv = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        conv.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<f32, _>(conv.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn initializer_fan_out() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true, // test that fan_out is passed to `init_with()`\n        };\n\n        let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    fn initializer_fan_with_groups_is_valid() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let init = Initializer::KaimingUniform {\n            gain: 1.0 / 3.0f64.sqrt(),\n            fan_out_only: true,\n        };\n\n        let config = DeformConv2dConfig::new([4, 4], [1, 1])\n            .with_initializer(init.clone())\n            .with_weight_groups(4);\n        let _ = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, init);\n    }\n\n    #[test]\n    #[should_panic = \"Both channels must be divisible by the number of groups.\"]\n    fn channels_with_groups_is_invalid() {\n        let device = Default::default();\n        let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4);\n        let _ = config.init::<TestBackend>(&device);\n    }\n\n    #[test]\n    #[should_panic = \"Same padding with an even kernel size is not supported\"]\n    fn same_with_even_kernel_is_invalid() {\n        let device = Default::default();\n        let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same);\n        let _ = config.init::<TestBackend>(&device);\n    }\n\n    #[test]\n    fn display() {\n        let config = DeformConv2dConfig::new([5, 1], [5, 5]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{conv}\"),\n            \"DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}\"\n        );\n    }\n\n    #[test]\n    #[should_panic = \"Number of channels in input tensor and input channels of convolution must be equal. got: 4, expected: 5\"]\n    fn input_channels_mismatch() {\n        let config = DeformConv2dConfig::new([5, 3], [3, 3]);\n        let conv = config.init::<TestBackend>(&Default::default());\n\n        let input = Tensor::<TestBackend, 4>::zeros([1, 4, 10, 10], &Default::default());\n        let offset = Tensor::<TestBackend, 4>::zeros([1, 2 * 3 * 3, 10, 10], &Default::default());\n        let _ = conv.forward(input, offset, None);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/conv/mod.rs",
    "content": "mod conv1d;\nmod conv2d;\nmod conv3d;\nmod conv_transpose1d;\nmod conv_transpose2d;\nmod conv_transpose3d;\nmod deform_conv2d;\n\npub(crate) mod checks;\n\npub use conv_transpose1d::*;\npub use conv_transpose2d::*;\npub use conv_transpose3d::*;\npub use conv1d::*;\npub use conv2d::*;\npub use conv3d::*;\npub use deform_conv2d::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/dropout.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{Distribution, Tensor};\n\n/// Configuration to create a [Dropout](Dropout) layer using the [init function](DropoutConfig::init).\n#[derive(Config, Debug)]\npub struct DropoutConfig {\n    /// The probability of randomly zeroes some elements of the input tensor during training.\n    pub prob: f64,\n}\n\n/// Set at random some elements of the input tensor to zero during training.\n///\n/// This is an effective regularization technique as describe in the paper\n/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).\n///\n/// The input is also scaled during training to `1 / (1 - prob_keep)`.\n///\n/// Should be created with [DropoutConfig].\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Dropout {\n    /// The probability of randomly zeroes some elements of the input tensor during training.\n    pub prob: f64,\n}\n\nimpl DropoutConfig {\n    /// Initialize a new [dropout](Dropout) module.\n    pub fn init(&self) -> Dropout {\n        if self.prob < 0.0 || self.prob > 1.0 {\n            panic!(\n                \"Dropout probability should be between 0 and 1, but got {}\",\n                self.prob\n            );\n        }\n        Dropout { prob: self.prob }\n    }\n}\n\nimpl Dropout {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [Dropout](Dropout) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        if !B::ad_enabled(&input.device()) || self.prob == 0.0 {\n            return input;\n        }\n\n        let prob_keep = 1.0 - self.prob;\n        let random = input.random_like(Distribution::Bernoulli(prob_keep));\n        let x = input * random;\n\n        x * (1.0 / prob_keep)\n    }\n}\n\nimpl ModuleDisplay for Dropout {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"prob\", &self.prob).optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn::tensor::Shape;\n\n    #[cfg(feature = \"std\")]\n    use crate::{TestAutodiffBackend, TestBackend};\n\n    #[cfg(not(feature = \"std\"))]\n    use crate::TestBackend;\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn with_ad_backend_should_mark_input() {\n        let tensor =\n            Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());\n        let dropout = DropoutConfig::new(0.5).init();\n\n        let output = dropout.forward(tensor.clone());\n\n        assert_ne!(tensor.to_data(), output.to_data());\n    }\n\n    #[test]\n    fn without_ad_backend_should_not_change_input() {\n        let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());\n        let dropout = DropoutConfig::new(0.5).init();\n\n        let output = dropout.forward(tensor.clone());\n\n        assert_eq!(tensor.to_data(), output.to_data());\n    }\n\n    #[test]\n    fn display() {\n        let config = DropoutConfig::new(0.5);\n        let layer = config.init();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"Dropout {prob: 0.5}\");\n    }\n\n    #[test]\n    #[should_panic = \"Dropout probability should be between 0 and 1,\"]\n    fn dropout_prob_invalid() {\n        let config = DropoutConfig::new(-10.);\n        let _layer = config.init();\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/embedding.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::Param;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Int;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\nuse burn::tensor::module::embedding;\n\n/// Configuration to create an [Embedding](Embedding) layer using the [init function](EmbeddingConfig::init).\n#[derive(Config, Debug)]\npub struct EmbeddingConfig {\n    /// The number of embedding vectors.\n    pub n_embedding: usize,\n    /// The size of each vector.\n    pub d_model: usize,\n    /// The type of function used to initialize neural network parameters\n    #[config(default = \"Initializer::Normal{mean:0.0, std:1.0}\")]\n    pub initializer: Initializer,\n}\n\n/// Lookup table to store a fix number of vectors.\n///\n/// Should be created with [EmbeddingConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Embedding<B: Backend> {\n    /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized\n    /// from a normal distribution `N(0, 1)`.\n    pub weight: Param<Tensor<B, 2>>,\n}\n\nimpl<B: Backend> ModuleDisplay for Embedding<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [n_embedding, d_model] = self.weight.shape().dims();\n        content\n            .add(\"n_embedding\", &n_embedding)\n            .add(\"d_model\", &d_model)\n            .optional()\n    }\n}\n\nimpl EmbeddingConfig {\n    /// Initialize a new [embedding](Embedding) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Embedding<B> {\n        let weight = self\n            .initializer\n            .init([self.n_embedding, self.d_model], device);\n\n        Embedding { weight }\n    }\n}\n\nimpl<B: Backend> Embedding<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See also [embedding](burn::tensor::module::embedding).\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, seq_length]`\n    /// - output: `[batch_size, seq_length, d_model]`\n    pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {\n        embedding(self.weight.val(), input)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = EmbeddingConfig::new(5, 5).with_initializer(Initializer::Zeros);\n        let embed = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        embed.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<f32, _>(embed.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn display() {\n        let config = EmbeddingConfig::new(100, 10);\n        let embed = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{embed}\"),\n            \"Embedding {n_embedding: 100, d_model: 10, params: 1000}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/interpolate/interpolate1d.rs",
    "content": "use alloc::format;\n\nuse burn::tensor::module::interpolate;\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::InterpolateOptions;\n\nuse super::InterpolateMode;\n\n/// Configuration for the 1D interpolation module.\n///\n/// This struct defines the configuration options for the 1D interpolation operation.\n/// It allows specifying the output size, scale factor, and interpolation mode.\n#[derive(Config, Debug)]\npub struct Interpolate1dConfig {\n    /// Output size of the interpolated tensor.\n    /// If specified, this takes precedence over `scale_factor`.\n    #[config(default = \"None\")]\n    pub output_size: Option<usize>,\n\n    /// Scale factor for resizing the input tensor.\n    /// This is used when `output_size` is not specified.\n    #[config(default = \"None\")]\n    pub scale_factor: Option<f32>,\n\n    /// Interpolation mode to use for resizing.\n    /// Determines how the output values are calculated.\n    #[config(default = \"InterpolateMode::Nearest\")]\n    pub mode: InterpolateMode,\n\n    /// If `true`, the input and output tensors are aligned by their corner pixels.\n    /// If `false`, half-pixel coordinate mapping is used instead.\n    #[config(default = true)]\n    pub align_corners: bool,\n}\n\n/// Interpolate module for resizing 1D tensors with shape [N, C, L].\n///\n/// This struct represents a 1D interpolation module that can resize tensors\n/// using various interpolation methods. It provides flexibility in specifying\n/// either an output size or a scale factor for resizing, along with options\n/// for the interpolation mode.\n///\n/// The module can be used to upsample or downsample 1D tensors, preserving the\n/// number of channels and batch size while adjusting the length dimension.\n///\n/// The module can be created using the [Interpolate1dConfig] struct and the\n/// `init` method, which returns an instance of the [Interpolate1d] struct.\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Interpolate1d {\n    /// Output size of the interpolated tensor\n    pub output_size: Option<usize>,\n\n    /// Scale factor for resizing the input tensor\n    pub scale_factor: Option<f32>,\n\n    /// Interpolation mode used for resizing\n    pub mode: InterpolateMode,\n\n    /// Whether to align corner pixels\n    pub align_corners: bool,\n}\n\nimpl Interpolate1dConfig {\n    /// Initialize the interpolation module\n    pub fn init(self) -> Interpolate1d {\n        Interpolate1d {\n            output_size: self.output_size,\n            scale_factor: self.scale_factor,\n            mode: self.mode,\n            align_corners: self.align_corners,\n        }\n    }\n}\n\nimpl Interpolate1d {\n    /// Performs the forward pass of the 1D interpolation module\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - Input tensor with shape [N, C, L]\n    ///\n    /// # Returns\n    ///\n    /// Resized tensor with shape [N, C, L'], where L' is determined by\n    /// the output_size or scale_factor specified in the module configuration\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);\n    /// let interpolate = Interpolate1dConfig::new()\n    ///     .with_output_size(Some(128))\n    ///     .init();\n    /// let output = interpolate.forward(input);\n    /// assert_eq!(output.dims(), [1, 3, 128]);\n    /// ```\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);\n\n        // Use the interpolate operation to resize the temporal input tensor\n        // by adding a new dimension for the interpolation axis\n        let input = input.unsqueeze_dim(2);\n\n        let result = interpolate(\n            input,\n            [1, output_size],\n            InterpolateOptions::new(self.mode.clone().into())\n                .with_align_corners(self.align_corners),\n        );\n\n        result.squeeze_dims(&[2])\n    }\n}\n\n/// Calculate output size based on input dimensions, output size, and scale factor\n///\n/// # Arguments\n///\n/// * `input_dims` - Input dimensions of the tensor\n/// * `output_size` - Output size for the interpolated tensor\n/// * `scale_factor` - Scale factor for resizing the tensor\n///\n/// # Returns\n///\n/// Output size for the interpolated tensor\n///\n/// # Panics\n///\n/// Panics if neither output_size nor scale_factor is provided\n/// or if the scale factor is too large\nfn calculate_output_size(\n    input_dims: [usize; 3],\n    output_size: Option<usize>,\n    scale_factor: Option<f32>,\n) -> usize {\n    match (output_size, scale_factor) {\n        (Some(output_size), None) => {\n            // Use provided\n            output_size\n        }\n        (None, Some(scale_factor)) => {\n            // Calculate output size based on scale factor\n            let [_, _, l] = input_dims;\n\n            let new_dim = (l as f64) * (scale_factor as f64);\n\n            if new_dim > usize::MAX as f64 {\n                panic!(\"Scale factor is too large\");\n            }\n\n            new_dim as usize\n        }\n        _ => panic!(\"Either output_size or scale_factor must be provided\"),\n    }\n}\n\nimpl ModuleDisplay for Interpolate1d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add_debug_attribute(\"mode\", &self.mode)\n            .add(\"output_size\", &format!(\"{:?}\", self.output_size))\n            .add(\"scale_factor\", &self.scale_factor)\n            .optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n\n    use burn::tensor::Distribution;\n\n    use super::*;\n    use crate::TestBackend;\n    #[test]\n    fn test_calculate_output_size() {\n        let input_dims = [1, 1, 4];\n\n        let output_size = calculate_output_size(input_dims, Some(2), None);\n        assert_eq!(output_size, 2);\n\n        let output_size = calculate_output_size(input_dims, None, Some(2.0));\n        assert_eq!(output_size, 8);\n\n        let output_size = calculate_output_size(input_dims, None, Some(0.5));\n        assert_eq!(output_size, 2);\n\n        let output_size = calculate_output_size(input_dims, None, Some(1.5));\n        assert_eq!(output_size, 6);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Either output_size or scale_factor must be provided\")]\n    fn test_panic() {\n        let input_dims = [1, 1, 4];\n        calculate_output_size(input_dims, None, None);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Scale factor is too large\")]\n    fn test_large_scale_factor() {\n        let input_dims = [1, 1, usize::MAX - 1];\n        calculate_output_size(input_dims, None, Some(2.0));\n    }\n\n    #[test]\n    fn test_module() {\n        let input = Tensor::<TestBackend, 3>::random(\n            [2, 3, 4],\n            Distribution::Uniform(0.0, 1.0),\n            &Default::default(),\n        );\n\n        // Test with output_size\n        let config = Interpolate1dConfig::new().with_output_size(Some(8));\n        let interpolate = config.init();\n        let output = interpolate.forward(input.clone());\n        assert_eq!(output.dims(), [2, 3, 8]);\n\n        // Test with scale_factor\n        let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));\n        let interpolate = config.init();\n        let output = interpolate.forward(input.clone());\n        assert_eq!(output.dims(), [2, 3, 2]);\n\n        // Test with different interpolation mode\n        let config = Interpolate1dConfig::new()\n            .with_output_size(Some(6))\n            .with_mode(InterpolateMode::Linear);\n        let interpolate = config.init();\n        let output = interpolate.forward(input);\n        assert_eq!(output.dims(), [2, 3, 6]);\n    }\n\n    #[test]\n    fn display() {\n        let config = Interpolate1dConfig::new().with_output_size(Some(20));\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"Interpolate1d {mode: Nearest, output_size: Some(20), \\\n            scale_factor: None}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/interpolate/interpolate2d.rs",
    "content": "use alloc::format;\n\nuse burn::tensor::module::interpolate;\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::InterpolateOptions;\n\nuse super::InterpolateMode;\n\n/// Configuration for the 2D interpolation module.\n///\n/// This struct defines the configuration options for the 2D interpolation operation.\n/// It allows specifying the output size, scale factor, and interpolation mode.\n#[derive(Config, Debug)]\npub struct Interpolate2dConfig {\n    /// Output size of the interpolated tensor.\n    /// If specified, this takes precedence over `scale_factor`.\n    #[config(default = \"None\")]\n    pub output_size: Option<[usize; 2]>,\n\n    /// Scale factor for resizing the input tensor.\n    /// This is used when `output_size` is not specified.\n    #[config(default = \"None\")]\n    pub scale_factor: Option<[f32; 2]>,\n\n    /// Interpolation mode to use for resizing.\n    /// Determines how the output values are calculated.\n    #[config(default = \"InterpolateMode::Nearest\")]\n    pub mode: InterpolateMode,\n\n    /// If `true`, the input and output tensors are aligned by their corner pixels.\n    /// If `false`, half-pixel coordinate mapping is used instead.\n    #[config(default = true)]\n    pub align_corners: bool,\n}\n\n/// Interpolate module for resizing tensors with shape [N, C, H, W].\n///\n/// This struct represents an interpolation module that can resize tensors\n/// using various interpolation methods. It provides flexibility in specifying\n/// either an output size or a scale factor for resizing, along with options\n/// for the interpolation mode.\n///\n/// The module can be used to upsample or downsample tensors, preserving the\n/// number of channels and batch size while adjusting the height and width\n/// dimensions.\n///\n/// The module can be created using the [Interpolate2dConfig] struct and the\n/// `init` method, which returns an instance of the [Interpolate2d] struct.\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Interpolate2d {\n    /// Output size of the interpolated tensor\n    pub output_size: Option<[usize; 2]>,\n\n    /// Scale factor for resizing the input tensor\n    pub scale_factor: Option<[f32; 2]>,\n\n    /// Interpolation mode used for resizing\n    pub mode: InterpolateMode,\n\n    /// Whether to align corner pixels\n    pub align_corners: bool,\n}\n\nimpl Interpolate2dConfig {\n    /// Initialize the interpolation module\n    pub fn init(self) -> Interpolate2d {\n        Interpolate2d {\n            output_size: self.output_size,\n            scale_factor: self.scale_factor,\n            mode: self.mode,\n            align_corners: self.align_corners,\n        }\n    }\n}\nimpl Interpolate2d {\n    /// Performs the forward pass of the interpolation module\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - Input tensor with shape [N, C, H, W]\n    ///\n    /// # Returns\n    ///\n    /// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by\n    /// the output_size or scale_factor specified in the module configuration\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);\n    /// let interpolate = Interpolate2dConfig::new()\n    ///     .with_output_size(Some([128, 128]))\n    ///     .init();\n    /// let output = interpolate.forward(input);\n    /// assert_eq!(output.dims(), [1, 3, 128, 128]);\n    /// ```\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);\n        interpolate(\n            input,\n            output_size,\n            InterpolateOptions::new(self.mode.clone().into())\n                .with_align_corners(self.align_corners),\n        )\n    }\n}\n\n/// Calculates the output size for tensor interpolation.\n///\n/// # Arguments\n///\n/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].\n/// * `output_size` - Optional desired output size [H', W'].\n/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].\n///\n/// # Returns\n///\n/// A tuple [H', W'] representing the calculated output size.\n///\n/// # Panics\n///\n/// Panics if neither `output_size` nor `scale_factor` is provided,\n/// or if the scale factor results in dimensions exceeding usize::MAX.\nfn calculate_output_size(\n    input_dims: [usize; 4],\n    output_size: Option<[usize; 2]>,\n    scale_factor: Option<[f32; 2]>,\n) -> [usize; 2] {\n    match (output_size, scale_factor) {\n        (Some(output_size), None) => {\n            // Use provided\n            output_size\n        }\n        (None, Some(scale_factor)) => {\n            // Calculate output size based on scale factor\n            let [_, _, h, w] = input_dims;\n\n            let new_dim_h = (h as f64) * (scale_factor[0] as f64);\n\n            if new_dim_h > usize::MAX as f64 {\n                panic!(\"Scale factor for height is too large\");\n            }\n\n            let new_dim_w = (w as f64) * (scale_factor[1] as f64);\n\n            if new_dim_w > usize::MAX as f64 {\n                panic!(\"Scale factor for width is too large\");\n            }\n\n            [new_dim_h as usize, new_dim_w as usize]\n        }\n        _ => panic!(\"Either output_size or scale_factor must be provided\"),\n    }\n}\n\nimpl ModuleDisplay for Interpolate2d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add_debug_attribute(\"mode\", &self.mode)\n            .add(\"output_size\", &format!(\"{:?}\", self.output_size))\n            .add(\"scale_factor\", &self.scale_factor)\n            .optional()\n    }\n}\n#[cfg(test)]\nmod tests {\n    use burn::tensor::Distribution;\n\n    use crate::TestBackend;\n\n    use super::*;\n\n    #[test]\n    fn test_calculate_output_size() {\n        let input_dims = [1, 1, 4, 4];\n\n        let output_size = calculate_output_size(input_dims, Some([2, 2]), None);\n        assert_eq!(output_size, [2, 2]);\n\n        let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));\n        assert_eq!(output_size, [8, 8]);\n\n        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));\n        assert_eq!(output_size, [2, 2]);\n\n        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));\n        assert_eq!(output_size, [8, 6]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Either output_size or scale_factor must be provided\")]\n    fn test_missing_params() {\n        calculate_output_size([1, 1, 4, 4], None, None);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Scale factor for height is too large\")]\n    fn test_infinite_height() {\n        calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));\n    }\n\n    #[test]\n    #[should_panic(expected = \"Scale factor for width is too large\")]\n    fn test_infinite_width() {\n        calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));\n    }\n\n    #[test]\n    fn test_module() {\n        let input = Tensor::<TestBackend, 4>::random(\n            [2, 3, 4, 4],\n            Distribution::Uniform(0.0, 1.0),\n            &Default::default(),\n        );\n\n        // Test with output_size\n        let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));\n        let interpolate = config.init();\n        let output = interpolate.forward(input.clone());\n        assert_eq!(output.dims(), [2, 3, 8, 8]);\n\n        // Test with scale_factor\n        let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));\n        let interpolate = config.init();\n        let output = interpolate.forward(input.clone());\n        assert_eq!(output.dims(), [2, 3, 2, 2]);\n\n        // Test with different interpolation mode\n        let config = Interpolate2dConfig::new()\n            .with_output_size(Some([6, 6]))\n            .with_mode(InterpolateMode::Linear);\n        let interpolate = config.init();\n        let output = interpolate.forward(input);\n        assert_eq!(output.dims(), [2, 3, 6, 6]);\n    }\n\n    #[test]\n    fn display() {\n        let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \\\n            scale_factor: None}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/interpolate/mod.rs",
    "content": "mod interpolate1d;\nmod interpolate2d;\n\npub use interpolate1d::*;\npub use interpolate2d::*;\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::tensor::ops::InterpolateMode as OpsInterpolateMode;\n\n/// Algorithm used for downsampling and upsampling\n///\n/// This enum defines different interpolation modes for resampling data.\n#[derive(Config, Debug)]\npub enum InterpolateMode {\n    /// Nearest-neighbor interpolation\n    ///\n    /// This mode selects the value of the nearest sample point for each output pixel.\n    /// It is applicable for both temporal and spatial data.\n    Nearest,\n\n    /// Linear interpolation\n    ///\n    /// This mode calculates the output value using linear\n    /// interpolation between nearby sample points.\n    ///\n    /// It is applicable for both temporal and spatial data.\n    Linear,\n\n    /// Cubic interpolation\n    ///\n    /// This mode uses cubic interpolation to calculate the output value\n    /// based on surrounding sample points.\n    ///\n    /// It is applicable for both temporal and spatial data and generally\n    /// provides smoother results than linear interpolation.\n    Cubic,\n\n    /// Lanczos3 interpolation\n    ///\n    /// This mode uses a 6-tap sinc-based Lanczos filter (a=3) to calculate\n    /// the output value. It generally provides high-quality results,\n    /// especially for downsampling.\n    Lanczos,\n}\n\nimpl From<InterpolateMode> for OpsInterpolateMode {\n    fn from(mode: InterpolateMode) -> Self {\n        match mode {\n            InterpolateMode::Nearest => OpsInterpolateMode::Nearest,\n            InterpolateMode::Linear => OpsInterpolateMode::Bilinear,\n            InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,\n            InterpolateMode::Lanczos => OpsInterpolateMode::Lanczos3,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/linear.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Param;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::module::linear;\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Configuration to create a [`Linear`] layer using the [init function](LinearConfig::init).\n#[derive(Config, Debug)]\npub struct LinearConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the output features.\n    pub d_output: usize,\n    /// If a bias should be applied during the linear transformation.\n    #[config(default = true)]\n    pub bias: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n    /// The layout in which the linear parameters are stored.\n    #[config(default = \"LinearLayout::Row\")]\n    pub layout: LinearLayout,\n}\n\n#[derive(Config, Debug, Copy)]\n/// The layout in which the linear parameters are stored.\n///\n/// This can have performance impacts.\npub enum LinearLayout {\n    /// Parameters are stored in Row major.\n    Row,\n    /// Parameters are stored in Col major.\n    Col,\n}\n\n/// Applies a linear transformation to the input tensor.\n///\n/// Should be created with [LinearConfig]\n///\n/// `O = IW + b`\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Linear<B: Backend> {\n    /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:\n    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`\n    pub weight: Param<Tensor<B, 2>>,\n    /// Vector of size `d_output` initialized from a uniform distribution:\n    ///     `U(-k, k)`, where `k = sqrt(1 / d_input)`\n    pub bias: Option<Param<Tensor<B, 1>>>,\n}\n\nimpl LinearConfig {\n    /// Initialize a new [`Linear`] module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Linear<B> {\n        let weight = match self.layout {\n            LinearLayout::Row => {\n                let shape = [self.d_input, self.d_output];\n                self.initializer\n                    .init_with(shape, Some(self.d_input), Some(self.d_output), device)\n            }\n            LinearLayout::Col => {\n                let shape = [self.d_output, self.d_input];\n\n                self.initializer\n                    .init_with(shape, Some(self.d_output), Some(self.d_input), device)\n                    // The param is already transposed when init. We re-transpose to have\n                    // [d_output, d_input] while saving.\n                    .save_mapper(move |tensor| {\n                        B::sync(&tensor.device()).unwrap();\n                        let tensor = tensor.transpose();\n                        B::sync(&tensor.device()).unwrap();\n                        tensor\n                    })\n                    // When loading from record we have to transpose.\n                    .load_mapper(move |tensor| {\n                        B::sync(&tensor.device()).unwrap();\n                        let tensor = tensor.transpose();\n                        B::sync(&tensor.device()).unwrap();\n\n                        tensor\n                    })\n                    // When loading from initialization, we have to transpose.\n                    .init_mapper(|tensor| {\n                        B::sync(&tensor.device()).unwrap();\n                        let tensor = tensor.transpose();\n                        B::sync(&tensor.device()).unwrap();\n                        tensor\n                    })\n            }\n        };\n        let bias = if self.bias {\n            Some(self.initializer.init_with(\n                [self.d_output],\n                Some(self.d_input),\n                Some(self.d_output),\n                device,\n            ))\n        } else {\n            None\n        };\n\n        Linear { weight, bias }\n    }\n}\n\nimpl<B: Backend> Linear<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Arguments\n    ///\n    /// - `input` - The input tensor of shape `[..., d_input]`.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., d_input]`\n    /// - output: `[..., d_output]`\n    ///\n    /// # Returns\n    ///\n    /// The transformed tensor of shape `[..., d_output]`.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        linear(\n            input,\n            self.weight.val(),\n            self.bias.as_ref().map(|b| b.val()),\n        )\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for Linear<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, d_output] = self.weight.shape().dims();\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_output\", &d_output)\n            .add(\"bias\", &self.bias.is_some())\n            .optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::module::ParamId;\n    use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};\n    use burn::tensor::ElementConversion;\n    use burn::tensor::{Shape, TensorData};\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn initializer_default() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = LinearConfig::new(5, 5);\n        let k = (1.0 / config.d_input as f64).sqrt().elem::<FT>();\n        let linear = config.init::<TestBackend>(&device);\n\n        assert_eq!(\n            config.initializer,\n            Initializer::KaimingUniform {\n                gain: 1.0 / 3.0f64.sqrt(),\n                fan_out_only: false\n            }\n        );\n        linear.weight.to_data().assert_within_range(-k..k);\n    }\n\n    #[test]\n    fn initializer_zeros() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = LinearConfig::new(5, 5).with_initializer(Initializer::Zeros);\n        let linear = config.init::<TestBackend>(&device);\n\n        assert_eq!(config.initializer, Initializer::Zeros);\n        linear.weight.to_data().assert_approx_eq::<FT>(\n            &TensorData::zeros::<f32, _>(linear.weight.shape()),\n            Tolerance::default(),\n        );\n    }\n\n    #[test]\n    fn test_linear_forward_no_bias() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let value = 2.;\n        let config = LinearConfig::new(2, 3)\n            .with_initializer(Initializer::Constant { value })\n            .with_bias(false);\n        let linear = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);\n        let result = linear.forward(input);\n        let expected_result = Tensor::<TestBackend, 2>::from_data([[4., 4., 4.]], &device);\n\n        assert_eq!(result.into_data(), expected_result.into_data());\n    }\n\n    #[test]\n    fn test_linear_forward_with_bias() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let device = Default::default();\n\n        let value = 2.;\n        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });\n        let linear = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);\n        let result = linear.forward(input);\n        let expected_result = Tensor::<TestBackend, 2>::from_data([[6., 6., 6.]], &device);\n\n        assert_eq!(result.into_data(), expected_result.into_data());\n    }\n\n    #[test]\n    fn test_linear_1d() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let device = Default::default();\n\n        let value = 2.;\n        let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value });\n        let linear = config.init::<TestBackend>(&device);\n\n        let input_1d = Tensor::<TestBackend, 1>::ones(Shape::new([2]), &device);\n        let input_2d = Tensor::<TestBackend, 2>::ones(Shape::new([1, 2]), &device);\n\n        let result_1d = linear.forward(input_1d).unsqueeze::<2>();\n        let result_2d = linear.forward(input_2d);\n\n        assert_eq!(result_1d.into_data(), result_2d.into_data());\n    }\n\n    #[test]\n    fn display() {\n        let config = LinearConfig::new(3, 5);\n        let linear = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{linear}\"),\n            \"Linear {d_input: 3, d_output: 5, bias: true, params: 20}\"\n        );\n    }\n\n    #[test]\n    fn layout() {\n        let device = Default::default();\n        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);\n        let linear = config.init::<TestBackend>(&device);\n\n        assert_eq!(linear.weight.dims(), [6, 12], \"Shape is as configured\");\n\n        let recorder = BinBytesRecorder::<FullPrecisionSettings>::new();\n\n        // We go through serialization to trigger the mappers..\n        let record = linear.into_record();\n        let data = recorder.record(record, ()).unwrap();\n        let record = recorder.load(data.clone(), &device).unwrap();\n\n        let config = LinearConfig::new(12, 6).with_layout(LinearLayout::Row);\n        let linear_row = config.init::<TestBackend>(&device).load_record(record);\n\n        assert_eq!(\n            linear_row.weight.dims(),\n            [12, 6],\n            \"Shape should be transposed\"\n        );\n\n        let record = recorder.load(data.clone(), &device).unwrap();\n        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);\n        let linear_col = config.init::<TestBackend>(&device).load_record(record);\n\n        assert_eq!(\n            linear_col.weight.dims(),\n            [6, 12],\n            \"Shape should be as configured\"\n        );\n\n        // We go through serialization to trigger the mappers.\n        //\n        // The test will fail if the mapper is not correctly given to the module after loading a\n        // record.\n        let record = linear_col.into_record();\n        let data = recorder.record(record, ()).unwrap();\n\n        let record = recorder.load(data, &device).unwrap();\n        let config = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);\n        let linear_col = config.init::<TestBackend>(&device).load_record(record);\n\n        assert_eq!(\n            linear_col.weight.dims(),\n            [6, 12],\n            \"Shape should be as configured\"\n        );\n    }\n\n    #[test]\n    fn col_row_same_result() {\n        let device = Default::default();\n        let config_col = LinearConfig::new(6, 12).with_layout(LinearLayout::Col);\n        let linear_col = config_col.init::<TestBackend>(&device);\n        let signal = Tensor::<_, 2>::random([8, 6], burn::tensor::Distribution::Default, &device);\n        let value = linear_col.forward(signal.clone());\n\n        let data_1 = value.into_data();\n\n        let weights = linear_col.weight.val().into_data();\n        let weights = Tensor::from_data(weights, &device);\n\n        let linear = Linear {\n            weight: Param::initialized(ParamId::new(), weights),\n            bias: linear_col\n                .bias\n                .map(|b| Param::initialized(ParamId::new(), b.val())),\n        };\n\n        let value = linear.forward(signal);\n        let data_2 = value.into_data();\n\n        data_1.assert_approx_eq::<f32>(&data_2, Default::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/mod.rs",
    "content": "/// Attention module\npub mod attention;\n\n/// Cache module\npub mod cache;\n\n/// Convolution module\npub mod conv;\n\n/// Pooling module\npub mod pool;\n\n/// Transformer module\npub mod transformer;\n\n/// Interpolate module\npub mod interpolate;\n\nmod dropout;\nmod embedding;\nmod linear;\nmod noise;\nmod pos_encoding;\nmod rnn;\nmod rope_encoding;\nmod unfold;\n\npub mod norm;\npub use norm::{batch::*, group::*, instance::*, layer::*, rms::*};\n\npub use dropout::*;\npub use embedding::*;\npub use linear::*;\npub use noise::*;\npub use pos_encoding::*;\npub use rnn::*;\npub use rope_encoding::*;\npub use unfold::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/noise.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{Distribution, Tensor};\n\n/// Configuration to create a [GaussianNoise](GaussianNoise) layer using the [init function](GaussianNoiseConfig::init).\n#[derive(Config, Debug)]\npub struct GaussianNoiseConfig {\n    /// Standard deviation of the normal noise distribution.\n    pub std: f64,\n}\n\n/// Add pseudorandom Gaussian noise to an arbitrarily shaped tensor.\n///\n/// This is an effective regularization technique that also contributes to data augmentation.\n/// Please keep in mind that the value of [std](GaussianNoise::std) should be chosen with care in order to avoid\n/// distortion.\n///\n/// Should be created with [GaussianNoiseConfig].\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct GaussianNoise {\n    /// Standard deviation of the normal noise distribution.\n    pub std: f64,\n}\n\nimpl GaussianNoiseConfig {\n    /// Initialize a new [Gaussian noise](GaussianNoise) module.\n    pub fn init(&self) -> GaussianNoise {\n        if self.std.is_sign_negative() {\n            panic!(\n                \"Standard deviation is required to be non-negative, but got {}\",\n                self.std\n            );\n        }\n        GaussianNoise { std: self.std }\n    }\n}\n\nimpl GaussianNoise {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [GaussianNoise](GaussianNoise) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any]`\n    /// - output: `[..., any]`\n    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        if B::ad_enabled(&input.device()) && self.std != 0.0 {\n            let noise = Tensor::random(\n                input.shape(),\n                Distribution::Normal(0.0, self.std),\n                &input.device(),\n            );\n            input + noise\n        } else {\n            input\n        }\n    }\n}\n\nimpl ModuleDisplay for GaussianNoise {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"std\", &self.std).optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn::tensor::Shape;\n\n    #[cfg(feature = \"std\")]\n    use crate::{TestAutodiffBackend, TestBackend};\n\n    #[cfg(not(feature = \"std\"))]\n    use crate::TestBackend;\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn with_ad_backend_should_mark_input() {\n        let tensor =\n            Tensor::<TestAutodiffBackend, 2>::ones(Shape::new([100, 100]), &Default::default());\n        let noise = GaussianNoiseConfig::new(0.5).init();\n\n        let output = noise.forward(tensor.clone());\n\n        assert_ne!(tensor.to_data(), output.to_data());\n    }\n\n    #[test]\n    fn without_ad_backend_should_not_change_input() {\n        let tensor = Tensor::<TestBackend, 2>::ones(Shape::new([100, 100]), &Default::default());\n        let noise = GaussianNoiseConfig::new(0.5).init();\n\n        let output = noise.forward(tensor.clone());\n\n        assert_eq!(tensor.to_data(), output.to_data());\n    }\n\n    #[test]\n    #[should_panic(expected = \"Standard deviation is required to be non-negative\")]\n    fn negative_std_should_panic() {\n        GaussianNoiseConfig { std: -0.5 }.init();\n    }\n\n    #[test]\n    fn display() {\n        let config = GaussianNoiseConfig::new(0.5);\n        let layer = config.init();\n\n        assert_eq!(alloc::format!(\"{layer}\"), \"GaussianNoise {std: 0.5}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/batch.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::Initializer;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::{Tensor, backend::Backend};\nuse burn::{\n    config::Config,\n    module::{Module, Param, RunningState},\n};\n\n/// [`BatchNorm`] Configuration.\n///\n/// Used to create a [`BatchNorm`] layer using the [`BatchNormConfig::init`].\n#[derive(Config, Debug)]\npub struct BatchNormConfig {\n    /// The number of features.\n    pub num_features: usize,\n    /// A value required for numerical stability. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub epsilon: f64,\n    /// Momentum used to update the metrics. Default: 0.1\n    #[config(default = 0.1)]\n    pub momentum: f64,\n}\n\n/// Applies Batch Normalization over a tensor.\n///\n/// Based upon the paper [Batch Normalization](https://arxiv.org/abs/1502.03167).\n///\n/// Assumes input tensor is of shape ``[batch_size, channels, ...]``.\n///\n/// `Y = norm(X) * γ + β`\n///\n/// Where:\n/// - `X` is the input tensor\n/// - `Y` is the output tensor\n/// - `norm` is the normalization function\n/// - `γ` is the learnable weight\n/// - `β` is the learnable bias\n///\n/// Should be created using [`BatchNormConfig`].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct BatchNorm<B: Backend> {\n    /// The learnable weight gamma.\n    pub gamma: Param<Tensor<B, 1>>,\n    /// The learnable weight beta.\n    pub beta: Param<Tensor<B, 1>>,\n    /// The running mean.\n    pub running_mean: RunningState<Tensor<B, 1>>,\n    /// The running variance.\n    pub running_var: RunningState<Tensor<B, 1>>,\n    /// Momentum used to update the metrics.\n    pub momentum: f64,\n    /// A value required for numerical stability.\n    pub epsilon: f64,\n}\n\nimpl BatchNormConfig {\n    /// Initializes a new [batch norm](BatchNorm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> BatchNorm<B> {\n        let gamma = Initializer::Ones.init([self.num_features], device);\n        let beta = Initializer::Zeros.init([self.num_features], device);\n\n        let running_mean = Tensor::zeros([self.num_features], device);\n        let running_var = Tensor::ones([self.num_features], device);\n\n        BatchNorm {\n            gamma,\n            beta,\n            running_mean: RunningState::new(running_mean),\n            running_var: RunningState::new(running_var),\n            momentum: self.momentum,\n            epsilon: self.epsilon,\n        }\n    }\n}\n\nimpl<B: Backend> BatchNorm<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [`BatchNorm`] for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - `input`: ``[batch_size, channels, ...]``\n    /// - `output`: ``[batch_size, channels, ...]``\n    ///\n    /// # Panics\n    ///\n    /// This function will panic if the input tensor has rank < 2.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        // Should be move to a compilation error when const generic support that kind of\n        // validation. https://github.com/rust-lang/rust/issues/76560\n        if D < 2 {\n            panic!(\n                \"BatchNorm can only be applied on tensors of rank >= 2 with the following shape \\\n                 [batch_size, channels, ...], received {}D tensor\",\n                D\n            );\n        }\n\n        match B::ad_enabled(&input.device()) {\n            true => self.forward_train(input),\n            false => self.forward_inference(input),\n        }\n    }\n\n    fn forward_inference<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let device = input.device();\n        let channels = input.dims()[1];\n        let mean = self.running_mean.value().to_device(&device);\n        let var = self.running_var.value().to_device(&device);\n\n        let mut shape = [1; D];\n        shape[1] = channels;\n\n        self.forward_shared(input, mean.reshape(shape), var.reshape(shape))\n    }\n\n    fn forward_train<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let device = input.device();\n        let dims = input.dims();\n        let batch_size = dims[0];\n        let channels = dims[1];\n\n        let mut shape_unsqueeze = [1; D];\n        let mut flatten_size = batch_size;\n        shape_unsqueeze[1] = channels;\n\n        for dim in dims.iter().take(D).skip(2) {\n            flatten_size *= dim;\n        }\n\n        let mean = input\n            .clone()\n            .swap_dims(0, 1)\n            .reshape([channels, flatten_size])\n            .mean_dim(1)\n            .reshape(shape_unsqueeze);\n\n        let var = input\n            .clone()\n            .sub(mean.clone())\n            .square()\n            .swap_dims(0, 1)\n            .reshape([channels, flatten_size])\n            .mean_dim(1)\n            .reshape(shape_unsqueeze);\n\n        let running_mean = self.running_mean.value_sync().to_device(&device);\n        let running_var = self.running_var.value_sync().to_device(&device);\n\n        let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(\n            mean.clone()\n                .detach()\n                .mul_scalar(self.momentum)\n                .reshape([channels]),\n        );\n        let running_var = running_var.mul_scalar(1.0 - self.momentum).add(\n            var.clone()\n                .detach()\n                .mul_scalar(self.momentum)\n                .reshape([channels]),\n        );\n\n        self.running_mean.update(running_mean.detach());\n        self.running_var.update(running_var.detach());\n\n        self.forward_shared(input, mean, var)\n    }\n\n    fn forward_shared<const D: usize>(\n        &self,\n        x: Tensor<B, D>,\n        mean: Tensor<B, D>,\n        var: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        let channels = x.dims()[1];\n        let mut shape = [1; D];\n        shape[1] = channels;\n\n        let std = var.add_scalar(self.epsilon).sqrt();\n\n        let x = x.sub(mean);\n        let x = x.div(std);\n\n        let x = x.mul(self.gamma.val().reshape(shape));\n\n        x.add(self.beta.val().reshape(shape))\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for BatchNorm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [num_features] = self.beta.shape().dims();\n\n        content\n            .add(\"num_features\", &num_features)\n            .add(\"momentum\", &self.momentum)\n            .add(\"epsilon\", &self.epsilon)\n            .optional()\n    }\n}\n\n#[cfg(feature = \"std\")]\n#[cfg(test)]\nmod tests_1d {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use burn::module::AutodiffModule;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestAutodiffBackend>;\n\n    #[test]\n    fn batch_norm_forward_train() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        let output = module.forward(input_tensor(&device));\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_train(), Tolerance::rel_abs(0.1, 0.001));\n    }\n\n    #[test]\n    fn batch_norm_forward_inference() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        module.forward(input_tensor(&device));\n        let module = module.valid();\n        let output = module.forward(input_tensor(&device));\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());\n    }\n\n    fn expected_valid() -> TensorData {\n        TensorData::from([\n            [[0.9409, 0.6976], [0.5892, 0.8774], [0.9106, 0.6844]],\n            [[0.6012, 0.0782], [-0.0394, 0.9270], [0.6181, 0.5492]],\n        ])\n    }\n\n    fn expected_train() -> TensorData {\n        TensorData::from([\n            [\n                [1.1483e+00, 3.7521e-01],\n                [1.6272e-03, 7.5067e-01],\n                [1.6204e+00, -4.5168e-02],\n            ],\n            [\n                [6.8856e-02, -1.5923e+00],\n                [-1.6318e+00, 8.7949e-01],\n                [-5.3368e-01, -1.0416e+00],\n            ],\n        ])\n    }\n\n    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 3> {\n        Tensor::<B, 3>::from_floats(\n            [\n                [[0.9601, 0.7277], [0.6272, 0.9034], [0.9378, 0.7230]],\n                [[0.6356, 0.1362], [0.0249, 0.9509], [0.6600, 0.5945]],\n            ],\n            device,\n        )\n    }\n\n    #[test]\n    fn batch_norm_forward_train_inference() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        module.forward(input_tensor(&device));\n        let module = module.valid();\n        let output = module.forward(input_tensor(&device));\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_valid(), Tolerance::default());\n\n        let module = module.train::<TestAutodiffBackend>();\n        let output = module.forward(input_tensor(&device));\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_train(), Tolerance::default());\n    }\n}\n\n#[cfg(feature = \"std\")]\n#[cfg(test)]\nmod tests_2d {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use burn::module::AutodiffModule;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestAutodiffBackend>;\n\n    #[test]\n    fn batch_norm_forward_train() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        let output = module.forward(input_tensor(&device));\n\n        let expected = TensorData::from([\n            [\n                [[1.5136, 0.7506], [-1.2216, 0.1477]],\n                [[0.3135, 1.2252], [-0.4150, 0.6130]],\n                [[1.4186, 0.3372], [-1.5183, 1.5262]],\n            ],\n            [\n                [[0.4483, -1.1914], [-1.2010, 0.7537]],\n                [[-1.6752, 1.3822], [-0.5058, -0.9381]],\n                [[0.0200, -0.3097], [-0.5715, -0.9026]],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(0.1, 0.001));\n    }\n\n    #[test]\n    fn batch_norm_forward_inference() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        module.forward(input_tensor(&device));\n        let module = module.valid();\n        let output = module.forward(input_tensor(&device));\n\n        let expected = TensorData::from([\n            [\n                [[0.9538, 0.7103], [0.0808, 0.5179]],\n                [[0.6015, 0.8910], [0.3703, 0.6966]],\n                [[0.9171, 0.6912], [0.3037, 0.9395]],\n            ],\n            [\n                [[0.6138, 0.0904], [0.0874, 0.7113]],\n                [[-0.0297, 0.9408], [0.3415, 0.2042]],\n                [[0.6250, 0.5561], [0.5013, 0.4323]],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn batch_norm_running_mean() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        let _output = module.forward(input_tensor(&device));\n\n        let running_mean = module.running_mean.value_sync();\n\n        let expected = TensorData::from([0.0499, 0.0532, 0.0656]);\n        running_mean\n            .reshape([3])\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn batch_norm_running_var() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        let _output = module.forward(input_tensor(&device));\n\n        let running_var = module.running_var.value_sync();\n\n        let expected = TensorData::from([0.9106, 0.9105, 0.9045]);\n        running_var\n            .reshape([3])\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn batch_norm_running_mean_inner_module() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n\n        let _output = module.forward(input_tensor(&device));\n\n        let module_valid = module.valid();\n        let running_mean = module_valid.running_mean.value();\n        let running_mean_after = module.running_mean.value();\n\n        running_mean_after\n            .into_data()\n            .assert_approx_eq::<FT>(&running_mean.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn batch_norm_grads() {\n        let device = Default::default();\n        let module = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&device);\n        let input = input_tensor(&device).require_grad();\n\n        let output = module.forward(input.clone());\n\n        let grads = output.backward();\n\n        let tolerance = Tolerance::rel_abs(0.1, 0.001);\n        let expected = TensorData::from([0.0000e+00, -5.9035e-07, -6.0011e-07]);\n        module\n            .gamma\n            .grad(&grads)\n            .unwrap()\n            .reshape([3])\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        let expected = TensorData::from([8., 8., 8.]);\n        module\n            .beta\n            .grad(&grads)\n            .unwrap()\n            .reshape([3])\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        let expected = TensorData::from([\n            [\n                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],\n                [[7.6400e-08, 2.9848e-07], [-1.0110e-07, 1.4933e-07]],\n                [[5.3570e-07, 1.2732e-07], [-5.7336e-07, 5.7632e-07]],\n            ],\n            [\n                [[0.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00]],\n                [[-4.0807e-07, 3.3673e-07], [-1.2323e-07, -2.2854e-07]],\n                [[7.5642e-09, -1.1695e-07], [-2.1582e-07, -3.4078e-07]],\n            ],\n        ]);\n        input\n            .grad(&grads)\n            .unwrap()\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n    }\n\n    fn input_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 4> {\n        Tensor::<B, 4>::from_floats(\n            [\n                [\n                    [[0.9601, 0.7277], [0.1270, 0.5441]],\n                    [[0.6272, 0.9034], [0.4066, 0.7179]],\n                    [[0.9378, 0.7230], [0.3544, 0.9591]],\n                ],\n                [\n                    [[0.6356, 0.1362], [0.1333, 0.7287]],\n                    [[0.0249, 0.9509], [0.3791, 0.2481]],\n                    [[0.6600, 0.5945], [0.5424, 0.4767]],\n                ],\n            ],\n            device,\n        )\n    }\n\n    #[test]\n    fn display() {\n        let batch_norm = BatchNormConfig::new(3).init::<TestAutodiffBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{batch_norm}\"),\n            \"BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/group.rs",
    "content": "use burn::module::Initializer;\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::Param;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Configuration to create a [GroupNorm](GroupNorm) layer using the [init function](GroupNormConfig::init).\n#[derive(Debug, Config)]\npub struct GroupNormConfig {\n    /// The number of groups to separate the channels into\n    pub num_groups: usize,\n    /// The number of channels expected in the input\n    pub num_channels: usize,\n    /// A value required for numerical stability. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub epsilon: f64,\n    /// A boolean value that when set to `true`, this module has learnable\n    /// per-channel affine parameters initialized to ones (for weights)\n    /// and zeros (for biases). Default: `true`\n    #[config(default = true)]\n    pub affine: bool,\n}\n\n/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).\n///\n/// `Y = groupnorm(X) * γ + β`\n///\n/// Where:\n/// - `X` is the input tensor\n/// - `Y` is the output tensor\n/// - `γ` is the learnable weight\n/// - `β` is the learnable bias\n///\n/// Should be created using [GroupNormConfig](GroupNormConfig).\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct GroupNorm<B: Backend> {\n    /// The learnable weight\n    pub gamma: Option<Param<Tensor<B, 1>>>,\n    /// The learnable bias\n    pub beta: Option<Param<Tensor<B, 1>>>,\n    /// The number of groups to separate the channels into\n    pub num_groups: usize,\n    /// The number of channels expected in the input\n    pub num_channels: usize,\n    /// A value required for numerical stability\n    pub epsilon: f64,\n    /// A boolean value that when set to `true`, this module has learnable\n    pub affine: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for GroupNorm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"num_groups\", &self.num_groups)\n            .add(\"num_channels\", &self.num_channels)\n            .add(\"epsilon\", &self.epsilon)\n            .add(\"affine\", &self.affine)\n            .optional()\n    }\n}\n\nimpl GroupNormConfig {\n    /// Initialize a new [group norm](GroupNorm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {\n        assert_eq!(\n            self.num_channels % self.num_groups,\n            0,\n            \"The number of channels must be divisible by the number of groups\"\n        );\n\n        let (gamma, beta) = if self.affine {\n            let gamma = Initializer::Ones.init([self.num_channels], device);\n            let beta = Initializer::Zeros.init([self.num_channels], device);\n\n            (Some(gamma), Some(beta))\n        } else {\n            (None, None)\n        };\n\n        GroupNorm {\n            num_groups: self.num_groups,\n            num_channels: self.num_channels,\n            gamma,\n            beta,\n            epsilon: self.epsilon,\n            affine: self.affine,\n        }\n    }\n}\n\nimpl<B: Backend> GroupNorm<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [GroupNorm](GroupNorm) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, num_channels, *]`\n    /// - output: `[batch_size, num_channels, *]`\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        if input.shape()[1] != self.num_channels {\n            panic!(\n                \"The number of channels in the input tensor should be equal to the number of channels in the GroupNorm module. Expected {}, got {}\",\n                self.num_channels,\n                input.shape()[1]\n            );\n        }\n\n        let gamma = self.gamma.as_ref().map(|x| x.val());\n        let beta = self.beta.as_ref().map(|x| x.val());\n\n        group_norm(\n            input,\n            gamma,\n            beta,\n            self.num_groups,\n            self.epsilon,\n            self.affine,\n        )\n    }\n}\n\n/// Applies Group Normalization over a mini-batch of inputs as described in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).\n///\n/// `Y = groupnorm(X) * γ + β`\n///\n/// Where:\n/// - `X` is the input tensor\n/// - `Y` is the output tensor\n/// - `γ` is the learnable weight\n/// - `β` is the learnable bias\n///\npub(crate) fn group_norm<B: Backend, const D: usize>(\n    input: Tensor<B, D>,\n    gamma: Option<Tensor<B, 1>>,\n    beta: Option<Tensor<B, 1>>,\n    num_groups: usize,\n    epsilon: f64,\n    affine: bool,\n) -> Tensor<B, D> {\n    if (beta.is_none() || gamma.is_none()) && affine {\n        panic!(\"Affine is set to true, but gamma or beta is None\");\n    }\n\n    let shape = input.shape();\n    if shape.num_elements() <= 2 {\n        panic!(\n            \"input rank for GroupNorm should be at least 3, but got {}\",\n            shape.num_elements()\n        );\n    }\n\n    let batch_size = shape[0];\n    let num_channels = shape[1];\n\n    let hidden_size = shape[2..].iter().product::<usize>() * num_channels / num_groups;\n    let input = input.reshape([batch_size, num_groups, hidden_size]);\n\n    let mean = input.clone().sum_dim(2) / hidden_size as f64;\n    let input = input.sub(mean);\n\n    let var = input.clone().square().sum_dim(2) / hidden_size as f64;\n    let input_normalized = input.div(var.add_scalar(epsilon).sqrt());\n\n    if affine {\n        let mut affine_shape = [1; D];\n        affine_shape[1] = num_channels;\n\n        input_normalized\n            .reshape(shape)\n            .mul(gamma.clone().unwrap().reshape(affine_shape))\n            .add(beta.clone().unwrap().reshape(affine_shape))\n    } else {\n        input_normalized.reshape(shape)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use alloc::format;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn group_norm_forward_affine_false() {\n        let device = Default::default();\n        let module = GroupNormConfig::new(2, 6)\n            .with_affine(false)\n            .init::<TestBackend>(&device);\n\n        assert!(module.gamma.is_none());\n        assert!(module.beta.is_none());\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([\n                [\n                    [-0.3034, 0.2726, -0.9659],\n                    [-1.1845, -1.3236, 0.0172],\n                    [1.9507, 1.2554, -0.8625],\n                    [1.0682, 0.3604, 0.3985],\n                    [-0.4957, -0.4461, -0.9721],\n                    [1.5157, -0.1546, -0.5596],\n                ],\n                [\n                    [-1.6698, -0.4040, -0.7927],\n                    [0.3736, -0.0975, -0.1351],\n                    [-0.9461, 0.5461, -0.6334],\n                    [-1.0919, -0.1158, 0.1213],\n                    [-0.9535, 0.1281, 0.4372],\n                    [-0.2845, 0.3488, 0.5641],\n                ],\n            ]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([\n            [\n                [-0.1653, 0.3748, -0.7866],\n                [-0.9916, -1.1220, 0.1353],\n                [1.9485, 1.2965, -0.6896],\n                [1.2769, 0.3628, 0.4120],\n                [-0.7427, -0.6786, -1.3578],\n                [1.8547, -0.3022, -0.8252],\n            ],\n            [\n                [-1.9342, 0.0211, -0.5793],\n                [1.2223, 0.4945, 0.4365],\n                [-0.8163, 1.4887, -0.3333],\n                [-1.7960, -0.0392, 0.3875],\n                [-1.5469, 0.3998, 0.9561],\n                [-0.3428, 0.7970, 1.1845],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn group_norm_forward_affine_true() {\n        let device = Default::default();\n        let module = GroupNormConfig::new(3, 6)\n            .with_affine(true)\n            .init::<TestBackend>(&device);\n\n        let tolerance = Tolerance::permissive();\n        module\n            .gamma\n            .as_ref()\n            .expect(\"gamma should not be None\")\n            .val()\n            .to_data()\n            .assert_approx_eq::<FT>(&TensorData::ones::<f32, _>([6]), tolerance);\n\n        module\n            .beta\n            .as_ref()\n            .expect(\"beta should not be None\")\n            .val()\n            .to_data()\n            .assert_approx_eq::<FT>(&TensorData::zeros::<f32, _>([6]), tolerance);\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([\n                [\n                    [0.3345, 0.4429, 0.6639],\n                    [0.5041, 0.4175, 0.8437],\n                    [0.6159, 0.3758, 0.4071],\n                    [0.5417, 0.5785, 0.7671],\n                    [0.3837, 0.9883, 0.0420],\n                    [0.4808, 0.8989, 0.6144],\n                ],\n                [\n                    [0.3930, 0.2098, 0.0602],\n                    [0.2298, 0.9425, 0.0333],\n                    [0.7409, 0.8172, 0.8879],\n                    [0.4846, 0.0486, 0.2029],\n                    [0.6741, 0.9765, 0.6864],\n                    [0.2827, 0.5534, 0.2125],\n                ],\n            ]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([\n            [\n                [-1.1694, -0.5353, 0.7572],\n                [-0.1775, -0.6838, 1.8087],\n                [0.5205, -1.3107, -1.0723],\n                [-0.0459, 0.2351, 1.6734],\n                [-0.5796, 1.3218, -1.6544],\n                [-0.2744, 1.0406, 0.1459],\n            ],\n            [\n                [0.2665, -0.3320, -0.8205],\n                [-0.2667, 2.0612, -0.9085],\n                [0.6681, 0.9102, 1.1345],\n                [-0.1453, -1.5287, -1.0389],\n                [0.4253, 1.5962, 0.4731],\n                [-1.0903, -0.0419, -1.3623],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n    }\n\n    #[test]\n    fn display() {\n        let config = GroupNormConfig::new(3, 6);\n        let group_norm = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{group_norm}\"),\n            \"GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/instance.rs",
    "content": "use burn_core as burn;\n\nuse crate::norm::group_norm;\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::module::{Module, Param};\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Configuration to create a [InstanceNorm](InstanceNorm) layer using the [init function](InstanceNormConfig::init).\n#[derive(Debug, Config)]\npub struct InstanceNormConfig {\n    /// The number of channels expected in the input\n    pub num_channels: usize,\n    /// A value required for numerical stability. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub epsilon: f64,\n    /// A boolean value that when set to `true`, this module has learnable\n    /// per-channel affine parameters initialized to ones (for weights)\n    /// and zeros (for biases). Default: `true`\n    #[config(default = true)]\n    pub affine: bool,\n}\n\n/// Applies Instance Normalization over a tensor as described in the paper [Instance Normalization](https://arxiv.org/abs/1607.08022)\n///\n/// Should be created using [InstanceNormConfig](InstanceNormConfig).\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct InstanceNorm<B: Backend> {\n    /// The learnable weight\n    pub gamma: Option<Param<Tensor<B, 1>>>,\n    /// The learnable bias\n    pub beta: Option<Param<Tensor<B, 1>>>,\n    /// The number of channels expected in the input\n    pub num_channels: usize,\n    /// A value required for numerical stability\n    pub epsilon: f64,\n    /// A boolean value that when set to `true`, this module has learnable\n    pub affine: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for InstanceNorm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"num_channels\", &self.num_channels)\n            .add(\"epsilon\", &self.epsilon)\n            .add(\"affine\", &self.affine)\n            .optional()\n    }\n}\n\nimpl InstanceNormConfig {\n    /// Initialize a new [instance norm](InstanceNorm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> InstanceNorm<B> {\n        let (gamma, beta) = if self.affine {\n            let gamma = Initializer::Ones.init([self.num_channels], device);\n            let beta = Initializer::Zeros.init([self.num_channels], device);\n\n            (Some(gamma), Some(beta))\n        } else {\n            (None, None)\n        };\n\n        InstanceNorm {\n            gamma,\n            beta,\n            num_channels: self.num_channels,\n            epsilon: self.epsilon,\n            affine: self.affine,\n        }\n    }\n}\n\nimpl<B: Backend> InstanceNorm<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See also [InstanceNormConfig](InstanceNormConfig) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, num_channels, *]`\n    /// - output: `[batch_size, num_channels, *]`\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        // Instance norm is equivalent to group norm when the number of groups is equal to the number of channels.\n        let num_groups = self.num_channels;\n\n        let gamma = self.gamma.as_ref().map(|x| x.val());\n        let beta = self.beta.as_ref().map(|x| x.val());\n\n        group_norm(input, gamma, beta, num_groups, self.epsilon, self.affine)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use alloc::format;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn instance_norm_forward_affine_false() {\n        let device = Default::default();\n        let module = InstanceNormConfig::new(6)\n            .with_affine(false)\n            .init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([\n                [\n                    [-0.3034, 0.2726, -0.9659],\n                    [-1.1845, 1.4078, 0.9774],\n                    [0.3963, -1.3738, 1.4125],\n                    [1.0682, 0.3604, 0.3985],\n                    [-0.4957, -0.4461, -0.9721],\n                    [1.5157, -0.1546, -0.5596],\n                ],\n                [\n                    [-1.6698, -0.4040, -0.7927],\n                    [0.3736, -0.0975, -0.1351],\n                    [-0.9461, 0.5461, -0.6334],\n                    [-1.0919, -0.1158, 0.1213],\n                    [-0.9535, 0.1281, 0.4372],\n                    [-0.2845, 0.3488, 0.5641],\n                ],\n            ]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([\n            [\n                [0.0569, 1.1952, -1.2522],\n                [-1.3971, 0.8883, 0.5088],\n                [0.2183, -1.3192, 1.1009],\n                [1.4126, -0.7649, -0.6477],\n                [0.5999, 0.8091, -1.409],\n                [1.39, -0.4696, -0.9205],\n            ],\n            [\n                [-1.3492, 1.0417, 0.3075],\n                [1.411, -0.6243, -0.7867],\n                [-0.9363, 1.386, -0.4497],\n                [-1.3899, 0.4692, 0.9208],\n                [-1.3822, 0.4319, 0.9503],\n                [-1.3714, 0.3868, 0.9846],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn instance_norm_forward_affine_true() {\n        let device = Default::default();\n        let module = InstanceNormConfig::new(6)\n            .with_affine(true)\n            .init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([\n                [\n                    [0.3345, 0.4429, 0.6639],\n                    [0.5041, 0.4175, 0.8437],\n                    [0.6159, 0.3758, 0.4071],\n                    [0.5417, 0.5785, 0.7671],\n                    [0.3837, 0.9883, 0.0420],\n                    [0.4808, 0.8989, 0.6144],\n                ],\n                [\n                    [0.3930, 0.2098, 0.0602],\n                    [0.2298, 0.9425, 0.0333],\n                    [0.7409, 0.8172, 0.8879],\n                    [0.4846, 0.0486, 0.2029],\n                    [0.6741, 0.9765, 0.6864],\n                    [0.2827, 0.5534, 0.2125],\n                ],\n            ]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([\n            [\n                [-1.06458, -0.2738, 1.33838],\n                [-0.45848, -0.92929, 1.38777],\n                [1.40388, -0.84877, -0.55511],\n                [-0.88515, -0.51245, 1.3976],\n                [-0.22397, 1.32124, -1.09727],\n                [-1.05468, 1.34316, -0.28848],\n            ],\n            [\n                [1.26372, -0.08229, -1.18144],\n                [-0.44049, 1.38403, -0.94354],\n                [-1.23828, 0.03109, 1.2072],\n                [1.32524, -1.08999, -0.23524],\n                [-0.75061, 1.4132, -0.66259],\n                [-0.45469, 1.38697, -0.93228],\n            ],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = InstanceNormConfig::new(6);\n        let instance_norm = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{instance_norm}\"),\n            \"InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/layer.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Content;\nuse burn::module::DisplaySettings;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::ModuleDisplay;\nuse burn::module::Param;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Configuration to create a [LayerNorm](LayerNorm) layer using the [init function](LayerNormConfig::init).\n#[derive(Debug, Config)]\npub struct LayerNormConfig {\n    /// The size of the input features.\n    pub d_model: usize,\n    /// A value required for numerical stability. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub epsilon: f64,\n    /// If a bias (beta) should be applied during the normalization. Default: true\n    #[config(default = true)]\n    pub bias: bool,\n}\n\n/// Applies Layer Normalization over an input tensor as described in the paper [Layer Normalization](https://arxiv.org/abs/1607.06450).\n///\n/// `Y = norm(X) * γ + β`\n///\n/// Where:\n/// - `X` is the input tensor\n/// - `Y` is the output tensor\n/// - `γ` is the learnable weight (scale)\n/// - `β` is the learnable bias (optional)\n///\n/// Should be created using [LayerNormConfig](LayerNormConfig).\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct LayerNorm<B: Backend> {\n    /// The learnable weight (scale).\n    pub gamma: Param<Tensor<B, 1>>,\n    /// The learnable bias (optional).\n    pub beta: Option<Param<Tensor<B, 1>>>,\n    /// A value required for numerical stability.\n    epsilon: f64,\n}\n\nimpl LayerNormConfig {\n    /// Initialize a new [layer norm](LayerNorm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> LayerNorm<B> {\n        let gamma = Initializer::Ones.init([self.d_model], device);\n        let beta = if self.bias {\n            Some(Initializer::Zeros.init([self.d_model], device))\n        } else {\n            None\n        };\n\n        LayerNorm {\n            gamma,\n            beta,\n            epsilon: self.epsilon,\n        }\n    }\n}\n\nimpl<B: Backend> LayerNorm<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See the [LayerNorm](LayerNorm) documentation for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any, d_model]`\n    /// - output: `[..., any, d_model]`\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let (var, mean) = input.clone().var_mean_bias(D - 1);\n\n        let input_normalized = input.sub(mean).div(var.add_scalar(self.epsilon).sqrt());\n\n        let output = input_normalized.mul(self.gamma.val().unsqueeze());\n\n        match &self.beta {\n            Some(beta) => output.add(beta.val().unsqueeze()),\n            None => output,\n        }\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for LayerNorm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_model] = self.gamma.shape().dims();\n        content\n            .add(\"d_model\", &d_model)\n            .add(\"epsilon\", &self.epsilon)\n            .add(\"bias\", &self.beta.is_some())\n            .optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use alloc::format;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[cfg(feature = \"std\")]\n    use crate::{TestAutodiffBackend, TestBackend};\n\n    #[cfg(not(feature = \"std\"))]\n    use crate::TestBackend;\n\n    #[test]\n    fn layer_norm_forward() {\n        let device = Default::default();\n        let module = LayerNormConfig::new(10).init::<TestBackend>(&device);\n        let input = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[\n                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,\n            ]]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([[\n            -0.4990, -1.9680, 1.6178, -0.7486, -0.6470, 0.8576, 0.0461, 1.1111, -0.2614, 0.4915,\n        ]]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn layer_norm_forward_large_epsilon() {\n        let device = Default::default();\n        let module = LayerNormConfig::new(10)\n            .with_epsilon(1e-1)\n            .init::<TestBackend>(&device);\n        let input = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[\n                -0.6897, -2.7106, 2.2222, -1.0330, -0.8933, 1.1765, 0.0601, 1.5252, -0.3630, 0.6728,\n            ]]),\n            &device,\n        );\n\n        let output = module.forward(input);\n\n        let expected = TensorData::from([[\n            -0.4863, -1.9180, 1.5766, -0.7295, -0.6305, 0.8358, 0.0449, 1.0828, -0.2548, 0.4790,\n        ]]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn layer_norm_backward() {\n        let device = Default::default();\n        let module = LayerNormConfig::new(2).init::<TestAutodiffBackend>(&device);\n        let tensor_1 = Tensor::<TestAutodiffBackend, 2>::from_data(\n            TensorData::from([[0.0, 1.0], [3.0, 4.0]]),\n            &device,\n        )\n        .require_grad();\n        let tensor_2 = Tensor::<TestAutodiffBackend, 2>::from_data(\n            TensorData::from([[6.0, 7.0], [9.0, 10.0]]),\n            &device,\n        )\n        .require_grad();\n\n        let x = tensor_1.clone().matmul(tensor_2.clone());\n\n        let output = module.forward(x);\n        let grads = output.backward();\n\n        let tensor_1_grad = tensor_1.grad(&grads).unwrap();\n        let tensor_2_grad = tensor_2.grad(&grads).unwrap();\n        let gamma_grad = module.gamma.grad(&grads).unwrap();\n        let beta_grad = module.beta.as_ref().unwrap().grad(&grads).unwrap();\n\n        let expected = TensorData::from([-2.0, 2.0]);\n        gamma_grad\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::from([2.0, 2.0]);\n        beta_grad\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::zeros::<f32, _>(tensor_1_grad.shape());\n        tensor_1_grad\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n\n        let expected = TensorData::zeros::<f32, _>(tensor_2_grad.shape());\n        tensor_2_grad\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = LayerNormConfig::new(6);\n        let layer_norm = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{layer_norm}\"),\n            \"LayerNorm {d_model: 6, epsilon: 0.00001, bias: true, params: 12}\"\n        );\n    }\n\n    #[test]\n    fn display_no_bias() {\n        let config = LayerNormConfig::new(6).with_bias(false);\n        let layer_norm = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{layer_norm}\"),\n            \"LayerNorm {d_model: 6, epsilon: 0.00001, bias: false, params: 6}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/mod.rs",
    "content": "//! # Normalization Layers\n//!\n//! Users who wish to provide an abstraction over swappable normalization\n//! layers can use the [`Normalization`] wrapper, with support for:\n//! * [`Normalization::Batch`] - [`BatchNorm`]\n//! * [`Normalization::Group`] - [`GroupNorm`]\n//! * [`Normalization::Instance`] - [`InstanceNorm`]\n//! * [`Normalization::Layer`] - [`LayerNorm`]\n//! * [`Normalization::Rms`] - [`RmsNorm`]\n//!\n//! [`NormalizationConfig`] can be used as a generic normalization policy:\n//! * Construct a config with arbitrary input features (we suggest `0`).\n//! * Clone and match that config to the target input layer,\n//!   using the [`NormalizationConfig::with_num_features()`] method.\npub(crate) mod batch;\npub(crate) mod group;\npub(crate) mod instance;\npub(crate) mod layer;\npub(crate) mod rms;\n\nmod normalization_wrapper;\n\npub use batch::*;\npub use group::*;\npub use instance::*;\npub use layer::*;\npub use normalization_wrapper::*;\npub use rms::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/normalization_wrapper.rs",
    "content": "use burn_core as burn;\n\nuse crate::{\n    BatchNorm, BatchNormConfig, GroupNorm, GroupNormConfig, InstanceNorm, InstanceNormConfig,\n    LayerNorm, LayerNormConfig, RmsNorm, RmsNormConfig,\n};\nuse burn::prelude::{Config, Module};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// ['Normalization'] Configuration.\n///\n/// The enum is non-exhaustive to prepare for future additions.\n///\n/// Can be used as a generic configuration for normalization layers:\n/// * Construct a config with arbitrary input features (we suggest `0`).\n/// * Clone and match that config to the target input layer,\n///   using the [`NormalizationConfig::with_num_features()`] method.\n#[derive(Config, Debug)]\n#[non_exhaustive]\npub enum NormalizationConfig {\n    /// ['BatchNorm'] Configuration.\n    Batch(BatchNormConfig),\n\n    /// ['GroupNorm'] Configuration.\n    Group(GroupNormConfig),\n\n    /// ['InstanceNorm'] Configuration.\n    Instance(InstanceNormConfig),\n\n    /// ['LayerNorm'] Configuration.\n    Layer(LayerNormConfig),\n\n    /// ['RmsNorm'] Configuration.\n    Rms(RmsNormConfig),\n}\n\nimpl From<BatchNormConfig> for NormalizationConfig {\n    fn from(config: BatchNormConfig) -> Self {\n        Self::Batch(config)\n    }\n}\n\nimpl From<GroupNormConfig> for NormalizationConfig {\n    fn from(config: GroupNormConfig) -> Self {\n        Self::Group(config)\n    }\n}\n\nimpl From<InstanceNormConfig> for NormalizationConfig {\n    fn from(config: InstanceNormConfig) -> Self {\n        Self::Instance(config)\n    }\n}\n\nimpl From<LayerNormConfig> for NormalizationConfig {\n    fn from(config: LayerNormConfig) -> Self {\n        Self::Layer(config)\n    }\n}\n\nimpl From<RmsNormConfig> for NormalizationConfig {\n    fn from(config: RmsNormConfig) -> Self {\n        Self::Rms(config)\n    }\n}\n\nimpl NormalizationConfig {\n    /// Initialize a ['Norm'] layer.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Normalization<B> {\n        match self {\n            NormalizationConfig::Batch(config) => config.init(device).into(),\n            NormalizationConfig::Group(config) => config.init(device).into(),\n            NormalizationConfig::Instance(config) => config.init(device).into(),\n            NormalizationConfig::Layer(config) => config.init(device).into(),\n            NormalizationConfig::Rms(config) => config.init(device).into(),\n        }\n    }\n\n    /// Set the number of features.\n    pub fn with_num_features(self, num_features: usize) -> Self {\n        match self {\n            NormalizationConfig::Batch(config) => BatchNormConfig {\n                num_features,\n                ..config\n            }\n            .into(),\n            NormalizationConfig::Group(config) => GroupNormConfig {\n                num_channels: num_features,\n                ..config\n            }\n            .into(),\n            NormalizationConfig::Instance(config) => InstanceNormConfig {\n                num_channels: num_features,\n                ..config\n            }\n            .into(),\n            NormalizationConfig::Layer(config) => LayerNormConfig {\n                d_model: num_features,\n                ..config\n            }\n            .into(),\n            NormalizationConfig::Rms(config) => RmsNormConfig {\n                d_model: num_features,\n                ..config\n            }\n            .into(),\n        }\n    }\n\n    /// Get the number of features.\n    pub fn num_features(&self) -> usize {\n        match self {\n            NormalizationConfig::Batch(config) => config.num_features,\n            NormalizationConfig::Group(config) => config.num_channels,\n            NormalizationConfig::Instance(config) => config.num_channels,\n            NormalizationConfig::Layer(config) => config.d_model,\n            NormalizationConfig::Rms(config) => config.d_model,\n        }\n    }\n}\n\n/// Normalization Layer Wrapper\n///\n/// Provides support for built-in ``burn::nn::norm`` norm layers:\n/// * [`Normalization::Batch`] - [`BatchNorm`]\n/// * [`Normalization::Group`] - [`GroupNorm`]\n/// * [`Normalization::Instance`] - [`InstanceNorm`]\n/// * [`Normalization::Layer`] - [`LayerNorm`]\n/// * [`Normalization::Rms`] - [`RmsNorm`]\n///\n/// The enum is non-exhaustive, to prepare for future additions.\n#[derive(Module, Debug)]\n#[non_exhaustive]\npub enum Normalization<B: Backend> {\n    /// [`BatchNorm`] layer.\n    Batch(BatchNorm<B>),\n\n    /// [`GroupNorm`] layer.\n    Group(GroupNorm<B>),\n\n    /// ['InstanceNorm'] layer.\n    Instance(InstanceNorm<B>),\n\n    /// [`LayerNorm`] layer.\n    Layer(LayerNorm<B>),\n\n    /// ['RmsNorm'] layer.\n    Rms(RmsNorm<B>),\n}\n\nimpl<B: Backend> From<BatchNorm<B>> for Normalization<B> {\n    fn from(layer: BatchNorm<B>) -> Self {\n        Self::Batch(layer)\n    }\n}\n\nimpl<B: Backend> From<GroupNorm<B>> for Normalization<B> {\n    fn from(layer: GroupNorm<B>) -> Self {\n        Self::Group(layer)\n    }\n}\n\nimpl<B: Backend> From<InstanceNorm<B>> for Normalization<B> {\n    fn from(layer: InstanceNorm<B>) -> Self {\n        Self::Instance(layer)\n    }\n}\n\nimpl<B: Backend> From<LayerNorm<B>> for Normalization<B> {\n    fn from(layer: LayerNorm<B>) -> Self {\n        Self::Layer(layer)\n    }\n}\n\nimpl<B: Backend> From<RmsNorm<B>> for Normalization<B> {\n    fn from(layer: RmsNorm<B>) -> Self {\n        Self::Rms(layer)\n    }\n}\n\nimpl<B: Backend> Normalization<B> {\n    /// Applies normalization to a tensor.\n    ///\n    /// The normalization contract depends upon the wrapped norm layer;\n    /// but all norm layers assume an input of at least rank 2;\n    /// and produce an output of the same rank and shape.\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        match self {\n            Normalization::Batch(norm) => norm.forward(input),\n            Normalization::Group(norm) => norm.forward(input),\n            Normalization::Instance(norm) => norm.forward(input),\n            Normalization::Layer(norm) => norm.forward(input),\n            Normalization::Rms(norm) => norm.forward(input),\n        }\n    }\n\n    /// Get the number of features.\n    pub fn num_features(&self) -> usize {\n        match self {\n            Normalization::Batch(norm) => norm.gamma.shape()[0],\n            Normalization::Group(norm) => norm.num_channels,\n            Normalization::Instance(norm) => norm.num_channels,\n            Normalization::Layer(norm) => norm.gamma.shape()[0],\n            Normalization::Rms(norm) => norm.gamma.shape()[0],\n        }\n    }\n}\n\n#[cfg(feature = \"std\")]\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestAutodiffBackend>;\n\n    #[test]\n    fn test_match_feature_size() {\n        let config: NormalizationConfig = BatchNormConfig::new(0).into();\n        assert_eq!(config.num_features(), 0);\n        let config = config.with_num_features(12);\n        assert_eq!(config.num_features(), 12);\n\n        let config: NormalizationConfig = GroupNormConfig::new(4, 0).into();\n        assert_eq!(config.num_features(), 0);\n        let config = config.with_num_features(12);\n        assert_eq!(config.num_features(), 12);\n\n        let config: NormalizationConfig = InstanceNormConfig::new(0).into();\n        assert_eq!(config.num_features(), 0);\n        let config = config.with_num_features(12);\n        assert_eq!(config.num_features(), 12);\n\n        let config: NormalizationConfig = LayerNormConfig::new(0).into();\n        assert_eq!(config.num_features(), 0);\n        let config = config.with_num_features(12);\n        assert_eq!(config.num_features(), 12);\n\n        let config: NormalizationConfig = RmsNormConfig::new(0).into();\n        assert_eq!(config.num_features(), 0);\n        let config = config.with_num_features(12);\n        assert_eq!(config.num_features(), 12);\n    }\n\n    #[test]\n    fn test_batch_norm() {\n        type B = TestAutodiffBackend;\n        let device = Default::default();\n\n        let num_features = 12;\n        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);\n\n        let config: NormalizationConfig = BatchNormConfig::new(12).into();\n\n        let layer: Normalization<B> = config.init(&device);\n        assert_eq!(layer.num_features(), 12);\n\n        let expected = match &layer {\n            Normalization::Batch(inner) => inner.forward(input.clone()),\n            _ => panic!(\"Unexpected layer type\"),\n        };\n\n        let output = layer.forward(input);\n\n        output.to_data().assert_eq(&expected.to_data(), true);\n    }\n\n    #[test]\n    fn test_group_norm() {\n        type B = TestAutodiffBackend;\n        let device = Default::default();\n\n        let num_features = 12;\n        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);\n\n        let config: NormalizationConfig = GroupNormConfig::new(3, num_features).into();\n\n        let layer: Normalization<B> = config.init(&device);\n        assert_eq!(layer.num_features(), 12);\n\n        let expected = match &layer {\n            Normalization::Group(inner) => inner.forward(input.clone()),\n            _ => panic!(\"Unexpected layer type\"),\n        };\n\n        let output = layer.forward(input);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_instance_norm() {\n        type B = TestAutodiffBackend;\n        let device = Default::default();\n\n        let num_features = 12;\n        let input: Tensor<B, 4> = Tensor::ones([2, num_features, 3, 4], &device);\n\n        let config: NormalizationConfig = InstanceNormConfig::new(num_features).into();\n\n        let layer: Normalization<B> = config.init(&device);\n        assert_eq!(layer.num_features(), 12);\n\n        let expected = match &layer {\n            Normalization::Instance(inner) => inner.forward(input.clone()),\n            _ => panic!(\"Unexpected layer type\"),\n        };\n\n        let output = layer.forward(input);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_layer_norm() {\n        type B = TestAutodiffBackend;\n        let device = Default::default();\n\n        let num_features = 12;\n        let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);\n\n        let config: NormalizationConfig = LayerNormConfig::new(num_features).into();\n\n        let layer: Normalization<B> = config.init(&device);\n        assert_eq!(layer.num_features(), 12);\n\n        let expected = match &layer {\n            Normalization::Layer(inner) => inner.forward(input.clone()),\n            _ => panic!(\"Unexpected layer type\"),\n        };\n\n        let output = layer.forward(input);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_rms_norm() {\n        type B = TestAutodiffBackend;\n        let device = Default::default();\n\n        let num_features = 12;\n        let input: Tensor<B, 4> = Tensor::ones([2, 3, 4, num_features], &device);\n\n        let config: NormalizationConfig = RmsNormConfig::new(num_features).into();\n\n        let layer: Normalization<B> = config.init(&device);\n        assert_eq!(layer.num_features(), 12);\n\n        let expected = match &layer {\n            Normalization::Rms(inner) => inner.forward(input.clone()),\n            _ => panic!(\"Unexpected layer type\"),\n        };\n\n        let output = layer.forward(input);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/norm/rms.rs",
    "content": "use burn::tensor::DType;\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::Param;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).\n#[derive(Config, Debug)]\npub struct RmsNormConfig {\n    /// The size of the input features.\n    pub d_model: usize,\n    /// A value required for numerical stability. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub epsilon: f64,\n}\n\nimpl RmsNormConfig {\n    /// Initialize a new [RMS Norm](RmsNorm) module.\n    ///\n    /// # Panics\n    ///\n    /// Panics if `epsilon` is not positive.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {\n        assert!(self.epsilon > 0.0, \"epsilon must be positive.\");\n\n        let gamma = Initializer::Ones.init([self.d_model], device);\n\n        RmsNorm {\n            gamma,\n            epsilon: self.epsilon,\n        }\n    }\n}\n\n/// Applies RMS Normalization over an input tensor along the last dimension.\n///\n/// `Y = X / sqrt(mean(X^2) + eps) * gamma`\n///\n/// Where:\n/// - `X` is the input tensor\n/// - `Y` is the output tensor\n/// - `gamma` is the learnable weight\n/// - `mean` is the mean operation\n/// - `eps` is a small value to avoid division by zero.\n///\n/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct RmsNorm<B: Backend> {\n    /// The learnable parameter to scale the normalized tensor\n    pub gamma: Param<Tensor<B, 1>>,\n    /// A value required for numerical stability\n    pub epsilon: f64,\n}\n\nimpl<B: Backend> RmsNorm<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See the [RmsNorm](RmsNorm) documentation for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[..., any, d_model]`\n    /// - output: `[..., any, d_model]`\n    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {\n        // Calculate the root-mean-square norm of the input tensor along the last dimension\n        let dtype = x.dtype();\n        let rms = (x.clone().cast(DType::F32).square().mean_dim(D - 1) + self.epsilon).sqrt();\n        (x / rms.cast(dtype)) * self.gamma.val().unsqueeze()\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for RmsNorm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_model] = self.gamma.shape().dims();\n        content\n            .add(\"d_model\", &d_model)\n            .add(\"epsilon\", &self.epsilon)\n            .optional()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use alloc::format;\n    use burn::tensor::TensorData;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn rms_norm_forward() {\n        let device = Default::default();\n        let module = RmsNormConfig::new(3)\n            .with_epsilon(1e-5)\n            .init::<TestBackend>(&device);\n\n        let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);\n        let output = module.forward(input);\n\n        let expected = TensorData::from([\n            [0.0000, 0.7746, 1.5492],\n            [0.7348, 0.9798, 1.2247],\n            [0.8514, 0.9933, 1.1352],\n        ]);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn display() {\n        let config = RmsNormConfig::new(6);\n        let layer_norm = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            format!(\"{layer_norm}\"),\n            \"RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/adaptive_avg_pool1d.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\nuse burn::tensor::module::adaptive_avg_pool1d;\n\n/// Configuration to create a [1D adaptive avg pooling](AdaptiveAvgPool1d) layer using the [init function](AdaptiveAvgPool1dConfig::init).\n#[derive(Config, Debug)]\npub struct AdaptiveAvgPool1dConfig {\n    /// The size of the output.\n    pub output_size: usize,\n}\n\n/// Applies a 1D adaptive avg pooling over input tensors.\n///\n/// Should be created with [AdaptiveAvgPool1dConfig].\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct AdaptiveAvgPool1d {\n    /// The size of the output.\n    pub output_size: usize,\n}\n\nimpl ModuleDisplay for AdaptiveAvgPool1d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content.add(\"output_size\", &self.output_size).optional()\n    }\n}\n\nimpl AdaptiveAvgPool1dConfig {\n    /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module.\n    pub fn init(&self) -> AdaptiveAvgPool1d {\n        AdaptiveAvgPool1d {\n            output_size: self.output_size,\n        }\n    }\n}\n\nimpl AdaptiveAvgPool1d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [adaptive_avg_pool1d](burn::tensor::module::adaptive_avg_pool1d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, length]`\n    /// - output: `[batch_size, channels, length_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        adaptive_avg_pool1d(input, self.output_size)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let config = AdaptiveAvgPool1dConfig::new(3);\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"AdaptiveAvgPool1d {output_size: 3}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/adaptive_avg_pool2d.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\nuse burn::tensor::module::adaptive_avg_pool2d;\n\n/// Configuration to create a [2D adaptive avg pooling](AdaptiveAvgPool2d) layer using the [init function](AdaptiveAvgPool2dConfig::init).\n#[derive(Config, Debug)]\npub struct AdaptiveAvgPool2dConfig {\n    /// The size of the output.\n    pub output_size: [usize; 2],\n}\n\n/// Applies a 2D adaptive avg pooling over input tensors.\n///\n/// Should be created with [AdaptiveAvgPool2dConfig].\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct AdaptiveAvgPool2d {\n    /// The size of the output.\n    pub output_size: [usize; 2],\n}\n\nimpl ModuleDisplay for AdaptiveAvgPool2d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let output_size = alloc::format!(\"{:?}\", self.output_size);\n\n        content.add(\"output_size\", &output_size).optional()\n    }\n}\n\nimpl AdaptiveAvgPool2dConfig {\n    /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module.\n    pub fn init(&self) -> AdaptiveAvgPool2d {\n        AdaptiveAvgPool2d {\n            output_size: self.output_size,\n        }\n    }\n}\n\nimpl AdaptiveAvgPool2d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [adaptive_avg_pool2d](burn::tensor::module::adaptive_avg_pool2d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, height_in, width_in]`\n    /// - output: `[batch_size, channels, height_out, width_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        adaptive_avg_pool2d(input, self.output_size)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let config = AdaptiveAvgPool2dConfig::new([3, 3]);\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"AdaptiveAvgPool2d {output_size: [3, 3]}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/avg_pool1d.rs",
    "content": "use burn_core as burn;\n\nuse crate::PaddingConfig1d;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::PadMode;\n\nuse burn::tensor::module::avg_pool1d;\n\n/// Configuration to create a [1D avg pooling](AvgPool1d) layer using the [init function](AvgPool1dConfig::init).\n#[derive(Config, Debug)]\npub struct AvgPool1dConfig {\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The stride.\n    #[config(default = \"kernel_size\")]\n    pub stride: usize,\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig1d::Valid\")]\n    pub padding: PaddingConfig1d,\n    /// If the padding is counted in the denominator when computing the average.\n    #[config(default = \"true\")]\n    pub count_include_pad: bool,\n    /// If true, use ceiling instead of floor for output size calculation.\n    #[config(default = \"false\")]\n    pub ceil_mode: bool,\n}\n\n/// Applies a 1D avg pooling over input tensors.\n///\n/// Should be created with [AvgPool1dConfig](AvgPool1dConfig).\n///\n/// # Remarks\n///\n/// The zero-padding values will be included in the calculation\n/// of the average. This means that the zeros are counted as\n/// legitimate values, and they contribute to the denominator\n/// when calculating the average. This is equivalent to\n/// `torch.nn.AvgPool2d` with `count_include_pad=True`.\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct AvgPool1d {\n    /// The stride.\n    pub stride: usize,\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The padding configuration.\n    pub padding: PaddingConfig1d,\n    /// If the padding is counted in the denominator when computing the average.\n    pub count_include_pad: bool,\n    /// If true, use ceiling instead of floor for output size calculation.\n    pub ceil_mode: bool,\n}\n\nimpl ModuleDisplay for AvgPool1d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"kernel_size\", &self.kernel_size)\n            .add(\"stride\", &self.stride)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .add(\"count_include_pad\", &self.count_include_pad)\n            .add(\"ceil_mode\", &self.ceil_mode)\n            .optional()\n    }\n}\n\nimpl AvgPool1dConfig {\n    /// Initialize a new [avg pool 1d](AvgPool1d) module.\n    pub fn init(&self) -> AvgPool1d {\n        AvgPool1d {\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            padding: self.padding.clone(),\n            count_include_pad: self.count_include_pad,\n            ceil_mode: self.ceil_mode,\n        }\n    }\n}\n\nimpl AvgPool1d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [avg_pool1d](burn::tensor::module::avg_pool1d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, length_in]`\n    /// - output: `[batch_size, channels, length_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let [_batch_size, _channels, length] = input.dims();\n\n        // Calculate padding as pair - handles Same, Valid, and Explicit uniformly\n        let (left, right) =\n            self.padding\n                .calculate_padding_1d_pair(length, self.kernel_size, self.stride);\n\n        // TODO: Move asymmetric padding to functional level via PoolOptions\n        // See: https://github.com/tracel-ai/burn/issues/4362\n        // Handle asymmetric padding by applying explicit pad operation first\n        if left != right {\n            // Burn's pad takes (left, right, top, bottom) for the last two dimensions\n            // For 1D (NCL format), we only pad L (last dim), so top/bottom = 0\n            let padded = input.pad((left, right, 0, 0), PadMode::Constant(0.0));\n            // Use zero padding for the pool operation since we already padded\n            avg_pool1d(\n                padded,\n                self.kernel_size,\n                self.stride,\n                0,\n                self.count_include_pad,\n                self.ceil_mode,\n            )\n        } else {\n            // Symmetric padding\n            avg_pool1d(\n                input,\n                self.kernel_size,\n                self.stride,\n                left,\n                self.count_include_pad,\n                self.ceil_mode,\n            )\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use rstest::rstest;\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = AvgPool1dConfig::new(2)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Same);\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=5]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);\n        let output = pool.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 2, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = AvgPool1dConfig::new(3);\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"AvgPool1d {kernel_size: 3, stride: 3, padding: Valid, count_include_pad: true, ceil_mode: false}\"\n        );\n    }\n\n    #[rstest]\n    #[case(1)]\n    #[case(2)]\n    fn default_strides_match_kernel_size(#[case] kernel_size: usize) {\n        let config = AvgPool1dConfig::new(kernel_size);\n\n        assert_eq!(\n            config.stride, kernel_size,\n            \"Expected stride ({:?}) to match kernel size ({:?}) in default AvgPool1dConfig::new constructor\",\n            config.stride, config.kernel_size\n        );\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create avg pool with asymmetric padding: left=1, right=2\n        let config = AvgPool1dConfig::new(3)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Explicit(1, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = pool.forward(input);\n\n        // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7\n        // Output length = (7 - 3) / 1 + 1 = 5\n        assert_eq!(output.dims(), [1, 2, 5]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create avg pool with symmetric explicit padding: left=2, right=2\n        let config = AvgPool1dConfig::new(3)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Explicit(2, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = pool.forward(input);\n\n        // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8\n        // Output length = (8 - 3) / 1 + 1 = 6\n        assert_eq!(output.dims(), [1, 2, 6]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/avg_pool2d.rs",
    "content": "use burn_core as burn;\n\nuse crate::PaddingConfig2d;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::PadMode;\n\nuse burn::tensor::module::avg_pool2d;\n\n/// Configuration to create a [2D avg pooling](AvgPool2d) layer using the [init function](AvgPool2dConfig::init).\n#[derive(Config, Debug)]\npub struct AvgPool2dConfig {\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The strides.\n    #[config(default = \"kernel_size\")]\n    pub strides: [usize; 2],\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig2d::Valid\")]\n    pub padding: PaddingConfig2d,\n    /// If the padding is counted in the denominator when computing the average.\n    #[config(default = \"true\")]\n    pub count_include_pad: bool,\n    /// If true, use ceiling instead of floor for output size calculation.\n    #[config(default = \"false\")]\n    pub ceil_mode: bool,\n}\n\n/// Applies a 2D avg pooling over input tensors.\n///\n/// Should be created with [AvgPool2dConfig](AvgPool2dConfig).\n///\n/// # Remarks\n///\n/// The zero-padding values will be included in the calculation\n/// of the average. This means that the zeros are counted as\n/// legitimate values, and they contribute to the denominator\n/// when calculating the average. This is equivalent to\n/// `torch.nn.AvgPool2d` with `count_include_pad=True`.\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct AvgPool2d {\n    /// Stride of the pooling.\n    pub stride: [usize; 2],\n    /// Size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// Padding configuration.\n    pub padding: PaddingConfig2d,\n    /// If the padding is counted in the denominator when computing the average.\n    pub count_include_pad: bool,\n    /// If true, use ceiling instead of floor for output size calculation.\n    pub ceil_mode: bool,\n}\n\nimpl ModuleDisplay for AvgPool2d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"kernel_size\", &alloc::format!(\"{:?}\", &self.kernel_size))\n            .add(\"stride\", &alloc::format!(\"{:?}\", &self.stride))\n            .add_debug_attribute(\"padding\", &self.padding)\n            .add(\"count_include_pad\", &self.count_include_pad)\n            .add(\"ceil_mode\", &self.ceil_mode)\n            .optional()\n    }\n}\n\nimpl AvgPool2dConfig {\n    /// Initialize a new [avg pool 2d](AvgPool2d) module.\n    pub fn init(&self) -> AvgPool2d {\n        AvgPool2d {\n            stride: self.strides,\n            kernel_size: self.kernel_size,\n            padding: self.padding.clone(),\n            count_include_pad: self.count_include_pad,\n            ceil_mode: self.ceil_mode,\n        }\n    }\n}\n\nimpl AvgPool2d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [avg_pool2d](burn::tensor::module::avg_pool2d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, height_in, width_in]`\n    /// - output: `[batch_size, channels, height_out, width_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let [_batch_size, _channels_in, height_in, width_in] = input.dims();\n\n        // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly\n        let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(\n            height_in,\n            width_in,\n            &self.kernel_size,\n            &self.stride,\n        );\n\n        // TODO: Move asymmetric padding to functional level via PoolOptions\n        // See: https://github.com/tracel-ai/burn/issues/4362\n        // Handle asymmetric padding by applying explicit pad operation first\n        if top != bottom || left != right {\n            // Burn's pad takes (left, right, top, bottom) for the last two dimensions\n            let padded = input.pad((left, right, top, bottom), PadMode::Constant(0.0));\n            // Use zero padding for the pool operation since we already padded\n            avg_pool2d(\n                padded,\n                self.kernel_size,\n                self.stride,\n                [0, 0],\n                self.count_include_pad,\n                self.ceil_mode,\n            )\n        } else {\n            // Symmetric padding\n            avg_pool2d(\n                input,\n                self.kernel_size,\n                self.stride,\n                [top, left],\n                self.count_include_pad,\n                self.ceil_mode,\n            )\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use rstest::rstest;\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = AvgPool2dConfig::new([2, 2])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Same);\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=5, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);\n        let output = pool.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 2, 5, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = AvgPool2dConfig::new([3, 3]);\n\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"AvgPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}\"\n        );\n    }\n\n    #[rstest]\n    #[case([2, 2])]\n    #[case([1, 2])]\n    fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {\n        let config = AvgPool2dConfig::new(kernel_size);\n\n        assert_eq!(\n            config.strides, kernel_size,\n            \"Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool2dConfig::new constructor\",\n            config.strides, config.kernel_size\n        );\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create avg pool with asymmetric padding: top=1, left=2, bottom=3, right=4\n        let config = AvgPool2dConfig::new([3, 3])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = pool.forward(input);\n\n        // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9\n        assert_eq!(output.dims(), [1, 2, 6, 9]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create avg pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2\n        let config = AvgPool2dConfig::new([3, 3])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = pool.forward(input);\n\n        // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7\n        assert_eq!(output.dims(), [1, 2, 6, 7]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/max_pool1d.rs",
    "content": "use burn_core as burn;\n\nuse crate::PaddingConfig1d;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::PadMode;\n\nuse burn::tensor::module::max_pool1d;\n\n/// Configuration to create a [1D max pooling](MaxPool1d) layer using the [init function](MaxPool1dConfig::init).\n#[derive(Config, Debug)]\npub struct MaxPool1dConfig {\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The stride.\n    #[config(default = \"kernel_size\")]\n    pub stride: usize,\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig1d::Valid\")]\n    pub padding: PaddingConfig1d,\n    /// The dilation.\n    #[config(default = \"1\")]\n    pub dilation: usize,\n    /// If true, use ceiling instead of floor for output size calculation.\n    #[config(default = \"false\")]\n    pub ceil_mode: bool,\n}\n\n/// Applies a 1D max pooling over input tensors.\n///\n/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct MaxPool1d {\n    /// The stride.\n    pub stride: usize,\n    /// The size of the kernel.\n    pub kernel_size: usize,\n    /// The padding configuration.\n    pub padding: PaddingConfig1d,\n    /// The dilation.\n    pub dilation: usize,\n    /// If true, use ceiling instead of floor for output size calculation.\n    pub ceil_mode: bool,\n}\n\nimpl ModuleDisplay for MaxPool1d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"kernel_size\", &self.kernel_size)\n            .add(\"stride\", &self.stride)\n            .add_debug_attribute(\"padding\", &self.padding)\n            .add(\"dilation\", &self.dilation)\n            .add(\"ceil_mode\", &self.ceil_mode)\n            .optional()\n    }\n}\n\nimpl MaxPool1dConfig {\n    /// Initialize a new [max pool 1d](MaxPool1d) module.\n    pub fn init(&self) -> MaxPool1d {\n        MaxPool1d {\n            stride: self.stride,\n            kernel_size: self.kernel_size,\n            padding: self.padding.clone(),\n            dilation: self.dilation,\n            ceil_mode: self.ceil_mode,\n        }\n    }\n}\n\nimpl MaxPool1d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [max_pool1d](burn::tensor::module::max_pool1d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, length_in]`\n    /// - output: `[batch_size, channels, length_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let [_batch_size, _channels, length] = input.dims();\n\n        // Calculate padding as pair - handles Same, Valid, and Explicit uniformly\n        let (left, right) =\n            self.padding\n                .calculate_padding_1d_pair(length, self.kernel_size, self.stride);\n\n        // TODO: Move asymmetric padding to functional level via PoolOptions\n        // See: https://github.com/tracel-ai/burn/issues/4362\n        // Handle asymmetric padding by applying explicit pad operation first\n        if left != right {\n            // For 1D (NCL format), pad the length dimension with (left, right)\n            // and no padding for channel dimension (top=0, bottom=0)\n            // Use -inf for max pooling so padded values don't affect the max\n            let padded = input.pad((left, right, 0, 0), PadMode::Constant(f32::NEG_INFINITY));\n            // Use zero padding for the pool operation since we already padded\n            max_pool1d(\n                padded,\n                self.kernel_size,\n                self.stride,\n                0,\n                self.dilation,\n                self.ceil_mode,\n            )\n        } else {\n            // Symmetric padding\n            max_pool1d(\n                input,\n                self.kernel_size,\n                self.stride,\n                left,\n                self.dilation,\n                self.ceil_mode,\n            )\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use rstest::rstest;\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = MaxPool1dConfig::new(2)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Same);\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=5]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 5], &device);\n        let output = pool.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 2, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = MaxPool1dConfig::new(3);\n\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"MaxPool1d {kernel_size: 3, stride: 3, padding: Valid, dilation: 1, ceil_mode: false}\"\n        );\n    }\n\n    #[rstest]\n    #[case(1)]\n    #[case(2)]\n    fn default_strides_match_kernel_size(#[case] kernel_size: usize) {\n        let config = MaxPool1dConfig::new(kernel_size);\n\n        assert_eq!(\n            config.stride, kernel_size,\n            \"Expected stride ({:?}) to match kernel size ({:?}) in default MaxPool1dConfig::new constructor\",\n            config.stride, config.kernel_size\n        );\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create max pool with asymmetric padding: left=1, right=2\n        let config = MaxPool1dConfig::new(3)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Explicit(1, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = pool.forward(input);\n\n        // With asymmetric padding (1, 2), input length 4 becomes 4+1+2=7\n        // Output length = (7 - 3) / 1 + 1 = 5\n        assert_eq!(output.dims(), [1, 2, 5]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create max pool with symmetric explicit padding: left=2, right=2\n        let config = MaxPool1dConfig::new(3)\n            .with_stride(1)\n            .with_padding(PaddingConfig1d::Explicit(2, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, length=4]\n        let input = Tensor::<TestBackend, 3>::ones([1, 2, 4], &device);\n        let output = pool.forward(input);\n\n        // With symmetric padding (2, 2), input length 4 becomes 4+2+2=8\n        // Output length = (8 - 3) / 1 + 1 = 6\n        assert_eq!(output.dims(), [1, 2, 6]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/max_pool2d.rs",
    "content": "use burn_core as burn;\n\nuse crate::PaddingConfig2d;\nuse burn::config::Config;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::ops::PadMode;\n\nuse burn::tensor::module::max_pool2d;\n\n/// Configuration to create a [2D max pooling](MaxPool2d) layer using the [init function](MaxPool2dConfig::init).\n#[derive(Debug, Config)]\npub struct MaxPool2dConfig {\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The strides.\n    #[config(default = \"kernel_size\")]\n    pub strides: [usize; 2],\n    /// The padding configuration.\n    ///\n    /// Supports symmetric and asymmetric padding. `Same` padding with even kernel sizes\n    /// will automatically use asymmetric padding to preserve input dimensions.\n    #[config(default = \"PaddingConfig2d::Valid\")]\n    pub padding: PaddingConfig2d,\n    /// The dilation.\n    #[config(default = \"[1, 1]\")]\n    pub dilation: [usize; 2],\n    /// If true, use ceiling instead of floor for output size calculation.\n    #[config(default = \"false\")]\n    pub ceil_mode: bool,\n}\n\n/// Applies a 2D max pooling over input tensors.\n///\n/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct MaxPool2d {\n    /// The strides.\n    pub stride: [usize; 2],\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The padding configuration.\n    pub padding: PaddingConfig2d,\n    /// The dilation.\n    pub dilation: [usize; 2],\n    /// If true, use ceiling instead of floor for output size calculation.\n    pub ceil_mode: bool,\n}\n\nimpl ModuleDisplay for MaxPool2d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"kernel_size\", &alloc::format!(\"{:?}\", &self.kernel_size))\n            .add(\"stride\", &alloc::format!(\"{:?}\", &self.stride))\n            .add_debug_attribute(\"padding\", &self.padding)\n            .add(\"dilation\", &alloc::format!(\"{:?}\", &self.dilation))\n            .add(\"ceil_mode\", &self.ceil_mode)\n            .optional()\n    }\n}\n\nimpl MaxPool2dConfig {\n    /// Initialize a new [max pool 2d](MaxPool2d) module.\n    pub fn init(&self) -> MaxPool2d {\n        MaxPool2d {\n            stride: self.strides,\n            kernel_size: self.kernel_size,\n            padding: self.padding.clone(),\n            dilation: self.dilation,\n            ceil_mode: self.ceil_mode,\n        }\n    }\n}\n\nimpl MaxPool2d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [max_pool2d](burn::tensor::module::max_pool2d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, channels, height_in, width_in]`\n    /// - output: `[batch_size, channels, height_out, width_out]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let [_batch_size, _channels_in, height_in, width_in] = input.dims();\n\n        // Calculate padding as pairs - handles Same, Valid, and Explicit uniformly\n        let ((top, bottom), (left, right)) = self.padding.calculate_padding_2d_pairs(\n            height_in,\n            width_in,\n            &self.kernel_size,\n            &self.stride,\n        );\n\n        // TODO: Move asymmetric padding to functional level via PoolOptions\n        // See: https://github.com/tracel-ai/burn/issues/4362\n        // Handle asymmetric padding by applying explicit pad operation first\n        if top != bottom || left != right {\n            // Burn's pad takes (left, right, top, bottom) for the last two dimensions\n            // Use -inf for max pooling so padded values don't affect the max\n            let padded = input.pad(\n                (left, right, top, bottom),\n                PadMode::Constant(f32::NEG_INFINITY),\n            );\n            // Use zero padding for the pool operation since we already padded\n            max_pool2d(\n                padded,\n                self.kernel_size,\n                self.stride,\n                [0, 0],\n                self.dilation,\n                self.ceil_mode,\n            )\n        } else {\n            // Symmetric padding\n            max_pool2d(\n                input,\n                self.kernel_size,\n                self.stride,\n                [top, left],\n                self.dilation,\n                self.ceil_mode,\n            )\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use rstest::rstest;\n\n    #[test]\n    fn same_with_even_kernel_uses_asymmetric_padding() {\n        let device = Default::default();\n        let config = MaxPool2dConfig::new([2, 2])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Same);\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=5, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 5, 5], &device);\n        let output = pool.forward(input);\n\n        // Same padding should preserve spatial dimensions\n        assert_eq!(output.dims(), [1, 2, 5, 5]);\n    }\n\n    #[test]\n    fn display() {\n        let config = MaxPool2dConfig::new([3, 3]);\n\n        let layer = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"MaxPool2d {kernel_size: [3, 3], stride: [3, 3], padding: Valid, dilation: [1, 1], ceil_mode: false}\"\n        );\n    }\n\n    #[rstest]\n    #[case([2, 2])]\n    #[case([1, 2])]\n    fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 2]) {\n        let config = MaxPool2dConfig::new(kernel_size);\n\n        assert_eq!(\n            config.strides, kernel_size,\n            \"Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool2dConfig::new constructor\",\n            config.strides, config.kernel_size\n        );\n    }\n\n    #[test]\n    fn asymmetric_padding_forward() {\n        let device = Default::default();\n        // Create max pool with asymmetric padding: top=1, left=2, bottom=3, right=4\n        let config = MaxPool2dConfig::new([3, 3])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Explicit(1, 2, 3, 4));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = pool.forward(input);\n\n        // Height: 4 + 1 + 3 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 4 = 11, output = (11 - 3) / 1 + 1 = 9\n        assert_eq!(output.dims(), [1, 2, 6, 9]);\n    }\n\n    #[test]\n    fn symmetric_explicit_padding_forward() {\n        let device = Default::default();\n        // Create max pool with symmetric explicit padding: top=2, left=2, bottom=2, right=2\n        let config = MaxPool2dConfig::new([3, 3])\n            .with_strides([1, 1])\n            .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2));\n        let pool = config.init();\n\n        // Input: [batch=1, channels=2, height=4, width=5]\n        let input = Tensor::<TestBackend, 4>::ones([1, 2, 4, 5], &device);\n        let output = pool.forward(input);\n\n        // Height: 4 + 2 + 2 = 8, output = (8 - 3) / 1 + 1 = 6\n        // Width: 5 + 2 + 2 = 9, output = (9 - 3) / 1 + 1 = 7\n        assert_eq!(output.dims(), [1, 2, 6, 7]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pool/mod.rs",
    "content": "mod adaptive_avg_pool1d;\nmod adaptive_avg_pool2d;\nmod avg_pool1d;\nmod avg_pool2d;\nmod max_pool1d;\nmod max_pool2d;\n\npub use adaptive_avg_pool1d::*;\npub use adaptive_avg_pool2d::*;\npub use avg_pool1d::*;\npub use avg_pool2d::*;\npub use max_pool1d::*;\npub use max_pool2d::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/pos_encoding.rs",
    "content": "use burn_core as burn;\n\nuse alloc::vec::Vec;\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\n\nuse burn::tensor::Tensor;\nuse burn::tensor::TensorData;\nuse burn::tensor::backend::Backend;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).\n#[derive(Config, Debug)]\npub struct PositionalEncodingConfig {\n    /// Maximum sequence size to use.\n    #[config(default = \"5_000\")]\n    pub max_sequence_size: usize,\n\n    /// The size of each vector.\n    pub d_model: usize,\n\n    /// Max time scale to use.\n    #[config(default = \"10_000\")]\n    pub max_timescale: usize,\n}\n\n/// Positional encoding layer for transformer models.\n///\n/// This layer adds positional information to the input embeddings, allowing the transformer model\n/// to take into account the order of the sequence. The positional encoding is added to the input\n/// embeddings by computing a set of sinusoidal functions with different frequencies and phases.\n///\n/// Sinusoids are used for positional embedding introduced in\n/// [Attention is all you need](https://arxiv.org/abs/1706.03762).\n///\n/// The reference implementation can be found here:\n/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT\n/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)\n///\n/// Should be created using [PositionalEncodingConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct PositionalEncoding<B: Backend> {\n    /// The sinusoids used to add positional information to the input embeddings.\n    pub sinusoids: Tensor<B, 3>,\n    /// The maximum sequence size to use.\n    pub max_sequence_size: usize,\n    /// Max time scale to use.\n    pub max_timescale: usize,\n}\n\nimpl<B: Backend> ModuleDisplay for PositionalEncoding<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [_, _, d_model] = self.sinusoids.shape().dims();\n        content\n            .add(\"d_model\", &d_model)\n            .add(\"max_sequence_size\", &self.max_sequence_size)\n            .add(\"max_timescale\", &self.max_timescale)\n            .optional()\n    }\n}\n\nimpl PositionalEncodingConfig {\n    /// Initialize a new [PositionalEncoding](PositionalEncoding) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {\n        let sinusoids = generate_sinusoids::<B>(\n            self.max_sequence_size,\n            self.d_model,\n            self.max_timescale,\n            device,\n        )\n        .unsqueeze::<3>();\n\n        PositionalEncoding {\n            sinusoids,\n            max_sequence_size: self.max_sequence_size,\n            max_timescale: self.max_timescale,\n        }\n    }\n}\n\nimpl<B: Backend> PositionalEncoding<B> {\n    /// Applies the forward pass on the input tensor by adding the sinusoids to the input.\n    ///\n    /// # Shapes\n    ///\n    /// * input: [batch_size, seq_length, d_model]\n    /// * output: [batch_size, seq_length, d_model]\n    ///\n    ///\n    /// # Panics\n    ///\n    /// * Panics if the input sequence length is greater than the maximum sequence size.\n    /// * Panics if the input d_model is not equal to the d_model of the sinusoids.\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let [_, seq_length, d_model_input] = input.dims();\n\n        let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();\n\n        assert!(\n            max_sequence_size >= seq_length,\n            \"max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})\"\n        );\n\n        assert!(\n            d_model_input == d_model,\n            \"d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})\"\n        );\n\n        let slices = [0..batch_size, 0..seq_length, 0..d_model];\n\n        input.add(self.sinusoids.clone().slice(slices))\n    }\n}\n\n/// Returns sinusoids for positional embedding introduced in\n/// [Attention is all you need](https://arxiv.org/abs/1706.03762).\n///\n/// The reference implementation can be found here:\n/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT\n/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)\n///\n/// # Arguments\n///\n/// * `length` - The length of the sequence.\n/// * `d_model` - The size of each vector.\n/// * `max_timescale` - The maximum time scale to use.\n///\n/// # Returns\n///\n/// A tensor of shape [length, d_model] containing the sinusoids.\npub fn generate_sinusoids<B: Backend>(\n    length: usize,\n    d_model: usize,\n    max_timescale: usize,\n    device: &B::Device,\n) -> Tensor<B, 2> {\n    assert!(d_model.is_multiple_of(2), \"d_model must be even\");\n    assert!(\n        max_timescale >= length,\n        \"max_timescale must be greater than length\"\n    );\n\n    // Calculate the increment for the logarithmic timescale\n    let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;\n\n    // Create a vector to hold the sinusoids\n    let mut scaled_time_sin_cos = Vec::with_capacity(length);\n\n    // Loop over each position in the sequence\n    for i in 0..length {\n        // Create a vector to hold the sinusoids for this position\n        let mut row = Vec::with_capacity(d_model / 2);\n        // Loop over each dimension of the sinusoids\n        for k in (0..d_model).step_by(2) {\n            // Calculate the division term for this dimension\n            let div_term = (k as f32 * log_timescale_increment).exp();\n            // Calculate the sine and cosine values for this dimension and position\n            row.push((div_term * i as f32).sin());\n            row.push((div_term * i as f32).cos());\n        }\n\n        // Add the sinusoids for this position to the vector\n        scaled_time_sin_cos.push(row);\n    }\n\n    // Convert the sinusoids to a tensor and return it\n    let data = TensorData::new(\n        scaled_time_sin_cos.into_iter().flatten().collect(),\n        [length, d_model],\n    );\n\n    Tensor::<B, 2>::from_data(data, device)\n}\n\n#[cfg(test)]\nmod tests {\n\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_module() {\n        let d_model = 6;\n        let length = 3;\n\n        // expected to broadcast\n        let batch_size = 2;\n\n        let device = Default::default();\n        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);\n\n        // Use a tensor of zeros as input for easy verification of the output\n        // The output should be the sinusoids broadcasted to the input shape\n        let tensor = Tensor::zeros([batch_size, length, d_model], &device);\n\n        let output = pe.forward(tensor);\n\n        assert_eq!(&*output.shape(), [batch_size, length, d_model]);\n\n        let expected = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],\n                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],\n                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],\n                ],\n                [\n                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],\n                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],\n                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],\n                ],\n            ],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_generate_sinusoids() {\n        let device = Default::default();\n        let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);\n\n        // The values are taken from the pytorch reference implementation\n        let expected = Tensor::<TestBackend, 2>::from_floats(\n            [\n                [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],\n                [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],\n                [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],\n                [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],\n                [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],\n                [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],\n                [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],\n                [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],\n                [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],\n                [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],\n                [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],\n                [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],\n            ],\n            &device,\n        );\n        sinusoids\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    #[should_panic]\n    fn d_model_input_should_match() {\n        let d_model = 8;\n        let device = Default::default();\n        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);\n        let input = Tensor::zeros([1, 5, 10], &device);\n        let _output = pe.forward(input);\n    }\n\n    #[test]\n    #[should_panic]\n    fn input_length_should_be_less_than_max_len() {\n        let d_model = 8;\n        let device = Default::default();\n        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);\n        let input = Tensor::zeros([1, 6_000, d_model], &device);\n        let _output = pe.forward(input);\n    }\n\n    #[test]\n    fn display() {\n        let config = PositionalEncodingConfig::new(4);\n        let pe = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{pe}\"),\n            \"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rnn/basic.rs",
    "content": "use burn_core as burn;\n\nuse crate::GateController;\nuse crate::activation::{Activation, ActivationConfig};\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// A RnnState is used to store hidden state in RNN.\npub struct RnnState<B: Backend, const D: usize> {\n    /// The hidden state.\n    pub hidden: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> RnnState<B, D> {\n    /// Initialize a new [RNN State](RnnState).\n    pub fn new(hidden: Tensor<B, D>) -> Self {\n        Self { hidden }\n    }\n}\n\n/// Configuration to create a [Rnn](Rnn) module using the [init function](RnnConfig::init).\n#[derive(Config, Debug)]\npub struct RnnConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the Rnn transformation.\n    pub bias: bool,\n    /// Rnn initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.\n    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.\n    #[config(default = true)]\n    pub batch_first: bool,\n    /// If true, process the sequence in reverse order.\n    /// This is useful for implementing reverse-direction RNNs (e.g., ONNX reverse direction).\n    #[config(default = false)]\n    pub reverse: bool,\n    /// Optional hidden state clip threshold. If provided, hidden state values are clipped\n    /// to the range `[-clip, +clip]` after each timestep. This can help prevent\n    /// exploding values during inference.\n    pub clip: Option<f64>,\n    /// Activation function applied to the hidden state before computing hidden output.\n    /// Default is Tanh, which is standard for Rnn.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n}\n\n/// The Rnn module. This implementation is for a unidirectional, stateless, Rnn.\n/// Should be created with [RnnConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Rnn<B: Backend> {\n    /// gate controller for Rnn (has single gate).\n    pub gate: GateController<B>,\n    /// The hidden state of the Rnn.\n    pub d_hidden: usize,\n    /// If true, input is `[batch_size, seq_length, input_size]`.\n    /// If false, input is `[seq_length, batch_size, input_size]`.\n    pub batch_first: bool,\n    /// If true, process the sequence in reverse order.\n    pub reverse: bool,\n    /// Optional hidden state clip threshold.\n    pub clip: Option<f64>,\n    /// Activation function for hidden output.\n    pub hidden_activation: Activation<B>,\n}\n\nimpl<B: Backend> ModuleDisplay for Rnn<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self.gate.input_transform.weight.shape().dims();\n        let bias = self.gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .optional()\n    }\n}\n\nimpl RnnConfig {\n    /// Initialize a new [Rnn](Rnn) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Rnn<B> {\n        let d_output = self.d_hidden;\n\n        let new_gate = || {\n            GateController::new(\n                self.d_input,\n                d_output,\n                self.bias,\n                self.initializer.clone(),\n                device,\n            )\n        };\n\n        Rnn {\n            gate: new_gate(),\n            d_hidden: self.d_hidden,\n            batch_first: self.batch_first,\n            reverse: self.reverse,\n            clip: self.clip,\n            hidden_activation: self.hidden_activation.init(device),\n        }\n    }\n}\n\nimpl<B: Backend> Rnn<B> {\n    /// Applies the forward pass on the input tensor. This RNN implementation\n    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.\n    ///\n    /// ## Parameters:\n    /// - batched_input: The input tensor of shape:\n    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)\n    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false\n    /// - state: An optional `RnnState` representing the initial hidden state.\n    ///   The state tensor has shape `[batch_size, hidden_size]`.\n    ///   If no initial state is provided, these tensors are initialized to zeros.\n    ///\n    /// ## Returns:\n    /// - output: A tensor represents the output features of Rnn. Shape:\n    ///   - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true\n    ///   - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false\n    /// - state: A `RnnState` represents the final hidden state. The hidden state tensor has the shape\n    ///   `[batch_size, hidden_size]`.\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<RnnState<B, 2>>,\n    ) -> (Tensor<B, 3>, RnnState<B, 2>) {\n        // Convert to batch-first layout internally if needed\n        let batched_input = if self.batch_first {\n            batched_input\n        } else {\n            batched_input.swap_dims(0, 1)\n        };\n\n        let device = batched_input.device();\n        let [batch_size, seq_length, _] = batched_input.dims();\n\n        // Process sequence in forward or reverse order based on config\n        let (output, state) = if self.reverse {\n            self.forward_iter(\n                batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),\n                state,\n                batch_size,\n                seq_length,\n                &device,\n            )\n        } else {\n            self.forward_iter(\n                batched_input.iter_dim(1).zip(0..seq_length),\n                state,\n                batch_size,\n                seq_length,\n                &device,\n            )\n        };\n\n        // Convert output back to seq-first layout if needed\n        let output = if self.batch_first {\n            output\n        } else {\n            output.swap_dims(0, 1)\n        };\n\n        (output, state)\n    }\n\n    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(\n        &self,\n        input_timestep_iter: I,\n        state: Option<RnnState<B, 2>>,\n        batch_size: usize,\n        seq_length: usize,\n        device: &B::Device,\n    ) -> (Tensor<B, 3>, RnnState<B, 2>) {\n        let mut batched_hidden_state =\n            Tensor::empty([batch_size, seq_length, self.d_hidden], device);\n\n        let mut hidden_state = match state {\n            Some(state) => state.hidden,\n            None => Tensor::zeros([batch_size, self.d_hidden], device),\n        };\n\n        for (input_t, t) in input_timestep_iter {\n            let input_t = input_t.squeeze_dim(1);\n\n            // Compute gate output: h_t = activation(W_i @ x_t + W_h @ h_{t-1} + b)\n            let biased_gate_sum = self\n                .gate\n                .gate_product(input_t.clone(), hidden_state.clone());\n\n            let output_values = self.hidden_activation.forward(biased_gate_sum);\n\n            // Update hidden state\n            hidden_state = output_values;\n\n            // Apply hidden state clipping if configured\n            if let Some(clip) = self.clip {\n                hidden_state = hidden_state.clamp(-clip, clip);\n            }\n\n            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);\n\n            // store the hidden state for this timestep\n            batched_hidden_state = batched_hidden_state.slice_assign(\n                [0..batch_size, t..(t + 1), 0..self.d_hidden],\n                unsqueezed_hidden_state.clone(),\n            );\n        }\n\n        (batched_hidden_state, RnnState::new(hidden_state))\n    }\n}\n\n/// Configuration to create a [BiRnn](BiRnn) module using the [init function](BiRnnConfig::init).\n#[derive(Config, Debug)]\npub struct BiRnnConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the BiRnn transformation.\n    pub bias: bool,\n    /// BiRnn initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.\n    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.\n    #[config(default = true)]\n    pub batch_first: bool,\n    /// Optional hidden state clip threshold.\n    pub clip: Option<f64>,\n    /// Activation function applied to the hidden state before computing hidden output.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n}\n\n/// The BiRnn module. This implementation is for Bidirectional RNN.\n/// Should be created with [BiRnnConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct BiRnn<B: Backend> {\n    /// RNN for the forward direction.\n    pub forward: Rnn<B>,\n    /// RNN for the reverse direction.\n    pub reverse: Rnn<B>,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If true, input is `[batch_size, seq_length, input_size]`.\n    /// If false, input is `[seq_length, batch_size, input_size]`.\n    pub batch_first: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for BiRnn<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self.forward.gate.input_transform.weight.shape().dims();\n        let bias = self.forward.gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .optional()\n    }\n}\n\nimpl BiRnnConfig {\n    /// Initialize a new [Bidirectional RNN](BiRnn) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> BiRnn<B> {\n        // Internal RNNs always use batch_first=true; BiRnn handles layout conversion\n        let base_config = RnnConfig::new(self.d_input, self.d_hidden, self.bias)\n            .with_initializer(self.initializer.clone())\n            .with_batch_first(true)\n            .with_clip(self.clip)\n            .with_hidden_activation(self.hidden_activation.clone());\n\n        BiRnn {\n            forward: base_config.clone().init(device),\n            reverse: base_config.init(device),\n            d_hidden: self.d_hidden,\n            batch_first: self.batch_first,\n        }\n    }\n}\n\nimpl<B: Backend> BiRnn<B> {\n    /// Applies the forward pass on the input tensor. This Bidirectional RNN implementation\n    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.\n    ///\n    /// ## Parameters:\n    /// - batched_input: The input tensor of shape:\n    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)\n    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false\n    /// - state: An optional `RnnState` representing the hidden state.\n    ///   Each state tensor has shape `[2, batch_size, hidden_size]`.\n    ///   If no initial state is provided, these tensors are initialized to zeros.\n    ///\n    /// ## Returns:\n    /// - output: A tensor represents the output features of RNN. Shape:\n    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true\n    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false\n    /// - state: A `RnnState` represents the final forward and reverse states.\n    ///   The `state.hidden` have the shape `[2, batch_size, hidden_size]`.\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<RnnState<B, 3>>,\n    ) -> (Tensor<B, 3>, RnnState<B, 3>) {\n        // Convert to batch-first layout internally if needed\n        let batched_input = if self.batch_first {\n            batched_input\n        } else {\n            batched_input.swap_dims(0, 1)\n        };\n\n        let device = batched_input.clone().device();\n        let [batch_size, seq_length, _] = batched_input.shape().dims();\n\n        let [init_state_forward, init_state_reverse] = match state {\n            Some(state) => {\n                let hidden_state_forward = state\n                    .hidden\n                    .clone()\n                    .slice([0..1, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n                let hidden_state_reverse = state\n                    .hidden\n                    .slice([1..2, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n\n                [\n                    Some(RnnState::new(hidden_state_forward)),\n                    Some(RnnState::new(hidden_state_reverse)),\n                ]\n            }\n            None => [None, None],\n        };\n\n        // forward direction\n        let (batched_hidden_state_forward, final_state_forward) = self\n            .forward\n            .forward(batched_input.clone(), init_state_forward);\n\n        // reverse direction\n        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(\n            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),\n            init_state_reverse,\n            batch_size,\n            seq_length,\n            &device,\n        );\n\n        let output = Tensor::cat(\n            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),\n            2,\n        );\n\n        // Convert output back to seq-first layout if needed\n        let output = if self.batch_first {\n            output\n        } else {\n            output.swap_dims(0, 1)\n        };\n\n        let state = RnnState::new(Tensor::stack(\n            [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),\n            0,\n        ));\n\n        (output, state)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{LinearRecord, TestBackend};\n    use burn::module::Param;\n    use burn::tensor::{Device, Distribution, TensorData};\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[cfg(feature = \"std\")]\n    use crate::TestAutodiffBackend;\n\n    fn create_single_feature_gate_controller(\n        weights: f32,\n        biases: f32,\n        d_input: usize,\n        d_output: usize,\n        bias: bool,\n        initializer: Initializer,\n        device: &Device<TestBackend>,\n    ) -> GateController<TestBackend> {\n        let record_1 = LinearRecord {\n            weight: Param::from_data(TensorData::from([[weights]]), device),\n            bias: Some(Param::from_data(TensorData::from([biases]), device)),\n        };\n        let record_2 = LinearRecord {\n            weight: Param::from_data(TensorData::from([[weights]]), device),\n            bias: Some(Param::from_data(TensorData::from([biases]), device)),\n        };\n        GateController::create_with_weights(\n            d_input,\n            d_output,\n            bias,\n            initializer,\n            record_1,\n            record_2,\n        )\n    }\n\n    #[test]\n    fn test_with_uniform_initializer() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = RnnConfig::new(5, 5, false)\n            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });\n        let rnn = config.init::<TestBackend>(&Default::default());\n\n        let gate_to_data =\n            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();\n\n        gate_to_data(rnn.gate).assert_within_range::<FT>(0.elem()..1.elem());\n    }\n\n    /// Test forward pass with simple input vector.\n    ///\n    /// Simple RNN: h_t = tanh(W_input @ x_t + W_hidden @ h_{t-1} + b)\n    /// With input=0.1, weight_input=0.5, bias=0.0, h_0=0.0, weight_hidden=0.5\n    /// h_t = tanh(0.5*0.1 + 0.5*0) = tanh(0.05) = 0.04995\n    #[test]\n    fn test_forward_single_input_single_feature() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = RnnConfig::new(1, 1, false);\n        let device = Default::default();\n        let mut rnn = config.init::<TestBackend>(&device);\n\n        rnn.gate = create_single_feature_gate_controller(\n            0.5,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n\n        // single timestep with single feature\n        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);\n\n        let (output, state) = rnn.forward(input, None);\n\n        let tolerance = Tolerance::default();\n        let expected = TensorData::from([[0.04995]]);\n        state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        output\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0)\n            .to_data()\n            .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);\n    }\n\n    #[test]\n    fn test_batched_forward_pass_batch_of_one() {\n        let device = Default::default();\n        let rnn = RnnConfig::new(64, 1024, true).init(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);\n\n        let (output, state) = rnn.forward(batched_input, None);\n        assert_eq!(output.dims(), [1, 2, 1024]);\n        assert_eq!(state.hidden.dims(), [1, 1024]);\n    }\n\n    #[test]\n    #[cfg(feature = \"std\")]\n    fn test_batched_backward_pass() {\n        use burn::tensor::Shape;\n        let device = Default::default();\n        let rnn = RnnConfig::new(64, 32, true).init(&device);\n        let shape: Shape = [8, 10, 64].into();\n        let batched_input =\n            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);\n\n        let (output, _) = rnn.forward(batched_input.clone(), None);\n        let fake_loss = output;\n        let grads = fake_loss.backward();\n\n        let some_gradient = rnn.gate.hidden_transform.weight.grad(&grads).unwrap();\n\n        // Asserts that the gradients exist and are non-zero\n        assert_ne!(\n            some_gradient\n                .any()\n                .into_data()\n                .iter::<f32>()\n                .next()\n                .unwrap(),\n            0.0\n        );\n    }\n\n    #[test]\n    fn test_bidirectional() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = BiRnnConfig::new(2, 3, true);\n        let mut rnn = config.init(&device);\n\n        fn create_gate_controller<const D1: usize, const D2: usize>(\n            input_weights: [[f32; D1]; D2],\n            input_biases: [f32; D1],\n            hidden_weights: [[f32; D1]; D1],\n            hidden_biases: [f32; D1],\n            device: &Device<TestBackend>,\n        ) -> GateController<TestBackend> {\n            let d_input = input_weights[0].len();\n            let d_output = input_weights.len();\n\n            let input_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(input_weights), device),\n                bias: Some(Param::from_data(TensorData::from(input_biases), device)),\n            };\n            let hidden_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(hidden_weights), device),\n                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                true,\n                Initializer::XavierUniform { gain: 1.0 },\n                input_record,\n                hidden_record,\n            )\n        }\n\n        // [batch_size=1, seq_length=4, input_size=2]\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[\n                [0.949, -0.861],\n                [0.892, 0.927],\n                [-0.173, -0.301],\n                [-0.081, 0.992],\n            ]]),\n            &device,\n        );\n\n        // [2, batch_size=1, hidden_size=3]\n        let h0 = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),\n            &device,\n        );\n\n        rnn.forward.gate = create_gate_controller(\n            // input_weights: [input_size=2, hidden_size=3]\n            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],\n            // input_biases: [hidden_size=3]\n            [-0.196, 0.354, 0.209],\n            // hidden_weights: [hidden_size=3, hidden_size=3]\n            [\n                [-0.320, 0.232, -0.165],\n                [0.093, -0.572, -0.315],\n                [-0.467, 0.325, 0.046],\n            ],\n            // hidden_biases: [hidden_size=3]\n            [0.181, -0.190, -0.245],\n            &device,\n        );\n\n        rnn.reverse.gate = create_gate_controller(\n            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],\n            [0.540, -0.164, 0.033],\n            [\n                [0.159, 0.180, -0.037],\n                [-0.443, 0.485, -0.488],\n                [0.098, -0.085, -0.140],\n            ],\n            [-0.510, 0.105, 0.114],\n            &device,\n        );\n\n        // [batch_size=1, sequence_length=4, hidden_size * 2 = 6]\n        // The expected output values were computed from PyTorch\n        let expected_output_with_init_state = TensorData::from([[\n            [0.5226, -0.6370, 0.0210, 0.0685, 0.3867, 0.3602],\n            [0.3580, 0.8431, 0.4129, -0.3175, 0.4374, 0.1766],\n            [-0.3837, -0.2703, -0.3957, -0.1542, -0.1122, 0.0725],\n            [0.5059, 0.5527, 0.1244, -0.6779, 0.3725, -0.3387],\n        ]]);\n        let expected_output_without_init_state = TensorData::from([[\n            [0.0560, -0.2056, 0.2334, 0.0892, 0.3912, 0.3607],\n            [0.4340, 0.7378, 0.3714, -0.2394, 0.4235, 0.2002],\n            [-0.3962, -0.2097, -0.3798, 0.0532, -0.2067, 0.1727],\n            [0.5075, 0.5298, 0.1083, -0.3200, 0.0764, -0.1282],\n        ]]);\n\n        //`[2, batch_size=1, hidden_size=3]`\n        let expected_hn_with_init_state =\n            TensorData::from([[[0.5059, 0.5527, 0.1244]], [[0.0685, 0.3867, 0.3602]]]);\n        let expected_hn_without_init_state =\n            TensorData::from([[[0.5075, 0.5298, 0.1083]], [[0.0892, 0.3912, 0.3607]]]);\n\n        let (output_with_init_state, state_with_init_state) =\n            rnn.forward(input.clone(), Some(RnnState::new(h0)));\n        let (output_without_init_state, state_without_init_state) = rnn.forward(input, None);\n\n        let tolerance = Tolerance::permissive();\n        output_with_init_state\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);\n        output_without_init_state\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);\n        state_with_init_state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);\n        state_without_init_state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);\n    }\n\n    #[test]\n    fn display_rnn() {\n        let config = RnnConfig::new(2, 3, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"Rnn {d_input: 2, d_hidden: 3, bias: true, params: 21}\"\n        );\n    }\n\n    #[test]\n    fn display_birnn() {\n        let config = BiRnnConfig::new(2, 3, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"BiRnn {d_input: 2, d_hidden: 3, bias: true, params: 42}\"\n        );\n    }\n\n    #[test]\n    fn test_rnn_clipping() {\n        let device = Default::default();\n\n        // Create Rnn with clipping enabled\n        let clip_value = 0.3;\n        let config = RnnConfig::new(4, 8, true).with_clip(Some(clip_value));\n        let rnn = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);\n        let (_, state) = rnn.forward(input, None);\n\n        // Verify output values are within the clip range\n        let hidden_state: Vec<f32> = state.hidden.to_data().to_vec().unwrap();\n        for val in hidden_state {\n            assert!(\n                val >= -clip_value as f32 && val <= clip_value as f32,\n                \"Value {} is outside clip range [-{}, {}]\",\n                val,\n                clip_value,\n                clip_value\n            );\n        }\n    }\n\n    #[test]\n    fn test_forward_reverse_sequence() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        // Create RNN with reverse=true to process sequence in reverse order\n        let config = RnnConfig::new(1, 1, false).with_reverse(true);\n        let mut rnn = config.init::<TestBackend>(&device);\n\n        rnn.gate = create_single_feature_gate_controller(\n            0.5,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n\n        // Create input with 3 timesteps: [0.1, 0.2, 0.3]\n        // Shape: [batch_size=1, seq_length=3, input_features=1]\n        let input =\n            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);\n\n        let (output, state) = rnn.forward(input, None);\n\n        // With reverse=true and weight=0.5, sequence is processed in reverse:\n        // t=2 (last): h = tanh(0.5*0.3 + 0.5*0) = tanh(0.15) ≈ 0.1488850\n        // t=1 (mid):  h = tanh(0.5*0.2 + 0.5*0.1488850) ≈ 0.17269433\n        // t=0 (first): h = tanh(0.5*0.1 + 0.5*0.17269433) ≈ 0.135508\n        let expected_final_hidden = TensorData::from([[0.135508]]);\n\n        let tolerance = Tolerance::default();\n        state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_final_hidden, tolerance);\n\n        // Verify output tensor has correct shape and matches state at final timestep\n        assert_eq!(output.dims(), [1, 3, 1]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rnn/gate_controller.rs",
    "content": "use burn_core as burn;\n\nuse crate::{Linear, LinearConfig, LinearLayout};\nuse burn::module::{Initializer, Module};\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// A GateController represents a gate in an LSTM cell. An\n/// LSTM cell generally contains three gates: an input gate,\n/// forget gate, and output gate. Additionally, cell gate\n/// is just used to compute the cell state.\n///\n/// An Lstm gate is modeled as two linear transformations.\n/// The results of these transformations are used to calculate\n/// the gate's output.\n#[derive(Module, Debug)]\npub struct GateController<B: Backend> {\n    /// Represents the affine transformation applied to input vector\n    pub input_transform: Linear<B>,\n    /// Represents the affine transformation applied to the hidden state\n    pub hidden_transform: Linear<B>,\n}\n\nimpl<B: Backend> GateController<B> {\n    /// Initialize a new [gate_controller](GateController) module.\n    pub fn new(\n        d_input: usize,\n        d_output: usize,\n        bias: bool,\n        initializer: Initializer,\n        device: &B::Device,\n    ) -> Self {\n        Self {\n            input_transform: LinearConfig {\n                d_input,\n                d_output,\n                bias,\n                initializer: initializer.clone(),\n                layout: LinearLayout::Row,\n            }\n            .init(device),\n            hidden_transform: LinearConfig {\n                d_input: d_output,\n                d_output,\n                bias,\n                initializer,\n                layout: LinearLayout::Row,\n            }\n            .init(device),\n        }\n    }\n\n    /// Helper function for performing weighted matrix product for a gate and adds\n    /// bias, if any.\n    ///\n    ///  Mathematically, performs `Wx*X + Wh*H + b`, where:\n    ///     Wx = weight matrix for the connection to input vector X\n    ///     Wh = weight matrix for the connection to hidden state H\n    ///     X = input vector\n    ///     H = hidden state\n    ///     b = bias terms\n    pub fn gate_product(&self, input: Tensor<B, 2>, hidden: Tensor<B, 2>) -> Tensor<B, 2> {\n        self.input_transform.forward(input) + self.hidden_transform.forward(hidden)\n    }\n\n    /// Used to initialize a gate controller with known weight layers,\n    /// allowing for predictable behavior. Used only for testing in\n    /// lstm.\n    #[cfg(test)]\n    pub fn create_with_weights(\n        d_input: usize,\n        d_output: usize,\n        bias: bool,\n        initializer: Initializer,\n        input_record: crate::LinearRecord<B>,\n        hidden_record: crate::LinearRecord<B>,\n    ) -> Self {\n        let l1 = LinearConfig {\n            d_input,\n            d_output,\n            bias,\n            initializer: initializer.clone(),\n            layout: LinearLayout::Row,\n        }\n        .init(&input_record.weight.device())\n        .load_record(input_record);\n        let l2 = LinearConfig {\n            d_input,\n            d_output,\n            bias,\n            initializer,\n            layout: LinearLayout::Row,\n        }\n        .init(&hidden_record.weight.device())\n        .load_record(hidden_record);\n\n        Self {\n            input_transform: l1,\n            hidden_transform: l2,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rnn/gru.rs",
    "content": "use burn_core as burn;\n\nuse super::gate_controller::GateController;\nuse crate::activation::{Activation, ActivationConfig};\nuse burn::config::Config;\nuse burn::module::Initializer;\nuse burn::module::Module;\nuse burn::module::{Content, DisplaySettings, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Configuration to create a [gru](Gru) module using the [init function](GruConfig::init).\n#[derive(Config, Debug)]\npub struct GruConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the Gru transformation.\n    pub bias: bool,\n    /// If reset gate should be applied after weight multiplication.\n    ///\n    /// This configuration option controls how the reset gate is applied to the hidden state.\n    /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for\n    ///   Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by\n    ///   the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU).\n    /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine\n    ///   Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication.\n    ///\n    /// The differing implementations can give slightly different numerical results and have different efficiencies. For more\n    /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs).\n    ///\n    /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`).\n    #[config(default = \"true\")]\n    pub reset_after: bool,\n    /// Gru initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// Activation function for the update and reset gates.\n    /// Default is Sigmoid, which is standard for GRU gates.\n    #[config(default = \"ActivationConfig::Sigmoid\")]\n    pub gate_activation: ActivationConfig,\n    /// Activation function for the new/candidate gate.\n    /// Default is Tanh, which is standard for GRU.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n    /// Optional hidden state clip threshold. If provided, hidden state values are clipped\n    /// to the range `[-clip, +clip]` after each timestep. This can help prevent\n    /// exploding values during inference.\n    pub clip: Option<f64>,\n}\n\n/// The Gru (Gated recurrent unit) module. This implementation is for a unidirectional, stateless, Gru.\n///\n/// Introduced in the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).\n///\n/// Should be created with [GruConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Gru<B: Backend> {\n    /// The update gate controller.\n    pub update_gate: GateController<B>,\n    /// The reset gate controller.\n    pub reset_gate: GateController<B>,\n    /// The new gate controller.\n    pub new_gate: GateController<B>,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If reset gate should be applied after weight multiplication.\n    pub reset_after: bool,\n    /// Activation function for gates (update, reset).\n    pub gate_activation: Activation<B>,\n    /// Activation function for new/candidate gate.\n    pub hidden_activation: Activation<B>,\n    /// Optional hidden state clip threshold.\n    pub clip: Option<f64>,\n}\n\nimpl<B: Backend> ModuleDisplay for Gru<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();\n        let bias = self.update_gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .add(\"reset_after\", &self.reset_after)\n            .optional()\n    }\n}\n\nimpl GruConfig {\n    /// Initialize a new [gru](Gru) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {\n        let d_output = self.d_hidden;\n\n        let update_gate = GateController::new(\n            self.d_input,\n            d_output,\n            self.bias,\n            self.initializer.clone(),\n            device,\n        );\n        let reset_gate = GateController::new(\n            self.d_input,\n            d_output,\n            self.bias,\n            self.initializer.clone(),\n            device,\n        );\n        let new_gate = GateController::new(\n            self.d_input,\n            d_output,\n            self.bias,\n            self.initializer.clone(),\n            device,\n        );\n\n        Gru {\n            update_gate,\n            reset_gate,\n            new_gate,\n            d_hidden: self.d_hidden,\n            reset_after: self.reset_after,\n            gate_activation: self.gate_activation.init(device),\n            hidden_activation: self.hidden_activation.init(device),\n            clip: self.clip,\n        }\n    }\n}\n\nimpl<B: Backend> Gru<B> {\n    /// Applies the forward pass on the input tensor. This GRU implementation\n    /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`.\n    ///\n    /// # Parameters\n    /// - batched_input: `[batch_size, sequence_length, input_size]`.\n    /// - state: An optional tensor representing an initial cell state with dimensions\n    ///   `[batch_size, hidden_size]`. If none is provided, an empty state will be used.\n    ///\n    /// # Returns\n    /// - output: `[batch_size, sequence_length, hidden_size]`\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<Tensor<B, 2>>,\n    ) -> Tensor<B, 3> {\n        let device = batched_input.device();\n        let [batch_size, seq_length, _] = batched_input.shape().dims();\n\n        self.forward_iter(\n            batched_input.iter_dim(1).zip(0..seq_length),\n            state,\n            batch_size,\n            seq_length,\n            &device,\n        )\n        .0\n    }\n\n    /// Forward pass variant that accepts an iterator over timesteps.\n    /// Used by BiGru to process sequences in either direction.\n    ///\n    /// # Parameters\n    /// - input_timestep_iter: Iterator yielding (input_tensor, timestep_index) pairs.\n    ///   The timestep_index determines where in the output tensor to store results.\n    /// - state: Optional initial hidden state with shape `[batch_size, hidden_size]`.\n    /// - batch_size: Batch size of the input.\n    /// - seq_length: Sequence length of the input.\n    /// - device: Device to create tensors on.\n    ///\n    /// # Returns\n    /// - output: `[batch_size, sequence_length, hidden_size]`\n    /// - final_hidden: Final hidden state `[batch_size, hidden_size]`\n    pub(crate) fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(\n        &self,\n        input_timestep_iter: I,\n        state: Option<Tensor<B, 2>>,\n        batch_size: usize,\n        seq_length: usize,\n        device: &B::Device,\n    ) -> (Tensor<B, 3>, Tensor<B, 2>) {\n        let mut batched_hidden_state =\n            Tensor::empty([batch_size, seq_length, self.d_hidden], device);\n\n        let mut hidden_t = match state {\n            Some(state) => state,\n            None => Tensor::zeros([batch_size, self.d_hidden], device),\n        };\n\n        for (input_t, t) in input_timestep_iter {\n            let input_t = input_t.squeeze_dim(1);\n\n            // u(pdate)g(ate) tensors\n            let biased_ug_input_sum =\n                self.gate_product(&input_t, &hidden_t, None, &self.update_gate);\n            let update_values = self.gate_activation.forward(biased_ug_input_sum);\n\n            // r(eset)g(ate) tensors\n            let biased_rg_input_sum =\n                self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);\n            let reset_values = self.gate_activation.forward(biased_rg_input_sum);\n\n            // n(ew)g(ate) tensor\n            let biased_ng_input_sum = if self.reset_after {\n                self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)\n            } else {\n                let reset_t = hidden_t.clone().mul(reset_values);\n                self.gate_product(&input_t, &reset_t, None, &self.new_gate)\n            };\n            let candidate_state = self.hidden_activation.forward(biased_ng_input_sum);\n\n            // calculate linear interpolation between previous hidden state and candidate state:\n            // h_t = (1 - z_t) * g_t + z_t * h_{t-1}\n            let one_minus_z = update_values.clone().neg().add_scalar(1.0);\n            hidden_t = candidate_state.mul(one_minus_z) + update_values.mul(hidden_t);\n\n            // Apply hidden state clipping if configured\n            if let Some(clip) = self.clip {\n                hidden_t = hidden_t.clamp(-clip, clip);\n            }\n\n            let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);\n\n            batched_hidden_state = batched_hidden_state.slice_assign(\n                [0..batch_size, t..(t + 1), 0..self.d_hidden],\n                unsqueezed_hidden_state,\n            );\n        }\n\n        (batched_hidden_state, hidden_t)\n    }\n\n    /// Helper function for performing weighted matrix product for a gate and adds\n    /// bias, if any, and optionally applies reset to hidden state.\n    ///\n    ///  Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where:\n    ///     Wx = weight matrix for the connection to input vector X\n    ///     Wh = weight matrix for the connection to hidden state H\n    ///     X = input vector\n    ///     H = hidden state\n    ///     b = bias terms\n    ///     r = reset state\n    fn gate_product(\n        &self,\n        input: &Tensor<B, 2>,\n        hidden: &Tensor<B, 2>,\n        reset: Option<&Tensor<B, 2>>,\n        gate: &GateController<B>,\n    ) -> Tensor<B, 2> {\n        let input_product = input.clone().matmul(gate.input_transform.weight.val());\n        let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());\n\n        let input_part = match &gate.input_transform.bias {\n            Some(bias) => input_product + bias.val().unsqueeze(),\n            None => input_product,\n        };\n\n        let hidden_part = match &gate.hidden_transform.bias {\n            Some(bias) => hidden_product + bias.val().unsqueeze(),\n            None => hidden_product,\n        };\n\n        match reset {\n            Some(r) => input_part + r.clone().mul(hidden_part),\n            None => input_part + hidden_part,\n        }\n    }\n}\n\n/// Configuration to create a [BiGru](BiGru) module using the [init function](BiGruConfig::init).\n#[derive(Config, Debug)]\npub struct BiGruConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the BiGru transformation.\n    pub bias: bool,\n    /// If reset gate should be applied after weight multiplication.\n    #[config(default = \"true\")]\n    pub reset_after: bool,\n    /// BiGru initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.\n    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.\n    #[config(default = true)]\n    pub batch_first: bool,\n    /// Activation function for the update and reset gates.\n    #[config(default = \"ActivationConfig::Sigmoid\")]\n    pub gate_activation: ActivationConfig,\n    /// Activation function for the new/candidate gate.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n    /// Optional hidden state clip threshold.\n    pub clip: Option<f64>,\n}\n\n/// The BiGru module. This implementation is for Bidirectional GRU.\n///\n/// Based on the paper: [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078).\n///\n/// Should be created with [BiGruConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct BiGru<B: Backend> {\n    /// GRU for the forward direction.\n    pub forward: Gru<B>,\n    /// GRU for the reverse direction.\n    pub reverse: Gru<B>,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If true, input is `[batch_size, seq_length, input_size]`.\n    /// If false, input is `[seq_length, batch_size, input_size]`.\n    pub batch_first: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for BiGru<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self\n            .forward\n            .update_gate\n            .input_transform\n            .weight\n            .shape()\n            .dims();\n        let bias = self.forward.update_gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .optional()\n    }\n}\n\nimpl BiGruConfig {\n    /// Initialize a new [Bidirectional GRU](BiGru) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> BiGru<B> {\n        // Internal GRUs always use batch_first=true; BiGru handles layout conversion\n        let base_config = GruConfig::new(self.d_input, self.d_hidden, self.bias)\n            .with_initializer(self.initializer.clone())\n            .with_reset_after(self.reset_after)\n            .with_gate_activation(self.gate_activation.clone())\n            .with_hidden_activation(self.hidden_activation.clone())\n            .with_clip(self.clip);\n\n        BiGru {\n            forward: base_config.clone().init(device),\n            reverse: base_config.init(device),\n            d_hidden: self.d_hidden,\n            batch_first: self.batch_first,\n        }\n    }\n}\n\nimpl<B: Backend> BiGru<B> {\n    /// Applies the forward pass on the input tensor. This Bidirectional GRU implementation\n    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.\n    ///\n    /// ## Parameters:\n    /// - batched_input: The input tensor of shape:\n    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)\n    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false\n    /// - state: An optional tensor representing the initial hidden state with shape\n    ///   `[2, batch_size, hidden_size]`. If no initial state is provided, it is initialized to zeros.\n    ///\n    /// ## Returns:\n    /// - output: A tensor representing the output features. Shape:\n    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true\n    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false\n    /// - state: The final forward and reverse hidden states stacked along dimension 0\n    ///   with shape `[2, batch_size, hidden_size]`.\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<Tensor<B, 3>>,\n    ) -> (Tensor<B, 3>, Tensor<B, 3>) {\n        // Convert to batch-first layout internally if needed\n        let batched_input = if self.batch_first {\n            batched_input\n        } else {\n            batched_input.swap_dims(0, 1)\n        };\n\n        let device = batched_input.clone().device();\n        let [batch_size, seq_length, _] = batched_input.shape().dims();\n\n        let [init_state_forward, init_state_reverse] = match state {\n            Some(state) => {\n                let hidden_state_forward = state\n                    .clone()\n                    .slice([0..1, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n                let hidden_state_reverse = state\n                    .slice([1..2, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n\n                [Some(hidden_state_forward), Some(hidden_state_reverse)]\n            }\n            None => [None, None],\n        };\n\n        // forward direction\n        let (batched_hidden_state_forward, final_state_forward) = self.forward.forward_iter(\n            batched_input.clone().iter_dim(1).zip(0..seq_length),\n            init_state_forward,\n            batch_size,\n            seq_length,\n            &device,\n        );\n\n        // reverse direction\n        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(\n            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),\n            init_state_reverse,\n            batch_size,\n            seq_length,\n            &device,\n        );\n\n        let output = Tensor::cat(\n            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),\n            2,\n        );\n\n        // Convert output back to seq-first layout if needed\n        let output = if self.batch_first {\n            output\n        } else {\n            output.swap_dims(0, 1)\n        };\n\n        let state = Tensor::stack([final_state_forward, final_state_reverse].to_vec(), 0);\n\n        (output, state)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{LinearRecord, TestBackend};\n    use burn::module::Param;\n    use burn::tensor::{Distribution, TensorData};\n    use burn::tensor::{Tolerance, ops::FloatElem};\n\n    type FT = FloatElem<TestBackend>;\n\n    fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {\n        fn create_gate_controller<B: Backend>(\n            weights: f32,\n            biases: f32,\n            d_input: usize,\n            d_output: usize,\n            bias: bool,\n            initializer: Initializer,\n            device: &B::Device,\n        ) -> GateController<B> {\n            let record_1 = LinearRecord {\n                weight: Param::from_data(TensorData::from([[weights]]), device),\n                bias: Some(Param::from_data(TensorData::from([biases]), device)),\n            };\n            let record_2 = LinearRecord {\n                weight: Param::from_data(TensorData::from([[weights]]), device),\n                bias: Some(Param::from_data(TensorData::from([biases]), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                bias,\n                initializer,\n                record_1,\n                record_2,\n            )\n        }\n\n        let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);\n        let mut gru = config.init::<B>(device);\n\n        gru.update_gate = create_gate_controller(\n            0.5,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierNormal { gain: 1.0 },\n            device,\n        );\n        gru.reset_gate = create_gate_controller(\n            0.6,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierNormal { gain: 1.0 },\n            device,\n        );\n        gru.new_gate = create_gate_controller(\n            0.7,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierNormal { gain: 1.0 },\n            device,\n        );\n        gru\n    }\n\n    /// Test forward pass with simple input vector.\n    ///\n    /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125\n    /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150\n    /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699\n    ///\n    /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341\n    #[test]\n    fn tests_forward_single_input_single_feature() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let mut gru = init_gru::<TestBackend>(false, &device);\n\n        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);\n        let expected = TensorData::from([[0.034]]);\n\n        // Reset gate applied to hidden state before the matrix multiplication\n        let state = gru.forward(input.clone(), None);\n\n        let output = state\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0);\n\n        let tolerance = Tolerance::default();\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        // Reset gate applied to hidden state after the matrix multiplication\n        gru.reset_after = true; // override forward behavior\n        let state = gru.forward(input, None);\n\n        let output = state\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n    }\n\n    #[test]\n    fn tests_forward_seq_len_3() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n        let mut gru = init_gru::<TestBackend>(true, &device);\n\n        let input =\n            Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);\n        let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);\n\n        let result = gru.forward(input.clone(), None);\n        let output = result\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0);\n\n        let tolerance = Tolerance::default();\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        // Reset gate applied to hidden state before the matrix multiplication\n        gru.reset_after = false; // override forward behavior\n        let state = gru.forward(input, None);\n\n        let output = state\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0);\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n    }\n\n    #[test]\n    fn test_batched_forward_pass() {\n        let device = Default::default();\n        let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);\n\n        let hidden_state = gru.forward(batched_input, None);\n\n        assert_eq!(&*hidden_state.shape(), [8, 10, 1024]);\n    }\n\n    #[test]\n    fn display() {\n        let config = GruConfig::new(2, 8, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}\"\n        );\n    }\n\n    #[test]\n    fn test_bigru_batched_forward_pass() {\n        let device = Default::default();\n        let bigru = BiGruConfig::new(64, 1024, true).init::<TestBackend>(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);\n\n        let (output, state) = bigru.forward(batched_input, None);\n\n        // Output should have hidden_size * 2 features (forward + reverse concatenated)\n        assert_eq!(&*output.shape(), [8, 10, 2048]);\n        // State should have shape [2, batch_size, hidden_size]\n        assert_eq!(&*state.shape(), [2, 8, 1024]);\n    }\n\n    #[test]\n    fn test_bigru_with_initial_state() {\n        let device = Default::default();\n        let bigru = BiGruConfig::new(32, 64, true).init::<TestBackend>(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([4, 5, 32], Distribution::Default, &device);\n        let initial_state =\n            Tensor::<TestBackend, 3>::random([2, 4, 64], Distribution::Default, &device);\n\n        let (output, state) = bigru.forward(batched_input, Some(initial_state));\n\n        assert_eq!(&*output.shape(), [4, 5, 128]);\n        assert_eq!(&*state.shape(), [2, 4, 64]);\n    }\n\n    #[test]\n    fn test_bigru_seq_first() {\n        let device = Default::default();\n        let bigru = BiGruConfig::new(32, 64, true)\n            .with_batch_first(false)\n            .init::<TestBackend>(&device);\n        // Input shape: [seq_length, batch_size, input_size] when batch_first=false\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([5, 4, 32], Distribution::Default, &device);\n\n        let (output, state) = bigru.forward(batched_input, None);\n\n        // Output shape: [seq_length, batch_size, hidden_size * 2]\n        assert_eq!(&*output.shape(), [5, 4, 128]);\n        assert_eq!(&*state.shape(), [2, 4, 64]);\n    }\n\n    /// Test BiGru against PyTorch reference implementation.\n    /// Expected values computed with PyTorch nn.GRU(bidirectional=True).\n    #[test]\n    fn test_bigru_against_pytorch() {\n        use burn::tensor::Device;\n\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = BiGruConfig::new(2, 3, true);\n        let mut bigru = config.init::<TestBackend>(&device);\n\n        fn create_gate_controller<const D1: usize, const D2: usize>(\n            input_weights: [[f32; D1]; D2],\n            input_biases: [f32; D1],\n            hidden_weights: [[f32; D1]; D1],\n            hidden_biases: [f32; D1],\n            device: &Device<TestBackend>,\n        ) -> GateController<TestBackend> {\n            let d_input = input_weights[0].len();\n            let d_output = input_weights.len();\n\n            let input_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(input_weights), device),\n                bias: Some(Param::from_data(TensorData::from(input_biases), device)),\n            };\n            let hidden_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(hidden_weights), device),\n                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                true,\n                Initializer::XavierUniform { gain: 1.0 },\n                input_record,\n                hidden_record,\n            )\n        }\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[\n                [0.949, -0.861],\n                [0.892, 0.927],\n                [-0.173, -0.301],\n                [-0.081, 0.992],\n            ]]),\n            &device,\n        );\n        let h0 = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),\n            &device,\n        );\n\n        // Forward GRU gates (weights from PyTorch with seed 42, transposed for burn)\n        bigru.forward.update_gate = create_gate_controller(\n            [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],\n            [0.2932, -0.3519, -0.5715],\n            [\n                [-0.3471, 0.5214, 0.0961],\n                [0.0545, -0.4904, -0.1875],\n                [-0.5702, 0.4457, 0.3568],\n            ],\n            [-0.0100, 0.4518, -0.4102],\n            &device,\n        );\n\n        bigru.forward.reset_gate = create_gate_controller(\n            [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],\n            [-0.2524, 0.3333, 0.1033],\n            [\n                [-0.2695, -0.0677, -0.4557],\n                [0.1472, -0.2345, -0.2662],\n                [-0.2660, 0.3830, -0.1630],\n            ],\n            [0.1663, 0.2391, 0.1826],\n            &device,\n        );\n\n        bigru.forward.new_gate = create_gate_controller(\n            [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],\n            [-0.2231, -0.4428, 0.4737],\n            [\n                [0.0900, -0.1821, 0.2430],\n                [0.4665, 0.1551, 0.5155],\n                [0.0631, -0.1566, 0.3337],\n            ],\n            [0.0364, -0.3941, 0.1780],\n            &device,\n        );\n\n        // Reverse GRU gates\n        bigru.reverse.update_gate = create_gate_controller(\n            [[-0.3444, 0.1924, -0.4765], [0.5193, 0.5556, -0.5727]],\n            [0.1090, 0.1779, -0.5385],\n            [\n                [0.1221, 0.3925, 0.5287],\n                [-0.1472, -0.4187, -0.1948],\n                [0.3441, -0.3082, -0.2047],\n            ],\n            [0.0016, -0.2148, -0.0400],\n            &device,\n        );\n\n        bigru.reverse.reset_gate = create_gate_controller(\n            [[-0.1988, -0.1203, -0.3422], [0.1769, 0.4788, -0.3443]],\n            [-0.5053, -0.3676, 0.5771],\n            [\n                [-0.3936, 0.3504, -0.4486],\n                [0.3063, -0.1370, -0.2914],\n                [-0.2334, 0.3303, 0.1760],\n            ],\n            [-0.5080, -0.2488, -0.3456],\n            &device,\n        );\n\n        bigru.reverse.new_gate = create_gate_controller(\n            [[-0.4517, 0.2339, 0.4797], [-0.3884, 0.2067, -0.2982]],\n            [-0.3792, -0.1922, 0.0903],\n            [\n                [-0.5586, -0.0762, -0.3944],\n                [-0.3306, -0.4191, -0.4898],\n                [0.1442, 0.0135, -0.3179],\n            ],\n            [-0.3912, -0.3963, -0.3368],\n            &device,\n        );\n\n        // Expected values from PyTorch\n        let expected_output_with_init = TensorData::from([[\n            [0.24537, 0.14018, 0.19449, -0.49777, -0.15647, 0.48392],\n            [0.27468, -0.14514, 0.56205, -0.60381, -0.04986, 0.15683],\n            [-0.04062, -0.33486, 0.52330, -0.42244, -0.12644, -0.12034],\n            [-0.11743, -0.53873, 0.54429, -0.64943, 0.30127, -0.41943],\n        ]]);\n\n        let expected_hn_with_init = TensorData::from([\n            [[-0.11743, -0.53873, 0.54429]],\n            [[-0.49777, -0.15647, 0.48392]],\n        ]);\n\n        let expected_output_without_init = TensorData::from([[\n            [0.07452, -0.08247, 0.46677, -0.46770, -0.18086, 0.47519],\n            [0.15843, -0.27144, 0.65781, -0.50286, -0.12806, 0.14884],\n            [-0.10704, -0.41573, 0.53954, -0.24794, -0.24003, -0.10294],\n            [-0.16505, -0.57952, 0.53565, -0.23598, -0.07137, -0.28937],\n        ]]);\n\n        let expected_hn_without_init = TensorData::from([\n            [[-0.16505, -0.57952, 0.53565]],\n            [[-0.46770, -0.18086, 0.47519]],\n        ]);\n\n        let (output_with_init, hn_with_init) = bigru.forward(input.clone(), Some(h0));\n        let (output_without_init, hn_without_init) = bigru.forward(input, None);\n\n        let tolerance = Tolerance::permissive();\n        output_with_init\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_with_init, tolerance);\n        output_without_init\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_without_init, tolerance);\n        hn_with_init\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_with_init, tolerance);\n        hn_without_init\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_without_init, tolerance);\n    }\n\n    #[test]\n    fn bigru_display() {\n        let config = BiGruConfig::new(2, 8, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"BiGru {d_input: 2, d_hidden: 8, bias: true, params: 576}\"\n        );\n    }\n\n    #[test]\n    fn test_gru_custom_activations() {\n        let device = Default::default();\n\n        // Create GRU with custom activations (ReLU instead of Sigmoid/Tanh)\n        let config = GruConfig::new(4, 8, true)\n            .with_gate_activation(ActivationConfig::Relu)\n            .with_hidden_activation(ActivationConfig::Relu);\n        let gru = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);\n\n        // Should run without panicking and produce valid output\n        let output = gru.forward(input, None);\n        assert_eq!(&*output.shape(), [2, 3, 8]);\n    }\n\n    #[test]\n    fn test_bigru_custom_activations() {\n        let device = Default::default();\n\n        // Create BiGRU with custom activations\n        let config = BiGruConfig::new(4, 8, true)\n            .with_gate_activation(ActivationConfig::Relu)\n            .with_hidden_activation(ActivationConfig::Relu);\n        let bigru = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);\n\n        let (output, state) = bigru.forward(input, None);\n        assert_eq!(&*output.shape(), [2, 3, 16]); // hidden_size * 2\n        assert_eq!(&*state.shape(), [2, 2, 8]);\n    }\n\n    #[test]\n    fn test_gru_clipping() {\n        let device = Default::default();\n\n        // Create GRU with clipping enabled\n        let clip_value = 0.5;\n        let config = GruConfig::new(4, 8, true).with_clip(Some(clip_value));\n        let gru = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);\n\n        let output = gru.forward(input, None);\n\n        // Verify output values are within the clip range\n        let output_data: Vec<f32> = output.to_data().to_vec().unwrap();\n        for val in output_data {\n            assert!(\n                val >= -clip_value as f32 && val <= clip_value as f32,\n                \"Value {} is outside clip range [-{}, {}]\",\n                val,\n                clip_value,\n                clip_value\n            );\n        }\n    }\n\n    #[test]\n    fn test_bigru_clipping() {\n        let device = Default::default();\n\n        // Create BiGRU with clipping enabled\n        let clip_value = 0.3;\n        let config = BiGruConfig::new(4, 8, true).with_clip(Some(clip_value));\n        let bigru = config.init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);\n\n        let (output, state) = bigru.forward(input, None);\n\n        // Verify output values are within the clip range\n        let output_data: Vec<f32> = output.to_data().to_vec().unwrap();\n        for val in output_data {\n            assert!(\n                val >= -clip_value as f32 && val <= clip_value as f32,\n                \"Output value {} is outside clip range [-{}, {}]\",\n                val,\n                clip_value,\n                clip_value\n            );\n        }\n\n        // Verify state values are within the clip range\n        let state_data: Vec<f32> = state.to_data().to_vec().unwrap();\n        for val in state_data {\n            assert!(\n                val >= -clip_value as f32 && val <= clip_value as f32,\n                \"State value {} is outside clip range [-{}, {}]\",\n                val,\n                clip_value,\n                clip_value\n            );\n        }\n    }\n\n    /// Test Gru against PyTorch reference implementation.\n    /// Expected values computed with PyTorch nn.GRU (seed=42 for weights, seed=123 for input).\n    #[test]\n    fn test_gru_against_pytorch() {\n        use burn::tensor::Device;\n\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = GruConfig::new(2, 3, true);\n        let mut gru = config.init::<TestBackend>(&device);\n\n        fn create_gate_controller<const D1: usize, const D2: usize>(\n            input_weights: [[f32; D1]; D2],\n            input_biases: [f32; D1],\n            hidden_weights: [[f32; D1]; D1],\n            hidden_biases: [f32; D1],\n            device: &Device<TestBackend>,\n        ) -> GateController<TestBackend> {\n            let d_input = input_weights[0].len();\n            let d_output = input_weights.len();\n\n            let input_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(input_weights), device),\n                bias: Some(Param::from_data(TensorData::from(input_biases), device)),\n            };\n            let hidden_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(hidden_weights), device),\n                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                true,\n                Initializer::XavierUniform { gain: 1.0 },\n                input_record,\n                hidden_record,\n            )\n        }\n\n        // Input: [batch=1, seq=4, input=2]\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[\n                [-0.11147, 0.12036],\n                [-0.36963, -0.24042],\n                [-1.19692, 0.20927],\n                [-0.97236, -0.75505],\n            ]]),\n            &device,\n        );\n\n        // Initial hidden state: [batch=1, hidden=3]\n        let h0 = Tensor::<TestBackend, 2>::from_data(\n            TensorData::from([[0.3239, -0.10852, 0.21033]]),\n            &device,\n        );\n\n        // Update gate (z) - weights from PyTorch, transposed for Burn's Row layout\n        gru.update_gate = create_gate_controller(\n            [[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],\n            [0.2932, -0.3519, -0.5715],\n            [\n                [-0.3471, 0.5214, 0.0961],\n                [0.0545, -0.4904, -0.1875],\n                [-0.5702, 0.4457, 0.3568],\n            ],\n            [-0.0100, 0.4518, -0.4102],\n            &device,\n        );\n\n        // Reset gate (r)\n        gru.reset_gate = create_gate_controller(\n            [[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],\n            [-0.2524, 0.3333, 0.1033],\n            [\n                [-0.2695, -0.0677, -0.4557],\n                [0.1472, -0.2345, -0.2662],\n                [-0.2660, 0.3830, -0.1630],\n            ],\n            [0.1663, 0.2391, 0.1826],\n            &device,\n        );\n\n        // New gate (n)\n        gru.new_gate = create_gate_controller(\n            [[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],\n            [-0.2231, -0.4428, 0.4737],\n            [\n                [0.0900, -0.1821, 0.2430],\n                [0.4665, 0.1551, 0.5155],\n                [0.0631, -0.1566, 0.3337],\n            ],\n            [0.0364, -0.3941, 0.1780],\n            &device,\n        );\n\n        // Expected values from PyTorch\n        let expected_output_with_h0 = TensorData::from([[\n            [0.05665, -0.34932, 0.43267],\n            [-0.1737, -0.49246, 0.38099],\n            [-0.35401, -0.68099, 0.05061],\n            [-0.47854, -0.70427, -0.13648],\n        ]]);\n\n        let expected_output_no_h0 = TensorData::from([[\n            [-0.0985, -0.31661, 0.36126],\n            [-0.24563, -0.47784, 0.34609],\n            [-0.39497, -0.67659, 0.03083],\n            [-0.50146, -0.70066, -0.14894],\n        ]]);\n\n        let output_with_h0 = gru.forward(input.clone(), Some(h0));\n        let output_no_h0 = gru.forward(input, None);\n\n        let tolerance = Tolerance::permissive();\n        output_with_h0\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_with_h0, tolerance);\n        output_no_h0\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_no_h0, tolerance);\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rnn/lstm.rs",
    "content": "use burn_core as burn;\n\nuse crate::GateController;\nuse crate::activation::{Activation, ActivationConfig};\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// A LstmState is used to store cell state and hidden state in LSTM.\npub struct LstmState<B: Backend, const D: usize> {\n    /// The cell state.\n    pub cell: Tensor<B, D>,\n    /// The hidden state.\n    pub hidden: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> LstmState<B, D> {\n    /// Initialize a new [LSTM State](LstmState).\n    pub fn new(cell: Tensor<B, D>, hidden: Tensor<B, D>) -> Self {\n        Self { cell, hidden }\n    }\n}\n\n/// Configuration to create a [Lstm](Lstm) module using the [init function](LstmConfig::init).\n#[derive(Config, Debug)]\npub struct LstmConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the Lstm transformation.\n    pub bias: bool,\n    /// Lstm initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.\n    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.\n    #[config(default = true)]\n    pub batch_first: bool,\n    /// If true, process the sequence in reverse order.\n    /// This is useful for implementing reverse-direction LSTMs (e.g., ONNX reverse direction).\n    #[config(default = false)]\n    pub reverse: bool,\n    /// Optional cell state clip threshold. If provided, cell state values are clipped\n    /// to the range `[-clip, +clip]` after each timestep. This can help prevent\n    /// exploding values during inference.\n    pub clip: Option<f64>,\n    /// If true, couples the input and forget gates: `f_t = 1 - i_t`.\n    /// This reduces the number of parameters and is based on GRU-style simplification.\n    #[config(default = false)]\n    pub input_forget: bool,\n    /// Activation function for the input, forget, and output gates.\n    /// Default is Sigmoid, which is standard for LSTM gates.\n    #[config(default = \"ActivationConfig::Sigmoid\")]\n    pub gate_activation: ActivationConfig,\n    /// Activation function for the cell gate (candidate cell state).\n    /// Default is Tanh, which is standard for LSTM.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub cell_activation: ActivationConfig,\n    /// Activation function applied to the cell state before computing hidden output.\n    /// Default is Tanh, which is standard for LSTM.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n}\n\n/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.\n///\n/// Introduced in the paper: [Long Short-Term Memory](https://www.researchgate.net/publication/13853244).\n///\n/// Should be created with [LstmConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Lstm<B: Backend> {\n    /// The input gate regulates which information to update and store in the cell state at each time step.\n    pub input_gate: GateController<B>,\n    /// The forget gate is used to control which information to discard or keep in the memory cell at each time step.\n    /// Note: When `input_forget` is true, this gate is not used (forget = 1 - input).\n    pub forget_gate: GateController<B>,\n    /// The output gate determines which information from the cell state to output at each time step.\n    pub output_gate: GateController<B>,\n    /// The cell gate is used to compute the cell state that stores and carries information through time.\n    pub cell_gate: GateController<B>,\n    /// The hidden state of the LSTM.\n    pub d_hidden: usize,\n    /// If true, input is `[batch_size, seq_length, input_size]`.\n    /// If false, input is `[seq_length, batch_size, input_size]`.\n    pub batch_first: bool,\n    /// If true, process the sequence in reverse order.\n    pub reverse: bool,\n    /// Optional cell state clip threshold.\n    pub clip: Option<f64>,\n    /// If true, couples input and forget gates: f_t = 1 - i_t.\n    pub input_forget: bool,\n    /// Activation function for gates (input, forget, output).\n    pub gate_activation: Activation<B>,\n    /// Activation function for cell gate (candidate cell state).\n    pub cell_activation: Activation<B>,\n    /// Activation function for hidden output.\n    pub hidden_activation: Activation<B>,\n}\n\nimpl<B: Backend> ModuleDisplay for Lstm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self.input_gate.input_transform.weight.shape().dims();\n        let bias = self.input_gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .optional()\n    }\n}\n\nimpl LstmConfig {\n    /// Initialize a new [lstm](Lstm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Lstm<B> {\n        let d_output = self.d_hidden;\n\n        let new_gate = || {\n            GateController::new(\n                self.d_input,\n                d_output,\n                self.bias,\n                self.initializer.clone(),\n                device,\n            )\n        };\n\n        Lstm {\n            input_gate: new_gate(),\n            forget_gate: new_gate(),\n            output_gate: new_gate(),\n            cell_gate: new_gate(),\n            d_hidden: self.d_hidden,\n            batch_first: self.batch_first,\n            reverse: self.reverse,\n            clip: self.clip,\n            input_forget: self.input_forget,\n            gate_activation: self.gate_activation.init(device),\n            cell_activation: self.cell_activation.init(device),\n            hidden_activation: self.hidden_activation.init(device),\n        }\n    }\n}\n\nimpl<B: Backend> Lstm<B> {\n    /// Applies the forward pass on the input tensor. This LSTM implementation\n    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.\n    ///\n    /// ## Parameters:\n    /// - batched_input: The input tensor of shape:\n    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)\n    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false\n    /// - state: An optional `LstmState` representing the initial cell state and hidden state.\n    ///   Each state tensor has shape `[batch_size, hidden_size]`.\n    ///   If no initial state is provided, these tensors are initialized to zeros.\n    ///\n    /// ## Returns:\n    /// - output: A tensor represents the output features of LSTM. Shape:\n    ///   - `[batch_size, sequence_length, hidden_size]` if `batch_first` is true\n    ///   - `[sequence_length, batch_size, hidden_size]` if `batch_first` is false\n    /// - state: A `LstmState` represents the final states. Both `state.cell` and `state.hidden` have the shape\n    ///   `[batch_size, hidden_size]`.\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<LstmState<B, 2>>,\n    ) -> (Tensor<B, 3>, LstmState<B, 2>) {\n        // Convert to batch-first layout internally if needed\n        let batched_input = if self.batch_first {\n            batched_input\n        } else {\n            batched_input.swap_dims(0, 1)\n        };\n\n        let device = batched_input.device();\n        let [batch_size, seq_length, _] = batched_input.dims();\n\n        // Process sequence in forward or reverse order based on config\n        let (output, state) = if self.reverse {\n            self.forward_iter(\n                batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),\n                state,\n                batch_size,\n                seq_length,\n                &device,\n            )\n        } else {\n            self.forward_iter(\n                batched_input.iter_dim(1).zip(0..seq_length),\n                state,\n                batch_size,\n                seq_length,\n                &device,\n            )\n        };\n\n        // Convert output back to seq-first layout if needed\n        let output = if self.batch_first {\n            output\n        } else {\n            output.swap_dims(0, 1)\n        };\n\n        (output, state)\n    }\n\n    fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(\n        &self,\n        input_timestep_iter: I,\n        state: Option<LstmState<B, 2>>,\n        batch_size: usize,\n        seq_length: usize,\n        device: &B::Device,\n    ) -> (Tensor<B, 3>, LstmState<B, 2>) {\n        let mut batched_hidden_state =\n            Tensor::empty([batch_size, seq_length, self.d_hidden], device);\n\n        let (mut cell_state, mut hidden_state) = match state {\n            Some(state) => (state.cell, state.hidden),\n            None => (\n                Tensor::zeros([batch_size, self.d_hidden], device),\n                Tensor::zeros([batch_size, self.d_hidden], device),\n            ),\n        };\n\n        for (input_t, t) in input_timestep_iter {\n            let input_t = input_t.squeeze_dim(1);\n\n            // i(nput)g(ate) tensors\n            let biased_ig_input_sum = self\n                .input_gate\n                .gate_product(input_t.clone(), hidden_state.clone());\n            let input_values = self.gate_activation.forward(biased_ig_input_sum);\n\n            // f(orget)g(ate) tensors - either computed or coupled to input gate\n            let forget_values = if self.input_forget {\n                // Coupled mode: f_t = 1 - i_t\n                input_values.clone().neg().add_scalar(1.0)\n            } else {\n                let biased_fg_input_sum = self\n                    .forget_gate\n                    .gate_product(input_t.clone(), hidden_state.clone());\n                self.gate_activation.forward(biased_fg_input_sum)\n            };\n\n            // o(output)g(ate) tensors\n            let biased_og_input_sum = self\n                .output_gate\n                .gate_product(input_t.clone(), hidden_state.clone());\n            let output_values = self.gate_activation.forward(biased_og_input_sum);\n\n            // c(ell)g(ate) tensors\n            let biased_cg_input_sum = self\n                .cell_gate\n                .gate_product(input_t.clone(), hidden_state.clone());\n            let candidate_cell_values = self.cell_activation.forward(biased_cg_input_sum);\n\n            cell_state = forget_values * cell_state.clone() + input_values * candidate_cell_values;\n\n            // Apply cell state clipping if configured\n            if let Some(clip) = self.clip {\n                cell_state = cell_state.clamp(-clip, clip);\n            }\n\n            hidden_state = output_values * self.hidden_activation.forward(cell_state.clone());\n\n            let unsqueezed_hidden_state = hidden_state.clone().unsqueeze_dim(1);\n\n            // store the hidden state for this timestep\n            batched_hidden_state = batched_hidden_state.slice_assign(\n                [0..batch_size, t..(t + 1), 0..self.d_hidden],\n                unsqueezed_hidden_state.clone(),\n            );\n        }\n\n        (\n            batched_hidden_state,\n            LstmState::new(cell_state, hidden_state),\n        )\n    }\n}\n\n/// Configuration to create a [BiLstm](BiLstm) module using the [init function](BiLstmConfig::init).\n#[derive(Config, Debug)]\npub struct BiLstmConfig {\n    /// The size of the input features.\n    pub d_input: usize,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If a bias should be applied during the BiLstm transformation.\n    pub bias: bool,\n    /// BiLstm initializer\n    #[config(default = \"Initializer::XavierNormal{gain:1.0}\")]\n    pub initializer: Initializer,\n    /// If true, the input tensor is expected to be `[batch_size, seq_length, input_size]`.\n    /// If false, the input tensor is expected to be `[seq_length, batch_size, input_size]`.\n    #[config(default = true)]\n    pub batch_first: bool,\n    /// Optional cell state clip threshold.\n    pub clip: Option<f64>,\n    /// If true, couples the input and forget gates.\n    #[config(default = false)]\n    pub input_forget: bool,\n    /// Activation function for the input, forget, and output gates.\n    #[config(default = \"ActivationConfig::Sigmoid\")]\n    pub gate_activation: ActivationConfig,\n    /// Activation function for the cell gate (candidate cell state).\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub cell_activation: ActivationConfig,\n    /// Activation function applied to the cell state before computing hidden output.\n    #[config(default = \"ActivationConfig::Tanh\")]\n    pub hidden_activation: ActivationConfig,\n}\n\n/// The BiLstm module. This implementation is for Bidirectional LSTM.\n///\n/// Introduced in the paper: [Framewise phoneme classification with bidirectional LSTM and other neural network architectures](https://www.cs.toronto.edu/~graves/ijcnn_2005.pdf).\n///\n/// Should be created with [BiLstmConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct BiLstm<B: Backend> {\n    /// LSTM for the forward direction.\n    pub forward: Lstm<B>,\n    /// LSTM for the reverse direction.\n    pub reverse: Lstm<B>,\n    /// The size of the hidden state.\n    pub d_hidden: usize,\n    /// If true, input is `[batch_size, seq_length, input_size]`.\n    /// If false, input is `[seq_length, batch_size, input_size]`.\n    pub batch_first: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for BiLstm<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_input, _] = self\n            .forward\n            .input_gate\n            .input_transform\n            .weight\n            .shape()\n            .dims();\n        let bias = self.forward.input_gate.input_transform.bias.is_some();\n\n        content\n            .add(\"d_input\", &d_input)\n            .add(\"d_hidden\", &self.d_hidden)\n            .add(\"bias\", &bias)\n            .optional()\n    }\n}\n\nimpl BiLstmConfig {\n    /// Initialize a new [Bidirectional LSTM](BiLstm) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> BiLstm<B> {\n        // Internal LSTMs always use batch_first=true; BiLstm handles layout conversion\n        let base_config = LstmConfig::new(self.d_input, self.d_hidden, self.bias)\n            .with_initializer(self.initializer.clone())\n            .with_batch_first(true)\n            .with_clip(self.clip)\n            .with_input_forget(self.input_forget)\n            .with_gate_activation(self.gate_activation.clone())\n            .with_cell_activation(self.cell_activation.clone())\n            .with_hidden_activation(self.hidden_activation.clone());\n\n        BiLstm {\n            forward: base_config.clone().init(device),\n            reverse: base_config.init(device),\n            d_hidden: self.d_hidden,\n            batch_first: self.batch_first,\n        }\n    }\n}\n\nimpl<B: Backend> BiLstm<B> {\n    /// Applies the forward pass on the input tensor. This Bidirectional LSTM implementation\n    /// returns the state for each element in a sequence (i.e., across seq_length) and a final state.\n    ///\n    /// ## Parameters:\n    /// - batched_input: The input tensor of shape:\n    ///   - `[batch_size, sequence_length, input_size]` if `batch_first` is true (default)\n    ///   - `[sequence_length, batch_size, input_size]` if `batch_first` is false\n    /// - state: An optional `LstmState` representing the initial cell state and hidden state.\n    ///   Each state tensor has shape `[2, batch_size, hidden_size]`.\n    ///   If no initial state is provided, these tensors are initialized to zeros.\n    ///\n    /// ## Returns:\n    /// - output: A tensor represents the output features of LSTM. Shape:\n    ///   - `[batch_size, sequence_length, hidden_size * 2]` if `batch_first` is true\n    ///   - `[sequence_length, batch_size, hidden_size * 2]` if `batch_first` is false\n    /// - state: A `LstmState` represents the final forward and reverse states. Both `state.cell` and\n    ///   `state.hidden` have the shape `[2, batch_size, hidden_size]`.\n    pub fn forward(\n        &self,\n        batched_input: Tensor<B, 3>,\n        state: Option<LstmState<B, 3>>,\n    ) -> (Tensor<B, 3>, LstmState<B, 3>) {\n        // Convert to batch-first layout internally if needed\n        let batched_input = if self.batch_first {\n            batched_input\n        } else {\n            batched_input.swap_dims(0, 1)\n        };\n\n        let device = batched_input.clone().device();\n        let [batch_size, seq_length, _] = batched_input.shape().dims();\n\n        let [init_state_forward, init_state_reverse] = match state {\n            Some(state) => {\n                let cell_state_forward = state\n                    .cell\n                    .clone()\n                    .slice([0..1, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n                let hidden_state_forward = state\n                    .hidden\n                    .clone()\n                    .slice([0..1, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n                let cell_state_reverse = state\n                    .cell\n                    .slice([1..2, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n                let hidden_state_reverse = state\n                    .hidden\n                    .slice([1..2, 0..batch_size, 0..self.d_hidden])\n                    .squeeze_dim(0);\n\n                [\n                    Some(LstmState::new(cell_state_forward, hidden_state_forward)),\n                    Some(LstmState::new(cell_state_reverse, hidden_state_reverse)),\n                ]\n            }\n            None => [None, None],\n        };\n\n        // forward direction\n        let (batched_hidden_state_forward, final_state_forward) = self\n            .forward\n            .forward(batched_input.clone(), init_state_forward);\n\n        // reverse direction\n        let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(\n            batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),\n            init_state_reverse,\n            batch_size,\n            seq_length,\n            &device,\n        );\n\n        let output = Tensor::cat(\n            [batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),\n            2,\n        );\n\n        // Convert output back to seq-first layout if needed\n        let output = if self.batch_first {\n            output\n        } else {\n            output.swap_dims(0, 1)\n        };\n\n        let state = LstmState::new(\n            Tensor::stack(\n                [final_state_forward.cell, final_state_reverse.cell].to_vec(),\n                0,\n            ),\n            Tensor::stack(\n                [final_state_forward.hidden, final_state_reverse.hidden].to_vec(),\n                0,\n            ),\n        );\n\n        (output, state)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{LinearRecord, TestBackend};\n    use burn::module::Param;\n    use burn::tensor::{Device, Distribution, TensorData};\n    use burn::tensor::{ElementConversion, Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[cfg(feature = \"std\")]\n    use crate::TestAutodiffBackend;\n\n    #[test]\n    fn test_with_uniform_initializer() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = LstmConfig::new(5, 5, false)\n            .with_initializer(Initializer::Uniform { min: 0.0, max: 1.0 });\n        let lstm = config.init::<TestBackend>(&Default::default());\n\n        let gate_to_data =\n            |gate: GateController<TestBackend>| gate.input_transform.weight.val().to_data();\n\n        gate_to_data(lstm.input_gate).assert_within_range::<FT>(0.elem()..1.elem());\n        gate_to_data(lstm.forget_gate).assert_within_range::<FT>(0.elem()..1.elem());\n        gate_to_data(lstm.output_gate).assert_within_range::<FT>(0.elem()..1.elem());\n        gate_to_data(lstm.cell_gate).assert_within_range::<FT>(0.elem()..1.elem());\n    }\n\n    /// Test forward pass with simple input vector.\n    ///\n    /// f_t = sigmoid(0.7*0.1 + 0.7*0) = sigmoid(0.07) = 0.5173928\n    /// i_t = sigmoid(0.5*0.1 + 0.5*0) = sigmoid(0.05) = 0.5123725\n    /// o_t = sigmoid(1.1*0.1 + 1.1*0) = sigmoid(0.11) = 0.5274723\n    /// c_t = tanh(0.9*0.1 + 0.9*0) = tanh(0.09) = 0.0892937\n    /// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243\n    /// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648\n    #[test]\n    fn test_forward_single_input_single_feature() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = LstmConfig::new(1, 1, false);\n        let device = Default::default();\n        let mut lstm = config.init::<TestBackend>(&device);\n\n        fn create_gate_controller(\n            weights: f32,\n            biases: f32,\n            d_input: usize,\n            d_output: usize,\n            bias: bool,\n            initializer: Initializer,\n            device: &Device<TestBackend>,\n        ) -> GateController<TestBackend> {\n            let record_1 = LinearRecord {\n                weight: Param::from_data(TensorData::from([[weights]]), device),\n                bias: Some(Param::from_data(TensorData::from([biases]), device)),\n            };\n            let record_2 = LinearRecord {\n                weight: Param::from_data(TensorData::from([[weights]]), device),\n                bias: Some(Param::from_data(TensorData::from([biases]), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                bias,\n                initializer,\n                record_1,\n                record_2,\n            )\n        }\n\n        lstm.input_gate = create_gate_controller(\n            0.5,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n        lstm.forget_gate = create_gate_controller(\n            0.7,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n        lstm.cell_gate = create_gate_controller(\n            0.9,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n        lstm.output_gate = create_gate_controller(\n            1.1,\n            0.0,\n            1,\n            1,\n            false,\n            Initializer::XavierUniform { gain: 1.0 },\n            &device,\n        );\n\n        // single timestep with single feature\n        let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);\n\n        let (output, state) = lstm.forward(input, None);\n\n        let expected = TensorData::from([[0.046]]);\n        let tolerance = Tolerance::default();\n        state\n            .cell\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        let expected = TensorData::from([[0.0242]]);\n        state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected, tolerance);\n\n        output\n            .select(0, Tensor::arange(0..1, &device))\n            .squeeze_dim::<2>(0)\n            .to_data()\n            .assert_approx_eq::<FT>(&state.hidden.to_data(), tolerance);\n    }\n\n    #[test]\n    fn test_batched_forward_pass() {\n        let device = Default::default();\n        let lstm = LstmConfig::new(64, 1024, true).init(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);\n\n        let (output, state) = lstm.forward(batched_input, None);\n\n        assert_eq!(output.dims(), [8, 10, 1024]);\n        assert_eq!(state.cell.dims(), [8, 1024]);\n        assert_eq!(state.hidden.dims(), [8, 1024]);\n    }\n\n    #[test]\n    fn test_batched_forward_pass_batch_of_one() {\n        let device = Default::default();\n        let lstm = LstmConfig::new(64, 1024, true).init(&device);\n        let batched_input =\n            Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);\n\n        let (output, state) = lstm.forward(batched_input, None);\n\n        assert_eq!(output.dims(), [1, 2, 1024]);\n        assert_eq!(state.cell.dims(), [1, 1024]);\n        assert_eq!(state.hidden.dims(), [1, 1024]);\n    }\n\n    #[test]\n    #[cfg(feature = \"std\")]\n    fn test_batched_backward_pass() {\n        use burn::tensor::Shape;\n        let device = Default::default();\n        let lstm = LstmConfig::new(64, 32, true).init(&device);\n        let shape: Shape = [8, 10, 64].into();\n        let batched_input =\n            Tensor::<TestAutodiffBackend, 3>::random(shape, Distribution::Default, &device);\n\n        let (output, _) = lstm.forward(batched_input.clone(), None);\n        let fake_loss = output;\n        let grads = fake_loss.backward();\n\n        let some_gradient = lstm\n            .output_gate\n            .hidden_transform\n            .weight\n            .grad(&grads)\n            .unwrap();\n\n        // Asserts that the gradients exist and are non-zero\n        assert_ne!(\n            some_gradient\n                .any()\n                .into_data()\n                .iter::<f32>()\n                .next()\n                .unwrap(),\n            0.0\n        );\n    }\n\n    #[test]\n    fn test_bidirectional() {\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        let config = BiLstmConfig::new(2, 3, true);\n        let device = Default::default();\n        let mut lstm = config.init(&device);\n\n        fn create_gate_controller<const D1: usize, const D2: usize>(\n            input_weights: [[f32; D1]; D2],\n            input_biases: [f32; D1],\n            hidden_weights: [[f32; D1]; D1],\n            hidden_biases: [f32; D1],\n            device: &Device<TestBackend>,\n        ) -> GateController<TestBackend> {\n            let d_input = input_weights[0].len();\n            let d_output = input_weights.len();\n\n            let input_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(input_weights), device),\n                bias: Some(Param::from_data(TensorData::from(input_biases), device)),\n            };\n            let hidden_record = LinearRecord {\n                weight: Param::from_data(TensorData::from(hidden_weights), device),\n                bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),\n            };\n            GateController::create_with_weights(\n                d_input,\n                d_output,\n                true,\n                Initializer::XavierUniform { gain: 1.0 },\n                input_record,\n                hidden_record,\n            )\n        }\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[\n                [0.949, -0.861],\n                [0.892, 0.927],\n                [-0.173, -0.301],\n                [-0.081, 0.992],\n            ]]),\n            &device,\n        );\n        let h0 = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),\n            &device,\n        );\n        let c0 = Tensor::<TestBackend, 3>::from_data(\n            TensorData::from([[[0.723, 0.397, -0.262]], [[0.471, 0.613, 1.885]]]),\n            &device,\n        );\n\n        lstm.forward.input_gate = create_gate_controller(\n            [[0.367, 0.091, 0.342], [0.322, 0.533, 0.059]],\n            [-0.196, 0.354, 0.209],\n            [\n                [-0.320, 0.232, -0.165],\n                [0.093, -0.572, -0.315],\n                [-0.467, 0.325, 0.046],\n            ],\n            [0.181, -0.190, -0.245],\n            &device,\n        );\n\n        lstm.forward.forget_gate = create_gate_controller(\n            [[-0.342, -0.084, -0.420], [-0.432, 0.119, 0.191]],\n            [0.315, -0.413, -0.041],\n            [\n                [0.453, 0.063, 0.561],\n                [0.211, 0.149, 0.213],\n                [-0.499, -0.158, 0.068],\n            ],\n            [-0.431, -0.535, 0.125],\n            &device,\n        );\n\n        lstm.forward.cell_gate = create_gate_controller(\n            [[-0.046, -0.382, 0.321], [-0.533, 0.558, 0.004]],\n            [-0.358, 0.282, -0.078],\n            [\n                [-0.358, 0.109, 0.139],\n                [-0.345, 0.091, -0.368],\n                [-0.508, 0.221, -0.507],\n            ],\n            [0.502, -0.509, -0.247],\n            &device,\n        );\n\n        lstm.forward.output_gate = create_gate_controller(\n            [[-0.577, -0.359, 0.216], [-0.550, 0.268, 0.243]],\n            [-0.227, -0.274, 0.039],\n            [\n                [-0.383, 0.449, 0.222],\n                [-0.357, -0.093, 0.449],\n                [-0.106, 0.236, 0.360],\n            ],\n            [-0.361, -0.209, -0.454],\n            &device,\n        );\n\n        lstm.reverse.input_gate = create_gate_controller(\n            [[-0.055, 0.506, 0.247], [-0.369, 0.178, -0.258]],\n            [0.540, -0.164, 0.033],\n            [\n                [0.159, 0.180, -0.037],\n                [-0.443, 0.485, -0.488],\n                [0.098, -0.085, -0.140],\n            ],\n            [-0.510, 0.105, 0.114],\n            &device,\n        );\n\n        lstm.reverse.forget_gate = create_gate_controller(\n            [[-0.154, -0.432, -0.547], [-0.369, -0.310, -0.175]],\n            [0.141, 0.004, 0.055],\n            [\n                [-0.005, -0.277, -0.515],\n                [-0.011, -0.101, -0.365],\n                [0.426, 0.379, 0.337],\n            ],\n            [-0.382, 0.331, -0.176],\n            &device,\n        );\n\n        lstm.reverse.cell_gate = create_gate_controller(\n            [[-0.571, 0.228, -0.287], [-0.331, 0.110, 0.219]],\n            [-0.206, -0.546, 0.462],\n            [\n                [0.449, -0.240, 0.071],\n                [-0.045, 0.131, 0.124],\n                [0.138, -0.201, 0.191],\n            ],\n            [-0.030, 0.211, -0.352],\n            &device,\n        );\n\n        lstm.reverse.output_gate = create_gate_controller(\n            [[0.491, -0.442, 0.333], [0.313, -0.121, -0.070]],\n            [-0.387, -0.250, 0.066],\n            [\n                [-0.030, 0.268, 0.299],\n                [-0.019, -0.280, -0.314],\n                [0.466, -0.365, -0.248],\n            ],\n            [-0.398, -0.199, -0.566],\n            &device,\n        );\n\n        let expected_output_with_init_state = TensorData::from([[\n            [0.23764, -0.03442, 0.04414, -0.15635, -0.03366, -0.05798],\n            [0.00473, -0.02254, 0.02988, -0.16510, -0.00306, 0.08742],\n            [0.06210, -0.06509, -0.05339, -0.01710, 0.02091, 0.16012],\n            [-0.03420, 0.07774, -0.09774, -0.02604, 0.12584, 0.20872],\n        ]]);\n        let expected_output_without_init_state = TensorData::from([[\n            [0.08679, -0.08776, -0.00528, -0.15969, -0.05322, -0.08863],\n            [-0.02577, -0.05057, 0.00033, -0.17558, -0.03679, 0.03142],\n            [0.02942, -0.07411, -0.06044, -0.03601, -0.09998, 0.04846],\n            [-0.04026, 0.07178, -0.10189, -0.07349, -0.04576, 0.05550],\n        ]]);\n        let expected_hn_with_init_state = TensorData::from([\n            [[-0.03420, 0.07774, -0.09774]],\n            [[-0.15635, -0.03366, -0.05798]],\n        ]);\n        let expected_cn_with_init_state = TensorData::from([\n            [[-0.13593, 0.17125, -0.22395]],\n            [[-0.45425, -0.11206, -0.12908]],\n        ]);\n        let expected_hn_without_init_state = TensorData::from([\n            [[-0.04026, 0.07178, -0.10189]],\n            [[-0.15969, -0.05322, -0.08863]],\n        ]);\n        let expected_cn_without_init_state = TensorData::from([\n            [[-0.15839, 0.15923, -0.23569]],\n            [[-0.47407, -0.17493, -0.19643]],\n        ]);\n\n        let (output_with_init_state, state_with_init_state) =\n            lstm.forward(input.clone(), Some(LstmState::new(c0, h0)));\n        let (output_without_init_state, state_without_init_state) = lstm.forward(input, None);\n\n        let tolerance = Tolerance::permissive();\n        output_with_init_state\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_with_init_state, tolerance);\n        output_without_init_state\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output_without_init_state, tolerance);\n        state_with_init_state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_with_init_state, tolerance);\n        state_with_init_state\n            .cell\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_cn_with_init_state, tolerance);\n        state_without_init_state\n            .hidden\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_hn_without_init_state, tolerance);\n        state_without_init_state\n            .cell\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_cn_without_init_state, tolerance);\n    }\n\n    #[test]\n    fn display_lstm() {\n        let config = LstmConfig::new(2, 3, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}\"\n        );\n    }\n\n    #[test]\n    fn display_bilstm() {\n        let config = BiLstmConfig::new(2, 3, true);\n\n        let layer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{layer}\"),\n            \"BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rnn/mod.rs",
    "content": "mod gate_controller;\n\n/// Basic RNN.\npub mod basic;\n\n/// Gated Recurrent Unit module.\npub mod gru;\n\n/// Long Short-Term Memory module.\npub mod lstm;\n\npub use basic::*;\npub use gate_controller::*;\npub use gru::*;\npub use lstm::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/rope_encoding.rs",
    "content": "use burn_core as burn;\n\nuse alloc::vec;\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::Int;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse core::ops::Range;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).\n#[derive(Config, Debug)]\npub struct RotaryEncodingConfig {\n    /// Maximum sequence length of input\n    pub max_sequence_length: usize,\n\n    /// Size of the input embedding or hidden dimension\n    pub d_model: usize,\n\n    /// Scaling factor for frequency computation. Defaults to 10000.0\n    #[config(default = \"10000.0\")]\n    pub theta: f32,\n}\n\nimpl RotaryEncodingConfig {\n    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the size of input embedding dimension is not even.\n    /// Panics if the theta parameter is not positive.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {\n        self.initialize(|x| x, device)\n    }\n\n    /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.\n    /// This is useful to apply different RoPE extensions.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the size of input embedding dimension is not even.\n    /// Panics if the theta parameter is not positive.\n    pub fn init_with_frequency_scaling<B: Backend>(\n        &self,\n        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,\n        device: &B::Device,\n    ) -> RotaryEncoding<B> {\n        self.initialize(scaling, device)\n    }\n\n    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the size of input embedding dimension is not even.\n    /// Panics if the theta parameter is not positive.\n    fn initialize<B: Backend>(\n        &self,\n        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,\n        device: &B::Device,\n    ) -> RotaryEncoding<B> {\n        assert_eq!(\n            self.d_model % 2,\n            0,\n            \"The input embedding dimension must be even\"\n        );\n        assert!(\n            self.theta > 0.0,\n            \"Theta parameter must be positive (default: 10000).\"\n        );\n\n        // Calculate the rotation frequencies for positional embeddings based on the formula\n        // `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`\n        let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)\n            .float()\n            .div_scalar(self.d_model as f32);\n\n        // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`\n        // This is done since burn doesn't support exponentiation of scalar to tensor\n        let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();\n\n        let theta = scaling(theta);\n\n        let freq_complex =\n            RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());\n\n        RotaryEncoding {\n            freq_complex,\n            theta,\n            start_offset: 0,\n        }\n    }\n}\n\n/// A module that applies rotary positional encoding to a tensor.\n/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes\n/// absolute positional information with rotation matrix and naturally incorporates\n/// explicit relative position dependency in self-attention formulation.\n///\n/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)\n///\n/// Should be created using [RotaryEncodingConfig].\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct RotaryEncoding<B: Backend> {\n    /// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components\n    // Essentially a cache of pre-computed RoPE values.\n    pub freq_complex: Tensor<B, 3>,\n    /// Frequency vector used to compute/apply the complex rotations.\n    pub theta: Tensor<B, 1>,\n    start_offset: usize,\n}\n\nimpl<B: Backend> ModuleDisplay for RotaryEncoding<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();\n        content\n            .add(\"d_model\", &d_model)\n            .add(\"max_sequence_length\", &max_sequence_length)\n            .optional()\n    }\n}\n\n#[allow(clippy::single_range_in_vec_init)]\nimpl<B: Backend> RotaryEncoding<B> {\n    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)\n    ///\n    /// # Arguments:\n    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors\n    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)\n    ///   respectively.\n    ///\n    /// # Returns:\n    /// Output tensor with the same shape as input tensor after applying rotary encoding.\n    ///\n    /// # Panics\n    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.\n    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {\n        self.apply(x, 0)\n    }\n\n    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)\n    ///\n    /// # Arguments:\n    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors\n    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)\n    ///   respectively.\n    /// * `start` - Sequence start position index.\n    ///\n    /// # Returns:\n    /// Output tensor with the same shape as input tensor after applying rotary encoding.\n    ///\n    /// # Panics\n    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.\n    pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {\n        assert!(\n            D >= 2,\n            \"Input tensor must have at least 2 dimensions for sequence length and hidden dimension\"\n        );\n\n        let device = x.device();\n        let input_shape = x.shape();\n\n        // Extract the sequence length and embedding dimension, other dimensions are kept generic\n        // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)\n        let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);\n        let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);\n\n        // Create a dummy tensor with signed ones based on the 2D rotation matrix\n        // [[cos, -sin], [sin, cos]]\n        let sign_tensor =\n            Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);\n\n        // Rotate input using the frequency tensor. Slice the frequencies till input sequence length\n        let out: Tensor<B, 4> = x\n            .reshape([dummy_dim_size, seq_len, d_model / 2, 2])\n            .matmul(sign_tensor.unsqueeze())\n            .reshape([dummy_dim_size, seq_len, d_model, 2])\n            * self\n                .freq_complex\n                .clone()\n                .slice([start..start + seq_len])\n                .unsqueeze();\n\n        // Sum the real and imaginary components to get output tensor and reshape to original shape\n        out.sum_dim(-1).reshape(input_shape)\n    }\n\n    /// Shifts the pre-computed rotary frequency to cover a new range of positions.\n    ///\n    /// This method updates the internal frequency tensor `freq_complex` to store\n    /// the rotary positional encodings for a new window of positions starting at `start`.\n    pub fn shift(&mut self, start: usize) {\n        let max_seq_len = self.freq_complex.dims()[0];\n        assert!(\n            start > self.start_offset,\n            \"Shift start position must be monotonically increasing\"\n        );\n\n        let current_end = self.start_offset + max_seq_len;\n\n        if start >= current_end {\n            // Overwrite the whole buffer\n            let new_freqs =\n                Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());\n            self.freq_complex\n                .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));\n        } else {\n            // Shift the tail\n            let num_keep = current_end - start;\n            let start_rel = start - self.start_offset;\n            let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);\n            self.freq_complex\n                .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));\n            // Compute the rest and assign\n            let new_freqs = Self::compute_rotary_frequencies(\n                current_end..start + max_seq_len,\n                self.theta.clone(),\n            );\n            self.freq_complex\n                .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));\n        }\n        self.start_offset = start;\n    }\n\n    /// Computes the positional rotation frequencies (cosine and sine values) used in RoPE.\n    ///\n    /// # Arguments\n    /// - `range`: Range of position indices `[start, end)`.\n    /// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies.\n    ///\n    /// # Returns\n    /// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency.\n    fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {\n        let d_model = theta.dims()[0] * 2;\n        let num_positions = range.end - range.start;\n\n        // Generate frequency values for positional embeddings\n        let frequencies: Tensor<B, 2> =\n            Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())\n                .float()\n                .unsqueeze()\n                .transpose()\n                .repeat_dim(1, d_model / 2)\n                * theta.unsqueeze();\n\n        // Convert frequency values to complex numbers (polar form)\n        let p_cos = frequencies.clone().cos();\n        let p_sin = frequencies.sin();\n\n        Tensor::cat(vec![p_cos, p_sin], 1)\n            .reshape([num_positions, 2, d_model / 2])\n            .transpose()\n            .unsqueeze_dim::<4>(2)\n            .repeat_dim(2, 2)\n            .reshape([num_positions, d_model, 2])\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_rotary_encoding_forward() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],\n                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],\n            ],\n            &device,\n        );\n\n        // Input = [Batch size, Num of heads, Seq_len, d_model]\n        let input = input.unsqueeze::<4>();\n\n        let output = rotary_encoding.forward(input);\n        let expected_output = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [1.0000, 2.0000, 3.0000, 4.0000],\n                    [-2.3473, 7.4492, 6.9197, 8.0696],\n                ],\n                [\n                    [9.0000, 10.0000, 11.0000, 12.0000],\n                    [-4.7567, 18.5034, 14.8393, 16.1492],\n                ],\n            ],\n            &device,\n        );\n\n        output\n            .squeeze_dim::<3>(0)\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_rotary_encoding_3d() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);\n\n        let input = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],\n                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],\n            ],\n            &device,\n        );\n\n        // Input = [Batch size, Num of heads, Seq_len, d_model]\n        // let input = input.unsqueeze::<4>();\n\n        let output = rotary_encoding.forward(input);\n        let expected_output = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [1.0000, 2.0000, 3.0000, 4.0000],\n                    [-2.3473, 7.4492, 6.9197, 8.0696],\n                ],\n                [\n                    [9.0000, 10.0000, 11.0000, 12.0000],\n                    [-4.7567, 18.5034, 14.8393, 16.1492],\n                ],\n            ],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_zero_input_rotary_encoding_forward() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);\n\n        // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well\n        let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);\n\n        let output = rotary_encoding.forward(input);\n        let expected_output = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [0.0000, 0.0000, 0.0000, 0.0000],\n                    [0.0000, 0.0000, 0.0000, 0.0000],\n                ],\n                [\n                    [0.0000, 0.0000, 0.0000, 0.0000],\n                    [0.0000, 0.0000, 0.0000, 0.0000],\n                ],\n            ],\n            &device,\n        );\n\n        output\n            .squeeze_dim::<3>(0)\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    #[should_panic]\n    fn test_valid_input_hidden_dim() {\n        // Hidden dimension must be even to be able to split into real and imaginary components\n        // for rotation\n        let d_model = 15;\n        let device = Default::default();\n        let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);\n        let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);\n        let _output = pe.forward(input);\n    }\n\n    #[test]\n    fn test_rotary_encoding_frequencies() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);\n\n        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                ],\n                [\n                    [5.4030e-01, 8.4147e-01],\n                    [9.9500e-01, 9.9833e-02],\n                    [9.9995e-01, 9.9998e-03],\n                    [9.9999e-01, 9.9999e-04],\n                ],\n            ],\n            &device,\n        )\n        .unsqueeze_dim::<4>(2)\n        .repeat_dim(2, 2)\n        .reshape([2, 8, 2]);\n\n        rotary_encoding\n            .freq_complex\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());\n    }\n\n    fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {\n        // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45\n        let scale_factor = 8.;\n        let low_freq_factor = 1.;\n        let high_freq_factor = 4.;\n        let old_context_len = 8192.;\n\n        let low_freq_wavelen = old_context_len / low_freq_factor;\n        let high_freq_wavelen = old_context_len / high_freq_factor;\n\n        let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);\n\n        // if wavelen >= high_freq_wavelen\n        let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);\n        let smooth = wavelen\n            .clone()\n            .recip()\n            .mul_scalar(old_context_len)\n            .sub_scalar(low_freq_factor)\n            .div_scalar(high_freq_factor - low_freq_factor);\n        // (1 - smooth) * freq / scale_factor + smooth * freq\n        let new_freqs = smooth\n            .clone()\n            .neg()\n            .add_scalar(1.)\n            .mul(freqs.clone().div_scalar(scale_factor))\n            .add(smooth.clone().mul(freqs.clone()));\n        let new_freqs = freqs.clone().mask_where(cond, new_freqs);\n\n        // if wavelen > low_freq_wavelen\n        let cond = wavelen.clone().greater_elem(low_freq_wavelen);\n        let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));\n\n        // if wavelen < high_freq_wavelen\n        let cond = wavelen.lower_elem(high_freq_wavelen);\n        new_freqs.mask_where(cond, freqs)\n    }\n\n    #[test]\n    fn test_rotary_encoding_with_frequency_scaling() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(2, 8)\n            .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);\n\n        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                    [1.0000, 0.0000],\n                ],\n                [\n                    [5.4030e-01, 8.4148e-01],\n                    [9.9500e-01, 9.9833e-02],\n                    [9.9995e-01, 9.9998e-03],\n                    [1.0000, 2.1361e-04],\n                ],\n            ],\n            &device,\n        )\n        .unsqueeze_dim::<4>(2)\n        .repeat_dim(2, 2)\n        .reshape([2, 8, 2]);\n\n        rotary_encoding\n            .freq_complex\n            .to_data()\n            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_rotary_encoding_shift_full() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);\n\n        // Input = [Batch size, Num of heads, Seq_len, d_model]\n        let input = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],\n                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],\n            ],\n            &device,\n        )\n        .unsqueeze::<4>();\n\n        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result\n        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same\n        // initial position\n        let expected_output = rotary_encoding.apply(input.clone(), 6);\n\n        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);\n        rotary_encoding.shift(6); // start > 4 will perform a full re-compute\n\n        let output = rotary_encoding.apply(input, 0);\n\n        output\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_rotary_encoding_shift() {\n        let device = Default::default();\n        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);\n\n        // Input = [Batch size, Num of heads, Seq_len, d_model]\n        let input = Tensor::<TestBackend, 3>::from_floats(\n            [\n                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],\n                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],\n            ],\n            &device,\n        )\n        .unsqueeze::<4>();\n\n        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result\n        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same\n        // initial position\n        let expected_output = rotary_encoding.apply(input.clone(), 2);\n\n        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);\n        rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest\n\n        let output = rotary_encoding.apply(input, 0);\n\n        output\n            .into_data()\n            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_rotary_encoding_shift_multiple() {\n        let device = Default::default();\n        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);\n        rotary_encoding.shift(2);\n        rotary_encoding.shift(5);\n    }\n\n    #[test]\n    #[should_panic = \"Shift start position must be monotonically increasing\"]\n    fn test_rotary_encoding_shift_should_increase() {\n        let device = Default::default();\n        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);\n        rotary_encoding.shift(6);\n        rotary_encoding.shift(4); // should be monotonically increasing\n    }\n\n    #[test]\n    fn display() {\n        let config = RotaryEncodingConfig::new(10, 4);\n        let pe = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{pe}\"),\n            \"RotaryEncoding {d_model: 4, max_sequence_length: 10}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/transformer/decoder.rs",
    "content": "use burn_core as burn;\n\nuse alloc::vec::Vec;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::{Bool, Tensor, backend::Backend};\n\nuse crate::activation::ActivationConfig;\nuse crate::cache::TensorCache;\nuse crate::{\n    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,\n    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},\n};\n\nuse super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};\n\n/// Configuration to create a [Transformer Decoder](TransformerDecoder) layer using the [init function](TransformerDecoderConfig::init).\n#[derive(Config, Debug)]\npub struct TransformerDecoderConfig {\n    /// The size of the model.\n    pub d_model: usize,\n    /// The size of the position-wise feed-forward network.\n    pub d_ff: usize,\n    /// The number of attention heads.\n    pub n_heads: usize,\n    /// The number of layers.\n    pub n_layers: usize,\n    /// The dropout rate. Default: 0.1\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    /// Layer norm will be applied first instead of after the other modules.\n    #[config(default = false)]\n    pub norm_first: bool,\n    /// Use \"quiet softmax\" instead of regular softmax.\n    ///\n    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).\n    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.\n    ///\n    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>\n    #[config(default = false)]\n    pub quiet_softmax: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n    /// The activation function used in the position-wise feed-forward network. Default: Gelu\n    #[config(default = \"ActivationConfig::Gelu\")]\n    pub activation: ActivationConfig,\n    /// The epsilon value for layer normalization. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub layer_norm_eps: f64,\n}\n\n/// The transformer decoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).\n///\n/// # Params\n///\n/// - layers: transformer decoder layers with `d_model` input and output features.\n///\n/// Should be created using [TransformerDecoderConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct TransformerDecoder<B: Backend> {\n    /// Transformer decoder layers.\n    pub layers: Vec<TransformerDecoderLayer<B>>,\n\n    /// The size of the model.\n    pub d_model: usize,\n\n    /// The size of the position-wise feed-forward network.\n    pub d_ff: usize,\n\n    /// The number of attention heads.\n    pub n_heads: usize,\n\n    /// The number of layers.\n    pub n_layers: usize,\n\n    /// The dropout rate. Default: 0.1\n    pub dropout: f64,\n\n    /// Layer norm will be applied first instead of after the other modules.\n    pub norm_first: bool,\n\n    /// Use \"quiet softmax\" instead of regular softmax.\n    pub quiet_softmax: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for TransformerDecoder<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"d_model\", &self.d_model)\n            .add(\"d_ff\", &self.d_ff)\n            .add(\"n_heads\", &self.n_heads)\n            .add(\"n_layers\", &self.n_layers)\n            .add(\"dropout\", &self.dropout)\n            .add(\"norm_first\", &self.norm_first)\n            .add(\"quiet_softmax\", &self.quiet_softmax)\n            .optional()\n    }\n}\n\nimpl TransformerDecoderConfig {\n    /// Initialize a new [Transformer Decoder](TransformerDecoder) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {\n        let layers = (0..self.n_layers)\n            .map(|_| TransformerDecoderLayer::new(self, device))\n            .collect::<Vec<_>>();\n\n        TransformerDecoder {\n            layers,\n            d_model: self.d_model,\n            d_ff: self.d_ff,\n            n_heads: self.n_heads,\n            n_layers: self.n_layers,\n            dropout: self.dropout,\n            norm_first: self.norm_first,\n            quiet_softmax: self.quiet_softmax,\n        }\n    }\n}\n\n/// [Transformer Decoder](TransformerDecoder) forward pass input argument.\n#[derive(Debug)]\npub struct TransformerDecoderInput<B: Backend> {\n    target: Tensor<B, 3>,\n    target_mask_pad: Option<Tensor<B, 2, Bool>>,\n    target_mask_attn: Option<Tensor<B, 3, Bool>>,\n    memory: Tensor<B, 3>,\n    memory_mask_pad: Option<Tensor<B, 2, Bool>>,\n    memory_mask_attn: Option<Tensor<B, 3, Bool>>,\n}\n\nimpl<B: Backend> TransformerDecoderInput<B> {\n    /// Create a [transformer decoder](TransformerDecoder) input argument.\n    pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {\n        Self {\n            target,\n            target_mask_pad: None,\n            target_mask_attn: None,\n            memory,\n            memory_mask_pad: None,\n            memory_mask_attn: None,\n        }\n    }\n\n    /// Register the memory padding mask.\n    pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {\n        self.memory_mask_pad = Some(mask_pad);\n        self\n    }\n\n    /// Register the memory attention mask.\n    pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {\n        self.memory_mask_attn = Some(mask_attn);\n        self\n    }\n\n    /// Register the target padding mask.\n    pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {\n        self.target_mask_pad = Some(mask_pad);\n        self\n    }\n\n    /// Register the target attention mask.\n    pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {\n        self.target_mask_attn = Some(mask_attn);\n        self\n    }\n}\n\n/// [Transformer Decoder](TransformerDecoder) layer module.\n#[derive(Module, Debug)]\npub struct TransformerDecoderLayer<B: Backend> {\n    /// Cross-attention module.\n    pub cross_attn: MultiHeadAttention<B>,\n    /// Self-attention module.\n    pub self_attn: MultiHeadAttention<B>,\n    /// Position-wise feed-forward module.\n    pub pwff: PositionWiseFeedForward<B>,\n    /// First layer norm.\n    pub norm_1: LayerNorm<B>,\n    /// Second layer norm.\n    pub norm_2: LayerNorm<B>,\n    /// Third layer norm.\n    pub norm_3: LayerNorm<B>,\n    /// Dropout.\n    pub dropout: Dropout,\n    /// Whether to apply norm first.\n    pub norm_first: bool,\n}\n\n/// Autoregressive cache for a single [Transformer Decoder Layer](TransformerDecoderLayer).\npub struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {\n    /// Cross-attention cache.\n    pub cross_attn: MhaCache<B>,\n    /// Self-attention cache.\n    pub self_attn: MhaCache<B>,\n    /// Position-wise feed-forward cache.\n    pub pwff: TensorCache<B, 3>,\n    /// First layer norm cache.\n    pub norm_1: TensorCache<B, 3>,\n    /// Second layer norm cache.\n    pub norm_2: TensorCache<B, 3>,\n    /// Third layer norm cache.\n    pub norm_3: TensorCache<B, 3>,\n}\n\nimpl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {\n    /// Create an empty cache.\n    pub fn empty() -> Self {\n        Self {\n            cross_attn: MhaCache::autoregressive_cross_attention(),\n            self_attn: MhaCache::autoregressive(),\n            pwff: TensorCache::empty(),\n            norm_1: TensorCache::empty(),\n            norm_2: TensorCache::empty(),\n            norm_3: TensorCache::empty(),\n        }\n    }\n}\n\n/// Autoregressive cache for the [Transformer Decoder](TransformerDecoder) layer.\n///\n/// To be used during inference when decoding tokens.\npub struct TransformerDecoderAutoregressiveCache<B: Backend> {\n    layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,\n}\n\nimpl<B: Backend> TransformerDecoderAutoregressiveCache<B> {\n    fn empty(num_layers: usize) -> Self {\n        Self {\n            layers: (0..num_layers)\n                .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())\n                .collect(),\n        }\n    }\n}\n\nimpl<B: Backend> TransformerDecoderLayer<B> {\n    /// Create a new [TransformerDecoderLayer](TransformerDecoderLayer).\n    pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {\n        let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)\n            .with_initializer(config.initializer.clone())\n            .with_dropout(config.dropout)\n            .with_quiet_softmax(config.quiet_softmax)\n            .init(device);\n\n        let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)\n            .with_initializer(config.initializer.clone())\n            .with_dropout(config.dropout)\n            .with_quiet_softmax(config.quiet_softmax)\n            .init(device);\n        let norm_1 = LayerNormConfig::new(config.d_model)\n            .with_epsilon(config.layer_norm_eps)\n            .init(device);\n        let norm_2 = LayerNormConfig::new(config.d_model)\n            .with_epsilon(config.layer_norm_eps)\n            .init(device);\n        let norm_3 = LayerNormConfig::new(config.d_model)\n            .with_epsilon(config.layer_norm_eps)\n            .init(device);\n        let dropout = DropoutConfig::new(config.dropout).init();\n        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)\n            .with_initializer(config.initializer.clone())\n            .with_dropout(config.dropout)\n            .with_activation(config.activation.clone())\n            .init(device);\n\n        Self {\n            cross_attn,\n            self_attn,\n            norm_1,\n            norm_2,\n            norm_3,\n            pwff,\n            dropout,\n            norm_first: config.norm_first,\n        }\n    }\n\n    /// Applies the TransformerDecoder forward pass to the input tensor.\n    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {\n        // Self attention residual path.\n        let x = input.target;\n        let mut residual_path = x.clone();\n\n        // Normalize.\n        if self.norm_first {\n            residual_path = self.norm_3.forward(residual_path);\n        }\n\n        // Self attention.\n        let mut self_attn_input = MhaInput::self_attn(residual_path);\n        if let Some(mask_pad) = &input.target_mask_pad {\n            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());\n        }\n        if let Some(mask_attn) = &input.target_mask_attn {\n            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());\n        }\n        let residual_path = self.self_attn.forward(self_attn_input).context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Cross attention residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            self.norm_1.forward(x.clone())\n        } else {\n            x = self.norm_1.forward(x);\n            x.clone()\n        };\n\n        // Cross attention.\n        let mut cross_attn_input =\n            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());\n        if let Some(mask_pad) = &input.memory_mask_pad {\n            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());\n        }\n        if let Some(mask_attn) = &input.memory_mask_attn {\n            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());\n        }\n        let residual_path = self.cross_attn.forward(cross_attn_input).context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Feed forward residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            self.norm_2.forward(x.clone())\n        } else {\n            x = self.norm_2.forward(x);\n            x.clone()\n        };\n\n        let residual_path = self.pwff.forward(residual_path);\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Main path.\n        // Normalize.\n        if !self.norm_first {\n            x = self.norm_3.forward(x)\n        }\n\n        input.target = x;\n        input\n    }\n\n    /// Applies the forward pass using an autoregressive cache.\n    pub fn forward_autoregressive_inference(\n        &self,\n        mut input: TransformerDecoderInput<B>,\n        cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,\n    ) -> TransformerDecoderInput<B> {\n        // Self attention residual path.\n        let x = input.target;\n        let mut residual_path = x.clone();\n\n        // Normalize.\n        if self.norm_first {\n            residual_path = cache\n                .norm_3\n                .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));\n        }\n\n        // Self attention.\n        let mut self_attn_input = MhaInput::self_attn(residual_path);\n        if let Some(mask_pad) = &input.target_mask_pad {\n            self_attn_input = self_attn_input.mask_pad(mask_pad.clone());\n        }\n        if let Some(mask_attn) = &input.target_mask_attn {\n            self_attn_input = self_attn_input.mask_attn(mask_attn.clone());\n        }\n        let residual_path = self\n            .self_attn\n            .forward_cache(self_attn_input, &mut cache.self_attn)\n            .context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Cross attention residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            cache\n                .norm_1\n                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))\n        } else {\n            x = cache\n                .norm_1\n                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));\n            x.clone()\n        };\n\n        // Cross attention.\n        let mut cross_attn_input =\n            MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());\n        if let Some(mask_pad) = &input.memory_mask_pad {\n            cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());\n        }\n        if let Some(mask_attn) = &input.memory_mask_attn {\n            cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());\n        }\n        let residual_path = self\n            .cross_attn\n            .forward_cache(cross_attn_input, &mut cache.cross_attn)\n            .context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Feed forward residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            cache\n                .norm_2\n                .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))\n        } else {\n            x = cache\n                .norm_2\n                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));\n            x.clone()\n        };\n\n        let residual_path = cache\n            .pwff\n            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Main path.\n        // Normalize.\n        if !self.norm_first {\n            x = cache\n                .norm_3\n                .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))\n        }\n\n        input.target = x;\n        input\n    }\n}\n\nimpl<B: Backend> TransformerDecoder<B> {\n    /// Applies the forward pass.\n    pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {\n        for layer in self.layers.iter() {\n            input = layer.forward(input);\n        }\n\n        input.target\n    }\n\n    /// Applies the forward pass on the input using autoregressive cache.\n    pub fn forward_autoregressive_inference(\n        &self,\n        mut input: TransformerDecoderInput<B>,\n        cache: &mut TransformerDecoderAutoregressiveCache<B>,\n    ) -> Tensor<B, 3> {\n        for i in 0..self.layers.len() {\n            let layer = self.layers.get(i).unwrap();\n            let cache = cache.layers.get_mut(i).unwrap();\n\n            input = layer.forward_autoregressive_inference(input, cache);\n        }\n\n        input.target\n    }\n    /// Create an empty autoregressive cache.\n    pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {\n        TransformerDecoderAutoregressiveCache::empty(self.layers.len())\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::Device;\n\n    use super::*;\n    use crate::{TestBackend, attention::generate_autoregressive_mask};\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_autoregressive_norm_last() {\n        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        test_autoregressive(\n            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)\n                .with_norm_first(false),\n        )\n    }\n\n    #[test]\n    fn test_autoregressive_norm_first() {\n        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];\n        let device = Default::default();\n        TestBackend::seed(&device, 0);\n\n        test_autoregressive(\n            TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),\n        )\n    }\n\n    fn test_autoregressive(config: TransformerDecoderConfig) {\n        let device: Device<TestBackend> = Default::default();\n        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];\n        let transformer = config.init::<TestBackend>(&device);\n\n        let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)\n            .float()\n            .reshape([batch_size, seq_length, d_model]);\n        let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)\n            .float()\n            .reshape([batch_size, seq_length, d_model]);\n        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());\n        let input = TransformerDecoderInput::new(target.clone(), memory.clone())\n            .target_mask_attn(mask_attn);\n\n        // Normal forward using masking.\n        let output_1 = transformer.forward(input);\n\n        // Forward using the autoregressive cache.\n        let mut output_2 = Vec::new();\n        let mut cache = transformer.new_autoregressive_cache();\n\n        for i in 1..seq_length + 1 {\n            let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);\n\n            let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());\n            let input = TransformerDecoderInput::new(target.clone(), memory.clone())\n                .target_mask_attn(mask_attn);\n            let next_tok = transformer // Greedy sampling\n                .forward_autoregressive_inference(input, &mut cache)\n                .slice([0..batch_size, i - 1..i, 0..d_model]);\n            output_2.push(next_tok);\n        }\n\n        let output_2 = Tensor::cat(output_2, 1);\n\n        // Should produce the same tokens.\n        let tolerance = Tolerance::rel_abs(5e-3, 1e-4);\n        output_1\n            .into_data()\n            .assert_approx_eq::<FT>(&output_2.into_data(), tolerance);\n    }\n\n    #[test]\n    fn display() {\n        let config = TransformerDecoderConfig::new(2, 4, 2, 3);\n        let transformer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{transformer}\"),\n            \"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \\\n            dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/transformer/encoder.rs",
    "content": "use burn_core as burn;\n\nuse alloc::vec::Vec;\n\nuse super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};\nuse crate::{\n    Dropout, DropoutConfig, LayerNorm, LayerNormConfig,\n    activation::ActivationConfig,\n    attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},\n    cache::TensorCache,\n};\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::{Bool, Tensor, backend::Backend};\n\n/// Configuration to create a [Transformer Encoder](TransformerEncoder) layer using the [init function](TransformerEncoderConfig::init).\n#[derive(Config, Debug)]\npub struct TransformerEncoderConfig {\n    /// The size of the model.\n    pub d_model: usize,\n    /// The size of the position-wise feed-forward network.\n    pub d_ff: usize,\n    /// The number of attention heads.\n    pub n_heads: usize,\n    /// The number of layers.\n    pub n_layers: usize,\n    /// The dropout rate. Default: 0.1\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    /// Layer norm will be applied first instead of after the other modules.\n    #[config(default = false)]\n    pub norm_first: bool,\n    /// Use \"quiet softmax\" instead of regular softmax.\n    ///\n    /// - Usage may improve performance by allowing attention heads to deposit no information (if the sequence contains no information relevant to that head).\n    /// - Usage may reduce the entropy of weights in the model, enhancing quantization and compression.\n    ///\n    /// Reference: <https://www.evanmiller.org/attention-is-off-by-one.html>\n    #[config(default = false)]\n    pub quiet_softmax: bool,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n    /// The activation function used in the position-wise feed-forward network. Default: Gelu\n    #[config(default = \"ActivationConfig::Gelu\")]\n    pub activation: ActivationConfig,\n    /// The epsilon value for layer normalization. Default: 1e-5\n    #[config(default = 1e-5)]\n    pub layer_norm_eps: f64,\n}\n\n/// The transformer encoder module as describe in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).\n///\n/// # Params\n///\n/// - layers: transformer encoder layers with `d_model` input and output features.\n///\n/// Should be created using [TransformerEncoderConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct TransformerEncoder<B: Backend> {\n    /// The transformer encoder layers.\n    pub layers: Vec<TransformerEncoderLayer<B>>,\n\n    /// The size of the model.\n    pub d_model: usize,\n\n    /// The size of the position-wise feed-forward network.\n    pub d_ff: usize,\n\n    /// The number of attention heads.\n    pub n_heads: usize,\n\n    /// The number of layers.\n    pub n_layers: usize,\n\n    /// The dropout rate. Default: 0.1\n    pub dropout: f64,\n\n    /// Layer norm will be applied first instead of after the other modules.\n    pub norm_first: bool,\n\n    /// Use \"quiet softmax\" instead of regular softmax.\n    pub quiet_softmax: bool,\n}\n\nimpl<B: Backend> ModuleDisplay for TransformerEncoder<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"d_model\", &self.d_model)\n            .add(\"d_ff\", &self.d_ff)\n            .add(\"n_heads\", &self.n_heads)\n            .add(\"n_layers\", &self.n_layers)\n            .add(\"dropout\", &self.dropout)\n            .add(\"norm_first\", &self.norm_first)\n            .add(\"quiet_softmax\", &self.quiet_softmax)\n            .optional()\n    }\n}\n\n/// [Transformer Encoder](TransformerEncoder) forward pass input argument.\n#[derive(Debug)]\npub struct TransformerEncoderInput<B: Backend> {\n    tensor: Tensor<B, 3>,\n    mask_pad: Option<Tensor<B, 2, Bool>>,\n    mask_attn: Option<Tensor<B, 3, Bool>>,\n}\n\nimpl<B: Backend> TransformerEncoderInput<B> {\n    /// Create a [transformer encoder](TransformerEncoder) input argument.\n    pub fn new(tensor: Tensor<B, 3>) -> Self {\n        Self {\n            tensor,\n            mask_pad: None,\n            mask_attn: None,\n        }\n    }\n\n    /// Register the padding mask.\n    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {\n        self.mask_pad = Some(mask_pad);\n        self\n    }\n\n    /// Register the attention mask.\n    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {\n        self.mask_attn = Some(mask_attn);\n        self\n    }\n}\nimpl TransformerEncoderConfig {\n    /// Initialize a new [transformer encoder](TransformerEncoder) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {\n        let layers = (0..self.n_layers)\n            .map(|_| TransformerEncoderLayer::new(self, device))\n            .collect::<Vec<_>>();\n\n        TransformerEncoder {\n            layers,\n            d_model: self.d_model,\n            d_ff: self.d_ff,\n            n_heads: self.n_heads,\n            n_layers: self.n_layers,\n            dropout: self.dropout,\n            norm_first: self.norm_first,\n            quiet_softmax: self.quiet_softmax,\n        }\n    }\n}\n\nimpl<B: Backend> TransformerEncoder<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - tensor: `[batch_size, seq_length, d_model]`\n    /// - output: `[batch_size, seq_length, d_model]`\n    pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {\n        let mut x = input.tensor;\n\n        for layer in self.layers.iter() {\n            x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());\n        }\n\n        x\n    }\n    /// Applies the forward pass on the input tensor using autoregressive cache.\n    ///\n    /// # Shapes\n    ///\n    /// - tensor: `[batch_size, seq_length, d_model]`\n    /// - output: `[batch_size, seq_length, d_model]`\n    pub fn forward_autoregressive_inference(\n        &self,\n        input: TransformerEncoderInput<B>,\n        cache: &mut TransformerEncoderAutoregressiveCache<B>,\n    ) -> Tensor<B, 3> {\n        let mut x = input.tensor;\n\n        for i in 0..self.layers.len() {\n            let layer = self.layers.get(i).unwrap();\n            let cache = cache.layers.get_mut(i).unwrap();\n\n            x = layer.forward_autoregressive_inference(\n                x,\n                input.mask_pad.clone(),\n                input.mask_attn.clone(),\n                cache,\n            );\n        }\n\n        x\n    }\n\n    /// Create an empty autoregressive cache.\n    pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {\n        TransformerEncoderAutoregressiveCache::empty(self.layers.len())\n    }\n}\n\n/// Transformer encoder layer module.\n#[derive(Module, Debug)]\npub struct TransformerEncoderLayer<B: Backend> {\n    /// Multi-head self-attention sub-layer.\n    pub mha: MultiHeadAttention<B>,\n    /// Position-wise feed-forward sub-layer.\n    pub pwff: PositionWiseFeedForward<B>,\n    /// Layer normalization applied around the feed-forward sub-layer.\n    pub norm_1: LayerNorm<B>,\n    /// Layer normalization applied around the attention sub-layer.\n    pub norm_2: LayerNorm<B>,\n    /// Dropout module applied to residual connections.\n    pub dropout: Dropout,\n    /// If `true`, apply layer normalization before sub-layers (pre-norm),\n    /// otherwise apply it after (post-norm).\n    pub norm_first: bool,\n}\n\nimpl<B: Backend> TransformerEncoderLayer<B> {\n    /// Create a new transformer encoder layer from the given configuration.\n    pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {\n        let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)\n            .with_initializer(config.initializer.clone())\n            .with_dropout(config.dropout)\n            .with_quiet_softmax(config.quiet_softmax)\n            .init(device);\n        let norm_1 = LayerNormConfig::new(config.d_model)\n            .with_epsilon(config.layer_norm_eps)\n            .init(device);\n        let norm_2 = LayerNormConfig::new(config.d_model)\n            .with_epsilon(config.layer_norm_eps)\n            .init(device);\n        let dropout = DropoutConfig::new(config.dropout).init();\n        let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)\n            .with_initializer(config.initializer.clone())\n            .with_dropout(config.dropout)\n            .with_activation(config.activation.clone())\n            .init(device);\n\n        Self {\n            mha,\n            norm_1,\n            norm_2,\n            pwff,\n            dropout,\n            norm_first: config.norm_first,\n        }\n    }\n\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, seq_length, d_model]`\n    /// - output: `[batch_size, seq_length, d_model]`\n    pub fn forward(\n        &self,\n        input: Tensor<B, 3>,\n        mask_pad: Option<Tensor<B, 2, Bool>>,\n        mask_attn: Option<Tensor<B, 3, Bool>>,\n    ) -> Tensor<B, 3> {\n        // Multi-head attention residual path.\n        let x = input;\n        let mut residual_path = x.clone();\n\n        // Normalize.\n        if self.norm_first {\n            residual_path = self.norm_2.forward(residual_path)\n        }\n\n        // Multi-head attention.\n        let mut input_mhs = MhaInput::self_attn(residual_path);\n        if let Some(mask_pad) = mask_pad {\n            input_mhs = input_mhs.mask_pad(mask_pad);\n        }\n        if let Some(mask_attn) = mask_attn {\n            input_mhs = input_mhs.mask_attn(mask_attn);\n        }\n        let residual_path = self.mha.forward(input_mhs).context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Feed forward residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            self.norm_1.forward(x.clone())\n        } else {\n            x = self.norm_1.forward(x);\n            x.clone()\n        };\n\n        // Feed forward.\n        let residual_path = self.pwff.forward(residual_path);\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Main path.\n        // Normalize.\n        if !self.norm_first {\n            x = self.norm_2.forward(x)\n        }\n\n        x\n    }\n\n    /// Applies the forward pass using an autoregressive cache.\n    pub fn forward_autoregressive_inference(\n        &self,\n        input: Tensor<B, 3>,\n        mask_pad: Option<Tensor<B, 2, Bool>>,\n        mask_attn: Option<Tensor<B, 3, Bool>>,\n        cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,\n    ) -> Tensor<B, 3> {\n        // Multi-head attention residual path.\n        let x = input;\n        let mut residual_path = x.clone();\n\n        // Normalize.\n        if self.norm_first {\n            residual_path = cache\n                .norm_2\n                .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))\n        }\n\n        // Multi-head attention.\n        let mut input_mhs = MhaInput::self_attn(residual_path);\n        if let Some(mask_pad) = mask_pad {\n            input_mhs = input_mhs.mask_pad(mask_pad);\n        }\n        if let Some(mask_attn) = mask_attn {\n            input_mhs = input_mhs.mask_attn(mask_attn);\n        }\n        let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;\n\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Feed forward residual path.\n        // Normalize.\n        let residual_path = if self.norm_first {\n            cache\n                .norm_1\n                .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))\n        } else {\n            x = cache\n                .norm_1\n                .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));\n            x.clone()\n        };\n\n        // Feed forward.\n        let residual_path = cache\n            .pwff\n            .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));\n        let residual_path = self.dropout.forward(residual_path);\n        let mut x = x + residual_path;\n\n        // Main path.\n        // Normalize.\n        if !self.norm_first {\n            x = cache\n                .norm_2\n                .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))\n        }\n\n        x\n    }\n}\n\n/// Autoregressive cache for a single [Transformer Encoder Layer](TransformerEncoderLayer).\npub struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {\n    /// Multi-head attention cache.\n    pub mha: MhaCache<B>,\n    /// Position-wise feed-forward cache.\n    pub pwff: TensorCache<B, 3>,\n    /// First layer norm cache.\n    pub norm_1: TensorCache<B, 3>,\n    /// Second layer norm cache.\n    pub norm_2: TensorCache<B, 3>,\n}\n\nimpl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {\n    /// Create an empty cache.\n    pub fn empty() -> Self {\n        Self {\n            mha: MhaCache::autoregressive(),\n            pwff: TensorCache::empty(),\n            norm_1: TensorCache::empty(),\n            norm_2: TensorCache::empty(),\n        }\n    }\n}\n\n/// Autoregressive cache for the [Transformer Encoder](TransformerEncoder) layer.\n///\n/// To be used during inference when decoding tokens.\npub struct TransformerEncoderAutoregressiveCache<B: Backend> {\n    layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,\n}\n\nimpl<B: Backend> TransformerEncoderAutoregressiveCache<B> {\n    fn empty(num_layers: usize) -> Self {\n        Self {\n            layers: (0..num_layers)\n                .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())\n                .collect(),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, attention::generate_autoregressive_mask};\n    use burn::tensor::Distribution;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    type FT = FloatElem<TestBackend>;\n\n    #[test]\n    fn test_autoregressive_norm_last() {\n        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];\n        test_autoregressive(\n            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)\n                .with_norm_first(false),\n        )\n    }\n\n    #[test]\n    fn test_autoregressive_norm_first() {\n        let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];\n        test_autoregressive(\n            TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),\n        )\n    }\n\n    fn test_autoregressive(config: TransformerEncoderConfig) {\n        let [batch_size, seq_length, d_model] = [3, 4, config.d_model];\n        let device = Default::default();\n        let transformer = config.init(&device);\n\n        let tensor = Tensor::<TestBackend, 3>::random(\n            [batch_size, seq_length, d_model],\n            Distribution::Default,\n            &device,\n        );\n        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());\n        let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);\n\n        let output_1 = transformer.forward(input);\n        let mut output_2 = Vec::new();\n        let mut cache = transformer.new_autoregressive_cache();\n\n        for i in 1..seq_length + 1 {\n            let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);\n            let input = TransformerEncoderInput::new(tensor.clone());\n            let next_tok = transformer\n                .forward_autoregressive_inference(input, &mut cache)\n                .slice([0..batch_size, i - 1..i, 0..d_model]);\n            output_2.push(next_tok);\n        }\n\n        let output_2 = Tensor::cat(output_2, 1);\n\n        output_1\n            .into_data()\n            .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());\n    }\n\n    #[test]\n    fn display() {\n        let config = TransformerEncoderConfig::new(2, 4, 2, 3);\n        let transformer = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{transformer}\"),\n            \"TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \\\n            n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/transformer/mod.rs",
    "content": "mod decoder;\nmod encoder;\nmod pwff;\n\npub use decoder::*;\npub use encoder::*;\npub use pwff::*;\n"
  },
  {
    "path": "crates/burn-nn/src/modules/transformer/pwff.rs",
    "content": "use burn_core as burn;\n\nuse crate::activation::{Activation, ActivationConfig};\nuse crate::{Dropout, DropoutConfig, Linear, LinearConfig};\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Configuration to create a [position-wise feed-forward](PositionWiseFeedForward) layer using the [init function](PositionWiseFeedForwardConfig::init).\n#[derive(Config, Debug)]\npub struct PositionWiseFeedForwardConfig {\n    /// The size of the input and output features.\n    pub d_model: usize,\n    /// The size of the hidden inner features.\n    pub d_ff: usize,\n    /// The dropout rate. Default: 0.1\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    /// The type of function used to initialize neural network parameters\n    #[config(\n        default = \"Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}\"\n    )]\n    pub initializer: Initializer,\n    /// The activation function used between the two linear layers. Default: Gelu\n    #[config(default = \"ActivationConfig::Gelu\")]\n    pub activation: ActivationConfig,\n}\n\n/// Applies the position-wise feed-forward network to the input tensor from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762v7).\n///\n/// # Params\n///\n/// - linear inner: Linear layer with `d_model` input features and `d_ff` output features.\n/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.\n///\n/// `FFN(x) = max(0, xW1 + b1)W2 + b2`\n///\n/// Should be created using [PositionWiseFeedForwardConfig]\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct PositionWiseFeedForward<B: Backend> {\n    /// Linear layer with `d_model` input features and `d_ff` output features.\n    pub linear_inner: Linear<B>,\n    /// Linear layer with `d_ff` input features and `d_model` output features.\n    pub linear_outer: Linear<B>,\n    /// Dropout layer.\n    pub dropout: Dropout,\n    /// Activation function.\n    pub activation: Activation<B>,\n}\n\nimpl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let [d_model, dff] = self.linear_inner.weight.shape().dims();\n\n        content\n            .add(\"d_model\", &d_model)\n            .add(\"d_ff\", &dff)\n            .add(\"prob\", &self.dropout.prob)\n            .optional()\n    }\n}\n\nimpl PositionWiseFeedForwardConfig {\n    /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {\n        PositionWiseFeedForward {\n            linear_inner: LinearConfig::new(self.d_model, self.d_ff)\n                .with_initializer(self.initializer.clone())\n                .init(device),\n            linear_outer: LinearConfig::new(self.d_ff, self.d_model)\n                .with_initializer(self.initializer.clone())\n                .init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n            activation: self.activation.init(device),\n        }\n    }\n}\n\nimpl<B: Backend> PositionWiseFeedForward<B> {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - tensor: `[batch_size, seq_length, d_model]`\n    /// - output: `[batch_size, seq_length, d_model]`\n    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {\n        let x = self.linear_inner.forward(input);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.linear_outer.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn display() {\n        let config = PositionWiseFeedForwardConfig::new(2, 4);\n        let pwff = config.init::<TestBackend>(&Default::default());\n\n        assert_eq!(\n            alloc::format!(\"{pwff}\"),\n            \"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/modules/unfold.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\n\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::module::unfold4d;\nuse burn::tensor::ops::UnfoldOptions;\n\n/// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init).\n#[derive(Config, Debug)]\npub struct Unfold4dConfig {\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The stride of the convolution.\n    #[config(default = \"[1, 1]\")]\n    pub stride: [usize; 2],\n    /// Spacing between kernel elements.\n    #[config(default = \"[1, 1]\")]\n    pub dilation: [usize; 2],\n    /// The padding configuration.\n    #[config(default = \"[0, 0]\")]\n    pub padding: [usize; 2],\n}\n\n/// Four-dimensional unfolding.\n///\n/// Should be created with [Unfold4dConfig].\n#[derive(Module, Clone, Debug)]\n#[module(custom_display)]\npub struct Unfold4d {\n    /// The size of the kernel.\n    pub kernel_size: [usize; 2],\n    /// The stride of the convolution.\n    pub stride: [usize; 2],\n    /// Spacing between kernel elements.\n    pub dilation: [usize; 2],\n    /// The padding configuration.\n    pub padding: [usize; 2],\n}\n\nimpl ModuleDisplay for Unfold4d {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"kernel_size\", &alloc::format!(\"{:?}\", &self.kernel_size))\n            .add(\"stride\", &alloc::format!(\"{:?}\", &self.stride))\n            .add(\"dilation\", &alloc::format!(\"{:?}\", &self.dilation))\n            .add(\"padding\", &alloc::format!(\"{:?}\", &self.padding))\n            .optional()\n    }\n}\n\nimpl Unfold4dConfig {\n    /// Initializes a new [Unfold4d] module.\n    pub fn init(&self) -> Unfold4d {\n        Unfold4d {\n            kernel_size: self.kernel_size,\n            stride: self.stride,\n            dilation: self.dilation,\n            padding: self.padding,\n        }\n    }\n}\n\nimpl Unfold4d {\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// See [unfold4d](burn::tensor::module::unfold4d) for more information.\n    ///\n    /// # Shapes\n    ///\n    /// input:   `[batch_size, channels_in, height, width]`\n    /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`\n    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 3> {\n        unfold4d(\n            input,\n            self.kernel_size,\n            UnfoldOptions::new(self.stride, self.padding, self.dilation),\n        )\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn display() {\n        let config = Unfold4dConfig::new([3, 3]);\n        let unfold = config.init();\n\n        assert_eq!(\n            alloc::format!(\"{unfold}\"),\n            \"Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/src/padding.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\n\n/// Calculate asymmetric padding for \"same\" convolution.\n/// Returns (start_padding, end_padding) where start is applied first (top/left).\n/// For odd total padding, the extra pad goes to the end (bottom/right) following ONNX convention.\nfn calculate_same_padding(kernel_size: usize, stride: usize, size_in: usize) -> (usize, usize) {\n    let size_out = size_in.div_ceil(stride); // ceil division for same padding\n    let total_padding = if size_out > 0 {\n        let needed = (size_out - 1) * stride + kernel_size;\n        needed.saturating_sub(size_in)\n    } else {\n        0\n    };\n    let pad_start = total_padding / 2;\n    let pad_end = total_padding - pad_start;\n    (pad_start, pad_end)\n}\n\n/// Padding configuration for 1D operators.\n#[derive(Config, Debug, PartialEq)]\npub enum PaddingConfig1d {\n    /// Dynamically calculates padding to ensure output size matches input size.\n    Same,\n    /// No padding applied.\n    Valid,\n    /// Applies explicit padding values.\n    /// Format: (left, right)\n    /// For symmetric padding, use the same value for both (e.g., `Explicit(1, 1)`).\n    Explicit(usize, usize),\n}\n\nimpl PaddingConfig1d {\n    /// Calculate padding as (left, right) pair for 1D operations.\n    /// For `Same` padding, this computes the actual asymmetric padding if needed.\n    pub(crate) fn calculate_padding_1d_pair(\n        &self,\n        length: usize,\n        kernel_size: usize,\n        stride: usize,\n    ) -> (usize, usize) {\n        match self {\n            Self::Valid => (0, 0),\n            Self::Same => calculate_same_padding(kernel_size, stride, length),\n            Self::Explicit(left, right) => (*left, *right),\n        }\n    }\n}\n\n/// Padding configuration for 2D operators.\n#[derive(Config, Debug, PartialEq)]\npub enum PaddingConfig2d {\n    /// Dynamically calculates padding to preserve input dimensions in output.\n    Same,\n    /// No padding applied.\n    Valid,\n    /// Applies explicit padding values.\n    /// Format: (top, left, bottom, right)\n    /// For symmetric padding, use matching values (e.g., `Explicit(1, 1, 1, 1)`).\n    Explicit(usize, usize, usize, usize),\n}\n\nimpl PaddingConfig2d {\n    /// Calculate padding as ((top, bottom), (left, right)) pairs for 2D operations.\n    /// For `Same` padding, this computes the actual asymmetric padding if needed.\n    pub(crate) fn calculate_padding_2d_pairs(\n        &self,\n        height: usize,\n        width: usize,\n        kernel_size: &[usize; 2],\n        stride: &[usize; 2],\n    ) -> ((usize, usize), (usize, usize)) {\n        match self {\n            Self::Valid => ((0, 0), (0, 0)),\n            Self::Same => {\n                let (top, bottom) = calculate_same_padding(kernel_size[0], stride[0], height);\n                let (left, right) = calculate_same_padding(kernel_size[1], stride[1], width);\n                ((top, bottom), (left, right))\n            }\n            Self::Explicit(top, left, bottom, right) => ((*top, *bottom), (*left, *right)),\n        }\n    }\n\n    /// Calculate symmetric padding for 2D operations.\n    /// Returns padding values [height, width] (same for both sides).\n    /// Panics if asymmetric padding is detected.\n    pub(crate) fn calculate_padding_2d(\n        &self,\n        height: usize,\n        width: usize,\n        kernel_size: &[usize; 2],\n        stride: &[usize; 2],\n    ) -> [usize; 2] {\n        let ((top, bottom), (left, right)) =\n            self.calculate_padding_2d_pairs(height, width, kernel_size, stride);\n        if top != bottom || left != right {\n            panic!(\"Asymmetric padding should be handled via calculate_padding_2d_pairs()\")\n        }\n        [top, left]\n    }\n}\n\n/// Padding configuration for 3D operators.\n#[derive(Config, Debug, PartialEq)]\npub enum PaddingConfig3d {\n    /// Dynamically calculates padding to preserve input dimensions in output.\n    Same,\n    /// No padding applied.\n    Valid,\n    /// Applies explicit symmetric padding values.\n    /// Format: (depth, height, width) — same padding on both sides of each dimension.\n    Explicit(usize, usize, usize),\n}\n\nimpl PaddingConfig3d {\n    /// Calculate symmetric padding for 3D operations.\n    /// Returns padding values [depth, height, width] (same for both sides).\n    pub(crate) fn calculate_padding_3d(\n        &self,\n        depth: usize,\n        height: usize,\n        width: usize,\n        kernel_size: &[usize; 3],\n        stride: &[usize; 3],\n    ) -> [usize; 3] {\n        match self {\n            Self::Valid => [0, 0, 0],\n            Self::Same => {\n                let (front, back) = calculate_same_padding(kernel_size[0], stride[0], depth);\n                let (top, bottom) = calculate_same_padding(kernel_size[1], stride[1], height);\n                let (left, right) = calculate_same_padding(kernel_size[2], stride[2], width);\n                if front != back || top != bottom || left != right {\n                    panic!(\n                        \"Asymmetric 3D 'Same' padding is not supported. \\\n                        Use odd kernel sizes for symmetric padding.\"\n                    )\n                }\n                [front, top, left]\n            }\n            Self::Explicit(depth, height, width) => [*depth, *height, *width],\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    // ==================== PaddingConfig1d Tests ====================\n\n    #[test]\n    fn test_padding_config_1d_calculate_pair_valid() {\n        let padding = PaddingConfig1d::Valid;\n        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (0, 0));\n    }\n\n    #[test]\n    fn test_padding_config_1d_calculate_pair_explicit() {\n        let padding = PaddingConfig1d::Explicit(1, 2);\n        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 2));\n    }\n\n    #[test]\n    fn test_padding_config_1d_calculate_pair_same() {\n        let padding = PaddingConfig1d::Same;\n        // kernel=3, stride=1, length=10: total=2, start=1, end=1\n        assert_eq!(padding.calculate_padding_1d_pair(10, 3, 1), (1, 1));\n    }\n\n    // ==================== PaddingConfig2d Tests ====================\n\n    #[test]\n    fn test_padding_config_2d_calculate_pairs_valid() {\n        let padding = PaddingConfig2d::Valid;\n        assert_eq!(\n            padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),\n            ((0, 0), (0, 0))\n        );\n    }\n\n    #[test]\n    fn test_padding_config_2d_calculate_pairs_explicit() {\n        let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);\n        assert_eq!(\n            padding.calculate_padding_2d_pairs(10, 10, &[3, 3], &[1, 1]),\n            ((1, 3), (2, 4))\n        );\n    }\n\n    #[test]\n    fn test_padding_config_2d_calculate_symmetric_valid() {\n        let padding = PaddingConfig2d::Valid;\n        assert_eq!(\n            padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),\n            [0, 0]\n        );\n    }\n\n    #[test]\n    fn test_padding_config_2d_calculate_symmetric_explicit() {\n        let padding = PaddingConfig2d::Explicit(2, 3, 2, 3);\n        assert_eq!(\n            padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]),\n            [2, 3]\n        );\n    }\n\n    #[test]\n    #[should_panic(\n        expected = \"Asymmetric padding should be handled via calculate_padding_2d_pairs\"\n    )]\n    fn test_padding_config_2d_calculate_symmetric_asymmetric_panics() {\n        let padding = PaddingConfig2d::Explicit(1, 2, 3, 4);\n        let _ = padding.calculate_padding_2d(10, 10, &[3, 3], &[1, 1]);\n    }\n\n    // ==================== PaddingConfig3d Tests ====================\n\n    #[test]\n    fn test_padding_config_3d_calculate_valid() {\n        let padding = PaddingConfig3d::Valid;\n        assert_eq!(\n            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),\n            [0, 0, 0]\n        );\n    }\n\n    #[test]\n    fn test_padding_config_3d_calculate_explicit() {\n        let padding = PaddingConfig3d::Explicit(1, 2, 3);\n        assert_eq!(\n            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),\n            [1, 2, 3]\n        );\n    }\n\n    #[test]\n    fn test_padding_config_3d_calculate_same_odd_kernel() {\n        let padding = PaddingConfig3d::Same;\n        // kernel=3, stride=1: total=2, symmetric (1,1) per dim\n        assert_eq!(\n            padding.calculate_padding_3d(10, 10, 10, &[3, 3, 3], &[1, 1, 1]),\n            [1, 1, 1]\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-nn/tests/quantize.rs",
    "content": "use burn_core as burn;\n\nuse burn::module::{Module, Quantizer};\nuse burn::tensor::{\n    Device, Distribution, Tensor, Tolerance,\n    ops::{FloatElem, QuantizedTensor},\n    quantization::{\n        Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantScheme, QuantValue,\n    },\n};\nuse burn_nn::{\n    Linear, LinearConfig,\n    transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},\n};\n\n#[cfg(all(\n    test,\n    not(feature = \"test-wgpu\"),\n    not(feature = \"test-cuda\"),\n    not(feature = \"test-rocm\")\n))]\npub type B = burn_ndarray::NdArray<f32>;\n\n#[cfg(all(test, feature = \"test-wgpu\"))]\n/// Backend for test cases\npub type B = burn_wgpu::Wgpu;\n\n#[cfg(all(test, feature = \"test-cuda\"))]\n/// Backend for test cases\npub type B = burn_cuda::Cuda;\n\n#[cfg(all(test, feature = \"test-rocm\"))]\n/// Backend for test cases\npub type B = burn_rocm::Rocm;\n\nfn should_quantize_module<M: Module<B>, const D: usize, F: Fn(&M) -> Tensor<B, D>>(\n    module: M,\n    scheme: QuantScheme,\n    func: F,\n    tolerance: Tolerance<FloatElem<B>>,\n) {\n    let result = func(&module);\n\n    let calibration = Calibration::MinMax;\n    let mut quantizer = Quantizer {\n        calibration,\n        scheme,\n    };\n    let q_module = module.quantize_weights(&mut quantizer);\n    let q_result = func(&q_module);\n\n    result\n        .into_data()\n        .assert_approx_eq::<f32>(&q_result.into_data(), tolerance);\n}\n\n#[test]\nfn should_quantize_transformer() {\n    let device: Device<B> = Default::default();\n    let transformer: TransformerEncoder<B> =\n        TransformerEncoderConfig::new(128, 256, 2, 2).init(&device);\n    let signal = Tensor::random([2, 32, 128], Distribution::Default, &device);\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([32]))\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.forward(TransformerEncoderInput::new(signal.clone())),\n        Tolerance::rel_abs(1e-2, 2e-2), // slightly higher abs tolerance (permissive: 1e-2)\n    );\n}\n\n#[test]\nfn should_quantize_linear_128_256() {\n    let device: Device<B> = Default::default();\n    let transformer: Linear<B> = LinearConfig::new(128, 256).with_bias(false).init(&device);\n    let signal = Tensor::<B, 2>::random([1, 128], Distribution::Default, &device);\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::Tensor)\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.forward(signal.clone()),\n        Tolerance::permissive(),\n    );\n}\n\n#[test]\nfn should_quantize_linear() {\n    let device: Device<B> = Default::default();\n    let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);\n    let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);\n    // Default scheme should select supported QuantStore default\n    // TODO: set native if dtype is supported by the test backend\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::Tensor)\n        // .with_store(QuantStore::Native)\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.forward(signal.clone()),\n        Tolerance::permissive(),\n    );\n}\n\n#[test]\nfn should_quantize_linear_weights() {\n    let device: Device<B> = Default::default();\n    let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::Tensor)\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.weight.val().dequantize(),\n        Tolerance::permissive(),\n    );\n}\n\n#[test]\nfn should_quantize_linear_blocks() {\n    let device: Device<B> = Default::default();\n    let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);\n    let signal = Tensor::<B, 2>::random([1, 32], Distribution::Default, &device);\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([16]))\n        // .with_store(QuantStore::Native)\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.forward(signal.clone()),\n        Tolerance::permissive(),\n    );\n}\n\n#[test]\nfn should_quantize_linear_weights_blocks() {\n    let device: Device<B> = Default::default();\n    let transformer: Linear<B> = LinearConfig::new(32, 32).with_bias(false).init(&device);\n    let scheme = <QuantizedTensor<B> as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([16]))\n        // .with_store(QuantStore::Native)\n        .with_param(QuantParam::F32);\n\n    should_quantize_module(\n        transformer,\n        scheme,\n        |tr| tr.weight.val().dequantize(),\n        Tolerance::permissive(),\n    );\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/Cargo.toml",
    "content": "[package]\nauthors = [\n  \"nathanielsimard <nathaniel.simard.42@gmail.com>\",\n  \"Dilshod Tadjibaev (@antimora)\",\n]\nedition.workspace = true\nlicense.workspace = true\nname = \"burn-no-std-tests\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-no-std-tests\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = []\n\ntracing = [\n    \"burn/tracing\",\n    \"burn-ndarray/tracing\",\n    \"burn-store/tracing\",\n]\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std **\n\nburn = { path = \"../burn\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", default-features = false }\n\nburn-store = { path = \"../burn-store\", version = \"=0.21.0-pre.2\", default-features = false, features = [\"safetensors\", \"burnpack\"]}\n"
  },
  {
    "path": "crates/burn-no-std-tests/README.md",
    "content": "The `burn-no-std-tests` contains integration tests aimed to check `no_std` compatibility of `burn`, `burn-core`, `burn-tensor` and `burn-ndarray` packages.\n\nCurrently there is only a minimal test that checks if mnist model can be built with `no_std`. More tests should be added to check completeness.\n\nThe continuous integration (CI) should build with additional targets:\n\n * `wasm32-unknown-unknown` - WebAssembly\n * `thumbv7m-none-eabi` - ARM Cortex-M3\n * `thumbv6m-none-eabi` - ARM Cortex-M0+\n\nShell commands to build and test the package:\n\n```sh\n\n# install the new targets if not installed previously\nrustup target add thumbv6m-none-eabi\nrustup target add thumbv7m-none-eabi\nrustup target add wasm32-unknown-unknown\n\n# build for various targets \ncargo build # regular build\ncargo build --target thumbv7m-none-eabi\ncargo build --target wasm32-unknown-unknown\nRUSTFLAGS=\"--cfg portable_atomic_unsafe_assume_single_core\" cargo build --target thumbv6m-none-eabi\n\n# test\ncargo test\n\n ```"
  },
  {
    "path": "crates/burn-no-std-tests/src/burnpack.rs",
    "content": "// Test Burnpack storage in no-std environment\n\nuse burn::{\n    module::Module,\n    nn,\n    tensor::{Tensor, backend::Backend},\n};\n\nuse burn_store::{BurnpackStore, ModuleSnapshot, PathFilter};\n\n/// Simple model for testing Burnpack storage\n#[derive(Module, Debug)]\npub struct TestModel<B: Backend> {\n    linear1: nn::Linear<B>,\n    linear2: nn::Linear<B>,\n    batch_norm: nn::BatchNorm<B>,\n}\n\nimpl<B: Backend> TestModel<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            linear1: nn::LinearConfig::new(10, 20).init(device),\n            linear2: nn::LinearConfig::new(20, 10).init(device),\n            batch_norm: nn::BatchNormConfig::new(10).init(device),\n        }\n    }\n\n    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {\n        let x = self.linear1.forward(x);\n        let x = self.linear2.forward(x);\n        // Apply batch norm (expand to 3D, apply, then squeeze back)\n        let x: Tensor<B, 3> = x.unsqueeze_dim(2);\n        let x = self.batch_norm.forward(x);\n        x.squeeze_dim(2)\n    }\n}\n\n/// Test basic Burnpack save and load in no-std\npub fn test_burnpack_basic<B: Backend>(device: &B::Device) {\n    // Create a model\n    let model = TestModel::<B>::new(device);\n\n    // Save to bytes (no file I/O in no-std)\n    let mut save_store = BurnpackStore::from_bytes(None);\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save model\");\n\n    // Get the serialized bytes\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load from bytes\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut loaded_model = TestModel::<B>::new(device);\n    let result = loaded_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load model\");\n\n    // Verify all tensors were loaded\n    assert!(result.is_success(), \"Should have no errors\");\n    assert!(!result.applied.is_empty(), \"Should have loaded tensors\");\n\n    // Test that the model still works\n    let input = Tensor::<B, 2>::ones([2, 10], device);\n    let _output = loaded_model.forward(input);\n}\n\n/// Test Burnpack with filtering in no-std\npub fn test_burnpack_filtering<B: Backend>(device: &B::Device) {\n    let model = TestModel::<B>::new(device);\n\n    // Save only linear1 weights\n    let filter = PathFilter::new()\n        .with_full_path(\"linear1.weight\")\n        .with_full_path(\"linear1.bias\");\n    let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter);\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save filtered model\");\n\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load with partial loading allowed\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true);\n    let mut partial_model = TestModel::<B>::new(device);\n    let result = partial_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load partial model\");\n\n    // Verify that only linear1 was loaded\n    assert_eq!(result.applied.len(), 2, \"Should have loaded 2 tensors\");\n    assert!(!result.missing.is_empty(), \"Should have missing tensors\");\n}\n\n/// Test Burnpack with metadata in no-std\npub fn test_burnpack_metadata<B: Backend>(device: &B::Device) {\n    let model = TestModel::<B>::new(device);\n\n    // Save with metadata\n    let mut save_store = BurnpackStore::from_bytes(None)\n        .metadata(\"version\", \"1.0.0\")\n        .metadata(\"environment\", \"no-std\")\n        .metadata(\"model_type\", \"test\");\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save model with metadata\");\n\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load and verify it works\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut loaded_model = TestModel::<B>::new(device);\n    let result = loaded_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load model with metadata\");\n\n    assert!(result.is_success(), \"Should load successfully\");\n}\n\n// Note: Key remapping test is omitted as KeyRemapper requires std feature\n\n// Note: Regex filtering test is omitted as with_regex requires std feature\n\n/// Test Burnpack with match_all in no-std\npub fn test_burnpack_match_all<B: Backend>(device: &B::Device) {\n    let model = TestModel::<B>::new(device);\n\n    // Save with match_all (should save everything)\n    let mut save_store = BurnpackStore::from_bytes(None).match_all();\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save model\");\n\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load everything\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut loaded_model = TestModel::<B>::new(device);\n    let result = loaded_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load model\");\n\n    assert!(result.is_success(), \"Should load successfully\");\n    // linear1 (weight, bias) + linear2 (weight, bias) + batch_norm (4 params)\n    assert_eq!(result.applied.len(), 8, \"Should load all 8 tensors\");\n    assert!(result.missing.is_empty(), \"Should have no missing tensors\");\n    assert!(result.unused.is_empty(), \"Should have no unused tensors\");\n}\n\n/// Run all Burnpack no-std tests\npub fn run_all_tests<B: Backend>(device: &B::Device) {\n    test_burnpack_basic::<B>(device);\n    test_burnpack_filtering::<B>(device);\n    test_burnpack_metadata::<B>(device);\n    // test_burnpack_remapping requires KeyRemapper which needs std\n    // test_burnpack_regex_filter requires with_regex which needs std\n    test_burnpack_match_all::<B>(device);\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/src/conv.rs",
    "content": "// Originally copied from the burn/examples/mnist package\n\nuse burn::{\n    config::Config,\n    module::Module,\n    nn,\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct ConvBlock<B: Backend> {\n    conv: nn::conv::Conv2d<B>,\n    pool: nn::pool::MaxPool2d,\n    activation: nn::Gelu,\n}\n\n#[derive(Config, Debug)]\npub struct ConvBlockConfig {\n    channels: [usize; 2],\n    #[config(default = \"[3, 3]\")]\n    kernel_size: [usize; 2],\n}\n\nimpl<B: Backend> ConvBlock<B> {\n    pub fn new(config: &ConvBlockConfig, device: &B::Device) -> Self {\n        let conv = nn::conv::Conv2dConfig::new(config.channels, config.kernel_size)\n            .with_padding(nn::PaddingConfig2d::Same)\n            .init(device);\n        let pool = nn::pool::MaxPool2dConfig::new(config.kernel_size)\n            .with_strides([1, 1])\n            .with_padding(nn::PaddingConfig2d::Same)\n            .init();\n        let activation = nn::Gelu::new();\n\n        Self {\n            conv,\n            pool,\n            activation,\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv.forward(input.clone());\n        let x = self.pool.forward(x);\n        let x = self.activation.forward(x);\n\n        (x + input) / 2.0\n    }\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/src/lib.rs",
    "content": "#![no_std]\n\npub mod burnpack;\npub mod conv;\npub mod mlp;\npub mod model;\npub mod safetensors;\n\nextern crate alloc;\n"
  },
  {
    "path": "crates/burn-no-std-tests/src/mlp.rs",
    "content": "// Originally copied from the burn/examples/mnist package\n\nuse alloc::vec::Vec;\n\nuse burn::{\n    config::Config,\n    module::Module,\n    nn,\n    tensor::{Tensor, backend::Backend},\n};\n\n/// Configuration to create a [Multilayer Perceptron](Mlp) layer.\n#[derive(Config, Debug)]\npub struct MlpConfig {\n    /// The number of layers.\n    #[config(default = 3)]\n    pub num_layers: usize,\n    /// The dropout rate.\n    #[config(default = 0.5)]\n    pub dropout: f64,\n    /// The size of each layer.\n    #[config(default = 256)]\n    pub d_model: usize,\n}\n\n/// Multilayer Perceptron module.\n#[derive(Module, Debug)]\npub struct Mlp<B: Backend> {\n    linears: Vec<nn::Linear<B>>,\n    dropout: nn::Dropout,\n    activation: nn::Relu,\n}\n\nimpl<B: Backend> Mlp<B> {\n    /// Create the module from the given configuration.\n    pub fn new(config: &MlpConfig, device: &B::Device) -> Self {\n        let mut linears = Vec::with_capacity(config.num_layers);\n\n        for _ in 0..config.num_layers {\n            linears.push(nn::LinearConfig::new(config.d_model, config.d_model).init(device));\n        }\n\n        Self {\n            linears,\n            dropout: nn::DropoutConfig::new(0.3).init(),\n            activation: nn::Relu::new(),\n        }\n    }\n\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, d_model]`\n    /// - output: `[batch_size, d_model]`\n    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        let mut x = input;\n\n        for linear in self.linears.iter() {\n            x = linear.forward(x);\n            x = self.dropout.forward(x);\n            x = self.activation.forward(x);\n        }\n\n        x\n    }\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/src/model.rs",
    "content": "// Originally copied from the burn/examples/mnist package\n\nuse crate::{\n    conv::{ConvBlock, ConvBlockConfig},\n    mlp::{Mlp, MlpConfig},\n};\n\nuse burn::{\n    config::Config,\n    module::Module,\n    nn,\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Config, Debug)]\npub struct MnistConfig {\n    #[config(default = 42)]\n    pub seed: u64,\n\n    pub mlp: MlpConfig,\n\n    #[config(default = 784)]\n    pub input_size: usize,\n\n    #[config(default = 10)]\n    pub output_size: usize,\n}\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    mlp: Mlp<B>,\n    conv: ConvBlock<B>,\n    input: nn::Linear<B>,\n    output: nn::Linear<B>,\n    num_classes: usize,\n}\n\nimpl<B: Backend> Model<B> {\n    pub fn new(config: &MnistConfig, device: &B::Device) -> Self {\n        let mlp = Mlp::new(&config.mlp, device);\n        let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init(device);\n        let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init(device);\n        let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1]), device);\n\n        Self {\n            mlp,\n            conv,\n            output,\n            input,\n            num_classes: config.output_size,\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = input.dims();\n\n        let x = input.reshape([batch_size, 1, height, width]).detach();\n        let x = self.conv.forward(x);\n        let x = x.reshape([batch_size, height * width]);\n\n        let x = self.input.forward(x);\n        let x = self.mlp.forward(x);\n\n        self.output.forward(x)\n    }\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/src/safetensors.rs",
    "content": "// Test SafeTensors storage in no-std environment\n\nuse burn::{\n    module::Module,\n    nn,\n    tensor::{Tensor, backend::Backend},\n};\n\nuse burn_store::{ModuleSnapshot, SafetensorsStore};\n\n/// Simple model for testing SafeTensors storage\n#[derive(Module, Debug)]\npub struct TestModel<B: Backend> {\n    linear1: nn::Linear<B>,\n    linear2: nn::Linear<B>,\n}\n\nimpl<B: Backend> TestModel<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            linear1: nn::LinearConfig::new(10, 20).init(device),\n            linear2: nn::LinearConfig::new(20, 10).init(device),\n        }\n    }\n\n    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {\n        let x = self.linear1.forward(x);\n        self.linear2.forward(x)\n    }\n}\n\n/// Test basic SafeTensors save and load in no-std\npub fn test_safetensors_basic<B: Backend>(device: &B::Device) {\n    // Create a model\n    let model = TestModel::<B>::new(device);\n\n    // Save to bytes (no file I/O in no-std)\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save model\");\n\n    // Get the serialized bytes\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load from bytes\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n    let mut loaded_model = TestModel::<B>::new(device);\n    loaded_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load model\");\n\n    // Test that the model still works\n    let input = Tensor::<B, 2>::ones([2, 10], device);\n    let _output = loaded_model.forward(input);\n}\n\n/// Test SafeTensors with filtering in no-std\npub fn test_safetensors_filtering<B: Backend>(device: &B::Device) {\n    let model = TestModel::<B>::new(device);\n\n    // Save only linear1 weights\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .with_full_path(\"linear1.weight\")\n        .with_full_path(\"linear1.bias\");\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save filtered model\");\n\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load with partial loading allowed\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes)).allow_partial(true);\n    let mut partial_model = TestModel::<B>::new(device);\n    let result = partial_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load partial model\");\n\n    // Verify that only linear1 was loaded\n    assert_eq!(result.applied.len(), 2, \"Should have loaded 2 tensors\");\n    assert!(!result.missing.is_empty(), \"Should have missing tensors\");\n}\n\n/// Test SafeTensors with metadata in no-std\npub fn test_safetensors_metadata<B: Backend>(device: &B::Device) {\n    let model = TestModel::<B>::new(device);\n\n    // Save with metadata\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .metadata(\"version\", \"1.0.0\")\n        .metadata(\"environment\", \"no-std\");\n    model\n        .save_into(&mut save_store)\n        .expect(\"Failed to save model with metadata\");\n\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load and verify it works\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n    let mut loaded_model = TestModel::<B>::new(device);\n    loaded_model\n        .load_from(&mut load_store)\n        .expect(\"Failed to load model with metadata\");\n}\n\n/// Run all SafeTensors no-std tests\npub fn run_all_tests<B: Backend>(device: &B::Device) {\n    test_safetensors_basic::<B>(device);\n    test_safetensors_filtering::<B>(device);\n    test_safetensors_metadata::<B>(device);\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/tests/burnpack_tests.rs",
    "content": "extern crate alloc;\n\n#[test]\nfn test_burnpack_no_std() {\n    use burn_ndarray::NdArray;\n    use burn_no_std_tests::burnpack;\n    type Backend = NdArray<f32>;\n    let device = Default::default();\n\n    // Run all Burnpack tests\n    burnpack::run_all_tests::<Backend>(&device);\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/tests/safetensors_tests.rs",
    "content": "extern crate alloc;\n\n#[test]\nfn test_safetensors_no_std() {\n    use burn_ndarray::NdArray;\n    use burn_no_std_tests::safetensors;\n    type Backend = NdArray<f32>;\n    let device = Default::default();\n\n    // Run all SafeTensors tests\n    safetensors::run_all_tests::<Backend>(&device);\n}\n"
  },
  {
    "path": "crates/burn-no-std-tests/tests/test_integration.rs",
    "content": "#![no_std] // Must keep it for testing\n\nuse burn_no_std_tests::mlp::*;\nuse burn_no_std_tests::model::*;\n\nuse burn::tensor::{Distribution, Tensor, backend::Backend};\nuse burn_ndarray::NdArray;\n\n#[test]\nfn test_mnist_model_with_random_input() {\n    type Backend = NdArray<f32>;\n\n    // Model configurations\n    let device = Default::default();\n    let mlp_config = MlpConfig::new();\n    let mnist_config = MnistConfig::new(mlp_config);\n    let mnist_model: Model<Backend> = Model::new(&mnist_config, &device);\n\n    // Pass a fixed seed for random, otherwise a build generated random seed is used\n    Backend::seed(&device, mnist_config.seed);\n\n    // Some random input\n    let input_shape = [1, 28, 28];\n    let input = Tensor::<Backend, 3>::random(input_shape, Distribution::Default, &device);\n\n    // Run through the model\n    let output = mnist_model.forward(input);\n\n    assert_eq!(&*output.shape(), [1, 10]);\n    assert!(output.to_data().iter::<f32>().all(|x| x <= 1.0));\n}\n"
  },
  {
    "path": "crates/burn-optim/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Optimizer building blocks for the Burn deep learning framework\"\ndocumentation = \"https://docs.rs/burn-optim\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-optim\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-optim\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\n    \"std\",\n    \"burn-core/default\",\n]\ndoc = [\n    \"std\",\n    # Doc features\n    \"burn-core/doc\",\n]\nstd = [\n    \"burn-core/std\",\n    \"num-traits/std\",\n    \"serde/std\",\n    \"log\",\n]\ntracing = [\n    \"burn-collective?/tracing\",\n    \"burn-core/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-fusion?/tracing\",\n    \"burn-remote?/tracing\",\n    \"burn-rocm?/tracing\",\n    \"burn-router?/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-wgpu?/tracing\",\n]\n\ncollective = [\"burn-collective\"]\n\ntest-cuda = [\n    \"burn-cuda/default\",\n] # To use cuda during testing, default uses ndarray.\ntest-rocm = [\n    \"burn-rocm/default\",\n] # To use hip during testing, default uses ndarray.\ntest-tch = [\n    \"burn-tch/default\",\n] # To use tch during testing, default uses ndarray.\ntest-wgpu = [\n    \"burn-wgpu/default\",\n] # To use wgpu during testing, default uses ndarray.\ntest-vulkan = [\n    \"test-wgpu\",\n    \"burn-wgpu/vulkan\",\n] # To use wgpu-spirv during testing, default uses ndarray.\ntest-metal = [\n    \"test-wgpu\",\n    \"burn-wgpu/metal\",\n] # To use wgpu-spirv during testing, default uses ndarray.\n\n# Memory checks are disabled by default\ntest-memory-checks = [\"burn-fusion/memory-checks\"]\n\n[dependencies]\n\n# ** Please make sure all dependencies support no_std when std is disabled **\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-collective = { path = \"../burn-collective\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\n\nnum-traits = { workspace = true }\nderive-new = { workspace = true }\nlog = { workspace = true, optional = true }\nserde = { workspace = true, features = [\"derive\"] }\n\n# The same implementation of HashMap in std but with no_std support (only alloc crate is needed)\nhashbrown = { workspace = true, features = [\"serde\"] } # no_std compatible\n\n# FOR TESTING\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-rocm = { path = \"../burn-rocm\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-remote = { path = \"../burn-remote\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", default-features = false, optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\n\n[dev-dependencies]\nburn-nn = { path = \"../burn-nn\", version = \"=0.21.0-pre.2\" }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\" }\nrstest = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-optim/README.md",
    "content": "# Burn Optimizers\n\nCore building blocks for Burn optimizers."
  },
  {
    "path": "crates/burn-optim/src/grad_clipping/base.rs",
    "content": "use burn_core as burn;\n\nuse burn::tensor::backend::Backend;\nuse burn::{config::Config, tensor::Tensor};\n\n/// Gradient Clipping provides a way to mitigate exploding gradients\n#[derive(Config, Debug)]\npub enum GradientClippingConfig {\n    /// Clip the gradient by value.\n    Value(f32),\n\n    /// Clip the gradient by norm.\n    Norm(f32),\n}\n\nimpl GradientClippingConfig {\n    /// Initialize the gradient clipping.\n    ///\n    /// # Returns\n    ///\n    /// The gradient clipping.\n    pub fn init(&self) -> GradientClipping {\n        match self {\n            GradientClippingConfig::Value(val) => GradientClipping::Value(*val),\n            GradientClippingConfig::Norm(val) => GradientClipping::Norm(*val),\n        }\n    }\n}\n\n/// Gradient Clipping provides a way to mitigate exploding gradients\n/// by clipping every component of the gradient by value or by norm during\n/// backpropagation.\n#[derive(Clone)]\npub enum GradientClipping {\n    /// Clip the gradient by value.\n    Value(f32),\n\n    /// Clip the gradient by norm.\n    Norm(f32),\n}\n\nimpl GradientClipping {\n    /// Clip the gradient.\n    ///\n    /// # Arguments\n    ///\n    /// * `grad` - The gradient to clip.\n    ///\n    /// # Returns\n    ///\n    /// The clipped gradient.\n    pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {\n        match self {\n            GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),\n            GradientClipping::Norm(max_norm) => self.clip_by_norm(grad, *max_norm),\n        }\n    }\n\n    fn clip_by_value<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        threshold: f32,\n    ) -> Tensor<B, D> {\n        let greater_mask = grad.clone().greater_elem(threshold);\n        let lower_mask = grad.clone().lower_elem(-threshold);\n\n        let clipped_grad = grad.mask_fill(greater_mask, threshold);\n\n        clipped_grad.mask_fill(lower_mask, -threshold)\n    }\n\n    fn clip_by_norm<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        threshold: f32,\n    ) -> Tensor<B, D> {\n        let norm = Self::l2_norm(grad.clone());\n        let clip_coef = threshold / norm.add_scalar(1e-6); // avoid div by zero\n        let clip_coef_clamped = clip_coef.clamp_max(1.0);\n        grad.mul(clip_coef_clamped.unsqueeze())\n    }\n\n    fn l2_norm<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, 1> {\n        let squared = tensor.square();\n        let sum = squared.sum();\n        sum.sqrt()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n    use burn::tensor::Tensor;\n\n    #[test]\n    fn test_clip_by_value() {\n        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &Default::default(),\n        );\n\n        let clipped_gradient = GradientClipping::Value(0.5).clip_gradient(gradient);\n        let clipped_gradient_data = clipped_gradient.into_data();\n\n        for value in clipped_gradient_data.iter::<f32>() {\n            assert!(value <= 0.5);\n        }\n    }\n\n    #[test]\n    fn test_clip_by_norm() {\n        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &Default::default(),\n        );\n\n        let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient);\n        let clipped_gradient_data = clipped_gradient.into_data();\n\n        for value in clipped_gradient_data.iter::<f32>() {\n            assert!(value <= 0.88);\n        }\n    }\n    #[test]\n    fn test_clip_by_norm_no_clipping() {\n        let gradient: Tensor<TestBackend, 2> = Tensor::from_floats(\n            [[0.3, 0.4, 0.5, 0.2], [0.1, 0.6, 0.3, 0.4]],\n            &Default::default(),\n        );\n\n        let clipped_gradient = GradientClipping::Norm(2.2).clip_gradient(gradient.clone());\n\n        clipped_gradient\n            .into_data()\n            .assert_eq(&gradient.into_data(), true);\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/grad_clipping/mod.rs",
    "content": "mod base;\npub use base::*;\n"
  },
  {
    "path": "crates/burn-optim/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![recursion_limit = \"256\"]\n\n//! Burn optimizers.\n\n#[macro_use]\nextern crate derive_new;\n\nextern crate alloc;\n\n/// Optimizer module.\npub mod optim;\npub use optim::*;\n\n/// Gradient clipping module.\npub mod grad_clipping;\n\n/// Learning rate scheduler module.\n#[cfg(feature = \"std\")]\npub mod lr_scheduler;\n\n/// Type alias for the learning rate.\n///\n/// LearningRate also implements [learning rate scheduler](crate::lr_scheduler::LrScheduler) so it\n/// can be used for constant learning rate.\npub type LearningRate = f64; // We could potentially change the type.\n\n/// Backend for test cases\n#[cfg(all(\n    test,\n    not(feature = \"test-tch\"),\n    not(feature = \"test-wgpu\"),\n    not(feature = \"test-cuda\"),\n    not(feature = \"test-rocm\")\n))]\npub type TestBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(all(test, feature = \"test-tch\"))]\n/// Backend for test cases\npub type TestBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(all(test, feature = \"test-wgpu\"))]\n/// Backend for test cases\npub type TestBackend = burn_wgpu::Wgpu;\n\n#[cfg(all(test, feature = \"test-cuda\"))]\n/// Backend for test cases\npub type TestBackend = burn_cuda::Cuda;\n\n#[cfg(all(test, feature = \"test-rocm\"))]\n/// Backend for test cases\npub type TestBackend = burn_rocm::Rocm;\n\n/// Backend for autodiff test cases\n#[cfg(test)]\npub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;\n\n#[cfg(all(test, feature = \"test-memory-checks\"))]\nmod tests {\n    burn_fusion::memory_checks!();\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/base.rs",
    "content": "pub(super) use alloc::string::String;\nuse burn_core as burn;\n\nuse burn::record::Record;\nuse burn::tensor::backend::Backend;\n\nuse crate::LearningRate;\n\n/// Learning rate scheduler defines how the learning rate will evolve during training.\npub trait LrScheduler: Clone + Send + Sync {\n    /// Scheduler associative type to be used when saving and loading the state.\n    type Record<B: Backend>: Record<B>;\n\n    /// Perform the scheduler step, potentially updating its state, and returning the effective\n    /// learning rate.\n    fn step(&mut self) -> LearningRate;\n\n    /// Get the current state of the scheduler as a [record](Record).\n    fn to_record<B: Backend>(&self) -> Self::Record<B>;\n\n    /// Load the state of the scheduler as a [record](Record).\n    fn load_record<B: Backend>(self, record: Self::Record<B>) -> Self;\n}\n\n#[cfg(test)]\npub(super) mod test_utils {\n    use super::*;\n    use crate::TestBackend;\n\n    // A small tolerance for learning rate comparisons. Depending on how learning rates are\n    // computed, floating-point arithmetic error might exceed f64::EPSILON, so a larger value is\n    // used here.\n    const LOOSE_EPSILON: LearningRate = 1e-10;\n\n    pub fn check_lr_sequence<I, S>(mut scheduler: S, expected_lrs: I)\n    where\n        I: IntoIterator<Item = LearningRate>,\n        S: LrScheduler,\n    {\n        expected_lrs\n            .into_iter()\n            .enumerate()\n            .for_each(|(i, expected)| {\n                let lr = scheduler.step();\n                assert!(\n                    (lr - expected).abs() < LOOSE_EPSILON,\n                    \"Scheduled learning rate {lr} is not approximately equal to the expected value \\\n                     {expected} at step {i}\",\n                );\n            });\n    }\n\n    // save_at_step is the number of steps to run the scheduler before saving and loading back its\n    // state.\n    pub fn check_save_load<S>(mut scheduler: S, save_at_step: usize)\n    where\n        S: Clone + LrScheduler,\n    {\n        let mut truth = scheduler.clone();\n        // Consume some steps before saving and loading back\n        (0..save_at_step).for_each(|_| {\n            truth.step();\n            scheduler.step();\n        });\n        let rec = scheduler.to_record::<TestBackend>();\n        scheduler = scheduler.load_record::<TestBackend>(rec);\n\n        // Validate that the scheduler resumes from where it left off.\n        compare_steps(&mut scheduler, &mut truth, save_at_step);\n    }\n\n    // Check if two schedulers produce the same learning rate sequences over the specified number of\n    // steps.\n    pub fn compare_steps<S: LrScheduler>(a: &mut S, b: &mut S, num_steps: usize) {\n        (0..num_steps).for_each(|i| {\n            let lr_a = a.step();\n            let lr_b = b.step();\n            assert!(\n                (lr_a - lr_b).abs() < LOOSE_EPSILON,\n                \"The two learning rates ({lr_a}, {lr_b}) at position {i} in the remaining \\\n                 sequences are not approximately equal\",\n            );\n        });\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/composed.rs",
    "content": "use burn_core as burn;\n\nuse super::cosine::{CosineAnnealingLrScheduler, CosineAnnealingLrSchedulerConfig};\nuse super::exponential::{ExponentialLrScheduler, ExponentialLrSchedulerConfig};\nuse super::linear::{LinearLrScheduler, LinearLrSchedulerConfig};\nuse super::noam::{NoamLrScheduler, NoamLrSchedulerConfig};\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\n\nuse burn::config::Config;\nuse burn::record::Record;\nuse burn::tensor::backend::Backend;\n\n/// Compose multiple [learning rate schedulers](LrScheduler) together.\n#[derive(Config, Debug)]\npub struct ComposedLrSchedulerConfig {\n    #[config(default = \"Vec::new()\")]\n    schedulers: Vec<LrSchedulerConfig>,\n    #[config(default = \"SchedulerReduction::Prod\")]\n    reduction: SchedulerReduction,\n}\n\n/// Compose multiple [learning rate schedulers](LrScheduler) together.\n#[derive(Clone)]\npub struct ComposedLrScheduler {\n    schedulers: Vec<LrSchedulerItem>,\n    reduction: SchedulerReduction,\n}\n\n/// Defines how the learning rates generated by the schedulers are combined.\n#[derive(Config, Debug, Copy)]\npub enum SchedulerReduction {\n    /// All learning rates are averaged.\n    Avg,\n    /// All learning rates are summed.\n    Sum,\n    /// All learning rates are multiplied.\n    Prod,\n}\n\nimpl ComposedLrSchedulerConfig {\n    /// Initialize the learning rate scheduler.\n    pub fn init(&self) -> Result<ComposedLrScheduler, String> {\n        let mut schedulers = Vec::with_capacity(self.schedulers.len());\n        for config in self.schedulers.iter() {\n            let config = match config {\n                LrSchedulerConfig::Linear(config) => LrSchedulerItem::Linear(config.init()?),\n                LrSchedulerConfig::Cosine(config) => LrSchedulerItem::Cosine(config.init()?),\n                LrSchedulerConfig::Exponential(config) => {\n                    LrSchedulerItem::Exponential(config.init()?)\n                }\n                LrSchedulerConfig::Noam(config) => LrSchedulerItem::Noam(config.init()?),\n            };\n            schedulers.push(config);\n        }\n\n        Ok(ComposedLrScheduler {\n            schedulers,\n            reduction: self.reduction,\n        })\n    }\n\n    /// Appends a [linear scheduler](LinearLrScheduler).\n    pub fn linear(mut self, config: LinearLrSchedulerConfig) -> Self {\n        self.schedulers.push(LrSchedulerConfig::Linear(config));\n        self\n    }\n\n    /// Appends a [cosine scheduler](ComposedLrSchedulerConfig).\n    pub fn cosine(mut self, config: CosineAnnealingLrSchedulerConfig) -> Self {\n        self.schedulers.push(LrSchedulerConfig::Cosine(config));\n        self\n    }\n\n    /// Appends an [exponential scheduler](ExponentialLrScheduler).\n    pub fn exponential(mut self, config: ExponentialLrSchedulerConfig) -> Self {\n        self.schedulers.push(LrSchedulerConfig::Exponential(config));\n        self\n    }\n\n    /// Appends a [noam scheduler](NoamLrScheduler).\n    pub fn noam(mut self, config: NoamLrSchedulerConfig) -> Self {\n        self.schedulers.push(LrSchedulerConfig::Noam(config));\n        self\n    }\n}\n\n#[derive(Config, Debug)]\nenum LrSchedulerConfig {\n    Linear(LinearLrSchedulerConfig),\n    Cosine(CosineAnnealingLrSchedulerConfig),\n    Exponential(ExponentialLrSchedulerConfig),\n    Noam(NoamLrSchedulerConfig),\n}\n\n#[derive(Clone)]\nenum LrSchedulerItem {\n    Linear(LinearLrScheduler),\n    Cosine(CosineAnnealingLrScheduler),\n    Exponential(ExponentialLrScheduler),\n    Noam(NoamLrScheduler),\n}\n\n#[derive(Record)]\n/// Record item for the [composed learning rate scheduler](ComposedLrScheduler).\npub enum LrSchedulerRecord<B: Backend> {\n    /// The linear variant.\n    Linear(<LinearLrScheduler as LrScheduler>::Record<B>),\n    /// The cosine variant.\n    Cosine(<CosineAnnealingLrScheduler as LrScheduler>::Record<B>),\n    /// The exponential variant.\n    Exponential(<ExponentialLrScheduler as LrScheduler>::Record<B>),\n    /// The noam variant.\n    Noam(<NoamLrScheduler as LrScheduler>::Record<B>),\n}\n\n#[derive(Record)]\n/// Records for the [composed learning rate scheduler](ComposedLrScheduler).\npub struct ComposedLrSchedulerRecord<B: Backend> {\n    schedulers: Vec<LrSchedulerRecord<B>>,\n}\n\nimpl LrScheduler for ComposedLrScheduler {\n    type Record<B: Backend> = ComposedLrSchedulerRecord<B>;\n\n    fn step(&mut self) -> LearningRate {\n        let mut step = match self.reduction {\n            SchedulerReduction::Avg => 0.0,\n            SchedulerReduction::Sum => 0.0,\n            SchedulerReduction::Prod => 1.0,\n        };\n        let num_scheduler = self.schedulers.len() as f64;\n\n        for lr in self.schedulers.iter_mut().map(|s| match s {\n            LrSchedulerItem::Linear(item) => item.step(),\n            LrSchedulerItem::Cosine(item) => item.step(),\n            LrSchedulerItem::Exponential(item) => item.step(),\n            LrSchedulerItem::Noam(item) => item.step(),\n        }) {\n            step = match self.reduction {\n                SchedulerReduction::Avg => step + (lr / num_scheduler),\n                SchedulerReduction::Sum => step + lr,\n                SchedulerReduction::Prod => step * lr,\n            }\n        }\n\n        step\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        ComposedLrSchedulerRecord::<B> {\n            schedulers: self\n                .schedulers\n                .iter()\n                .map(|s| match s {\n                    LrSchedulerItem::Linear(item) => {\n                        LrSchedulerRecord::Linear(item.to_record::<B>())\n                    }\n                    LrSchedulerItem::Cosine(item) => {\n                        LrSchedulerRecord::Cosine(item.to_record::<B>())\n                    }\n                    LrSchedulerItem::Exponential(item) => {\n                        LrSchedulerRecord::Exponential(item.to_record::<B>())\n                    }\n                    LrSchedulerItem::Noam(item) => LrSchedulerRecord::Noam(item.to_record::<B>()),\n                })\n                .collect(),\n        }\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.schedulers = self\n            .schedulers\n            .into_iter()\n            .zip(record.schedulers)\n            .map(|scheduler| match scheduler {\n                (LrSchedulerItem::Linear(item), LrSchedulerRecord::Linear(record)) => {\n                    LrSchedulerItem::Linear(item.load_record::<B>(record))\n                }\n                (LrSchedulerItem::Cosine(item), LrSchedulerRecord::Cosine(record)) => {\n                    LrSchedulerItem::Cosine(item.load_record::<B>(record))\n                }\n                (LrSchedulerItem::Exponential(item), LrSchedulerRecord::Exponential(record)) => {\n                    LrSchedulerItem::Exponential(item.load_record::<B>(record))\n                }\n                (LrSchedulerItem::Noam(item), LrSchedulerRecord::Noam(record)) => {\n                    LrSchedulerItem::Noam(item.load_record::<B>(record))\n                }\n                _ => panic!(\"Invalid state\"),\n            })\n            .collect();\n\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/constant.rs",
    "content": "use burn_core as burn;\n\nuse burn::tensor::backend::Backend;\n\nuse super::LrScheduler;\nuse crate::LearningRate;\n\n/// Constant learning rate implementing [learning rate scheduler](LrScheduler).\n///\n/// # Notes\n///\n/// You can also use [learning rate](LearningRate) which the same effect.\n#[derive(new, Clone, Debug)]\npub struct ConstantLr {\n    lr: LearningRate,\n}\n\nimpl From<LearningRate> for ConstantLr {\n    fn from(lr: LearningRate) -> Self {\n        Self { lr }\n    }\n}\n\nimpl LrScheduler for ConstantLr {\n    type Record<B: Backend> = ();\n\n    fn step(&mut self) -> LearningRate {\n        self.lr\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {}\n\n    fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {\n        self\n    }\n}\n\nimpl LrScheduler for LearningRate {\n    type Record<B: Backend> = ();\n\n    fn step(&mut self) -> LearningRate {\n        *self\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {}\n\n    fn load_record<B: Backend>(self, _record: Self::Record<B>) -> Self {\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/cosine.rs",
    "content": "use burn_core as burn;\n\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\n\n/// The configuration for creating a [Cosine Annealing learning rate scheduler with warm\n/// restarts](CosineAnnealingLrScheduler).\n///\n/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by\n/// following a cosine function. After `num_iters` iterations, the learning rate is reset to\n/// `initial_lr`.\n#[derive(Config, Debug)]\npub struct CosineAnnealingLrSchedulerConfig {\n    // The initial learning rate.\n    initial_lr: LearningRate,\n    // The final learning rate.\n    #[config(default = 0.0)]\n    min_lr: LearningRate,\n    // The number of iterations between two restarts. The two restart iterations themselves are not\n    // included.\n    num_iters: usize,\n}\n\nimpl CosineAnnealingLrSchedulerConfig {\n    /// Initializes a [Cosine learning rate scheduler](CosineAnnealingLrScheduler).\n    ///\n    /// # Errors\n    ///\n    /// An error will be returned if any of the following conditions is true:\n    ///\n    /// * `initial_lr` is out of range (0.0, 1.0]\n    /// * `min_lr` is out of range [0.0, `initial_lr`]\n    /// * `num_iters` is 0\n    pub fn init(&self) -> Result<CosineAnnealingLrScheduler, String> {\n        if self.initial_lr <= 0. || self.initial_lr > 1. {\n            return Err(\"Initial learning rate must be greater than 0 and at most 1\".into());\n        }\n        if self.min_lr < 0.0 || self.min_lr > self.initial_lr {\n            return Err(\n                \"Minimum learning rate must be at least 0 and at most equal to the initial \\\n                 learning rate\"\n                    .into(),\n            );\n        }\n        if self.num_iters == 0 {\n            return Err(\"Number of iterations must be at least 1\".into());\n        }\n\n        Ok(CosineAnnealingLrScheduler {\n            min_lr: self.min_lr,\n            max_lr: self.initial_lr,\n            num_iters: self.num_iters,\n            current_iter: usize::MAX,\n        })\n    }\n}\n\n/// A Cosine Annealing learning rate scheduler.\n///\n/// This scheduler is described in [SGDR: Stochastic Gradient Descent with Warm\n/// Restarts](https://arxiv.org/abs/1608.03983). See [CosineAnnealingLrSchedulerConfig] for more\n/// information.\n#[derive(Clone, Copy, Debug)]\npub struct CosineAnnealingLrScheduler {\n    min_lr: LearningRate,\n    max_lr: LearningRate,\n    num_iters: usize,\n    current_iter: usize,\n}\n\nimpl LrScheduler for CosineAnnealingLrScheduler {\n    type Record<B: Backend> = usize;\n\n    fn step(&mut self) -> LearningRate {\n        // Make current_iter overflow from usize::MAX to 0 to get the initial learning rate on the\n        // first call. We could've used i64 with an initial value -1, but keeping it in usize saves\n        // us from some type casting here.\n        self.current_iter = self.current_iter.wrapping_add(1) % (self.num_iters + 1);\n        self.min_lr\n            + 0.5\n                * (self.max_lr - self.min_lr)\n                * (1.0\n                    + (self.current_iter as f64 / self.num_iters as f64 * std::f64::consts::PI)\n                        .cos())\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        self.current_iter\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.current_iter = record;\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::super::test_utils;\n    use super::*;\n\n    #[test]\n    fn config_initial_lr_too_low() {\n        let r = CosineAnnealingLrSchedulerConfig::new(0., 10).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_initial_lr_too_high() {\n        let r = CosineAnnealingLrSchedulerConfig::new(1.5, 10).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_min_lr_too_low() {\n        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)\n            .with_min_lr(-0.1)\n            .init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Minimum learning rate must be at least 0 and at most equal to the initial learning \\\n             rate\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_min_lr_too_high() {\n        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 10)\n            .with_min_lr(0.6)\n            .init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Minimum learning rate must be at least 0 and at most equal to the initial learning \\\n             rate\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_num_iters_too_low() {\n        let r = CosineAnnealingLrSchedulerConfig::new(0.5, 0).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Number of iterations must be at least 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn test_lr_change() {\n        const INITIAL_LR: LearningRate = 0.5;\n        const MIN_LR: LearningRate = 0.1;\n\n        let scheduler = CosineAnnealingLrSchedulerConfig::new(INITIAL_LR, 2)\n            .with_min_lr(MIN_LR)\n            .init()\n            .unwrap();\n        let expected_lrs = [\n            INITIAL_LR,                  // cos(0)\n            (INITIAL_LR + MIN_LR) * 0.5, // cos(PI/2)\n            MIN_LR,                      // cos(PI)\n            INITIAL_LR,                  // restart\n        ];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_save_and_load() {\n        const NUM_ITERS: usize = 9;\n        let scheduler = CosineAnnealingLrSchedulerConfig::new(1.0, NUM_ITERS)\n            .init()\n            .unwrap();\n        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/exponential.rs",
    "content": "use burn_core as burn;\n\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\n\n/// The configuration for creating an [exponential learning rate scheduler](ExponentialLrScheduler).\n///\n/// This scheduler returns the learning rate `initial_lr` at the first step, then multiplies it by\n/// a constant `gamma` at every iteration. At any iteration `i` (which starts from 0), the learning\n/// rate is given by `initial_lr * gamma^i`.\n#[derive(Config, Debug)]\npub struct ExponentialLrSchedulerConfig {\n    // The initial learning rate.\n    initial_lr: LearningRate,\n    // The constant that the learning rate is multiplied by on each iteration.\n    gamma: f64,\n}\n\nimpl ExponentialLrSchedulerConfig {\n    /// Initializes a [exponential learning rate scheduler](ExponentialLrScheduler).\n    ///\n    /// # Errors\n    ///\n    /// An error will be returned if any of the following conditions is true:\n    ///\n    /// * `initial_lr` is out of range (0.0, 1.0]\n    /// * `gamma` is out of range (0.0, 1.0]\n    pub fn init(&self) -> Result<ExponentialLrScheduler, String> {\n        if self.initial_lr <= 0. || self.initial_lr > 1. {\n            return Err(\"Initial learning rate must be greater than 0 and at most 1\".into());\n        }\n        if self.gamma <= 0. || self.gamma > 1. {\n            return Err(\"Gamma must be greater than 0 and at most 1\".into());\n        }\n\n        Ok(ExponentialLrScheduler {\n            // Such an initial value eliminates the need for special-case handling of the first\n            // learning rate.\n            previous_lr: self.initial_lr / self.gamma,\n            gamma: self.gamma,\n        })\n    }\n}\n\n/// A exponential learning rate scheduler.\n///\n/// See [ExponentialLrSchedulerConfig] for more information.\n#[derive(Clone, Copy, Debug)]\npub struct ExponentialLrScheduler {\n    // The previous iteration's learning rate.\n    previous_lr: LearningRate,\n    // The constant that the learning rate is multiplied by on each iteration.\n    gamma: f64,\n}\n\nimpl LrScheduler for ExponentialLrScheduler {\n    type Record<B: Backend> = LearningRate;\n\n    fn step(&mut self) -> LearningRate {\n        self.previous_lr *= self.gamma;\n        self.previous_lr\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        self.previous_lr\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.previous_lr = record;\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::super::test_utils;\n    use super::*;\n\n    #[test]\n    fn config_initial_lr_too_low() {\n        let r = ExponentialLrSchedulerConfig::new(0., 0.5).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_initial_lr_too_high() {\n        let r = ExponentialLrSchedulerConfig::new(1.5, 0.5).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_gamma_too_low() {\n        let r = ExponentialLrSchedulerConfig::new(0.5, 0.0).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Gamma must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_gamma_too_high() {\n        let r = ExponentialLrSchedulerConfig::new(0.5, 1.5).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Gamma must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn test_lr_change() {\n        let scheduler = ExponentialLrSchedulerConfig::new(0.8, 0.1).init().unwrap();\n        let expected_lrs = [0.8, 0.08, 0.008, 0.0008, 0.00008];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_save_and_load() {\n        let scheduler = ExponentialLrSchedulerConfig::new(0.083, 0.3)\n            .init()\n            .unwrap();\n        test_utils::check_save_load(scheduler, 7);\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/linear.rs",
    "content": "use burn_core as burn;\n\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\n\n/// The configuration for creating a [linear learning rate scheduler](LinearLrScheduler).\n///\n/// This scheduler returns the learning rate `initial_lr` at the first step, then changes it by a\n/// constant amount on each iteration until reaching a final learning rate `final_lr`. The\n/// `num_iters` parameter controls how many iterations are needed to go from `initial_lr` to\n/// `final_lr`.\n#[derive(Config, Debug)]\npub struct LinearLrSchedulerConfig {\n    // The initial learning rate.\n    initial_lr: LearningRate,\n    // The final learning rate.\n    final_lr: LearningRate,\n    // The number of iterations before reaching the final learning rate.\n    num_iters: usize,\n}\n\nimpl LinearLrSchedulerConfig {\n    /// Initializes a [linear learning rate scheduler](LinearLrScheduler).\n    ///\n    /// # Errors\n    ///\n    /// An error will be returned if any of the following conditions is true:\n    ///\n    /// * `initial_lr` is out of range (0.0, 1.0]\n    /// * `final_lr` is out of range [0.0, 1.0]\n    /// * `num_iters` is 0\n    pub fn init(&self) -> Result<LinearLrScheduler, String> {\n        if self.initial_lr <= 0. || self.initial_lr > 1. {\n            return Err(\"Initial learning rate must be greater than 0 and at most 1\".into());\n        }\n        if self.final_lr < 0. || self.final_lr > 1. {\n            return Err(\"Final learning rate must be at least 0 and at most 1\".into());\n        }\n        if self.num_iters == 0 {\n            return Err(\"Number of iterations must be at least 1\".into());\n        }\n\n        Ok(LinearLrScheduler {\n            final_lr: self.final_lr,\n            step_size: (self.final_lr - self.initial_lr) / self.num_iters as f64,\n            remaining_iters: self.num_iters + 1,\n        })\n    }\n}\n\n/// A linear learning rate scheduler.\n///\n/// See [LinearLrSchedulerConfig] for more information.\n#[derive(Clone, Copy, Debug)]\npub struct LinearLrScheduler {\n    // The final learning rate after the linear changing process stops.\n    final_lr: LearningRate,\n    // The amount that the learning rate changes by on each iteration.\n    step_size: f64,\n    // The number of iterations left before reaching the final learning rate.\n    remaining_iters: usize,\n}\n\nimpl LrScheduler for LinearLrScheduler {\n    type Record<B: Backend> = usize;\n\n    fn step(&mut self) -> LearningRate {\n        self.remaining_iters -= (self.remaining_iters != 0) as usize;\n        self.final_lr - self.step_size * self.remaining_iters as f64\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        self.remaining_iters\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.remaining_iters = record;\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::super::test_utils;\n    use super::*;\n\n    #[test]\n    fn config_initial_lr_too_low() {\n        let r = LinearLrSchedulerConfig::new(0., 0.5, 100).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_initial_lr_too_high() {\n        let r = LinearLrSchedulerConfig::new(1.5, 0.5, 100).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Initial learning rate must be greater than 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_final_lr_too_low() {\n        let r = LinearLrSchedulerConfig::new(0.5, -0.5, 100).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Final learning rate must be at least 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_final_lr_too_high() {\n        let r = LinearLrSchedulerConfig::new(0.5, 1.5, 100).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Final learning rate must be at least 0 and at most 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn config_num_iters_too_low() {\n        let r = LinearLrSchedulerConfig::new(0.9, 0.1, 0).init();\n        assert!(r.is_err(), \"Should return an error\");\n        assert_eq!(\n            r.unwrap_err(),\n            \"Number of iterations must be at least 1\",\n            \"Error messages should match\",\n        );\n    }\n\n    #[test]\n    fn test_lr_decreasing() {\n        let scheduler = LinearLrSchedulerConfig::new(0.9, 0.5, 4).init().unwrap();\n        let expected_lrs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.5];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_lr_increasing() {\n        let scheduler = LinearLrSchedulerConfig::new(0.01, 0.04, 3).init().unwrap();\n        let expected_lrs = [0.01, 0.02, 0.03, 0.04, 0.04];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_lr_unchanging() {\n        let scheduler = LinearLrSchedulerConfig::new(0.3, 0.3, 2).init().unwrap();\n        let expected_lrs = [0.3, 0.3, 0.3, 0.3];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_save_and_load() {\n        const NUM_ITERS: usize = 6;\n        let scheduler = LinearLrSchedulerConfig::new(1.0, 0.01, NUM_ITERS)\n            .init()\n            .unwrap();\n        test_utils::check_save_load(scheduler, NUM_ITERS / 3 * 2);\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/mod.rs",
    "content": "/// Constant learning rate scheduler\npub mod constant;\n\n/// Composed learning rate scheduler\npub mod composed;\n\n/// Linear learning rate scheduler\npub mod linear;\n\n/// Noam learning rate scheduler\npub mod noam;\n\n/// Exponential learning rate scheduler\npub mod exponential;\n\n/// Cosine learning rate scheduler\npub mod cosine;\n\n/// Step learning rate scheduler\npub mod step;\n\nmod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/noam.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\n\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\n\n/// Configuration to create a [noam](NoamLrScheduler) learning rate scheduler.\n#[derive(Config, Debug)]\npub struct NoamLrSchedulerConfig {\n    /// The overall scale factor for the learning rate decay.\n    factor: f64,\n    /// The number of steps before the exponential decay stats.\n    #[config(default = 4000)]\n    warmup_steps: usize,\n    /// The size of the model.\n    #[config(default = 512)]\n    model_size: usize,\n}\n\n/// Noam learning rate scheduler as described in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).\n#[derive(Clone, Debug)]\npub struct NoamLrScheduler {\n    warmup_steps: f64,\n    embedding_size: f64,\n    factor: f64,\n    step: f64,\n}\n\nimpl NoamLrSchedulerConfig {\n    /// Initialize a new [noam](NoamLrScheduler) learning rate scheduler.\n    ///\n    /// # Errors\n    ///\n    /// An error will be returned if any of the following conditions is true:\n    ///\n    /// * `warmup_steps` is 0\n    /// * `model_size` is 0\n    pub fn init(&self) -> Result<NoamLrScheduler, String> {\n        if self.warmup_steps == 0 {\n            return Err(\n                \"Number of steps before exponential decay starts must be greater than 0\".into(),\n            );\n        }\n        if self.model_size == 0 {\n            return Err(\"Model size must be greater than 0\".into());\n        }\n\n        Ok(NoamLrScheduler {\n            warmup_steps: self.warmup_steps as f64,\n            embedding_size: self.model_size as f64,\n            factor: self.factor,\n            step: 0.0,\n        })\n    }\n}\n\nimpl LrScheduler for NoamLrScheduler {\n    type Record<B: Backend> = usize;\n\n    fn step(&mut self) -> LearningRate {\n        self.step += 1.0;\n\n        let arg1 = self.step.powf(-0.5);\n        let arg2 = self.step * self.warmup_steps.powf(-1.5);\n\n        self.factor * self.embedding_size.powf(-0.5) * f64::min(arg1, arg2)\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        self.step as usize\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.step = record as f64;\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_config_warmup_steps_invalid() {\n        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(0).init();\n        assert!(r.is_err(), \"Should return an error\");\n    }\n\n    #[test]\n    fn test_config_warmup_steps_valid() {\n        let r = NoamLrSchedulerConfig::new(0.1).with_warmup_steps(1).init();\n        assert!(r.is_ok(), \"Should return a success value\");\n    }\n\n    #[test]\n    fn test_config_model_size_invalid() {\n        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(0).init();\n        assert!(r.is_err(), \"Should return an error\");\n    }\n\n    #[test]\n    fn test_config_model_size_valid() {\n        let r = NoamLrSchedulerConfig::new(0.1).with_model_size(1).init();\n        assert!(r.is_ok(), \"Should return a success value\");\n    }\n\n    #[test]\n    fn test_function_increase_and_decrease() {\n        let warmup_steps = 100;\n        let mut scheduler = NoamLrSchedulerConfig::new(10.0)\n            .with_warmup_steps(warmup_steps)\n            .init()\n            .unwrap();\n        let mut lr_current = 0.0;\n\n        for _ in 0..warmup_steps {\n            let lr = scheduler.step();\n            assert!(\n                lr > lr_current,\n                \"Learning rate should increase before the warmup_steps is reached.\"\n            );\n            lr_current = lr;\n        }\n\n        for _ in 0..warmup_steps {\n            let lr = scheduler.step();\n            assert!(\n                lr < lr_current,\n                \"Learning rate should decrease after the warmup_steps is reached.\"\n            );\n            lr_current = lr;\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/lr_scheduler/step.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\n\nuse super::{LrScheduler, String};\nuse crate::LearningRate;\n\n/// The configuration for create a [step learning rate scheduler](StepLrScheduler).\n///\n/// This scheduler returns the learning rate `initial_lr` from the start, and keeps doing so until\n/// the same value has been given for `step_size` times. Then it multiplies the learning rate by\n/// `gamma` before repeating the process.\n///\n/// Gamma values out of range (0.0, 1.0) and non-positive initial learning rates are acceptable, but\n/// a warning log will be output for such a value in case of mistyping.\n///\n/// ## Notes\n///\n/// The [step](StepLrScheduler::step) method of the scheduler panics if it is called more than\n/// `i32::MAX + 1` times.\n#[derive(Config, Debug)]\npub struct StepLrSchedulerConfig {\n    // The learning rate at the initial step.\n    initial_lr: LearningRate,\n    // The number of iterations over which the learning rate remains unchanged before the next\n    // update.\n    step_size: usize,\n    /// The factor by which the learning rate is multiplied with each update. Default: 0.1.\n    #[config(default = 0.1)]\n    gamma: f64,\n}\n\nimpl StepLrSchedulerConfig {\n    /// Initializes a [step learning rate scheduler](StepLrScheduler).\n    ///\n    /// # Errors\n    ///\n    /// An error will be returned if `step_size` is 0.\n    pub fn init(&self) -> Result<StepLrScheduler, String> {\n        if self.step_size == 0 {\n            return Err(\"Step size must be greater than 0\".into());\n        }\n\n        // Atypical values of `initial_lr` and `gamma` are not rejected because they might be useful\n        // in some cases like debugging (e.g., https://datascience.stackexchange.com/q/89518).\n        if self.initial_lr <= 0.0 {\n            log::warn!(\n                \"Initial learning rate value of {} is not a positive number. Ignore this warning \\\n                 if it is intended.\",\n                self.initial_lr\n            );\n        }\n        if self.gamma <= 0.0 || self.gamma >= 1.0 {\n            log::warn!(\n                \"Gamma value of {} is out of range (0.0, 1.0). Ignore this warning if it is \\\n                 intended.\",\n                self.gamma\n            );\n        }\n\n        Ok(StepLrScheduler {\n            init_lr: self.initial_lr,\n            step_size: self.step_size,\n            gamma: self.gamma,\n            iter_idx: -1,\n        })\n    }\n}\n\n/// Step learning rate scheduler.\n#[derive(Clone, Debug)]\npub struct StepLrScheduler {\n    init_lr: LearningRate,\n    step_size: usize,\n    gamma: f64,\n    // The index of the current iteration.\n    // `i32` is used for avoiding truncating the exponent when taking powers of `gamma`.\n    iter_idx: i32,\n}\n\nimpl LrScheduler for StepLrScheduler {\n    type Record<B: Backend> = i32;\n\n    fn step(&mut self) -> LearningRate {\n        self.iter_idx = self\n            .iter_idx\n            .checked_add(1)\n            .expect(\"`.step()` should be called no more than `i32::MAX + 1` times\");\n        // Type casting below causes no truncation, as all the values fall within the ranges.\n        self.init_lr\n            * self\n                .gamma\n                .powi((self.iter_idx as usize / self.step_size) as i32)\n    }\n\n    fn to_record<B: Backend>(&self) -> Self::Record<B> {\n        self.iter_idx\n    }\n\n    fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {\n        self.iter_idx = record;\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::super::test_utils;\n    use super::*;\n    use crate::TestBackend;\n\n    // Warning logs for initial LR and gamma are not tested because there seems no straightforward\n    // way to do it.\n    //\n    // Creating a mock logger that collects logs into `String` for later examination seems a possible\n    // solution, but unit tests run in the same process in parallel, where the single logger would\n    // be shared by multiple tests, so logs from different tests would be mixed up with no easy way\n    // to separate them.\n    // Using \"--test-threads=1\" could prevent mixup, but whether the ability to test logging is\n    // worth the slowdown would be a question. Also, using a primitive provided by `std` to\n    // synchronize the logger across tests is not an option since we need to support `no-std`.\n    // Maybe the mocking approach can be reconsidered after we are given an option to run tests in\n    // separate processes like what the issue below is proposing:\n    //     https://github.com/rust-lang/rust/issues/47506\n    //\n    // As a side note, a helper crate exists for the exact purpose:\n    //     https://crates.io/crates/testing_logger\n    // but the crate has been unmaintained and using it would introduce another dependency.\n\n    #[test]\n    fn test_config_step_size_zero() {\n        let r = StepLrSchedulerConfig::new(1.0, 0).init();\n        assert!(r.is_err(), \"Should return an error\");\n    }\n\n    #[test]\n    fn test_config_step_size_nonzero() {\n        let r = StepLrSchedulerConfig::new(1.0, 1).init();\n        assert!(r.is_ok(), \"Should return a success value\");\n    }\n\n    #[test]\n    fn test_config_default_gamma() {\n        const INIT_LR: LearningRate = 0.4;\n        const STEP_SIZE: usize = 2;\n\n        let mut default = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)\n            .init()\n            .unwrap();\n        let mut explicit = StepLrSchedulerConfig::new(INIT_LR, STEP_SIZE)\n            .with_gamma(0.1)\n            .init()\n            .unwrap();\n        test_utils::compare_steps(&mut default, &mut explicit, 3 * STEP_SIZE);\n    }\n\n    #[test]\n    fn test_lr_decreasing() {\n        let scheduler = StepLrSchedulerConfig::new(0.5, 3)\n            .with_gamma(0.1)\n            .init()\n            .unwrap();\n        let expected_lrs = [0.5, 0.5, 0.5, 0.05, 0.05, 0.05, 0.005, 0.005, 0.005];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_lr_increasing() {\n        let scheduler = StepLrSchedulerConfig::new(0.1, 2)\n            .with_gamma(2.0)\n            .init()\n            .unwrap();\n        let expected_lrs = [0.1, 0.1, 0.2, 0.2, 0.4, 0.4];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_lr_unchanging() {\n        let scheduler = StepLrSchedulerConfig::new(3.1, 1)\n            .with_gamma(1.0)\n            .init()\n            .unwrap();\n        let expected_lrs = [3.1, 3.1, 3.1];\n        test_utils::check_lr_sequence(scheduler, expected_lrs);\n    }\n\n    #[test]\n    fn test_save_and_load() {\n        const STEP_SIZE: usize = 10;\n\n        let scheduler = StepLrSchedulerConfig::new(0.007, STEP_SIZE)\n            .with_gamma(0.03)\n            .init()\n            .unwrap();\n        test_utils::check_save_load(scheduler, 3 * STEP_SIZE / 2);\n    }\n\n    // It's too time consuming to actually run a scheduler `i32::MAX` steps, so an approach that\n    // depends on private fields is used to implement the test.\n    #[test]\n    fn test_number_of_calls_within_limit() {\n        // Create a scheduler that has already run `i32::MAX` steps\n        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();\n        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);\n        scheduler.step();\n    }\n\n    #[test]\n    #[should_panic = \"i32::MAX\"]\n    fn test_number_of_calls_over_limit() {\n        // Create a scheduler that has already run `i32::MAX` steps\n        let mut scheduler = StepLrSchedulerConfig::new(0.1, 2).init().unwrap();\n        scheduler = scheduler.load_record::<TestBackend>(i32::MAX - 1);\n        scheduler.step();\n        scheduler.step();\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/adagrad.rs",
    "content": "use burn_core as burn;\n\nuse burn::{module::AutodiffModule, record::Record};\n\nuse burn::config::Config;\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse burn::tensor::{backend::Backend, ops::Device};\n\nuse super::{\n    SimpleOptimizer,\n    adaptor::OptimizerAdaptor,\n    decay::{WeightDecay, WeightDecayConfig},\n};\nuse crate::{LearningRate, grad_clipping::GradientClippingConfig};\n\n/// AdaGrad configuration.\n#[derive(Config, Debug)]\npub struct AdaGradConfig {\n    #[config(default = 0.)]\n    lr_decay: f64,\n    #[config(default = 1e-5)]\n    epsilon: f32,\n    /// [Weight decay](WeightDecayConfig) config.\n    weight_decay: Option<WeightDecayConfig>,\n    /// [Gradient Clipping](GradientClippingConfig) config.\n    grad_clipping: Option<GradientClippingConfig>,\n}\n\n/// AdaGrad optimizer\n#[derive(Clone)]\npub struct AdaGrad {\n    lr_decay: LrDecay,\n    weight_decay: Option<WeightDecay>,\n}\n\n/// AdaGrad state.\n#[derive(Record, Clone, new)]\npub struct AdaGradState<B: Backend, const D: usize> {\n    lr_decay: LrDecayState<B, D>,\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for AdaGrad {\n    type State<const D: usize> = AdaGradState<B, D>;\n\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        mut grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        let mut state_lr_decay = None;\n\n        if let Some(state) = state {\n            state_lr_decay = Some(state.lr_decay);\n        }\n\n        if let Some(weight_decay) = &self.weight_decay {\n            grad = weight_decay.transform(grad, tensor.clone());\n        }\n\n        let (grad, state_lr_decay) = self.lr_decay.transform(grad, lr, state_lr_decay);\n\n        let state = AdaGradState::new(state_lr_decay);\n\n        (tensor - grad, Some(state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {\n        state.lr_decay = state.lr_decay.to_device(device);\n        state\n    }\n}\n\nimpl AdaGradConfig {\n    /// Initialize AdaGrad optimizer.\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer that can be used to optimize a module.\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(\n        &self,\n    ) -> OptimizerAdaptor<AdaGrad, M, B> {\n        let optim = AdaGrad {\n            lr_decay: LrDecay {\n                lr_decay: self.lr_decay,\n                epsilon: self.epsilon,\n            },\n            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),\n        };\n\n        let mut optim = OptimizerAdaptor::from(optim);\n        if let Some(config) = &self.grad_clipping {\n            optim = optim.with_grad_clipping(config.init());\n        }\n        optim\n    }\n}\n\n/// Learning rate decay state (also includes sum state).\n#[derive(Record, new, Clone)]\npub struct LrDecayState<B: Backend, const D: usize> {\n    time: usize,\n    sum: Tensor<B, D>,\n}\n\n#[derive(Clone)]\nstruct LrDecay {\n    lr_decay: f64,\n    epsilon: f32,\n}\n\nimpl LrDecay {\n    pub fn transform<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        lr: LearningRate,\n        lr_decay_state: Option<LrDecayState<B, D>>,\n    ) -> (Tensor<B, D>, LrDecayState<B, D>) {\n        let state = if let Some(mut state) = lr_decay_state {\n            state.sum = state.sum.add(grad.clone().square());\n            state.time += 1;\n            state\n        } else {\n            LrDecayState::new(1, grad.clone().square())\n        };\n\n        let new_lr = lr / (1. + (state.time as f64 - 1.) * self.lr_decay);\n\n        let grad = grad\n            .div(state.sum.clone().sqrt().add_scalar(self.epsilon))\n            .mul_scalar(new_lr);\n\n        (grad, state)\n    }\n}\n\nimpl<B: Backend, const D: usize> LrDecayState<B, D> {\n    /// Move state to device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move state to.\n    ///\n    /// # Returns\n    ///\n    /// Returns state moved to device.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.sum = self.sum.to_device(device);\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::Tolerance;\n    use burn::tensor::ops::FloatElem;\n\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use crate::{GradientsParams, Optimizer};\n    use burn::module::{Module, Param};\n    use burn::tensor::{Distribution, Tensor, TensorData};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    const LEARNING_RATE: LearningRate = 0.01;\n\n    #[test]\n    fn test_adagrad_optimizer_save_load_state() {\n        let device = Default::default();\n        let linear = LinearConfig::new(6, 6).init(&device);\n        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);\n        let mut optimizer = create_adagrad();\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let _linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        #[cfg(feature = \"std\")]\n        {\n            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};\n\n            BinFileRecorder::<FullPrecisionSettings>::default()\n                .record(\n                    optimizer.to_record(),\n                    std::env::temp_dir().as_path().join(\"test_optim_adagrad\"),\n                )\n                .unwrap();\n        }\n        #[cfg(not(feature = \"std\"))]\n        {\n            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};\n\n            let result = BinBytesRecorder::<FullPrecisionSettings>::default()\n                .record(optimizer.to_record(), ())\n                .unwrap();\n            assert!(!result.is_empty());\n        }\n\n        let state_optim_before = optimizer.to_record();\n        let state_optim_before_copy = optimizer.to_record();\n        let optimizer = create_adagrad();\n        let optimizer = optimizer.load_record(state_optim_before_copy);\n        let state_optim_after = optimizer.to_record();\n\n        assert_eq!(state_optim_before.len(), state_optim_after.len());\n    }\n\n    #[test]\n    fn test_adagrad_optimizer_with_numbers() {\n        let device = Default::default();\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = AdaGradConfig::new()\n            .with_epsilon(1e-8)\n            .with_lr_decay(0.5)\n            .init();\n\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        let weights_expected = TensorData::from([\n            [-0.334989, 0.123011, 0.389911, 0.305611, 0.071511, 0.052711],\n            [\n                0.066144, -0.030056, -0.378256, 0.243444, 0.183944, -0.303756,\n            ],\n            [\n                -0.033462, 0.020138, -0.310662, 0.233938, -0.292462, 0.298538,\n            ],\n            [\n                -0.312636, -0.236036, -0.386136, -0.312736, -0.090736, 0.147964,\n            ],\n            [\n                0.315896, -0.232304, 0.357596, -0.187004, 0.365496, -0.044504,\n            ],\n            [-0.030305, -0.026405, 0.111395, 0.177695, 0.014895, 0.368895],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.405214, 0.073686, -0.111714, 0.102886, 0.121886, -0.001714,\n        ]);\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.val().into_data(),\n            state_updated.bias.unwrap().val().into_data(),\n        );\n\n        type FT = FloatElem<TestAutodiffBackend>;\n        let tolerance = Tolerance::absolute(1e-6);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: Some(Param::from_data(bias, &device)),\n        };\n\n        LinearConfig::new(6, 6).init(&device).load_record(record)\n    }\n\n    fn create_adagrad()\n    -> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {\n        let config = AdaGradConfig::new();\n        AdaGrad {\n            lr_decay: LrDecay {\n                lr_decay: config.lr_decay,\n                epsilon: config.epsilon,\n            },\n            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),\n        }\n        .into()\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/adam.rs",
    "content": "use burn_core as burn;\n\nuse burn::{module::AutodiffModule, record::Record};\n\nuse burn::config::Config;\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse burn::tensor::{backend::Backend, ops::Device};\n\nuse super::{\n    SimpleOptimizer,\n    adaptor::OptimizerAdaptor,\n    decay::{WeightDecay, WeightDecayConfig},\n};\nuse crate::{LearningRate, grad_clipping::GradientClippingConfig};\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Adam configuration.\n#[derive(Config, Debug)]\npub struct AdamConfig {\n    /// Parameter for Adam.\n    #[config(default = 0.9)]\n    beta_1: f32,\n    /// Parameter for Adam.\n    #[config(default = 0.999)]\n    beta_2: f32,\n    /// A value required for numerical stability.\n    #[config(default = 1e-5)]\n    epsilon: f32,\n    /// Whether to use AMSGrad algorithm\n    #[config(default = false)]\n    amsgrad: bool,\n    /// [Weight decay](WeightDecayConfig) config.\n    weight_decay: Option<WeightDecayConfig>,\n    /// [Gradient Clipping](GradientClippingConfig) config.\n    grad_clipping: Option<GradientClippingConfig>,\n}\n\n/// Adam optimizer.\n///\n/// See:\n/// - [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).\n/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)\n#[derive(Clone)]\npub struct Adam {\n    momentum: AdaptiveMomentum,\n    weight_decay: Option<WeightDecay>,\n}\n\n/// Adam state.\n#[derive(Record, Clone, new)]\npub struct AdamState<B: Backend, const D: usize> {\n    /// The current adaptive momentum.\n    pub momentum: AdaptiveMomentumState<B, D>,\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for Adam {\n    type State<const D: usize> = AdamState<B, D>;\n\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        mut grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        let mut state_momentum = None;\n\n        if let Some(state) = state {\n            state_momentum = Some(state.momentum);\n        }\n\n        if let Some(weight_decay) = &self.weight_decay {\n            grad = weight_decay.transform(grad, tensor.clone());\n        }\n\n        let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);\n\n        let state = AdamState::new(state_momentum);\n        let delta = grad.mul_scalar(lr);\n\n        (tensor - delta, Some(state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {\n        state.momentum = state.momentum.to_device(device);\n        state\n    }\n}\n\nimpl AdamConfig {\n    /// Initialize Adam optimizer.\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer that can be used to optimize a module.\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {\n        let optim = Adam {\n            momentum: AdaptiveMomentum {\n                beta_1: self.beta_1,\n                beta_2: self.beta_2,\n                epsilon: self.epsilon,\n                amsgrad: self.amsgrad,\n            },\n            weight_decay: self.weight_decay.as_ref().map(WeightDecay::new),\n        };\n\n        let mut optim = OptimizerAdaptor::from(optim);\n        if let Some(config) = &self.grad_clipping {\n            optim = optim.with_grad_clipping(config.init());\n        }\n        optim\n    }\n}\n\n/// Adaptive momentum state.\n#[derive(Record, new, Clone)]\npub struct AdaptiveMomentumState<B: Backend, const D: usize> {\n    /// The number of iterations aggregated.\n    pub time: usize,\n    /// The first order momentum.\n    pub moment_1: Tensor<B, D>,\n    /// The second order momentum.\n    pub moment_2: Tensor<B, D>,\n    /// Max of second  order momentum (for AMSGrad)\n    #[new(default)]\n    pub max_moment_2: Option<Tensor<B, D>>,\n}\n\n#[derive(Clone)]\nstruct AdaptiveMomentum {\n    beta_1: f32,\n    beta_2: f32,\n    epsilon: f32,\n    amsgrad: bool,\n}\n\nimpl AdaptiveMomentum {\n    pub fn transform<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        momentum_state: Option<AdaptiveMomentumState<B, D>>,\n    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {\n        let state = if let Some(mut state) = momentum_state {\n            let factor = 1.0 - self.beta_1;\n            state.moment_1 = state\n                .moment_1\n                .mul_scalar(self.beta_1)\n                .add(grad.clone().mul_scalar(factor));\n\n            let factor = 1.0 - self.beta_2;\n            state.moment_2 = state\n                .moment_2\n                .mul_scalar(self.beta_2)\n                .add(grad.square().mul_scalar(factor));\n            if self.amsgrad {\n                let max_v = state\n                    .max_moment_2\n                    .take()\n                    .unwrap_or_else(|| state.moment_2.clone());\n\n                let new_max = max_v.max_pair(state.moment_2.clone());\n                state.max_moment_2 = Some(new_max);\n            }\n\n            state.time += 1;\n\n            state\n        } else {\n            let factor = 1.0 - self.beta_1;\n            let moment_1 = grad.clone().mul_scalar(factor);\n\n            let factor = 1.0 - self.beta_2;\n            let moment_2 = grad.square().mul_scalar(factor);\n            let max_moment_2 = self.amsgrad.then(|| moment_2.clone());\n            AdaptiveMomentumState {\n                time: 1,\n                moment_1,\n                moment_2,\n                max_moment_2,\n            }\n        };\n\n        let time = state.time as i32;\n        let bias_correction2_sqrt = (1.0 - self.beta_2.powi(time)).sqrt();\n        let combined_factor = bias_correction2_sqrt / (1.0 - self.beta_1.powi(time));\n\n        let v_to_use = if self.amsgrad {\n            state.max_moment_2.as_ref().unwrap_or(&state.moment_2)\n        } else {\n            &state.moment_2\n        };\n\n        let grad = state.moment_1.clone().mul_scalar(combined_factor).div(\n            v_to_use\n                .clone()\n                .sqrt()\n                .add_scalar(self.epsilon * bias_correction2_sqrt),\n        );\n        (grad, state)\n    }\n}\n\nimpl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {\n    /// Move state to device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move state to.\n    ///\n    /// # Returns\n    ///\n    /// Returns state moved to device.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.moment_1 = self.moment_1.to_device(device);\n        self.moment_2 = self.moment_2.to_device(device);\n        self.max_moment_2 = self.max_moment_2.map(|tensor| tensor.to_device(device));\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::Tolerance;\n    use burn::tensor::ops::FloatElem;\n\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use crate::{GradientsParams, Optimizer};\n    use burn::module::{Module, Param};\n    use burn::tensor::{Distribution, Tensor, TensorData};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    const LEARNING_RATE: LearningRate = 0.01;\n\n    #[test]\n    fn test_adam_optimizer_save_load_state() {\n        let device = Default::default();\n        let linear = LinearConfig::new(6, 6).init(&device);\n        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);\n        let mut optimizer = create_adam();\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let _linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        #[cfg(feature = \"std\")]\n        {\n            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};\n\n            BinFileRecorder::<FullPrecisionSettings>::default()\n                .record(\n                    optimizer.to_record(),\n                    std::env::temp_dir().as_path().join(\"test_optim_adam\"),\n                )\n                .unwrap();\n        }\n        #[cfg(not(feature = \"std\"))]\n        {\n            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};\n\n            let result = BinBytesRecorder::<FullPrecisionSettings>::default()\n                .record(optimizer.to_record(), ())\n                .unwrap();\n            assert!(!result.is_empty());\n        }\n\n        let state_optim_before = optimizer.to_record();\n        let state_optim_before_copy = optimizer.to_record();\n        let optimizer = create_adam();\n        let optimizer = optimizer.load_record(state_optim_before_copy);\n        let state_optim_after = optimizer.to_record();\n\n        assert_eq!(state_optim_before.len(), state_optim_after.len());\n    }\n    #[test]\n    fn test_adam_optimizer_with_amsgrad_50_steps() {\n        let device = Default::default();\n        let mut linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n\n        let mut optimizer = AdamConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_amsgrad(true)\n            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))\n            .init();\n\n        for i in 1..=50 {\n            let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)\n                .mul_scalar(i as f32 * 0.1)\n                .require_grad();\n\n            let grads = linear.forward(x).backward();\n            let grads = GradientsParams::from_grads(grads, &linear);\n            linear = optimizer.step(LEARNING_RATE, linear, grads);\n        }\n\n        let state_updated = linear.into_record();\n        let weight_updated = state_updated.weight.to_data();\n        let bias_updated = state_updated.bias.unwrap().to_data();\n\n        let weights_expected = TensorData::from([\n            [\n                -0.9125810265541077,\n                -0.45855265855789185,\n                -0.1915993094444275,\n                -0.2759990692138672,\n                -0.5099529027938843,\n                -0.5287043452262878,\n            ],\n            [\n                -0.5181325674057007,\n                -0.6139854788780212,\n                -0.9574727416038513,\n                -0.34102925658226013,\n                -0.400514155626297,\n                -0.8847861886024475,\n            ],\n            [\n                -0.614483118057251,\n                -0.5611032247543335,\n                -0.8887064456939697,\n                -0.34762972593307495,\n                -0.8708556890487671,\n                -0.2830044627189636,\n            ],\n            [\n                -0.8904699683189392,\n                -0.8151527643203735,\n                -0.9621278643608093,\n                -0.8905676603317261,\n                -0.671261191368103,\n                -0.4333854615688324,\n            ],\n            [\n                -0.26599061489105225,\n                -0.8119961023330688,\n                -0.22424538433551788,\n                -0.7672406435012817,\n                -0.2163349837064743,\n                -0.6258266568183899,\n            ],\n            [\n                -0.611397922039032,\n                -0.6075160503387451,\n                -0.4701341986656189,\n                -0.4039117991924286,\n                -0.5663845539093018,\n                -0.21262989938259125,\n            ],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.8817203044891357,\n            -0.4038999378681183,\n            -0.5889149308204651,\n            -0.37475723028182983,\n            -0.3557940721511841,\n            -0.47914788126945496,\n        ]);\n\n        type FT = FloatElem<TestAutodiffBackend>;\n        let tolerance = Tolerance::absolute(1e-5);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n    }\n    #[test]\n    fn test_adam_optimizer_with_numbers() {\n        let device = Default::default();\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = AdamConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))\n            .init();\n\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        let weights_expected = TensorData::from([\n            [-0.340528, 0.118929, 0.384336, 0.300010, 0.066034, 0.047154],\n            [\n                0.057757, -0.036690, -0.386649, 0.235010, 0.175624, -0.312133,\n            ],\n            [\n                -0.038940, 0.016306, -0.316151, 0.228410, -0.297819, 0.293047,\n            ],\n            [\n                -0.317929, -0.239100, -0.391449, -0.318087, -0.095948, 0.142651,\n            ],\n            [\n                0.310050, -0.235909, 0.351736, -0.192888, 0.359710, -0.050343,\n            ],\n            [-0.035840, -0.030203, 0.105840, 0.172110, 0.009440, 0.363346],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.410499, 0.068401, -0.116999, 0.097601, 0.116601, -0.006999,\n        ]);\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.to_data(),\n            state_updated.bias.unwrap().to_data(),\n        );\n\n        type FT = FloatElem<TestAutodiffBackend>;\n        let tolerance = Tolerance::absolute(1e-2);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    #[test]\n    fn test_adam_optimizer_no_nan() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n\n        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &Default::default(),\n        )\n        .require_grad();\n\n        let mut optimizer = AdamConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_weight_decay(Some(WeightDecayConfig::new(0.5)))\n            .init();\n\n        let grads = linear.forward(x.clone()).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());\n    }\n\n    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: Some(Param::from_data(bias, &device)),\n        };\n\n        LinearConfig::new(6, 6).init(&device).load_record(record)\n    }\n\n    fn create_adam() -> OptimizerAdaptor<Adam, Linear<TestAutodiffBackend>, TestAutodiffBackend> {\n        let config = AdamConfig::new();\n        Adam {\n            momentum: AdaptiveMomentum {\n                beta_1: config.beta_1,\n                beta_2: config.beta_2,\n                epsilon: config.epsilon,\n                amsgrad: config.amsgrad,\n            },\n            weight_decay: config.weight_decay.as_ref().map(WeightDecay::new),\n        }\n        .into()\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/adamw.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse burn::tensor::{backend::Backend, ops::Device};\nuse burn::{module::AutodiffModule, record::Record};\n\nuse super::{AdaptiveMomentumState, SimpleOptimizer, adaptor::OptimizerAdaptor};\nuse crate::{LearningRate, grad_clipping::GradientClippingConfig};\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// [`AdamW`] Configuration.\n#[derive(Config, Debug)]\npub struct AdamWConfig {\n    /// Parameter for AdamW.\n    #[config(default = 0.9)]\n    beta_1: f32,\n    /// Parameter for AdamW.\n    #[config(default = 0.999)]\n    beta_2: f32,\n    /// A value required for numerical stability.\n    #[config(default = 1e-5)]\n    epsilon: f32,\n    /// Weight decay config.\n    #[config(default = 1e-4)]\n    weight_decay: f32,\n\n    /// Cautious weight decay config.\n    ///\n    /// See: <https://arxiv.org/abs/2510.12402>\n    #[config(default = false)]\n    cautious_weight_decay: bool,\n\n    /// Whether to use AMSGrad algorithm\n    #[config(default = false)]\n    amsgrad: bool,\n    /// [Gradient Clipping](GradientClippingConfig) config.\n    grad_clipping: Option<GradientClippingConfig>,\n}\n\n/// AdamW optimizer.\n///\n/// See:\n/// - [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).\n/// - [Cautious Weight Decay, 2025](https://arxiv.org/abs/2510.12402)\n/// - [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)\n///\n/// Configured by [`AdamWConfig`].\n#[derive(Clone)]\npub struct AdamW {\n    momentum: AdaptiveMomentumW,\n    weight_decay: f32,\n    cautious_weight_decay: bool,\n}\n\n/// AdamW state.\n#[derive(Record, Clone, new)]\npub struct AdamWState<B: Backend, const D: usize> {\n    /// Th current adaptive momentum state.\n    pub momentum: AdaptiveMomentumState<B, D>,\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for AdamW {\n    type State<const D: usize> = AdamWState<B, D>;\n\n    /// A single optimization step for any tensor that represents the parameters of a model.\n    fn step<const D: usize>(\n        &self,\n        // Learning rate.\n        lr: LearningRate,\n        // Any tensor that represents the parameters of a model.\n        tensor: Tensor<B, D>,\n        // Gradient of the loss w.r.t. the parameters.\n        grad: Tensor<B, D>,\n        // State of the optimizer.\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        let (raw_delta, momentum_state) = self.momentum.transform(grad, state.map(|s| s.momentum));\n\n        let decay_rate = lr * (self.weight_decay as f64);\n\n        let decayed_tensor = if decay_rate == 0.0 {\n            tensor.clone()\n        } else if self.cautious_weight_decay {\n            // Cautious weight decay.\n            // See: https://arxiv.org/abs/2510.12402\n            let tensor_pos = tensor.clone().greater_equal_elem(0.0);\n            let grad_pos = momentum_state.moment_1.clone().greater_equal_elem(0.0);\n            let differ = tensor_pos.not_equal(grad_pos);\n\n            // Zero out the decay where the decay is counter to the update direction.\n            tensor.clone() - tensor.mul_scalar(decay_rate).mask_fill(differ, 0.0)\n        } else {\n            tensor.clone().mul_scalar(1.0 - decay_rate)\n        };\n\n        let tensor_updated = decayed_tensor - raw_delta.mul_scalar(lr);\n\n        let state = AdamWState {\n            momentum: momentum_state,\n        };\n\n        (tensor_updated, Some(state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {\n        state.momentum = state.momentum.to_device(device);\n        state\n    }\n}\n\nimpl AdamWConfig {\n    /// Initialize AdamW optimizer.\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer that can be used to optimize a module.\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {\n        let optim = AdamW {\n            momentum: AdaptiveMomentumW {\n                beta_1: self.beta_1,\n                beta_2: self.beta_2,\n                epsilon: self.epsilon,\n                amsgrad: self.amsgrad,\n            },\n            weight_decay: self.weight_decay,\n            cautious_weight_decay: self.cautious_weight_decay,\n        };\n\n        let mut optim = OptimizerAdaptor::from(optim);\n        if let Some(config) = &self.grad_clipping {\n            optim = optim.with_grad_clipping(config.init());\n        }\n        optim\n    }\n}\n\n#[derive(Clone)]\nstruct AdaptiveMomentumW {\n    beta_1: f32,\n    beta_2: f32,\n    epsilon: f32,\n    amsgrad: bool,\n}\n\nimpl AdaptiveMomentumW {\n    pub fn transform<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        state: Option<AdaptiveMomentumState<B, D>>,\n    ) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {\n        let factor_1 = 1.0 - self.beta_1;\n        let factor_2 = 1.0 - self.beta_2;\n\n        let state = if let Some(mut state) = state {\n            // Update first moment estimate.\n            state.moment_1 = state\n                .moment_1\n                .mul_scalar(self.beta_1)\n                .add(grad.clone().mul_scalar(factor_1));\n\n            // Update second moment estimate.\n            state.moment_2 = state\n                .moment_2\n                .mul_scalar(self.beta_2)\n                .add(grad.square().mul_scalar(factor_2));\n\n            if self.amsgrad {\n                let max_v = state\n                    .max_moment_2\n                    .take()\n                    .unwrap_or_else(|| state.moment_2.clone());\n                state.max_moment_2 = Some(max_v.max_pair(state.moment_2.clone()));\n            }\n\n            // Update time.\n            state.time += 1;\n\n            state\n        } else {\n            // Initialize first moment estimate.\n            let moment_1 = grad.clone().mul_scalar(factor_1);\n\n            // Initialize second moment estimate.\n            let moment_2 = grad.square().mul_scalar(factor_2);\n            let max_moment_2 = self.amsgrad.then(|| moment_2.clone());\n            AdaptiveMomentumState {\n                time: 1,\n                moment_1,\n                moment_2,\n                max_moment_2,\n            }\n        };\n\n        let time: i32 = state.time as i32;\n\n        // Compute bias-corrected first and second moment estimates.\n        let moment_1_corrected = state\n            .moment_1\n            .clone()\n            .div_scalar(1f32 - self.beta_1.powi(time));\n\n        let v_to_use = if self.amsgrad {\n            state.max_moment_2.as_ref().unwrap_or(&state.moment_2)\n        } else {\n            &state.moment_2\n        };\n\n        let moment_2_corrected = v_to_use.clone().div_scalar(1f32 - self.beta_2.powi(time));\n\n        let update_delta =\n            moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));\n\n        (update_delta, state)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use crate::{GradientsParams, Optimizer};\n    use burn::module::{Module, Param};\n    use burn::tensor::{Distribution, Tensor, TensorData};\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    type FT = FloatElem<TestAutodiffBackend>;\n\n    const LEARNING_RATE: LearningRate = 0.01;\n\n    #[test]\n    fn test_adamw_optimizer_save_load_state() {\n        let device = Default::default();\n        let linear = LinearConfig::new(6, 6).init(&device);\n        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);\n        let mut optimizer = create_adamw();\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let _linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        #[cfg(feature = \"std\")]\n        {\n            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};\n\n            BinFileRecorder::<FullPrecisionSettings>::default()\n                .record(\n                    optimizer.to_record(),\n                    std::env::temp_dir().as_path().join(\"test_optim_adamw\"),\n                )\n                .unwrap();\n        }\n        #[cfg(not(feature = \"std\"))]\n        {\n            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};\n\n            let result = BinBytesRecorder::<FullPrecisionSettings>::default()\n                .record(optimizer.to_record(), ())\n                .unwrap();\n            assert!(!result.is_empty());\n        }\n\n        let state_optim_before = optimizer.to_record();\n        let state_optim_before_copy = optimizer.to_record();\n        let optimizer = create_adamw();\n        let optimizer = optimizer.load_record(state_optim_before_copy);\n        let state_optim_after = optimizer.to_record();\n\n        assert_eq!(state_optim_before.len(), state_optim_after.len());\n    }\n    #[test]\n    fn test_adamw_optimizer_with_amsgrad_50_steps() {\n        let device = Default::default();\n        let mut linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n\n        let mut optimizer = AdamWConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_amsgrad(true)\n            .with_weight_decay(0.5)\n            .init();\n\n        for i in 1..=50 {\n            let x = Tensor::<TestAutodiffBackend, 2>::ones([2, 6], &device)\n                .mul_scalar(i as f32 * 0.1)\n                .require_grad();\n\n            let grads = linear.forward(x).backward();\n            let grads = GradientsParams::from_grads(grads, &linear);\n            linear = optimizer.step(LEARNING_RATE, linear, grads);\n        }\n\n        let state_updated = linear.into_record();\n        let weight_updated = state_updated.weight.to_data();\n        let bias_updated = state_updated.bias.unwrap().to_data();\n\n        let weights_expected = TensorData::from([\n            [\n                -0.7822558283805847,\n                -0.42578864097595215,\n                -0.21805696189403534,\n                -0.28366872668266296,\n                -0.46587175130844116,\n                -0.4805040955543518,\n            ],\n            [\n                -0.4722539782524109,\n                -0.5471276640892029,\n                -0.8181359767913818,\n                -0.33425918221473694,\n                -0.3805687427520752,\n                -0.7601516842842102,\n            ],\n            [\n                -0.5475167632102966,\n                -0.5057991743087769,\n                -0.763265073299408,\n                -0.3393959403038025,\n                -0.7490996718406677,\n                -0.28911691904067993,\n            ],\n            [\n                -0.7646660208702087,\n                -0.7050473093986511,\n                -0.8218720555305481,\n                -0.7647438049316406,\n                -0.5919585227966309,\n                -0.40617525577545166,\n            ],\n            [\n                -0.27588561177253723,\n                -0.7025567889213562,\n                -0.24343004822731018,\n                -0.6672990918159485,\n                -0.23728127777576447,\n                -0.556389570236206,\n            ],\n            [\n                -0.5451040267944336,\n                -0.5420684814453125,\n                -0.4348171353340149,\n                -0.3832150399684906,\n                -0.5099242925643921,\n                -0.23440153896808624,\n            ],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.7473056316375732,\n            -0.3745720386505127,\n            -0.5188710689544678,\n            -0.35184532403945923,\n            -0.33705732226371765,\n            -0.4332566559314728,\n        ]);\n\n        type FT = FloatElem<TestAutodiffBackend>;\n        let tolerance = Tolerance::absolute(1e-5);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n    }\n    #[test]\n    fn test_adamw_optimizer_with_numbers() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n        let device = Default::default();\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = AdamWConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_weight_decay(0.5)\n            .init();\n\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        let weights_expected = TensorData::from([\n            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],\n            [\n                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,\n            ],\n            [\n                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,\n            ],\n            [\n                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,\n            ],\n            [\n                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,\n            ],\n            [-0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.359580],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,\n        ]);\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.to_data(),\n            state_updated.bias.unwrap().to_data(),\n        );\n\n        let tolerance = Tolerance::absolute(1e-2);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    #[test]\n    fn test_adamw_optimizer_with_numbers_cautious() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n        let device = Default::default();\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, -0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = AdamWConfig::new()\n            .with_cautious_weight_decay(true)\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_weight_decay(0.5)\n            .init();\n\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        let weights_expected = TensorData::from([\n            [-0.337295, 0.117827, 0.380358, 0.296868, 0.065232, 0.046534],\n            [\n                0.057032, -0.036518, -0.382951, 0.232516, 0.173738, -0.309182,\n            ],\n            [\n                -0.038703, 0.016052, -0.313155, 0.225982, -0.295039, 0.289981,\n            ],\n            [\n                -0.314920, -0.237394, -0.387704, -0.315067, -0.095153, 0.141081,\n            ],\n            [\n                0.306815, -0.234226, 0.348083, -0.191115, 0.356002, -0.049993,\n            ],\n            [\n                -0.035634, -0.030083, 0.104636, 0.170244, 0.009196, 0.37061332,\n            ],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,\n        ]);\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.to_data(),\n            state_updated.bias.unwrap().to_data(),\n        );\n\n        let tolerance = Tolerance::absolute(1e-2);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    #[test]\n    fn test_adam_optimizer_no_nan() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n\n        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &Default::default(),\n        )\n        .require_grad();\n\n        let mut optimizer = AdamWConfig::new()\n            .with_epsilon(1e-8)\n            .with_beta_1(0.9)\n            .with_beta_2(0.999)\n            .with_weight_decay(0.5)\n            .init();\n\n        let grads = linear.forward(x.clone()).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        assert!(!state_updated.weight.to_data().as_slice::<f32>().unwrap()[0].is_nan());\n    }\n\n    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: Some(Param::from_data(bias, &device)),\n        };\n\n        LinearConfig::new(6, 6).init(&device).load_record(record)\n    }\n\n    fn create_adamw() -> OptimizerAdaptor<AdamW, Linear<TestAutodiffBackend>, TestAutodiffBackend> {\n        let config = AdamWConfig::new();\n        AdamW {\n            momentum: AdaptiveMomentumW {\n                beta_1: config.beta_1,\n                beta_2: config.beta_2,\n                epsilon: config.epsilon,\n                amsgrad: config.amsgrad,\n            },\n            weight_decay: config.weight_decay,\n            cautious_weight_decay: false,\n        }\n        .into()\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/base.rs",
    "content": "use burn_core::{self as burn, Tensor};\n\nuse burn_core::module::ParamId;\nuse burn_core::prelude::{Backend, DeviceOps};\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::DeviceId;\n\nuse super::GradientsParams;\nuse crate::LearningRate;\nuse alloc::vec::Vec;\nuse burn::module::AutodiffModule;\nuse burn::record::Record;\nuse burn::tensor::backend::AutodiffBackend;\n\n#[derive(Default)]\n/// Exposes multiple gradients for each parameter.\npub struct MultiGradientsParams {\n    /// Each [GradientsParams] has its associated [DeviceId].\n    pub grads: Vec<(GradientsParams, DeviceId)>,\n}\n\nimpl MultiGradientsParams {\n    /// Removes the gradients for the given [parameter id](ParamId).\n    ///\n    /// Potentially accumulates the gradients from multiple sources using a device associated with\n    /// a parameter id. The same parameter will be accumulated using the same device during\n    /// all training.\n    pub fn remove<B: Backend, const D: usize>(\n        &mut self,\n        id: ParamId,\n    ) -> Option<(Tensor<B, D>, Device<B>)> {\n        let (mut tensor, device, index) = self.select(id)?;\n\n        for (i, (grads, _)) in self.grads.iter_mut().enumerate() {\n            if i == index {\n                continue;\n            }\n\n            if let Some(grad) = grads.remove::<B, D>(id) {\n                tensor = tensor + grad.to_device(&device);\n            }\n        }\n\n        Some((tensor, device))\n    }\n\n    fn select<B: Backend, const D: usize>(\n        &mut self,\n        id: ParamId,\n    ) -> Option<(Tensor<B, D>, Device<B>, usize)> {\n        let id_val = id.val() as usize;\n        for i in 0..self.grads.len() {\n            let selected_device_index = (id_val + i) % self.grads.len();\n\n            if let Some(acc) = self.grads[selected_device_index].0.remove::<B, D>(id) {\n                let device_id = self.grads[selected_device_index].1;\n                let device = <B::Device as DeviceOps>::from_id(device_id);\n                return Some((acc.to_device(&device), device, selected_device_index));\n            }\n        }\n\n        None\n    }\n}\n\n/// General trait to optimize [module](AutodiffModule).\npub trait Optimizer<M, B>: Send + Clone\nwhere\n    M: AutodiffModule<B>,\n    B: AutodiffBackend,\n{\n    /// Optimizer associative type to be used when saving and loading the state.\n    type Record: Record<B>;\n\n    /// Perform the optimizer step using the given learning rate and gradients.\n    /// The updated module is returned.\n    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M;\n\n    /// Perform the optimizer step using the given learning rate and gradients.\n    /// The updated module is returned.\n    fn step_multi(&mut self, lr: LearningRate, module: M, grads: MultiGradientsParams) -> M;\n\n    /// Get the current state of the optimizer as a [record](Record).\n    fn to_record(&self) -> Self::Record;\n\n    /// Load the state of the optimizer as a [record](Record).\n    fn load_record(self, record: Self::Record) -> Self;\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/decay.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::record::Record;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\n\n/// Configuration to create [weight decay](WeightDecay).\n#[derive(Config, Debug)]\npub struct WeightDecayConfig {\n    /// L2 penalty.\n    pub penalty: f32,\n}\n\n/// State of [weight decay](WeightDecay).\n#[derive(Record, Clone, new)]\npub struct WeightDecayState<B: Backend, const D: usize> {\n    pub(crate) grad_last_step: Tensor<B, D>,\n}\n\n/// Weight decay implementation that transforms gradients.\n#[derive(Clone)]\npub struct WeightDecay {\n    penalty: f32,\n}\n\nimpl WeightDecay {\n    /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).\n    pub fn new(config: &WeightDecayConfig) -> Self {\n        Self {\n            penalty: config.penalty,\n        }\n    }\n\n    /// Transforms a gradient.\n    ///\n    /// # Arguments\n    ///\n    /// * `grad` - Gradient to transform.\n    /// * `tensor` - Tensor param of the last iteration.\n    ///\n    /// # Returns\n    ///\n    /// * `grad` - Transformed gradient.\n    pub fn transform<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        tensor: Tensor<B, D>,\n    ) -> Tensor<B, D> {\n        tensor.mul_scalar(self.penalty).add(grad)\n    }\n}\n\nimpl<B: Backend, const D: usize> WeightDecayState<B, D> {\n    /// Moves the state to a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move the state to.\n    ///\n    /// # Returns\n    ///\n    /// * `self` - Moved state.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.grad_last_step = self.grad_last_step.to_device(device);\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/grad_accum.rs",
    "content": "use burn_core as burn;\n\nuse core::marker::PhantomData;\n\nuse burn::module::{AutodiffModule, ModuleVisitor, Param};\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\n\nuse super::GradientsParams;\n\n/// Accumulate gradients into a single [Gradients](AutodiffBackend::Gradients) object.\npub struct GradientsAccumulator<M> {\n    grads: GradientsParams,\n    phantom: PhantomData<M>,\n}\n\nimpl<M> Default for GradientsAccumulator<M> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<M> GradientsAccumulator<M> {\n    /// Create a new gradients accumulator.\n    pub fn new() -> Self {\n        Self {\n            grads: GradientsParams::new(),\n            phantom: PhantomData,\n        }\n    }\n}\n\nimpl<M> GradientsAccumulator<M> {\n    /// Accumulate the given gradients for each parameter in the given module.\n    pub fn accumulate<B: AutodiffBackend>(&mut self, module: &M, grads: GradientsParams)\n    where\n        M: AutodiffModule<B>,\n    {\n        let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);\n        module.visit(&mut visitor);\n    }\n\n    /// Return the accumulated gradients and reset the accumulator state.\n    pub fn grads(&mut self) -> GradientsParams {\n        let mut grads = GradientsParams::new();\n        core::mem::swap(&mut self.grads, &mut grads);\n\n        grads\n    }\n}\n\n#[derive(new)]\nstruct ModuleGradsAccumulator<'a, M> {\n    grads: &'a mut GradientsParams,\n    grads_new: GradientsParams,\n    phantom: PhantomData<M>,\n}\n\nimpl<B: AutodiffBackend, M: AutodiffModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'_, M> {\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(param.id) {\n            Some(new) => match self.grads.remove::<B::InnerBackend, D>(param.id) {\n                Some(grad) => grad.add(new),\n                None => new,\n            },\n            None => match self.grads.remove::<B::InnerBackend, D>(param.id) {\n                Some(grad) => grad,\n                None => return,\n            },\n        };\n\n        self.grads\n            .register::<B::InnerBackend, D>(param.id, grad_updated);\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use burn::tensor::{Distribution, backend::Backend};\n    use burn_nn::{Linear, LinearConfig};\n\n    #[test]\n    fn test_accumulate_gradients_one_step() {\n        let device = Default::default();\n        let mut accumulator = GradientsAccumulator::new();\n        let layer = layer::<TestAutodiffBackend>(&device);\n        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));\n        let grads = GradientsParams::from_grads(loss.backward(), &layer);\n\n        accumulator.accumulate(&layer, grads);\n\n        let grads = accumulator.grads();\n        assert!(!grads.is_empty())\n    }\n\n    #[test]\n    fn test_accumulate_gradients_two_steps() {\n        let device = Default::default();\n        let mut accumulator = GradientsAccumulator::new();\n        let layer = layer::<TestAutodiffBackend>(&device);\n        let loss_1 = layer.forward(random_tensor(&device));\n        let loss_2 = layer.forward(random_tensor(&device));\n        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer);\n        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer);\n\n        accumulator.accumulate(&layer, grads_1);\n        accumulator.accumulate(&layer, grads_2);\n\n        let grads = accumulator.grads();\n        assert_eq!(grads.len(), 2)\n    }\n\n    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {\n        LinearConfig::new(20, 20).init(device)\n    }\n\n    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {\n        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/grads.rs",
    "content": "use burn_core as burn;\n\n#[cfg(feature = \"collective\")]\nuse burn_collective::{CollectiveError, PeerId, ReduceOperation, all_reduce};\n\nuse burn::{\n    Tensor,\n    tensor::{\n        backend::{AutodiffBackend, Backend},\n        container::TensorContainer,\n    },\n};\n\nuse burn::module::{AutodiffModule, ParamId};\n\nuse super::visitor::{GradientsParamsChangeDevice, GradientsParamsConverter};\n\n/// Data type that contains gradients for parameters.\n#[derive(Default, Debug)]\npub struct GradientsParams {\n    container: TensorContainer<ParamId>,\n}\n\nimpl GradientsParams {\n    /// Creates a new [GradientsParams](GradientsParams).\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Extract each tensor gradients for the given [module](AutodiffModule).\n    ///\n    /// Note: This consumes the gradients. See ['from_module'] to extract gradients only for\n    ///  a specific module.\n    pub fn from_grads<B: AutodiffBackend, M: AutodiffModule<B>>(\n        grads: B::Gradients,\n        module: &M,\n    ) -> Self {\n        let mut grads = grads;\n        Self::from_module(&mut grads, module)\n    }\n\n    /// Extract each tensor gradients for the given [module](AutodiffModule).\n    pub fn from_module<B: AutodiffBackend, M: AutodiffModule<B>>(\n        grads: &mut B::Gradients,\n        module: &M,\n    ) -> Self {\n        let mut grads_params = GradientsParams::new();\n        let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, None);\n        module.visit(&mut visitor);\n        grads_params\n    }\n\n    /// Extract tensor gradients for the given [module](AutodiffModule) and given parameters.\n    pub fn from_params<B: AutodiffBackend, M: AutodiffModule<B>>(\n        grads: &mut B::Gradients,\n        module: &M,\n        params: &[ParamId],\n    ) -> Self {\n        let mut grads_params = GradientsParams::new();\n        let mut visitor =\n            GradientsParamsConverter::<M, B>::new(grads, &mut grads_params, Some(params.to_vec()));\n        module.visit(&mut visitor);\n        grads_params\n    }\n\n    /// Get the gradients for the given [parameter id](ParamId).\n    ///\n    /// # Notes\n    ///\n    /// You should use [remove](GradientsParams::remove) if you want to get the gradients\n    /// only one time.\n    pub fn get<B, const D: usize>(&self, id: ParamId) -> Option<Tensor<B, D>>\n    where\n        B: Backend,\n    {\n        self.container.get(&id).map(Tensor::from_primitive)\n    }\n\n    /// Remove the gradients for the given [parameter id](ParamId).\n    pub fn remove<B, const D: usize>(&mut self, id: ParamId) -> Option<Tensor<B, D>>\n    where\n        B: Backend,\n    {\n        self.container.remove(&id).map(Tensor::from_primitive)\n    }\n\n    /// Register a gradients tensor for the given [parameter id](ParamId).\n    ///\n    /// # Notes\n    ///\n    /// If a tensor is already registered for the given [parameter id](ParamId), it will be replaced.\n    pub fn register<B, const D: usize>(&mut self, id: ParamId, value: Tensor<B, D>)\n    where\n        B: Backend,\n    {\n        self.container.register(id, value.into_primitive())\n    }\n\n    /// The number of gradients tensors registered.\n    pub fn len(&self) -> usize {\n        self.container.len()\n    }\n\n    /// If any tensor is contained.\n    pub fn is_empty(&self) -> bool {\n        self.len() == 0\n    }\n\n    /// Change the device of each tensor gradients registered for the given [module](AutodiffModule).\n    pub fn to_device<B: AutodiffBackend, M: AutodiffModule<B>>(\n        mut self,\n        device: &B::Device,\n        module: &M,\n    ) -> Self {\n        let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);\n        module.visit(&mut visitor);\n        self\n    }\n\n    /// Syncs the gradient params with the other peers in the collective.\n    #[cfg(feature = \"collective\")]\n    pub fn all_reduce<B: Backend>(\n        mut self,\n        peer_id: PeerId,\n        op: ReduceOperation,\n    ) -> Result<Self, CollectiveError> {\n        let mut ids = self\n            .container\n            .ids()\n            .into_iter()\n            .copied()\n            .collect::<Vec<ParamId>>();\n        // This is crucial, since the all-reduce operations need to happen in the same order for the same parameters on all nodes!\n        ids.sort();\n\n        for id in ids {\n            let Some(grad) = self.container.remove::<B>(&id) else {\n                todo!()\n            };\n\n            let grad = match grad {\n                burn::tensor::TensorPrimitive::Float(grad) => {\n                    let grad = all_reduce::<B>(peer_id, grad, op)?;\n                    burn::tensor::TensorPrimitive::Float(grad)\n                }\n                burn::tensor::TensorPrimitive::QFloat(_grad) => {\n                    unimplemented!(\"quantized all-reduce unimplemented\")\n                }\n            };\n\n            self.container.register::<B>(id, grad);\n        }\n\n        Ok(self)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use burn::module::{Module, list_param_ids};\n    use burn::tensor::{Distribution, backend::Backend};\n    use burn_nn::{Linear, LinearConfig};\n\n    #[test]\n    fn test_convert_grads() {\n        let device = Default::default();\n        let layer_1 = layer::<TestAutodiffBackend>(&device);\n        let mut layer_2 = layer_1.clone();\n        layer_2 = layer_2.fork(&device);\n        let loss_1 = layer_1.forward(random_tensor(&device));\n        let loss_2 = layer_2.forward(random_tensor(&device));\n        let grads_1 = GradientsParams::from_grads(loss_1.backward(), &layer_1);\n        let grads_2 = GradientsParams::from_grads(loss_2.backward(), &layer_2);\n\n        let param_ids_1 = list_param_ids(&layer_1);\n        let param_ids_2 = list_param_ids(&layer_2);\n\n        assert_eq!(param_ids_1, param_ids_2);\n        assert_eq!(grads_1.len(), param_ids_1.len());\n        assert_eq!(grads_2.len(), param_ids_2.len());\n    }\n\n    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {\n        LinearConfig::new(20, 20).init(device)\n    }\n\n    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {\n        Tensor::<B, 2>::random([2, 20], Distribution::Default, device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/lbfgs.rs",
    "content": "#![allow(clippy::excessive_precision)]\n\nuse burn_core as burn;\n\nuse super::GradientsParams;\nuse crate::LearningRate;\nuse burn::config::Config;\nuse burn::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param};\nuse burn::prelude::ToElement;\nuse burn::record::Record;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse serde::{Deserialize, Serialize};\n\nuse alloc::vec;\nuse alloc::vec::Vec;\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Cubic Interpolate\n///\n/// Uses two points (x1, f1), (x2, f2) and their first derivatives g1,g2 to construct\n/// a cubic interpolant and return its minimum within the given bounds.\nfn cubic_interpolate(\n    x1: f64,\n    f1: f64,\n    g1: f64,\n    x2: f64,\n    f2: f64,\n    g2: f64,\n    bounds: Option<(f64, f64)>,\n) -> f64 {\n    // Compute bounds of interpolation area\n    let (min_bound, max_bound) = bounds.unwrap_or(if x1 <= x2 { (x1, x2) } else { (x2, x1) });\n    // Code for most common case: cubic interpolation of 2 points\n    // with function and derivative values for both\n    // Solution in this case (where x2 is the farthest point)\n    // d1 = g1 + g2 - 3*(f1 - f2) / (x1-x2);\n    // d2 = sqrt(d1^2 - g1 * g2);\n    // min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));\n    // t_new = min(max(min_pos,min_bound), max_bound);\n    let d1 = g1 + g2 - 3.0 * (f1 - f2) / (x1 - x2);\n    let d2_square = d1 * d1 - g1 * g2;\n\n    if d2_square >= 0.0 {\n        let d2 = d2_square.sqrt();\n        let min_pos = if x1 <= x2 {\n            x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2.0 * d2))\n        } else {\n            x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2.0 * d2))\n        };\n        min_pos.max(min_bound).min(max_bound)\n    } else {\n        (min_bound + max_bound) / 2.0\n    }\n}\n/// Auxiliary Struct For Strong_Wolfe\nstruct LineSearchSample<B: Backend> {\n    // step size\n    t: f64,\n    // loss\n    f: f64,\n    // gradient\n    g: Tensor<B, 1>,\n    // directional derivative\n    gtd: f64,\n}\n\n#[allow(clippy::too_many_arguments)]\nfn strong_wolfe<B: Backend, F>(\n    // obj_func(x,step size,direction) -> (loss,grad)\n    obj_func: &mut F,\n    x: &Tensor<B, 1>,\n    // initial step size\n    mut t: f64,\n    d: &Tensor<B, 1>,\n    f: f64,\n    g: Tensor<B, 1>,\n    gtd: f64,\n    c1: f64,\n    c2: f64,\n    tolerance_change: f64,\n    max_ls: usize,\n) -> (f64, Tensor<B, 1>, f64, usize)\nwhere\n    F: FnMut(&Tensor<B, 1>, f64, &Tensor<B, 1>) -> (f64, Tensor<B, 1>),\n{\n    let d_norm = d.clone().abs().max().into_scalar().to_f64();\n\n    // evaluate objective and gradient using initial step\n    let (mut f_new, mut g_new) = obj_func(x, t, d);\n    let mut ls_func_evals = 1;\n    let mut gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();\n\n    // bracket an interval [t_prev,t] containing a point satisfying the Wolfe criteria\n    let (mut t_prev, mut f_prev, mut g_prev, mut gtd_prev) = (0.0, f, g.clone(), gtd);\n    let mut done = false;\n    let mut ls_iter = 0;\n\n    // the interval [low,high] using for Zoom phase\n    let mut bracket: Option<[LineSearchSample<B>; 2]> = None;\n    // point which satisfy the wolfe condition\n    let mut wolfe_bracket: Option<LineSearchSample<B>> = None;\n    while ls_iter < max_ls {\n        // Checking Conditions.\n\n        // Checking the Armijo Condition and function value increasing condition.\n        // Armijo: f(x+t*d) <= f(x) + c_1 t gtd\n        if f_new > (f + c1 * t * gtd) || (ls_iter > 1 && f_new >= f_prev) {\n            bracket = Some([\n                LineSearchSample {\n                    t: t_prev,\n                    f: f_prev,\n                    g: g_prev,\n                    gtd: gtd_prev,\n                },\n                LineSearchSample {\n                    t,\n                    f: f_new,\n                    g: g_new.clone(),\n                    gtd: gtd_new,\n                },\n            ]);\n            break;\n        }\n\n        // Checking Strong Wolfe Condition\n        // |gtd_new| <= -c_2 gtd\n        if gtd_new.abs() <= -c2 * gtd {\n            wolfe_bracket = Some(LineSearchSample {\n                t,\n                f: f_new,\n                g: g_new.clone(),\n                gtd: gtd_new,\n            });\n            done = true;\n            break;\n        }\n\n        // gtd_new >=0 , there must be a local minimum in the interval.\n        if gtd_new >= 0.0 {\n            bracket = Some([\n                LineSearchSample {\n                    t: t_prev,\n                    f: f_prev,\n                    g: g_prev,\n                    gtd: gtd_prev,\n                },\n                LineSearchSample {\n                    t,\n                    f: f_new,\n                    g: g_new.clone(),\n                    gtd: gtd_new,\n                },\n            ]);\n            break;\n        }\n\n        // interpolate\n        let min_step = t + 0.01 * (t - t_prev);\n        let max_step = t * 10.0;\n        let t_next = cubic_interpolate(\n            t_prev,\n            f_prev,\n            gtd_prev,\n            t,\n            f_new,\n            gtd_new,\n            Some((min_step, max_step)),\n        );\n        t_prev = t;\n        f_prev = f_new;\n        g_prev = g_new;\n        gtd_prev = gtd_new;\n\n        // next step\n        t = t_next;\n        (f_new, g_new) = obj_func(x, t, d);\n        ls_func_evals += 1;\n        gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();\n        ls_iter += 1;\n    }\n    if let Some(sample) = wolfe_bracket {\n        return (sample.f, sample.g, sample.t, ls_func_evals);\n    }\n\n    let mut bracket = bracket.unwrap_or_else(|| {\n        [\n            LineSearchSample {\n                t: 0.0,\n                f,\n                g: g.clone(),\n                gtd,\n            },\n            LineSearchSample {\n                t,\n                f: f_new,\n                g: g_new.clone(),\n                gtd: gtd_new,\n            },\n        ]\n    });\n\n    // zoom phase\n    let mut insuf_progress = false;\n\n    // find high and low points in bracket\n    let (mut low_idx, mut high_idx) = if bracket[0].f <= bracket[1].f {\n        (0, 1)\n    } else {\n        (1, 0)\n    };\n\n    while !done && ls_iter < max_ls {\n        let diff = (bracket[1].t - bracket[0].t).abs();\n        // line-search bracket is so small\n        if diff * d_norm < tolerance_change {\n            break;\n        }\n\n        // compute new trial value\n        t = cubic_interpolate(\n            bracket[0].t,\n            bracket[0].f,\n            bracket[0].gtd,\n            bracket[1].t,\n            bracket[1].f,\n            bracket[1].gtd,\n            None,\n        );\n\n        let b_min = bracket[0].t.min(bracket[1].t);\n        let b_max = bracket[0].t.max(bracket[1].t);\n        let eps = 0.1 * (b_max - b_min);\n\n        if (b_max - t).min(t - b_min) < eps {\n            // interpolation close to boundary\n            if insuf_progress || t >= b_max || t <= b_min {\n                t = if (t - b_max).abs() < (t - b_min).abs() {\n                    b_max - eps\n                } else {\n                    b_min + eps\n                };\n                insuf_progress = false;\n            } else {\n                insuf_progress = true;\n            }\n        } else {\n            insuf_progress = false;\n        }\n\n        // Evaluate new point\n        (f_new, g_new) = obj_func(x, t, d);\n\n        ls_func_evals += 1;\n        gtd_new = g_new.clone().dot(d.clone()).into_scalar().to_f64();\n        ls_iter += 1;\n\n        let armijo_holds = f_new <= (f + c1 * t * gtd) && f_new < bracket[low_idx].f;\n\n        if !armijo_holds {\n            bracket[high_idx] = LineSearchSample {\n                t,\n                f: f_new,\n                g: g_new,\n                gtd: gtd_new,\n            };\n        } else {\n            if gtd_new.abs() <= -c2 * gtd {\n                return (f_new, g_new, t, ls_func_evals);\n            }\n\n            if gtd_new * (bracket[high_idx].t - bracket[low_idx].t) >= 0.0 {\n                bracket[high_idx] = LineSearchSample {\n                    t: bracket[low_idx].t,\n                    f: bracket[low_idx].f,\n                    g: bracket[low_idx].g.clone(),\n                    gtd: bracket[low_idx].gtd,\n                };\n            }\n            bracket[low_idx] = LineSearchSample {\n                t,\n                f: f_new,\n                g: g_new,\n                gtd: gtd_new,\n            };\n        }\n\n        if bracket[0].f <= bracket[1].f {\n            low_idx = 0;\n            high_idx = 1;\n        } else {\n            low_idx = 1;\n            high_idx = 0;\n        }\n    }\n    // return stuff\n    (\n        bracket[low_idx].f,\n        bracket[low_idx].g.clone(),\n        bracket[low_idx].t,\n        ls_func_evals,\n    )\n}\n\n/// Strategy for the line search optimization phase\n#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]\npub enum LineSearchFn {\n    /// No line search performed\n    #[default]\n    None,\n    /// strong wolfe conditions\n    ///\n    /// See: <https://en.wikipedia.org/wiki/Wolfe_conditions>\n    StrongWolfe,\n}\n\n/// LBFGS Configuration.\n#[derive(Config, Debug)]\npub struct LBFGSConfig {\n    /// Maximal number of iterations per optimization step (default: 20)\n    #[config(default = 20)]\n    pub max_iter: usize,\n    /// Update history size (default: 100).\n    #[config(default = 100)]\n    pub history_size: usize,\n    /// Termination tolerance on first order optimality (default: 1e-7).\n    #[config(default = 1e-7)]\n    pub tolerance_grad: f64,\n    /// Termination tolerance on function value/parameter changes (default: 1e-9).\n    #[config(default = 1e-9)]\n    pub tolerance_change: f64,\n    /// Maximal number of function evaluations per optimization step (default: max_iter * 1.25).\n    #[config(default = \"None\")]\n    pub max_eval: Option<usize>,\n    /// Either ‘strong_wolfe’ or None (default: None).\n    #[config(default = \"LineSearchFn::None\")]\n    pub line_search_fn: LineSearchFn,\n}\n\nimpl LBFGSConfig {\n    /// Initialize AdamW optimizer\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer that can be used to optimize a module\n    pub fn init<B: AutodiffBackend>(&self) -> LBFGS<B> {\n        // by default max_eval = max_iter * 5/4\n        let max_eval = self.max_eval.unwrap_or(self.max_iter * 5 / 4);\n        LBFGS {\n            config: LBFGSConfig {\n                max_iter: self.max_iter,\n                history_size: self.history_size,\n                tolerance_grad: self.tolerance_grad,\n                tolerance_change: self.tolerance_change,\n                max_eval: Some(max_eval),\n                line_search_fn: self.line_search_fn,\n            },\n            state: Default::default(),\n        }\n    }\n}\n\n/// Collects gradients in module visit order.\nstruct FlattenGradsVisitorInner<'a, B: AutodiffBackend> {\n    grads: &'a GradientsParams,\n    tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,\n}\n\nimpl<B: AutodiffBackend> ModuleVisitor<B> for FlattenGradsVisitorInner<'_, B> {\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        if let Some(g) = self.grads.get::<B::InnerBackend, D>(param.id) {\n            let numel = g.shape().num_elements();\n            self.tensors.push(g.reshape([numel]));\n        }\n    }\n}\n\n/// Flatten params to inner backend 1D tensor.\nfn flatten_params_inner<B: AutodiffBackend, M: Module<B>>(\n    module: &M,\n) -> Tensor<B::InnerBackend, 1> {\n    let mut tensors = Vec::new();\n    let mut visitor = FlattenParamsVisitorInner::<B> {\n        tensors: &mut tensors,\n    };\n    module.visit(&mut visitor);\n    if tensors.is_empty() {\n        return Tensor::empty([0], &module.devices()[0]);\n    }\n    Tensor::cat(tensors, 0)\n}\n\nstruct FlattenParamsVisitorInner<'a, B: AutodiffBackend> {\n    tensors: &'a mut Vec<Tensor<B::InnerBackend, 1>>,\n}\n\nimpl<B: AutodiffBackend> ModuleVisitor<B> for FlattenParamsVisitorInner<'_, B> {\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        let t = param.val().inner();\n        let numel = t.shape().num_elements();\n        self.tensors.push(t.reshape([numel]));\n    }\n}\n\n/// Flatten gradients for a module.\nfn flatten_grads_inner<B: AutodiffBackend, M: Module<B>>(\n    module: &M,\n    grads: &GradientsParams,\n) -> Tensor<B::InnerBackend, 1> {\n    let mut tensors = Vec::new();\n    let mut visitor = FlattenGradsVisitorInner {\n        grads,\n        tensors: &mut tensors,\n    };\n    module.visit(&mut visitor);\n    if tensors.is_empty() {\n        return Tensor::empty([0], &module.devices()[0]);\n    }\n    Tensor::cat(tensors, 0)\n}\n\n/// Mapper that assigns each float param from a flat inner-backend 1D tensor.\nstruct ParamsFromFlatMapperInner<'a, B: AutodiffBackend> {\n    flat: &'a Tensor<B::InnerBackend, 1>,\n    offset: &'a mut usize,\n}\n\nimpl<B: AutodiffBackend> ParamsFromFlatMapperInner<'_, B> {\n    fn take_slice(&mut self, numel: usize) -> Tensor<B::InnerBackend, 1> {\n        let start = *self.offset;\n        *self.offset += numel;\n        self.flat.clone().slice(start..*self.offset)\n    }\n}\n\nimpl<B: AutodiffBackend> ModuleMapper<B> for ParamsFromFlatMapperInner<'_, B> {\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        let numel = tensor.shape().num_elements();\n        let slice_1d = self.take_slice(numel);\n        let new_inner = slice_1d.reshape(tensor.shape());\n        let new_tensor = Tensor::from_inner(new_inner).require_grad();\n        Param::from_mapped_value(id, new_tensor, mapper)\n    }\n}\n\n/// Overwrite module parameters from a flat inner-backend 1D tensor\nfn set_params_from_flat_inner<B: AutodiffBackend, M: Module<B>>(\n    module: M,\n    flat: Tensor<B::InnerBackend, 1>,\n) -> M {\n    let mut offset = 0;\n    let mut mapper = ParamsFromFlatMapperInner {\n        flat: &flat,\n        offset: &mut offset,\n    };\n    module.map(&mut mapper)\n}\n\n/// L-BFGS optimizer state\n#[derive(Clone, Record)]\npub struct LBFGSState<B: Backend> {\n    /// Historical displacement vectors\n    pub history_s: Vec<Tensor<B, 1>>,\n    /// Historical gradient difference vectors\n    pub history_y: Vec<Tensor<B, 1>>,\n    /// Search direction\n    pub d: Option<Tensor<B, 1>>,\n    /// Step size from the previous iteration\n    pub t: Option<f64>,\n    /// Flattened gradient from the previous iteration\n    pub prev_flat_grad: Option<Tensor<B, 1>>,\n    /// Loss value from the previous iteration\n    pub prev_loss: Option<f64>,\n    /// Global iteration count\n    pub g_iter: usize,\n}\n\nimpl<B: Backend> LBFGSState<B> {\n    /// Moves all historical tensors to the target device.\n    pub fn to_device(self, device: &B::Device) -> Self {\n        Self {\n            history_s: self\n                .history_s\n                .into_iter()\n                .map(|t| t.to_device(device))\n                .collect(),\n            history_y: self\n                .history_y\n                .into_iter()\n                .map(|t| t.to_device(device))\n                .collect(),\n            d: self.d.map(|t| t.to_device(device)),\n            t: self.t,\n            prev_flat_grad: self.prev_flat_grad.map(|t| t.to_device(device)),\n            prev_loss: self.prev_loss,\n            g_iter: self.g_iter,\n        }\n    }\n}\nimpl<B: Backend> Default for LBFGSState<B> {\n    fn default() -> Self {\n        Self {\n            history_s: Vec::new(),\n            history_y: Vec::new(),\n            d: None,\n            t: Some(1.0),\n            prev_flat_grad: None,\n            prev_loss: None,\n            g_iter: 0,\n        }\n    }\n}\n\n/// L-BFGS optimizer.\n///\n/// Ported from [pytorch](https://github.com/pytorch/pytorch/torch/optim/lbfgs.py). Heavily inspired by [miniFunc](https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html)\n///\n/// See also:\n/// - [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS)\n///\n/// # Note\n/// This optimizer is memory intensive\n#[derive(Clone)]\npub struct LBFGS<B: Backend + AutodiffBackend> {\n    config: LBFGSConfig,\n    state: LBFGSState<B::InnerBackend>,\n}\n\nimpl<B: Backend + AutodiffBackend> LBFGS<B> {\n    /// A single optimization step for any tensor that represents the parameters of a model.\n    pub fn step<M, F>(&mut self, lr: LearningRate, mut module: M, mut closure: F) -> (M, f64)\n    where\n        M: AutodiffModule<B> + Clone,\n        F: FnMut(M) -> (f64, GradientsParams),\n    {\n        // evaluate initial f(x) and df/dx\n        let (mut loss, grads) = closure(module.clone());\n        let mut current_evals = 1;\n\n        let mut flat_grad = flatten_grads_inner::<B, M>(&module, &grads);\n        let mut x_flat = flatten_params_inner::<B, M>(&module);\n\n        let opt_cond =\n            flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad;\n        // optimal condition\n        if opt_cond {\n            return (module, loss);\n        }\n\n        // tensors cached in state\n        let mut d = self\n            .state\n            .d\n            .take()\n            .unwrap_or_else(|| flat_grad.clone().neg());\n        let mut t = self.state.t.unwrap_or(lr);\n        let mut prev_flat_grad = self.state.prev_flat_grad.take();\n\n        let mut n_iter = 0;\n\n        // optimize for a max of max_iter iterations\n        while n_iter < self.config.max_iter {\n            // keep track of nb of iterations\n            n_iter += 1;\n            self.state.g_iter += 1;\n\n            // compute gradient descent direction\n            if self.state.g_iter == 1 {\n                d = flat_grad.clone().neg();\n                self.state.history_s.clear();\n                self.state.history_y.clear();\n            } else {\n                // do lbfgs update (update memory)\n                if let Some(pg) = prev_flat_grad.as_ref() {\n                    let y = flat_grad.clone().sub(pg.clone());\n                    let s = d.clone().mul_scalar(t);\n\n                    let ys = y.clone().dot(s.clone()).into_scalar().to_f64();\n\n                    if ys > 1e-10 {\n                        // updating memory\n                        if self.state.history_s.len() >= self.config.history_size {\n                            // shift history by one (limited-memory)\n                            self.state.history_s.remove(0);\n                            self.state.history_y.remove(0);\n                        }\n                        self.state.history_s.push(s);\n                        self.state.history_y.push(y);\n                    }\n                }\n\n                // compute the approximate (L-BFGS) inverse Hessian\n                // multiplied by the gradient\n                let num_old = self.state.history_s.len();\n                let mut q = flat_grad.clone().neg();\n                let mut alphas: Vec<Tensor<B::InnerBackend, 1>> =\n                    vec![Tensor::zeros([1], &flat_grad.device()); num_old];\n\n                if num_old > 0 {\n                    // multiply by initial Hessian\n                    // r/d is the final direction\n                    for i in (0..num_old).rev() {\n                        let s = &self.state.history_s[i];\n                        let y = &self.state.history_y[i];\n                        let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);\n                        let alpha = rho.clone().mul(s.clone().dot(q.clone()));\n                        alphas[i] = alpha.clone();\n                        q = q.sub(y.clone().mul(alpha));\n                    }\n\n                    let last_s = &self.state.history_s[num_old - 1];\n                    let last_y = &self.state.history_y[num_old - 1];\n                    let ys = last_y.clone().dot(last_s.clone());\n                    let yy = last_y.clone().dot(last_y.clone());\n                    let h_diag = ys.div(yy);\n\n                    let mut r = q.mul(h_diag);\n\n                    for ((s, y), alpha) in self\n                        .state\n                        .history_s\n                        .iter()\n                        .zip(self.state.history_y.iter())\n                        .zip(alphas.into_iter())\n                        .take(num_old)\n                    {\n                        let rho = y.clone().dot(s.clone()).powf_scalar(-1.0);\n\n                        let beta = rho.mul(y.clone().dot(r.clone()));\n\n                        r = r.add(s.clone().mul(alpha.sub(beta)));\n                    }\n                    d = r;\n                } else {\n                    d = q;\n                }\n            }\n\n            prev_flat_grad = Some(flat_grad.clone());\n            let prev_loss_iter = loss;\n\n            // compute step len\n            if self.state.g_iter == 1 {\n                let grad_l1 = flat_grad.clone().abs().sum().into_scalar().to_f64();\n                t = (1.0f64 / grad_l1).min(1.0) * lr;\n            } else {\n                t = lr;\n            }\n\n            // directional derivative\n            let gtd = flat_grad.clone().dot(d.clone()).into_scalar().to_f64();\n\n            if gtd > -self.config.tolerance_change {\n                break;\n            }\n\n            let ls_func_evals;\n\n            if let LineSearchFn::StrongWolfe = self.config.line_search_fn {\n                // perform line search, using user function\n                let mut obj_func =\n                    |current_x: &Tensor<B::InnerBackend, 1>,\n                     step: f64,\n                     dir: &Tensor<B::InnerBackend, 1>| {\n                        let update = dir.clone().mul_scalar(step);\n                        let new_x = current_x.clone().add(update);\n                        let tmp_module = set_params_from_flat_inner::<B, M>(module.clone(), new_x);\n                        let (l, g) = closure(tmp_module);\n                        (l, flatten_grads_inner::<B, M>(&module, &g))\n                    };\n\n                let (ls_f, ls_g, ls_t, evals) = strong_wolfe(\n                    &mut obj_func,\n                    &x_flat,\n                    t,\n                    &d,\n                    loss,\n                    flat_grad.clone(),\n                    gtd,\n                    1e-4,\n                    0.9,\n                    self.config.tolerance_change,\n                    self.config.max_eval.unwrap() - current_evals,\n                );\n\n                loss = ls_f;\n                flat_grad = ls_g;\n                t = ls_t;\n                ls_func_evals = evals;\n\n                x_flat = x_flat.add(d.clone().mul_scalar(t));\n                module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());\n            } else {\n                // no line search, simply move with fixed-step\n                let step_vec = d.clone().mul_scalar(t);\n                x_flat = x_flat.add(step_vec);\n                module = set_params_from_flat_inner::<B, M>(module, x_flat.clone());\n                // re-evaluate function only if not in last iteration\n                // the reason we do this: in a stochastic setting,\n                // no use to re-evaluate that function here\n                let (new_loss, new_grads) = closure(module.clone());\n                loss = new_loss;\n                flat_grad = flatten_grads_inner::<B, M>(&module, &new_grads);\n                ls_func_evals = 1;\n            }\n\n            // update func eval\n            current_evals += ls_func_evals;\n\n            // check conditions\n\n            if current_evals >= self.config.max_eval.unwrap() {\n                break;\n            }\n\n            if flat_grad.clone().abs().max().into_scalar().to_f64() <= self.config.tolerance_grad {\n                break;\n            }\n\n            if d.clone().mul_scalar(t).abs().max().into_scalar().to_f64()\n                <= self.config.tolerance_change\n            {\n                break;\n            }\n\n            if (loss - prev_loss_iter).abs() < self.config.tolerance_change {\n                break;\n            }\n        }\n        self.state.d = Some(d);\n        self.state.t = Some(t);\n        self.state.prev_flat_grad = prev_flat_grad;\n        self.state.prev_loss = Some(loss);\n        (module, loss)\n    }\n    /// Moves the optimizer state to the specified device.\n    pub fn to_device(self, device: &B::Device) -> Self {\n        Self {\n            config: self.config,\n            // History tensors reside in InnerBackend, so we convert the device accordingly\n            state: self.state.to_device(device),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n\n    use super::*;\n    use crate::GradientsParams;\n    use crate::TestAutodiffBackend;\n    use burn::module::{Module, Param};\n    use burn::tensor::{Tensor, TensorData};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: Some(Param::from_data(bias, &device)),\n        };\n\n        LinearConfig::new(6, 6).init(&device).load_record(record)\n    }\n    #[test]\n    fn test_cubic_interpolate() {\n        let tolerance = 1e-8;\n\n        // basic\n        let (x1, f1, g1, x2, f2, g2) = (-1.0, 1.0, -2.0, 1.0, 1.0, 2.0);\n        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);\n        assert!(\n            (result - 0.00000).abs() < tolerance,\n            \"Basic: Result {} should be close to 0.0\",\n            result\n        );\n\n        // bound\n        let (x1, f1, g1, x2, f2, g2) = (0.0, 0.25, -1.0, 1.0, 0.25, 1.0);\n        let bounds = Some((0.6, 1.0));\n        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds);\n        assert!(\n            (result - 0.6000000000).abs() < tolerance,\n            \"Bound: Result {} should be clamped to 0.6\",\n            result\n        );\n\n        // d2_square < 0,should return mid value\n        let (x1, f1, g1, x2, f2, g2) = (0.0, 0.0, 10.0, 1.0, 5.0, 10.0);\n        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((0.0, 1.0)));\n        assert!(\n            (result - 0.5000000).abs() < tolerance,\n            \"Fallback: Result {} should be midpoint 0.5\",\n            result\n        );\n\n        // asymmetric\n        let (x1, f1, g1, x2, f2, g2) = (0.0, 1.0, -5.0, 1.0, 0.5, 1.0);\n        let result = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);\n        assert!(\n            (result - 0.4606553370833684).abs() < tolerance,\n            \"Asymmetric: Result {} should be 0.4606553370833684\",\n            result\n        );\n\n        // not good value\n        let (x1, f1, g1, x2, f2, g2) = (\n            1.231232145,\n            -0.12567458754,\n            9.1231243007,\n            8.239105015,\n            -100.9012398021,\n            123201321.0293982,\n        );\n        let result_1 = cubic_interpolate(x1, f1, g1, x2, f2, g2, None);\n        let result_2 = cubic_interpolate(x1, f1, g1, x2, f2, g2, Some((-4.4, 4.4)));\n        assert!(\n            (result_1 - 5.9031480234724434).abs() < tolerance,\n            \"not good value 1: Result {} should be 5.9031480234724434\",\n            result\n        );\n        assert!(\n            (result_2 - 4.4000000000000004).abs() < tolerance,\n            \"not good value 2: Result {} should be 4.4000000000000004\",\n            result\n        );\n    }\n    #[test]\n    fn test_strong_wolfe_direct_comparison() {\n        let device = Default::default();\n        let tol = 1e-8;\n\n        {\n            let x = Tensor::<TestAutodiffBackend, 1>::from_floats([2.1321912957_f64], &device);\n            let d = Tensor::<TestAutodiffBackend, 1>::from_floats([0.91312321_f64], &device);\n            let t_initial = 1.213132_f64;\n            fn func<B: Backend>(\n                x_base: &Tensor<B, 1>,\n                t_val: f64,\n                d_vec: &Tensor<B, 1>,\n            ) -> (f64, Tensor<B, 1>) {\n                let curr_x = x_base.clone().add(d_vec.clone().mul_scalar(t_val));\n                let x2 = curr_x.clone().mul(curr_x.clone());\n                let x3 = x2.clone().mul(curr_x.clone());\n                let x4 = x2.clone().mul(x2.clone());\n\n                // f(x) = x^4 - 2*x^2 + x\n                let f_elements = x4 - x2.mul_scalar(2.0) + curr_x.clone();\n\n                let f_val = f_elements.sum().into_scalar().to_f64();\n\n                // g(x) = 4*x^3 - 4*x + 1\n                let g = x3.mul_scalar(4.0) - curr_x.clone().mul_scalar(4.0)\n                    + Tensor::ones_like(&curr_x);\n\n                (f_val, g)\n            }\n            let (f_init, g_init) = func(&x, 0.0, &d);\n            let gtd_init = g_init.clone().dot(d.clone()).into_scalar().to_f64();\n            println!(\"Initial State: f={},gtd = {}\", f_init, gtd_init);\n            assert!((f_init - 13.7080059052).abs() < tol);\n            assert!((gtd_init - 28.5305728912).abs() < tol);\n            let mut obj_func =\n                |xb: &Tensor<TestAutodiffBackend, 1>,\n                 tv: f64,\n                 dv: &Tensor<TestAutodiffBackend, 1>| func(xb, tv, dv);\n\n            let (f_final, _g_final, t_final, evals) = strong_wolfe(\n                &mut obj_func,\n                &x,\n                t_initial,\n                &d,\n                f_init,\n                g_init,\n                gtd_init,\n                1e-4, // c1\n                0.9,  // c2\n                1e-9, // tolerance_change\n                10,   // max_ls\n            );\n            let g_f = _g_final.into_scalar().to_f64();\n            println!(\n                \"f_final:{:?},_g_final:{:?},t_final:{:?},evals:{:?}\",\n                f_final, g_f, t_final, evals\n            );\n            assert!((f_final - 13.708005905151367).abs() < tol);\n            assert!((g_f - 31.2450428009).abs() < tol);\n            assert!((t_final.to_f64() - 0.0).abs() < tol);\n            assert!((evals == 11));\n        }\n    }\n    #[test]\n    fn test_lbfgs_strong_wolfe_comparison() {\n        let device = Default::default();\n        let tol = 1e-5;\n        let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);\n        let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);\n        let weight = TensorData::from([[0.5f64]]);\n        let bias = TensorData::from([0.1f64]);\n        let module = given_linear_layer(weight, bias);\n\n        let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()\n            .with_line_search_fn(LineSearchFn::StrongWolfe)\n            .init();\n        let mut closure = |mod_in: Linear<TestAutodiffBackend>| {\n            let output = mod_in.forward(x_data.clone());\n            let loss = burn_nn::loss::MseLoss::new().forward(\n                output,\n                y_true.clone(),\n                burn_nn::loss::Reduction::Sum,\n            );\n\n            let grads = loss.backward();\n            let grads_params = GradientsParams::from_grads(grads, &mod_in);\n\n            (loss.into_scalar().to_f64(), grads_params)\n        };\n        let initial_loss = closure(module.clone()).0;\n        assert!((initial_loss - 50.1300048828).abs() < tol);\n        let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);\n        assert!((final_loss - 0.0234732367).abs() < tol);\n        let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();\n        let optimized_bias: f64 = updated_module\n            .bias\n            .as_ref()\n            .unwrap()\n            .val()\n            .into_scalar()\n            .to_f64();\n        assert!((optimized_data - 2.0570652485).abs() < tol);\n        assert!((optimized_bias - 0.8106800914).abs() < tol);\n    }\n    #[test]\n    fn test_lbfgs_no_strong_wolfe_comparison() {\n        let device = Default::default();\n        let tol = 1e-5;\n        let x_data = Tensor::<TestAutodiffBackend, 2>::from_data([[1.0], [2.0], [3.0]], &device);\n        let y_true = Tensor::<TestAutodiffBackend, 2>::from_data([[3.0], [5.0], [7.0]], &device);\n        let weight = TensorData::from([[0.5f64]]);\n        let bias = TensorData::from([0.1f64]);\n        let module = given_linear_layer(weight, bias);\n\n        let mut optimizer: LBFGS<TestAutodiffBackend> = LBFGSConfig::new()\n            .with_line_search_fn(LineSearchFn::None)\n            .init();\n        let mut closure = |mod_in: Linear<TestAutodiffBackend>| {\n            let output = mod_in.forward(x_data.clone());\n            let loss = burn_nn::loss::MseLoss::new().forward(\n                output,\n                y_true.clone(),\n                burn_nn::loss::Reduction::Sum,\n            );\n\n            let grads = loss.backward();\n            let grads_params = GradientsParams::from_grads(grads, &mod_in);\n\n            (loss.into_scalar().to_f64(), grads_params)\n        };\n        let initial_loss = closure(module.clone()).0;\n        assert!((initial_loss - 50.1300048828).abs() < tol);\n        let (updated_module, final_loss) = optimizer.step(0.001, module, &mut closure);\n        assert!((final_loss - 48.2181930542).abs() < tol);\n        let optimized_data: f64 = updated_module.weight.val().into_scalar().to_f64();\n        let optimized_bias: f64 = updated_module\n            .bias\n            .as_ref()\n            .unwrap()\n            .val()\n            .into_scalar()\n            .to_f64();\n\n        assert!((optimized_data - 0.5302446192).abs() < tol);\n        assert!((optimized_bias - 0.1142520783).abs() < tol);\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/mod.rs",
    "content": "/// Weight decay module for optimizers.\npub mod decay;\n\n/// Momentum module for optimizers.\npub mod momentum;\n\nmod adagrad;\nmod adam;\nmod adamw;\nmod base;\nmod grad_accum;\nmod grads;\nmod lbfgs;\nmod muon;\nmod rmsprop;\nmod sgd;\nmod simple;\nmod visitor;\n\npub use adagrad::*;\npub use adam::*;\npub use adamw::*;\npub use base::*;\npub use grad_accum::*;\npub use grads::*;\npub use lbfgs::*;\npub use muon::*;\npub use rmsprop::*;\npub use sgd::*;\npub use simple::*;\n"
  },
  {
    "path": "crates/burn-optim/src/optim/momentum.rs",
    "content": "use burn_core as burn;\n\nuse burn::config::Config;\nuse burn::record::Record;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{ElementConversion, Tensor};\n\n/// Configuration to create [momentum](Momentum).\n#[derive(Config, Debug)]\npub struct MomentumConfig {\n    /// Momentum factor\n    #[config(default = 0.9)]\n    pub momentum: f64,\n    /// Dampening factor.\n    #[config(default = 0.1)]\n    pub dampening: f64,\n    /// Enables Nesterov momentum, see [On the importance of initialization and\n    /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).\n    #[config(default = false)]\n    pub nesterov: bool,\n}\n\n/// State of [momentum](Momentum).\n#[derive(Record, Clone, new)]\npub struct MomentumState<B: Backend, const D: usize> {\n    velocity: Tensor<B, D>,\n}\n\n/// Momentum implementation that transforms gradients.\n#[derive(Clone)]\npub struct Momentum<B: Backend> {\n    momentum: B::FloatElem,\n    dampening: f64,\n    nesterov: bool,\n}\n\nimpl<B: Backend> Momentum<B> {\n    /// Creates a new [momentum](Momentum) from a [config](MomentumConfig).\n    pub fn new(config: &MomentumConfig) -> Self {\n        Self {\n            momentum: config.momentum.elem(),\n            dampening: config.dampening,\n            nesterov: config.nesterov,\n        }\n    }\n\n    /// Transforms a gradient.\n    ///\n    /// # Arguments\n    ///\n    /// * `grad` - Gradient to transform.\n    /// * `state` - State of the optimizer.\n    ///\n    /// # Returns\n    ///\n    /// * `grad` - Transformed gradient.\n    /// * `state` - State of the optimizer.\n    pub fn transform<const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        state: Option<MomentumState<B, D>>,\n    ) -> (Tensor<B, D>, MomentumState<B, D>) {\n        let velocity = if let Some(state) = state {\n            grad.clone()\n                .mul_scalar(1.0 - self.dampening)\n                .add(state.velocity.mul_scalar(self.momentum))\n        } else {\n            grad.clone()\n        };\n\n        let grad = match self.nesterov {\n            true => velocity.clone().mul_scalar(self.momentum).add(grad),\n            false => velocity.clone(),\n        };\n\n        (grad, MomentumState::new(velocity))\n    }\n}\n\nimpl<B: Backend, const D: usize> MomentumState<B, D> {\n    /// Moves the state to a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move the state to.\n    ///\n    /// # Returns\n    ///\n    /// * `self` - Moved state.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.velocity = self.velocity.to_device(device);\n        self\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/muon.rs",
    "content": "use burn_core as burn;\n\nuse burn::{module::AutodiffModule, record::Record};\n\nuse burn::config::Config;\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse burn::tensor::{backend::Backend, ops::Device};\nuse serde::{Deserialize, Serialize};\n\nuse super::{\n    SimpleOptimizer,\n    adaptor::OptimizerAdaptor,\n    decay::WeightDecayConfig,\n    momentum::{Momentum, MomentumConfig, MomentumState},\n};\nuse crate::LearningRate;\n\n#[cfg(not(feature = \"std\"))]\n#[allow(unused_imports)]\nuse num_traits::Float as _;\n\n/// Learning rate adjustment method for Muon optimizer.\n///\n/// Muon adjusts the learning rate based on parameter shape to maintain consistent\n/// RMS across rectangular matrices.\n///\n/// # References\n///\n/// - Original: [Muon: An optimizer for hidden layers](https://kellerjordan.github.io/posts/muon/)\n/// - Moonshot: [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)\n#[derive(Clone, Default, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]\npub enum AdjustLrFn {\n    /// Keller Jordan's original method: `lr * sqrt(max(1, A/B))`\n    ///\n    /// This scales the learning rate based on the aspect ratio of the weight matrix,\n    /// ensuring that tall matrices (more rows than columns) get proportionally larger\n    /// learning rates.\n    ///\n    /// # Example\n    ///\n    /// For a [1024, 512] matrix: `lr * sqrt(1024/512) = lr * 1.414`\n    #[default]\n    Original,\n\n    /// Moonshot's method: `lr * 0.2 * sqrt(max(A, B))`\n    ///\n    /// This method is designed to match AdamW's RMS, allowing Muon to directly reuse\n    /// learning rates and weight decay values tuned for AdamW without retuning.\n    ///\n    /// # Example\n    ///\n    /// For a [1024, 512] matrix: `lr * 0.2 * sqrt(1024) = lr * 6.4`\n    MatchRmsAdamW,\n}\n\nimpl AdjustLrFn {\n    /// Calculate the learning rate adjustment ratio for a given parameter shape.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - Parameter shape (uses first two dimensions)\n    ///\n    /// # Returns\n    ///\n    /// Adjustment ratio to multiply with the base learning rate\n    fn adjustment_ratio(&self, shape: &[usize]) -> f64 {\n        if shape.len() < 2 {\n            return 1.0;\n        }\n\n        let a = shape[0] as f64;\n        let b = shape[1] as f64;\n\n        match self {\n            Self::Original => {\n                // sqrt(max(1, A/B))\n                let ratio = a / b;\n                ratio.max(1.0).sqrt()\n            }\n            Self::MatchRmsAdamW => {\n                // 0.2 * sqrt(max(A, B))\n                0.2 * a.max(b).sqrt()\n            }\n        }\n    }\n}\n\n/// Muon configuration.\n///\n/// Muon is an optimizer specifically designed for 2D parameters of neural network\n/// hidden layers (weight matrices). Other parameters such as biases and embeddings\n/// should be optimized using a standard method such as AdamW.\n///\n/// # Learning Rate Adjustment\n///\n/// Muon adjusts the learning rate based on parameter shape to maintain consistent\n/// RMS across rectangular matrices. Two methods are available:\n///\n/// - **Original**: Uses `sqrt(max(1, A/B))` where A and B are the first two dimensions.\n///   This is Keller Jordan's method and is the default.\n///\n/// - **MatchRmsAdamW**: Uses `0.2 * sqrt(max(A, B))`. This is Moonshot's method\n///   designed to match AdamW's RMS, allowing direct reuse of AdamW hyperparameters.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_optim::{MuonConfig, AdjustLrFn};\n///\n/// // Using default (Original) method\n/// let optimizer = MuonConfig::new().init();\n///\n/// // Using MatchRmsAdamW for AdamW-compatible hyperparameters\n/// let optimizer = MuonConfig::new()\n///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)\n///     .init();\n/// ```\n///\n/// # References\n///\n/// - [Muon: An optimizer for hidden layers in neural networks](https://kellerjordan.github.io/posts/muon/)\n/// - [Muon is Scalable for LLM Training](https://arxiv.org/pdf/2502.16982)\n/// - [PyTorch Implementation](https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py)\n/// - [Original Implementation](https://github.com/KellerJordan/Muon)\n#[derive(Config, Debug)]\npub struct MuonConfig {\n    /// [Weight decay](WeightDecayConfig) config.\n    weight_decay: Option<WeightDecayConfig>,\n\n    /// [Momentum](MomentumConfig) config.\n    ///\n    /// Muon always uses momentum. Default configuration:\n    /// - momentum: 0.95\n    /// - dampening: 0.0\n    /// - nesterov: true\n    #[config(default = \"MomentumConfig { momentum: 0.95, dampening: 0.0, nesterov: true }\")]\n    momentum: MomentumConfig,\n\n    /// Newton-Schulz iteration coefficients (a, b, c).\n    ///\n    /// These coefficients are selected to maximize the slope at zero for the\n    /// quintic iteration. Default values are from Keller Jordan's implementation.\n    #[config(default = \"(3.4445, -4.775, 2.0315)\")]\n    ns_coefficients: (f32, f32, f32),\n\n    /// Epsilon for numerical stability.\n    #[config(default = 1e-7)]\n    epsilon: f32,\n\n    /// Number of Newton-Schulz iteration steps.\n    #[config(default = 5)]\n    ns_steps: usize,\n\n    /// Learning rate adjustment method.\n    ///\n    /// Controls how the learning rate is adjusted based on parameter shape.\n    /// See [`AdjustLrFn`] for available methods.\n    #[config(default = \"AdjustLrFn::Original\")]\n    adjust_lr_fn: AdjustLrFn,\n}\n\nimpl MuonConfig {\n    /// Initialize Muon optimizer.\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer adaptor that can be used to optimize a module.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};\n    ///\n    /// // Basic configuration with default (Original) LR adjustment\n    /// let optimizer = MuonConfig::new()\n    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.01)))\n    ///     .init();\n    ///\n    /// // With AdamW-compatible settings using MatchRmsAdamW\n    /// let optimizer = MuonConfig::new()\n    ///     .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)\n    ///     .with_weight_decay(Some(WeightDecayConfig::new(0.1)))\n    ///     .init();\n    ///\n    /// // Custom momentum and NS settings\n    /// let optimizer = MuonConfig::new()\n    ///     .with_momentum(MomentumConfig {\n    ///         momentum: 0.9,\n    ///         dampening: 0.1,\n    ///         nesterov: false,\n    ///     })\n    ///     .with_ns_steps(7)\n    ///     .init();\n    /// ```\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(\n        &self,\n    ) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {\n        let momentum = Momentum::new(&self.momentum);\n        let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);\n\n        let optim = Muon {\n            momentum,\n            ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),\n            weight_decay_penalty,\n            epsilon: self.epsilon,\n            adjust_lr_fn: self.adjust_lr_fn,\n        };\n\n        OptimizerAdaptor::from(optim)\n    }\n}\n\n/// Parameters for Newton-Schulz orthogonalization.\n#[derive(Clone, Copy)]\nstruct NewtonSchulzParams {\n    a: f32,\n    b: f32,\n    c: f32,\n    steps: usize,\n}\n\nimpl NewtonSchulzParams {\n    fn new(coefficients: (f32, f32, f32), steps: usize) -> Self {\n        Self {\n            a: coefficients.0,\n            b: coefficients.1,\n            c: coefficients.2,\n            steps,\n        }\n    }\n}\n\n/// Muon optimizer.\n///\n/// Muon internally runs standard SGD-momentum, and then performs an orthogonalization\n/// post-processing step, in which each 2D parameter's update is replaced with the\n/// nearest orthogonal matrix. For efficient orthogonalization we use a Newton-Schulz\n/// iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.\n///\n/// # Important Notes\n///\n/// 1. **Only for 2D+ parameters**: Muon is designed for weight matrices. Use AdamW\n///    or SGD for biases, embeddings, and layer norms.\n///\n/// 2. **Learning rate adjustment**: Muon automatically adjusts the learning rate based\n///    on parameter shape. See [`AdjustLrFn`] for details.\n///\n/// 3. **Weight decay timing**: Unlike typical optimizers, Muon applies weight decay\n///    AFTER orthogonalization but uses the original (unadjusted) learning rate for it.\n#[derive(Clone)]\npub struct Muon<B: Backend> {\n    momentum: Momentum<B>,\n    ns_params: NewtonSchulzParams,\n    weight_decay_penalty: Option<f32>,\n    epsilon: f32,\n    adjust_lr_fn: AdjustLrFn,\n}\n\nimpl<B: Backend> Muon<B> {\n    /// Adjust learning rate based on parameter shape.\n    ///\n    /// # Arguments\n    ///\n    /// * `lr` - Base learning rate\n    /// * `shape` - Parameter shape (uses first two dimensions)\n    ///\n    /// # Returns\n    ///\n    /// Adjusted learning rate\n    ///\n    /// ```ignore\n    /// // For a [1024, 512] weight matrix with lr=0.01:\n    /// // Original: 0.01 * sqrt(1024/512) = 0.01 * 1.414 = 0.01414\n    /// // MatchRmsAdamW: 0.01 * 0.2 * sqrt(1024) = 0.01 * 0.2 * 32 = 0.064\n    /// ```\n    fn adjust_lr(&self, lr: LearningRate, shape: &[usize]) -> LearningRate {\n        lr * self.adjust_lr_fn.adjustment_ratio(shape)\n    }\n\n    /// Perform Newton-Schulz orthogonalization on a gradient tensor.\n    ///\n    /// This computes the zeroth power (orthogonalization) of the input matrix G\n    /// using a quintic Newton-Schulz iteration.\n    ///\n    /// # Algorithm\n    ///\n    /// 1. Transpose if tall matrix (A > B)\n    /// 2. Normalize: X = X / ||X||\n    /// 3. For k steps:\n    ///    - A = X @ X^T\n    ///    - B = b*A + c*A^2\n    ///    - X = a*X + B@X\n    /// 4. Transpose back if needed\n    ///\n    /// # References\n    ///\n    /// - Original: https://github.com/KellerJordan/Muon/blob/master/muon.py\n    /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py\n    fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {\n        let shape = g.shape();\n        let dim_m2 = shape[D - 2];\n        let dim_m1 = shape[D - 1];\n\n        // Step 1: Transpose if tall matrix (more rows than columns)\n        let (mut x, needs_transpose) = if dim_m2 > dim_m1 {\n            (g.swap_dims(D - 2, D - 1), true)\n        } else {\n            (g, false)\n        };\n\n        // Step 2: Normalize by Frobenius norm\n        // X = X / (||X|| + epsilon)\n        let norm = x\n            .clone()\n            .powf_scalar(2.0)\n            .sum()\n            .sqrt()\n            .clamp_min(self.epsilon)\n            .unsqueeze();\n\n        x = x.div(norm);\n\n        // Step 3: Newton-Schulz iteration\n        // This is the quintic iteration with coefficients (a, b, c)\n        let NewtonSchulzParams { a, b, c, steps } = self.ns_params;\n\n        for _ in 0..steps {\n            // A = X @ X^T\n            let x_t = x.clone().swap_dims(D - 2, D - 1);\n            let a_matrix = x.clone().matmul(x_t);\n\n            // B = b*A + c*A@A\n            let a_squared = a_matrix.clone().matmul(a_matrix.clone());\n            let b_matrix = a_matrix.mul_scalar(b).add(a_squared.mul_scalar(c));\n\n            // X = a*X + B@X\n            x = x.clone().mul_scalar(a).add(b_matrix.matmul(x.clone()));\n        }\n\n        // Step 4: Restore transpose if it was a tall matrix\n        if needs_transpose {\n            x = x.swap_dims(D - 2, D - 1);\n        }\n\n        x\n    }\n}\n\n/// Muon state.\n#[derive(Record, Clone, new)]\npub struct MuonState<B: Backend, const D: usize> {\n    /// Current momentum state\n    pub momentum: MomentumState<B, D>,\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for Muon<B> {\n    type State<const D: usize> = MuonState<B, D>;\n\n    /// Perform a single Muon optimization step.\n    ///\n    /// # Algorithm\n    ///\n    /// 1. Apply momentum to gradient\n    /// 2. Orthogonalize update via Newton-Schulz\n    /// 3. Adjust learning rate based on parameter shape\n    /// 4. Apply weight decay (using original lr)\n    /// 5. Update parameter (using adjusted lr)\n    ///\n    /// # Notes\n    ///\n    /// Unlike typical optimizers, the weight decay and parameter update use\n    /// different learning rates:\n    /// - Weight decay uses the original `lr`\n    /// - Parameter update uses the shape-adjusted `lr`\n    ///\n    /// # Panics\n    /// This function will panic if the input tensors are not 2D.\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        assert!(\n            D == 2,\n            \"Newton-Schulz iteration requires 2D tensors, got {}D\",\n            D\n        );\n\n        // Step 1: Apply momentum\n        let state_momentum = state.map(|s| s.momentum);\n        let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);\n\n        // Step 2: Orthogonalize via Newton-Schulz\n        let update = self.zeropower_via_newtonschulz(grad);\n\n        // Step 3: Adjust learning rate based on parameter shape\n        let adjusted_lr = self.adjust_lr(lr, &tensor.shape());\n\n        // Step 4: Apply weight decay (using ORIGINAL lr, not adjusted)\n        // Muon applies weight decay AFTER orthogonalization\n        let tensor = if let Some(penalty) = self.weight_decay_penalty {\n            let decay_factor = 1.0 - lr * penalty as f64;\n            tensor.mul_scalar(decay_factor)\n        } else {\n            tensor\n        };\n\n        // Step 5: Update parameter (using ADJUSTED lr)\n        let delta = update.mul_scalar(adjusted_lr);\n        let new_state = MuonState::new(new_momentum_state);\n\n        (tensor - delta, Some(new_state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {\n        state.momentum = state.momentum.to_device(device);\n        state\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use crate::{GradientsParams, Optimizer};\n    use burn::module::{Module, Param};\n    use burn::tensor::{Distribution, Tensor, TensorData};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    type TestBackend = burn_ndarray::NdArray<f32>;\n\n    const TOLERANCE: f64 = 1e-8;\n\n    fn given_linear_layer_no_bias(weight: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: None, //No bias for Muon optimizer\n        };\n\n        LinearConfig::new(4, 4)\n            .with_bias(false)\n            .init(&device)\n            .load_record(record)\n    }\n\n    #[test]\n    fn test_adjust_lr_fn_original() {\n        let method = AdjustLrFn::Original;\n\n        // Square matrix [512, 512] -> sqrt(1) = 1.0\n        let ratio = method.adjustment_ratio(&[512, 512]);\n        assert!((ratio - 1.0).abs() < TOLERANCE);\n\n        // Tall matrix [1024, 512] -> sqrt(2) ≈ 1.414\n        let ratio = method.adjustment_ratio(&[1024, 512]);\n        let expected = (2.0f64).sqrt();\n        assert!((ratio - expected).abs() < TOLERANCE);\n\n        // Wide matrix [512, 1024] -> max(1, 0.5) = 1.0\n        let ratio = method.adjustment_ratio(&[512, 1024]);\n        assert!((ratio - 1.0).abs() < TOLERANCE);\n    }\n\n    #[test]\n    fn test_adjust_lr_fn_match_rms_adamw() {\n        let method = AdjustLrFn::MatchRmsAdamW;\n\n        // [1024, 512] -> 0.2 * sqrt(1024) = 6.4\n        let ratio = method.adjustment_ratio(&[1024, 512]);\n        let expected = 0.2 * 1024.0f64.sqrt();\n        assert!((ratio - expected).abs() < TOLERANCE);\n\n        // [512, 512] -> 0.2 * sqrt(512) ≈ 4.525\n        let ratio = method.adjustment_ratio(&[512, 512]);\n        let expected = 0.2 * 512.0f64.sqrt();\n        assert!((ratio - expected).abs() < TOLERANCE);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Newton-Schulz iteration requires 2D tensors, got 1D\")]\n    fn test_1d_tensor_panics() {\n        let device = Default::default();\n        let config = MuonConfig::new();\n        let optim: Muon<TestBackend> = Muon {\n            momentum: Momentum::new(&config.momentum),\n            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),\n            weight_decay_penalty: None,\n            epsilon: config.epsilon,\n            adjust_lr_fn: config.adjust_lr_fn,\n        };\n\n        let tensor_1d = Tensor::<TestBackend, 1>::zeros([512], &device);\n        let grad_1d = Tensor::<TestBackend, 1>::ones([512], &device);\n\n        let _ = optim.step(0.01, tensor_1d, grad_1d, None);\n    }\n\n    #[test]\n    fn test_muon_optimizer_save_load_state() {\n        let device = Default::default();\n        // Use Linear layer WITHOUT bias for Muon optimizer\n        let linear = LinearConfig::new(6, 6)\n            .with_bias(false) // No bias - only 2D weight matrix\n            .init::<TestAutodiffBackend>(&device);\n\n        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);\n\n        let mut optimizer =\n            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let _linear = optimizer.step(0.01, linear, grads);\n\n        let state_before = optimizer.to_record();\n        let state_before_copy = optimizer.to_record();\n\n        let optimizer_new =\n            MuonConfig::new().init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();\n        let optimizer_loaded = optimizer_new.load_record(state_before_copy);\n        let state_after = optimizer_loaded.to_record();\n\n        assert_eq!(state_before.len(), state_after.len());\n    }\n\n    #[test]\n    fn test_muon_with_weight_decay() {\n        let device = Default::default();\n        // Create Linear layer WITHOUT bias for Muon\n        let linear = given_linear_layer_no_bias(TensorData::from([\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n            [1.0, 1.0, 1.0, 1.0],\n        ]));\n\n        let x = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = MuonConfig::new()\n            .with_weight_decay(Some(WeightDecayConfig::new(0.01)))\n            .init::<TestAutodiffBackend, Linear<TestAutodiffBackend>>();\n\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(0.01, linear, grads);\n\n        let state = linear.into_record();\n        let weight = state.weight.to_data();\n\n        for val in weight.as_slice::<f32>().unwrap() {\n            assert!(\n                *val < 1.0,\n                \"Weight should be reduced by weight decay, got {}\",\n                val\n            );\n        }\n    }\n\n    #[test]\n    fn test_newton_schulz_orthogonalization() {\n        let device = Default::default();\n        let matrix = Tensor::<TestBackend, 2>::from_floats([[1.0, 0.5], [0.5, 1.0]], &device);\n\n        let config = MuonConfig::new();\n        let muon: Muon<TestBackend> = Muon {\n            momentum: Momentum::new(&config.momentum),\n            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),\n            weight_decay_penalty: None,\n            epsilon: config.epsilon,\n            adjust_lr_fn: config.adjust_lr_fn,\n        };\n\n        let orthogonalized = muon.zeropower_via_newtonschulz(matrix);\n        let o_t = orthogonalized.clone().transpose();\n        let product = orthogonalized.matmul(o_t);\n\n        let data = product.into_data();\n        let values = data.as_slice::<f32>().unwrap();\n\n        assert!(\n            (values[0] - 1.0).abs() < 0.1,\n            \"Product[0,0] should be ~1.0, got {}\",\n            values[0]\n        );\n        assert!(\n            (values[3] - 1.0).abs() < 0.1,\n            \"Product[1,1] should be ~1.0, got {}\",\n            values[3]\n        );\n    }\n\n    #[test]\n    fn test_tall_matrix_transpose() {\n        // Test that tall matrices (A > B) are transposed during Newton-Schulz iteration\n        // and then transposed back\n        let device = Default::default();\n\n        // Create a tall matrix: [8, 4] (more rows than columns)\n        let tall_matrix = Tensor::<TestBackend, 2>::from_floats(\n            [\n                [1.0, 0.5, 0.3, 0.2],\n                [0.5, 1.0, 0.4, 0.1],\n                [0.3, 0.4, 1.0, 0.5],\n                [0.2, 0.1, 0.5, 1.0],\n                [0.1, 0.2, 0.3, 0.4],\n                [0.4, 0.3, 0.2, 0.1],\n                [0.2, 0.4, 0.1, 0.3],\n                [0.3, 0.1, 0.4, 0.2],\n            ],\n            &device,\n        );\n\n        let config = MuonConfig::new();\n        let muon: Muon<TestBackend> = Muon {\n            momentum: Momentum::new(&config.momentum),\n            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),\n            weight_decay_penalty: None,\n            epsilon: config.epsilon,\n            adjust_lr_fn: config.adjust_lr_fn,\n        };\n\n        // Perform Newton-Schulz orthogonalization\n        let orthogonalized = muon.zeropower_via_newtonschulz(tall_matrix.clone());\n\n        // Verify shape is preserved (should be transposed internally but returned in original shape)\n        let original_shape = tall_matrix.shape();\n        let result_shape = orthogonalized.shape();\n        assert_eq!(\n            original_shape.dims::<2>(),\n            result_shape.dims::<2>(),\n            \"Shape should be preserved: [8, 4]\"\n        );\n\n        // Verify output is different from input (orthogonalization happened)\n        let original_data = tall_matrix.into_data();\n        let result_data = orthogonalized.into_data();\n        assert_ne!(\n            original_data.as_slice::<f32>().unwrap(),\n            result_data.as_slice::<f32>().unwrap(),\n            \"Orthogonalized matrix should differ from input\"\n        );\n\n        // For comparison, test a wide matrix [4, 8] should NOT be transposed\n        let wide_matrix = Tensor::<TestBackend, 2>::from_floats(\n            [\n                [1.0, 0.5, 0.3, 0.2, 0.1, 0.4, 0.2, 0.3],\n                [0.5, 1.0, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1],\n                [0.3, 0.4, 1.0, 0.5, 0.3, 0.2, 0.1, 0.4],\n                [0.2, 0.1, 0.5, 1.0, 0.4, 0.1, 0.3, 0.2],\n            ],\n            &device,\n        );\n\n        let orthogonalized_wide = muon.zeropower_via_newtonschulz(wide_matrix.clone());\n\n        // Verify wide matrix shape is also preserved\n        let wide_original_shape = wide_matrix.shape();\n        let wide_result_shape = orthogonalized_wide.shape();\n        assert_eq!(\n            wide_original_shape.dims::<2>(),\n            wide_result_shape.dims::<2>(),\n            \"Wide matrix shape should be preserved: [4, 8]\"\n        );\n    }\n\n    #[test]\n    fn test_zero_gradient() {\n        // Test that Muon handles zero gradients gracefully\n        let device = Default::default();\n\n        let tensor = Tensor::<TestBackend, 2>::from_floats(\n            [\n                [1.0, 0.5, 0.3, 0.2],\n                [0.5, 1.0, 0.4, 0.1],\n                [0.3, 0.4, 1.0, 0.5],\n                [0.2, 0.1, 0.5, 1.0],\n            ],\n            &device,\n        );\n\n        // Zero gradient - all zeros\n        let zero_grad = Tensor::<TestBackend, 2>::zeros([4, 4], &device);\n\n        let config = MuonConfig::new();\n        let muon: Muon<TestBackend> = Muon {\n            momentum: Momentum::new(&config.momentum),\n            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),\n            weight_decay_penalty: None,\n            epsilon: config.epsilon,\n            adjust_lr_fn: config.adjust_lr_fn,\n        };\n\n        // Should not panic or produce NaN\n        let (updated_tensor, state) = muon.step(0.01, tensor.clone(), zero_grad, None);\n\n        // Verify state was created\n        assert!(state.is_some());\n\n        // With zero gradient and no weight decay, tensor should remain unchanged\n        let original_data = tensor.into_data();\n        let updated_data = updated_tensor.clone().into_data();\n\n        let original_vals = original_data.as_slice::<f32>().unwrap();\n        let updated_vals = updated_data.as_slice::<f32>().unwrap();\n\n        for (orig, upd) in original_vals.iter().zip(updated_vals.iter()) {\n            assert!(\n                (orig - upd).abs() < 1e-6,\n                \"With zero gradient, tensor should remain unchanged (or very close)\"\n            );\n        }\n\n        // Verify no NaN values\n        for val in updated_vals {\n            assert!(\n                !val.is_nan(),\n                \"Result should not contain NaN values with zero gradient\"\n            );\n        }\n\n        // Test with weight decay - should still work\n        let muon_with_decay: Muon<TestBackend> = Muon {\n            momentum: Momentum::new(&config.momentum),\n            ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),\n            weight_decay_penalty: Some(0.01),\n            epsilon: config.epsilon,\n            adjust_lr_fn: config.adjust_lr_fn,\n        };\n\n        let tensor2 = Tensor::<TestBackend, 2>::from_floats(\n            [\n                [1.0, 0.5, 0.3, 0.2],\n                [0.5, 1.0, 0.4, 0.1],\n                [0.3, 0.4, 1.0, 0.5],\n                [0.2, 0.1, 0.5, 1.0],\n            ],\n            &device,\n        );\n        let zero_grad2 = Tensor::<TestBackend, 2>::zeros([4, 4], &device);\n\n        let (updated_tensor_decay, _) =\n            muon_with_decay.step(0.01, tensor2.clone(), zero_grad2, None);\n\n        // With zero gradient but with weight decay, tensor should be slightly reduced\n        let updated_decay_data = updated_tensor_decay.into_data();\n        let updated_decay_vals = updated_decay_data.as_slice::<f32>().unwrap();\n\n        for val in updated_decay_vals {\n            assert!(\n                !val.is_nan(),\n                \"Result should not contain NaN with zero gradient and weight decay\"\n            );\n        }\n\n        // With weight decay, values should be slightly smaller than original\n        let original_vals2 = tensor2.into_data().as_slice::<f32>().unwrap().to_vec();\n        for (orig, upd) in original_vals2.iter().zip(updated_decay_vals.iter()) {\n            if orig.abs() > 1e-6 {\n                // Non-zero values should be reduced by weight decay\n                assert!(\n                    upd.abs() < orig.abs(),\n                    \"Weight decay should reduce magnitude: original={}, updated={}\",\n                    orig,\n                    upd\n                );\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/rmsprop.rs",
    "content": "use burn_core as burn;\n\nuse burn::{module::AutodiffModule, record::Record};\n\nuse super::{\n    SimpleOptimizer,\n    adaptor::OptimizerAdaptor,\n    decay::{WeightDecay, WeightDecayConfig},\n};\nuse crate::{LearningRate, grad_clipping::GradientClippingConfig};\n\nuse burn::config::Config;\nuse burn::tensor::backend::Backend;\nuse burn::tensor::{Tensor, backend::AutodiffBackend, ops::Device};\n\n/// Configuration to create the [RmsProp](RmsProp) optimizer.\n#[derive(Config, Debug)]\npub struct RmsPropConfig {\n    /// Smoothing constant.\n    #[config(default = 0.99)]\n    alpha: f32,\n    /// momentum for RmsProp.\n    #[config(default = 0.9)]\n    momentum: f32,\n    /// A value required for numerical stability.\n    #[config(default = 1e-5)]\n    epsilon: f32,\n    /// if True, compute the centered RmsProp, the gradient is normalized by an estimation of its variance\n    #[config(default = false)]\n    centered: bool,\n    /// [Weight decay](WeightDecayConfig) config.\n    weight_decay: Option<WeightDecayConfig>,\n    /// [Gradient Clipping](GradientClippingConfig) config.\n    grad_clipping: Option<GradientClippingConfig>,\n}\n\nimpl RmsPropConfig {\n    /// Initialize RmsProp optimizer.\n    ///\n    /// # Returns\n    ///\n    /// Returns an optimizer that can be used to optimize a module.\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(\n        &self,\n    ) -> OptimizerAdaptor<RmsProp, M, B> {\n        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);\n\n        let mut optim = OptimizerAdaptor::from(RmsProp {\n            alpha: self.alpha,\n            centered: self.centered,\n            weight_decay,\n            momentum: RmsPropMomentum {\n                momentum: self.momentum,\n                epsilon: self.epsilon,\n            },\n        });\n\n        if let Some(config) = &self.grad_clipping {\n            optim = optim.with_grad_clipping(config.init());\n        }\n\n        optim\n    }\n}\n\n/// Optimizer that implements stochastic gradient descent with momentum.\n/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).\n#[derive(Clone)]\npub struct RmsProp {\n    alpha: f32,\n    // epsilon: f32,\n    centered: bool,\n    // momentum: Option<Momentum<B>>,\n    momentum: RmsPropMomentum,\n    weight_decay: Option<WeightDecay>,\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for RmsProp {\n    type State<const D: usize> = RmsPropState<B, D>;\n\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        mut grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        // fetch state for params\n        let mut state_square_avg = None;\n        let mut state_centered = None;\n        let mut state_momentum = None;\n        if let Some(state) = state {\n            state_square_avg = Some(state.square_avg);\n            state_centered = Some(state.centered);\n            state_momentum = state.momentum;\n        }\n\n        // weight_decay transform\n        if let Some(weight_decay) = &self.weight_decay {\n            grad = weight_decay.transform(grad, tensor.clone());\n        }\n\n        // square_avg transform\n        let (grad, state_square_avg) =\n            SquareAvgState::transform(self.alpha, grad, state_square_avg);\n\n        // centered transform\n        let (grad, state_square_avg, state_centered) = CenteredState::transform(\n            self.alpha,\n            self.centered,\n            grad,\n            state_square_avg,\n            state_centered,\n        );\n\n        // momentum transform\n        let (grad, state_centered, state_momentum) =\n            self.momentum\n                .transform(grad, state_centered, state_momentum);\n\n        // transition state\n        let state = RmsPropState::new(state_square_avg, state_centered, state_momentum);\n\n        // tensor param transform\n        let delta = grad.mul_scalar(lr);\n        (tensor - delta, Some(state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &Device<B>) -> Self::State<D> {\n        state.square_avg = state.square_avg.to_device(device);\n        state.centered = state.centered.to_device(device);\n        state.momentum = state.momentum.map(|momentum| momentum.to_device(device));\n        state\n    }\n}\n\n/// State of [RmsProp](RmsProp)\n#[derive(Record, Clone, new)]\npub struct RmsPropState<B: Backend, const D: usize> {\n    /// Current squared average state.\n    pub square_avg: SquareAvgState<B, D>,\n    /// Current centered state\n    pub centered: CenteredState<B, D>,\n    /// Current gradient momentum, if any.\n    pub momentum: Option<RmsPropMomentumState<B, D>>,\n}\n\n/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.\n#[derive(Record, Clone, new)]\npub struct SquareAvgState<B: Backend, const D: usize> {\n    /// Current squared average.\n    pub square_avg: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> SquareAvgState<B, D> {\n    /// transform [SquareAvgState] to the next step\n    fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {\n        match state {\n            Some(state) => {\n                let square_avg = state\n                    .square_avg\n                    .mul_scalar(alpha)\n                    .add(grad.clone().square().mul_scalar(1. - alpha));\n                (grad, Self { square_avg })\n            }\n            _ => {\n                let square_avg = grad.clone().square().mul_scalar(1. - alpha);\n                (grad, Self { square_avg })\n            }\n        }\n    }\n\n    /// Moves the state to a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move the state to.\n    ///\n    /// # Returns\n    ///\n    /// * `self` - Moved state.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.square_avg = self.square_avg.to_device(device);\n        self\n    }\n}\n\n/// [CenteredState](CenteredState) is to store and pass optimizer step params.\n#[derive(Record, Clone, new)]\npub struct CenteredState<B: Backend, const D: usize> {\n    /// The averaged gradient to calculate the centered gradient, if available.\n    pub grad_avg: Option<Tensor<B, D>>,\n    /// The current average value.\n    pub avg: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> CenteredState<B, D> {\n    /// transform [CenteredState] to the next step\n    fn transform(\n        alpha: f32,\n        centered: bool,\n        grad: Tensor<B, D>,\n        square_avg_state: SquareAvgState<B, D>,\n        centered_state: Option<Self>,\n    ) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {\n        if centered {\n            let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);\n            let grad_avg = match centered_state {\n                Some(state) => state\n                    .grad_avg\n                    .map_or(grad_avg_constant.clone(), move |grad_avg| {\n                        grad_avg.mul_scalar(alpha).add(grad_avg_constant)\n                    }),\n                _ => grad_avg_constant,\n            };\n            let avg = square_avg_state\n                .square_avg\n                .clone()\n                .sub(grad_avg.clone().square());\n\n            (\n                grad,\n                square_avg_state,\n                Self {\n                    grad_avg: Some(grad_avg),\n                    avg,\n                },\n            )\n        } else {\n            (\n                grad,\n                square_avg_state.clone(),\n                Self {\n                    grad_avg: None,\n                    avg: square_avg_state.square_avg,\n                },\n            )\n        }\n    }\n\n    /// Moves the state to a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move the state to.\n    ///\n    /// # Returns\n    ///\n    /// * `self` - Moved state.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));\n        self.avg = self.avg.to_device(device);\n        self\n    }\n}\n\n/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.\n/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)\n#[derive(Clone)]\npub struct RmsPropMomentum {\n    momentum: f32,\n    epsilon: f32,\n}\n\nimpl RmsPropMomentum {\n    /// transform [grad](Tensor) and [RmsPropMomentumState] to the next step\n    fn transform<B: Backend, const D: usize>(\n        &self,\n        grad: Tensor<B, D>,\n        centered_state: CenteredState<B, D>,\n        momentum_state: Option<RmsPropMomentumState<B, D>>,\n    ) -> (\n        Tensor<B, D>,\n        CenteredState<B, D>,\n        Option<RmsPropMomentumState<B, D>>,\n    ) {\n        let grad = grad.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));\n\n        if self.momentum > 0. {\n            let buf = match momentum_state {\n                Some(state) => state.buf.mul_scalar(self.momentum).add(grad),\n                _ => grad,\n            };\n            (\n                buf.clone(),\n                centered_state,\n                Some(RmsPropMomentumState { buf }),\n            )\n        } else {\n            (grad, centered_state, None)\n        }\n    }\n}\n\n/// [RmsPropMomentumState](RmsPropMomentumState) is to store and pass optimizer step params.\n#[derive(Record, Clone, new)]\npub struct RmsPropMomentumState<B: Backend, const D: usize> {\n    buf: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> RmsPropMomentumState<B, D> {\n    /// Moves the state to a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to move the state to.\n    ///\n    /// # Returns\n    ///\n    /// * `self` - Moved state.\n    pub fn to_device(mut self, device: &B::Device) -> Self {\n        self.buf = self.buf.to_device(device);\n        self\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn::tensor::ops::FloatElem;\n    use burn::tensor::{Shape, Tolerance};\n\n    use super::*;\n    use crate::TestAutodiffBackend;\n    use crate::optim::{GradientsParams, Optimizer};\n    use burn::module::{Module, Param};\n    use burn::tensor::{Distribution, Tensor, TensorData};\n    use burn_nn::{Linear, LinearConfig, LinearRecord};\n\n    type FT = FloatElem<TestAutodiffBackend>;\n\n    const LEARNING_RATE: LearningRate = 0.01;\n\n    #[test]\n    fn test_rmsprop_optimizer_save_load_state() {\n        let device = Default::default();\n        let linear = LinearConfig::new(6, 6).init(&device);\n        let x = Tensor::<TestAutodiffBackend, 2>::random([2, 6], Distribution::Default, &device);\n        let mut optimizer = create_rmsprop();\n        let grads = linear.forward(x).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let _linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        #[cfg(feature = \"std\")]\n        {\n            use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};\n\n            BinFileRecorder::<FullPrecisionSettings>::default()\n                .record(\n                    optimizer.to_record(),\n                    std::env::temp_dir().as_path().join(\"test_optim_rmsprop\"),\n                )\n                .unwrap();\n        }\n        #[cfg(not(feature = \"std\"))]\n        {\n            use burn::record::{BinBytesRecorder, FullPrecisionSettings, Recorder};\n\n            let result = BinBytesRecorder::<FullPrecisionSettings>::default()\n                .record(optimizer.to_record(), ())\n                .unwrap();\n            assert!(!result.is_empty());\n        }\n\n        let state_optim_before = optimizer.to_record();\n        let state_optim_before_copy = optimizer.to_record();\n        let optimizer = create_rmsprop();\n        let optimizer = optimizer.load_record(state_optim_before_copy);\n        let state_optim_after = optimizer.to_record();\n\n        assert_eq!(state_optim_before.len(), state_optim_after.len());\n    }\n\n    /// used for test differences and debug\n    #[test]\n    fn test_rmsprop_optimizer_with_numbers_basic() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [1., 1., 1., 1., 1., 1.],\n                [1., 1., 1., 1., 1., 1.],\n                [1., 1., 1., 1., 1., 1.],\n                [1., 1., 1., 1., 1., 1.],\n                [1., 1., 1., 1., 1., 1.],\n                [1., 1., 1., 1., 1., 1.],\n            ]),\n            TensorData::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),\n        );\n        let device = Default::default();\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = RmsPropConfig::new()\n            .with_alpha(0.99)\n            .with_epsilon(1e-8)\n            .with_weight_decay(WeightDecayConfig::new(0.05).into())\n            .with_momentum(0.9)\n            .with_centered(false)\n            .init();\n\n        // println!(\"linear is {:?}\", linear);\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        // println!(\"linear is {:?}\", linear);\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        // println!(\"linear is {:?}\", linear);\n        let state_updated = linear.into_record();\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.to_data(),\n            state_updated.bias.unwrap().to_data(),\n        );\n\n        // println!(\"\\nweight_updated\\n{:?}\", weight_updated);\n        // println!(\"\\nbias_updated\\n{:?}\", bias_updated);\n\n        let weights_expected = TensorData::from([\n            [0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],\n            [0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],\n            [0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],\n            [0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],\n            [0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],\n            [0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],\n        ]);\n        let bias_expected =\n            TensorData::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);\n\n        let tolerance = Tolerance::absolute(1e-6);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    #[test]\n    fn test_rmsprop_optimizer_with_numbers() {\n        let linear = given_linear_layer(\n            TensorData::from([\n                [-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],\n                [0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],\n                [-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],\n                [-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],\n                [0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],\n                [-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],\n            ]),\n            TensorData::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),\n        );\n        let device = Default::default();\n        let x_1 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],\n                [0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],\n            ],\n            &device,\n        )\n        .require_grad();\n        let x_2 = Tensor::<TestAutodiffBackend, 2>::from_floats(\n            [\n                [0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],\n                [0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],\n            ],\n            &device,\n        )\n        .require_grad();\n\n        let mut optimizer = RmsPropConfig::new()\n            .with_alpha(0.99)\n            .with_epsilon(1e-8)\n            .with_weight_decay(WeightDecayConfig::new(0.05).into())\n            .with_momentum(0.9)\n            .with_centered(false)\n            .init();\n\n        let grads = linear.forward(x_1).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let grads = linear.forward(x_2).backward();\n        let grads = GradientsParams::from_grads(grads, &linear);\n        let linear = optimizer.step(LEARNING_RATE, linear, grads);\n\n        let state_updated = linear.into_record();\n        let weights_expected = TensorData::from([\n            [\n                -0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,\n            ],\n            [\n                -0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,\n            ],\n            [\n                -0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,\n            ],\n            [\n                -0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,\n            ],\n            [\n                0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,\n            ],\n            [\n                -0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,\n            ],\n        ]);\n        let bias_expected = TensorData::from([\n            -0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,\n        ]);\n\n        let (weight_updated, bias_updated) = (\n            state_updated.weight.to_data(),\n            state_updated.bias.unwrap().to_data(),\n        );\n\n        // println!(\"\\nweight_updated\\n{:?}\", weight_updated);\n        // println!(\"\\nbias_updated\\n{:?}\", bias_updated);\n\n        let tolerance = Tolerance::absolute(1e-6);\n        bias_updated.assert_approx_eq::<FT>(&bias_expected, tolerance);\n        weight_updated.assert_approx_eq::<FT>(&weights_expected, tolerance);\n    }\n\n    fn given_linear_layer(weight: TensorData, bias: TensorData) -> Linear<TestAutodiffBackend> {\n        let device = Default::default();\n        let record = LinearRecord {\n            weight: Param::from_data(weight, &device),\n            bias: Some(Param::from_data(bias, &device)),\n        };\n\n        LinearConfig::new(6, 6).init(&device).load_record(record)\n    }\n\n    #[allow(dead_code)]\n    fn create_random_tensor() -> Tensor<TestAutodiffBackend, 2> {\n        Tensor::<TestAutodiffBackend, 2>::random(\n            Shape::new([2, 20]),\n            Distribution::Default,\n            &Default::default(),\n        )\n    }\n\n    fn create_rmsprop()\n    -> OptimizerAdaptor<RmsProp, Linear<TestAutodiffBackend>, TestAutodiffBackend> {\n        RmsPropConfig {\n            alpha: 0.99,\n            epsilon: 1e-9,\n            centered: false,\n            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),\n            momentum: 0.9,\n            grad_clipping: None,\n        }\n        .init()\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/sgd.rs",
    "content": "use burn_core as burn;\n\nuse super::SimpleOptimizer;\nuse super::adaptor::OptimizerAdaptor;\nuse super::decay::{WeightDecay, WeightDecayConfig};\nuse super::momentum::{Momentum, MomentumConfig, MomentumState};\nuse crate::LearningRate;\nuse crate::grad_clipping::GradientClippingConfig;\nuse burn::config::Config;\nuse burn::module::AutodiffModule;\nuse burn::record::Record;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::{AutodiffBackend, Backend};\n\n/// Configuration to create the [Sgd](Sgd) optimizer.\n#[derive(Config, Debug)]\npub struct SgdConfig {\n    /// [Weight decay](WeightDecayConfig) config.\n    weight_decay: Option<WeightDecayConfig>,\n    /// [Momentum](MomentumConfig) config.\n    momentum: Option<MomentumConfig>,\n    /// [Gradient Clipping](GradientClippingConfig) config.\n    gradient_clipping: Option<GradientClippingConfig>,\n}\n\n/// Optimizer that implements stochastic gradient descent with momentum.\n///\n/// The optimizer can be configured with [SgdConfig](SgdConfig).\n#[derive(Clone)]\npub struct Sgd<B: Backend> {\n    momentum: Option<Momentum<B>>,\n    weight_decay: Option<WeightDecay>,\n}\n\n/// State of [Sgd](Sgd).\n#[derive(Record, Clone, new)]\npub struct SgdState<B: Backend, const D: usize> {\n    /// The current state of the momentum (if any).\n    pub momentum: Option<MomentumState<B, D>>,\n}\n\nimpl SgdConfig {\n    /// Creates a new [SgdConfig](SgdConfig) with default values.\n    pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(\n        &self,\n    ) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {\n        let momentum = self.momentum.as_ref().map(Momentum::new);\n        let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);\n\n        let mut optim = OptimizerAdaptor::from(Sgd {\n            momentum,\n            weight_decay,\n        });\n        if let Some(config) = &self.gradient_clipping {\n            optim = optim.with_grad_clipping(config.init());\n        }\n        optim\n    }\n}\n\nimpl<B: Backend> SimpleOptimizer<B> for Sgd<B> {\n    type State<const D: usize> = SgdState<B, D>;\n\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        mut grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>) {\n        let mut state_momentum = None;\n\n        if let Some(state) = state {\n            state_momentum = state.momentum;\n        }\n\n        if let Some(weight_decay) = &self.weight_decay {\n            grad = weight_decay.transform(grad, tensor.clone());\n        }\n\n        if let Some(momentum) = &self.momentum {\n            let (grad_out, state) = momentum.transform(grad, state_momentum);\n            state_momentum = Some(state);\n            grad = grad_out;\n        }\n\n        let state = SgdState::new(state_momentum);\n        let delta = grad.mul_scalar(lr);\n\n        (tensor - delta, Some(state))\n    }\n\n    fn to_device<const D: usize>(mut state: Self::State<D>, device: &B::Device) -> Self::State<D> {\n        state.momentum = state.momentum.map(|state| state.to_device(device));\n        state\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{\n        TestAutodiffBackend, TestBackend,\n        grad_clipping::GradientClipping,\n        optim::{GradientsParams, Optimizer},\n    };\n    use burn::tensor::{Distribution, Shape};\n    use burn_nn::{Linear, LinearConfig};\n\n    const LEARNING_RATE: LearningRate = 0.02;\n\n    #[test]\n    fn with_updated_params_should_have_state() {\n        let device = Default::default();\n        let layer = layer::<TestAutodiffBackend>(&device);\n        let mut optim = sgd_with_all();\n        let loss = layer.forward(random_tensor::<TestAutodiffBackend>(&device));\n        let grads = loss.backward();\n        let grads = GradientsParams::from_grads(grads, &layer);\n        let _layer = optim.step(LEARNING_RATE, layer, grads);\n\n        let record = optim.to_record();\n\n        assert!(!record.is_empty());\n    }\n\n    #[test]\n    fn without_updated_params_should_not_have_state() {\n        let optim = sgd_with_all();\n        let record = optim.to_record();\n        assert!(record.is_empty());\n    }\n\n    #[test]\n    fn can_attach_gradient_clipping() {\n        let optim = sgd_with_all().with_grad_clipping(GradientClipping::Value(0.5));\n        assert!(optim.has_gradient_clipping());\n    }\n\n    #[test]\n    fn should_load_state() {\n        let device = Default::default();\n        let layer = layer::<TestAutodiffBackend>(&device);\n        let mut optim = sgd_with_all();\n        let loss = layer.forward(random_tensor(&device));\n        let grads = loss.backward();\n        let grads = GradientsParams::from_grads(grads, &layer);\n        let _layer = optim.step(LEARNING_RATE, layer, grads);\n\n        let record = optim.to_record();\n        let optim_new = sgd_with_all();\n        let record_new = optim_new.to_record();\n        let optim_new = optim_new.load_record(record.clone());\n        let state_restored = optim_new.to_record();\n\n        assert_ne!(record.len(), record_new.len());\n        assert_eq!(record.len(), state_restored.len());\n    }\n\n    fn random_tensor<B: Backend>(device: &B::Device) -> Tensor<B, 2> {\n        Tensor::<B, 2>::random(Shape::new([2, 20]), Distribution::Default, device)\n    }\n\n    fn layer<B: Backend>(device: &B::Device) -> Linear<B> {\n        LinearConfig::new(20, 20).init(device)\n    }\n\n    fn sgd_with_all()\n    -> OptimizerAdaptor<Sgd<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend> {\n        SgdConfig {\n            weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),\n            momentum: Some(MomentumConfig {\n                momentum: 0.9,\n                dampening: 0.1,\n                nesterov: true,\n            }),\n            gradient_clipping: None,\n        }\n        .init()\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/adaptor.rs",
    "content": "use burn_core::{self as burn, prelude::Backend, tensor::Device};\n\nuse super::{SimpleOptimizer, record::AdaptorRecord};\nuse crate::{\n    LearningRate, MultiGradientsParams,\n    grad_clipping::GradientClipping,\n    optim::{GradientsParams, Optimizer},\n};\n\nuse burn::module::{AutodiffModule, ModuleMapper, Param, ParamId};\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse core::marker::PhantomData;\nuse hashbrown::HashMap;\n\n/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into\n/// an [optimizer](Optimizer).\n#[derive(Clone)]\npub struct OptimizerAdaptor<O, M, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    M: AutodiffModule<B>,\n    B: AutodiffBackend,\n{\n    optim: O,\n    records: HashMap<ParamId, AdaptorRecord<O, B>>,\n    module: PhantomData<M>,\n    grad_clipping: Option<GradientClipping>,\n}\n\nimpl<O, B, M> From<O> for OptimizerAdaptor<O, M, B>\nwhere\n    B: AutodiffBackend,\n    M: AutodiffModule<B>,\n    O: SimpleOptimizer<B::InnerBackend>,\n{\n    fn from(optim: O) -> Self {\n        Self {\n            optim,\n            records: HashMap::new(),\n            module: PhantomData,\n            grad_clipping: None,\n        }\n    }\n}\n\nimpl<O, M, B> OptimizerAdaptor<O, M, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    M: AutodiffModule<B>,\n    B: AutodiffBackend,\n{\n    /// Sets the gradient clipping.\n    ///\n    /// # Arguments\n    ///\n    /// * `gradient_clipping` - The gradient clipping.\n    ///\n    /// # Returns\n    ///\n    /// The optimizer.\n    pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {\n        self.grad_clipping = Some(gradient_clipping);\n        self\n    }\n\n    #[cfg(test)]\n    pub(crate) fn has_gradient_clipping(&self) -> bool {\n        self.grad_clipping.is_some()\n    }\n}\n\nimpl<O, B, M> Optimizer<M, B> for OptimizerAdaptor<O, M, B>\nwhere\n    B: AutodiffBackend,\n    M: AutodiffModule<B>,\n    O: SimpleOptimizer<B::InnerBackend>,\n{\n    type Record = HashMap<ParamId, AdaptorRecord<O, B>>;\n\n    fn step(&mut self, lr: LearningRate, module: M, grads: GradientsParams) -> M {\n        let mut grads = GradAdaptor::Single(grads);\n\n        let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(\n            &self.optim,\n            &mut self.records,\n            &mut grads,\n            lr,\n            self.grad_clipping.as_ref(),\n        );\n        module.map(&mut mapper)\n    }\n\n    fn step_multi(&mut self, lr: LearningRate, module: M, grads: crate::MultiGradientsParams) -> M {\n        let mut grads = GradAdaptor::Multi(grads);\n\n        let mut mapper = SimpleOptimizerMapper::<M, B, O>::new(\n            &self.optim,\n            &mut self.records,\n            &mut grads,\n            lr,\n            self.grad_clipping.as_ref(),\n        );\n        module.map(&mut mapper)\n    }\n\n    fn to_record(&self) -> Self::Record {\n        self.records.clone()\n    }\n\n    fn load_record(mut self, record: Self::Record) -> Self {\n        self.records = record;\n        self\n    }\n}\n\nenum GradAdaptor {\n    Single(GradientsParams),\n    Multi(MultiGradientsParams),\n}\n\nimpl GradAdaptor {\n    fn remove<B: Backend, const D: usize>(\n        &mut self,\n        id: ParamId,\n    ) -> Option<(Tensor<B, D>, Device<B>)> {\n        match self {\n            GradAdaptor::Single(grads) => grads.remove(id).map(|t| {\n                let device = t.device();\n                (t, device)\n            }),\n            GradAdaptor::Multi(grads) => grads.remove(id),\n        }\n    }\n}\n\n#[derive(new)]\nstruct SimpleOptimizerMapper<'a, M, B, O>\nwhere\n    M: AutodiffModule<B>,\n    B: AutodiffBackend,\n    O: SimpleOptimizer<B::InnerBackend>,\n{\n    optimizer: &'a O,\n    records: &'a mut HashMap<ParamId, AdaptorRecord<O, B>>,\n    grads: &'a mut GradAdaptor,\n    lr: LearningRate,\n    phantom: PhantomData<M>,\n    grad_clipping: Option<&'a GradientClipping>,\n}\n\nimpl<M, B, O> ModuleMapper<B> for SimpleOptimizerMapper<'_, M, B, O>\nwhere\n    M: AutodiffModule<B>,\n    B: AutodiffBackend,\n    O: SimpleOptimizer<B::InnerBackend>,\n{\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        let grad = self.grads.remove(id);\n\n        let tensor = if let Some((grad, device)) = grad {\n            let is_require_grad = tensor.is_require_grad();\n            let (key, record) = self.records.remove_entry(&id).unzip();\n            let tensor = if tensor.device() != device {\n                tensor.to_device(&device)\n            } else {\n                tensor\n            };\n\n            debug_assert_eq!(\n                grad.device(),\n                device,\n                \"The gradient is on the provided device\"\n            );\n            let clipped_grad = if let Some(g_clipping) = self.grad_clipping {\n                g_clipping.clip_gradient(grad)\n            } else {\n                grad\n            };\n\n            debug_assert_eq!(\n                tensor.device(),\n                device,\n                \"Tensor and gradients are on the same device.\"\n            );\n\n            let (tensor, state) = self.optimizer.step(\n                self.lr,\n                tensor.inner(),\n                clipped_grad,\n                record.map(|record| O::to_device(record.into_state(), &device)),\n            );\n\n            if let Some(state) = state {\n                self.records\n                    .insert(key.unwrap_or(id), AdaptorRecord::from_state(state));\n            }\n\n            let mut tensor = Tensor::from_inner(tensor);\n            if is_require_grad {\n                tensor = tensor.require_grad();\n            }\n            tensor\n        } else {\n            tensor\n        };\n\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/base.rs",
    "content": "use burn_core as burn;\n\nuse crate::LearningRate;\nuse burn::record::Record;\nuse burn::tensor::{Tensor, backend::Backend};\n\n/// Simple optimizer is an opinionated trait to simplify the process of implementing an\n/// optimizer.\n///\n/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the\n/// module parameter structure, handle tracked and untracked tensors, and the likes.\npub trait SimpleOptimizer<B>: Send + Sync + Clone\nwhere\n    B: Backend,\n{\n    /// The state of the optimizer. It also implements [record](Record), so that it can be saved.\n    type State<const D: usize>: Record<B> + Clone + 'static;\n\n    /// The optimizer step is performed for one tensor at a time with its gradient and state.\n    ///\n    /// Note that the state is passed as parameter, so implementations don't have to handle\n    /// the saving and loading of recorded states.\n    fn step<const D: usize>(\n        &self,\n        lr: LearningRate,\n        tensor: Tensor<B, D>,\n        grad: Tensor<B, D>,\n        state: Option<Self::State<D>>,\n    ) -> (Tensor<B, D>, Option<Self::State<D>>);\n\n    /// Change the device of the state.\n    ///\n    /// This function will be called accordingly to have the state on the same device as the\n    /// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.\n    fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/mod.rs",
    "content": "mod base;\npub use base::*;\n\n/// Adaptor module for optimizers.\npub mod adaptor;\n\n/// Record module for optimizers.\npub mod record;\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/record/base.rs",
    "content": "use burn_core as burn;\n\nuse super::{AdaptorRecordItemV1, AdaptorRecordV1};\nuse crate::optim::SimpleOptimizer;\nuse burn::record::{PrecisionSettings, Record};\nuse burn::tensor::backend::AutodiffBackend;\nuse serde::{Deserialize, Serialize};\n\n/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record.\n///\n/// Records are versioned for backward compatibility, so old records can be loaded.\npub enum AdaptorRecord<O, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    B: AutodiffBackend,\n{\n    /// Version 1.\n    V1(AdaptorRecordV1<O, B::InnerBackend>),\n}\n\n/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.\n#[derive(Serialize, Deserialize, Clone)]\n#[serde(bound = \"\")]\npub enum AdaptorRecordItem<\n    O: SimpleOptimizer<B::InnerBackend>,\n    B: AutodiffBackend,\n    S: PrecisionSettings,\n> {\n    /// Version 1.\n    V1(AdaptorRecordItemV1<O, B::InnerBackend, S>),\n}\n\nimpl<O, B> Record<B> for AdaptorRecord<O, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    B: AutodiffBackend,\n{\n    type Item<S: PrecisionSettings> = AdaptorRecordItem<O, B, S>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        match self {\n            AdaptorRecord::V1(record) => AdaptorRecordItem::V1(record.into_item()),\n        }\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        match item {\n            AdaptorRecordItem::V1(item) => Self::V1(AdaptorRecordV1::from_item(item, device)),\n        }\n    }\n}\n\nimpl<O, B> Clone for AdaptorRecord<O, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    B: AutodiffBackend,\n{\n    fn clone(&self) -> Self {\n        match self {\n            AdaptorRecord::V1(record) => Self::V1(record.clone()),\n        }\n    }\n}\n\nimpl<O, B> AdaptorRecord<O, B>\nwhere\n    O: SimpleOptimizer<B::InnerBackend>,\n    B: AutodiffBackend,\n{\n    /// Converts the record into the optimizer state.\n    ///\n    /// # Returns\n    ///\n    /// The optimizer state.\n    pub fn into_state<const D: usize>(self) -> O::State<D> {\n        match self {\n            AdaptorRecord::V1(record) => record.into_state(),\n        }\n    }\n\n    /// Converts the optimizer state into the record.\n    ///\n    /// # Arguments\n    ///\n    /// * `state`: The optimizer state.\n    ///\n    /// # Returns\n    ///\n    /// The record.\n    pub fn from_state<const D: usize>(state: O::State<D>) -> Self {\n        Self::V1(AdaptorRecordV1::from_state(state))\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/record/mod.rs",
    "content": "mod base;\nmod v1;\n\npub use base::*;\npub use v1::*;\n"
  },
  {
    "path": "crates/burn-optim/src/optim/simple/record/v1.rs",
    "content": "use burn_core as burn;\n\nuse crate::optim::SimpleOptimizer;\nuse burn::record::{PrecisionSettings, Record};\nuse burn::tensor::backend::Backend;\nuse core::any::Any;\nuse serde::{Deserialize, Serialize};\n\n#[cfg(not(feature = \"std\"))]\nuse alloc::boxed::Box;\n\n/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.\npub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {\n    /// Rank 0.\n    Rank0(O::State<0>),\n\n    /// Rank 1.\n    Rank1(O::State<1>),\n\n    /// Rank 2.\n    Rank2(O::State<2>),\n\n    /// Rank 3.\n    Rank3(O::State<3>),\n\n    /// Rank 4.\n    Rank4(O::State<4>),\n\n    /// Rank 5.\n    Rank5(O::State<5>),\n\n    /// Rank 6.\n    Rank6(O::State<6>),\n\n    /// Rank 7.\n    Rank7(O::State<7>),\n\n    /// Rank 8.\n    Rank8(O::State<8>),\n}\n\nimpl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {\n    fn clone(&self) -> Self {\n        match self {\n            AdaptorRecordV1::Rank0(record) => AdaptorRecordV1::Rank0(record.clone()),\n            AdaptorRecordV1::Rank1(record) => AdaptorRecordV1::Rank1(record.clone()),\n            AdaptorRecordV1::Rank2(record) => AdaptorRecordV1::Rank2(record.clone()),\n            AdaptorRecordV1::Rank3(record) => AdaptorRecordV1::Rank3(record.clone()),\n            AdaptorRecordV1::Rank4(record) => AdaptorRecordV1::Rank4(record.clone()),\n            AdaptorRecordV1::Rank5(record) => AdaptorRecordV1::Rank5(record.clone()),\n            AdaptorRecordV1::Rank6(record) => AdaptorRecordV1::Rank6(record.clone()),\n            AdaptorRecordV1::Rank7(record) => AdaptorRecordV1::Rank7(record.clone()),\n            AdaptorRecordV1::Rank8(record) => AdaptorRecordV1::Rank8(record.clone()),\n        }\n    }\n}\n\n/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.\n#[derive(Serialize, Deserialize, Clone)]\n#[serde(bound = \"\")]\npub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {\n    /// Rank 0.\n    Rank0(<O::State<0> as Record<B>>::Item<S>),\n\n    /// Rank 1.\n    Rank1(<O::State<1> as Record<B>>::Item<S>),\n\n    /// Rank 2.\n    Rank2(<O::State<2> as Record<B>>::Item<S>),\n\n    /// Rank 3.\n    Rank3(<O::State<3> as Record<B>>::Item<S>),\n\n    /// Rank 4.\n    Rank4(<O::State<4> as Record<B>>::Item<S>),\n\n    /// Rank 5.\n    Rank5(<O::State<5> as Record<B>>::Item<S>),\n\n    /// Rank 6.\n    Rank6(<O::State<6> as Record<B>>::Item<S>),\n\n    /// Rank 7.\n    Rank7(<O::State<7> as Record<B>>::Item<S>),\n\n    /// Rank 8.\n    Rank8(<O::State<8> as Record<B>>::Item<S>),\n}\n\nimpl<O, B> AdaptorRecordV1<O, B>\nwhere\n    O: SimpleOptimizer<B>,\n    B: Backend,\n{\n    /// Convert the record into the state.\n    ///\n    /// # Returns\n    ///\n    /// The state.\n    ///\n    /// # Panics\n    ///\n    /// Panics if the state dimension is not supported.\n    pub fn into_state<const D: usize>(self) -> O::State<D> {\n        let boxed_state: Box<dyn Any> = match self {\n            AdaptorRecordV1::Rank0(s) => Box::new(s),\n            AdaptorRecordV1::Rank1(s) => Box::new(s),\n            AdaptorRecordV1::Rank2(s) => Box::new(s),\n            AdaptorRecordV1::Rank3(s) => Box::new(s),\n            AdaptorRecordV1::Rank4(s) => Box::new(s),\n            AdaptorRecordV1::Rank5(s) => Box::new(s),\n            AdaptorRecordV1::Rank6(s) => Box::new(s),\n            AdaptorRecordV1::Rank7(s) => Box::new(s),\n            AdaptorRecordV1::Rank8(s) => Box::new(s),\n        };\n        let state = boxed_state\n            .downcast::<O::State<D>>()\n            .expect(\"Unsupported state dimension, dimension up to 8 are supported.\");\n        *state\n    }\n\n    /// Convert the state into the record.\n    ///\n    /// # Arguments\n    ///\n    /// * `state`: The state.\n    ///\n    /// # Returns\n    ///\n    /// The record.\n    pub fn from_state<const D: usize>(state: O::State<D>) -> Self {\n        let state: Box<dyn Any> = Box::new(state);\n\n        match D {\n            0 => AdaptorRecordV1::Rank0(*state.downcast().unwrap()),\n            1 => AdaptorRecordV1::Rank1(*state.downcast().unwrap()),\n            2 => AdaptorRecordV1::Rank2(*state.downcast().unwrap()),\n            3 => AdaptorRecordV1::Rank3(*state.downcast().unwrap()),\n            4 => AdaptorRecordV1::Rank4(*state.downcast().unwrap()),\n            5 => AdaptorRecordV1::Rank5(*state.downcast().unwrap()),\n            6 => AdaptorRecordV1::Rank6(*state.downcast().unwrap()),\n            7 => AdaptorRecordV1::Rank7(*state.downcast().unwrap()),\n            8 => AdaptorRecordV1::Rank8(*state.downcast().unwrap()),\n            _ => panic!(\"Unsupported state dimension, dimension up to 8 are supported.\"),\n        }\n    }\n}\n\nimpl<O, B> Record<B> for AdaptorRecordV1<O, B>\nwhere\n    O: SimpleOptimizer<B>,\n    B: Backend,\n{\n    type Item<S: PrecisionSettings> = AdaptorRecordItemV1<O, B, S>;\n\n    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {\n        match self {\n            AdaptorRecordV1::Rank0(record) => AdaptorRecordItemV1::Rank0(record.into_item()),\n            AdaptorRecordV1::Rank1(record) => AdaptorRecordItemV1::Rank1(record.into_item()),\n            AdaptorRecordV1::Rank2(record) => AdaptorRecordItemV1::Rank2(record.into_item()),\n            AdaptorRecordV1::Rank3(record) => AdaptorRecordItemV1::Rank3(record.into_item()),\n            AdaptorRecordV1::Rank4(record) => AdaptorRecordItemV1::Rank4(record.into_item()),\n            AdaptorRecordV1::Rank5(record) => AdaptorRecordItemV1::Rank5(record.into_item()),\n            AdaptorRecordV1::Rank6(record) => AdaptorRecordItemV1::Rank6(record.into_item()),\n            AdaptorRecordV1::Rank7(record) => AdaptorRecordItemV1::Rank7(record.into_item()),\n            AdaptorRecordV1::Rank8(record) => AdaptorRecordItemV1::Rank8(record.into_item()),\n        }\n    }\n\n    fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {\n        match item {\n            AdaptorRecordItemV1::Rank0(item) => {\n                AdaptorRecordV1::Rank0(<O::State<0> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank1(item) => {\n                AdaptorRecordV1::Rank1(<O::State<1> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank2(item) => {\n                AdaptorRecordV1::Rank2(<O::State<2> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank3(item) => {\n                AdaptorRecordV1::Rank3(<O::State<3> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank4(item) => {\n                AdaptorRecordV1::Rank4(<O::State<4> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank5(item) => {\n                AdaptorRecordV1::Rank5(<O::State<5> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank6(item) => {\n                AdaptorRecordV1::Rank6(<O::State<6> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank7(item) => {\n                AdaptorRecordV1::Rank7(<O::State<7> as Record<B>>::from_item(item, device))\n            }\n            AdaptorRecordItemV1::Rank8(item) => {\n                AdaptorRecordV1::Rank8(<O::State<8> as Record<B>>::from_item(item, device))\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-optim/src/optim/visitor.rs",
    "content": "use burn_core as burn;\n\nuse super::GradientsParams;\nuse burn::module::{AutodiffModule, ModuleVisitor, Param, ParamId};\nuse burn::tensor::{Tensor, backend::AutodiffBackend};\nuse core::marker::PhantomData;\n\n#[cfg(not(feature = \"std\"))]\nuse alloc::vec::Vec;\n\n#[derive(new)]\npub struct GradientsParamsConverter<'a, M: AutodiffModule<B>, B: AutodiffBackend> {\n    grads: &'a mut B::Gradients,\n    grads_params: &'a mut GradientsParams,\n    phatom: PhantomData<M>,\n    filter: Option<Vec<ParamId>>,\n}\n\n#[derive(new)]\npub struct GradientsParamsChangeDevice<'a, M: AutodiffModule<B>, B: AutodiffBackend> {\n    device: &'a B::Device,\n    grads: &'a mut GradientsParams,\n    phatom: PhantomData<M>,\n}\n\nimpl<B, M> ModuleVisitor<B> for GradientsParamsConverter<'_, M, B>\nwhere\n    B: AutodiffBackend,\n    M: AutodiffModule<B>,\n{\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        if let Some(filter) = self.filter.as_ref()\n            && !filter.contains(&param.id)\n        {\n            return;\n        }\n\n        let Some(grad) = param.val().grad_remove(self.grads) else {\n            return;\n        };\n\n        self.grads_params\n            .register::<B::InnerBackend, D>(param.id, grad);\n    }\n}\n\nimpl<B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'_, M, B>\nwhere\n    B: AutodiffBackend,\n    M: AutodiffModule<B>,\n{\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        let Some(grad) = self.grads.remove::<B::InnerBackend, D>(param.id) else {\n            return;\n        };\n\n        self.grads\n            .register::<B::InnerBackend, D>(param.id, grad.to_device(self.device));\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Backend router decorator over the network.\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-remote\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-router-remote\"\ndocumentation = \"https://docs.rs/burn-router-remote\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"client\", \"server\"]\ndoc = []\ntracing = [\n    \"burn-communication/tracing\",\n    \"burn-ir/tracing\",\n    \"burn-router/tracing\",\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n]\n\nclient = [\"tokio-tungstenite\", \"async-channel\", \"tokio/sync\"]\nserver = [\n    \"tokio-tungstenite\",\n    \"async-channel\",\n    \"tokio/sync\",\n    \"axum\",\n    \"tracing-core/default\",\n    \"tracing-subscriber/default\",\n]\n\n\n[dependencies]\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-router = { path = \"../burn-router\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-communication = { path = \"../burn-communication\", version = \"=0.21.0-pre.2\", features = [\n    \"data-service\",\n    \"websocket\",\n] }\n\nbytes = { workspace = true }\n\n# Basic dependencies\nderive-new = { workspace = true }\nlog = { workspace = true }\n\n# Shared dependencies\ntokio = { workspace = true, features = [\"rt-multi-thread\"] }\nserde = { workspace = true, features = [\"derive\"] }\nserde_bytes = { workspace = true }\nrmp-serde = { workspace = true }\nfutures-util = { workspace = true }\n\n# Client dependencies\nasync-channel = { workspace = true, optional = true }\ntokio-tungstenite = { workspace = true, optional = true }\n\n# Server dependencies\naxum = { workspace = true, features = [\"ws\"], optional = true }\ntracing-core = { workspace = true, optional = true }\ntracing-subscriber = { workspace = true, optional = true }\ntokio-util = { workspace = true }\n\n[dev-dependencies]\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-remote/README.md",
    "content": ""
  },
  {
    "path": "crates/burn-remote/src/client/base.rs",
    "content": "pub use super::RemoteDevice;\nuse super::worker::{ClientRequest, ClientWorker};\nuse crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponseContent};\nuse async_channel::{RecvError, SendError, Sender};\nuse burn_communication::ProtocolClient;\nuse burn_ir::TensorId;\nuse burn_std::id::StreamId;\nuse std::{\n    future::Future,\n    sync::{Arc, atomic::AtomicU64},\n};\n\n#[derive(Clone)]\npub struct RemoteClient {\n    pub(crate) device: RemoteDevice,\n    pub(crate) sender: Arc<RemoteSender>,\n    pub(crate) runtime: Arc<tokio::runtime::Runtime>,\n}\n\nimpl RemoteClient {\n    pub fn init<C: ProtocolClient>(device: RemoteDevice) -> Self {\n        ClientWorker::<C>::start(device)\n    }\n\n    pub(crate) fn new(\n        device: RemoteDevice,\n        sender: Sender<ClientRequest>,\n        runtime: Arc<tokio::runtime::Runtime>,\n        session_id: SessionId,\n    ) -> Self {\n        Self {\n            device,\n            runtime,\n            sender: Arc::new(RemoteSender {\n                sender,\n                position_counter: AtomicU64::new(0),\n                tensor_id_counter: AtomicU64::new(0),\n                session_id,\n            }),\n        }\n    }\n}\n\npub(crate) struct RemoteSender {\n    sender: Sender<ClientRequest>,\n    position_counter: AtomicU64,\n    tensor_id_counter: AtomicU64,\n    session_id: SessionId,\n}\n\n#[allow(unused)]\n#[derive(Debug)]\npub enum RemoteSendError {\n    SendError(SendError<ClientRequest>),\n    RecvError(RecvError),\n}\n\nimpl RemoteSender {\n    /// Generate a new unique (for this [`RemoteSender`] [`TensorId`].\n    pub(crate) fn new_tensor_id(&self) -> TensorId {\n        TensorId::new(\n            self.tensor_id_counter\n                .fetch_add(1, std::sync::atomic::Ordering::Relaxed),\n        )\n    }\n\n    /// Give the next operation sequence number.\n    fn next_position(&self) -> u64 {\n        self.position_counter\n            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)\n    }\n\n    pub(crate) fn send(&self, task: ComputeTask) {\n        self.sender\n            .send_blocking(ClientRequest::WithoutCallback(Task::Compute(\n                task,\n                ConnectionId::new(self.next_position(), StreamId::current()),\n            )))\n            .unwrap();\n    }\n\n    pub(crate) fn send_async(\n        &self,\n        task: ComputeTask,\n    ) -> impl Future<Output = Result<TaskResponseContent, RemoteSendError>> + Send + use<> {\n        let stream_id = StreamId::current();\n        let position = self.next_position();\n        let sender = self.sender.clone();\n\n        async move {\n            let (tx, rx) = async_channel::bounded(1);\n\n            if let Err(e) = sender\n                .send(ClientRequest::WithSyncCallback(\n                    Task::Compute(task, ConnectionId::new(position, stream_id)),\n                    tx,\n                ))\n                .await\n            {\n                return Err(RemoteSendError::SendError(e));\n            }\n\n            match rx.recv().await {\n                Ok(response) => Ok(response),\n                Err(e) => Err(RemoteSendError::RecvError(e)),\n            }\n        }\n    }\n\n    pub(crate) fn close(&mut self) {\n        let sender = self.sender.clone();\n\n        let close_task = ClientRequest::WithoutCallback(Task::Close(self.session_id));\n\n        sender.send_blocking(close_task).unwrap();\n    }\n}\n\nimpl Drop for RemoteSender {\n    fn drop(&mut self) {\n        self.close();\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/client/channel.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn_backend::Shape;\nuse burn_communication::ProtocolClient;\nuse burn_ir::TensorIr;\nuse burn_router::{RouterTensor, RunnerChannel, get_client};\n\nuse super::{\n    RemoteClient,\n    runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},\n};\n\n/// A local channel with direct connection to the backend runner clients.\npub struct RemoteChannel<C: ProtocolClient> {\n    _p: PhantomData<C>,\n}\n\nimpl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {\n    type Device = RemoteDevice;\n    type Bridge = RemoteBridge<C>;\n    type Client = RemoteClient;\n\n    type FloatElem = f32;\n\n    type IntElem = i32;\n\n    type BoolElem = u32;\n\n    fn name(device: &Self::Device) -> String {\n        format!(\"remote-{device:?}\")\n    }\n\n    fn init_client(device: &Self::Device) -> Self::Client {\n        RemoteClient::init::<C>(device.clone())\n    }\n\n    fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {\n        RemoteTensorHandle {\n            client: client.clone(),\n            tensor: tensor.clone(),\n            _p: PhantomData,\n        }\n    }\n\n    fn register_tensor(\n        _client: &Self::Client,\n        _handle: RemoteTensorHandle<C>,\n        _shape: Shape,\n        _dtype: burn_backend::DType,\n    ) -> RouterTensor<Self::Client> {\n        // This function is normally only used to move a tensor from a device to another.\n        //\n        // In other words, to change the client.\n        panic!(\"Can't register manually a tensor on a remote channel.\");\n    }\n\n    fn change_client_backend(\n        tensor: RouterTensor<Self::Client>,\n        target_device: &Self::Device, // target device\n    ) -> RouterTensor<Self::Client> {\n        // Get tensor handle from current client\n        let original_client = tensor.client.clone();\n        let desc = tensor.into_ir();\n        let handle = Self::get_tensor_handle(&desc, &original_client);\n\n        let handle = handle.change_backend(target_device);\n\n        let id = handle.tensor.id;\n\n        let target_client = get_client::<Self>(target_device);\n        let router_tensor: RouterTensor<RemoteClient> =\n            RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);\n\n        router_tensor\n    }\n}\n\nimpl<C: ProtocolClient> Clone for RemoteChannel<C> {\n    fn clone(&self) -> Self {\n        RemoteChannel { _p: PhantomData }\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/client/mod.rs",
    "content": "mod base;\nmod channel;\nmod runner;\nmod worker;\n\npub use base::*;\npub use channel::*;\npub use runner::RemoteDevice;\n"
  },
  {
    "path": "crates/burn-remote/src/client/runner.rs",
    "content": "use super::{RemoteChannel, RemoteClient};\nuse crate::shared::{ComputeTask, TaskResponseContent, TensorRemote};\nuse burn_backend::{DeviceId, DeviceOps, ExecutionError, TensorData};\nuse burn_communication::{Address, ProtocolClient, data_service::TensorTransferId};\nuse burn_ir::TensorIr;\nuse burn_router::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};\nuse burn_std::{backtrace::BackTrace, future::DynFut};\nuse std::sync::OnceLock;\nuse std::{collections::HashMap, marker::PhantomData, str::FromStr, sync::Mutex};\n\n// TODO: we should work with the parsed structure of Address, not the string.\nstatic ADDRESS_REGISTRY: OnceLock<Mutex<HashMap<String, u32>>> = OnceLock::new();\n\nfn get_address_registry() -> &'static Mutex<HashMap<String, u32>> {\n    ADDRESS_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))\n}\n\n/// Map a string network address to a (local runtime) global unique u32.\n///\n/// Globally stable over the lifetime of the process, shared between threads,\n/// If the address has never been seen, a new id will be created.\n/// If the address has been seen, the previous id will be returned.\npub fn address_to_id<S: AsRef<str>>(address: S) -> u32 {\n    let registry = get_address_registry();\n    let mut registry = registry.lock().unwrap();\n    let next_id = registry.len() as u32;\n    *registry\n        .entry(address.as_ref().to_string())\n        .or_insert_with(|| next_id)\n}\n\n/// Look up an address by id.\n///\n/// Returns the same address given ids by [`address_to_id`].\npub fn id_to_address(id: u32) -> Option<String> {\n    let registry = get_address_registry();\n    let registry = registry.lock().unwrap();\n    for entry in registry.iter() {\n        if entry.1 == &id {\n            return Some(entry.0.clone());\n        }\n    }\n    None\n}\n\n// It is very important to block on any request made with the sender, since ordering is crucial\n// when registering operation or creating tensors.\n//\n// The overhead is minimal, since we only wait for the task to be sent to the async\n// channel, but not sent to the server and even less processed by the server.\nimpl RunnerClient for RemoteClient {\n    type Device = RemoteDevice;\n\n    fn register_op(&self, op: burn_ir::OperationIr) {\n        self.sender\n            .send(ComputeTask::RegisterOperation(Box::new(op)));\n    }\n\n    fn read_tensor_async(\n        &self,\n        tensor: burn_ir::TensorIr,\n    ) -> DynFut<Result<TensorData, ExecutionError>> {\n        // Important for ordering to call the creation of the future sync.\n        let fut = self.sender.send_async(ComputeTask::ReadTensor(tensor));\n\n        Box::pin(async move {\n            match fut.await {\n                Ok(response) => match response {\n                    TaskResponseContent::ReadTensor(res) => res,\n                    _ => panic!(\"Invalid message type\"),\n                },\n                Err(e) => Err(ExecutionError::Generic {\n                    reason: format!(\"Failed to read tensor: {:?}\", e),\n                    backtrace: BackTrace::capture(),\n                }),\n            }\n        })\n    }\n\n    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {\n        let id = self.sender.new_tensor_id();\n        let shape = data.shape.clone();\n        let dtype = data.dtype;\n\n        self.sender.send(ComputeTask::RegisterTensor(id, data));\n\n        RouterTensor::new(id, shape, dtype, self.clone())\n    }\n\n    fn device(&self) -> Self::Device {\n        self.device.clone()\n    }\n\n    fn sync(&self) -> Result<(), ExecutionError> {\n        // Important for ordering to call the creation of the future sync.\n        let fut = self.sender.send_async(ComputeTask::SyncBackend);\n\n        match self.runtime.block_on(fut) {\n            Ok(response) => match response {\n                TaskResponseContent::SyncBackend(res) => res,\n                _ => panic!(\"Invalid message type\"),\n            },\n            Err(e) => Err(ExecutionError::Generic {\n                reason: format!(\"Failed to sync: {:?}\", e),\n                backtrace: BackTrace::capture(),\n            }),\n        }\n    }\n\n    fn seed(&self, seed: u64) {\n        self.sender.send(ComputeTask::Seed(seed));\n    }\n\n    fn create_empty_handle(&self) -> burn_ir::TensorId {\n        self.sender.new_tensor_id()\n    }\n\n    fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {\n        let fut = self.sender.send_async(ComputeTask::SupportsDType(dtype));\n\n        match self.runtime.block_on(fut) {\n            Ok(_response) => panic!(\"Invalid message type\"),\n            Err(e) => panic!(\"Failed to check dtype support: {:?}\", e),\n        }\n    }\n}\n\n#[derive(Clone, PartialEq, Eq, Debug)]\n/// The device contains the connection information of the server.\npub struct RemoteDevice {\n    pub(crate) address: Address,\n    /// The id of the device in the local registry, see [`address_to_id`].\n    pub(crate) id: u32,\n}\n\nimpl RemoteDevice {\n    /// Create a device from an url.\n    pub fn new(address: &str) -> Self {\n        let id = address_to_id(address);\n        Self {\n            address: Address::from_str(address).unwrap(),\n            id,\n        }\n    }\n}\n\nimpl Default for RemoteDevice {\n    fn default() -> Self {\n        let address = match std::env::var(\"BURN_REMOTE_ADDRESS\") {\n            Ok(address) => address,\n            Err(_) => String::from(\"ws://127.0.0.1:3000\"),\n        };\n\n        Self::new(&address)\n    }\n}\n\nimpl burn_std::device::Device for RemoteDevice {\n    fn from_id(device_id: DeviceId) -> Self {\n        if device_id.type_id != 0 {\n            panic!(\"Invalid device id: {device_id} (expected type 0)\");\n        }\n        let address = id_to_address(device_id.index_id)\n            .unwrap_or_else(|| panic!(\"Invalid device id: {device_id}\"));\n        Self::new(&address)\n    }\n\n    fn to_id(&self) -> DeviceId {\n        DeviceId {\n            type_id: 0,\n            index_id: self.id,\n        }\n    }\n\n    fn device_count(_type_id: u16) -> usize {\n        1\n    }\n}\n\nimpl DeviceOps for RemoteDevice {}\n\npub struct RemoteBridge<C: ProtocolClient> {\n    _p: PhantomData<C>,\n}\n\npub struct RemoteTensorHandle<C: ProtocolClient> {\n    pub(crate) client: RemoteClient,\n    pub(crate) tensor: TensorIr,\n    pub(crate) _p: PhantomData<C>,\n}\n\nstatic TRANSFER_COUNTER: Mutex<Option<TensorTransferId>> = Mutex::new(None);\n\nfn get_next_transfer_id() -> TensorTransferId {\n    let mut transfer_counter = TRANSFER_COUNTER.lock().unwrap();\n    if transfer_counter.is_none() {\n        *transfer_counter = Some(0.into());\n\n        transfer_counter.unwrap()\n    } else {\n        let mut transfer_counter = transfer_counter.unwrap();\n        transfer_counter.next();\n\n        transfer_counter\n    }\n}\n\nimpl<C: ProtocolClient> RemoteTensorHandle<C> {\n    /// Changes the backend of the tensor via a dWebSocket.\n    /// We ask the original server to expose the tensor, then ask the target server to fetch\n    /// the tensor. The target server will open a new network connection to the original server\n    /// to download the data.\n    /// This way the client never sees the tensor's data, and we avoid a bottleneck.\n    pub(crate) fn change_backend(mut self, target_device: &RemoteDevice) -> Self {\n        let transfer_id = get_next_transfer_id();\n        self.client.sender.send(ComputeTask::ExposeTensorRemote {\n            tensor: self.tensor.clone(),\n            count: 1,\n            transfer_id,\n        });\n\n        let target_client = get_client::<RemoteChannel<C>>(target_device);\n\n        let new_id = target_client.sender.new_tensor_id();\n\n        let remote_tensor = TensorRemote {\n            transfer_id,\n            address: self.client.device.address.clone(),\n        };\n        target_client\n            .sender\n            .send(ComputeTask::RegisterTensorRemote(remote_tensor, new_id));\n\n        self.tensor.id = new_id;\n        self.client = target_client;\n\n        self\n    }\n}\n\nimpl<C: ProtocolClient> MultiBackendBridge for RemoteBridge<C> {\n    type TensorHandle = RemoteTensorHandle<C>;\n    type Device = RemoteDevice;\n\n    fn change_backend_float(\n        tensor: Self::TensorHandle,\n        _shape: burn_backend::Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle {\n        tensor.change_backend(target_device)\n    }\n\n    fn change_backend_int(\n        tensor: Self::TensorHandle,\n        _shape: burn_backend::Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle {\n        tensor.change_backend(target_device)\n    }\n\n    fn change_backend_bool(\n        tensor: Self::TensorHandle,\n        _shape: burn_backend::Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle {\n        tensor.change_backend(target_device)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_address_to_id() {\n        let address1 = \"ws://127.0.0.1:3000\";\n        let address2 = \"ws://127.0.0.1:3001\";\n\n        let id1 = address_to_id(address1);\n        let id2 = address_to_id(address2);\n\n        assert_ne!(id1, id2);\n\n        assert_eq!(address_to_id(address1), id1);\n        assert_eq!(id_to_address(id1), Some(address1.to_string()));\n\n        assert_eq!(address_to_id(address2), id2);\n        assert_eq!(id_to_address(id2), Some(address2.to_string()));\n\n        let unused_id = u32::MAX;\n\n        assert_eq!(id_to_address(unused_id), None);\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/client/worker.rs",
    "content": "use super::{RemoteClient, runner::RemoteDevice};\nuse crate::shared::{ConnectionId, SessionId, Task, TaskResponse, TaskResponseContent};\nuse burn_communication::{CommunicationChannel, Message, ProtocolClient};\nuse std::{collections::HashMap, marker::PhantomData, sync::Arc};\n\npub type CallbackSender = async_channel::Sender<TaskResponseContent>;\n\n#[derive(Debug)]\npub enum ClientRequest {\n    WithSyncCallback(Task, CallbackSender),\n    WithoutCallback(Task),\n}\n\npub(crate) struct ClientWorker<C: ProtocolClient> {\n    requests: HashMap<ConnectionId, CallbackSender>,\n    _p: PhantomData<C>,\n}\n\nimpl<C: ProtocolClient> ClientWorker<C> {\n    async fn on_response(&mut self, response: TaskResponse) {\n        match self.requests.remove(&response.id) {\n            Some(request) => {\n                request.send(response.content).await.unwrap();\n            }\n            None => {\n                panic!(\"Can't ignore message from the server.\");\n            }\n        }\n    }\n\n    fn register_callback(&mut self, id: ConnectionId, callback: CallbackSender) {\n        self.requests.insert(id, callback);\n    }\n}\n\nimpl<C: ProtocolClient> ClientWorker<C> {\n    pub fn start(device: RemoteDevice) -> RemoteClient {\n        let runtime = Arc::new(\n            tokio::runtime::Builder::new_multi_thread()\n                .enable_io()\n                .build()\n                .unwrap(),\n        );\n\n        let (sender, rec) = async_channel::bounded(10);\n\n        let session_id = SessionId::new();\n        let address = device.address.clone();\n\n        #[allow(deprecated)]\n        runtime.spawn(async move {\n            log::info!(\"Connecting to {} ...\", address.clone());\n            let mut stream_request = C::connect(address.clone(), \"request\")\n                .await\n                .expect(\"Server to be accessible\");\n            let mut stream_response = C::connect(address, \"response\")\n                .await\n                .expect(\"Server to be accessible\");\n\n            let state = Arc::new(tokio::sync::Mutex::new(ClientWorker::<C>::default()));\n\n            // Init the connection.\n            let bytes: bytes::Bytes = rmp_serde::to_vec(&Task::Init(session_id))\n                .expect(\"Can serialize tasks to bytes.\")\n                .into();\n            stream_request\n                .send(Message::new(bytes.clone()))\n                .await\n                .expect(\"Can send the message over the comms channel.\");\n            stream_response\n                .send(Message::new(bytes))\n                .await\n                .expect(\"Can send the message on the websocket.\");\n\n            // Async worker loading callbacks from the server.\n            let state_ws = state.clone();\n            tokio::spawn(async move {\n                while let Ok(msg) = stream_response.recv().await {\n                    let msg = match msg {\n                        Some(msg) => msg,\n                        None => {\n                            log::warn!(\"Closed connection\");\n                            return;\n                        }\n                    };\n\n                    let response: TaskResponse = rmp_serde::from_slice(&msg.data)\n                        .expect(\"Can deserialize messages from the websocket.\");\n                    let mut state = state_ws.lock().await;\n                    state.on_response(response).await;\n                }\n            });\n\n            // Channel async worker sending operations to the server.\n            tokio::spawn(async move {\n                while let Ok(req) = rec.recv().await {\n                    let task = match req {\n                        ClientRequest::WithSyncCallback(task, callback) => {\n                            if let Task::Compute(_content, id) = &task {\n                                let mut state = state.lock().await;\n                                state.register_callback(*id, callback);\n                            }\n                            task\n                        }\n                        ClientRequest::WithoutCallback(task) => task,\n                    };\n                    let bytes = rmp_serde::to_vec(&task)\n                        .expect(\"Can serialize tasks to bytes.\")\n                        .into();\n                    stream_request\n                        .send(Message::new(bytes))\n                        .await\n                        .expect(\"Can send the message on the websocket.\");\n                }\n            });\n        });\n\n        RemoteClient::new(device, sender, runtime, session_id)\n    }\n}\n\nimpl<C: ProtocolClient> Default for ClientWorker<C> {\n    fn default() -> Self {\n        Self {\n            requests: Default::default(),\n            _p: PhantomData,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/lib.rs",
    "content": "#[macro_use]\nextern crate derive_new;\n\n#[cfg(feature = \"client\")]\npub(crate) mod client;\n\n#[cfg(feature = \"server\")]\npub mod server;\n\npub(crate) mod shared;\n\n#[cfg(feature = \"client\")]\nmod __client {\n    use super::*;\n\n    use crate::{client::RemoteChannel, shared::RemoteProtocol};\n    use burn_communication::Protocol;\n    use burn_router::BackendRouter;\n\n    /// The remote backend allows you to run computation on a remote device.\n    ///\n    /// Make sure there is a running server before trying to connect to it.\n    ///\n    /// ```rust, ignore\n    /// fn main() {\n    ///     let device = Default::default();\n    ///     let port = 3000;\n    ///\n    ///     // You need to activate the `server` feature flag to have access to this function.\n    ///     burn::server::start::<burn::backend::Wgpu>(device, port);\n    /// }\n    ///```\n    pub type RemoteBackend = BackendRouter<RemoteChannel<<RemoteProtocol as Protocol>::Client>>;\n\n    pub use client::RemoteDevice;\n}\n#[cfg(feature = \"client\")]\npub use __client::*;\n\n#[cfg(all(test, feature = \"client\", feature = \"server\"))]\nmod tests {\n    use crate::RemoteBackend;\n    use burn_ndarray::NdArray;\n    use burn_tensor::{Distribution, Tensor};\n\n    #[test]\n    pub fn test_to_device_over_websocket() {\n        let rt = tokio::runtime::Builder::new_multi_thread()\n            .enable_io()\n            .build()\n            .unwrap();\n\n        rt.spawn(crate::server::start_websocket_async::<NdArray>(\n            Default::default(),\n            3000,\n        ));\n        rt.spawn(crate::server::start_websocket_async::<NdArray>(\n            Default::default(),\n            3010,\n        ));\n\n        let remote_device_1 = super::RemoteDevice::new(\"ws://localhost:3000\");\n        let remote_device_2 = super::RemoteDevice::new(\"ws://localhost:3010\");\n\n        // Some random input\n        let input_shape = [1, 28, 28];\n        let input = Tensor::<RemoteBackend, 3>::random(\n            input_shape,\n            Distribution::Default,\n            &remote_device_1,\n        );\n        let numbers_expected: Vec<f32> = input.to_data().to_vec().unwrap();\n\n        // Move tensor to device 2\n        let input = input.to_device(&remote_device_2);\n        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();\n        assert_eq!(numbers, numbers_expected);\n\n        // Move tensor back to device 1\n        let input = input.to_device(&remote_device_1);\n        let numbers: Vec<f32> = input.to_data().to_vec().unwrap();\n        assert_eq!(numbers, numbers_expected);\n\n        rt.shutdown_background();\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/server/base.rs",
    "content": "use burn_communication::{\n    CommunicationChannel, Message, Protocol, ProtocolServer,\n    data_service::{TensorDataServer, TensorDataService},\n    util::os_shutdown_signal,\n    websocket::{WebSocket, WsServer},\n};\nuse std::{marker::PhantomData, sync::Arc};\nuse tokio_util::sync::CancellationToken;\n\nuse burn_backend::tensor::Device;\nuse burn_ir::BackendIr;\n\nuse crate::shared::{ComputeTask, Task};\n\nuse super::session::SessionManager;\n\npub struct RemoteServer<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    _b: PhantomData<B>,\n    _n: PhantomData<P>,\n}\n\nimpl<B, P> RemoteServer<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    /// Start the server on the given address.\n    pub async fn start(device: Device<B>, server: P::Server) {\n        let cancel_token = CancellationToken::new();\n        let data_service = Arc::new(TensorDataService::<B, P>::new(cancel_token));\n        let session_manager = Arc::new(SessionManager::<B, P>::new(device, data_service.clone()));\n\n        let _server = server\n            .route(\"/response\", {\n                let session_manager = session_manager.clone();\n                move |stream| Self::handle_socket_response(session_manager, stream)\n            })\n            .route(\"/request\", {\n                let session_manager = session_manager.clone();\n                move |stream| Self::handle_socket_request(session_manager, stream)\n            })\n            .route_tensor_data_service(data_service)\n            .serve(os_shutdown_signal())\n            .await;\n    }\n\n    async fn handle_socket_response(\n        session_manager: Arc<SessionManager<B, P>>,\n        mut socket: <P::Server as ProtocolServer>::Channel,\n    ) {\n        log::info!(\"[Response Handler] On new connection.\");\n\n        let packet = socket.recv().await;\n        let msg = match packet {\n            Ok(Some(msg)) => msg,\n            Ok(None) => {\n                log::info!(\"Response stream closed\");\n                return;\n            }\n            Err(e) => {\n                log::info!(\"Response stream error on init: {e:?}\");\n                return;\n            }\n        };\n\n        let id = match rmp_serde::from_slice::<Task>(&msg.data) {\n            Ok(Task::Init(session_id)) => session_id,\n            msg => {\n                log::error!(\"Message is not a valid initialization task {msg:?}\");\n                return;\n            }\n        };\n\n        let mut receiver = session_manager.register_responder(id).await;\n\n        log::info!(\"Response handler connection active\");\n\n        while let Some(mut callback) = receiver.recv().await {\n            let response = callback.recv().await.unwrap();\n            let bytes = rmp_serde::to_vec(&response).unwrap();\n\n            socket.send(Message::new(bytes.into())).await.unwrap();\n        }\n    }\n\n    async fn handle_socket_request(\n        session_manager: Arc<SessionManager<B, P>>,\n        mut socket: <P::Server as ProtocolServer>::Channel,\n    ) {\n        log::info!(\"[Request Handler] On new connection.\");\n        let mut session_id = None;\n\n        loop {\n            let packet = socket.recv().await;\n            let msg = match packet {\n                Ok(Some(msg)) => msg,\n                Ok(None) => {\n                    log::info!(\"Request stream closed\");\n                    break;\n                }\n                Err(e) => {\n                    log::info!(\"Request stream error: {e:?}, Closing.\");\n                    break;\n                }\n            };\n\n            let task = match rmp_serde::from_slice::<Task>(&msg.data) {\n                Ok(val) => val,\n                Err(err) => {\n                    log::info!(\"Only bytes message in the json format are supported {err:?}\");\n                    break;\n                }\n            };\n\n            if let Task::Close(id) = task {\n                session_id = Some(id);\n                break;\n            }\n\n            let (stream, connection_id, task) =\n                match session_manager.stream(&mut session_id, task).await {\n                    Some(val) => val,\n                    None => {\n                        log::info!(\"Ops session activated {session_id:?}\");\n                        continue;\n                    }\n                };\n\n            match task {\n                ComputeTask::RegisterOperation(op) => {\n                    stream.register_operation(op).await;\n                }\n                ComputeTask::RegisterTensor(id, data) => {\n                    stream.register_tensor(id, data).await;\n                }\n                ComputeTask::ReadTensor(tensor) => {\n                    stream.read_tensor(connection_id, tensor).await;\n                }\n                ComputeTask::SyncBackend => {\n                    stream.sync(connection_id).await;\n                }\n                ComputeTask::RegisterTensorRemote(tensor, new_id) => {\n                    stream.register_tensor_remote(tensor, new_id).await;\n                }\n                ComputeTask::ExposeTensorRemote {\n                    tensor,\n                    count,\n                    transfer_id,\n                } => {\n                    stream\n                        .expose_tensor_remote(tensor, count, transfer_id)\n                        .await;\n                }\n                ComputeTask::Seed(seed) => {\n                    stream.seed(seed).await;\n                }\n                ComputeTask::SupportsDType(dtype) => {\n                    stream.supports_dtype(connection_id, dtype).await\n                }\n            }\n        }\n\n        log::info!(\"Closing session {session_id:?}\");\n        session_manager.close(session_id).await;\n    }\n}\n\n/// Start the server on the given port and [device](Device).\npub async fn start_websocket_async<B: BackendIr>(device: Device<B>, port: u16) {\n    let server = WsServer::new(port);\n    RemoteServer::<B, WebSocket>::start(device, server).await;\n}\n\n#[tokio::main]\n/// Start the server on the given port and [device](Device).\npub async fn start_websocket<B: BackendIr>(device: Device<B>, port: u16) {\n    start_websocket_async::<B>(device, port).await;\n}\n"
  },
  {
    "path": "crates/burn-remote/src/server/mod.rs",
    "content": "pub(crate) mod processor;\npub(crate) mod session;\npub(crate) mod stream;\n\nmod base;\n\npub use base::{start_websocket, start_websocket_async};\n"
  },
  {
    "path": "crates/burn-remote/src/server/processor.rs",
    "content": "use burn_backend::TensorData;\nuse burn_communication::{\n    Protocol,\n    data_service::{TensorDataService, TensorTransferId},\n};\nuse burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};\nuse burn_router::{Runner, RunnerClient};\nuse burn_std::DType;\nuse core::marker::PhantomData;\nuse std::sync::Arc;\nuse tokio::sync::mpsc::Sender;\n\nuse crate::shared::{ConnectionId, TaskResponse, TaskResponseContent, TensorRemote};\n\n/// The goal of the processor is to asynchronously process compute tasks on it own thread.\npub struct Processor<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    p: PhantomData<B>,\n    n: PhantomData<P>,\n}\n\npub type Callback<M> = Sender<M>;\n\npub enum ProcessorTask {\n    RegisterOperation(Box<OperationIr>),\n    RegisterTensor(TensorId, TensorData),\n    RegisterTensorRemote(TensorRemote, TensorId),\n    ExposeTensorRemote {\n        tensor: TensorIr,\n        transfer_id: TensorTransferId,\n        count: u32,\n    },\n    ReadTensor(ConnectionId, TensorIr, Callback<TaskResponse>),\n    Sync(ConnectionId, Callback<TaskResponse>),\n    Seed(u64),\n    SupportsDType(ConnectionId, DType, Callback<TaskResponse>),\n    Close,\n}\n\nimpl<B: BackendIr, P> Processor<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    pub async fn start(\n        runner: Runner<B>,\n        data_service: Arc<TensorDataService<B, P>>,\n    ) -> Sender<ProcessorTask> {\n        // channel for tasks to execute\n        let (task_sender, mut task_rec) = tokio::sync::mpsc::channel(1);\n\n        tokio::spawn(async move {\n            while let Some(item) = task_rec.recv().await {\n                match item {\n                    ProcessorTask::RegisterOperation(op) => {\n                        runner.register_op(*op);\n                    }\n                    ProcessorTask::Sync(id, callback) => {\n                        let result = runner.sync();\n                        callback\n                            .send(TaskResponse {\n                                content: TaskResponseContent::SyncBackend(result),\n                                id,\n                            })\n                            .await\n                            .unwrap();\n                    }\n                    ProcessorTask::RegisterTensor(id, data) => {\n                        runner.register_tensor_data_id(id, data);\n                    }\n                    ProcessorTask::RegisterTensorRemote(remote_tensor, new_id) => {\n                        log::info!(\n                            \"Registering remote tensor...(id: {:?})\",\n                            remote_tensor.transfer_id\n                        );\n                        let data = data_service\n                            .download_tensor(remote_tensor.address, remote_tensor.transfer_id)\n                            .await\n                            .expect(\"Can't download tensor: error\"); // TODO all these panics should be server errors\n                        runner.register_tensor_data_id(new_id, data);\n                    }\n                    ProcessorTask::ExposeTensorRemote {\n                        tensor,\n                        transfer_id,\n                        count,\n                    } => {\n                        log::info!(\"Exposing tensor: (id: {transfer_id:?})\");\n                        let data = runner.read_tensor_async(tensor).await;\n                        data_service\n                            .expose_data(data.unwrap(), count, transfer_id)\n                            .await;\n                    }\n                    ProcessorTask::ReadTensor(id, tensor, callback) => {\n                        let tensor = runner.read_tensor_async(tensor).await;\n                        callback\n                            .send(TaskResponse {\n                                content: TaskResponseContent::ReadTensor(tensor),\n                                id,\n                            })\n                            .await\n                            .unwrap();\n                    }\n                    ProcessorTask::Close => {\n                        let device = runner.device();\n                        runner.sync().unwrap();\n                        core::mem::drop(runner);\n                        B::sync(&device).unwrap();\n                        break;\n                    }\n                    ProcessorTask::Seed(seed) => runner.seed(seed),\n                    ProcessorTask::SupportsDType(id, dtype, callback) => {\n                        let _result = runner.dtype_usage(dtype);\n                        callback\n                            .send(TaskResponse {\n                                // content: TaskResponseContent::SupportsDType(result),\n                                // TODO: Update to result.\n                                content: TaskResponseContent::SupportsDType(()),\n                                id,\n                            })\n                            .await\n                            .unwrap();\n                    }\n                }\n            }\n        });\n\n        task_sender\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/server/session.rs",
    "content": "use burn_backend::tensor::Device;\nuse burn_communication::{Protocol, data_service::TensorDataService};\nuse burn_ir::BackendIr;\nuse burn_router::Runner;\nuse burn_std::id::StreamId;\nuse std::{collections::HashMap, sync::Arc};\nuse tokio::sync::{\n    Mutex,\n    mpsc::{Receiver, Sender},\n};\n\nuse crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse};\n\nuse super::stream::Stream;\n\n/// A session manager control the creation of sessions.\n///\n/// Each session manages its own stream, spawning one thread per stream to mimic the same behavior\n/// a native backend would have.\npub struct SessionManager<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    runner: Runner<B>,\n    sessions: Mutex<HashMap<SessionId, Session<B, P>>>,\n    data_service: Arc<TensorDataService<B, P>>,\n}\n\nstruct Session<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    runner: Runner<B>,\n    streams: HashMap<StreamId, Stream<B, P>>,\n    sender: Sender<Receiver<TaskResponse>>,\n    receiver: Option<Receiver<Receiver<TaskResponse>>>,\n    data_service: Arc<TensorDataService<B, P>>,\n}\n\nimpl<B, P> SessionManager<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    pub fn new(device: Device<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {\n        Self {\n            runner: Runner::new(device),\n            sessions: Mutex::new(Default::default()),\n            data_service,\n        }\n    }\n\n    /// Register a new responder for the session. Only one responder can exist for a session for\n    /// now.\n    pub async fn register_responder(\n        &self,\n        session_id: SessionId,\n    ) -> Receiver<Receiver<TaskResponse>> {\n        log::info!(\"Register responder for session {session_id}\");\n        let mut sessions = self.sessions.lock().await;\n        self.register_session(&mut sessions, session_id);\n\n        let session = sessions.get_mut(&session_id).unwrap();\n        session.init_responder()\n    }\n\n    /// Get the stream for the current session and task.\n    pub async fn stream(\n        &self,\n        session_id: &mut Option<SessionId>,\n        task: Task,\n    ) -> Option<(Stream<B, P>, ConnectionId, ComputeTask)> {\n        let mut sessions = self.sessions.lock().await;\n\n        let session_id = match session_id {\n            Some(id) => *id,\n            None => match task {\n                Task::Init(id) => {\n                    log::info!(\"Init requester for session {id}\");\n                    *session_id = Some(id);\n                    self.register_session(&mut sessions, id);\n                    return None;\n                }\n                _ => panic!(\"The first message should initialize the session\"),\n            },\n        };\n\n        match sessions.get_mut(&session_id) {\n            Some(session) => {\n                let (task, connection_id) = match task {\n                    Task::Compute(task, connection_id) => (task, connection_id),\n                    _ => panic!(\"Only support compute tasks.\"),\n                };\n                let stream = session.select(connection_id.stream_id).await;\n                Some((stream, connection_id, task))\n            }\n            None => panic!(\"To be initialized\"),\n        }\n    }\n\n    /// Close the session with the given id.\n    pub async fn close(&self, session_id: Option<SessionId>) {\n        if let Some(id) = session_id {\n            let mut sessions = self.sessions.lock().await;\n            if let Some(session) = sessions.get_mut(&id) {\n                session.close().await;\n            }\n        }\n    }\n\n    fn register_session(&self, sessions: &mut HashMap<SessionId, Session<B, P>>, id: SessionId) {\n        sessions.entry(id).or_insert_with(|| {\n            log::info!(\"Creating a new session {id}\");\n\n            Session::new(self.runner.clone(), self.data_service.clone())\n        });\n    }\n}\n\nimpl<B, P> Session<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    fn new(runner: Runner<B>, data_service: Arc<TensorDataService<B, P>>) -> Self {\n        let (sender, receiver) = tokio::sync::mpsc::channel(1);\n\n        Self {\n            runner,\n            streams: Default::default(),\n            sender,\n            receiver: Some(receiver),\n            data_service,\n        }\n    }\n\n    fn init_responder(&mut self) -> Receiver<Receiver<TaskResponse>> {\n        let mut receiver = None;\n        core::mem::swap(&mut receiver, &mut self.receiver);\n        receiver.expect(\"Only one responder per session is possible.\")\n    }\n\n    /// Select the current [stream](Stream) based on the given task.\n    async fn select(&mut self, stream_id: StreamId) -> Stream<B, P> {\n        // We return the stream.\n        match self.streams.get(&stream_id) {\n            Some(stream) => stream.clone(),\n            None => {\n                let stream = Stream::<B, P>::new(\n                    self.runner.clone(),\n                    self.sender.clone(),\n                    self.data_service.clone(),\n                )\n                .await;\n                self.streams.insert(stream_id, stream.clone());\n                stream\n            }\n        }\n    }\n\n    // Close all streams created in the session.\n    async fn close(&mut self) {\n        for (id, stream) in self.streams.drain() {\n            log::info!(\"Closing stream {id}\");\n            stream.close().await;\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/server/stream.rs",
    "content": "use core::marker::PhantomData;\nuse std::sync::Arc;\n\nuse crate::shared::{ConnectionId, TaskResponse, TensorRemote};\n\nuse super::processor::{Processor, ProcessorTask};\nuse burn_backend::TensorData;\nuse burn_communication::{\n    Protocol,\n    data_service::{TensorDataService, TensorTransferId},\n};\nuse burn_ir::{BackendIr, OperationIr, TensorId, TensorIr};\nuse burn_router::Runner;\nuse burn_std::DType;\nuse tokio::sync::mpsc::{Receiver, Sender};\n\n/// A stream makes sure all operations registered are executed in the order they were sent to the\n/// server, potentially waiting to reconstruct consistency.\n#[derive(Clone)]\npub struct Stream<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    compute_sender: Sender<ProcessorTask>,\n    writer_sender: Sender<Receiver<TaskResponse>>,\n    _p: PhantomData<B>,\n    _n: PhantomData<P>,\n}\n\nimpl<B, P> Stream<B, P>\nwhere\n    B: BackendIr,\n    P: Protocol,\n{\n    pub async fn new(\n        runner: Runner<B>,\n        writer_sender: Sender<Receiver<TaskResponse>>,\n        data_service: Arc<TensorDataService<B, P>>,\n    ) -> Self {\n        let sender = Processor::<B, P>::start(runner, data_service).await;\n\n        Self {\n            compute_sender: sender,\n            writer_sender,\n            _p: PhantomData,\n            _n: PhantomData,\n        }\n    }\n\n    pub async fn register_operation(&self, op: Box<OperationIr>) {\n        self.compute_sender\n            .send(ProcessorTask::RegisterOperation(op))\n            .await\n            .unwrap();\n    }\n\n    pub async fn register_tensor(&self, tensor_id: TensorId, data: TensorData) {\n        self.compute_sender\n            .send(ProcessorTask::RegisterTensor(tensor_id, data))\n            .await\n            .unwrap();\n    }\n\n    pub async fn register_tensor_remote(&self, tensor: TensorRemote, new_id: TensorId) {\n        self.compute_sender\n            .send(ProcessorTask::RegisterTensorRemote(tensor, new_id))\n            .await\n            .unwrap();\n    }\n\n    pub async fn expose_tensor_remote(\n        &self,\n        tensor: TensorIr,\n        count: u32,\n        transfer_id: TensorTransferId,\n    ) {\n        self.compute_sender\n            .send(ProcessorTask::ExposeTensorRemote {\n                tensor,\n                count,\n                transfer_id,\n            })\n            .await\n            .unwrap();\n    }\n\n    pub async fn read_tensor(&self, id: ConnectionId, desc: TensorIr) {\n        let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);\n\n        self.compute_sender\n            .send(ProcessorTask::ReadTensor(id, desc, callback_sender))\n            .await\n            .unwrap();\n\n        self.writer_sender.send(callback_rec).await.unwrap();\n    }\n\n    pub async fn sync(&self, id: ConnectionId) {\n        let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);\n\n        self.compute_sender\n            .send(ProcessorTask::Sync(id, callback_sender))\n            .await\n            .unwrap();\n\n        self.writer_sender.send(callback_rec).await.unwrap();\n    }\n\n    pub async fn close(&self) {\n        self.compute_sender\n            .send(ProcessorTask::Close)\n            .await\n            .unwrap();\n    }\n\n    pub async fn seed(&self, seed: u64) {\n        self.compute_sender\n            .send(ProcessorTask::Seed(seed))\n            .await\n            .unwrap();\n    }\n\n    pub async fn supports_dtype(&self, id: ConnectionId, dtype: DType) {\n        let (callback_sender, callback_rec) = tokio::sync::mpsc::channel(1);\n\n        self.compute_sender\n            .send(ProcessorTask::SupportsDType(id, dtype, callback_sender))\n            .await\n            .unwrap();\n\n        self.writer_sender.send(callback_rec).await.unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-remote/src/shared/mod.rs",
    "content": "mod task;\n\n#[allow(unused_imports)]\npub(crate) use task::*;\n\n/// We define the communication protocol here\npub(crate) type RemoteProtocol = burn_communication::websocket::WebSocket;\n"
  },
  {
    "path": "crates/burn-remote/src/shared/task.rs",
    "content": "use burn_backend::{ExecutionError, TensorData};\nuse burn_communication::{Address, data_service::TensorTransferId};\nuse burn_ir::{OperationIr, TensorId, TensorIr};\nuse burn_std::{\n    DType,\n    id::{IdGenerator, StreamId},\n};\nuse serde::{Deserialize, Serialize};\nuse std::fmt::Display;\n\n#[allow(missing_docs)]\n#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]\npub struct ConnectionId {\n    pub position: u64,\n    pub stream_id: StreamId,\n}\n\n/// Unique identifier that can represent a session.\n#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]\npub struct SessionId {\n    id: u64,\n}\n\nimpl Display for SessionId {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        writeln!(f, \"SessionId({})\", self.id)\n    }\n}\n\nimpl SessionId {\n    /// Create a new [session id](SessionId).\n    #[allow(dead_code)]\n    pub fn new() -> Self {\n        Self {\n            id: IdGenerator::generate(),\n        }\n    }\n}\n\n#[allow(missing_docs)]\n#[derive(Serialize, Deserialize, Debug)]\npub enum Task {\n    Compute(ComputeTask, ConnectionId),\n    Init(SessionId),\n    Close(SessionId),\n}\n\n#[allow(missing_docs)]\n#[derive(Serialize, Deserialize, Debug, Clone)]\npub struct TensorRemote {\n    pub transfer_id: TensorTransferId,\n    pub address: Address,\n}\n\n#[allow(missing_docs)]\n#[derive(Serialize, Deserialize, Debug)]\npub enum ComputeTask {\n    Seed(u64),\n    RegisterOperation(Box<OperationIr>),\n    RegisterTensor(TensorId, TensorData),\n    RegisterTensorRemote(TensorRemote, TensorId),\n    ExposeTensorRemote {\n        tensor: TensorIr,\n        count: u32,\n        transfer_id: TensorTransferId,\n    },\n    ReadTensor(TensorIr),\n    SyncBackend,\n    SupportsDType(DType),\n}\n\n#[allow(missing_docs)]\n#[derive(Serialize, Deserialize, Debug)]\npub struct TaskResponse {\n    pub content: TaskResponseContent,\n    pub id: ConnectionId,\n}\n\n#[allow(missing_docs)]\n#[derive(Serialize, Deserialize, Debug)]\npub enum TaskResponseContent {\n    ReadTensor(Result<TensorData, ExecutionError>),\n    SyncBackend(Result<(), ExecutionError>),\n    // SupportsDType(DTypeUsageSet),\n    // TODO: Update to `DTypeUsageSet` when it implements `serde`.\n    SupportsDType(()),\n}\n"
  },
  {
    "path": "crates/burn-rl/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"RL crate for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-rl\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-rl\"\ndocumentation = \"https://docs.rs/burn-rl\"\nversion.workspace = true\n\n[dependencies]\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", features = [\n    \"dataset\",\n    \"std\",\n], default-features = false }\nburn-optim = { path = \"../burn-optim\", version = \"=0.21.0-pre.2\", features = [\n    \"std\",\n], default-features = false }\n\nderive-new.workspace = true\nlog = { workspace = true }\nrand.workspace = true\n\n[dev-dependencies]\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\n\n[lints]\nworkspace = true\n"
  },
  {
    "path": "crates/burn-rl/README.md",
    "content": "# Burn RL\n\n<!-- This crate should be used with [burn](https://github.com/tracel-ai/burn). -->\n\n<!-- [![Current Crates.io Version](https://img.shields.io/crates/v/burn-rl.svg)](https://crates.io/crates/burn-rl)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-rl/blob/master/README.md) -->\n"
  },
  {
    "path": "crates/burn-rl/src/environment/base.rs",
    "content": "/// The result of taking a step in an environment.\npub struct StepResult<S> {\n    /// The updated state.\n    pub next_state: S,\n    /// The reward.\n    pub reward: f64,\n    /// If the environment reached a terminal state.\n    pub done: bool,\n    /// If the environment reached its max length.\n    pub truncated: bool,\n}\n\n/// Trait to be implemented for a RL environment.\npub trait Environment {\n    /// The type of the state.\n    type State;\n    /// The type of actions.\n    type Action;\n\n    /// The maximum number of step for one episode.\n    const MAX_STEPS: usize;\n\n    /// Returns the current state.\n    fn state(&self) -> Self::State;\n    /// Take a step in the environment given an action.\n    fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;\n    /// Reset the environment to an initial state.\n    fn reset(&mut self);\n}\n\n/// Trait to define how to initialize an environment.\n/// By default, any function returning an environment implements it.\npub trait EnvironmentInit<E: Environment>: Clone {\n    /// Initialize the environment.\n    fn init(&self) -> E;\n}\n\nimpl<F, E> EnvironmentInit<E> for F\nwhere\n    F: Fn() -> E + Clone,\n    E: Environment,\n{\n    fn init(&self) -> E {\n        (self)()\n    }\n}\n"
  },
  {
    "path": "crates/burn-rl/src/environment/mod.rs",
    "content": "mod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-rl/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! A library for training reinforcement learning agents.\n\n/// Module for implementing an environment.\npub mod environment;\n/// Module for implementing a policy.\npub mod policy;\n/// Transition buffer.\npub mod transition_buffer;\n\npub use environment::*;\npub use policy::*;\npub use transition_buffer::*;\n\n#[cfg(test)]\npub(crate) type TestBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(test)]\npub(crate) mod tests {\n    use crate::{Batchable, Policy, PolicyState, TestBackend};\n\n    use burn_core::record::Record;\n    use burn_core::{self as burn};\n\n    /// Mock policy for testing\n    ///\n    /// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)\n    /// containing a list of 0s of the same length as the observation.\n    ///\n    /// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockAction](MockAction) with a list of actions of the same length as the observation.\n    /// The actions are all 1 if the call is requested as deterministic, or else 0.\n    #[derive(Clone)]\n    pub(crate) struct MockPolicy {}\n\n    impl MockPolicy {\n        pub fn new() -> Self {\n            Self {}\n        }\n    }\n\n    impl Policy<TestBackend> for MockPolicy {\n        type Observation = MockObservation;\n        type ActionDistribution = MockActionDistribution;\n        type Action = MockAction;\n        type ActionContext = MockActionContext;\n        type PolicyState = MockPolicyState;\n\n        fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {\n            let mut dists = vec![];\n\n            for _ in obs.0 {\n                dists.push(MockActionDistribution(vec![0.]));\n            }\n            MockActionDistribution::batch(dists)\n        }\n\n        fn action(\n            &mut self,\n            obs: Self::Observation,\n            deterministic: bool,\n        ) -> (Self::Action, Vec<Self::ActionContext>) {\n            let mut actions = vec![];\n            let mut contexts = vec![];\n\n            for _ in obs.0 {\n                if deterministic {\n                    actions.push(MockAction(vec![1]));\n                } else {\n                    actions.push(MockAction(vec![0]));\n                }\n                contexts.push(MockActionContext);\n            }\n\n            (MockAction::batch(actions), contexts)\n        }\n\n        fn update(&mut self, _update: Self::PolicyState) {}\n\n        fn state(&self) -> Self::PolicyState {\n            MockPolicyState\n        }\n\n        fn load_record(\n            self,\n            _record: <Self::PolicyState as PolicyState<TestBackend>>::Record,\n        ) -> Self {\n            self\n        }\n    }\n\n    /// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockObservation(pub Vec<f32>);\n\n    /// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockAction(pub Vec<i32>);\n\n    /// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockActionDistribution(Vec<f32>);\n\n    #[derive(Clone)]\n    pub(crate) struct MockActionContext;\n\n    /// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.\n    #[derive(Clone)]\n    pub(crate) struct MockPolicyState;\n\n    #[derive(Clone, Record)]\n    pub(crate) struct MockRecord {\n        item: usize,\n    }\n\n    impl PolicyState<TestBackend> for MockPolicyState {\n        type Record = MockRecord;\n\n        fn into_record(self) -> Self::Record {\n            MockRecord { item: 0 }\n        }\n\n        fn load_record(&self, _record: Self::Record) -> Self {\n            self.clone()\n        }\n    }\n\n    impl Batchable for MockObservation {\n        fn batch(items: Vec<Self>) -> Self {\n            MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            vec![MockObservation(self.0)]\n        }\n    }\n\n    impl Batchable for MockAction {\n        fn batch(items: Vec<Self>) -> Self {\n            MockAction(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            let mut actions = vec![];\n            for a in self.0 {\n                actions.push(MockAction(vec![a]));\n            }\n            actions\n        }\n    }\n\n    impl Batchable for MockActionDistribution {\n        fn batch(items: Vec<Self>) -> Self {\n            MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            let mut dists = vec![];\n            for _ in self.0 {\n                dists.push(MockActionDistribution(vec![0.]));\n            }\n            dists\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-rl/src/policy/async_policy.rs",
    "content": "use std::{\n    sync::{\n        Arc,\n        atomic::{AtomicUsize, Ordering},\n        mpsc::{self, Sender},\n    },\n    thread::spawn,\n};\n\nuse burn_core::prelude::Backend;\n\nuse crate::{ActionContext, Batchable, Policy, PolicyState};\n\n#[derive(Clone)]\nstruct PolicyInferenceServer<B: Backend, P: Policy<B>> {\n    // `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.\n    num_agents: Arc<AtomicUsize>,\n    max_autobatch_size: usize,\n    inner_policy: P,\n    batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,\n    batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,\n}\n\nimpl<B, P> PolicyInferenceServer<B, P>\nwhere\n    B: Backend,\n    P: Policy<B>,\n    P::Observation: Clone + Batchable,\n    P::ActionDistribution: Clone + Batchable,\n    P::Action: Clone + Batchable,\n    P::ActionContext: Clone,\n{\n    pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {\n        Self {\n            num_agents: Arc::new(AtomicUsize::new(0)),\n            max_autobatch_size,\n            inner_policy,\n            batch_action: vec![],\n            batch_logits: vec![],\n        }\n    }\n\n    pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {\n        self.batch_action.push(item);\n        if self.len_actions()\n            >= self\n                .num_agents\n                .load(Ordering::Relaxed)\n                .min(self.max_autobatch_size)\n        {\n            self.flush_actions();\n        }\n    }\n\n    pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {\n        self.batch_logits.push(item);\n        if self.len_logits()\n            >= self\n                .num_agents\n                .load(Ordering::Relaxed)\n                .min(self.max_autobatch_size)\n        {\n            self.flush_logits();\n        }\n    }\n\n    pub fn len_actions(&self) -> usize {\n        self.batch_action.len()\n    }\n\n    pub fn len_logits(&self) -> usize {\n        self.batch_logits.len()\n    }\n\n    pub fn flush_actions(&mut self) {\n        if self.len_actions() == 0 {\n            return;\n        }\n        let input: Vec<_> = self\n            .batch_action\n            .iter()\n            .map(|m| m.inference_state.clone())\n            .collect();\n        // Only deterministic if all actions are requested as deterministic.\n        let deterministic = self.batch_action.iter().all(|item| item.deterministic);\n        let (actions, context) = self\n            .inner_policy\n            .action(P::Observation::batch(input), deterministic);\n        let actions: Vec<_> = actions.unbatch();\n\n        for (i, item) in self.batch_action.iter().enumerate() {\n            item.sender\n                .send(ActionContext {\n                    context: vec![context[i].clone()],\n                    action: actions[i].clone(),\n                })\n                .expect(\"Autobatcher should be able to send resulting actions.\");\n        }\n        self.batch_action.clear();\n    }\n\n    pub fn flush_logits(&mut self) {\n        if self.len_logits() == 0 {\n            return;\n        }\n        let input: Vec<_> = self\n            .batch_logits\n            .iter()\n            .map(|m| m.inference_state.clone())\n            .collect();\n        let output = self.inner_policy.forward(P::Observation::batch(input));\n        let logits: Vec<_> = output.unbatch();\n        for (i, item) in self.batch_logits.iter().enumerate() {\n            item.sender\n                .send(logits[i].clone())\n                .expect(\"Autobatcher should be able to send resulting probabilities.\");\n        }\n        self.batch_logits.clear();\n    }\n\n    pub fn update_policy(&mut self, policy_update: P::PolicyState) {\n        if self.len_actions() > 0 {\n            self.flush_actions();\n        }\n        if self.len_logits() > 0 {\n            self.flush_logits();\n        }\n        self.inner_policy.update(policy_update);\n    }\n\n    pub fn state(&self) -> P::PolicyState {\n        self.inner_policy.state()\n    }\n\n    pub fn increment_agents(&mut self, num: usize) {\n        self.num_agents.fetch_add(num, Ordering::Relaxed);\n    }\n\n    pub fn decrement_agents(&mut self, num: usize) {\n        self.num_agents.fetch_sub(num, Ordering::Relaxed);\n        if self.len_actions()\n            >= self\n                .num_agents\n                .load(Ordering::Relaxed)\n                .min(self.max_autobatch_size)\n        {\n            self.flush_actions();\n        }\n        if self.len_logits()\n            >= self\n                .num_agents\n                .load(Ordering::Relaxed)\n                .min(self.max_autobatch_size)\n        {\n            self.flush_logits();\n        }\n    }\n}\n\nenum InferenceMessage<B: Backend, P: Policy<B>> {\n    ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),\n    ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),\n    PolicyUpdate(P::PolicyState),\n    PolicyRequest(Sender<P::PolicyState>),\n    IncrementAgents(usize),\n    DecrementAgents(usize),\n}\n\n#[derive(Clone)]\nstruct ActionItem<S, A, C> {\n    sender: Sender<ActionContext<A, Vec<C>>>,\n    inference_state: S,\n    deterministic: bool,\n}\n\n#[derive(Clone)]\nstruct ForwardItem<S, O> {\n    sender: Sender<O>,\n    inference_state: S,\n}\n\n/// An asynchronous policy using an inference server with autobatching.\n#[derive(Clone)]\npub struct AsyncPolicy<B: Backend, P: Policy<B>> {\n    inference_state_sender: Sender<InferenceMessage<B, P>>,\n}\n\nimpl<B, P> AsyncPolicy<B, P>\nwhere\n    B: Backend,\n    P: Policy<B> + Clone + Send + 'static,\n    P::ActionContext: Clone + Send,\n    P::PolicyState: Send,\n    P::Observation: Clone + Send + Batchable,\n    P::ActionDistribution: Clone + Send + Batchable,\n    P::Action: Clone + Send + Batchable,\n{\n    /// Create the policy.\n    ///\n    /// # Arguments\n    ///\n    /// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.\n    /// * `inner_policy` - The policy used to take actions.\n    pub fn new(autobatch_size: usize, inner_policy: P) -> Self {\n        let (sender, receiver) = std::sync::mpsc::channel();\n        let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());\n        spawn(move || {\n            loop {\n                match receiver.recv() {\n                    Ok(msg) => match msg {\n                        InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),\n                        InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),\n                        InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),\n                        InferenceMessage::PolicyRequest(sender) => sender\n                            .send(autobatcher.state())\n                            .expect(\"Autobatcher should be able to send current policy state.\"),\n                        InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),\n                        InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),\n                    },\n                    Err(err) => {\n                        log::error!(\"Error in AsyncPolicy : {}\", err);\n                        break;\n                    }\n                }\n            }\n        });\n\n        Self {\n            inference_state_sender: sender,\n        }\n    }\n\n    /// Increment the number of agents using the inference server.\n    pub fn increment_agents(&self, num: usize) {\n        self.inference_state_sender\n            .send(InferenceMessage::IncrementAgents(num))\n            .expect(\"Can send message to autobatcher.\")\n    }\n\n    /// Decrement the number of agents using the inference server.\n    pub fn decrement_agents(&self, num: usize) {\n        self.inference_state_sender\n            .send(InferenceMessage::DecrementAgents(num))\n            .expect(\"Can send message to autobatcher.\")\n    }\n}\n\nimpl<B, P> Policy<B> for AsyncPolicy<B, P>\nwhere\n    B: Backend,\n    P: Policy<B> + Send + 'static,\n{\n    type ActionContext = P::ActionContext;\n    type PolicyState = P::PolicyState;\n\n    type Observation = P::Observation;\n    type ActionDistribution = P::ActionDistribution;\n    type Action = P::Action;\n\n    fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {\n        let (action_sender, action_receiver) = std::sync::mpsc::channel();\n        let item = ForwardItem {\n            sender: action_sender,\n            inference_state: states,\n        };\n        self.inference_state_sender\n            .send(InferenceMessage::ForwardMessage(item))\n            .expect(\"Should be able to send message to inference_server\");\n        action_receiver\n            .recv()\n            .expect(\"AsyncPolicy should receive queued probabilities.\")\n    }\n\n    fn action(\n        &mut self,\n        states: Self::Observation,\n        deterministic: bool,\n    ) -> (Self::Action, Vec<Self::ActionContext>) {\n        let (action_sender, action_receiver) = std::sync::mpsc::channel();\n        let item = ActionItem {\n            sender: action_sender,\n            inference_state: states,\n            deterministic,\n        };\n        self.inference_state_sender\n            .send(InferenceMessage::ActionMessage(item))\n            .expect(\"should be able to send message to inference_server.\");\n        let action = action_receiver\n            .recv()\n            .expect(\"AsyncPolicy should receive queued actions.\");\n        (action.action, action.context)\n    }\n\n    fn update(&mut self, update: Self::PolicyState) {\n        self.inference_state_sender\n            .send(InferenceMessage::PolicyUpdate(update))\n            .expect(\"AsyncPolicy should be able to send policy state.\")\n    }\n\n    fn state(&self) -> Self::PolicyState {\n        let (sender, receiver) = mpsc::channel();\n        self.inference_state_sender\n            .send(InferenceMessage::PolicyRequest(sender))\n            .expect(\"should be able to send message to inference_server.\");\n        receiver\n            .recv()\n            .expect(\"AsyncPolicy should be able to receive policy state.\")\n    }\n\n    fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {\n        // Not needed for now\n        todo!()\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::needless_range_loop)]\nmod tests {\n    use std::thread::JoinHandle;\n    use std::time::Duration;\n\n    use crate::TestBackend;\n    use crate::tests::{MockAction, MockObservation, MockPolicy};\n\n    use super::*;\n\n    #[test]\n    fn test_multiple_actions_before_flush() {\n        fn launch_thread(\n            policy: &AsyncPolicy<TestBackend, MockPolicy>,\n            handles: &mut Vec<JoinHandle<()>>,\n        ) {\n            let mut thread_policy = policy.clone();\n            let handle = spawn(move || {\n                thread_policy.action(MockObservation(vec![0.]), false);\n            });\n            handles.push(handle);\n        }\n\n        let policy = AsyncPolicy::new(8, MockPolicy::new());\n        policy.increment_agents(1000);\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n\n        for _ in 0..6 {\n            launch_thread(&policy, &mut handles);\n        }\n        std::thread::sleep(Duration::from_millis(10));\n        for i in 0..7 {\n            assert!(!handles[i].is_finished());\n        }\n\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        for i in 0..8 {\n            assert!(handles[i].is_finished());\n        }\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n    }\n\n    #[test]\n    fn test_multiple_forward_before_flush() {\n        fn launch_thread(\n            policy: &AsyncPolicy<TestBackend, MockPolicy>,\n            handles: &mut Vec<JoinHandle<()>>,\n        ) {\n            let mut thread_policy = policy.clone();\n            let handle = spawn(move || {\n                thread_policy.forward(MockObservation(vec![0.]));\n            });\n            handles.push(handle);\n        }\n\n        let policy = AsyncPolicy::new(8, MockPolicy::new());\n        policy.increment_agents(1000);\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n\n        for _ in 0..6 {\n            launch_thread(&policy, &mut handles);\n        }\n        std::thread::sleep(Duration::from_millis(10));\n        for i in 0..7 {\n            assert!(!handles[i].is_finished());\n        }\n\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        for i in 0..8 {\n            assert!(handles[i].is_finished());\n        }\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n    }\n\n    #[test]\n    fn test_async_policy_deterministic_behaviour() {\n        fn launch_thread(\n            policy: &AsyncPolicy<TestBackend, MockPolicy>,\n            handles: &mut Vec<JoinHandle<MockAction>>,\n            deterministic: bool,\n        ) {\n            let mut thread_policy = policy.clone();\n            let handle = spawn(move || {\n                let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic);\n                action\n            });\n            handles.push(handle);\n        }\n\n        let policy = AsyncPolicy::new(2, MockPolicy::new());\n        policy.increment_agents(1000);\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles, true);\n        launch_thread(&policy, &mut handles, false);\n        for _ in 0..2 {\n            let action = handles.pop().unwrap().join().unwrap();\n            assert_eq!(action.0, vec![0]);\n        }\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles, true);\n        launch_thread(&policy, &mut handles, true);\n        for _ in 0..2 {\n            let action = handles.pop().unwrap().join().unwrap();\n            assert_eq!(action.0, vec![1]);\n        }\n    }\n\n    #[test]\n    fn flush_when_running_agents_smaller_than_autobatch_size() {\n        fn launch_thread(\n            policy: &AsyncPolicy<TestBackend, MockPolicy>,\n            handles: &mut Vec<JoinHandle<()>>,\n        ) {\n            let mut thread_policy = policy.clone();\n            let handle = spawn(move || {\n                thread_policy.action(MockObservation(vec![0.]), false);\n            });\n            handles.push(handle);\n        }\n\n        let policy = AsyncPolicy::new(8, MockPolicy::new());\n        policy.increment_agents(3);\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n        assert!(!handles[1].is_finished());\n\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        for i in 0..3 {\n            assert!(handles[i].is_finished());\n        }\n\n        let mut handles = vec![];\n        launch_thread(&policy, &mut handles);\n        launch_thread(&policy, &mut handles);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(!handles[0].is_finished());\n        assert!(!handles[1].is_finished());\n\n        policy.decrement_agents(1);\n        std::thread::sleep(Duration::from_millis(10));\n        assert!(handles[0].is_finished());\n        assert!(handles[1].is_finished());\n    }\n}\n"
  },
  {
    "path": "crates/burn-rl/src/policy/base.rs",
    "content": "use derive_new::new;\n\nuse burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};\n\nuse crate::TransitionBatch;\n\n/// An action along with additional context about the decision.\n#[derive(Clone, new)]\npub struct ActionContext<A, C> {\n    /// The context.\n    pub context: C,\n    /// The action.\n    pub action: A,\n}\n\n/// The state of a policy.\npub trait PolicyState<B: Backend> {\n    /// The type of the record.\n    type Record: Record<B>;\n\n    /// Convert the state to a record.\n    fn into_record(self) -> Self::Record;\n    /// Load the state from a record.\n    fn load_record(&self, record: Self::Record) -> Self;\n}\n\n/// Trait for a RL policy.\npub trait Policy<B: Backend>: Clone {\n    /// The observation given as input to the policy.\n    type Observation;\n    /// The action distribution parameters defining how the action will be sampled.\n    type ActionDistribution;\n    /// The action.\n    type Action;\n\n    /// Additional context on the policy's decision.\n    type ActionContext;\n    /// The current parameterization of the policy.\n    type PolicyState: PolicyState<B>;\n\n    /// Produces the action distribution from a batch of observations.\n    fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;\n    /// Gives the action from a batch of observations.\n    fn action(\n        &mut self,\n        obs: Self::Observation,\n        deterministic: bool,\n    ) -> (Self::Action, Vec<Self::ActionContext>);\n\n    /// Update the policy's parameters.\n    fn update(&mut self, update: Self::PolicyState);\n    /// Returns the current parameterization.\n    fn state(&self) -> Self::PolicyState;\n\n    /// Loads the policy parameters from a record.\n    fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;\n}\n\n/// Trait for a type that can be batched and unbatched (split).\npub trait Batchable: Sized {\n    /// Create a batch from a list of items.\n    fn batch(value: Vec<Self>) -> Self;\n    /// Create a list from batched items.\n    fn unbatch(self) -> Vec<Self>;\n}\n\n/// A training output.\npub struct RLTrainOutput<TO, P> {\n    /// The policy.\n    pub policy: P,\n    /// The item.\n    pub item: TO,\n}\n\n/// Batched transitions for a PolicyLearner.\npub type LearnerTransitionBatch<B, P> =\n    TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;\n\n/// Learner for a policy.\npub trait PolicyLearner<B>\nwhere\n    B: AutodiffBackend,\n    <Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,\n    <Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,\n    <Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,\n{\n    /// Additional context of a training step.\n    type TrainContext;\n    /// The policy to train.\n    type InnerPolicy: Policy<B>;\n    /// The record of the learner.\n    type Record: Record<B>;\n\n    /// Execute a training step on the policy.\n    fn train(\n        &mut self,\n        input: LearnerTransitionBatch<B, Self::InnerPolicy>,\n    ) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;\n    /// Returns the learner's current policy.\n    fn policy(&self) -> Self::InnerPolicy;\n    /// Update the learner's policy.\n    fn update_policy(&mut self, update: Self::InnerPolicy);\n\n    /// Convert the learner's state into a record.\n    fn record(&self) -> Self::Record;\n    /// Load the learner's state from a record.\n    fn load_record(self, record: Self::Record) -> Self;\n}\n"
  },
  {
    "path": "crates/burn-rl/src/policy/mod.rs",
    "content": "mod async_policy;\nmod base;\n\npub use async_policy::*;\npub use base::*;\n"
  },
  {
    "path": "crates/burn-rl/src/transition_buffer/base.rs",
    "content": "use burn_core::{Tensor, prelude::Backend, tensor::Distribution};\nuse derive_new::new;\n\nuse super::SliceAccess;\n\n/// A state transition in an environment.\n#[derive(Clone, new)]\npub struct Transition<B: Backend, S, A> {\n    /// The initial state.\n    pub state: S,\n    /// The state after the step was taken.\n    pub next_state: S,\n    /// The action taken in the step.\n    pub action: A,\n    /// The reward.\n    pub reward: Tensor<B, 1>,\n    /// If the environment has reached a terminal state.\n    pub done: Tensor<B, 1>,\n}\n\n/// A batch of transitions.\npub struct TransitionBatch<B: Backend, SB, AB> {\n    /// Batched initial states.\n    pub states: SB,\n    /// Batched resulting states.\n    pub next_states: SB,\n    /// Batched actions.\n    pub actions: AB,\n    /// Batched rewards.\n    pub rewards: Tensor<B, 2>,\n    /// Batched flags for terminal states.\n    pub dones: Tensor<B, 2>,\n}\n\n/// A tensor-backed circular buffer for transitions.\n///\n/// Uses [`SliceAccess`] to store state and action batches in contiguous\n/// tensor storage, enabling efficient random sampling via `select`.\n/// The buffer lazily initializes its storage on the first `push` call.\npub struct TransitionBuffer<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> {\n    states: Option<SB>,\n    next_states: Option<SB>,\n    actions: Option<AB>,\n    rewards: Option<Tensor<B, 2>>,\n    dones: Option<Tensor<B, 2>>,\n    capacity: usize,\n    write_head: usize,\n    len: usize,\n    device: B::Device,\n}\n\nimpl<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> TransitionBuffer<B, SB, AB> {\n    /// Creates a new buffer. Storage is lazily allocated on the first `push`.\n    pub fn new(capacity: usize, device: &B::Device) -> Self {\n        Self {\n            states: None,\n            next_states: None,\n            actions: None,\n            rewards: None,\n            dones: None,\n            capacity,\n            write_head: 0,\n            len: 0,\n            device: device.clone(),\n        }\n    }\n\n    fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) {\n        if self.states.is_none() {\n            self.states = Some(SB::zeros_like(state, self.capacity, &self.device));\n            self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device));\n            self.actions = Some(AB::zeros_like(action, self.capacity, &self.device));\n            self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device));\n            self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device));\n        }\n    }\n\n    /// Add a transition, overwriting the oldest if full.\n    pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) {\n        self.ensure_init(&state, &next_state, &action);\n\n        let idx = self.write_head % self.capacity;\n\n        self.states\n            .as_mut()\n            .unwrap()\n            .slice_assign_inplace(idx, state);\n        self.next_states\n            .as_mut()\n            .unwrap()\n            .slice_assign_inplace(idx, next_state);\n        self.actions\n            .as_mut()\n            .unwrap()\n            .slice_assign_inplace(idx, action);\n\n        let reward = Tensor::from_data([[reward]], &self.device);\n        self.rewards\n            .as_mut()\n            .unwrap()\n            .inplace(|r| r.slice_assign(idx..idx + 1, reward));\n\n        let done_val = if done { 1.0f32 } else { 0.0 };\n        let done = Tensor::from_data([[done_val]], &self.device);\n        self.dones\n            .as_mut()\n            .unwrap()\n            .inplace(|d| d.slice_assign(idx..idx + 1, done));\n\n        self.write_head += 1;\n        if self.len < self.capacity {\n            self.len += 1;\n        }\n    }\n\n    /// Sample a random batch of transitions.\n    pub fn sample(&self, batch_size: usize) -> TransitionBatch<B, SB, AB> {\n        assert!(batch_size <= self.len, \"batch_size exceeds buffer length\");\n\n        let indices = Tensor::<B, 1>::random(\n            [batch_size],\n            Distribution::Uniform(0.0, self.len as f64),\n            &self.device,\n        )\n        .int();\n\n        TransitionBatch {\n            states: self\n                .states\n                .as_ref()\n                .unwrap()\n                .clone()\n                .select(0, indices.clone()),\n            next_states: self\n                .next_states\n                .as_ref()\n                .unwrap()\n                .clone()\n                .select(0, indices.clone()),\n            actions: self\n                .actions\n                .as_ref()\n                .unwrap()\n                .clone()\n                .select(0, indices.clone()),\n            rewards: self\n                .rewards\n                .as_ref()\n                .unwrap()\n                .clone()\n                .select(0, indices.clone()),\n            dones: self.dones.as_ref().unwrap().clone().select(0, indices),\n        }\n    }\n\n    /// Current number of stored transitions.\n    pub fn len(&self) -> usize {\n        self.len\n    }\n\n    /// Whether the buffer is empty.\n    pub fn is_empty(&self) -> bool {\n        self.len == 0\n    }\n\n    /// Buffer capacity.\n    pub fn capacity(&self) -> usize {\n        self.capacity\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    type TB = Tensor<TestBackend, 2>;\n\n    fn push_transition(\n        buffer: &mut TransitionBuffer<TestBackend, TB, TB>,\n        device: &<TestBackend as Backend>::Device,\n        val: f32,\n    ) {\n        let state = Tensor::<TestBackend, 2>::from_data([[val, val]], device);\n        let next_state = Tensor::<TestBackend, 2>::from_data([[val + 1.0, val + 1.0]], device);\n        let action = Tensor::<TestBackend, 2>::from_data([[val]], device);\n        buffer.push(state, next_state, action, val, false);\n    }\n\n    #[test]\n    fn push_increment_len() {\n        let device = Default::default();\n        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);\n\n        assert_eq!(buffer.len(), 0);\n        assert!(buffer.is_empty());\n\n        push_transition(&mut buffer, &device, 1.0);\n        assert_eq!(buffer.len(), 1);\n\n        push_transition(&mut buffer, &device, 2.0);\n        assert_eq!(buffer.len(), 2);\n    }\n\n    #[test]\n    fn push_overwrites_when_full() {\n        let device = Default::default();\n        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(3, &device);\n\n        for i in 0..5 {\n            push_transition(&mut buffer, &device, i as f32);\n        }\n\n        assert_eq!(buffer.len(), 3);\n        assert_eq!(buffer.capacity(), 3);\n    }\n\n    #[test]\n    fn sample_returns_correct_shapes() {\n        let device = Default::default();\n        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(10, &device);\n\n        for i in 0..5 {\n            push_transition(&mut buffer, &device, i as f32);\n        }\n\n        let batch = buffer.sample(3);\n        assert_eq!(batch.states.dims(), [3, 2]);\n        assert_eq!(batch.next_states.dims(), [3, 2]);\n        assert_eq!(batch.actions.dims(), [3, 1]);\n        assert_eq!(batch.rewards.dims(), [3, 1]);\n        assert_eq!(batch.dones.dims(), [3, 1]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"batch_size exceeds buffer length\")]\n    fn sample_panics_when_batch_too_large() {\n        let device = Default::default();\n        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);\n\n        push_transition(&mut buffer, &device, 1.0);\n        buffer.sample(5);\n    }\n}\n"
  },
  {
    "path": "crates/burn-rl/src/transition_buffer/mod.rs",
    "content": "mod base;\nmod slice_access;\n\npub use base::*;\npub use slice_access::*;\n"
  },
  {
    "path": "crates/burn-rl/src/transition_buffer/slice_access.rs",
    "content": "use burn_core::prelude::*;\n\n/// Trait for types that support tensor-like slice operations,\n/// enabling storage in a [`TransitionBuffer`](super::TransitionBuffer).\n///\n/// Implement this trait for any type that wraps tensors and can be stored\n/// in a replay buffer. The buffer uses these operations for:\n/// - Pre-allocating storage (`zeros_like`)\n/// - Writing transitions (`slice_assign_inplace`)\n/// - Sampling batches (`select`)\npub trait SliceAccess<B: Backend>: Clone + Sized {\n    /// Create zeroed storage matching the shape of `sample` but with `capacity` rows\n    /// along the first dimension.\n    fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self;\n\n    /// Select rows at the given indices along the specified dimension.\n    fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self;\n\n    /// Assign `value` at row `index` along the first dimension, in place.\n    fn slice_assign_inplace(&mut self, index: usize, value: Self);\n}\n\nimpl<B: Backend> SliceAccess<B> for Tensor<B, 2> {\n    fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {\n        let feature_dim = sample.dims()[1];\n        Tensor::zeros([capacity, feature_dim], device)\n    }\n\n    fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {\n        Tensor::select(self, dim, indices)\n    }\n\n    fn slice_assign_inplace(&mut self, index: usize, value: Self) {\n        self.inplace(|t| t.slice_assign(index..index + 1, value));\n    }\n}\n"
  },
  {
    "path": "crates/burn-rocm/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"ROCm HIP backend for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\", \"rocm\", \"hip\"]\nlicense.workspace = true\nname = \"burn-rocm\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-rocm\"\ndocumentation = \"https://docs.rs/burn-rocm\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"fusion\", \"burn-cubecl/default\", \"cubecl/default\"]\ntracing = [\n    \"cubecl/tracing\",\n    \"burn-cubecl/tracing\",\n    \"burn-backend/tracing\",\n    \"burn-fusion?/tracing\",\n]\n\nfusion = [\"burn-fusion\", \"burn-cubecl/fusion\"]\nautotune = [\"burn-cubecl/autotune\"]\nautotune-checks = [\"burn-cubecl/autotune-checks\"]\ndoc = [\"burn-cubecl/doc\"]\nstd = [\"burn-cubecl/std\", \"cubecl/std\"]\n\n[dependencies]\ncubecl = { workspace = true, features = [\"hip\"] }\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", default-features = true }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", features = [\n    \"cubecl-hip\",\n] }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-rocm/README.md",
    "content": "# burn-rocm\n\nBackend using ROCm HIP runtime.\n\nTo execute the tests for this backend set an environment variable called `ROCM_PATH` or `CUBECL_ROCM_PATH` to the installation path of ROCm. It is often `/opt/rocm`.\n\nFor now this backend requires the version `6.2.2` of ROCm or a compatible version.\n"
  },
  {
    "path": "crates/burn-rocm/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\nextern crate alloc;\n\nuse burn_cubecl::CubeBackend;\n\npub use cubecl::hip::AmdDevice as RocmDevice;\n\nuse cubecl::hip::HipRuntime;\n\n#[cfg(not(feature = \"fusion\"))]\npub type Rocm<F = f32, I = i32, B = u8> = CubeBackend<HipRuntime, F, I, B>;\n\n#[cfg(feature = \"fusion\")]\npub type Rocm<F = f32, I = i32, B = u8> = burn_fusion::Fusion<CubeBackend<HipRuntime, F, I, B>>;\n"
  },
  {
    "path": "crates/burn-router/Cargo.toml",
    "content": "[package]\nauthors = [\n    \"laggui <lagrange.guillaume.1@gmail.com>\",\n    \"nathanielsimard <nathaniel.simard.42@gmail.com>\",\n]\ncategories = [\"science\"]\ndescription = \"Multi-backend router decorator for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-router\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-router\"\ndocumentation = \"https://docs.rs/burn-router\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\nstd = [\"burn-backend/std\", \"burn-std/std\", \"burn-ir/std\"]\ndoc = [\"default\"]\ntracing = [\n    \"burn-backend/tracing\",\n    \"burn-ir/tracing\",\n    \"burn-std/tracing\",\n]\n\n[dependencies]\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\nhashbrown = { workspace = true }\nspin = { workspace = true }\nlog = { workspace = true }\n\n[dev-dependencies]\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"std\",\n] }\n\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-router/README.md",
    "content": "# Burn Router\n\nA multi-backend extension that forwards the tensor operations to the appropriate backend.\n"
  },
  {
    "path": "crates/burn-router/src/backend.rs",
    "content": "use super::{RouterTensor, RunnerChannel, RunnerClient, get_client};\nuse alloc::{format, string::String};\nuse burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive, quantization::QuantScheme};\nuse core::marker::PhantomData;\n\n/// A backend that forwards the tensor operations to the appropriate backend (given multiple backends).\npub struct BackendRouter<R: RunnerChannel> {\n    r: PhantomData<R>,\n}\n\nimpl<R: RunnerChannel> core::fmt::Debug for BackendRouter<R> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_fmt(format_args!(\"router\"))\n    }\n}\n\nimpl<R: RunnerChannel> Clone for BackendRouter<R> {\n    fn clone(&self) -> Self {\n        Self { r: PhantomData }\n    }\n}\n\nimpl<R: RunnerChannel> Default for BackendRouter<R> {\n    fn default() -> Self {\n        Self { r: PhantomData }\n    }\n}\n\nimpl<R: RunnerClient> QTensorPrimitive for RouterTensor<R> {\n    fn scheme(&self) -> &QuantScheme {\n        if let DType::QFloat(scheme) = &self.dtype {\n            scheme\n        } else {\n            // TODO: maybe `tensor.scheme()` should return an option\n            panic!(\"Expected quantized float dtype, got {:?}\", self.dtype)\n        }\n    }\n}\n\nimpl<R: RunnerChannel> Backend for BackendRouter<R> {\n    type Device = R::Device;\n\n    type FloatTensorPrimitive = RouterTensor<R::Client>;\n\n    type FloatElem = R::FloatElem;\n\n    type IntTensorPrimitive = RouterTensor<R::Client>;\n\n    type IntElem = R::IntElem;\n\n    type BoolTensorPrimitive = RouterTensor<R::Client>;\n\n    type BoolElem = R::BoolElem;\n\n    type QuantizedTensorPrimitive = RouterTensor<R::Client>;\n\n    fn name(device: &Self::Device) -> String {\n        format!(\"router<{}>\", R::name(device))\n    }\n\n    fn seed(device: &Self::Device, seed: u64) {\n        let client = get_client::<R>(device);\n        client.seed(seed);\n    }\n\n    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {\n        let client = get_client::<R>(device);\n        client.sync()\n    }\n\n    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {\n        let client = get_client::<R>(device);\n        client.dtype_usage(dtype)\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/bridge/base.rs",
    "content": "use burn_backend::{Shape, backend::DeviceOps};\n\n/// Allows tensors to be transferred between multiple backends.\npub trait MultiBackendBridge: Send + Sync + 'static {\n    /// The type that can be used to point to a tensor of any kind.\n    type TensorHandle;\n    /// Device type used by the backends.\n    type Device: DeviceOps;\n\n    /// Change the backend of the given float tensor.\n    fn change_backend_float(\n        tensor: Self::TensorHandle,\n        shape: Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle;\n\n    /// Change the backend of the given int tensor.\n    fn change_backend_int(\n        tensor: Self::TensorHandle,\n        shape: Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle;\n\n    /// Change the backend of the given bool tensor.\n    fn change_backend_bool(\n        tensor: Self::TensorHandle,\n        shape: Shape,\n        target_device: &Self::Device,\n    ) -> Self::TensorHandle;\n\n    // TODO: change_backend_quantized\n}\n"
  },
  {
    "path": "crates/burn-router/src/bridge/byte.rs",
    "content": "use core::marker::PhantomData;\n\n/// Simply transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).\npub struct ByteBridge<Backends> {\n    backends: PhantomData<Backends>,\n}\n"
  },
  {
    "path": "crates/burn-router/src/bridge/mod.rs",
    "content": "mod base;\nmod byte;\n\npub use base::*;\npub use byte::*;\n"
  },
  {
    "path": "crates/burn-router/src/channel/base.rs",
    "content": "use alloc::string::String;\nuse burn_backend::{DType, Element, Shape, backend::DeviceOps};\nuse burn_ir::TensorIr;\n\nuse crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};\n\n/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.\npub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;\n\n/// Defines the connection channel and operations for a setup with multiple backend runner clients.\npub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {\n    /// Device type.\n    type Device: DeviceOps;\n    /// A bridge that can transfer tensors between multiple backends.\n    type Bridge: MultiBackendBridge<Device = Self::Device>;\n    /// Client type.\n    type Client: RunnerClient<Device = Self::Device>;\n    /// Float element type.\n    type FloatElem: Element;\n    /// Int element type.\n    type IntElem: Element;\n    /// Bool element type.\n    type BoolElem: Element;\n\n    /// Name of the channel.\n    fn name(device: &Self::Device) -> String;\n\n    /// Initialize a new client for the given device.\n    fn init_client(device: &Self::Device) -> Self::Client;\n\n    /// Get the tensor handle corresponding to the [tensor representation](TensorIr).\n    fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;\n\n    /// Create a tensor with the given handle and shape.\n    fn register_tensor(\n        client: &Self::Client,\n        handle: TensorHandle<Self::Bridge>,\n        shape: Shape,\n        dtype: DType,\n    ) -> RouterTensor<Self::Client>;\n\n    /// Change the tensor to a different client backend.\n    fn change_client_backend(\n        tensor: RouterTensor<Self::Client>,\n        device: &Self::Device, // target device\n    ) -> RouterTensor<Self::Client> {\n        // Get tensor handle from current client\n        let original_client = tensor.client.clone();\n        let desc = tensor.into_ir();\n        let mut handle = Self::get_tensor_handle(&desc, &original_client);\n\n        if desc.dtype.is_float() {\n            handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);\n        } else if desc.dtype.is_int() {\n            handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);\n        } else if desc.dtype.is_bool() {\n            handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);\n        } else {\n            unimplemented!()\n        }\n\n        // Register tensor handle on target client\n        let target_client = get_client::<Self>(device);\n        Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/channel/direct.rs",
    "content": "use core::marker::PhantomData;\n\n/// A local channel with direct connection to the backend runner clients.\npub struct DirectChannel<Backends, Bridge> {\n    backends: PhantomData<Backends>,\n    bridge: PhantomData<Bridge>,\n}\n\nimpl<Backends, Bridge> Clone for DirectChannel<Backends, Bridge> {\n    fn clone(&self) -> Self {\n        Self {\n            backends: self.backends,\n            bridge: self.bridge,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/channel/mod.rs",
    "content": "mod base;\nmod direct;\n\npub use base::*;\npub use direct::*;\n"
  },
  {
    "path": "crates/burn-router/src/client/base.rs",
    "content": "use crate::{RouterTensor, RunnerChannel};\nuse alloc::boxed::Box;\nuse alloc::vec::Vec;\nuse burn_backend::{\n    DType, TensorData,\n    backend::{DeviceId, DeviceOps, ExecutionError},\n};\nuse burn_ir::{OperationIr, TensorId, TensorIr};\nuse burn_std::future::DynFut;\nuse core::ops::DerefMut;\nuse hashbrown::HashMap;\nuse spin::Mutex;\n\n/// Type alias for `<R as RunnerChannel>::Client`.\npub type Client<R> = <R as RunnerChannel>::Client;\npub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();\n\ntype Key = (core::any::TypeId, DeviceId);\n\n/// Define how to interact with the runner.\npub trait RunnerClient: Clone + Send + Sync + Sized {\n    /// Device type.\n    type Device: DeviceOps;\n\n    /// Register a new tensor operation to be executed by the (runner) server.\n    fn register_op(&self, op: OperationIr);\n    /// Register a new tensor operation to be executed by the (runner) server.\n    ///\n    /// Returns the new (uninitialized) output tensor(s) generated by the registered operation.\n    fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {\n        let out = op\n            .outputs()\n            .map(|output| {\n                RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())\n            })\n            .collect();\n        self.register_op(op);\n\n        out\n    }\n    /// Read the values contained by a tensor.\n    fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;\n    /// Sync the runner, ensure that all computations are finished.\n    fn sync(&self) -> Result<(), ExecutionError>;\n    /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId).\n    fn create_empty_handle(&self) -> TensorId;\n    /// Create a new [RouterTensor] from the tensor data.\n    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;\n    /// Get the current device used by all operations handled by this client.\n    fn device(&self) -> Self::Device;\n    /// Seed the runner.\n    fn seed(&self, seed: u64);\n    /// Returns the supported data type usage set\n    fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet;\n}\n\npub(crate) struct RunnerClientLocator {\n    clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,\n}\n\n/// Get the client for the given device\npub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {\n    CLIENTS.client::<R>(device)\n}\n\n/// Initialize a new client for the given device.\n///\n/// If a (global) seed was previously set, the client seed is set.\nfn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {\n    R::init_client(device)\n}\n\nimpl RunnerClientLocator {\n    /// Create a new client locator.\n    pub const fn new() -> Self {\n        Self {\n            clients: Mutex::new(None),\n        }\n    }\n\n    /// Get the runner client for the given device.\n    ///\n    /// If a client isn't already initialized, it is created.\n    pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {\n        let device_id = device.id();\n        let client_id = (core::any::TypeId::of::<R>(), device_id);\n        let mut clients = self.clients.lock();\n\n        if clients.is_none() {\n            let client = new_client::<R>(device);\n            Self::register_inner::<R>(client_id, client, &mut clients);\n        }\n\n        match clients.deref_mut() {\n            Some(clients) => match clients.get(&client_id) {\n                Some(client) => {\n                    let client: &Client<R> = client.downcast_ref().unwrap();\n                    client.clone()\n                }\n                None => {\n                    let client = new_client::<R>(device);\n                    let any = Box::new(client.clone());\n                    clients.insert(client_id, any);\n                    client\n                }\n            },\n            _ => unreachable!(),\n        }\n    }\n\n    fn register_inner<R: RunnerChannel + 'static>(\n        key: Key,\n        client: Client<R>,\n        clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,\n    ) {\n        if clients.is_none() {\n            *clients = Some(HashMap::new());\n        }\n\n        if let Some(clients) = clients {\n            if clients.contains_key(&key) {\n                panic!(\"Client already created for device {key:?}\");\n            }\n\n            clients.insert(key, Box::new(client));\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/client/mod.rs",
    "content": "mod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-router/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![recursion_limit = \"138\"]\n\n//! Burn multi-backend router.\n\nmod backend;\nmod bridge;\nmod channel;\nmod client;\nmod ops;\nmod runner;\nmod tensor;\nmod types;\n\npub use backend::*;\npub use bridge::*;\npub use channel::*;\npub use client::*;\npub use runner::*;\npub use tensor::*;\npub use types::*;\n\n/// A local channel with a simple byte bridge between backends.\n/// It transfers tensors between backends via the underlying [tensor data](burn_backend::TensorData).\npub type DirectByteChannel<Backends> = DirectChannel<Backends, ByteBridge<Backends>>;\n\n/// Router backend.\n///\n/// # Example\n///\n/// ```ignore\n/// type MyBackend = Router<(NdArray, Wgpu)>;\n/// ```\npub type Router<Backends> = BackendRouter<DirectByteChannel<Backends>>;\n\nextern crate alloc;\n\n#[cfg(test)]\n#[allow(unused)]\nmod tests {\n    use crate::BackendRouter;\n    use crate::DirectByteChannel;\n\n    pub type TestBackend1 = burn_ndarray::NdArray<f32, i32>;\n    pub type TestBackend2 = burn_wgpu::Wgpu<f32, i32>;\n    pub type TestBackend = BackendRouter<DirectByteChannel<(TestBackend1, TestBackend2)>>;\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/activation.rs",
    "content": "use crate::{BackendRouter, RunnerChannel};\nuse burn_backend::ops::ActivationOps;\n\nimpl<R: RunnerChannel> ActivationOps<Self> for BackendRouter<R> {}\n"
  },
  {
    "path": "crates/burn-router/src/ops/binary.rs",
    "content": "#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_float_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);\n        let output = $ops(lhs, rhs);\n\n        $handles.register_float_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_float_cmp_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let rhs = $handles.get_float_tensor::<B>(&$desc.rhs);\n        let output = $ops(lhs, rhs);\n\n        $handles.register_bool_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_int_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);\n        let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);\n        let output = $ops(lhs, rhs);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_int_cmp_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);\n        let rhs = $handles.get_int_tensor::<B>(&$desc.rhs);\n        let output = $ops(lhs, rhs);\n\n        $handles.register_bool_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! binary_bool_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_bool_tensor::<B>(&$desc.lhs);\n        let rhs = $handles.get_bool_tensor::<B>(&$desc.rhs);\n        let output = $ops(lhs, rhs);\n\n        $handles.register_bool_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/bool_tensor.rs",
    "content": "use alloc::vec::Vec;\nuse burn_backend::backend::ExecutionError;\n\nuse crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};\nuse burn_backend::ops::BoolTensorOps;\nuse burn_backend::tensor::{\n    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,\n};\nuse burn_backend::{Element, Scalar, Shape, Slice, TensorData};\nuse burn_ir::{\n    BaseOperationIr, BinaryOpIr, BoolOperationIr, CastOpIr, CatOpIr, CreationOpIr, FlipOpIr,\n    GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, OperationIr, OperationOutput,\n    PermuteOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr,\n    SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,\n};\n\nimpl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {\n    fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc =\n            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Empty(desc)))\n            .output()\n    }\n\n    fn bool_zeros(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc =\n            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Zeros(desc)))\n            .output()\n    }\n\n    fn bool_ones(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc =\n            CreationOpIr::create(shape, R::BoolElem::dtype(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Ones(desc)))\n            .output()\n    }\n\n    async fn bool_into_data(tensor: BoolTensor<Self>) -> Result<TensorData, ExecutionError> {\n        tensor.into_data().await\n    }\n\n    fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {\n        let client = get_client::<R>(device);\n        let out = client.register_tensor_data(data);\n        let desc = InitOperationIr {\n            out: out.to_ir_out(),\n        };\n\n        // Call register op when output is already initialized\n        client.register_op(OperationIr::Init(desc));\n\n        out\n    }\n\n    fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Bool(BoolOperationIr::IntoInt(desc)))\n            .output()\n    }\n\n    fn bool_into_float(tensor: BoolTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Bool(BoolOperationIr::IntoFloat(desc)))\n            .output()\n    }\n\n    fn bool_device(tensor: &BoolTensor<Self>) -> Device<Self> {\n        tensor.client.device()\n    }\n\n    fn bool_to_device(tensor: BoolTensor<Self>, device: &Device<Self>) -> BoolTensor<Self> {\n        if &tensor.client.device() == device {\n            return tensor;\n        }\n        R::change_client_backend(tensor, device)\n    }\n\n    fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Reshape(desc)))\n            .output()\n    }\n\n    fn bool_slice(tensor: BoolTensor<Self>, slices: &[Slice]) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Slice(desc)))\n            .output()\n    }\n\n    fn bool_slice_assign(\n        tensor: BoolTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::SliceAssign(desc)))\n            .output()\n    }\n\n    fn bool_equal(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Equal(desc)))\n            .output()\n    }\n\n    fn bool_not(tensor: BoolTensor<Self>) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Bool(BoolOperationIr::Not(desc)))\n            .output()\n    }\n\n    fn bool_and(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Bool(BoolOperationIr::And(desc)))\n            .output()\n    }\n\n    fn bool_or(lhs: BoolTensor<Self>, rhs: BoolTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Bool(BoolOperationIr::Or(desc)))\n            .output()\n    }\n\n    fn bool_swap_dims(tensor: BoolTensor<Self>, dim1: usize, dim2: usize) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::SwapDims(desc)))\n            .output()\n    }\n\n    fn bool_permute(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Permute(desc)))\n            .output()\n    }\n\n    fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Flip(desc)))\n            .output()\n    }\n\n    fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Expand(desc)))\n            .output()\n    }\n\n    fn bool_cat(tensors: Vec<BoolTensor<Self>>, dim: usize) -> BoolTensor<Self> {\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Cat(desc)))\n            .output()\n    }\n\n    fn bool_repeat_dim(tensor: BoolTensor<Self>, dim: usize, times: usize) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::RepeatDim(desc)))\n            .output()\n    }\n\n    fn bool_unfold(\n        tensor: BoolTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)))\n            .output()\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::MaskWhere(desc)))\n            .output()\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::MaskFill(desc)))\n            .output()\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Gather(desc)))\n            .output()\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::Scatter(desc)))\n            .output()\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseBool(BaseOperationIr::EqualElem(desc)))\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/int_tensor.rs",
    "content": "use alloc::vec::Vec;\nuse burn_backend::backend::{Backend, ExecutionError};\n\nuse crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};\nuse burn_backend::tensor::{\n    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,\n};\nuse burn_backend::{\n    Distribution, Element, IntDType, Scalar, Shape, Slice, TensorData, ops::IntTensorOps,\n};\nuse burn_ir::{\n    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, DimOpIr, FlipOpIr,\n    GatherOpIr, InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, MatmulOpIr,\n    NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr, ReduceDimOpIr,\n    ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr, SelectAssignOpIr,\n    SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,\n};\n\nimpl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {\n    fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Empty(desc)))\n            .output()\n    }\n\n    async fn int_into_data(tensor: IntTensor<Self>) -> Result<TensorData, ExecutionError> {\n        Ok(tensor\n            .into_data()\n            .await?\n            // Since underlying backends can have different data types, we convert to the current elem\n            .convert::<<Self as Backend>::IntElem>())\n    }\n\n    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {\n        let client = get_client::<R>(device);\n        let out = client.register_tensor_data(data);\n        let desc = InitOperationIr {\n            out: out.to_ir_out(),\n        };\n\n        // Call register op when output is already initialized\n        client.register_op(OperationIr::Init(desc));\n\n        out\n    }\n\n    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {\n        tensor.client.device()\n    }\n\n    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {\n        if &tensor.client.device() == device {\n            return tensor;\n        }\n        R::change_client_backend(tensor, device)\n    }\n\n    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Reshape(desc)))\n            .output()\n    }\n\n    fn int_slice(tensor: IntTensor<Self>, slices: &[Slice]) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Slice(desc)))\n            .output()\n    }\n\n    fn int_slice_assign(\n        tensor: IntTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::SliceAssign(desc)))\n            .output()\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::Matmul(desc)))\n            .output()\n    }\n\n    fn int_mask_where(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::MaskWhere(desc)))\n            .output()\n    }\n\n    fn int_mask_fill(\n        tensor: IntTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::MaskFill(desc)))\n            .output()\n    }\n\n    fn int_gather(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Gather(desc)))\n            .output()\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: IntTensor<Self>,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Scatter(desc)))\n            .output()\n    }\n\n    fn int_select(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Select(desc)))\n            .output()\n    }\n\n    fn int_select_add(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: IntTensor<Self>,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SelectAssignOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::SelectAssign(desc)))\n            .output()\n    }\n\n    fn int_cat(tensors: Vec<IntTensor<Self>>, dim: usize) -> IntTensor<Self> {\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Cat(desc)))\n            .output()\n    }\n\n    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Equal(desc)))\n            .output()\n    }\n\n    fn int_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::EqualElem(desc)))\n            .output()\n    }\n\n    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::Greater(desc),\n            ))\n            .output()\n    }\n\n    fn int_greater_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterElem(desc),\n            ))\n            .output()\n    }\n\n    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterEqual(desc),\n            ))\n            .output()\n    }\n\n    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterEqualElem(desc),\n            ))\n            .output()\n    }\n\n    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::Lower(desc),\n            ))\n            .output()\n    }\n\n    fn int_lower_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerElem(desc),\n            ))\n            .output()\n    }\n\n    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerEqual(desc),\n            ))\n            .output()\n    }\n\n    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerEqualElem(desc),\n            ))\n            .output()\n    }\n\n    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Add(desc),\n            ))\n            .output()\n    }\n\n    fn int_add_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::AddScalar(desc),\n            ))\n            .output()\n    }\n\n    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Sub(desc),\n            ))\n            .output()\n    }\n\n    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::SubScalar(desc),\n            ))\n            .output()\n    }\n\n    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Mul(desc),\n            ))\n            .output()\n    }\n\n    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MulScalar(desc),\n            ))\n            .output()\n    }\n\n    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Div(desc),\n            ))\n            .output()\n    }\n\n    fn int_div_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::DivScalar(desc),\n            ))\n            .output()\n    }\n\n    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Rem(desc),\n            ))\n            .output()\n    }\n\n    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::RemScalar(desc),\n            ))\n            .output()\n    }\n\n    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Zeros(desc)))\n            .output()\n    }\n\n    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Ones(desc)))\n            .output()\n    }\n\n    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Sum(desc),\n            ))\n            .output()\n    }\n\n    fn int_sum_dim(tensor: IntTensor<Self>, axis: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::SumDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Prod(desc),\n            ))\n            .output()\n    }\n\n    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::ProdDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_mean(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Mean(desc),\n            ))\n            .output()\n    }\n\n    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MeanDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::CumSum(desc),\n            ))\n            .output()\n    }\n\n    fn int_cumprod(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::CumProd(desc),\n            ))\n            .output()\n    }\n\n    fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::CumMin(desc),\n            ))\n            .output()\n    }\n\n    fn int_cummax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::CumMax(desc),\n            ))\n            .output()\n    }\n\n    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::ArgMax(desc),\n            ))\n            .output()\n    }\n\n    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::ArgMin(desc),\n            ))\n            .output()\n    }\n\n    fn int_clamp(tensor: IntTensor<Self>, min: Scalar, max: Scalar) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let min = min.into();\n        let max = max.into();\n        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Clamp(desc),\n            ))\n            .output()\n    }\n\n    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Abs(desc),\n            ))\n            .output()\n    }\n\n    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), FloatElem::<Self>::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::IntoFloat(desc)))\n            .output()\n    }\n\n    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::SwapDims(desc)))\n            .output()\n    }\n\n    fn int_max(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Max(desc),\n            ))\n            .output()\n    }\n\n    fn int_max_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MaxDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_max_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        let client = tensor.client.clone();\n        let desc = ReduceDimWithIndicesOpIr::create(\n            tensor.into_ir(),\n            dim,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.tensor.dtype,\n                NumericOperationIr::MaxDimWithIndices(desc),\n            ))\n            .outputs()\n            .into()\n    }\n\n    fn int_max_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MaxAbs(desc),\n            ))\n            .output()\n    }\n\n    fn int_max_abs_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MaxAbsDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_min(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::Min(desc),\n            ))\n            .output()\n    }\n\n    fn int_min_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MinDim(desc),\n            ))\n            .output()\n    }\n\n    fn int_min_dim_with_indices(\n        tensor: IntTensor<Self>,\n        dim: usize,\n    ) -> (IntTensor<Self>, IntTensor<Self>) {\n        let client = tensor.client.clone();\n        let desc = ReduceDimWithIndicesOpIr::create(\n            tensor.into_ir(),\n            dim,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericInt(\n                desc.out.dtype,\n                NumericOperationIr::MinDimWithIndices(desc),\n            ))\n            .outputs()\n            .into()\n    }\n\n    fn int_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> IntTensor<Self> {\n        let client = get_client::<R>(device);\n        let dtype = IntElem::<Self>::dtype();\n        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericInt(\n                dtype,\n                NumericOperationIr::IntRandom(desc),\n            ))\n            .output()\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Permute(desc)))\n            .output()\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Expand(desc)))\n            .output()\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Flip(desc)))\n            .output()\n    }\n\n    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::RepeatDim(desc)))\n            .output()\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseAnd(desc)))\n            .output()\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseOr(desc)))\n            .output()\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseXor(desc)))\n            .output()\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseNot(desc)))\n            .output()\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseAndScalar(desc)))\n            .output()\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseOrScalar(desc)))\n            .output()\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseXorScalar(desc)))\n            .output()\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShift(desc)))\n            .output()\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseLeftShiftScalar(\n                desc,\n            )))\n            .output()\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseRightShift(desc)))\n            .output()\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Int(IntOperationIr::BitwiseRightShiftScalar(\n                desc,\n            )))\n            .output()\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: burn_backend::IntDType) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Cast(desc)))\n            .output()\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)))\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/mod.rs",
    "content": "mod activation;\nmod binary;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\nmod unary;\n"
  },
  {
    "path": "crates/burn-router/src/ops/module.rs",
    "content": "use alloc::boxed::Box;\n\nuse burn_backend::Element;\nuse burn_backend::ops::{\n    AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,\n    DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices,\n    MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,\n};\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntElem, IntTensor};\nuse burn_ir::*;\n\nuse crate::{BackendRouter, RunnerChannel, RunnerClient};\n\nimpl<R: RunnerChannel> ModuleOps<Self> for BackendRouter<R> {\n    fn conv1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv1dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv1d(desc)))\n            .output()\n    }\n\n    fn conv1d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv1dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv1dXBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv1d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<1>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv1dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(\n                ModuleOperationIr::Conv1dWeightBackward(desc),\n            ))\n            .output()\n    }\n\n    fn conv1d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv1dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv1dBiasBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv2dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv2d(desc)))\n            .output()\n    }\n\n    fn conv2d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv2dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv2dXBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv2d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv2dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(\n                ModuleOperationIr::Conv2dWeightBackward(desc),\n            ))\n            .output()\n    }\n\n    fn conv2d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv2dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv2dBiasBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv3dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv3d(desc)))\n            .output()\n    }\n\n    fn conv3d_x_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv3dXBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv3dXBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv3d_weight_backward(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n        options: ConvOptions<3>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv3dWeightBackwardOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(\n                ModuleOperationIr::Conv3dWeightBackward(desc),\n            ))\n            .output()\n    }\n\n    fn conv3d_bias_backward(\n        x: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n        output_grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = Conv3dBiasBackwardOpIr::create(\n            x.into_ir(),\n            bias.into_ir(),\n            output_grad.into_ir(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Conv3dBiasBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv_transpose1d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<1>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = ConvTranspose1dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose1d(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv_transpose2d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<2>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = ConvTranspose2dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose2d(\n                desc,\n            )))\n            .output()\n    }\n\n    fn conv_transpose3d(\n        x: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        bias: Option<FloatTensor<Self>>,\n        options: ConvTransposeOptions<3>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = ConvTranspose3dOpIr::create(\n            x.into_ir(),\n            weight.into_ir(),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::ConvTranspose3d(\n                desc,\n            )))\n            .output()\n    }\n\n    fn avg_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = AvgPool1dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AvgPool1d(desc)))\n            .output()\n    }\n\n    fn avg_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = AvgPool2dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AvgPool2d(desc)))\n            .output()\n    }\n\n    fn avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = AvgPool1dBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AvgPool1dBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = AvgPool2dBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            count_include_pad,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AvgPool2dBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn max_pool1d(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = MaxPool1dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::MaxPool1d(desc)))\n            .output()\n    }\n\n    fn max_pool2d(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = MaxPool2dOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::MaxPool2d(desc)))\n            .output()\n    }\n\n    fn max_pool1d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<Self> {\n        let client = x.client.clone();\n        let desc = MaxPool1dWithIndicesOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        let [out, out_indices] = client\n            .register(OperationIr::Module(\n                ModuleOperationIr::MaxPool1dWithIndices(desc),\n            ))\n            .outputs();\n\n        MaxPool1dWithIndices::new(out, out_indices)\n    }\n\n    fn max_pool2d_with_indices(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        let client = x.client.clone();\n        let desc = MaxPool2dWithIndicesOpIr::create(\n            x.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        let [out, out_indices] = client\n            .register(OperationIr::Module(\n                ModuleOperationIr::MaxPool2dWithIndices(desc),\n            ))\n            .outputs();\n\n        MaxPool2dWithIndices::new(out, out_indices)\n    }\n\n    fn max_pool1d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool1dBackward<Self> {\n        let client = x.client.clone();\n\n        let desc = MaxPool1dWithIndicesBackwardOpIr::create(\n            x.into_ir(),\n            output_grad.into_ir(),\n            indices.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        let out = client\n            .register(OperationIr::Module(\n                ModuleOperationIr::MaxPool1dWithIndicesBackward(desc),\n            ))\n            .output();\n\n        MaxPool1dBackward::new(out)\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: FloatTensor<Self>,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> MaxPool2dBackward<Self> {\n        let client = x.client.clone();\n\n        let desc = MaxPool2dWithIndicesBackwardOpIr::create(\n            x.into_ir(),\n            output_grad.into_ir(),\n            indices.into_ir(),\n            kernel_size,\n            stride,\n            padding,\n            dilation,\n            ceil_mode,\n            || client.create_empty_handle(),\n        );\n\n        let out = client\n            .register(OperationIr::Module(\n                ModuleOperationIr::MaxPool2dWithIndicesBackward(desc),\n            ))\n            .output();\n\n        MaxPool2dBackward::new(out)\n    }\n\n    fn adaptive_avg_pool1d(x: FloatTensor<Self>, output_size: usize) -> FloatTensor<Self> {\n        let client = x.client.clone();\n\n        let desc = AdaptiveAvgPool1dOpIr::create(x.into_ir(), output_size, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool1d(\n                desc,\n            )))\n            .output()\n    }\n\n    fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {\n        let client = x.client.clone();\n\n        let desc = AdaptiveAvgPool2dOpIr::create(x.into_ir(), output_size, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::AdaptiveAvgPool2d(\n                desc,\n            )))\n            .output()\n    }\n\n    fn adaptive_avg_pool1d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n\n        let desc = AdaptiveAvgPool1dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Module(\n                ModuleOperationIr::AdaptiveAvgPool1dBackward(desc),\n            ))\n            .output()\n    }\n\n    fn adaptive_avg_pool2d_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n\n        let desc = AdaptiveAvgPool2dBackwardOpIr::create(x.into_ir(), grad.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Module(\n                ModuleOperationIr::AdaptiveAvgPool2dBackward(desc),\n            ))\n            .output()\n    }\n\n    fn interpolate(\n        x: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = InterpolateOpIr::create(x.into_ir(), output_size, options.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Interpolate(desc)))\n            .output()\n    }\n\n    fn interpolate_backward(\n        x: FloatTensor<Self>,\n        grad: FloatTensor<Self>,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = InterpolateBackwardOpIr::create(\n            x.into_ir(),\n            grad.into_ir(),\n            output_size,\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::InterpolateBackward(\n                desc,\n            )))\n            .output()\n    }\n\n    fn deform_conv2d(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        options: DeformConvOptions<2>,\n    ) -> FloatTensor<Self> {\n        let client = x.client.clone();\n        let desc = DeformConv2dOpIr::create(\n            x.into_ir(),\n            offset.into_ir(),\n            weight.into_ir(),\n            mask.map(|mask| mask.into_ir()),\n            bias.map(|bias| bias.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::DeformableConv2d(\n                Box::new(desc),\n            )))\n            .output()\n    }\n\n    fn deform_conv2d_backward(\n        x: FloatTensor<Self>,\n        offset: FloatTensor<Self>,\n        weight: FloatTensor<Self>,\n        mask: Option<FloatTensor<Self>>,\n        bias: Option<FloatTensor<Self>>,\n        output_grad: FloatTensor<Self>,\n        options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        let client = x.client.clone();\n        let has_bias = bias.is_some();\n        let has_mask = mask.is_some();\n\n        let desc = DeformConv2dBackwardOpIr::create(\n            x.into_ir(),\n            offset.into_ir(),\n            weight.into_ir(),\n            mask.map(|mask| mask.into_ir()),\n            bias.map(|bias| bias.into_ir()),\n            output_grad.into_ir(),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n        let mut outputs = client\n            .register(OperationIr::Module(\n                ModuleOperationIr::DeformableConv2dBackward(Box::new(desc)),\n            ))\n            .into_iter();\n\n        // When the number of outputs is variable, the order is important\n        let input_grad = outputs.next().unwrap();\n        let offset_grad = outputs.next().unwrap();\n        let weight_grad = outputs.next().unwrap();\n        let mask_grad = has_mask.then(|| outputs.next().unwrap());\n        let bias_grad = has_bias.then(|| outputs.next().unwrap());\n\n        DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)\n    }\n\n    fn attention(\n        query: FloatTensor<Self>,\n        key: FloatTensor<Self>,\n        value: FloatTensor<Self>,\n        mask: Option<BoolTensor<Self>>,\n        attn_bias: Option<FloatTensor<Self>>,\n        options: AttentionModuleOptions,\n    ) -> FloatTensor<Self> {\n        let client = query.client.clone();\n        let desc = AttentionOpIr::create(\n            query.into_ir(),\n            key.into_ir(),\n            value.into_ir(),\n            mask.map(|m: BoolTensor<Self>| m.into_ir()),\n            attn_bias.map(|ab| ab.into_ir()),\n            options.into(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::Module(ModuleOperationIr::Attention(desc)))\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, Shape, Slice, TensorData,\n    ops::QTensorOps,\n    quantization::{QuantScheme, QuantizationParametersPrimitive},\n    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},\n};\n\nuse crate::{BackendRouter, RunnerChannel};\n\nimpl<R: RunnerChannel> QTensorOps<Self> for BackendRouter<R> {\n    fn q_from_data(_data: TensorData, _device: &Device<Self>) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn quantize(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n        _qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn quantize_dynamic(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_device(_tensor: &QuantizedTensor<Self>) -> Device<Self> {\n        unimplemented!()\n    }\n\n    fn q_to_device(\n        _tensor: QuantizedTensor<Self>,\n        _device: &Device<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unimplemented!()\n    }\n\n    fn q_swap_dims(\n        _tensor: QuantizedTensor<Self>,\n        _dim1: usize,\n        _dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_gather(\n        _dim: usize,\n        _tensor: QuantizedTensor<Self>,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_select(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/tensor.rs",
    "content": "use alloc::vec::Vec;\nuse burn_backend::Scalar;\nuse burn_backend::backend::{Backend, ExecutionError};\n\nuse crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};\nuse burn_backend::tensor::{\n    BoolTensor, Device, FloatElem, FloatTensor, IndexingUpdateOp, IntElem, IntTensor,\n};\nuse burn_backend::{\n    Distribution, Element, FloatDType, Shape, Slice, TensorData, ops::FloatTensorOps,\n};\nuse burn_ir::{\n    BaseOperationIr, BinaryOpIr, CastOpIr, CatOpIr, ClampOpIr, CreationOpIr, CrossOpIr, DimOpIr,\n    FlipOpIr, FloatOperationIr, FullOpIr, GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr,\n    MatmulOpIr, NumericOperationIr, OperationIr, OperationOutput, PermuteOpIr, RandomOpIr,\n    ReduceDimOpIr, ReduceDimWithIndicesOpIr, ReduceOpIr, RepeatDimOpIr, ScalarOpIr, ScatterOpIr,\n    SelectAssignOpIr, SelectOpIr, ShapeOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,\n    UnfoldOpIr,\n};\n\nimpl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {\n    fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let out = client.register_tensor_data(data);\n        let desc = InitOperationIr {\n            out: out.to_ir_out(),\n        };\n\n        // Call register op when output is already initialized\n        client.register_op(OperationIr::Init(desc));\n\n        out\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &Device<Self>,\n    ) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let dtype = FloatElem::<Self>::dtype();\n        let desc = RandomOpIr::create(shape, dtype, distribution, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(dtype, FloatOperationIr::Random(desc)))\n            .output()\n    }\n\n    fn float_zeros(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Zeros(desc)))\n            .output()\n    }\n\n    fn float_ones(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Ones(desc)))\n            .output()\n    }\n\n    fn float_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &Device<Self>,\n        dtype: FloatDType,\n    ) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let dtype = dtype.into();\n        let value = fill_value.into();\n        let desc = FullOpIr::create(shape, dtype, value, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Full(desc),\n            ))\n            .output()\n    }\n\n    async fn float_into_data(tensor: FloatTensor<Self>) -> Result<TensorData, ExecutionError> {\n        Ok(tensor\n            .into_data()\n            .await?\n            // Since underlying backends can have different data types, we convert to the current elem\n            .convert::<<Self as Backend>::FloatElem>())\n    }\n\n    fn float_device(tensor: &FloatTensor<Self>) -> Device<Self> {\n        tensor.client.device()\n    }\n\n    fn float_to_device(tensor: FloatTensor<Self>, device: &Device<Self>) -> FloatTensor<Self> {\n        if &tensor.client.device() == device {\n            return tensor;\n        }\n        R::change_client_backend(tensor, device)\n    }\n\n    fn float_into_int(tensor: FloatTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), IntElem::<Self>::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Float(\n                desc.input.dtype,\n                FloatOperationIr::IntoInt(desc),\n            ))\n            .output()\n    }\n\n    fn float_empty(shape: Shape, device: &Device<Self>, dtype: FloatDType) -> FloatTensor<Self> {\n        let client = get_client::<R>(device);\n        let desc = CreationOpIr::create(shape, dtype.into(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Empty(desc)))\n            .output()\n    }\n\n    fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Add(desc),\n            ))\n            .output()\n    }\n\n    fn float_add_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::AddScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_clamp(tensor: FloatTensor<Self>, min: Scalar, max: Scalar) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let min = min.into();\n        let max = max.into();\n        let desc = ClampOpIr::create(tensor.into_ir(), min, max, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Clamp(desc),\n            ))\n            .output()\n    }\n\n    fn float_sub(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Sub(desc),\n            ))\n            .output()\n    }\n\n    fn float_sub_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::SubScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_mul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Mul(desc),\n            ))\n            .output()\n    }\n\n    fn float_mul_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::MulScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_div(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Div(desc),\n            ))\n            .output()\n    }\n\n    fn float_div_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::DivScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_remainder(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Rem(desc),\n            ))\n            .output()\n    }\n\n    fn float_remainder_scalar(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::RemScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_matmul(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = MatmulOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Matmul(desc),\n            ))\n            .output()\n    }\n\n    fn float_cross(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        dim: usize,\n    ) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = CrossOpIr::create(lhs.into_ir(), rhs.into_ir(), dim, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Cross(desc),\n            ))\n            .output()\n    }\n\n    fn float_swap_dims(tensor: FloatTensor<Self>, dim1: usize, dim2: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SwapDimsOpIr::create(tensor.into_ir(), dim1, dim2, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc)))\n            .output()\n    }\n\n    fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::reshape(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Reshape(desc)))\n            .output()\n    }\n\n    fn float_gather(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = GatherOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Gather(desc)))\n            .output()\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: FloatTensor<Self>,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ScatterOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Scatter(desc)))\n            .output()\n    }\n\n    fn float_select(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SelectOpIr::create(tensor.into_ir(), dim, indices.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Select(desc)))\n            .output()\n    }\n\n    fn float_select_add(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        indices: IntTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SelectAssignOpIr::create(\n            tensor.into_ir(),\n            dim,\n            indices.into_ir(),\n            value.into_ir(),\n            IndexingUpdateOp::Add,\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::SelectAssign(desc)))\n            .output()\n    }\n\n    fn float_slice(tensor: FloatTensor<Self>, slices: &[Slice]) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = SliceOpIr::create(tensor.into_ir(), slices.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Slice(desc)))\n            .output()\n    }\n\n    fn float_slice_assign(\n        tensor: FloatTensor<Self>,\n        slices: &[burn_backend::Slice],\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc =\n            SliceAssignOpIr::create(tensor.into_ir(), slices.into(), value.into_ir(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::SliceAssign(desc)))\n            .output()\n    }\n\n    fn float_mask_where(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = MaskWhereOpIr::create(tensor.into_ir(), mask.into_ir(), value.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::MaskWhere(desc)))\n            .output()\n    }\n\n    fn float_mask_fill(\n        tensor: FloatTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let value = value.into();\n        let desc = MaskFillOpIr::create(tensor.into_ir(), mask.into_ir(), value, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::MaskFill(desc)))\n            .output()\n    }\n\n    fn float_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Equal(desc)))\n            .output()\n    }\n\n    fn float_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::EqualElem(desc)))\n            .output()\n    }\n\n    fn float_greater(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::Greater(desc),\n            ))\n            .output()\n    }\n\n    fn float_greater_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterElem(desc),\n            ))\n            .output()\n    }\n\n    fn float_greater_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterEqual(desc),\n            ))\n            .output()\n    }\n\n    fn float_greater_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::GreaterEqualElem(desc),\n            ))\n            .output()\n    }\n\n    fn float_lower(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::Lower(desc),\n            ))\n            .output()\n    }\n\n    fn float_lower_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerElem(desc),\n            ))\n            .output()\n    }\n\n    fn float_lower_equal(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create_comparison(\n            lhs.into_ir(),\n            rhs.into_ir(),\n            R::BoolElem::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerEqual(desc),\n            ))\n            .output()\n    }\n\n    fn float_lower_equal_elem(lhs: FloatTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create_comparison(lhs.into_ir(), rhs, R::BoolElem::dtype(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.lhs.dtype,\n                NumericOperationIr::LowerEqualElem(desc),\n            ))\n            .output()\n    }\n\n    fn float_sum(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Sum(desc),\n            ))\n            .output()\n    }\n\n    fn float_sum_dim(tensor: FloatTensor<Self>, axis: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), axis, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::SumDim(desc),\n            ))\n            .output()\n    }\n\n    fn float_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Prod(desc),\n            ))\n            .output()\n    }\n\n    fn float_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::ProdDim(desc),\n            ))\n            .output()\n    }\n\n    fn float_mean(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Mean(desc),\n            ))\n            .output()\n    }\n\n    fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::MeanDim(desc),\n            ))\n            .output()\n    }\n\n    fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::CumSum(desc),\n            ))\n            .output()\n    }\n\n    fn float_cumprod(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::CumProd(desc),\n            ))\n            .output()\n    }\n\n    fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::CumMin(desc),\n            ))\n            .output()\n    }\n\n    fn float_cummax(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = DimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::CumMax(desc),\n            ))\n            .output()\n    }\n\n    fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = UnaryOpIr::create(lhs.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Exp(desc),\n            ))\n            .output()\n    }\n\n    fn float_log(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Log(desc),\n            ))\n            .output()\n    }\n\n    fn float_log1p(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Log1p(desc),\n            ))\n            .output()\n    }\n\n    fn float_powf_scalar_impl(lhs: FloatTensor<Self>, rhs: Scalar) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let rhs = rhs.into();\n        let desc = ScalarOpIr::create(lhs.into_ir(), rhs, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::PowfScalar(desc),\n            ))\n            .output()\n    }\n\n    fn float_sqrt(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Sqrt(desc),\n            ))\n            .output()\n    }\n\n    fn float_abs(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Abs(desc),\n            ))\n            .output()\n    }\n\n    fn float_cos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Cos(desc),\n            ))\n            .output()\n    }\n\n    fn float_cosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Cosh(desc),\n            ))\n            .output()\n    }\n\n    fn float_sin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Sin(desc),\n            ))\n            .output()\n    }\n\n    fn float_sinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Sinh(desc),\n            ))\n            .output()\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Tan(desc),\n            ))\n            .output()\n    }\n\n    fn float_tanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Tanh(desc),\n            ))\n            .output()\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcCos(desc),\n            ))\n            .output()\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcCosh(desc),\n            ))\n            .output()\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcSin(desc),\n            ))\n            .output()\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcSinh(desc),\n            ))\n            .output()\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcTan(desc),\n            ))\n            .output()\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcTanh(desc),\n            ))\n            .output()\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::ArcTan2(desc),\n            ))\n            .output()\n    }\n\n    fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Round(desc),\n            ))\n            .output()\n    }\n\n    fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Floor(desc),\n            ))\n            .output()\n    }\n\n    fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Ceil(desc),\n            ))\n            .output()\n    }\n\n    fn float_trunc(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Trunc(desc),\n            ))\n            .output()\n    }\n\n    fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Recip(desc),\n            ))\n            .output()\n    }\n\n    fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnaryOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Erf(desc),\n            ))\n            .output()\n    }\n\n    fn float_cat(tensors: Vec<FloatTensor<Self>>, dim: usize) -> FloatTensor<Self> {\n        let client = tensors.first().unwrap().client.clone();\n        let tensors = tensors.into_iter().map(|t| t.into_ir()).collect();\n        let desc = CatOpIr::create(tensors, dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Cat(desc)))\n            .output()\n    }\n\n    fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc =\n            ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.input.dtype,\n                NumericOperationIr::ArgMax(desc),\n            ))\n            .output()\n    }\n\n    fn float_repeat_dim(tensor: FloatTensor<Self>, dim: usize, times: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = RepeatDimOpIr::create(tensor.into_ir(), dim, times, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::RepeatDim(desc)))\n            .output()\n    }\n\n    fn float_argmin(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<Self> {\n        let client = tensor.client.clone();\n        let desc =\n            ReduceDimOpIr::create_arg(tensor.into_ir(), dim, IntElem::<Self>::dtype(), || {\n                client.create_empty_handle()\n            });\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.input.dtype,\n                NumericOperationIr::ArgMin(desc),\n            ))\n            .output()\n    }\n\n    fn float_max(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Max(desc),\n            ))\n            .output()\n    }\n\n    fn float_max_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::MaxDim(desc),\n            ))\n            .output()\n    }\n\n    fn float_max_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        let client = tensor.client.clone();\n        let desc = ReduceDimWithIndicesOpIr::create(\n            tensor.into_ir(),\n            dim,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.tensor.dtype,\n                NumericOperationIr::MaxDimWithIndices(desc),\n            ))\n            .outputs()\n            .into()\n    }\n\n    fn float_min(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceOpIr::create(tensor.into_ir(), || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::Min(desc),\n            ))\n            .output()\n    }\n\n    fn float_min_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ReduceDimOpIr::create(tensor.into_ir(), dim, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.out.dtype,\n                NumericOperationIr::MinDim(desc),\n            ))\n            .output()\n    }\n\n    fn float_min_dim_with_indices(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n    ) -> (FloatTensor<Self>, IntTensor<Self>) {\n        let client = tensor.client.clone();\n        let desc = ReduceDimWithIndicesOpIr::create(\n            tensor.into_ir(),\n            dim,\n            IntElem::<Self>::dtype(),\n            || client.create_empty_handle(),\n        );\n\n        client\n            .register(OperationIr::NumericFloat(\n                desc.tensor.dtype,\n                NumericOperationIr::MinDimWithIndices(desc),\n            ))\n            .outputs()\n            .into()\n    }\n\n    fn float_powf(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        let client = lhs.client.clone();\n        let desc = BinaryOpIr::create(lhs.into_ir(), rhs.into_ir(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::Float(\n                desc.out.dtype,\n                FloatOperationIr::Powf(desc),\n            ))\n            .output()\n    }\n\n    fn float_permute(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = PermuteOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Permute(desc)))\n            .output()\n    }\n\n    fn float_expand(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = ShapeOpIr::expand(tensor.into_ir(), shape, || client.create_empty_handle());\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Expand(desc)))\n            .output()\n    }\n\n    fn float_flip(tensor: FloatTensor<Self>, axes: &[usize]) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = FlipOpIr::create(tensor.into_ir(), axes.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Flip(desc)))\n            .output()\n    }\n\n    fn float_cast(tensor: FloatTensor<Self>, dtype: burn_backend::FloatDType) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = CastOpIr::create(tensor.into_ir(), dtype.into(), || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Cast(desc)))\n            .output()\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        let client = tensor.client.clone();\n        let desc = UnfoldOpIr::create(tensor.into_ir(), dim, size, step, || {\n            client.create_empty_handle()\n        });\n\n        client\n            .register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)))\n            .output()\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/ops/transaction.rs",
    "content": "use burn_backend::ops::TransactionOps;\n\nuse crate::{BackendRouter, RunnerChannel};\n\nimpl<R: RunnerChannel> TransactionOps<Self> for BackendRouter<R> {}\n"
  },
  {
    "path": "crates/burn-router/src/ops/unary.rs",
    "content": "#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs.into());\n\n        $handles.register_float_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float_dim_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs);\n\n        $handles.register_float_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_float_dim_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let input = $handles.get_float_tensor::<B>(&$desc.input);\n        let output = $ops(input, $desc.axis);\n\n        $handles.register_float_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_float2int_dim_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let input = $handles.get_float_tensor::<B>(&$desc.input);\n        let output = $ops(input, $desc.axis);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! reduce_int_dim_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let input = $handles.get_int_tensor::<B>(&$desc.input);\n        let output = $ops(input, $desc.axis);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float2int_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_float_cmp_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs.into());\n\n        $handles.register_bool_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! unary_float_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_float_tensor::<B>(&$desc.input);\n        let output = $ops(lhs);\n\n        $handles.register_float_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_int_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs.into());\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_int_dim_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! scalar_int_cmp_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.lhs);\n        let output = $ops(lhs, $desc.rhs.into());\n\n        $handles.register_bool_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n\n#[allow(missing_docs)]\n#[macro_export(local_inner_macros)]\nmacro_rules! unary_int_ops {\n    (\n        $handles:expr, $desc:expr, $ops:expr\n    ) => {{\n        let lhs = $handles.get_int_tensor::<B>(&$desc.input);\n        let output = $ops(lhs);\n\n        $handles.register_int_tensor::<B>(&$desc.out.id, output);\n    }};\n}\n"
  },
  {
    "path": "crates/burn-router/src/runner.rs",
    "content": "use core::sync::atomic::{AtomicU64, Ordering};\n\nuse super::{RouterTensor, RunnerClient};\nuse crate::{\n    binary_bool_ops, binary_float_cmp_ops, binary_float_ops, binary_int_cmp_ops, binary_int_ops,\n    reduce_float_dim_ops, reduce_float2int_dim_ops, reduce_int_dim_ops, scalar_float_cmp_ops,\n    scalar_float_ops, scalar_int_cmp_ops, scalar_int_ops, unary_float_ops, unary_int_ops,\n};\nuse alloc::boxed::Box;\nuse alloc::sync::Arc;\nuse burn_backend::{Backend, DType, ExecutionError, Shape, TensorData, tensor::IndexingUpdateOp};\nuse burn_ir::{\n    BackendIr, BaseOperationIr, BoolOperationIr, FloatOperationIr, HandleContainer, IntOperationIr,\n    ModuleOperationIr, NumericOperationIr, OperationIr, TensorId, TensorIr, TensorStatus,\n};\nuse burn_std::{future::DynFut, stub::Mutex};\n\n/// A runner's context contains a [handle container](HandleContainer) to manage\n/// (i.e., fetch and update) existing tensors.\npub struct RunnerContext<B: BackendIr> {\n    /// Handle container to retrieve tensors based on their intermediate representation.\n    handles: HandleContainer<B::Handle>,\n}\n\nstatic COUNTER: AtomicU64 = AtomicU64::new(0);\n\nimpl<B: BackendIr> RunnerContext<B> {\n    /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId).\n    fn create_empty_handle(&mut self) -> TensorId {\n        let value = COUNTER.fetch_add(1, Ordering::Relaxed);\n        TensorId::new(value)\n    }\n}\n\n/// A runner is responsible for executing tensor operations for a given [intermediate backend](BackendIr).\n#[derive(Clone)]\npub struct Runner<B: BackendIr> {\n    // Mutex for the mutable handles\n    context: Arc<Mutex<RunnerContext<B>>>,\n    device: B::Device,\n}\n\nimpl<B: BackendIr> core::fmt::Debug for Runner<B> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.debug_struct(\"Runner\")\n            .field(\"device\", &self.device)\n            .finish()\n    }\n}\n\nimpl<B: BackendIr> Runner<B> {\n    /// Create a new runner.\n    pub fn new(device: B::Device) -> Self {\n        Self {\n            context: Arc::new(Mutex::new(RunnerContext {\n                handles: HandleContainer::new(),\n            })),\n            device,\n        }\n    }\n\n    /// Get the tensor handle for the given [tensor representation](TensorIr).\n    pub fn get_tensor_handle(&self, tensor: &TensorIr) -> B::Handle {\n        let handles = &mut self.context.lock().unwrap().handles;\n        handles.get_tensor_handle(tensor).handle\n    }\n\n    /// Create a tensor with the given handle and shape.\n    pub fn register_tensor<C: RunnerClient>(\n        &self,\n        handle: B::Handle,\n        shape: Shape,\n        dtype: DType,\n        client: C,\n    ) -> RouterTensor<C> {\n        let mut ctx = self.context.lock().unwrap();\n        let id = ctx.create_empty_handle();\n\n        ctx.handles.register_handle(id, handle);\n        core::mem::drop(ctx);\n\n        RouterTensor::new(id, shape, dtype, client)\n    }\n\n    /// Register a tensor from its data and id.\n    pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) {\n        let mut ctx = self.context.lock().unwrap();\n        let dtype = data.dtype;\n\n        if dtype.is_float() {\n            let tensor = B::float_from_data(data, &self.device);\n            ctx.handles.register_float_tensor::<B>(&id, tensor)\n        } else if dtype.is_int() {\n            let tensor = B::int_from_data(data, &self.device);\n            ctx.handles.register_int_tensor::<B>(&id, tensor)\n        } else if dtype.is_bool() {\n            let tensor = B::bool_from_data(data, &self.device);\n            ctx.handles.register_bool_tensor::<B>(&id, tensor)\n        } else if let DType::QFloat(_) = dtype {\n            todo!();\n        }\n\n        core::mem::drop(ctx);\n    }\n\n    /// Register a tensor and returns its intermediate representation.\n    pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorIr {\n        let mut ctx = self.context.lock().unwrap();\n        let id = ctx.create_empty_handle();\n        let shape = data.shape.clone();\n        let dtype = data.dtype;\n\n        if dtype.is_float() {\n            let tensor = B::float_from_data(data, &self.device);\n            ctx.handles.register_float_tensor::<B>(&id, tensor)\n        } else if dtype.is_int() {\n            let tensor = B::int_from_data(data, &self.device);\n            ctx.handles.register_int_tensor::<B>(&id, tensor)\n        } else if dtype.is_bool() {\n            let tensor = B::bool_from_data(data, &self.device);\n            ctx.handles.register_bool_tensor::<B>(&id, tensor)\n        } else if let DType::QFloat(_) = dtype {\n            todo!();\n        }\n\n        core::mem::drop(ctx);\n\n        TensorIr {\n            id,\n            shape,\n            status: TensorStatus::ReadWrite,\n            dtype,\n        }\n    }\n}\n\n// This is a Remote Runner\nimpl<B: BackendIr> RunnerClient for Runner<B> {\n    type Device = B::Device;\n\n    /// Execute a tensor operation.\n    fn register_op(&self, op: OperationIr) {\n        // Remove unused tensor handles\n        let mut ctx = self.context.lock().unwrap();\n\n        let handles = &mut ctx.handles;\n        match &op {\n            // For every op: get the input(s), execute the operation and register the output(s)\n            OperationIr::BaseFloat(op) => match op {\n                BaseOperationIr::Reshape(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_reshape(tensor, desc.out.shape.clone());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SwapDims(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_swap_dims(tensor, desc.dim1, desc.dim2);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Permute(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_permute(tensor, &desc.axes);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Flip(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_flip(tensor, &desc.axes);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Expand(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_expand(tensor, desc.out.shape.clone());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Unfold(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_unfold(tensor, desc.dim, desc.size, desc.step);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Slice(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n\n                    let output = B::float_slice(tensor, &desc.ranges);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SliceAssign(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let value = handles.get_float_tensor::<B>(&desc.value);\n\n                    let output = B::float_slice_assign(tensor, &desc.ranges, value);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Gather(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::float_gather(desc.dim, tensor, indices);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Scatter(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_float_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::float_scatter_add(desc.dim, tensor, indices, value)\n                        }\n                    };\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Select(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::float_select(tensor, desc.dim, indices);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SelectAssign(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_float_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::float_select_add(tensor, desc.dim, indices, value)\n                        }\n                    };\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskWhere(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n                    let value = handles.get_float_tensor::<B>(&desc.value);\n\n                    let output = B::float_mask_where(tensor, mask, value);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskFill(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n\n                    let output = B::float_mask_fill(tensor, mask, desc.value.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Equal(desc) => {\n                    binary_float_cmp_ops!(handles, desc, B::float_equal)\n                }\n                BaseOperationIr::EqualElem(desc) => {\n                    scalar_float_cmp_ops!(handles, desc, B::float_equal_elem)\n                }\n                BaseOperationIr::RepeatDim(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n\n                    let output = B::float_repeat_dim(tensor, desc.dim, desc.times);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cat(desc) => {\n                    let tensors = desc\n                        .tensors\n                        .iter()\n                        .map(|tensor| handles.get_float_tensor::<B>(tensor))\n                        .collect();\n\n                    let output = B::float_cat(tensors, desc.dim);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cast(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n                    let output = B::float_cast(tensor, desc.out.dtype.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Empty(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::float_empty(shape, &self.device, desc.out.dtype.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Ones(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::float_ones(shape, &self.device, desc.out.dtype.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Zeros(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::float_zeros(shape, &self.device, desc.out.dtype.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::BaseInt(op) => match op {\n                BaseOperationIr::Reshape(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_reshape(tensor, desc.out.shape.clone());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SwapDims(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_swap_dims(tensor, desc.dim1, desc.dim2);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Permute(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_permute(tensor, &desc.axes);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Flip(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_flip(tensor, &desc.axes);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Expand(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_expand(tensor, desc.out.shape.clone());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Unfold(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_unfold(tensor, desc.dim, desc.size, desc.step);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Slice(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n\n                    let output = B::int_slice(tensor, &desc.ranges);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SliceAssign(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let value = handles.get_int_tensor::<B>(&desc.value);\n\n                    let output = B::int_slice_assign(tensor, &desc.ranges, value);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Gather(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::int_gather(desc.dim, tensor, indices);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Scatter(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_int_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::int_scatter_add(desc.dim, tensor, indices, value)\n                        }\n                    };\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Select(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::int_select(tensor, desc.dim, indices);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SelectAssign(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_int_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::int_select_add(tensor, desc.dim, indices, value)\n                        }\n                    };\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskWhere(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n                    let value = handles.get_int_tensor::<B>(&desc.value);\n\n                    let output = B::int_mask_where(tensor, mask, value);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskFill(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n\n                    let output = B::int_mask_fill(tensor, mask, desc.value.into());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Equal(desc) => {\n                    binary_int_cmp_ops!(handles, desc, B::int_equal)\n                }\n                BaseOperationIr::EqualElem(desc) => {\n                    scalar_int_cmp_ops!(handles, desc, B::int_equal_elem)\n                }\n                BaseOperationIr::RepeatDim(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n\n                    let output = B::int_repeat_dim(tensor, desc.dim, desc.times);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cat(desc) => {\n                    let tensors = desc\n                        .tensors\n                        .iter()\n                        .map(|tensor| handles.get_int_tensor::<B>(tensor))\n                        .collect();\n\n                    let output = B::int_cat(tensors, desc.dim);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cast(_) => unreachable!(),\n                BaseOperationIr::Empty(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::int_empty(shape, &self.device, desc.out.dtype.into());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Ones(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::int_ones(shape, &self.device, desc.out.dtype.into());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Zeros(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::int_zeros(shape, &self.device, desc.out.dtype.into());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::BaseBool(op) => match op {\n                BaseOperationIr::Reshape(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_reshape(tensor, desc.out.shape.clone());\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SwapDims(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_swap_dims(tensor, desc.dim1, desc.dim2);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Permute(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_permute(tensor, &desc.axes);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Flip(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_flip(tensor, &desc.axes);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Expand(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_expand(tensor, desc.out.shape.clone());\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Unfold(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_unfold(tensor, desc.dim, desc.size, desc.step);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Slice(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n\n                    let output = B::bool_slice(tensor, &desc.ranges);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SliceAssign(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let value = handles.get_bool_tensor::<B>(&desc.value);\n\n                    let output = B::bool_slice_assign(tensor, &desc.ranges, value);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Gather(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::bool_gather(desc.dim, tensor, indices);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Scatter(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_bool_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::bool_scatter_or(desc.dim, tensor, indices, value)\n                        }\n                    };\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Select(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::bool_select(tensor, desc.dim, indices);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::SelectAssign(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let value = handles.get_bool_tensor::<B>(&desc.value);\n\n                    let output = match desc.update {\n                        IndexingUpdateOp::Add => {\n                            B::bool_select_or(tensor, desc.dim, indices, value)\n                        }\n                    };\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskWhere(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n                    let value = handles.get_bool_tensor::<B>(&desc.value);\n\n                    let output = B::bool_mask_where(tensor, mask, value);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::MaskFill(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n                    let mask = handles.get_bool_tensor::<B>(&desc.mask);\n\n                    let output = B::bool_mask_fill(tensor, mask, desc.value.into());\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Equal(desc) => {\n                    let lhs = handles.get_bool_tensor::<B>(&desc.lhs);\n                    let rhs = handles.get_bool_tensor::<B>(&desc.rhs);\n\n                    let output = B::bool_equal(lhs, rhs);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::EqualElem(desc) => {\n                    let lhs = handles.get_bool_tensor::<B>(&desc.lhs);\n\n                    let output = B::bool_equal_elem(lhs, desc.rhs.into());\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::RepeatDim(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.tensor);\n\n                    let output = B::bool_repeat_dim(tensor, desc.dim, desc.times);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cat(desc) => {\n                    let tensors = desc\n                        .tensors\n                        .iter()\n                        .map(|tensor| handles.get_bool_tensor::<B>(tensor))\n                        .collect();\n\n                    let output = B::bool_cat(tensors, desc.dim);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Cast(_) => unreachable!(),\n                BaseOperationIr::Empty(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::bool_empty(shape, &self.device);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Zeros(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::bool_zeros(shape, &self.device);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BaseOperationIr::Ones(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::bool_ones(shape, &self.device);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::NumericFloat(_dtype, op) => match op {\n                NumericOperationIr::Add(desc) => {\n                    binary_float_ops!(handles, desc, B::float_add)\n                }\n                NumericOperationIr::AddScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_add_scalar)\n                }\n                NumericOperationIr::Sub(desc) => {\n                    binary_float_ops!(handles, desc, B::float_sub)\n                }\n                NumericOperationIr::SubScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_sub_scalar)\n                }\n                NumericOperationIr::Div(desc) => {\n                    binary_float_ops!(handles, desc, B::float_div)\n                }\n                NumericOperationIr::DivScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_div_scalar)\n                }\n                NumericOperationIr::Rem(desc) => {\n                    binary_float_ops!(handles, desc, B::float_remainder)\n                }\n                NumericOperationIr::RemScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_remainder_scalar)\n                }\n                NumericOperationIr::Mul(desc) => {\n                    binary_float_ops!(handles, desc, B::float_mul)\n                }\n                NumericOperationIr::MulScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_mul_scalar)\n                }\n                NumericOperationIr::Abs(desc) => {\n                    unary_float_ops!(handles, desc, B::float_abs)\n                }\n                NumericOperationIr::Full(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::float_full(\n                        shape,\n                        desc.value.into(),\n                        &self.device,\n                        desc.out.dtype.into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::MeanDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_mean_dim)\n                }\n                NumericOperationIr::Mean(desc) => {\n                    unary_float_ops!(handles, desc, B::float_mean)\n                }\n                NumericOperationIr::Sum(desc) => {\n                    unary_float_ops!(handles, desc, B::float_sum)\n                }\n                NumericOperationIr::SumDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_sum_dim)\n                }\n                NumericOperationIr::Prod(desc) => {\n                    unary_float_ops!(handles, desc, B::float_prod)\n                }\n                NumericOperationIr::ProdDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_prod_dim)\n                }\n                NumericOperationIr::Greater(desc) => {\n                    binary_float_cmp_ops!(handles, desc, B::float_greater)\n                }\n                NumericOperationIr::GreaterElem(desc) => {\n                    scalar_float_cmp_ops!(handles, desc, B::float_greater_elem)\n                }\n                NumericOperationIr::GreaterEqual(desc) => {\n                    binary_float_cmp_ops!(handles, desc, B::float_greater_equal)\n                }\n                NumericOperationIr::GreaterEqualElem(desc) => {\n                    scalar_float_cmp_ops!(handles, desc, B::float_greater_equal_elem)\n                }\n                NumericOperationIr::Lower(desc) => {\n                    binary_float_cmp_ops!(handles, desc, B::float_lower)\n                }\n                NumericOperationIr::LowerElem(desc) => {\n                    scalar_float_cmp_ops!(handles, desc, B::float_lower_elem)\n                }\n                NumericOperationIr::LowerEqual(desc) => {\n                    binary_float_cmp_ops!(handles, desc, B::float_lower_equal)\n                }\n                NumericOperationIr::LowerEqualElem(desc) => {\n                    scalar_float_cmp_ops!(handles, desc, B::float_lower_equal_elem)\n                }\n                NumericOperationIr::ArgMax(desc) => {\n                    reduce_float2int_dim_ops!(handles, desc, B::float_argmax)\n                }\n                NumericOperationIr::ArgMin(desc) => {\n                    reduce_float2int_dim_ops!(handles, desc, B::float_argmin)\n                }\n                NumericOperationIr::Max(desc) => {\n                    unary_float_ops!(handles, desc, B::float_max)\n                }\n                NumericOperationIr::MaxDimWithIndices(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n\n                    let (output, output_idx) = B::float_max_dim_with_indices(tensor, desc.dim);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output_idx);\n                }\n                NumericOperationIr::MinDimWithIndices(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n\n                    let (output, output_idx) = B::float_min_dim_with_indices(tensor, desc.dim);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output_idx);\n                }\n                NumericOperationIr::Min(desc) => {\n                    unary_float_ops!(handles, desc, B::float_min)\n                }\n                NumericOperationIr::MaxDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_max_dim)\n                }\n                NumericOperationIr::MinDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_min_dim)\n                }\n                NumericOperationIr::MaxAbs(desc) => {\n                    unary_float_ops!(handles, desc, B::float_max_abs)\n                }\n                NumericOperationIr::MaxAbsDim(desc) => {\n                    reduce_float_dim_ops!(handles, desc, B::float_max_abs_dim)\n                }\n                NumericOperationIr::Clamp(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n\n                    let output = B::float_clamp(tensor, desc.min.into(), desc.max.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::IntRandom(_) => unreachable!(),\n                NumericOperationIr::Powi(desc) => {\n                    let lhs = handles.get_float_tensor::<B>(&desc.lhs);\n                    let rhs = handles.get_int_tensor::<B>(&desc.rhs);\n                    let output = (B::float_powi)(lhs, rhs);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumSum(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n                    let output = B::float_cumsum(tensor, desc.axis);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumProd(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n                    let output = B::float_cumprod(tensor, desc.axis);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumMin(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n                    let output = B::float_cummin(tensor, desc.axis);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumMax(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n                    let output = B::float_cummax(tensor, desc.axis);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::NumericInt(_dtype, op) => match op {\n                NumericOperationIr::Add(desc) => {\n                    binary_int_ops!(handles, desc, B::int_add)\n                }\n                NumericOperationIr::AddScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::int_add_scalar)\n                }\n                NumericOperationIr::Sub(desc) => {\n                    binary_int_ops!(handles, desc, B::int_sub)\n                }\n                NumericOperationIr::SubScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::int_sub_scalar)\n                }\n                NumericOperationIr::Div(desc) => {\n                    binary_int_ops!(handles, desc, B::int_div)\n                }\n                NumericOperationIr::DivScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::int_div_scalar)\n                }\n                NumericOperationIr::Rem(desc) => {\n                    binary_int_ops!(handles, desc, B::int_remainder)\n                }\n                NumericOperationIr::RemScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::int_remainder_scalar)\n                }\n                NumericOperationIr::Mul(desc) => {\n                    binary_int_ops!(handles, desc, B::int_mul)\n                }\n                NumericOperationIr::MulScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::int_mul_scalar)\n                }\n                NumericOperationIr::Abs(desc) => {\n                    unary_int_ops!(handles, desc, B::int_abs)\n                }\n                NumericOperationIr::Full(desc) => {\n                    let shape = desc.out.shape.clone();\n                    let output = B::int_full(\n                        shape,\n                        desc.value.into(),\n                        &self.device,\n                        desc.out.dtype.into(),\n                    );\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::MeanDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_mean_dim)\n                }\n                NumericOperationIr::Mean(desc) => {\n                    unary_int_ops!(handles, desc, B::int_mean)\n                }\n                NumericOperationIr::Sum(desc) => {\n                    unary_int_ops!(handles, desc, B::int_sum)\n                }\n                NumericOperationIr::SumDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_sum_dim)\n                }\n                NumericOperationIr::Prod(desc) => {\n                    unary_int_ops!(handles, desc, B::int_prod)\n                }\n                NumericOperationIr::ProdDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_prod_dim)\n                }\n                NumericOperationIr::Greater(desc) => {\n                    binary_int_cmp_ops!(handles, desc, B::int_greater)\n                }\n                NumericOperationIr::GreaterElem(desc) => {\n                    scalar_int_cmp_ops!(handles, desc, B::int_greater_elem)\n                }\n                NumericOperationIr::GreaterEqual(desc) => {\n                    binary_int_cmp_ops!(handles, desc, B::int_greater_equal)\n                }\n                NumericOperationIr::GreaterEqualElem(desc) => {\n                    scalar_int_cmp_ops!(handles, desc, B::int_greater_equal_elem)\n                }\n                NumericOperationIr::Lower(desc) => {\n                    binary_int_cmp_ops!(handles, desc, B::int_lower)\n                }\n                NumericOperationIr::LowerElem(desc) => {\n                    scalar_int_cmp_ops!(handles, desc, B::int_lower_elem)\n                }\n                NumericOperationIr::LowerEqual(desc) => {\n                    binary_int_cmp_ops!(handles, desc, B::int_lower_equal)\n                }\n                NumericOperationIr::LowerEqualElem(desc) => {\n                    scalar_int_cmp_ops!(handles, desc, B::int_lower_equal_elem)\n                }\n                NumericOperationIr::ArgMax(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_argmax)\n                }\n                NumericOperationIr::ArgMin(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_argmin)\n                }\n                NumericOperationIr::Max(desc) => {\n                    unary_int_ops!(handles, desc, B::int_max)\n                }\n                NumericOperationIr::MaxDimWithIndices(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n\n                    let (output, output_idx) = B::int_max_dim_with_indices(tensor, desc.dim);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output_idx);\n                }\n                NumericOperationIr::MinDimWithIndices(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n\n                    let (output, output_idx) = B::int_min_dim_with_indices(tensor, desc.dim);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output_idx);\n                }\n                NumericOperationIr::Min(desc) => {\n                    unary_int_ops!(handles, desc, B::int_min)\n                }\n                NumericOperationIr::MaxDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_max_dim)\n                }\n                NumericOperationIr::MinDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_min_dim)\n                }\n                NumericOperationIr::MaxAbs(desc) => {\n                    unary_int_ops!(handles, desc, B::int_max_abs)\n                }\n                NumericOperationIr::MaxAbsDim(desc) => {\n                    reduce_int_dim_ops!(handles, desc, B::int_max_abs_dim)\n                }\n                NumericOperationIr::Clamp(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.tensor);\n\n                    let output = B::int_clamp(tensor, desc.min.into(), desc.max.into());\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::IntRandom(desc) => {\n                    let shape = desc.out.shape.clone();\n\n                    let output = B::int_random(shape, desc.distribution, &self.device);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::Powi(desc) => {\n                    let lhs = handles.get_int_tensor::<B>(&desc.lhs);\n                    let rhs = handles.get_int_tensor::<B>(&desc.rhs);\n\n                    let output = B::int_powi(lhs, rhs);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumSum(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n                    let output = B::int_cumsum(tensor, desc.axis);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumProd(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n                    let output = B::int_cumprod(tensor, desc.axis);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumMin(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n                    let output = B::int_cummin(tensor, desc.axis);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                NumericOperationIr::CumMax(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n                    let output = B::int_cummax(tensor, desc.axis);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::Bool(op) => match op {\n                BoolOperationIr::IntoFloat(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_into_float(tensor);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                BoolOperationIr::IntoInt(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_into_int(tensor);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                BoolOperationIr::Not(desc) => {\n                    let tensor = handles.get_bool_tensor::<B>(&desc.input);\n\n                    let output = B::bool_not(tensor);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                BoolOperationIr::And(desc) => {\n                    binary_bool_ops!(handles, desc, B::bool_and)\n                }\n                BoolOperationIr::Or(desc) => {\n                    binary_bool_ops!(handles, desc, B::bool_or)\n                }\n            },\n            OperationIr::Int(op) => match op {\n                IntOperationIr::IntoFloat(desc) => {\n                    let tensor = handles.get_int_tensor::<B>(&desc.input);\n\n                    let output = B::int_into_float(tensor);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                IntOperationIr::Matmul(desc) => {\n                    binary_int_ops!(handles, desc, B::int_matmul)\n                }\n                IntOperationIr::BitwiseAnd(desc) => {\n                    binary_int_ops!(handles, desc, B::bitwise_and)\n                }\n                IntOperationIr::BitwiseAndScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::bitwise_and_scalar)\n                }\n                IntOperationIr::BitwiseOr(desc) => {\n                    binary_int_ops!(handles, desc, B::bitwise_or)\n                }\n                IntOperationIr::BitwiseOrScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::bitwise_or_scalar)\n                }\n                IntOperationIr::BitwiseXor(desc) => {\n                    binary_int_ops!(handles, desc, B::bitwise_xor)\n                }\n                IntOperationIr::BitwiseXorScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::bitwise_xor_scalar)\n                }\n                IntOperationIr::BitwiseNot(desc) => {\n                    unary_int_ops!(handles, desc, B::bitwise_not)\n                }\n                IntOperationIr::BitwiseLeftShift(desc) => {\n                    binary_int_ops!(handles, desc, B::bitwise_left_shift)\n                }\n                IntOperationIr::BitwiseRightShift(desc) => {\n                    binary_int_ops!(handles, desc, B::bitwise_right_shift)\n                }\n                IntOperationIr::BitwiseLeftShiftScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar)\n                }\n                IntOperationIr::BitwiseRightShiftScalar(desc) => {\n                    scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar)\n                }\n            },\n            OperationIr::Float(_dtype, op) => match op {\n                FloatOperationIr::Exp(desc) => {\n                    unary_float_ops!(handles, desc, B::float_exp)\n                }\n                FloatOperationIr::Powf(desc) => {\n                    binary_float_ops!(handles, desc, B::float_powf)\n                }\n                FloatOperationIr::Log(desc) => {\n                    unary_float_ops!(handles, desc, B::float_log)\n                }\n                FloatOperationIr::Log1p(desc) => {\n                    unary_float_ops!(handles, desc, B::float_log1p)\n                }\n                FloatOperationIr::Erf(desc) => {\n                    unary_float_ops!(handles, desc, B::float_erf)\n                }\n                FloatOperationIr::PowfScalar(desc) => {\n                    scalar_float_ops!(handles, desc, B::float_powf_scalar)\n                }\n                FloatOperationIr::Sqrt(desc) => {\n                    unary_float_ops!(handles, desc, B::float_sqrt)\n                }\n                FloatOperationIr::Cos(desc) => {\n                    unary_float_ops!(handles, desc, B::float_cos)\n                }\n                FloatOperationIr::Sin(desc) => {\n                    unary_float_ops!(handles, desc, B::float_sin)\n                }\n                FloatOperationIr::Tanh(desc) => {\n                    unary_float_ops!(handles, desc, B::float_tanh)\n                }\n                FloatOperationIr::Tan(desc) => unary_float_ops!(handles, desc, B::float_tan),\n                FloatOperationIr::Cosh(desc) => unary_float_ops!(handles, desc, B::float_cosh),\n                FloatOperationIr::Sinh(desc) => unary_float_ops!(handles, desc, B::float_sinh),\n                FloatOperationIr::ArcCos(desc) => unary_float_ops!(handles, desc, B::float_acos),\n                FloatOperationIr::ArcCosh(desc) => unary_float_ops!(handles, desc, B::float_acosh),\n                FloatOperationIr::ArcSin(desc) => unary_float_ops!(handles, desc, B::float_asin),\n                FloatOperationIr::ArcSinh(desc) => unary_float_ops!(handles, desc, B::float_asinh),\n                FloatOperationIr::ArcTan(desc) => unary_float_ops!(handles, desc, B::float_atan),\n                FloatOperationIr::ArcTanh(desc) => unary_float_ops!(handles, desc, B::float_atanh),\n                FloatOperationIr::ArcTan2(desc) => binary_float_ops!(handles, desc, B::float_atan2),\n                FloatOperationIr::Round(desc) => {\n                    unary_float_ops!(handles, desc, B::float_round)\n                }\n                FloatOperationIr::Floor(desc) => {\n                    unary_float_ops!(handles, desc, B::float_floor)\n                }\n                FloatOperationIr::Ceil(desc) => {\n                    unary_float_ops!(handles, desc, B::float_ceil)\n                }\n                FloatOperationIr::Trunc(desc) => {\n                    unary_float_ops!(handles, desc, B::float_trunc)\n                }\n                FloatOperationIr::IntoInt(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_into_int(tensor);\n                    handles.register_int_tensor::<B>(&desc.out.id, output);\n                }\n                FloatOperationIr::Matmul(desc) => {\n                    binary_float_ops!(handles, desc, B::float_matmul)\n                }\n                FloatOperationIr::Cross(desc) => {\n                    let lhs = handles.get_float_tensor::<B>(&desc.lhs);\n                    let rhs = handles.get_float_tensor::<B>(&desc.rhs);\n                    let output = B::float_cross(lhs, rhs, desc.dim);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                FloatOperationIr::Random(desc) => {\n                    let shape = desc.out.shape.clone();\n\n                    let output = B::float_random(shape, desc.distribution, &self.device);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                FloatOperationIr::Recip(desc) => {\n                    unary_float_ops!(handles, desc, B::float_recip)\n                }\n                FloatOperationIr::Quantize(_) => todo!(),\n                FloatOperationIr::Dequantize(_) => todo!(),\n                FloatOperationIr::IsNan(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_is_nan(tensor);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                FloatOperationIr::IsInf(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.input);\n\n                    let output = B::float_is_inf(tensor);\n                    handles.register_bool_tensor::<B>(&desc.out.id, output);\n                }\n                FloatOperationIr::GridSample2d(desc) => {\n                    let tensor = handles.get_float_tensor::<B>(&desc.tensor);\n                    let grid = handles.get_float_tensor::<B>(&desc.grid);\n\n                    let output = B::float_grid_sample_2d(tensor, grid, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::Module(op) => match op {\n                ModuleOperationIr::Embedding(desc) => {\n                    let weights = handles.get_float_tensor::<B>(&desc.weights);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::embedding(weights, indices);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::EmbeddingBackward(desc) => {\n                    let weights = handles.get_float_tensor::<B>(&desc.weights);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.out_grad);\n\n                    let output = B::embedding_backward(weights, output_grad, indices);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv1d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv1d(x, weight, bias, desc.clone().options.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv1dXBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output =\n                        B::conv1d_x_backward(x, weight, output_grad, desc.clone().options.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv1dWeightBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv1d_weight_backward(\n                        x,\n                        weight,\n                        output_grad,\n                        desc.clone().options.into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv1dBiasBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let bias = handles.get_float_tensor::<B>(&desc.bias);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv1d_bias_backward(x, bias, output_grad);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv2d(x, weight, bias, desc.clone().options.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv2dXBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output =\n                        B::conv2d_x_backward(x, weight, output_grad, desc.clone().options.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv2dWeightBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv2d_weight_backward(\n                        x,\n                        weight,\n                        output_grad,\n                        desc.clone().options.into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv2dBiasBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let bias = handles.get_float_tensor::<B>(&desc.bias);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv2d_bias_backward(x, bias, output_grad);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv3d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv3d(x, weight, bias, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv3dXBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output =\n                        B::conv3d_x_backward(x, weight, output_grad, desc.clone().options.into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv3dWeightBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv3d_weight_backward(\n                        x,\n                        weight,\n                        output_grad,\n                        desc.clone().options.into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Conv3dBiasBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let bias = handles.get_float_tensor::<B>(&desc.bias);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.output_grad);\n\n                    let output = B::conv3d_bias_backward(x, bias, output_grad);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::DeformableConv2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let offset = handles.get_float_tensor::<B>(&desc.offset);\n                    let mask = desc\n                        .mask\n                        .as_ref()\n                        .map(|mask| handles.get_float_tensor::<B>(mask));\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::deform_conv2d(\n                        x,\n                        offset,\n                        weight,\n                        mask,\n                        bias,\n                        desc.options.clone().into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::DeformableConv2dBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let offset = handles.get_float_tensor::<B>(&desc.offset);\n                    let mask = desc\n                        .mask\n                        .as_ref()\n                        .map(|mask| handles.get_float_tensor::<B>(mask));\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n                    let output_grad = handles.get_float_tensor::<B>(&desc.out_grad);\n\n                    let output = B::deform_conv2d_backward(\n                        x,\n                        offset,\n                        weight,\n                        mask,\n                        bias,\n                        output_grad,\n                        desc.options.clone().into(),\n                    );\n\n                    handles.register_float_tensor::<B>(&desc.input_grad.id, output.x_grad);\n                    handles.register_float_tensor::<B>(&desc.offset_grad.id, output.offset_grad);\n                    handles.register_float_tensor::<B>(&desc.weight_grad.id, output.weight_grad);\n                    if let Some((mask_grad, field)) = output.mask_grad.zip(desc.mask_grad.as_ref())\n                    {\n                        handles.register_float_tensor::<B>(&field.id, mask_grad);\n                    }\n                    if let Some((bias_grad, field)) = output.bias_grad.zip(desc.bias_grad.as_ref())\n                    {\n                        handles.register_float_tensor::<B>(&field.id, bias_grad);\n                    }\n                }\n                ModuleOperationIr::ConvTranspose1d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv_transpose1d(x, weight, bias, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::ConvTranspose2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv_transpose2d(x, weight, bias, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::ConvTranspose3d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let weight = handles.get_float_tensor::<B>(&desc.weight);\n                    let bias = desc\n                        .bias\n                        .as_ref()\n                        .map(|bias| handles.get_float_tensor::<B>(bias));\n\n                    let output = B::conv_transpose3d(x, weight, bias, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AvgPool1d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::avg_pool1d(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.count_include_pad,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AvgPool2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::avg_pool2d(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.count_include_pad,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AvgPool1dBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let grad = handles.get_float_tensor::<B>(&desc.grad);\n\n                    let output = B::avg_pool1d_backward(\n                        x,\n                        grad,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.count_include_pad,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AvgPool2dBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let grad = handles.get_float_tensor::<B>(&desc.grad);\n\n                    let output = B::avg_pool2d_backward(\n                        x,\n                        grad,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.count_include_pad,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AdaptiveAvgPool1d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::adaptive_avg_pool1d(x, desc.output_size);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AdaptiveAvgPool2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::adaptive_avg_pool2d(x, desc.output_size);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AdaptiveAvgPool1dBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let grad = handles.get_float_tensor::<B>(&desc.grad);\n\n                    let output = B::adaptive_avg_pool1d_backward(x, grad);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::AdaptiveAvgPool2dBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let grad = handles.get_float_tensor::<B>(&desc.grad);\n\n                    let output = B::adaptive_avg_pool2d_backward(x, grad);\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::MaxPool1d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::max_pool1d(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::MaxPool1dWithIndices(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::max_pool1d_with_indices(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output.output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output.indices);\n                }\n                ModuleOperationIr::MaxPool1dWithIndicesBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.grad);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::max_pool1d_with_indices_backward(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                        output_grad,\n                        indices,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output.x_grad);\n                }\n                ModuleOperationIr::MaxPool2d(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::max_pool2d(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::MaxPool2dWithIndices(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::max_pool2d_with_indices(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output.output);\n                    handles.register_int_tensor::<B>(&desc.out_indices.id, output.indices);\n                }\n                ModuleOperationIr::MaxPool2dWithIndicesBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let output_grad = handles.get_float_tensor::<B>(&desc.grad);\n                    let indices = handles.get_int_tensor::<B>(&desc.indices);\n\n                    let output = B::max_pool2d_with_indices_backward(\n                        x,\n                        desc.kernel_size,\n                        desc.stride,\n                        desc.padding,\n                        desc.dilation,\n                        desc.ceil_mode,\n                        output_grad,\n                        indices,\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output.x_grad);\n                }\n                ModuleOperationIr::Interpolate(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n\n                    let output = B::interpolate(x, desc.output_size, desc.options.clone().into());\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::InterpolateBackward(desc) => {\n                    let x = handles.get_float_tensor::<B>(&desc.x);\n                    let grad = handles.get_float_tensor::<B>(&desc.grad);\n\n                    let output = B::interpolate_backward(\n                        x,\n                        grad,\n                        desc.output_size,\n                        desc.options.clone().into(),\n                    );\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n                ModuleOperationIr::Attention(desc) => {\n                    let query = handles.get_float_tensor::<B>(&desc.query);\n                    let key = handles.get_float_tensor::<B>(&desc.key);\n                    let value = handles.get_float_tensor::<B>(&desc.value);\n                    let mask = desc.mask.as_ref().map(|m| handles.get_bool_tensor::<B>(m));\n                    let attn_bias = desc\n                        .attn_bias\n                        .as_ref()\n                        .map(|ab| handles.get_float_tensor::<B>(ab));\n\n                    let output = B::attention(\n                        query,\n                        key,\n                        value,\n                        mask,\n                        attn_bias,\n                        desc.options.clone().into(),\n                    );\n\n                    handles.register_float_tensor::<B>(&desc.out.id, output);\n                }\n            },\n            OperationIr::Custom(_) => {\n                panic!(\"Can't execute custom operation here\")\n            }\n            OperationIr::Init(_) => {\n                // Nothing to do.\n            }\n            OperationIr::Drop(repr) => {\n                handles.remove_handle(repr.id);\n            }\n        }\n    }\n\n    fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {\n        let mut ctx = self.context.lock().unwrap();\n\n        enum Output<B: Backend> {\n            Float(B::FloatTensorPrimitive),\n            Int(B::IntTensorPrimitive),\n            Bool(B::BoolTensorPrimitive),\n        }\n\n        let tensor = if tensor.dtype.is_float() {\n            let tensor = ctx.handles.get_float_tensor::<B>(&tensor);\n            Output::<B>::Float(tensor)\n        } else if tensor.dtype.is_int() {\n            let tensor = ctx.handles.get_int_tensor::<B>(&tensor);\n            Output::Int(tensor)\n        } else if tensor.dtype.is_bool() {\n            let tensor = ctx.handles.get_bool_tensor::<B>(&tensor);\n            Output::Bool(tensor)\n        } else if let DType::QFloat(_) = tensor.dtype {\n            todo!()\n        } else {\n            unimplemented!()\n        };\n\n        match tensor {\n            Output::Float(val) => Box::pin(B::float_into_data(val)),\n            Output::Int(val) => Box::pin(B::int_into_data(val)),\n            Output::Bool(val) => Box::pin(B::bool_into_data(val)),\n        }\n    }\n\n    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {\n        let desc = self.register_tensor_data_desc(data);\n        RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())\n    }\n\n    fn device(&self) -> Self::Device {\n        self.device.clone()\n    }\n\n    fn sync(&self) -> Result<(), ExecutionError> {\n        B::sync(&self.device)\n    }\n\n    fn seed(&self, seed: u64) {\n        B::seed(&self.device, seed)\n    }\n\n    fn create_empty_handle(&self) -> TensorId {\n        let mut ctx = self.context.lock().unwrap();\n        ctx.create_empty_handle()\n    }\n\n    fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet {\n        B::dtype_usage(&self.device, dtype)\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/tensor.rs",
    "content": "use core::sync::atomic::{AtomicU32, Ordering};\n\nuse alloc::format;\nuse alloc::{sync::Arc, vec::Vec};\n\nuse super::RunnerClient;\nuse burn_backend::{DType, Shape, TensorData, TensorMetadata, backend::ExecutionError};\nuse burn_ir::{TensorId, TensorIr, TensorStatus};\n\n/// Tensor primitive for the [router backend](crate::BackendRouter).\npub struct RouterTensor<C: RunnerClient> {\n    pub(crate) id: TensorId,\n    pub(crate) shape: Shape,\n    pub(crate) dtype: DType,\n    /// The client that has this tensor\n    pub client: C,\n    pub(crate) count: Arc<AtomicU32>,\n}\n\nimpl<C: RunnerClient> TensorMetadata for RouterTensor<C> {\n    fn dtype(&self) -> DType {\n        self.dtype\n    }\n\n    fn shape(&self) -> Shape {\n        self.shape.clone()\n    }\n\n    fn rank(&self) -> usize {\n        self.shape.num_dims()\n    }\n}\n\nimpl<C: RunnerClient> RouterTensor<C> {\n    /// Create a new router tensor.\n    pub fn new(id: TensorId, shape: Shape, dtype: DType, client: C) -> Self {\n        Self {\n            id,\n            shape,\n            dtype,\n            client,\n            count: Arc::new(AtomicU32::new(1)),\n        }\n    }\n\n    pub(crate) async fn into_data(self) -> Result<TensorData, ExecutionError> {\n        self.client.clone().read_tensor_async(self.into_ir()).await\n    }\n\n    /// Get the ir for this tensor\n    pub fn into_ir(mut self) -> TensorIr {\n        let count = self.count.load(Ordering::Relaxed);\n        let status = self.status(count);\n        let mut shape_out = Shape::from(Vec::<usize>::new());\n        core::mem::swap(&mut self.shape, &mut shape_out);\n\n        if let TensorStatus::ReadWrite = status {\n            // Avoids an unwanted drop on the same thread.\n            //\n            // Since `drop` is called after `into_ir`, we must not register a drop if the tensor\n            // was consumed with a `ReadWrite` status.\n            self.count.fetch_add(1, Ordering::Relaxed);\n        }\n\n        TensorIr {\n            status,\n            shape: shape_out,\n            id: self.id,\n            dtype: self.dtype,\n        }\n    }\n\n    pub(crate) fn to_ir_out(&self) -> TensorIr {\n        TensorIr {\n            status: TensorStatus::NotInit,\n            shape: self.shape.clone(),\n            id: self.id,\n            dtype: self.dtype,\n        }\n    }\n\n    pub(crate) fn status(&self, count: u32) -> TensorStatus {\n        if count <= 1 {\n            TensorStatus::ReadWrite\n        } else {\n            TensorStatus::ReadOnly\n        }\n    }\n}\n\nimpl<C: RunnerClient> core::fmt::Debug for RouterTensor<C> {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(\n            format!(\n                \"{{ id: {:?}, shape: {:?}, dtype: {:?}, device: {:?} }}\",\n                self.id,\n                self.shape,\n                self.dtype,\n                self.client.device().clone(),\n            )\n            .as_str(),\n        )\n    }\n}\n\nimpl<C: RunnerClient> Clone for RouterTensor<C> {\n    fn clone(&self) -> Self {\n        self.count.fetch_add(1, Ordering::Relaxed);\n\n        Self {\n            id: self.id,\n            shape: self.shape.clone(),\n            client: self.client.clone(),\n            dtype: self.dtype,\n            count: self.count.clone(),\n        }\n    }\n}\n\nimpl<C: RunnerClient> Drop for RouterTensor<C> {\n    fn drop(&mut self) {\n        let count = self.count.fetch_sub(1, Ordering::Relaxed);\n\n        match self.status(count) {\n            TensorStatus::ReadWrite => {\n                let id = self.id;\n                let mut shape = Shape::from(Vec::<usize>::new());\n                core::mem::swap(&mut shape, &mut self.shape);\n\n                let ir = TensorIr {\n                    id,\n                    shape,\n                    status: TensorStatus::ReadWrite,\n                    dtype: self.dtype,\n                };\n                self.client.register_op(burn_ir::OperationIr::Drop(ir));\n            }\n            TensorStatus::ReadOnly => {}\n            TensorStatus::NotInit => {}\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-router/src/types.rs",
    "content": "use alloc::format;\nuse alloc::string::String;\nuse burn_backend::{\n    DType, Shape, TensorData,\n    backend::{Backend, DeviceId, DeviceOps, ExecutionError},\n    try_read_sync,\n};\nuse burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};\nuse burn_std::future::DynFut;\n\nuse crate::{\n    ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,\n    RunnerClient,\n};\n\n/// Implement multi backend types, with enums having one variant per backend.\nmacro_rules! impl_multi_backend_types {\n    // Match the default backend and at least one other backend, with rest being optional\n    ($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {\n        /// Module containing the essential types for multi-backend operations.\n        ///\n        /// - `Handle`: the type used to point to a tensor (defined for all backends).\n        /// - `MultiRunnerClient`: a client for multiple runners (each responsible to execute tensor operations on a given backend).\n        /// - `DirectChannel`: a local channel with direct connection to the backend runner clients.\n        /// - `ByteBridge`: a simple multi-backend bridge that transfers tensors via the underlying [tensor data](burn_backend::TensorData).\n        ///\n        /// Each enum type is defined with backend identifiers as variant names (e.g., `B1` and `B2` for dual backends).\n        pub mod $module_name {\n            use super::*;\n\n            /// The type that can be used to point to a tensor of any kind.\n            /// Each backend has its own variant.\n            pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {\n                #[allow(missing_docs)]\n                $DefaultBackend($DefaultBackend::Handle),\n                $(\n                    #[allow(missing_docs)]\n                    $OtherBackend($OtherBackend::Handle),\n                )+\n            }\n\n            /// The device type used by a backend.\n            /// Each backend has its own variant.\n            #[derive(Clone, Debug)]\n            pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {\n                #[allow(missing_docs)]\n                $DefaultBackend($DefaultBackend::Device),\n                $(\n                    #[allow(missing_docs)]\n                    $OtherBackend($OtherBackend::Device),\n                )+\n            }\n            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {\n                fn eq(&self, other: &Self) -> bool {\n                    match (self, other) {\n                        (Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,\n                        $(\n                            (Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,\n                        )+\n                        _ => false,\n                    }\n                }\n            }\n\n            // Default implementation always returns the first backend's device\n            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {\n                fn default() -> Self {\n                    Self::$DefaultBackend($DefaultBackend::Device::default())\n                }\n            }\n\n            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {\n                fn from_id(_device_id: DeviceId) -> Self {\n                    // TODO: Should be fix with the new router backend.\n                    Default::default()\n                }\n\n                fn to_id(&self) -> DeviceId {\n                    match self {\n                        Self::$DefaultBackend(device) => device.id(),\n                        $(\n                            Self::$OtherBackend(device) => device.id(),\n                        )+\n                    }\n                }\n\n                fn device_count(_type_id: u16) -> usize {\n                    1\n                }\n            }\n\n            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}\n\n            /// A local client with multiple runners (each responsible to execute tensor operations on a given backend).\n            #[derive(Clone)]\n            pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {\n                #[allow(missing_docs)]\n                $DefaultBackend(Runner<$DefaultBackend>),\n                $(\n                    #[allow(missing_docs)]\n                    $OtherBackend(Runner<$OtherBackend>),\n                )+\n            }\n\n            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>\n            {\n               type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;\n\n                fn register_op(&self, op: OperationIr) {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.register_op(op),\n                        $(\n                            Self::$OtherBackend(runner) => runner.register_op(op),\n                        )+\n                    }\n                }\n\n                fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),\n                        $(\n                            Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),\n                        )+\n                    }\n                }\n\n                fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {\n                    match self {\n                        Self::$DefaultBackend(runner) => {\n                            let desc = runner.register_tensor_data_desc(data);\n                            RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())\n                        }\n                        $(\n                            Self::$OtherBackend(runner) => {\n                                let desc = runner.register_tensor_data_desc(data);\n                                RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())\n                            }\n                        )+\n                    }\n                }\n\n                fn device(&self) -> Self::Device {\n                    match self {\n                        Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),\n                        $(\n                            Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),\n                        )+\n                    }\n                }\n\n                fn sync(&self) -> Result<(), ExecutionError> {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.sync(),\n                        $(\n                            Self::$OtherBackend(runner) => runner.sync(),\n                        )+\n                    }\n                }\n\n                fn seed(&self, seed: u64) {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.seed(seed),\n                        $(\n                            Self::$OtherBackend(runner) => runner.seed(seed),\n                        )+\n                    }\n                }\n\n                fn create_empty_handle(&self) -> TensorId {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.create_empty_handle(),\n                        $(\n                            Self::$OtherBackend(runner) => runner.create_empty_handle(),\n                        )+\n                    }\n                }\n\n                fn dtype_usage(&self, dtype: burn_std::DType) -> burn_backend::DTypeUsageSet {\n                    match self {\n                        Self::$DefaultBackend(runner) => runner.dtype_usage(dtype),\n                        $(\n                            Self::$OtherBackend(runner) => runner.dtype_usage(dtype),\n                        )+\n                    }\n                }\n            }\n\n            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>\n            where\n                Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,\n            {\n                type Device = Br::Device;\n\n                type Bridge = Br;\n\n                type FloatElem = $DefaultBackend::FloatElem;\n                type IntElem = $DefaultBackend::IntElem;\n                type BoolElem = $DefaultBackend::BoolElem;\n\n                type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;\n\n                fn init_client(device: &Self::Device) -> Self::Client {\n                    match device {\n                        MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),\n                        $(\n                            MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),\n                        )+\n                    }\n                }\n\n                fn get_tensor_handle(\n                    tensor: &TensorIr,\n                    client: &Self::Client,\n                ) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {\n                    match client {\n                        MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),\n                        $(\n                            MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),\n                        )+\n                    }\n                }\n\n                fn register_tensor(\n                    client: &Self::Client,\n                    handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,\n                    shape: Shape,\n                    dtype: DType,\n                ) -> RouterTensor<Self::Client> {\n                    match client {\n                        MultiRunnerClient::$DefaultBackend(runner) => match handle {\n                            Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),\n                            _ => unreachable!(\"Can't register tensor handle for another backend.\"),\n                        },\n                        $(\n                            MultiRunnerClient::$OtherBackend(runner) =>  match handle {\n                                Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),\n                                _ => unreachable!(\"Can't register tensor handle for another backend.\"),\n                            },\n                        )+\n                    }\n                }\n\n                fn name(_device: &Self::Device) -> String {\n                    let mut name = format!(\"{}\", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));\n                    $(\n                        name.push_str(&format!(\", {}\", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));\n                    )+\n                    format!(\"direct<({})>\", name)\n                }\n            }\n\n            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {\n                type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;\n                type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;\n\n                fn change_backend_float(\n                    tensor: Self::TensorHandle,\n                    shape: Shape,\n                    target_device: &Self::Device,\n                ) -> Self::TensorHandle {\n                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)\n                }\n\n                fn change_backend_int(\n                    tensor: Self::TensorHandle,\n                    shape: Shape,\n                    target_device: &Self::Device,\n                ) -> Self::TensorHandle {\n                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)\n                }\n\n                fn change_backend_bool(\n                    tensor: Self::TensorHandle,\n                    shape: Shape,\n                    target_device: &Self::Device,\n                ) -> Self::TensorHandle {\n                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)\n                }\n\n            }\n        }\n    };\n}\n\nmacro_rules! bridge {\n    ($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{\n        // Bridge for the same backend\n        let tensor = $Backend::float_tensor(TensorHandle {\n            handle: $handle,\n            shape: $shape,\n        });\n        let tensor = $Backend::float_to_device(tensor, $device);\n        let handle = $Backend::float_tensor_handle(tensor);\n        Handle::$Backend(handle)\n    }};\n    ($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{\n        // Byte bridge between two backends\n        let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });\n        let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(\n            \"Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM.\"\n        );\n        let tensor = $BackendB::float_from_data(data, $device);\n        let handle = $BackendB::float_tensor_handle(tensor);\n        Handle::$BackendB(handle)\n    }};\n}\n\nmacro_rules! multi_backend_match {\n    ($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {\n        multi_backend_match! (\n            @step\n            $shape,\n            ($handle, $device);\n            {\n                (Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),\n                $(\n                    (Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),\n                    (Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),\n                    (Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),\n                )+\n            };\n            $($OtherBackend),+\n        )\n    };\n\n    (@step\n        $shape:expr,\n        $pats:tt;\n        { $($arms:tt)* };\n        $BackendA:ident,\n        $($OtherBackend:ident),+\n    ) => {\n        multi_backend_match! (\n            @step\n            $shape,\n            $pats;\n            {\n                $($arms)*\n                $(\n                    (Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),\n                    (Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),\n                )*\n            };\n            $($OtherBackend),*\n        )\n    };\n\n    (@step\n        $shape:expr,\n        ($handle:expr, $device:expr);\n        { $($arms:tt)* };\n        $($BackendA:ident)?\n    ) => {\n        match ($handle, $device) {\n            $($arms)*\n        }\n    };\n}\n\n// Implement multi-backend types and byte bridge for up to 4 backends\nimpl_multi_backend_types!(duo, B1, B2);\nimpl_multi_backend_types!(trio, B1, B2, B3);\nimpl_multi_backend_types!(quad, B1, B2, B3, B4);\n\n#[cfg(not(target_os = \"windows\"))] // cannot find a wgpu adapter on windows CI\n#[cfg(test)]\nmod tests {\n    use burn_tensor::{Tensor, backend::Backend};\n\n    use super::*;\n    use crate::tests::{TestBackend, TestBackend1, TestBackend2};\n\n    #[test]\n    fn should_support_dual_byte_bridge() {\n        let device1 = duo::MultiDevice::B1(<TestBackend1 as Backend>::Device::default());\n        let device2 = duo::MultiDevice::B2(<TestBackend2 as Backend>::Device::default());\n        let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);\n        let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);\n\n        let tensor1_2 = tensor1.clone().to_device(&device2);\n        tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);\n\n        let tensor2_1 = tensor2.clone().to_device(&device1);\n        tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/Cargo.toml",
    "content": "[package]\nauthors = [\"Dilshod Tadjibaev (@antimora)\"]\ncategories = []\ndescription = \"Core types and utilities shared across the Burn ecosystem.\"\ndocumentation = \"https://docs.rs/burn-std\"\nedition.workspace = true\nkeywords = []\nlicense.workspace = true\nname = \"burn-std\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-std\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ncubecl = [\"dep:cubecl\"]\ndefault = [\"std\", \"cubecl-common/default\"]\ndoc = [\"default\"]\nstd = [\"cubecl-common/std\", \"num-traits/std\"]\ntracing = [\"cubecl?/tracing\", \"cubecl-common/tracing\"]\n\nnetwork = [\"dep:indicatif\", \"dep:reqwest\", \"dep:tokio\"]\n\n[dependencies]\nbytemuck = { workspace = true, features = [\"extern_crate_alloc\"] }\nhalf = { workspace = true, features = [\"bytemuck\"] }\nnum-traits = { workspace = true }\nserde = { workspace = true }\nsmallvec = { workspace = true, features = [\"serde\"] }\n\ncubecl = { workspace = true, optional = true, default-features = false }\ncubecl-common = { workspace = true, default-features = false, features = [\n    \"serde\",\n    \"shared-bytes\",\n] }\ncubecl-zspace = { workspace = true, default-features = false }\n# Enable extra-platforms for portable-atomic support on targets without native atomics (e.g., thumbv6m)\n# This is needed because cubecl-common's shared-bytes feature pulls in bytes\nbytes = { workspace = true }\n\n# Network downloader\nindicatif = { workspace = true, optional = true }\nreqwest = { workspace = true, optional = true }\ntokio = { workspace = true, optional = true }\n\n[dev-dependencies]\ndashmap = { workspace = true }\n\n# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)\n[target.'cfg(not(target_has_atomic = \"ptr\"))'.dependencies]\nbytes = { workspace = true, features = [\"extra-platforms\"] }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-std/README.md",
    "content": "# Burn Standard Library\n\n`burn-std` provides the core types and utilities shared across the Burn ecosystem.  \nIt includes foundational definitions for shapes, indexing, and data types.\n\nThis crate supports both `std` and `no_std` environments and must compile with\n`cargo build --no-default-features` as well.\n"
  },
  {
    "path": "crates/burn-std/src/id.rs",
    "content": "//! # Unique Identifiers\nuse crate::rand::gen_random;\n\n/// Simple ID generator.\npub struct IdGenerator {}\n\nimpl IdGenerator {\n    /// Generates a new ID.\n    pub fn generate() -> u64 {\n        // Generate a random u64 (18,446,744,073,709,551,615 combinations)\n        let random_bytes: [u8; 8] = gen_random();\n        u64::from_le_bytes(random_bytes)\n    }\n}\n\npub use cubecl_common::stream_id::StreamId;\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    use alloc::collections::BTreeSet;\n\n    #[cfg(feature = \"std\")]\n    use dashmap::DashSet; //Concurrent HashMap\n    #[cfg(feature = \"std\")]\n    use std::{sync::Arc, thread};\n\n    #[test]\n    fn uniqueness_test() {\n        const IDS_CNT: usize = 10_000;\n\n        let mut set: BTreeSet<u64> = BTreeSet::new();\n\n        for _i in 0..IDS_CNT {\n            assert!(set.insert(IdGenerator::generate()));\n        }\n\n        assert_eq!(set.len(), IDS_CNT);\n    }\n\n    #[cfg(feature = \"std\")]\n    #[test]\n    fn thread_safety_test() {\n        const NUM_THREADS: usize = 10;\n        const NUM_REPEATS: usize = 1_000;\n        const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;\n\n        let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());\n\n        let mut handles = vec![];\n\n        for _ in 0..NUM_THREADS {\n            let set = set.clone();\n\n            let handle = thread::spawn(move || {\n                for _i in 0..NUM_REPEATS {\n                    assert!(set.insert(IdGenerator::generate()));\n                }\n            });\n            handles.push(handle);\n        }\n\n        for handle in handles {\n            handle.join().unwrap();\n        }\n        assert_eq!(set.len(), EXPECTED_TOTAL_IDS);\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! # Burn Standard Library\n//!\n//! This library contains core types and utilities shared across Burn, including shapes, indexing,\n//! and data types.\n\nextern crate alloc;\n\n/// Id module contains types for unique identifiers.\npub mod id;\n\n/// Tensor utilities.\npub mod tensor;\npub use tensor::*;\n\n/// Common Errors.\npub use cubecl_zspace::errors::{self, *};\n\n/// Network utilities.\n#[cfg(feature = \"network\")]\npub mod network;\n\n// Re-exported types\npub use cubecl_common::bytes::*;\npub use cubecl_common::device_handle::DeviceHandle;\npub use cubecl_common::*;\npub use half::{bf16, f16};\n\n#[cfg(feature = \"cubecl\")]\npub use cubecl::flex32;\n\n#[cfg(feature = \"cubecl\")]\nmod cube {\n    use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind};\n    use cubecl_common::quant::scheme::QuantScheme;\n\n    use crate::tensor::DType;\n    use crate::tensor::quantization::{QuantStore, QuantValue};\n\n    impl From<DType> for cubecl::ir::ElemType {\n        fn from(dtype: DType) -> Self {\n            match dtype {\n                DType::F64 => ElemType::Float(FloatKind::F64),\n                DType::F32 => ElemType::Float(FloatKind::F32),\n                DType::Flex32 => ElemType::Float(FloatKind::Flex32),\n                DType::F16 => ElemType::Float(FloatKind::F16),\n                DType::BF16 => ElemType::Float(FloatKind::BF16),\n                DType::I64 => ElemType::Int(IntKind::I64),\n                DType::I32 => ElemType::Int(IntKind::I32),\n                DType::I16 => ElemType::Int(IntKind::I16),\n                DType::I8 => ElemType::Int(IntKind::I8),\n                DType::U64 => ElemType::UInt(UIntKind::U64),\n                DType::U32 => ElemType::UInt(UIntKind::U32),\n                DType::U16 => ElemType::UInt(UIntKind::U16),\n                DType::U8 => ElemType::UInt(UIntKind::U8),\n                DType::Bool(store) => match store {\n                    crate::BoolStore::Native => ElemType::Bool,\n                    crate::BoolStore::U8 => ElemType::UInt(UIntKind::U8),\n                    crate::BoolStore::U32 => ElemType::UInt(UIntKind::U32),\n                },\n                DType::QFloat(scheme) => match scheme.store {\n                    QuantStore::Native => match scheme.value {\n                        QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),\n                        QuantValue::E4M3 => Self::Float(FloatKind::E4M3),\n                        QuantValue::E5M2 => Self::Float(FloatKind::E5M2),\n                        QuantValue::Q4F\n                        | QuantValue::Q4S\n                        | QuantValue::Q2F\n                        | QuantValue::Q2S\n                        | QuantValue::E2M1 => {\n                            panic!(\"Can't store native sub-byte values\")\n                        }\n                    },\n                    QuantStore::PackedU32(_) => Self::UInt(UIntKind::U32),\n                    QuantStore::PackedNative(_) => match scheme.value {\n                        QuantValue::E2M1 => panic!(\"Can't store native sub-byte values\"),\n                        other => panic!(\"{other:?} doesn't support native packing\"),\n                    },\n                },\n            }\n        }\n    }\n\n    impl From<DType> for cubecl::ir::StorageType {\n        fn from(dtype: DType) -> cubecl::ir::StorageType {\n            match dtype {\n                DType::QFloat(QuantScheme {\n                    store: QuantStore::PackedNative(_),\n                    value: QuantValue::E2M1,\n                    ..\n                }) => StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2),\n                _ => {\n                    let elem: ElemType = dtype.into();\n                    elem.into()\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/network.rs",
    "content": "//! # Common Network Utilities\n\n/// Network download utilities.\npub mod downloader {\n    use indicatif::{ProgressBar, ProgressState, ProgressStyle};\n    use reqwest::Client;\n    use std::io::Write;\n\n    /// Download the file at the specified url.\n    /// File download progress is reported with the help of a [progress bar](indicatif).\n    ///\n    /// # Arguments\n    ///\n    /// * `url` - The file URL to download.\n    /// * `message` - The message to display on the progress bar during download.\n    ///\n    /// # Returns\n    ///\n    /// A vector of bytes containing the downloaded file data.\n    #[tokio::main(flavor = \"current_thread\")]\n    pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {\n        // Get file from web\n        let mut response = Client::new().get(url).send().await.unwrap();\n        let total_size = response.content_length().unwrap();\n\n        // Pretty progress bar\n        let pb = ProgressBar::new(total_size);\n        let msg = message.to_owned();\n        pb.set_style(\n            ProgressStyle::with_template(\n                \"{msg}\\n    {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})\",\n            )\n            .unwrap()\n            .with_key(\n                \"eta\",\n                |state: &ProgressState, w: &mut dyn std::fmt::Write| {\n                    write!(w, \"{:.1}s\", state.eta().as_secs_f64()).unwrap()\n                },\n            )\n            .progress_chars(\"▬  \"),\n        );\n        pb.set_message(msg.clone());\n\n        // Read stream into bytes\n        let mut downloaded: u64 = 0;\n        let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);\n        while let Some(chunk) = response.chunk().await.unwrap() {\n            let num_bytes = bytes.write(&chunk).unwrap();\n            let new = std::cmp::min(downloaded + (num_bytes as u64), total_size);\n            downloaded = new;\n            pb.set_position(new);\n        }\n        pb.finish_with_message(msg);\n\n        bytes\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/tensor/dtype.rs",
    "content": "//! Tensor data type.\n\nuse serde::{Deserialize, Serialize};\n\nuse crate::tensor::quantization::{QuantScheme, QuantStore, QuantValue};\nuse crate::{bf16, f16};\n\n#[allow(missing_docs)]\n#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]\npub enum DType {\n    F64,\n    F32,\n    Flex32,\n    F16,\n    BF16,\n    I64,\n    I32,\n    I16,\n    I8,\n    U64,\n    U32,\n    U16,\n    U8,\n    Bool(BoolStore),\n    QFloat(QuantScheme),\n}\n\n#[cfg(feature = \"cubecl\")]\nimpl From<cubecl::ir::ElemType> for DType {\n    fn from(value: cubecl::ir::ElemType) -> Self {\n        match value {\n            cubecl::ir::ElemType::Float(float_kind) => match float_kind {\n                cubecl::ir::FloatKind::F16 => DType::F16,\n                cubecl::ir::FloatKind::BF16 => DType::BF16,\n                cubecl::ir::FloatKind::Flex32 => DType::Flex32,\n                cubecl::ir::FloatKind::F32 => DType::F32,\n                cubecl::ir::FloatKind::F64 => DType::F64,\n                cubecl::ir::FloatKind::TF32 => panic!(\"Not a valid DType for tensors.\"),\n                cubecl::ir::FloatKind::E2M1\n                | cubecl::ir::FloatKind::E2M3\n                | cubecl::ir::FloatKind::E3M2\n                | cubecl::ir::FloatKind::E4M3\n                | cubecl::ir::FloatKind::E5M2\n                | cubecl::ir::FloatKind::UE8M0 => {\n                    unimplemented!(\"Not yet supported, will be used for quantization\")\n                }\n            },\n            cubecl::ir::ElemType::Int(int_kind) => match int_kind {\n                cubecl::ir::IntKind::I8 => DType::I8,\n                cubecl::ir::IntKind::I16 => DType::I16,\n                cubecl::ir::IntKind::I32 => DType::I32,\n                cubecl::ir::IntKind::I64 => DType::I64,\n            },\n            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {\n                cubecl::ir::UIntKind::U8 => DType::U8,\n                cubecl::ir::UIntKind::U16 => DType::U16,\n                cubecl::ir::UIntKind::U32 => DType::U32,\n                cubecl::ir::UIntKind::U64 => DType::U64,\n            },\n            _ => panic!(\"Not a valid DType for tensors.\"),\n        }\n    }\n}\n\nimpl DType {\n    /// Returns the size of a type in bytes.\n    pub const fn size(&self) -> usize {\n        match self {\n            DType::F64 => core::mem::size_of::<f64>(),\n            DType::F32 => core::mem::size_of::<f32>(),\n            DType::Flex32 => core::mem::size_of::<f32>(),\n            DType::F16 => core::mem::size_of::<f16>(),\n            DType::BF16 => core::mem::size_of::<bf16>(),\n            DType::I64 => core::mem::size_of::<i64>(),\n            DType::I32 => core::mem::size_of::<i32>(),\n            DType::I16 => core::mem::size_of::<i16>(),\n            DType::I8 => core::mem::size_of::<i8>(),\n            DType::U64 => core::mem::size_of::<u64>(),\n            DType::U32 => core::mem::size_of::<u32>(),\n            DType::U16 => core::mem::size_of::<u16>(),\n            DType::U8 => core::mem::size_of::<u8>(),\n            DType::Bool(store) => match store {\n                BoolStore::Native => core::mem::size_of::<bool>(),\n                BoolStore::U8 => core::mem::size_of::<u8>(),\n                BoolStore::U32 => core::mem::size_of::<u32>(),\n            },\n            DType::QFloat(scheme) => match scheme.store {\n                QuantStore::Native => match scheme.value {\n                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),\n                    // e2m1 native is automatically packed by the kernels, so the actual storage is\n                    // 8 bits wide.\n                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {\n                        core::mem::size_of::<u8>()\n                    }\n                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {\n                        // Sub-byte values have fractional size\n                        0\n                    }\n                },\n                QuantStore::PackedU32(_) => core::mem::size_of::<u32>(),\n                QuantStore::PackedNative(_) => match scheme.value {\n                    QuantValue::E2M1 => core::mem::size_of::<u8>(),\n                    _ => 0,\n                },\n            },\n        }\n    }\n    /// Returns true if the data type is a floating point type.\n    pub fn is_float(&self) -> bool {\n        matches!(\n            self,\n            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16\n        )\n    }\n    /// Returns true if the data type is a signed integer type.\n    pub fn is_int(&self) -> bool {\n        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)\n    }\n    /// Returns true if the data type is an unsigned integer type.\n    pub fn is_uint(&self) -> bool {\n        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)\n    }\n\n    /// Returns true if the data type is a boolean type\n    pub fn is_bool(&self) -> bool {\n        matches!(self, DType::Bool(_))\n    }\n\n    /// Returns the data type name.\n    pub fn name(&self) -> &'static str {\n        match self {\n            DType::F64 => \"f64\",\n            DType::F32 => \"f32\",\n            DType::Flex32 => \"flex32\",\n            DType::F16 => \"f16\",\n            DType::BF16 => \"bf16\",\n            DType::I64 => \"i64\",\n            DType::I32 => \"i32\",\n            DType::I16 => \"i16\",\n            DType::I8 => \"i8\",\n            DType::U64 => \"u64\",\n            DType::U32 => \"u32\",\n            DType::U16 => \"u16\",\n            DType::U8 => \"u8\",\n            DType::Bool(store) => match store {\n                BoolStore::Native => \"bool\",\n                BoolStore::U8 => \"bool(u8)\",\n                BoolStore::U32 => \"bool(u32)\",\n            },\n            DType::QFloat(_) => \"qfloat\",\n        }\n    }\n}\n\n#[allow(missing_docs)]\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]\npub enum FloatDType {\n    F64,\n    F32,\n    Flex32,\n    F16,\n    BF16,\n}\n\nimpl From<DType> for FloatDType {\n    fn from(value: DType) -> Self {\n        match value {\n            DType::F64 => FloatDType::F64,\n            DType::F32 => FloatDType::F32,\n            DType::Flex32 => FloatDType::Flex32,\n            DType::F16 => FloatDType::F16,\n            DType::BF16 => FloatDType::BF16,\n            _ => panic!(\"Expected float data type, got {value:?}\"),\n        }\n    }\n}\n\nimpl From<FloatDType> for DType {\n    fn from(value: FloatDType) -> Self {\n        match value {\n            FloatDType::F64 => DType::F64,\n            FloatDType::F32 => DType::F32,\n            FloatDType::Flex32 => DType::Flex32,\n            FloatDType::F16 => DType::F16,\n            FloatDType::BF16 => DType::BF16,\n        }\n    }\n}\n\n#[allow(missing_docs)]\n#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]\npub enum IntDType {\n    I64,\n    I32,\n    I16,\n    I8,\n    U64,\n    U32,\n    U16,\n    U8,\n}\n\nimpl From<DType> for IntDType {\n    fn from(value: DType) -> Self {\n        match value {\n            DType::I64 => IntDType::I64,\n            DType::I32 => IntDType::I32,\n            DType::I16 => IntDType::I16,\n            DType::I8 => IntDType::I8,\n            DType::U64 => IntDType::U64,\n            DType::U32 => IntDType::U32,\n            DType::U16 => IntDType::U16,\n            DType::U8 => IntDType::U8,\n            _ => panic!(\"Expected int data type, got {value:?}\"),\n        }\n    }\n}\n\nimpl From<IntDType> for DType {\n    fn from(value: IntDType) -> Self {\n        match value {\n            IntDType::I64 => DType::I64,\n            IntDType::I32 => DType::I32,\n            IntDType::I16 => DType::I16,\n            IntDType::I8 => DType::I8,\n            IntDType::U64 => DType::U64,\n            IntDType::U32 => DType::U32,\n            IntDType::U16 => DType::U16,\n            IntDType::U8 => DType::U8,\n        }\n    }\n}\n\n#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]\n/// Data type used to store boolean values.\npub enum BoolStore {\n    /// Stored as native boolean type (e.g. `bool`).\n    Native,\n    /// Stored as 8-bit unsigned integer.\n    U8,\n    /// Stored as 32-bit unsigned integer.\n    U32,\n}\n\n/// Boolean dtype.\n///\n/// This is currently an alias to [`BoolStore`], since it only varies by the storage representation.\npub type BoolDType = BoolStore;\n\n#[allow(deprecated)]\nimpl From<DType> for BoolDType {\n    fn from(value: DType) -> Self {\n        match value {\n            DType::Bool(store) => match store {\n                BoolStore::Native => BoolDType::Native,\n                BoolStore::U8 => BoolDType::U8,\n                BoolStore::U32 => BoolDType::U32,\n            },\n            _ => panic!(\"Expected bool data type, got {value:?}\"),\n        }\n    }\n}\n\nimpl From<BoolDType> for DType {\n    fn from(value: BoolDType) -> Self {\n        match value {\n            BoolDType::Native => DType::Bool(BoolStore::Native),\n            BoolDType::U8 => DType::Bool(BoolStore::U8),\n            BoolDType::U32 => DType::Bool(BoolStore::U32),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/tensor/mod.rs",
    "content": "pub mod dtype;\npub mod quantization;\npub mod shape;\npub mod slice;\n\npub use dtype::*;\npub use quantization::*;\npub use shape::*;\npub use slice::*;\n\npub use cubecl_zspace::indexing::{self, *};\npub use cubecl_zspace::{Strides, metadata::Metadata, strides};\n\n/// Check if the current tensor is contiguous.\n///\n/// A tensor is considered contiguous if its elements are stored in memory\n/// such that the stride at position `k` is equal to the product of the shapes\n/// of all dimensions greater than `k`.\n///\n/// This means that strides increase as you move from the rightmost to the leftmost dimension.\npub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {\n    if shape.is_empty() {\n        return true;\n    }\n\n    for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {\n        if expected != stride {\n            return false;\n        }\n    }\n\n    true\n}\n\n/// Computes the strides for a contiguous tensor with the given shape.\n///\n/// In a contiguous row-major tensor, the stride for each dimension\n/// equals the product of all dimension sizes to its right.\npub fn contiguous_strides(shape: &[usize]) -> Strides {\n    let mut strides = strides![0; shape.len()];\n    let mut current = 1;\n\n    for (i, &dim) in shape.iter().enumerate().rev() {\n        strides[i] = current;\n        current *= dim;\n    }\n\n    strides\n}\n\n/// The action to take for a reshape operation.\n#[derive(Debug)]\npub enum ReshapeAction {\n    /// Updating the strides is sufficient to handle the reshape.\n    UpdateStrides {\n        /// The new strides.\n        strides: Strides,\n    },\n    /// The strides are not compatible, we should recompute the buffer.\n    Recompute,\n    /// The strides are already correct.\n    NoChange,\n}\n\n/// The reshape kind.\n#[derive(Debug)]\npub enum ReshapeAnalysis {\n    /// Original tensor is contiguous, can update the strides.\n    IsContiguous,\n    /// Original tensor is highly permutated, can't update the strides.\n    HighlyPermuted,\n    /// Only batch dimensions are added, can update the strides.\n    Broadcasted,\n    /// Dimensions are only split, can update the strides.\n    Split,\n    /// Original tensor is bigger than output shape.\n    SmallerRank,\n    /// New shape is the same.\n    NoChange,\n}\n\nimpl ReshapeAnalysis {\n    /// Returns the proper action to take for the current analysis.\n    fn action(self, shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {\n        match self {\n            ReshapeAnalysis::IsContiguous => ReshapeAction::UpdateStrides {\n                strides: contiguous_strides(shape_new),\n            },\n            ReshapeAnalysis::NoChange => ReshapeAction::NoChange,\n            ReshapeAnalysis::HighlyPermuted | ReshapeAnalysis::SmallerRank => {\n                ReshapeAction::Recompute\n            }\n            ReshapeAnalysis::Broadcasted => {\n                let shape_rank = shape.len();\n                let shape_new_rank = shape_new.len();\n                let n_new_batch = shape_new_rank - shape_rank;\n                let num_elems = shape.iter().product::<usize>();\n                let strides_new = broadcast_strides(n_new_batch, shape_rank, num_elems, strides);\n\n                ReshapeAction::UpdateStrides {\n                    strides: strides_new,\n                }\n            }\n            ReshapeAnalysis::Split => {\n                let strides_new = split_strides(shape, strides, shape_new);\n\n                ReshapeAction::UpdateStrides {\n                    strides: strides_new,\n                }\n            }\n        }\n    }\n}\n\n/// Returns the proper action to take when reshaping a tensor.\npub fn reshape_action(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> ReshapeAction {\n    reshape_analysis(shape, Some(strides), shape_new).action(shape, strides, shape_new)\n}\n\n/// Calculate the new strides given added batch dimensions.\npub fn broadcast_strides(\n    n_new_batch: usize,\n    rank_prev: usize,\n    num_elems: usize,\n    strides: &[usize],\n) -> Strides {\n    let mut strides_new = strides![num_elems; rank_prev + n_new_batch];\n\n    for (i, s) in strides.iter().enumerate() {\n        strides_new[i + n_new_batch] = *s;\n    }\n\n    strides_new\n}\n\n/// Calculate the new strides given added split dimensions.\npub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides {\n    let mut strides_new = strides![1; shape_new.len()];\n\n    let mut old_idx = shape.len() - 1;\n    let mut current_stride = strides[old_idx];\n    let mut dim_prod = 1;\n\n    for (i, dim) in shape_new.iter().enumerate().rev() {\n        dim_prod *= *dim;\n        strides_new[i] = current_stride;\n        if *dim == 1 {\n            continue;\n        } else if dim_prod == shape[old_idx] {\n            old_idx = old_idx.saturating_sub(1);\n            current_stride = strides[old_idx];\n            dim_prod = 1;\n        } else {\n            current_stride *= *dim;\n        }\n    }\n\n    strides_new\n}\n\n/// Returns the analysis of a reshape operation.\npub fn reshape_analysis(\n    shape: &[usize],\n    strides: Option<&[usize]>,\n    shape_new: &[usize],\n) -> ReshapeAnalysis {\n    let shape_rank = shape.len();\n    let shape_new_rank = shape_new.len();\n\n    let is_contiguous = match strides {\n        Some(strides) => is_contiguous(shape, strides),\n        None => false,\n    };\n\n    if is_contiguous {\n        return ReshapeAnalysis::IsContiguous;\n    }\n\n    if shape_new_rank < shape_rank {\n        return ReshapeAnalysis::SmallerRank;\n    }\n\n    let n_new_batch = shape_new_rank - shape_rank;\n\n    match n_new_batch > 0 {\n        true => {\n            if shape == &shape_new[n_new_batch..shape_new_rank]\n                && shape_new[0..n_new_batch].iter().all(|it| *it == 1)\n            {\n                return ReshapeAnalysis::Broadcasted;\n            } else {\n                let mut dim_prod = 1;\n                let mut old_idx = 0;\n                for dim in shape_new {\n                    dim_prod *= *dim;\n\n                    // We need to ignore unit dims because they don't affect analysis and break\n                    // things because they match the default `dim_prod`. If we don't do this,\n                    // reshapes like [2, 3] to [2, 3, 1] will panic from out of bounds access.\n                    if *dim == 1 {\n                        continue;\n                    } else if dim_prod == shape[old_idx] {\n                        dim_prod = 1;\n                        old_idx += 1;\n                    } else if dim_prod > shape[old_idx] {\n                        return ReshapeAnalysis::HighlyPermuted;\n                    }\n                }\n                return ReshapeAnalysis::Split;\n            }\n        }\n\n        false => {\n            if shape == shape_new {\n                return ReshapeAnalysis::NoChange;\n            }\n        }\n    };\n\n    ReshapeAnalysis::HighlyPermuted\n}\n"
  },
  {
    "path": "crates/burn-std/src/tensor/quantization.rs",
    "content": "//! Quantization data representation.\n\n// Re-exported types\npub use cubecl_common::quant::scheme::{\n    BlockSize, QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue,\n};\n\n/// Alignment (in bytes) for quantization parameters in serialized tensor data.\n///\n/// NOTE: This is currently f32-based since scales were originally always f32.\n/// With `QuantParam` now supporting different precisions (F16, BF16, etc.),\n/// this alignment may need to be revisited in the future.\npub const QPARAM_ALIGN: usize = core::mem::align_of::<f32>();\n\nuse alloc::vec::Vec;\nuse core::any::TypeId;\nuse num_traits::PrimInt;\nuse serde::{Deserialize, Serialize};\n\nuse crate::{DType, Metadata, Shape, bytes::Bytes};\n\n#[derive(\n    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,\n)]\n/// The precision of accumulating elements.\npub enum QuantAcc {\n    /// Full precision.\n    #[default]\n    F32,\n    /// Half precision.\n    F16,\n    /// bfloat16 precision.\n    BF16,\n}\n\n/// Specify if the output of an operation is quantized using the scheme of the input\n/// or returned unquantized.\n#[derive(\n    Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default,\n)]\npub enum QuantPropagation {\n    /// The output is quantized using the scheme of the input.\n    Propagate,\n    /// The output is not quantized.\n    #[default]\n    Inhibit,\n}\n\n/// The quantization tensor data parameters.\n#[derive(Clone, Debug)]\npub struct QParams<S> {\n    /// The scaling factor.\n    pub scales: S,\n}\n\n/// A quantization parameter tensor descriptor.\n#[derive(Debug, Clone, PartialEq, Eq)]\npub struct QParamTensor {\n    /// Start of the tensor in the buffer\n    pub offset_start: usize,\n    /// Offset of tensor end from the end of the buffer\n    pub offset_end: usize,\n    /// Metadata of the tensor\n    pub metadata: Metadata,\n    /// Data type of the tensor\n    pub dtype: DType,\n}\n\n/// Calculate the shape of the quantization parameters for a given tensor and level\npub fn params_shape(data_shape: &Shape, level: QuantLevel) -> Shape {\n    match level {\n        QuantLevel::Tensor => Shape::new([1]),\n        QuantLevel::Block(block_size) => {\n            let mut params_shape = data_shape.clone();\n            let block_size = block_size.to_dim_vec(data_shape.num_dims());\n\n            for (shape, block_size) in params_shape.iter_mut().zip(block_size) {\n                *shape = (*shape).div_ceil(block_size as usize);\n            }\n\n            params_shape\n        }\n    }\n}\n\n/// Quantized data bytes representation.\n///\n/// # Notes\n/// 1) The quantized values are packed into 32-bit unsigned integers. For example, int8\n///    quantized values pack 4 grouped values into a single `u32`. When unpacking these values,\n///    we make sure to retrieve only the meaningful values (and ignore the alignment padding).\n/// 2) Quantization parameters are appended to the tensor data.\n///    As such, the last bytes always correspond to the scale parameter.\n///    If the quantization scheme includes an offset (zero-point) parameter, it is next to last.\npub struct QuantizedBytes {\n    /// The quantized values and quantization parameters represented as bytes.\n    pub bytes: Bytes,\n    /// The quantization scheme.\n    pub scheme: QuantScheme,\n    /// The number of quantized elements.\n    pub num_elements: usize,\n}\n\nimpl QuantizedBytes {\n    /// Creates a new quantized bytes representation.\n    pub fn new<E: bytemuck::CheckedBitPattern + bytemuck::NoUninit>(\n        value: Vec<E>,\n        scheme: QuantScheme,\n        scales: &[f32],\n    ) -> Self {\n        let num_elements = value.len();\n        // Only used for 8-bit quantization data comparison in tests\n        if TypeId::of::<E>() != TypeId::of::<i8>() {\n            panic!(\"Invalid quantized type\");\n        }\n\n        // Re-interpret `Vec<E>` as `Vec<i8>` with `Vec::from_raw_parts`\n        let i8s: Vec<i8> = bytemuck::allocation::cast_vec(value);\n        let mut bytes = Bytes::from_elems(i8s);\n\n        match scheme.level {\n            QuantLevel::Tensor => {\n                let scale_bytes = bytemuck::bytes_of(&scales[0]);\n                bytes.extend_from_byte_slice_aligned(scale_bytes, QPARAM_ALIGN);\n            }\n            QuantLevel::Block(_block_size) => {\n                let mut scale_bytes = Vec::with_capacity(size_of_val(scales));\n                for scale in scales {\n                    scale_bytes.extend_from_slice(bytemuck::bytes_of(scale));\n                }\n                bytes.extend_from_byte_slice_aligned(scale_bytes.as_slice(), QPARAM_ALIGN);\n            }\n        }\n\n        Self {\n            bytes,\n            scheme,\n            num_elements,\n        }\n    }\n\n    /// Returns the int8 quantized values with the quantization parameters.\n    pub fn into_vec_i8(self) -> (Vec<i8>, QParams<Vec<f32>>) {\n        let (values, (qparams, num_params)) = self.split_values_off();\n\n        // Quantization parameters are added at the end of the tensor data.\n        // As such, the last bytes always correspond to the scale parameter(s).\n        // For example, per-block quantization can have multiple parameters for a single tensor:\n        // [scale, scale, scale, ...]\n        let scale_size = core::mem::size_of::<f32>(); // scale is stored as f32\n        let qparams_bytes: &[u8] = bytemuck::cast_slice(&qparams);\n        let total_bytes = qparams_bytes.len();\n\n        let scales_size = scale_size * num_params;\n\n        let scales = bytemuck::cast_slice(&qparams_bytes[total_bytes - scales_size..]).to_vec();\n\n        (values, QParams { scales })\n    }\n\n    fn split_i8_values(self, num_params: usize) -> (Vec<i8>, Vec<u32>) {\n        let mut values = read_bytes_to_i8(self.bytes);\n\n        let scale_size = num_params * size_of::<f32>();\n        let values_end = values.len() - scale_size;\n\n        let qparams = values.split_off(values_end);\n\n        let qparams = if (qparams.as_ptr() as usize).is_multiple_of(4) {\n            let mut qparams = core::mem::ManuallyDrop::new(qparams);\n            unsafe {\n                Vec::<u32>::from_raw_parts(\n                    qparams.as_mut_ptr() as _,\n                    qparams.len() / 4,\n                    qparams.capacity() / 4,\n                )\n            }\n        } else {\n            #[cfg(target_endian = \"little\")]\n            {\n                // SAFETY: quantized bytes representation is created from packed u32 values in little endian\n                bytemuck::cast_vec(qparams)\n            }\n            #[cfg(target_endian = \"big\")]\n            {\n                crate::quantization::pack_i8s_to_u32s(bytemuck::cast_vec(qparams))\n            }\n        };\n        (values, qparams)\n    }\n\n    /// Splits the quantized values of the tensor from the quantization parameters.\n    ///\n    /// Returns the values in i8 and a newly allocated vector containing the quantization parameters.\n    fn split_values_off(self) -> (Vec<i8>, (Vec<u32>, usize)) {\n        let num_params = match self.scheme.level {\n            QuantLevel::Tensor => 1,\n            QuantLevel::Block(block_size) => self.num_elements / block_size.num_elements(),\n        };\n\n        if let QuantStore::PackedU32(packed_dim) = self.scheme.store {\n            assert_eq!(\n                packed_dim, 0,\n                \"Packing must be on innermost dimension for splitting off values\"\n            );\n        }\n\n        let (values, qparams) = match self.scheme.store {\n            QuantStore::Native => self.split_i8_values(num_params),\n            QuantStore::PackedU32(_) => match self.scheme.value {\n                QuantValue::Q8F | QuantValue::Q8S => self.split_i8_values(num_params),\n                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {\n                    let mut values = self.bytes.try_into_vec::<u32>().unwrap();\n                    let scale_size = num_params; // size of f32 same as u32\n                    let values_end = values.len() - scale_size;\n\n                    let qparams = values.split_off(values_end);\n                    // Sub-byte values are unpacked as i8s for value equality tests\n                    let values = unpack_q_to_i8s(&values, self.num_elements, &self.scheme.value);\n                    (values, qparams)\n                }\n                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {\n                    unimplemented!(\"Not yet supported\")\n                }\n            },\n            QuantStore::PackedNative(_) => unimplemented!(\"Not yet supported\"),\n        };\n\n        (values, (qparams, num_params))\n    }\n}\n\nfn read_bytes_to_i8(bytes: Bytes) -> Vec<i8> {\n    match bytes.try_into_vec::<i8>() {\n        Ok(val) => val,\n        // Safety,\n        //\n        // `Vec<u8>` can be Re-interpreted as `Vec<i8>` since they share the same alignment.\n        Err(bytes) => unsafe { core::mem::transmute::<Vec<u8>, Vec<i8>>(bytes.to_vec()) },\n    }\n}\n\n/// Pack signed 8-bit integer values into a sequence of unsigned 32-bit integers.\npub fn pack_i8s_to_u32s(values: Vec<i8>) -> Vec<u32> {\n    // Shift and combine groups of four 8-bit values into a u32.\n    // Same as doing this:\n    //     let result = (d_u8 & 0xFF) << 24 | (c_u8 & 0xFF) << 16 | (b_u8 & 0xFF) << 8 | (a_u8 & 0xFF);\n    #[cfg(target_endian = \"big\")]\n    {\n        values\n            .chunks(4)\n            .map(|x| {\n                x.iter()\n                    .enumerate()\n                    .fold(0u32, |acc, (i, x)| acc | (*x as u32 & 0xFF) << (i * 8))\n            })\n            .collect()\n    }\n\n    // The order of bytes in little endian matches the above description, we just need to\n    // handle padding when the number of values is not a factor of 4\n    #[cfg(target_endian = \"little\")]\n    {\n        let mut values = values;\n        let remainder = values.len() % 4;\n        if remainder != 0 {\n            // Pad with zeros\n            values.extend(core::iter::repeat_n(0, 4 - remainder));\n        }\n\n        let len = values.len() / 4;\n        let capacity = values.capacity() / 4;\n\n        // Pre-forget the old vec and re-interpret as u32\n        let mut values = core::mem::ManuallyDrop::new(values);\n        let ptr = values.as_mut_ptr() as *mut u32;\n\n        unsafe { Vec::from_raw_parts(ptr, len, capacity) }\n    }\n}\n\n/// Unpack integer values into a sequence of signed 8-bit integers.\npub(crate) fn unpack_q_to_i8s<Q: PrimInt>(\n    values: &[Q],\n    numel: usize,\n    value: &QuantValue,\n) -> Vec<i8> {\n    let size_store = size_of::<Q>() * 8;\n    let size_quant = value.size_bits();\n    let num_quants = size_store / size_quant;\n    let mask = Q::from((1 << size_quant) - 1).unwrap();\n    let sign_shift = 8 - size_quant; // sign extension for sub-byte values\n    values\n        .iter()\n        .enumerate()\n        .flat_map(|(i, &packed)| {\n            // A single u32 could contain less than four 8-bit values...\n            let n = core::cmp::min(num_quants, numel - i * num_quants);\n            // Extract each 8-bit segment from u32 and cast back to i8\n            // Same as doing this (when 4 values are fully packed):\n            //     let a = (packed & 0xFF) as i8;\n            //     let b = ((packed >> 8) & 0xFF) as i8;\n            //     let c = ((packed >> 16) & 0xFF) as i8;\n            //     let d = ((packed >> 24) & 0xFF) as i8;\n            (0..n).map(move |i| {\n                let raw = (packed >> (i * size_quant) & mask).to_u8().unwrap();\n                ((raw << sign_shift) as i8) >> sign_shift\n            })\n        })\n        .collect()\n}\n\n#[cfg(test)]\nmod tests {\n\n    use super::*;\n    use alloc::vec;\n\n    #[test]\n    fn should_pack_i8s_to_u32() {\n        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127]);\n\n        assert_eq!(packed, vec![2147287680]);\n    }\n\n    #[test]\n    fn should_pack_i8s_to_u32_padded() {\n        let packed = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55]);\n        let packed_padded = pack_i8s_to_u32s(vec![-128, 2, -3, 127, 55, 0, 0, 0]);\n\n        assert_eq!(packed, vec![2147287680, 55]);\n        assert_eq!(packed, packed_padded);\n    }\n\n    #[test]\n    fn should_unpack_u32s_to_i8s() {\n        let unpacked = unpack_q_to_i8s(&[2147287680u32], 4, &QuantValue::Q8S);\n\n        assert_eq!(unpacked, vec![-128, 2, -3, 127]);\n    }\n\n    #[test]\n    fn should_unpack_u32s_to_i8s_padded() {\n        let unpacked = unpack_q_to_i8s(&[55u32], 1, &QuantValue::Q8S);\n\n        assert_eq!(unpacked, vec![55]);\n    }\n\n    #[test]\n    fn should_unpack_u32s_to_i8s_arange() {\n        let unpacked = unpack_q_to_i8s(\n            &[\n                0u32, 286331136, 286331153, 572657937, 572662306, 857874978, 858993459, 858993459,\n                1145324612, 1145324612, 1431655748, 1431655765, 1717982549, 1717986918, 2003199590,\n                2004318071,\n            ],\n            128,\n            &QuantValue::Q4S,\n        );\n\n        assert_eq!(\n            unpacked,\n            vec![\n                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n                2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n                3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,\n                5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,\n                6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7\n            ]\n        );\n    }\n\n    #[test]\n    fn should_pack_unpack_quantization_parameters_per_tensor_symmetric() {\n        // Quantized [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]\n        let scale = 0.03937008;\n        let values = vec![0i8, 25, 51, 76, 102, 127];\n\n        let q_bytes = QuantizedBytes::new(\n            values.clone(),\n            QuantScheme::default()\n                .with_value(QuantValue::Q8S)\n                .with_store(QuantStore::Native),\n            &[scale],\n        );\n\n        let (q_values, qparams) = q_bytes.into_vec_i8();\n\n        assert_eq!(qparams.scales, vec![scale]);\n\n        assert_eq!(q_values, values);\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/tensor/shape.rs",
    "content": "//! Tensor shape definition.\n\nuse super::{Slice, SliceArg};\nuse alloc::vec::Vec;\nuse core::ops::Range;\n\npub use crate::errors::ExpressionError;\n\npub use cubecl_zspace::{MetadataError, Shape, SmallVec, calculate_matmul_output, shape};\n\n/// Slice-related ops on [`Shape`]\npub trait SliceOps: Sized {\n    /// Convert shape dimensions to full covering ranges (0..dim) for each dimension.\n    fn into_ranges(self) -> Vec<Range<usize>>;\n    /// Converts slice arguments into an array of slice specifications for the shape.\n    ///\n    /// This method returns an array of `Slice` objects that can be used for slicing operations.\n    /// The slices are clamped to the shape's dimensions. Similar to `into_ranges()`, but\n    /// allows custom slice specifications instead of full ranges.\n    /// For creating complex slice specifications, use the [`s!`] macro.\n    ///\n    /// # Arguments\n    ///\n    /// * `slices` - An array of slice specifications, where each element can be:\n    ///   - A range (e.g., `2..5`)\n    ///   - An index\n    ///   - A `Slice` object\n    ///   - The output of the [`s!`] macro for advanced slicing\n    ///\n    /// # Behavior\n    ///\n    /// - Supports partial and full slicing in any number of dimensions.\n    /// - Missing ranges are treated as full slices if D > D2.\n    /// - Handles negative indices by wrapping around from the end of the dimension.\n    /// - Clamps ranges to the shape's dimensions if they exceed the bounds.\n    ///\n    /// # Returns\n    ///\n    /// An array of `Slice` objects corresponding to the provided slice specifications,\n    /// clamped to the shape's actual dimensions.\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// use burn_std::{Shape, Slice, s, SliceOps};\n    ///\n    /// fn example() {\n    ///     // 1D slicing\n    ///     let slices = Shape::new([4]).into_slices(1..4);\n    ///     assert_eq!(slices[0].to_range(4), 1..3);\n    ///\n    ///     // 2D slicing\n    ///     let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);\n    ///     assert_eq!(slices[0].to_range(3), 1..3);\n    ///     assert_eq!(slices[1].to_range(4), 0..2);\n    ///\n    ///     // Using negative indices\n    ///     let slices = Shape::new([3]).into_slices(..-2);\n    ///     assert_eq!(slices[0].to_range(3), 0..1);\n    ///\n    ///     // Using the slice macro to select different ranges\n    ///     let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);\n    ///     assert_eq!(slices[0].to_range(2), 0..2);\n    ///     assert_eq!(slices[1].to_range(3), 1..2);\n    /// }\n    /// ```\n    ///\n    /// # See Also\n    ///\n    /// - [`s!`] - The recommended macro for creating slice specifications\n    /// - [`Shape::into_ranges`] - Convert to full covering ranges\n    ///\n    /// [`s!`]: crate::s!\n    fn into_slices<S>(self, slices: S) -> Vec<Slice>\n    where\n        S: SliceArg;\n    /// Compute the output shape from the given slices.\n    fn slice(self, slices: &[Slice]) -> Result<Self, MetadataError>;\n}\n\nimpl SliceOps for Shape {\n    fn into_ranges(self) -> Vec<Range<usize>> {\n        self.iter().map(|&d| 0..d).collect()\n    }\n\n    fn into_slices<S>(self, slices: S) -> Vec<Slice>\n    where\n        S: SliceArg,\n    {\n        slices.into_slices(&self)\n    }\n\n    fn slice(mut self, slices: &[Slice]) -> Result<Self, MetadataError> {\n        if slices.len() > self.rank() {\n            return Err(MetadataError::RankMismatch {\n                left: self.rank(),\n                right: slices.len(),\n            });\n        }\n\n        slices\n            .iter()\n            .zip(self.iter_mut())\n            .for_each(|(slice, dim_size)| *dim_size = slice.output_size(*dim_size));\n\n        Ok(self)\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::identity_op, reason = \"useful for clarity\")]\nmod tests {\n    use super::*;\n    use crate::s;\n    use alloc::vec;\n\n    #[test]\n    fn test_into_ranges() {\n        let dims = [2, 3, 4, 5];\n        let shape = Shape::new(dims);\n        assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]);\n    }\n\n    #[allow(clippy::single_range_in_vec_init)]\n    #[test]\n    fn test_into_slices() {\n        let slices = Shape::new([3]).into_slices(1..4);\n        assert_eq!(slices[0].to_range(3), 1..3);\n\n        let slices = Shape::new([3, 4]).into_slices(s![1..4, 0..2]);\n        assert_eq!(slices[0].to_range(3), 1..3);\n        assert_eq!(slices[1].to_range(4), 0..2);\n\n        let slices = Shape::new([3]).into_slices(..-2);\n        assert_eq!(slices[0].to_range(3), 0..1);\n\n        let slices = Shape::new([2, 3, 4]).into_slices(s![.., 1..-1]);\n        assert_eq!(slices[0].to_range(2), 0..2);\n        assert_eq!(slices[1].to_range(3), 1..2);\n\n        let slices = Shape::new([2, 3, 4]).into_slices(s![..20, 2]);\n        assert_eq!(slices[0].to_range(2), 0..2);\n        assert_eq!(slices[1].to_range(3), 2..3);\n    }\n\n    #[test]\n    fn test_shape_as_slice() {\n        let dims = [2, 3, 4, 5];\n        let shape = Shape::new(dims);\n\n        assert_eq!(shape.as_slice(), dims.as_slice());\n\n        // Deref coercion\n        let shape_slice: &[usize] = &shape;\n        assert_eq!(shape_slice, *&[2, 3, 4, 5]);\n    }\n\n    #[test]\n    fn test_shape_as_mut_slice() {\n        let mut dims = [2, 3, 4, 5];\n        let mut shape = Shape::new(dims);\n\n        let shape_mut = shape.as_mut_slice();\n        assert_eq!(shape_mut, dims.as_mut_slice());\n        shape_mut[1] = 6;\n\n        assert_eq!(shape_mut, &[2, 6, 4, 5]);\n\n        let mut shape = Shape::new(dims);\n        let shape = &mut shape[..];\n        shape[1] = 6;\n\n        assert_eq!(shape, shape_mut)\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_basic() {\n        // Test basic slicing with step=1\n        let slices = [\n            Slice::new(0, Some(5), 1), // 5 elements\n            Slice::new(2, Some(8), 1), // 6 elements\n        ];\n        let original_shape = Shape::new([10, 10, 10]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([5, 6, 10]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_with_positive_steps() {\n        // Test slicing with various positive steps\n        let slices = [\n            Slice::new(0, Some(10), 2), // [0,2,4,6,8] -> 5 elements\n            Slice::new(1, Some(9), 3),  // [1,4,7] -> 3 elements\n            Slice::new(0, Some(7), 4),  // [0,4] -> 2 elements\n        ];\n        let original_shape = Shape::new([20, 20, 20, 30]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([5, 3, 2, 30]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_with_negative_steps() {\n        // Test slicing with negative steps (backward iteration)\n        let slices = [\n            Slice::new(0, Some(10), -1), // 10 elements traversed backward\n            Slice::new(2, Some(8), -2),  // [7,5,3] -> 3 elements\n        ];\n        let original_shape = Shape::new([20, 20, 20]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([10, 3, 20]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_mixed_steps() {\n        // Test with a mix of positive, negative, and unit steps\n        let slices = [\n            Slice::from_range_stepped(1..6, 1),   // 5 elements\n            Slice::from_range_stepped(0..10, -3), // [9,6,3,0] -> 4 elements\n            Slice::from_range_stepped(2..14, 4),  // [2,6,10] -> 3 elements\n        ];\n        let original_shape = Shape::new([20, 20, 20]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([5, 4, 3]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_partial_dims() {\n        // Test when slices has fewer dimensions than original shape\n        let slices = [\n            Slice::from_range_stepped(2..7, 2), // [2,4,6] -> 3 elements\n        ];\n        let original_shape = Shape::new([10, 20, 30, 40]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([3, 20, 30, 40]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_edge_cases() {\n        // Test edge cases with small ranges and large steps\n        let slices = [\n            Slice::from_range_stepped(0..1, 1),    // Single element\n            Slice::from_range_stepped(0..10, 100), // Step larger than range -> 1 element\n            Slice::from_range_stepped(5..5, 1),    // Empty range -> 0 elements\n        ];\n        let original_shape = Shape::new([10, 20, 30]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([1, 1, 0]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_empty() {\n        // Test with no slice infos (should return original shape)\n        let slices = [];\n        let original_shape = Shape::new([10, 20, 30]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([10, 20, 30]));\n    }\n\n    #[test]\n    fn test_shape_slice_output_shape_uneven_division() {\n        // Test cases where range size doesn't divide evenly by step\n        let slices = [\n            Slice::from_range_stepped(0..7, 3), // ceil(7/3) = 3 elements: [0,3,6]\n            Slice::from_range_stepped(0..11, 4), // ceil(11/4) = 3 elements: [0,4,8]\n            Slice::from_range_stepped(1..10, 5), // ceil(9/5) = 2 elements: [1,6]\n        ];\n        let original_shape = Shape::new([20, 20, 20]);\n        let result = original_shape.slice(&slices).unwrap();\n        assert_eq!(result, Shape::new([3, 3, 2]));\n    }\n}\n"
  },
  {
    "path": "crates/burn-std/src/tensor/slice.rs",
    "content": "//! Tensor slice utilities.\n\nuse crate::Shape;\nuse crate::indexing::AsIndex;\nuse alloc::format;\nuse alloc::vec::Vec;\nuse core::fmt::{Display, Formatter};\nuse core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};\nuse core::str::FromStr;\n\n/// Trait for slice arguments that can be converted into an array of slices.\n/// This allows the `slice` method to accept both single slices (from `s![..]`)\n/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).\npub trait SliceArg {\n    /// Convert to an vec of slices with clamping to shape dimensions.\n    ///\n    /// Returns a [Slice] for each dimension in `shape`.\n    fn into_slices(self, shape: &Shape) -> Vec<Slice>;\n}\n\nimpl<S: Into<Slice> + Clone> SliceArg for &[S] {\n    fn into_slices(self, shape: &Shape) -> Vec<Slice> {\n        assert!(\n            self.len() <= shape.num_dims(),\n            \"Too many slices provided for shape, got {} but expected at most {}\",\n            self.len(),\n            shape.num_dims()\n        );\n\n        shape\n            .iter()\n            .enumerate()\n            .map(|(i, dim_size)| {\n                let slice = if i >= self.len() {\n                    Slice::full()\n                } else {\n                    self[i].clone().into()\n                };\n                // Apply shape clamping by converting to range and back\n                let clamped_range = slice.to_range(*dim_size);\n                Slice::new(\n                    clamped_range.start as isize,\n                    Some(clamped_range.end as isize),\n                    slice.step(),\n                )\n            })\n            .collect::<Vec<_>>()\n    }\n}\n\nimpl SliceArg for &Vec<Slice> {\n    fn into_slices(self, shape: &Shape) -> Vec<Slice> {\n        self.as_slice().into_slices(shape)\n    }\n}\n\nimpl<const R: usize, T> SliceArg for [T; R]\nwhere\n    T: Into<Slice> + Clone,\n{\n    fn into_slices(self, shape: &Shape) -> Vec<Slice> {\n        self.as_slice().into_slices(shape)\n    }\n}\n\nimpl<T> SliceArg for T\nwhere\n    T: Into<Slice>,\n{\n    fn into_slices(self, shape: &Shape) -> Vec<Slice> {\n        let slice: Slice = self.into();\n        [slice].as_slice().into_slices(shape)\n    }\n}\n\n/// Slice argument constructor for tensor indexing.\n///\n/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.\n/// It converts various range syntax forms into a `&[Slice]` that can be used with\n/// `tensor.slice()` and `tensor.slice_assign()` operations.\n///\n/// # Syntax Overview\n///\n/// ## Basic Forms\n///\n/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)\n/// * **`s![range]`** - Slice a range of elements\n/// * **`s![range;step]`** - Slice a range with a custom step\n/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms\n///\n/// ## Range Types\n///\n/// All standard Rust range types are supported:\n/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)\n/// * **`a..=b`** - From `a` to `b` (both inclusive)\n/// * **`a..`** - From `a` to the end\n/// * **`..b`** - From the beginning to `b` (exclusive)\n/// * **`..=b`** - From the beginning to `b` (inclusive)\n/// * **`..`** - The full range (all elements)\n///\n/// ## Negative Indices\n///\n/// Negative indices count from the end of the axis:\n/// * **`-1`** refers to the last element\n/// * **`-2`** refers to the second-to-last element\n/// * And so on...\n///\n/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`\n///\n/// ## Step Syntax\n///\n/// Steps control the stride between selected elements:\n/// * **`;step`** after a range specifies the step\n/// * **Positive steps** select every nth element going forward\n/// * **Negative steps** select every nth element going backward\n/// * Default step is `1` when not specified\n/// * Step cannot be `0`\n///\n/// ### Negative Step Behavior\n///\n/// With negative steps, the range bounds still specify *which* elements to include,\n/// but the traversal order is reversed:\n///\n/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)\n/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)\n/// * `s![..;-1]` reverses the entire axis\n///\n/// This matches the semantics of NumPy and the ndarray crate.\n///\n/// # Examples\n///\n/// ## Basic Slicing\n///\n/// ```rust,ignore\n/// use burn_tensor::{Tensor, s};\n///\n/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {\n/// // Select rows 0-5 (exclusive)\n/// let subset = tensor.slice(s![0..5, .., ..]);\n///\n/// // Select the last row\n/// let last_row = tensor.slice(s![-1, .., ..]);\n///\n/// // Select columns 2, 3, 4\n/// let cols = tensor.slice(s![.., 2..5, ..]);\n///\n/// // Select a single element at position [1, 2, 3]\n/// let element = tensor.slice(s![1, 2, 3]);\n/// # }\n/// ```\n///\n/// ## Slicing with Steps\n///\n/// ```rust,ignore\n/// use burn_tensor::{Tensor, s};\n///\n/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {\n/// // Select every 2nd row\n/// let even_rows = tensor.slice(s![0..10;2, ..]);\n///\n/// // Select every 3rd column\n/// let cols = tensor.slice(s![.., 0..9;3]);\n///\n/// // Select every 2nd element in reverse order\n/// let reversed_even = tensor.slice(s![10..0;-2, ..]);\n/// # }\n/// ```\n///\n/// ## Reversing Dimensions\n///\n/// ```rust,ignore\n/// use burn_tensor::{Tensor, s};\n///\n/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {\n/// // Reverse the first dimension\n/// let reversed = tensor.slice(s![..;-1, ..]);\n///\n/// // Reverse both dimensions\n/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);\n///\n/// // Reverse a specific range\n/// let range_reversed = tensor.slice(s![2..8;-1, ..]);\n/// # }\n/// ```\n///\n/// ## Complex Multi-dimensional Slicing\n///\n/// ```rust,ignore\n/// use burn_tensor::{Tensor, s};\n///\n/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {\n/// // Mix of different slice types\n/// let complex = tensor.slice(s![\n///     0..10;2,    // Every 2nd element from 0 to 10\n///     ..,         // All elements in dimension 1\n///     5..15;-3,   // Every 3rd element from 14 down to 5\n///     -1          // Last element in dimension 3\n/// ]);\n///\n/// // Using inclusive ranges\n/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);\n///\n/// // Negative indices with steps\n/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);\n/// # }\n/// ```\n///\n/// ## Slice Assignment\n///\n/// ```rust,ignore\n/// use burn_tensor::{Tensor, s};\n///\n/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {\n/// // Assign to every 2nd row\n/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);\n///\n/// // Assign to a reversed slice\n/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);\n/// # }\n/// ```\n#[macro_export]\nmacro_rules! s {\n    // Empty - should not happen\n    [] => {\n        compile_error!(\"Empty slice specification\")\n    };\n\n    // Single expression with step\n    [$range:expr; $step:expr] => {\n        {\n            #[allow(clippy::reversed_empty_ranges)]\n            {\n                $crate::tensor::Slice::from_range_stepped($range, $step)\n            }\n        }\n    };\n\n    // Single expression without step (no comma after)\n    [$range:expr] => {\n        {\n            #[allow(clippy::reversed_empty_ranges)]\n            {\n                $crate::tensor::Slice::from($range)\n            }\n        }\n    };\n\n    // Two or more expressions with first having step\n    [$range:expr; $step:expr, $($rest:tt)*] => {\n        {\n            #[allow(clippy::reversed_empty_ranges)]\n            {\n                $crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)\n            }\n        }\n    };\n\n    // Two or more expressions with first not having step\n    [$range:expr, $($rest:tt)*] => {\n        {\n            #[allow(clippy::reversed_empty_ranges)]\n            {\n                $crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)\n            }\n        }\n    };\n\n    // Internal: finished parsing\n    (@internal [$($acc:expr),*]) => {\n        [$($acc),*]\n    };\n\n    // Internal: parse range with step followed by comma\n    (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {\n        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)\n    };\n\n    // Internal: parse range with step at end\n    (@internal [$($acc:expr),*] $range:expr; $step:expr) => {\n        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])\n    };\n\n    // Internal: parse range without step followed by comma\n    (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {\n        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)\n    };\n\n    // Internal: parse range without step at end\n    (@internal [$($acc:expr),*] $range:expr) => {\n        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])\n    };\n}\n\n/// A slice specification for a single tensor dimension.\n///\n/// This struct represents a range with an optional step, used for advanced indexing\n/// operations on tensors. It is typically created using the [`s!`] macro rather than\n/// constructed directly.\n///\n/// # Fields\n///\n/// * `start` - The starting index (inclusive). Negative values count from the end.\n/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.\n/// * `step` - The stride between elements. Must be non-zero.\n///\n/// # Index Interpretation\n///\n/// - **Positive indices**: Count from the beginning (0-based)\n/// - **Negative indices**: Count from the end (-1 is the last element)\n/// - **Bounds checking**: Indices are clamped to valid ranges\n///\n/// # Step Behavior\n///\n/// - **Positive step**: Traverse forward through the range\n/// - **Negative step**: Traverse backward through the range\n/// - **Step size**: Determines how many elements to skip\n///\n/// # Examples\n///\n/// While you typically use the [`s!`] macro, you can also construct slices directly:\n///\n/// ```rust,ignore\n/// use burn_tensor::Slice;\n///\n/// // Equivalent to s![2..8]\n/// let slice1 = Slice::new(2, Some(8), 1);\n///\n/// // Equivalent to s![0..10;2]\n/// let slice2 = Slice::new(0, Some(10), 2);\n///\n/// // Equivalent to s![..;-1] (reverse)\n/// let slice3 = Slice::new(0, None, -1);\n/// ```\n///\n/// See also the [`s!`] macro for the preferred way to create slices.\n#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\npub struct Slice {\n    /// Slice start index.\n    pub start: isize,\n    /// Slice end index (exclusive).\n    pub end: Option<isize>,\n    /// Step between elements (default: 1).\n    pub step: isize,\n}\n\n/// Defines an [`Iterator`] over a [`Slice`].\n#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\npub struct SliceIter {\n    slice: Slice,\n    current: isize,\n}\n\nimpl Iterator for SliceIter {\n    type Item = isize;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        let next = self.current;\n        self.current += self.slice.step;\n\n        if let Some(end) = self.slice.end {\n            if self.slice.is_reversed() {\n                if next <= end {\n                    return None;\n                }\n            } else if next >= end {\n                return None;\n            }\n        }\n\n        Some(next)\n    }\n}\n\n/// Note: Unbounded [`Slice`]s produce infinite iterators.\nimpl IntoIterator for Slice {\n    type Item = isize;\n    type IntoIter = SliceIter;\n\n    fn into_iter(self) -> Self::IntoIter {\n        SliceIter {\n            slice: self,\n            current: self.start,\n        }\n    }\n}\n\nimpl Default for Slice {\n    fn default() -> Self {\n        Self::full()\n    }\n}\n\nimpl Slice {\n    /// Creates a new slice with start, end, and step\n    pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {\n        assert!(step != 0, \"Step cannot be zero\");\n        Self { start, end, step }\n    }\n\n    /// Creates a slice that represents the full range.\n    pub const fn full() -> Self {\n        Self::new(0, None, 1)\n    }\n\n    /// Creates a slice that represents a single index\n    pub fn index(idx: isize) -> Self {\n        Self {\n            start: idx,\n            end: handle_signed_inclusive_end(idx),\n            step: 1,\n        }\n    }\n\n    /// Converts the slice to a vector.\n    pub fn into_vec(self) -> Vec<isize> {\n        assert!(\n            self.end.is_some(),\n            \"Slice must have an end to convert to a vector: {self:?}\"\n        );\n        self.into_iter().collect()\n    }\n\n    /// Clips the slice to a maximum size.\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// assert_eq!(\n    ///     Slice::new(0, None, 1).bound_to(10),\n    ///     Slice::new(0, Some(10), 1));\n    /// assert_eq!(\n    ///     Slice::new(0, Some(5), 1).bound_to(10),\n    ///     Slice::new(0, Some(5), 1));\n    /// assert_eq!(\n    ///     Slice::new(0, None, -1).bound_to(10),\n    ///     Slice::new(0, Some(-11), -1));\n    /// assert_eq!(\n    ///     Slice::new(0, Some(-5), -1).bound_to(10),\n    ///     Slice::new(0, Some(-5), -1));\n    /// ```\n    pub fn bound_to(self, size: usize) -> Self {\n        let mut bounds = size as isize;\n\n        if let Some(end) = self.end {\n            if end > 0 {\n                bounds = end.min(bounds);\n            } else {\n                bounds = end.max(-(bounds + 1));\n            }\n        } else if self.is_reversed() {\n            bounds = -(bounds + 1);\n        }\n\n        Self {\n            end: Some(bounds),\n            ..self\n        }\n    }\n\n    /// Creates a slice with a custom step\n    pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {\n        assert!(step != 0, \"Step cannot be zero\");\n        Self { start, end, step }\n    }\n\n    /// Creates a slice from a range with a specified step\n    pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {\n        assert!(step != 0, \"Step cannot be zero\");\n        let mut slice = range.into();\n        slice.step = step;\n        slice\n    }\n\n    /// Returns the step of the slice\n    pub fn step(&self) -> isize {\n        self.step\n    }\n\n    /// Returns the range for this slice given a dimension size\n    pub fn range(&self, size: usize) -> Range<usize> {\n        self.to_range(size)\n    }\n\n    /// Convert this slice to a range for a dimension of the given size.\n    ///\n    /// # Arguments\n    ///\n    /// * `size` - The size of the dimension to slice.\n    ///\n    /// # Returns\n    ///\n    /// A `Range<usize>` representing the slice bounds.\n    pub fn to_range(&self, size: usize) -> Range<usize> {\n        // Always return a valid range with start <= end\n        // The step information will be handled separately\n        let start = convert_signed_index(self.start, size);\n        let end = match self.end {\n            Some(end) => convert_signed_index(end, size),\n            None => size,\n        };\n        start..end\n    }\n\n    /// Converts the slice into a range and step tuple\n    pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {\n        let range = self.to_range(size);\n        (range, self.step)\n    }\n\n    /// Returns true if the step is negative\n    pub fn is_reversed(&self) -> bool {\n        self.step < 0\n    }\n\n    /// Calculates the output size for this slice operation\n    pub fn output_size(&self, dim_size: usize) -> usize {\n        let range = self.to_range(dim_size);\n        // Handle empty slices (start >= end)\n        if range.start >= range.end {\n            return 0;\n        }\n        let len = range.end - range.start;\n        if self.step.unsigned_abs() == 1 {\n            len\n        } else {\n            len.div_ceil(self.step.unsigned_abs())\n        }\n    }\n}\n\nfn convert_signed_index(index: isize, size: usize) -> usize {\n    if index < 0 {\n        (size as isize + index).max(0) as usize\n    } else {\n        (index as usize).min(size)\n    }\n}\n\nfn handle_signed_inclusive_end(end: isize) -> Option<isize> {\n    match end {\n        -1 => None,\n        end => Some(end + 1),\n    }\n}\n\nimpl<I: AsIndex> From<Range<I>> for Slice {\n    fn from(r: Range<I>) -> Self {\n        Self {\n            start: r.start.as_index(),\n            end: Some(r.end.as_index()),\n            step: 1,\n        }\n    }\n}\n\nimpl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {\n    fn from(r: RangeInclusive<I>) -> Self {\n        Self {\n            start: r.start().as_index(),\n            end: handle_signed_inclusive_end(r.end().as_index()),\n            step: 1,\n        }\n    }\n}\n\nimpl<I: AsIndex> From<RangeFrom<I>> for Slice {\n    fn from(r: RangeFrom<I>) -> Self {\n        Self {\n            start: r.start.as_index(),\n            end: None,\n            step: 1,\n        }\n    }\n}\n\nimpl<I: AsIndex> From<RangeTo<I>> for Slice {\n    fn from(r: RangeTo<I>) -> Self {\n        Self {\n            start: 0,\n            end: Some(r.end.as_index()),\n            step: 1,\n        }\n    }\n}\n\nimpl<I: AsIndex> From<RangeToInclusive<I>> for Slice {\n    fn from(r: RangeToInclusive<I>) -> Self {\n        Self {\n            start: 0,\n            end: handle_signed_inclusive_end(r.end.as_index()),\n            step: 1,\n        }\n    }\n}\n\nimpl From<RangeFull> for Slice {\n    fn from(_: RangeFull) -> Self {\n        Self {\n            start: 0,\n            end: None,\n            step: 1,\n        }\n    }\n}\n\nimpl From<usize> for Slice {\n    fn from(i: usize) -> Self {\n        Slice::index(i as isize)\n    }\n}\n\nimpl From<isize> for Slice {\n    fn from(i: isize) -> Self {\n        Slice::index(i)\n    }\n}\n\nimpl From<i32> for Slice {\n    fn from(i: i32) -> Self {\n        Slice::index(i as isize)\n    }\n}\n\nimpl From<i64> for Slice {\n    fn from(i: i64) -> Self {\n        Slice::index(i as isize)\n    }\n}\n\nimpl Display for Slice {\n    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {\n        if self.step == 1\n            && let Some(end) = self.end\n            && self.start == end - 1\n        {\n            f.write_fmt(format_args!(\"{}\", self.start))\n        } else {\n            if self.start != 0 {\n                f.write_fmt(format_args!(\"{}\", self.start))?;\n            }\n            f.write_str(\"..\")?;\n            if let Some(end) = self.end {\n                f.write_fmt(format_args!(\"{}\", end))?;\n            }\n            if self.step != 1 {\n                f.write_fmt(format_args!(\";{}\", self.step))?;\n            }\n            Ok(())\n        }\n    }\n}\n\nimpl FromStr for Slice {\n    type Err = crate::ExpressionError;\n\n    fn from_str(source: &str) -> Result<Self, Self::Err> {\n        let mut s = source.trim();\n\n        let parse_int = |v: &str| -> Result<isize, Self::Err> {\n            v.parse::<isize>().map_err(|e| {\n                crate::ExpressionError::parse_error(\n                    format!(\"Invalid integer: '{v}': {}\", e),\n                    source,\n                )\n            })\n        };\n\n        let mut start: isize = 0;\n        let mut end: Option<isize> = None;\n        let mut step: isize = 1;\n\n        if let Some((head, tail)) = s.split_once(\";\") {\n            step = parse_int(tail)?;\n            s = head;\n        }\n\n        if s.is_empty() {\n            return Err(crate::ExpressionError::parse_error(\n                \"Empty expression\",\n                source,\n            ));\n        }\n\n        if let Some((start_s, end_s)) = s.split_once(\"..\") {\n            if !start_s.is_empty() {\n                start = parse_int(start_s)?;\n            }\n            if !end_s.is_empty() {\n                if let Some(end_s) = end_s.strip_prefix('=') {\n                    end = Some(parse_int(end_s)? + 1);\n                } else {\n                    end = Some(parse_int(end_s)?);\n                }\n            }\n        } else {\n            start = parse_int(s)?;\n            end = Some(start + 1);\n        }\n\n        if step == 0 {\n            return Err(crate::ExpressionError::invalid_expression(\n                \"Step cannot be zero\",\n                source,\n            ));\n        }\n\n        Ok(Slice::new(start, end, step))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use alloc::string::ToString;\n    use alloc::vec;\n\n    #[test]\n    fn test_slice_to_str() {\n        assert_eq!(Slice::new(0, None, 1).to_string(), \"..\");\n\n        assert_eq!(Slice::new(0, Some(1), 1).to_string(), \"0\");\n\n        assert_eq!(Slice::new(0, Some(10), 1).to_string(), \"..10\");\n        assert_eq!(Slice::new(1, Some(10), 1).to_string(), \"1..10\");\n\n        assert_eq!(Slice::new(-3, Some(10), -2).to_string(), \"-3..10;-2\");\n    }\n\n    #[test]\n    fn test_slice_from_str() {\n        assert_eq!(\"1\".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));\n        assert_eq!(\"..\".parse::<Slice>(), Ok(Slice::new(0, None, 1)));\n        assert_eq!(\"..3\".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));\n        assert_eq!(\"..=3\".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));\n\n        assert_eq!(\"-12..3\".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));\n        assert_eq!(\"..;-1\".parse::<Slice>(), Ok(Slice::new(0, None, -1)));\n\n        assert_eq!(\"..=3;-2\".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));\n\n        assert_eq!(\n            \"..;0\".parse::<Slice>(),\n            Err(crate::ExpressionError::invalid_expression(\n                \"Step cannot be zero\",\n                \"..;0\"\n            ))\n        );\n\n        assert_eq!(\n            \"\".parse::<Slice>(),\n            Err(crate::ExpressionError::parse_error(\"Empty expression\", \"\"))\n        );\n        assert_eq!(\n            \"a\".parse::<Slice>(),\n            Err(crate::ExpressionError::parse_error(\n                \"Invalid integer: 'a': invalid digit found in string\",\n                \"a\"\n            ))\n        );\n        assert_eq!(\n            \"..a\".parse::<Slice>(),\n            Err(crate::ExpressionError::parse_error(\n                \"Invalid integer: 'a': invalid digit found in string\",\n                \"..a\"\n            ))\n        );\n        assert_eq!(\n            \"a:b:c\".parse::<Slice>(),\n            Err(crate::ExpressionError::parse_error(\n                \"Invalid integer: 'a:b:c': invalid digit found in string\",\n                \"a:b:c\"\n            ))\n        );\n    }\n\n    #[test]\n    fn test_slice_output_size() {\n        // Test the output_size method directly\n        assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);\n        assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);\n        assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)\n        assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);\n        assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);\n        assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)\n        assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range\n    }\n\n    #[test]\n    fn test_bound_to() {\n        assert_eq!(\n            Slice::new(0, None, 1).bound_to(10),\n            Slice::new(0, Some(10), 1)\n        );\n        assert_eq!(\n            Slice::new(0, Some(5), 1).bound_to(10),\n            Slice::new(0, Some(5), 1)\n        );\n\n        assert_eq!(\n            Slice::new(0, None, -1).bound_to(10),\n            Slice::new(0, Some(-11), -1)\n        );\n        assert_eq!(\n            Slice::new(0, Some(-5), -1).bound_to(10),\n            Slice::new(0, Some(-5), -1)\n        );\n    }\n\n    #[test]\n    fn test_slice_iter() {\n        assert_eq!(\n            Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),\n            vec![2]\n        );\n        assert_eq!(\n            Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),\n            vec![3, 2, 1, 0]\n        );\n\n        assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);\n\n        assert_eq!(\n            Slice::new(3, None, 2)\n                .into_iter()\n                .take(3)\n                .collect::<Vec<_>>(),\n            vec![3, 5, 7]\n        );\n        assert_eq!(\n            Slice::new(3, None, 2)\n                .bound_to(8)\n                .into_iter()\n                .collect::<Vec<_>>(),\n            vec![3, 5, 7]\n        );\n    }\n\n    #[test]\n    #[should_panic(\n        expected = \"Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }\"\n    )]\n    fn test_unbound_slice_into_vec() {\n        Slice::new(0, None, 1).into_vec();\n    }\n\n    #[test]\n    fn into_slices_should_return_for_all_shape_dims() {\n        let slice = s![1];\n        let shape = Shape::new([2, 3, 1]);\n\n        let slices = slice.into_slices(&shape);\n\n        assert_eq!(slices.len(), shape.len());\n\n        assert_eq!(slices[0], Slice::new(1, Some(2), 1));\n        assert_eq!(slices[1], Slice::new(0, Some(3), 1));\n        assert_eq!(slices[2], Slice::new(0, Some(1), 1));\n\n        let slice = s![1, 0..2];\n        let slices = slice.into_slices(&shape);\n\n        assert_eq!(slices.len(), shape.len());\n\n        assert_eq!(slices[0], Slice::new(1, Some(2), 1));\n        assert_eq!(slices[1], Slice::new(0, Some(2), 1));\n        assert_eq!(slices[2], Slice::new(0, Some(1), 1));\n\n        let slice = s![..];\n        let slices = slice.into_slices(&shape);\n\n        assert_eq!(slices.len(), shape.len());\n\n        assert_eq!(slices[0], Slice::new(0, Some(2), 1));\n        assert_eq!(slices[1], Slice::new(0, Some(3), 1));\n        assert_eq!(slices[2], Slice::new(0, Some(1), 1));\n    }\n\n    #[test]\n    fn into_slices_all_dimensions() {\n        let slice = s![1, ..2, ..];\n        let shape = Shape::new([2, 3, 1]);\n\n        let slices = slice.into_slices(&shape);\n\n        assert_eq!(slices.len(), shape.len());\n\n        assert_eq!(slices[0], Slice::new(1, Some(2), 1));\n        assert_eq!(slices[1], Slice::new(0, Some(2), 1));\n        assert_eq!(slices[2], Slice::new(0, Some(1), 1));\n    }\n\n    #[test]\n    fn into_slices_supports_empty_dimensions() {\n        let slice = s![.., 1, ..];\n        let shape = Shape::new([0, 3, 1]);\n\n        let slices = slice.into_slices(&shape);\n\n        assert_eq!(slices.len(), shape.len());\n\n        assert_eq!(slices[0], Slice::new(0, Some(0), 1));\n        assert_eq!(slices[1], Slice::new(1, Some(2), 1));\n        assert_eq!(slices[2], Slice::new(0, Some(1), 1));\n    }\n\n    #[test]\n    #[should_panic = \"Too many slices provided for shape\"]\n    fn into_slices_should_match_shape_rank() {\n        let slice = s![.., 1, ..];\n        let shape = Shape::new([3, 1]);\n\n        let _ = slice.into_slices(&shape);\n    }\n\n    #[test]\n    fn should_support_const_and_full() {\n        static SLICES: [Slice; 2] = [Slice::full(), Slice::new(2, None, 1)];\n        assert_eq!(SLICES[0], Slice::new(0, None, 1));\n        assert_eq!(SLICES[1], Slice::new(2, None, 1));\n    }\n\n    #[test]\n    fn should_support_default() {\n        assert_eq!(Slice::default(), Slice::new(0, None, 1));\n    }\n\n    #[test]\n    fn should_support_copy() {\n        let mut slice = Slice::new(1, Some(3), 2);\n        let slice_copy = slice;\n\n        slice.end = Some(4);\n\n        assert_eq!(slice, Slice::new(1, Some(4), 2));\n        assert_eq!(slice_copy, Slice::new(1, Some(3), 2));\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/Cargo.toml",
    "content": "[package]\nauthors = [\"Dilshod Tadjibaev (@antimora)\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Storage and serialization infrastructure for Burn\"\ndocumentation = \"https://docs.rs/burn-store\"\nedition.workspace = true\nkeywords = [\n    \"deep-learning\",\n    \"machine-learning\",\n    \"tensor\",\n    \"storage\",\n    \"serialization\",\n]\nlicense.workspace = true\nname = \"burn-store\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-store\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\", \"pytorch\", \"safetensors\", \"burnpack\", \"memmap\"]\nmemmap = [\"std\", \"dep:memmap2\"]\nstd = [\n    \"dep:memmap2\",\n    \"safetensors/std\",\n    \"burn-core/std\",\n    \"burn-tensor/std\",\n    \"dep:regex\",\n    \"byteorder/std\",\n]\ntracing = [\n    \"burn-core/tracing\",\n    \"burn-cuda?/tracing\",\n    \"burn-nn/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-tensor/tracing\",\n    \"burn-wgpu?/tracing\",\n]\n\n\nburnpack = [\"serde\", \"ciborium\"]\ncuda = [\"burn-cuda\"]\nmetal = [\"wgpu\", \"burn-wgpu/metal\"]\ntch = [\"burn-tch\"]\nwgpu = [\"burn-wgpu\"]\n\nsafetensors = [\"dep:safetensors\"]\n\npytorch = [\"burn-core/record-item-custom-serde\", \"zip\", \"serde\", \"tar\"]\n\n[dependencies]\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\", default-features = false }\n\n# External dependencies\nbyteorder = { workspace = true, default-features = false }\nbytes = { workspace = true }\nciborium = { workspace = true, optional = true }\nhalf = { workspace = true }\nhashbrown = { workspace = true, features = [\"serde\"] }\nmemmap2 = { workspace = true, optional = true }\nregex = { workspace = true, optional = true }\nserde = { workspace = true, optional = true }\ntextdistance = { workspace = true }\nzip = { workspace = true, optional = true }\ntar = { workspace = true, optional = true }\n\n# Workaround to force broken minor version to update\nlzma-rust2 = { workspace = true, optional = true }\n\nsafetensors = { workspace = true, optional = true }\n\n# Optional backend dependencies for benchmarks\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", optional = true }\n\n[dev-dependencies]\n# burn-import = { path = \"../burn-import\", version = \"=0.21.0-pre.2\" } # disabled (circular dep in publish, only for bench)\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-nn = { path = \"../burn-nn\", version = \"=0.21.0-pre.2\", default-features = false }\ndivan = \"0.1\"\ntempfile = { workspace = true }\n\n[[bench]]\nharness = false\nname = \"resnet18_loading\"\n\n[[bench]]\nharness = false\nname = \"unified_loading\"\n\n[[bench]]\nharness = false\nname = \"unified_saving\"\n\n[[bench]]\nharness = false\nname = \"zero_copy_loading\"\n\n# Enable extra-platforms for bytes on targets without native atomics (e.g., thumbv6m-none-eabi)\n[target.'cfg(not(target_has_atomic = \"ptr\"))'.dependencies]\nbytes = { workspace = true, features = [\"extra-platforms\"] }\n"
  },
  {
    "path": "crates/burn-store/MIGRATION.md",
    "content": "# Migration Guide: burn-import to burn-store\n\nThis guide helps you migrate from the deprecated `burn-import` recorders (`PyTorchFileRecorder`,\n`SafetensorsFileRecorder`) to the new `burn-store` API (`PytorchStore`, `SafetensorsStore`).\n\n## Overview\n\nThe new `burn-store` API provides:\n\n- **Simpler API**: Load directly into models instead of records\n- **Fluent builder pattern**: Chain configuration methods\n- **Better error handling**: Detailed load results with applied/missing/errors info\n- **Bidirectional support**: Both load and save operations\n- **More features**: Filtering, partial loading, metadata, zero-copy loading\n\n## Quick Migration\n\n### PyTorch Files (.pt/.pth)\n\n**Before (burn-import):**\n\n```rust\nuse burn::record::{FullPrecisionSettings, Recorder};\nuse burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};\n\n// Load into a record, then create model from record\nlet record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()\n    .load(\"model.pt\".into(), &device)\n    .expect(\"Failed to load\");\n\nlet model = Model::init(&device).load_record(record);\n```\n\n**After (burn-store):**\n\n```rust\nuse burn_store::{ModuleSnapshot, PytorchStore};\n\n// Initialize model, then load weights directly\nlet mut model = Model::init(&device);\nlet mut store = PytorchStore::from_file(\"model.pt\");\nmodel.load_from(&mut store).expect(\"Failed to load\");\n```\n\n### SafeTensors Files (.safetensors)\n\n**Before (burn-import):**\n\n```rust\nuse burn::record::{FullPrecisionSettings, Recorder};\nuse burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};\n\nlet record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()\n    .load(\"model.safetensors\".into(), &device)\n    .expect(\"Failed to load\");\n\nlet model = Model::init(&device).load_record(record);\n```\n\n**After (burn-store):**\n\n```rust\nuse burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};\n\nlet mut model = Model::init(&device);\n\n// For SafeTensors exported from PyTorch, use the adapter\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_from_adapter(PyTorchToBurnAdapter);\nmodel.load_from(&mut store).expect(\"Failed to load\");\n\n// For native Burn SafeTensors, no adapter needed\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\");\nmodel.load_from(&mut store).expect(\"Failed to load\");\n```\n\n## API Mapping\n\n### PyTorchFileRecorder Options\n\n| burn-import                                    | burn-store                                  |\n| ---------------------------------------------- | ------------------------------------------- |\n| `LoadArgs::new(path)`                          | `PytorchStore::from_file(path)`             |\n| `.with_key_remap(pattern, replacement)`        | `.with_key_remapping(pattern, replacement)` |\n| `.with_top_level_key(key)`                     | `.with_top_level_key(key)`                  |\n| `.with_debug_print()`                          | _(use tracing/logging instead)_             |\n| `PyTorchFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_         |\n\n### SafetensorsFileRecorder Options\n\n| burn-import                                        | burn-store                                  |\n| -------------------------------------------------- | ------------------------------------------- |\n| `LoadArgs::new(path)`                              | `SafetensorsStore::from_file(path)`         |\n| `.with_key_remap(pattern, replacement)`            | `.with_key_remapping(pattern, replacement)` |\n| `.with_adapter_type(AdapterType::PyTorch)`         | `.with_from_adapter(PyTorchToBurnAdapter)`  |\n| `.with_adapter_type(AdapterType::NoAdapter)`       | _(default, no adapter)_                     |\n| `.with_debug_print()`                              | _(use tracing/logging instead)_             |\n| `SafetensorsFileRecorder::<FullPrecisionSettings>` | _(precision handled automatically)_         |\n\n## Detailed Examples\n\n### Key Remapping\n\n**Before:**\n\n```rust\nlet args = LoadArgs::new(\"model.pt\".into())\n    .with_key_remap(\"conv\\\\.(.*)\", \"$1\")\n    .with_key_remap(\"^old_prefix\\\\.\", \"new_prefix.\");\n\nlet record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()\n    .load(args, &device)?;\n```\n\n**After:**\n\n```rust\nlet mut store = PytorchStore::from_file(\"model.pt\")\n    .with_key_remapping(\"conv\\\\.(.*)\", \"$1\")\n    .with_key_remapping(\"^old_prefix\\\\.\", \"new_prefix.\");\n\nmodel.load_from(&mut store)?;\n```\n\n### Top-Level Key Access\n\n**Before:**\n\n```rust\nlet args = LoadArgs::new(\"checkpoint.pt\".into())\n    .with_top_level_key(\"state_dict\");\n\nlet record: ModelRecord<B> = PyTorchFileRecorder::<FullPrecisionSettings>::default()\n    .load(args, &device)?;\n```\n\n**After:**\n\n```rust\nlet mut store = PytorchStore::from_file(\"checkpoint.pt\")\n    .with_top_level_key(\"state_dict\");\n\nmodel.load_from(&mut store)?;\n```\n\n### PyTorch Adapter for SafeTensors\n\n**Before:**\n\n```rust\nuse burn_import::safetensors::{AdapterType, LoadArgs};\n\nlet args = LoadArgs::new(\"pytorch_model.safetensors\".into())\n    .with_adapter_type(AdapterType::PyTorch);\n\nlet record: ModelRecord<B> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()\n    .load(args, &device)?;\n```\n\n**After:**\n\n```rust\nuse burn_store::{PyTorchToBurnAdapter, SafetensorsStore};\n\nlet mut store = SafetensorsStore::from_file(\"pytorch_model.safetensors\")\n    .with_from_adapter(PyTorchToBurnAdapter);\n\nmodel.load_from(&mut store)?;\n```\n\n## New Features in burn-store\n\n### Partial Loading\n\nHandle missing tensors gracefully:\n\n```rust\nlet mut store = PytorchStore::from_file(\"model.pt\")\n    .allow_partial(true);\n\nlet result = model.load_from(&mut store)?;\nprintln!(\"Loaded: {:?}\", result.applied);\nprintln!(\"Missing: {:?}\", result.missing);\n```\n\n### Filtering\n\nLoad only specific tensors:\n\n```rust\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_regex(r\"^encoder\\..*\")  // Only encoder layers\n    .allow_partial(true);\n\nmodel.load_from(&mut store)?;\n```\n\n### Saving Models\n\nSave models (not supported by old recorders):\n\n```rust\n// Save to SafeTensors\nlet mut store = SafetensorsStore::from_file(\"output.safetensors\")\n    .metadata(\"version\", \"1.0\");\nmodel.save_into(&mut store)?;\n\n// Save to Burnpack (native format)\nlet mut store = BurnpackStore::from_file(\"output.bpk\");\nmodel.save_into(&mut store)?;\n```\n\n### Load Results\n\nGet detailed information about loading:\n\n```rust\nlet result = model.load_from(&mut store)?;\n\n// Print the full result for debugging - shows applied, skipped, missing, and errors\nprintln!(\"{}\", result);\n\n// Or access individual fields\nprintln!(\"Applied: {} tensors\", result.applied.len());\nprintln!(\"Skipped: {} tensors\", result.skipped.len());\nprintln!(\"Missing: {:?}\", result.missing);\nprintln!(\"Errors: {:?}\", result.errors);\n\n// Check if fully successful\nif result.is_success() {\n    println!(\"All tensors loaded successfully\");\n}\n```\n\nThe `LoadResult` implements `Display`, so printing it shows a formatted summary with suggestions for\ncommon issues (e.g., using `allow_partial(true)` for missing tensors).\n\n## Updating Cargo.toml\n\n**Before:**\n\n```toml\n[dependencies]\nburn-import = { version = \"0.x\", features = [\"pytorch\", \"safetensors\"] }\n```\n\n**After:**\n\n```toml\n[dependencies]\nburn-store = { version = \"0.x\", features = [\"pytorch\", \"safetensors\"] }\n```\n\n## Common Migration Issues\n\n### 1. Model vs Record\n\nThe new API loads directly into models, not records. Update your model initialization:\n\n```rust\n// Before: Create record, then model from record\nlet record = recorder.load(...)?;\nlet model = Model::init(&device).load_record(record);\n\n// After: Create model, then load into it\nlet mut model = Model::init(&device);\nmodel.load_from(&mut store)?;\n```\n\n### 2. Inference Functions\n\nIf you had functions that took `ModelRecord`, update them to take `Model`:\n\n```rust\n// Before\nfn infer(record: ModelRecord<B>) {\n    let model = Model::init(&device).load_record(record);\n    // ...\n}\n\n// After\nfn infer(model: Model<B>) {\n    // Model already has weights loaded\n    // ...\n}\n```\n\n### 3. Precision Settings\n\nThe old API required explicit precision settings. The new API handles this automatically:\n\n```rust\n// Before: Had to specify FullPrecisionSettings or HalfPrecisionSettings\nPyTorchFileRecorder::<FullPrecisionSettings>::default()\n\n// After: Precision handled automatically based on tensor dtype\nPytorchStore::from_file(\"model.pt\")\n```\n\n### 4. Error Handling\n\nThe new API provides richer error information:\n\n```rust\n// Before: Simple Result\nlet record = recorder.load(args, &device)?;\n\n// After: LoadResult with detailed info\nlet result = model.load_from(&mut store)?;\n\n// Print the result to see a helpful summary with suggestions\nprintln!(\"{}\", result);\n\n// Or handle specific issues programmatically\nif !result.errors.is_empty() {\n    for (path, error) in &result.errors {\n        eprintln!(\"Error loading {}: {}\", path, error);\n    }\n}\n```\n\n## See Also\n\n- [burn-store README](README.md) - Full documentation\n- [import-model-weights example](../../examples/import-model-weights/) - Working example\n"
  },
  {
    "path": "crates/burn-store/README.md",
    "content": "# Burn Store\n\n> Advanced model storage and serialization for the Burn deep learning framework\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-store.svg)](https://crates.io/crates/burn-store)\n[![Documentation](https://docs.rs/burn-store/badge.svg)](https://docs.rs/burn-store)\n\nA comprehensive storage library for Burn that enables efficient model serialization, cross-framework\ninteroperability, and advanced tensor management.\n\n> **Migrating from burn-import?** See the [Migration Guide](MIGRATION.md) for help moving from\n> `PyTorchFileRecorder`/`SafetensorsFileRecorder` to the new Store API.\n\n## Features\n\n- **Burnpack Format** - Native Burn format with CBOR metadata, memory-mapped loading, ParamId\n  persistence for stateful training, and no-std support\n- **SafeTensors Format** - Industry-standard format for secure and efficient tensor serialization\n- **PyTorch Support** - Direct loading of PyTorch .pth/.pt files with automatic weight\n  transformation\n- **Zero-Copy Loading** - Memory-mapped files and lazy tensor materialization for optimal\n  performance\n- **Flexible Filtering** - Load/save specific model subsets with regex, exact paths, or custom\n  predicates\n- **Tensor Remapping** - Rename tensors during load/save for framework compatibility\n- **Half-Precision Storage** - Automatic F32/F16 conversion with smart defaults for reduced model\n  file size\n- **No-std Support** - Burnpack and SafeTensors formats available in embedded and WASM environments\n\n## Quick Start\n\n```rust\nuse burn_store::{ModuleSnapshot, PytorchStore, SafetensorsStore, BurnpackStore, HalfPrecisionAdapter};\n\n// Load from PyTorch\nlet mut store = PytorchStore::from_file(\"model.pt\");\nmodel.load_from(&mut store)?;\n\n// Load from SafeTensors (with PyTorch adapter)\nlet mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    .with_from_adapter(PyTorchToBurnAdapter);\nmodel.load_from(&mut store)?;\n\n// Save to Burnpack\nlet mut store = BurnpackStore::from_file(\"model.bpk\");\nmodel.save_into(&mut store)?;\n\n// Save with half-precision (F32 -> F16, ~50% smaller files)\nlet adapter = HalfPrecisionAdapter::new();\nlet mut store = BurnpackStore::from_file(\"model_f16.bpk\")\n    .with_to_adapter(adapter.clone());\nmodel.save_into(&mut store)?;\n\n// Load half-precision back (F16 -> F32, same adapter)\nlet mut store = BurnpackStore::from_file(\"model_f16.bpk\")\n    .with_from_adapter(adapter);\nmodel.load_from(&mut store)?;\n```\n\n## Documentation\n\nFor comprehensive documentation including:\n\n- Exporting weights from PyTorch\n- Loading weights into Burn models\n- Saving models to various formats\n- Advanced features (filtering, remapping, partial loading, zero-copy)\n- API reference and troubleshooting\n\nSee the **[Burn Book - Saving and Loading](../../burn-book/src/saving-and-loading.md)** chapter.\n\n## Running Benchmarks\n\n```bash\n# Generate model files (one-time setup)\nuv run benches/generate_unified_models.py\n\n# Run loading benchmarks\ncargo bench --bench unified_loading\n\n# Run saving benchmarks\ncargo bench --bench unified_saving\n\n# With specific backend\ncargo bench --bench unified_loading --features metal\n```\n\n## License\n\nThis project is dual-licensed under MIT and Apache-2.0.\n"
  },
  {
    "path": "crates/burn-store/benches/download_resnet18.py",
    "content": "#!/usr/bin/env python3\n# /// script\n# requires-python = \">=3.8\"\n# dependencies = [\n#     \"torch\",\n#     \"torchvision\",\n# ]\n# ///\n\"\"\"\nDownload ResNet18 PyTorch model for benchmarking.\nThis script downloads a pre-trained ResNet18 model from PyTorch Hub\nand saves it in a format suitable for benchmarking.\n\"\"\"\n\nimport os\nimport sys\nimport tempfile\nfrom pathlib import Path\n\nimport torch\nimport torchvision.models as models\n\ndef download_resnet18():\n    \"\"\"Download ResNet18 model and save to temp directory.\"\"\"\n\n    # Create a temporary directory for the model\n    temp_dir = Path(tempfile.gettempdir()) / \"burn_resnet18_benchmark\"\n    temp_dir.mkdir(parents=True, exist_ok=True)\n\n    output_path = temp_dir / \"resnet18.pth\"\n\n    # Check if already downloaded\n    if output_path.exists():\n        file_size_mb = output_path.stat().st_size / (1024 * 1024)\n        print(f\"✅ ResNet18 already exists at: {output_path}\")\n        print(f\"   Size: {file_size_mb:.1f} MB\")\n        return str(output_path)\n\n    print(\"📥 Downloading ResNet18 model...\")\n\n    try:\n        # Download pre-trained ResNet18 model\n        model = models.resnet18(pretrained=True)\n\n        # Save the model state dict (this is what burn-store reads)\n        # Using the legacy format for compatibility\n        torch.save(model.state_dict(), output_path, _use_new_zipfile_serialization=False)\n\n        file_size_mb = output_path.stat().st_size / (1024 * 1024)\n        print(f\"✅ Successfully downloaded ResNet18 to: {output_path}\")\n        print(f\"   Size: {file_size_mb:.1f} MB\")\n        print(f\"   Format: PyTorch legacy format\")\n\n        # Verify it's readable\n        state_dict = torch.load(output_path, map_location='cpu')\n        print(f\"   Tensors: {len(state_dict)} tensors\")\n\n        # Print a few tensor names and shapes for verification\n        print(\"\\n   Sample tensors:\")\n        for i, (name, tensor) in enumerate(state_dict.items()):\n            if i < 3:\n                print(f\"     - {name}: {list(tensor.shape)}\")\n\n        return str(output_path)\n\n    except Exception as e:\n        print(f\"❌ Failed to download ResNet18: {e}\")\n        sys.exit(1)\n\ndef main():\n    \"\"\"Main entry point.\"\"\"\n    path = download_resnet18()\n\n    # Write the path to a file that the benchmark can read\n    bench_config = Path(tempfile.gettempdir()) / \"burn_resnet18_benchmark\" / \"path.txt\"\n    bench_config.write_text(path)\n\n    print(f\"\\n💡 Model ready for benchmarking\")\n    print(f\"   Run: cargo bench --bench resnet18_loading\")\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "crates/burn-store/benches/generate_unified_models.py",
    "content": "#!/usr/bin/env python3\n# /// script\n# requires-python = \">=3.8\"\n# dependencies = [\n#     \"torch\",\n#     \"safetensors\",\n#     \"packaging\",\n#     \"numpy\",\n# ]\n# ///\n\"\"\"\nGenerate a large model (~312MB) in both PyTorch and SafeTensors formats for unified benchmarking.\n\nUsage:\n    uv run benches/generate_unified_models.py\n\nThe script will create model files in /tmp/simple_bench_models/ directory.\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport os\nfrom pathlib import Path\nimport tempfile\nfrom safetensors.torch import save_file\n\ndef get_temp_dir():\n    \"\"\"Get the appropriate temp directory.\"\"\"\n    temp_dir = Path(tempfile.gettempdir()) / \"simple_bench_models\"\n    temp_dir.mkdir(parents=True, exist_ok=True)\n    return temp_dir\n\nclass LargeModel(nn.Module):\n    \"\"\"Large model with 20 layers to match Rust benchmark.\"\"\"\n    def __init__(self):\n        super().__init__()\n        self.layers = nn.ModuleList()\n\n        # Create a model with 20 layers matching the Rust LargeModel\n        for i in range(20):\n            in_size = 1024 if i == 0 else 2048\n            out_size = 2048\n            self.layers.append(nn.Linear(in_size, out_size))\n\n        print(f\"Created model with {len(self.layers)} layers\")\n\n    def forward(self, x):\n        for layer in self.layers:\n            x = layer(x)\n        return x\n\ndef calculate_model_size(model):\n    \"\"\"Calculate the size of the model in MB.\"\"\"\n    total_params = sum(p.numel() for p in model.parameters())\n    size_mb = (total_params * 4) / (1024 * 1024)  # 4 bytes per float32\n    return total_params, size_mb\n\ndef initialize_weights(model):\n    \"\"\"Initialize model weights with random values.\"\"\"\n    for param in model.parameters():\n        if param.dim() > 1:\n            nn.init.xavier_uniform_(param)\n        else:\n            nn.init.zeros_(param)\n\ndef save_pytorch_format(model, output_dir):\n    \"\"\"Save model in PyTorch format.\"\"\"\n    pt_path = output_dir / \"large_model.pt\"\n\n    # Save as checkpoint with model_state_dict (common format)\n    checkpoint = {\n        'model_state_dict': model.state_dict(),\n        'metadata': {\n            'model_type': 'large_benchmark_model',\n            'num_layers': len(model.layers),\n        }\n    }\n    torch.save(checkpoint, pt_path)\n\n    return pt_path\n\ndef save_safetensors_format(model, output_dir):\n    \"\"\"Save model in SafeTensors format.\"\"\"\n    st_path = output_dir / \"large_model.safetensors\"\n\n    # Convert state dict to safetensors format\n    state_dict = model.state_dict()\n    # Ensure all tensors are contiguous and on CPU\n    state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()}\n\n    # Save with metadata\n    metadata = {\n        'model_type': 'large_benchmark_model',\n        'num_layers': str(len(model.layers)),\n    }\n    save_file(state_dict, st_path, metadata=metadata)\n\n    return st_path\n\ndef verify_files(pt_path, st_path):\n    \"\"\"Verify the saved files can be loaded.\"\"\"\n    # Verify PyTorch file\n    checkpoint = torch.load(pt_path, map_location='cpu')\n    pt_keys = set(checkpoint['model_state_dict'].keys())\n    print(f\"  PyTorch file: {len(pt_keys)} tensors\")\n\n    # Verify SafeTensors file\n    from safetensors import safe_open\n    with safe_open(st_path, framework=\"pt\", device=\"cpu\") as f:\n        st_keys = set(f.keys())\n        print(f\"  SafeTensors file: {len(st_keys)} tensors\")\n\n    # Check keys match\n    if pt_keys != st_keys:\n        print(\"  ⚠️ Warning: Keys don't match between formats!\")\n    else:\n        print(\"  ✓ Keys match between formats\")\n\ndef main():\n    print(\"🔧 Generating unified benchmark model files...\")\n    print(\"\")\n\n    output_dir = get_temp_dir()\n    print(f\"📁 Output directory: {output_dir}\")\n    print(\"\")\n\n    # Set random seed for reproducibility\n    torch.manual_seed(42)\n\n    # Create the large model\n    print(\"📝 Creating large model...\")\n    model = LargeModel()\n\n    # Calculate and display model size\n    total_params, size_mb = calculate_model_size(model)\n    print(f\"  Total parameters: {total_params:,}\")\n    print(f\"  Model size: {size_mb:.2f} MB\")\n    print(\"\")\n\n    # Initialize weights\n    print(\"🎲 Initializing weights...\")\n    initialize_weights(model)\n\n    # Save in PyTorch format\n    print(\"💾 Saving PyTorch format...\")\n    pt_path = save_pytorch_format(model, output_dir)\n    pt_size_mb = pt_path.stat().st_size / (1024 * 1024)\n    print(f\"  Saved: {pt_path}\")\n    print(f\"  File size: {pt_size_mb:.2f} MB\")\n    print(\"\")\n\n    # Save in SafeTensors format\n    print(\"💾 Saving SafeTensors format...\")\n    st_path = save_safetensors_format(model, output_dir)\n    st_size_mb = st_path.stat().st_size / (1024 * 1024)\n    print(f\"  Saved: {st_path}\")\n    print(f\"  File size: {st_size_mb:.2f} MB\")\n    print(\"\")\n\n    # Verify files\n    print(\"🔍 Verifying saved files...\")\n    verify_files(pt_path, st_path)\n    print(\"\")\n\n    print(f\"✅ Model files generated successfully!\")\n    print(\"\")\n    print(\"📊 Summary:\")\n    print(f\"  PyTorch file: {pt_path.name} ({pt_size_mb:.2f} MB)\")\n    print(f\"  SafeTensors file: {st_path.name} ({st_size_mb:.2f} MB)\")\n    print(\"\")\n    print(\"💡 To run the unified benchmark:\")\n    print(\"   cargo bench --bench unified_loading\")\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "crates/burn-store/benches/resnet18_loading.rs",
    "content": "//! Benchmark for ResNet18 loading to verify lazy loading memory usage.\n//!\n//! resnet18.pth is pytorch's legacy file format.\n//!\n//! This benchmark loads a ResNet18 model and materializes all tensors\n//! to ensure memory usage stays reasonable with lazy loading.\n//!\n//! Run the benchmark:\n//! ```bash\n//! cargo bench --bench resnet18_loading\n//! ```\n\nuse burn_store::pytorch::PytorchReader;\nuse divan::{AllocProfiler, Bencher};\nuse std::path::PathBuf;\n\n#[global_allocator]\nstatic ALLOC: AllocProfiler = AllocProfiler::system();\n\n#[allow(clippy::manual_range_contains)]\nfn main() {\n    // Check if ResNet18 file exists\n    let path = resnet18_path();\n    if !path.exists() {\n        eprintln!(\"❌ ResNet18 model not found!\");\n        eprintln!();\n        eprintln!(\"Please download it first by running:\");\n        eprintln!(\"  python benches/download_resnet18.py\");\n        eprintln!();\n        eprintln!(\"Or if you don't have Python/PyTorch installed:\");\n        eprintln!(\"  uv run benches/download_resnet18.py\");\n        eprintln!();\n        eprintln!(\"Expected location: {}\", path.display());\n        std::process::exit(1);\n    }\n\n    // Verify file size is reasonable\n    let metadata = std::fs::metadata(&path).expect(\"Failed to read file metadata\");\n    let size_mb = metadata.len() as f64 / 1_048_576.0;\n\n    if size_mb < 40.0 || size_mb > 50.0 {\n        eprintln!(\n            \"⚠️ Warning: ResNet18 file size ({:.1} MB) seems unusual\",\n            size_mb\n        );\n        eprintln!(\"Expected size is around 45 MB\");\n    }\n\n    println!(\"✅ Found ResNet18 model at: {}\", path.display());\n    println!(\"📦 File size: {:.1} MB\", size_mb);\n    println!(\"📊 Running ResNet18 loading benchmarks...\\n\");\n\n    // Run divan benchmarks\n    divan::main();\n}\n\n/// Get the path to ResNet18 model file\nfn resnet18_path() -> PathBuf {\n    // First try to read from the path file created by download script\n    let temp_dir = std::env::temp_dir();\n    let config_file = temp_dir.join(\"burn_resnet18_benchmark\").join(\"path.txt\");\n\n    if config_file.exists()\n        && let Ok(path_str) = std::fs::read_to_string(&config_file)\n    {\n        let path = PathBuf::from(path_str.trim());\n        if path.exists() {\n            return path;\n        }\n    }\n\n    // Fallback to default location\n    temp_dir\n        .join(\"burn_resnet18_benchmark\")\n        .join(\"resnet18.pth\")\n}\n\n#[divan::bench(sample_count = 10)]\nfn load_resnet18_metadata(bencher: Bencher) {\n    let path = resnet18_path();\n\n    bencher.bench_local(|| {\n        let reader = PytorchReader::new(&path).expect(\"Failed to load ResNet18\");\n        let metadata = reader.metadata();\n\n        // Just access metadata without materializing tensors\n        assert_eq!(metadata.tensor_count, 122);\n    });\n}\n\n#[divan::bench(sample_count = 5)]\nfn load_resnet18_materialize_all(bencher: Bencher) {\n    let path = resnet18_path();\n\n    bencher.bench_local(|| {\n        let reader = PytorchReader::new(&path).expect(\"Failed to load ResNet18\");\n        let keys = reader.keys();\n\n        let mut total_bytes = 0usize;\n\n        // Materialize all tensors one by one\n        for key in &keys {\n            let tensor = reader.get(key).expect(\"Failed to get tensor\");\n            // Materialize the tensor data\n            let _data = tensor.to_data().expect(\"Failed to materialize tensor data\");\n            total_bytes += tensor.data_len();\n        }\n\n        // Verify we processed all the data\n        assert!(total_bytes > 40_000_000); // Should be ~45MB\n    });\n}\n\n#[divan::bench(sample_count = 5)]\nfn load_resnet18_materialize_sequential(bencher: Bencher) {\n    let path = resnet18_path();\n\n    bencher.bench_local(|| {\n        let reader = PytorchReader::new(&path).expect(\"Failed to load ResNet18\");\n        let keys = reader.keys();\n\n        // Materialize tensors one at a time, letting previous ones be dropped\n        // This simulates processing tensors sequentially without keeping all in memory\n        for key in &keys {\n            let tensor = reader.get(key).expect(\"Failed to get tensor\");\n            let data = tensor.to_data().expect(\"Failed to materialize tensor data\");\n\n            // Do minimal work with the data to prevent optimization\n            let sum = match data.dtype {\n                burn_tensor::DType::F32 => data\n                    .as_slice::<f32>()\n                    .map(|s| s.iter().sum::<f32>())\n                    .unwrap_or(0.0) as f64,\n                burn_tensor::DType::F64 => data\n                    .as_slice::<f64>()\n                    .map(|s| s.iter().sum::<f64>())\n                    .unwrap_or(0.0),\n                _ => 0.0,\n            };\n\n            // Use the sum to prevent dead code elimination\n            std::hint::black_box(sum);\n        }\n    });\n}\n\n#[divan::bench(sample_count = 10)]\nfn load_resnet18_largest_tensor(bencher: Bencher) {\n    let path = resnet18_path();\n\n    bencher.bench_local(|| {\n        let reader = PytorchReader::new(&path).expect(\"Failed to load ResNet18\");\n\n        // Find and materialize only the largest tensor\n        // This tests peak memory for a single tensor operation\n        let keys = reader.keys();\n        let mut largest_key = String::new();\n        let mut largest_size = 0usize;\n\n        for key in &keys {\n            let tensor = reader.get(key).expect(\"Failed to get tensor\");\n            let size = tensor.data_len();\n            if size > largest_size {\n                largest_size = size;\n                largest_key = key.clone();\n            }\n        }\n\n        // Materialize the largest tensor\n        let tensor = reader\n            .get(&largest_key)\n            .expect(\"Failed to get largest tensor\");\n        let _data = tensor.to_data().expect(\"Failed to materialize tensor data\");\n\n        assert!(largest_size > 9_000_000); // Should be ~9MB for layer4.0.conv2.weight\n    });\n}\n\n#[divan::bench(sample_count = 10)]\nfn load_resnet18_memory_profile(bencher: Bencher) {\n    let path = resnet18_path();\n\n    bencher\n        .with_inputs(|| path.clone())\n        .bench_local_values(|path| {\n            let reader = PytorchReader::new(&path).expect(\"Failed to load ResNet18\");\n            let keys = reader.keys();\n\n            let mut peak_single_tensor = 0usize;\n            let mut total_data = 0usize;\n\n            // Process each tensor and track memory\n            for key in &keys {\n                let tensor = reader.get(key).expect(\"Failed to get tensor\");\n                let tensor_size = tensor.data_len();\n\n                // Track largest single tensor\n                if tensor_size > peak_single_tensor {\n                    peak_single_tensor = tensor_size;\n                }\n\n                // Materialize the tensor\n                let data = tensor.to_data().expect(\"Failed to materialize tensor data\");\n                total_data += tensor_size;\n\n                // Drop data immediately to test lazy loading memory efficiency\n                drop(data);\n            }\n\n            // Return stats for verification\n            (peak_single_tensor, total_data)\n        });\n}\n"
  },
  {
    "path": "crates/burn-store/benches/unified_loading.rs",
    "content": "#![recursion_limit = \"256\"]\n\n//! Unified benchmark comparing all loading methods:\n//! - BurnpackStore (new native format)\n//! - NamedMpkFileRecorder (old native format)\n//! - SafetensorsStore (new)\n//! - SafetensorsFileRecorder (old)\n//! - PytorchStore (new)\n//! - PyTorchFileRecorder (old)\n//!\n//! Before running this benchmark, generate the model files:\n//! ```bash\n//! cd crates/burn-store\n//! uv run benches/generate_unified_models.py\n//! ```\n//!\n//! Then run the benchmark:\n//! ```bash\n//! cargo bench --bench unified_loading\n//! ```\n\nuse burn_core as burn;\n\nuse burn_core::module::Module;\nuse burn_core::prelude::*;\nuse burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};\n// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};\n// use burn_import::safetensors::SafetensorsFileRecorder;\nuse burn_nn as nn;\nuse burn_store::{\n    BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore,\n};\nuse divan::{AllocProfiler, Bencher};\nuse std::fs;\nuse std::path::{Path, PathBuf};\n\n#[global_allocator]\nstatic ALLOC: AllocProfiler = AllocProfiler::system();\n\n// Backend type aliases\ntype NdArrayBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(feature = \"wgpu\")]\ntype WgpuBackend = burn_wgpu::Wgpu;\n\n#[cfg(feature = \"cuda\")]\ntype CudaBackend = burn_cuda::Cuda<f32, i32>;\n\n#[cfg(feature = \"tch\")]\ntype TchBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(feature = \"metal\")]\ntype MetalBackend = burn_wgpu::Metal;\n\n// Use the same LargeModel as other benchmarks for fair comparison\n#[derive(Module, Debug)]\nstruct LargeModel<B: Backend> {\n    layers: Vec<nn::Linear<B>>,\n}\n\nimpl<B: Backend> LargeModel<B> {\n    fn new(device: &B::Device) -> Self {\n        let mut layers = Vec::new();\n        // Create a model with 20 layers - same as safetensor_loading benchmark\n        for i in 0..20 {\n            let in_size = if i == 0 { 1024 } else { 2048 };\n            layers.push(nn::LinearConfig::new(in_size, 2048).init(device));\n        }\n        Self { layers }\n    }\n}\n\n/// Get the path to the model files\nfn get_model_dir() -> PathBuf {\n    std::env::temp_dir().join(\"simple_bench_models\")\n}\n\n/// Generate Burnpack and NamedMpk files from existing SafeTensors file\nfn generate_burn_formats(st_path: &Path, bp_path: &Path, mpk_path: &Path) {\n    type TestBackend = NdArrayBackend;\n    let device = Default::default();\n\n    // Load the model from SafeTensors\n    let mut model = LargeModel::<TestBackend>::new(&device);\n    let mut store = SafetensorsStore::from_file(st_path).with_from_adapter(PyTorchToBurnAdapter);\n    model\n        .load_from(&mut store)\n        .expect(\"Failed to load from SafeTensors\");\n\n    // Save as Burnpack\n    if !bp_path.exists() {\n        println!(\"  Creating Burnpack file...\");\n        let mut burnpack_store = BurnpackStore::from_file(bp_path);\n        model\n            .save_into(&mut burnpack_store)\n            .expect(\"Failed to save as Burnpack\");\n    }\n\n    // Save as NamedMpk\n    if !mpk_path.exists() {\n        println!(\"  Creating NamedMpk file...\");\n        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();\n        model\n            .save_file(mpk_path, &recorder)\n            .expect(\"Failed to save as NamedMpk\");\n    }\n}\n\n/// Get paths to the model files\nfn get_model_paths() -> (PathBuf, PathBuf, PathBuf, PathBuf) {\n    let dir = get_model_dir();\n    (\n        dir.join(\"large_model.bpk\"),\n        dir.join(\"large_model.mpk\"),\n        dir.join(\"large_model.safetensors\"),\n        dir.join(\"large_model.pt\"),\n    )\n}\n\n/// Check if model files exist\nfn check_model_files() -> Result<(), String> {\n    let (_, _, st_path, pt_path) = get_model_paths();\n\n    // For now, only check safetensors and pytorch files (will generate burnpack/mpk later)\n    if !st_path.exists() || !pt_path.exists() {\n        return Err(format!(\n            \"\\n❌ Model files not found!\\n\\\n            \\n\\\n            Please generate the model files first by running:\\n\\\n            \\n\\\n            cd crates/burn-store\\n\\\n            uv run benches/generate_unified_models.py\\n\\\n            \\n\\\n            Expected files:\\n\\\n            - {}\\n\\\n            - {}\\n\",\n            st_path.display(),\n            pt_path.display()\n        ));\n    }\n\n    Ok(())\n}\n\nfn main() {\n    // Check if model files exist before running benchmarks\n    match check_model_files() {\n        Ok(()) => {\n            let (bp_path, mpk_path, st_path, pt_path) = get_model_paths();\n\n            // First, generate Burnpack and MPK files if they don't exist\n            if !bp_path.exists() || !mpk_path.exists() {\n                println!(\"⏳ Generating Burnpack and NamedMpk files from SafeTensors...\");\n                generate_burn_formats(&st_path, &bp_path, &mpk_path);\n            }\n\n            let bp_size = fs::metadata(&bp_path)\n                .ok()\n                .map(|m| m.len() as f64 / 1_048_576.0);\n            let mpk_size = fs::metadata(&mpk_path)\n                .ok()\n                .map(|m| m.len() as f64 / 1_048_576.0);\n            let st_size = fs::metadata(&st_path).unwrap().len() as f64 / 1_048_576.0;\n            let pt_size = fs::metadata(&pt_path).unwrap().len() as f64 / 1_048_576.0;\n\n            println!(\"✅ Found model files:\");\n            if let Some(size) = bp_size {\n                println!(\"  Burnpack: {} ({:.1} MB)\", bp_path.display(), size);\n            }\n            if let Some(size) = mpk_size {\n                println!(\"  NamedMpk: {} ({:.1} MB)\", mpk_path.display(), size);\n            }\n            println!(\"  SafeTensors: {} ({:.1} MB)\", st_path.display(), st_size);\n            println!(\"  PyTorch: {} ({:.1} MB)\", pt_path.display(), pt_size);\n            println!();\n            println!(\"🚀 Running unified loading benchmarks...\");\n            println!();\n            println!(\"Comparing 6 loading methods:\");\n            println!(\"  1. BurnpackStore (new native format - lazy loading)\");\n            println!(\"  2. NamedMpkFileRecorder (old native format - loads all to memory)\");\n            println!(\"  3. SafetensorsStore (new)\");\n            println!(\"  4. SafetensorsFileRecorder (old)\");\n            println!(\"  5. PytorchStore (new)\");\n            println!(\"  6. PyTorchFileRecorder (old)\");\n            println!();\n            println!(\"Available backends:\");\n            println!(\"  - NdArray (CPU)\");\n            #[cfg(feature = \"wgpu\")]\n            println!(\"  - WGPU (GPU)\");\n            #[cfg(feature = \"cuda\")]\n            println!(\"  - CUDA (NVIDIA GPU)\");\n            #[cfg(feature = \"tch\")]\n            println!(\"  - LibTorch\");\n            #[cfg(feature = \"metal\")]\n            println!(\"  - Metal (Apple GPU)\");\n            println!();\n\n            divan::main();\n        }\n        Err(msg) => {\n            eprintln!(\"{}\", msg);\n            std::process::exit(1);\n        }\n    }\n}\n\n// Macro to generate benchmarks for each backend\nmacro_rules! bench_backend {\n    ($backend:ty, $mod_name:ident, $backend_name:literal) => {\n        #[divan::bench_group(name = $backend_name, sample_count = 10)]\n        mod $mod_name {\n            use super::*;\n\n            type TestBackend = $backend;\n            type TestDevice = <TestBackend as Backend>::Device;\n\n            #[divan::bench]\n            fn burnpack_store(bencher: Bencher) {\n                let (bp_path, _, _, _) = get_model_paths();\n                let file_size = fs::metadata(&bp_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n                        let mut store = BurnpackStore::from_file(bp_path.clone());\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            #[divan::bench]\n            fn namedmpk_recorder(bencher: Bencher) {\n                let (_, mpk_path, _, _) = get_model_paths();\n                let file_size = fs::metadata(&mpk_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();\n                        let record = recorder\n                            .load(mpk_path.clone().into(), &device)\n                            .expect(\"Failed to load\");\n                        let _model = LargeModel::<TestBackend>::new(&device).load_record(record);\n                    });\n            }\n\n            #[divan::bench]\n            fn safetensors_store(bencher: Bencher) {\n                let (_, _, st_path, _) = get_model_paths();\n                let file_size = fs::metadata(&st_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n                        let mut store = SafetensorsStore::from_file(st_path.clone())\n                            .with_from_adapter(PyTorchToBurnAdapter);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            // #[divan::bench]\n            // fn safetensors_recorder(bencher: Bencher) {\n            //     let (_, _, st_path, _) = get_model_paths();\n            //     let file_size = fs::metadata(&st_path).unwrap().len();\n\n            //     bencher\n            //         .counter(divan::counter::BytesCount::new(file_size))\n            //         .bench(|| {\n            //             let device: TestDevice = Default::default();\n            //             let recorder = SafetensorsFileRecorder::<FullPrecisionSettings>::default();\n            //             let record = recorder\n            //                 .load(st_path.clone().into(), &device)\n            //                 .expect(\"Failed to load\");\n            //             let _model = LargeModel::<TestBackend>::new(&device).load_record(record);\n            //         });\n            // }\n\n            #[divan::bench]\n            fn pytorch_store(bencher: Bencher) {\n                let (_, _, _, pt_path) = get_model_paths();\n                let file_size = fs::metadata(&pt_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n                        let mut store = PytorchStore::from_file(pt_path.clone())\n                            .with_top_level_key(\"model_state_dict\")\n                            .allow_partial(true);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            // #[divan::bench]\n            // fn pytorch_recorder(bencher: Bencher) {\n            //     let (_, _, _, pt_path) = get_model_paths();\n            //     let file_size = fs::metadata(&pt_path).unwrap().len();\n\n            //     bencher\n            //         .counter(divan::counter::BytesCount::new(file_size))\n            //         .bench(|| {\n            //             let device: TestDevice = Default::default();\n            //             let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();\n            //             let load_args =\n            //                 LoadArgs::new(pt_path.clone()).with_top_level_key(\"model_state_dict\");\n            //             let record = recorder.load(load_args, &device).expect(\"Failed to load\");\n            //             let _model = LargeModel::<TestBackend>::new(&device).load_record(record);\n            //         });\n            // }\n        }\n    };\n}\n\n// Generate benchmarks for each backend\nbench_backend!(NdArrayBackend, ndarray_backend, \"NdArray Backend (CPU)\");\n\n#[cfg(feature = \"wgpu\")]\nbench_backend!(WgpuBackend, wgpu_backend, \"WGPU Backend (GPU)\");\n\n#[cfg(feature = \"cuda\")]\nbench_backend!(CudaBackend, cuda_backend, \"CUDA Backend (NVIDIA GPU)\");\n\n#[cfg(feature = \"tch\")]\nbench_backend!(TchBackend, tch_backend, \"LibTorch Backend\");\n\n#[cfg(feature = \"metal\")]\nbench_backend!(MetalBackend, metal_backend, \"Metal Backend (Apple GPU)\");\n"
  },
  {
    "path": "crates/burn-store/benches/unified_saving.rs",
    "content": "#![recursion_limit = \"256\"]\n\n//! Unified benchmark comparing all saving methods:\n//! - BurnpackStore (new native format)\n//! - NamedMpkFileRecorder (old native format)\n//! - SafetensorsStore (new)\n//!\n//! Before running this benchmark, ensure the directory exists:\n//! ```bash\n//! mkdir -p /tmp/simple_bench_models\n//! ```\n//!\n//! Then run the benchmark:\n//! ```bash\n//! cargo bench --bench unified_saving\n//! ```\nuse burn_core as burn;\n\nuse burn_core::module::Module;\nuse burn_core::prelude::*;\nuse burn_core::record::{FullPrecisionSettings, NamedMpkFileRecorder};\nuse burn_nn as nn;\nuse burn_store::{BurnpackStore, ModuleSnapshot, SafetensorsStore};\nuse divan::{AllocProfiler, Bencher};\nuse std::fs;\nuse std::path::PathBuf;\n\n#[global_allocator]\nstatic ALLOC: AllocProfiler = AllocProfiler::system();\n\n// Backend type aliases\ntype NdArrayBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(feature = \"wgpu\")]\ntype WgpuBackend = burn_wgpu::Wgpu;\n\n#[cfg(feature = \"cuda\")]\ntype CudaBackend = burn_cuda::Cuda<f32, i32>;\n\n#[cfg(feature = \"tch\")]\ntype TchBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(feature = \"metal\")]\ntype MetalBackend = burn_wgpu::Metal;\n\n// Use the same LargeModel as other benchmarks for fair comparison\n#[derive(Module, Debug)]\nstruct LargeModel<B: Backend> {\n    layers: Vec<nn::Linear<B>>,\n}\n\nimpl<B: Backend> LargeModel<B> {\n    fn new(device: &B::Device) -> Self {\n        let mut layers = Vec::new();\n        // Create a model with 20 layers - same as loading benchmarks\n        for i in 0..20 {\n            let in_size = if i == 0 { 1024 } else { 2048 };\n            layers.push(nn::LinearConfig::new(in_size, 2048).init(device));\n        }\n        Self { layers }\n    }\n}\n\n/// Get the path to the output directory\nfn get_output_dir() -> PathBuf {\n    std::env::temp_dir().join(\"simple_bench_models_saving\")\n}\n\n/// Ensure output directory exists\nfn ensure_output_dir() -> Result<(), String> {\n    let dir = get_output_dir();\n    if !dir.exists() {\n        fs::create_dir_all(&dir)\n            .map_err(|e| format!(\"Failed to create output directory: {}\", e))?;\n    }\n    Ok(())\n}\n\nfn main() {\n    match ensure_output_dir() {\n        Ok(()) => {\n            println!(\"✅ Output directory ready: {}\", get_output_dir().display());\n            println!();\n            println!(\"🚀 Running unified saving benchmarks...\");\n            println!();\n            println!(\"Comparing 3 saving methods:\");\n            println!(\"  1. BurnpackStore (new native format)\");\n            println!(\"  2. NamedMpkFileRecorder (old native format)\");\n            println!(\"  3. SafetensorsStore (new)\");\n            println!();\n            println!(\"Available backends:\");\n            println!(\"  - NdArray (CPU)\");\n            #[cfg(feature = \"wgpu\")]\n            println!(\"  - WGPU (GPU)\");\n            #[cfg(feature = \"cuda\")]\n            println!(\"  - CUDA (NVIDIA GPU)\");\n            #[cfg(feature = \"tch\")]\n            println!(\"  - LibTorch\");\n            #[cfg(feature = \"metal\")]\n            println!(\"  - Metal (Apple GPU)\");\n            println!();\n\n            divan::main();\n        }\n        Err(msg) => {\n            eprintln!(\"❌ {}\", msg);\n            std::process::exit(1);\n        }\n    }\n}\n\n// Macro to generate benchmarks for each backend\nmacro_rules! bench_backend {\n    ($backend:ty, $mod_name:ident, $backend_name:literal) => {\n        #[divan::bench_group(name = $backend_name, sample_count = 10)]\n        mod $mod_name {\n            use super::*;\n\n            type TestBackend = $backend;\n            type TestDevice = <TestBackend as Backend>::Device;\n\n            #[divan::bench]\n            fn burnpack_store(bencher: Bencher) {\n                bencher.bench(|| {\n                    let device: TestDevice = Default::default();\n                    let model = LargeModel::<TestBackend>::new(&device);\n                    let output_path = get_output_dir().join(\"test_burnpack.bpk\");\n                    let mut store = BurnpackStore::from_file(output_path.clone()).overwrite(true);\n                    model\n                        .save_into(&mut store)\n                        .expect(\"Failed to save with BurnpackStore\");\n                    // Clean up\n                    let _ = fs::remove_file(output_path);\n                });\n            }\n\n            #[divan::bench]\n            fn namedmpk_recorder(bencher: Bencher) {\n                bencher.bench(|| {\n                    let device: TestDevice = Default::default();\n                    let model = LargeModel::<TestBackend>::new(&device);\n                    let output_path = get_output_dir().join(\"test_namedmpk.mpk\");\n                    let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();\n                    model\n                        .save_file(output_path.clone(), &recorder)\n                        .expect(\"Failed to save with NamedMpkFileRecorder\");\n                    // Clean up\n                    let _ = fs::remove_file(output_path);\n                });\n            }\n\n            #[divan::bench]\n            fn safetensors_store(bencher: Bencher) {\n                bencher.bench(|| {\n                    let device: TestDevice = Default::default();\n                    let model = LargeModel::<TestBackend>::new(&device);\n                    let output_path = get_output_dir().join(\"test_safetensors_store.safetensors\");\n                    let mut store = SafetensorsStore::from_file(output_path.clone());\n                    model\n                        .save_into(&mut store)\n                        .expect(\"Failed to save with SafetensorsStore\");\n                    // Clean up\n                    let _ = fs::remove_file(output_path);\n                });\n            }\n        }\n    };\n}\n\n// Generate benchmarks for each backend\nbench_backend!(NdArrayBackend, ndarray_backend, \"NdArray Backend (CPU)\");\n\n#[cfg(feature = \"wgpu\")]\nbench_backend!(WgpuBackend, wgpu_backend, \"WGPU Backend (GPU)\");\n\n#[cfg(feature = \"cuda\")]\nbench_backend!(CudaBackend, cuda_backend, \"CUDA Backend (NVIDIA GPU)\");\n\n#[cfg(feature = \"tch\")]\nbench_backend!(TchBackend, tch_backend, \"LibTorch Backend\");\n\n#[cfg(feature = \"metal\")]\nbench_backend!(MetalBackend, metal_backend, \"Metal Backend (Apple GPU)\");\n"
  },
  {
    "path": "crates/burn-store/benches/zero_copy_loading.rs",
    "content": "#![recursion_limit = \"256\"]\n\n//! Benchmark comparing zero-copy vs copy loading modes for BurnpackStore.\n//!\n//! This benchmark measures the performance difference between:\n//! - `zero_copy(false)` - Default mode, copies tensor data into new allocations\n//! - `zero_copy(true)` - Zero-copy mode, slices tensor data without copying\n//!\n//! ## Understanding the Results\n//!\n//! **IMPORTANT**: For NdArray backend, you'll see similar allocation numbers because:\n//! - NdArray uses `ndarray::ArrayD` which MUST own data as `Vec<T>`\n//! - Even with zero-copy, the backend eventually copies data into its own format\n//!\n//! The zero-copy benefit is:\n//! - **Without zero-copy**: File → Copy to heap (Bytes) → Copy to Vec (backend)\n//! - **With zero-copy**: File → Zero-copy slice → Copy to Vec (backend)\n//!\n//! So zero-copy saves ONE memory copy at the store level. The `store_only_*` benchmarks\n//! show the raw store performance without backend allocation overhead.\n//!\n//! GPU backends that can consume `Bytes` directly will show larger benefits.\n//!\n//! ## Running the benchmark\n//!\n//! Before running this benchmark, generate the model files:\n//! ```bash\n//! cd crates/burn-store\n//! uv run benches/generate_unified_models.py\n//! ```\n//!\n//! Then run the benchmark:\n//! ```bash\n//! cargo bench --bench zero_copy_loading\n//! ```\n\nuse burn_core as burn;\n\nuse burn_core::module::Module;\nuse burn_core::prelude::*;\nuse burn_nn as nn;\nuse burn_store::{\n    BurnpackStore, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore,\n};\nuse burn_tensor::{AllocationProperty, Bytes};\nuse divan::{AllocProfiler, Bencher};\nuse std::fs;\nuse std::path::PathBuf;\nuse std::sync::OnceLock;\n\n#[global_allocator]\nstatic ALLOC: AllocProfiler = AllocProfiler::system();\n\n// Static storage for embedded model bytes (simulating include_bytes!)\nstatic STATIC_MODEL_BYTES: OnceLock<&'static [u8]> = OnceLock::new();\n\n// Backend type aliases\ntype NdArrayBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(feature = \"wgpu\")]\ntype WgpuBackend = burn_wgpu::Wgpu;\n\n#[cfg(feature = \"cuda\")]\ntype CudaBackend = burn_cuda::Cuda<f32, i32>;\n\n#[cfg(feature = \"tch\")]\ntype TchBackend = burn_tch::LibTorch<f32>;\n\n#[cfg(feature = \"metal\")]\ntype MetalBackend = burn_wgpu::Metal;\n\n// Use the same LargeModel as other benchmarks for fair comparison\n#[derive(Module, Debug)]\nstruct LargeModel<B: Backend> {\n    layers: Vec<nn::Linear<B>>,\n}\n\nimpl<B: Backend> LargeModel<B> {\n    fn new(device: &B::Device) -> Self {\n        let mut layers = Vec::new();\n        // Create a model with 20 layers - same as unified_loading benchmark\n        for i in 0..20 {\n            let in_size = if i == 0 { 1024 } else { 2048 };\n            layers.push(nn::LinearConfig::new(in_size, 2048).init(device));\n        }\n        Self { layers }\n    }\n}\n\n/// Get the path to the model files\nfn get_model_dir() -> PathBuf {\n    std::env::temp_dir().join(\"simple_bench_models\")\n}\n\n/// Get path to Burnpack model file\nfn get_burnpack_path() -> PathBuf {\n    get_model_dir().join(\"large_model.bpk\")\n}\n\n/// Generate Burnpack file from existing SafeTensors file if needed\nfn ensure_burnpack_file() {\n    let bp_path = get_burnpack_path();\n    let st_path = get_model_dir().join(\"large_model.safetensors\");\n\n    if bp_path.exists() {\n        return;\n    }\n\n    if !st_path.exists() {\n        panic!(\n            \"\\n❌ SafeTensors model file not found!\\n\\\n            \\n\\\n            Please generate the model files first by running:\\n\\\n            \\n\\\n            cd crates/burn-store\\n\\\n            uv run benches/generate_unified_models.py\\n\\\n            \\n\\\n            Expected file: {}\\n\",\n            st_path.display()\n        );\n    }\n\n    println!(\"⏳ Generating Burnpack file from SafeTensors...\");\n\n    type TestBackend = NdArrayBackend;\n    let device = Default::default();\n\n    // Load from SafeTensors\n    let mut model = LargeModel::<TestBackend>::new(&device);\n    let mut store = SafetensorsStore::from_file(&st_path).with_from_adapter(PyTorchToBurnAdapter);\n    model\n        .load_from(&mut store)\n        .expect(\"Failed to load from SafeTensors\");\n\n    // Save as Burnpack\n    let mut burnpack_store = BurnpackStore::from_file(&bp_path);\n    model\n        .save_into(&mut burnpack_store)\n        .expect(\"Failed to save as Burnpack\");\n\n    println!(\"✅ Created Burnpack file: {}\", bp_path.display());\n}\n\n/// Initialize static model bytes (simulating include_bytes! at runtime for benchmarks)\nfn get_static_model_bytes() -> &'static [u8] {\n    STATIC_MODEL_BYTES.get_or_init(|| {\n        let bp_path = get_burnpack_path();\n        let bytes = fs::read(&bp_path).expect(\"Failed to read Burnpack file\");\n        // Leak the bytes to get a 'static lifetime (acceptable for benchmarks)\n        Box::leak(bytes.into_boxed_slice())\n    })\n}\n\nfn main() {\n    // Ensure Burnpack file exists\n    ensure_burnpack_file();\n\n    let bp_path = get_burnpack_path();\n    let file_size = fs::metadata(&bp_path).unwrap().len() as f64 / 1_048_576.0;\n\n    println!(\"✅ Found Burnpack model file:\");\n    println!(\"  Path: {}\", bp_path.display());\n    println!(\"  Size: {:.1} MB\", file_size);\n    println!();\n    println!(\"🚀 Running zero-copy loading benchmarks...\");\n    println!();\n    println!(\"Comparing loading modes:\");\n    println!(\"  1. file_copy        - from_file().zero_copy(false) - copies tensor data\");\n    println!(\"  2. file_zero_copy   - from_file().zero_copy(true)  - zero-copy via mmap\");\n    println!(\"  3. static_copy      - from_bytes() with Vec copy   - copies from static\");\n    println!(\"  4. static_zero_copy - from_static()                - zero-copy from static\");\n    println!();\n    println!(\"Available backends:\");\n    println!(\"  - NdArray (CPU)\");\n    #[cfg(feature = \"wgpu\")]\n    println!(\"  - WGPU (GPU)\");\n    #[cfg(feature = \"cuda\")]\n    println!(\"  - CUDA (NVIDIA GPU)\");\n    #[cfg(feature = \"tch\")]\n    println!(\"  - LibTorch\");\n    #[cfg(feature = \"metal\")]\n    println!(\"  - Metal (Apple GPU)\");\n    println!();\n\n    // Pre-initialize static bytes before benchmarks\n    let _ = get_static_model_bytes();\n\n    divan::main();\n}\n\n// Macro to generate benchmarks for each backend\nmacro_rules! bench_backend {\n    ($backend:ty, $mod_name:ident, $backend_name:literal) => {\n        #[divan::bench_group(name = $backend_name, sample_count = 10)]\n        mod $mod_name {\n            use super::*;\n\n            type TestBackend = $backend;\n            type TestDevice = <TestBackend as Backend>::Device;\n\n            /// File-based loading with copy mode (default)\n            #[divan::bench]\n            fn file_copy(bencher: Bencher) {\n                let bp_path = get_burnpack_path();\n                let file_size = fs::metadata(&bp_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n                        let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            /// File-based loading with zero-copy mode (mmap + bytes::Bytes)\n            #[divan::bench]\n            fn file_zero_copy(bencher: Bencher) {\n                let bp_path = get_burnpack_path();\n                let file_size = fs::metadata(&bp_path).unwrap().len();\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n                        let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            /// Static bytes with copy mode (simulating old behavior)\n            #[divan::bench]\n            fn static_copy(bencher: Bencher) {\n                let static_bytes = get_static_model_bytes();\n                let file_size = static_bytes.len() as u64;\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n\n                        // Simulate old behavior: copy static bytes to Vec, then load\n                        let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());\n                        let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            /// Static bytes with zero-copy mode (new from_static)\n            #[divan::bench]\n            fn static_zero_copy(bencher: Bencher) {\n                let static_bytes = get_static_model_bytes();\n                let file_size = static_bytes.len() as u64;\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n\n                        // Zero-copy: use from_static which keeps data in .rodata\n                        let mut store = BurnpackStore::from_static(static_bytes);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n\n            /// In-memory shared bytes with zero-copy\n            #[divan::bench]\n            fn memory_shared_zero_copy(bencher: Bencher) {\n                let static_bytes = get_static_model_bytes();\n                let file_size = static_bytes.len() as u64;\n\n                // Pre-create shared bytes outside the benchmark loop\n                let shared = bytes::Bytes::from_static(static_bytes);\n\n                bencher\n                    .counter(divan::counter::BytesCount::new(file_size))\n                    .bench(|| {\n                        let device: TestDevice = Default::default();\n                        let mut model = LargeModel::<TestBackend>::new(&device);\n\n                        // Create Bytes from shared (cheap clone of Arc)\n                        let bytes = Bytes::from_shared(shared.clone(), AllocationProperty::Other);\n                        let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(true);\n                        model.load_from(&mut store).expect(\"Failed to load\");\n                    });\n            }\n        }\n    };\n}\n\n// =============================================================================\n// Zero-copy verification (proves operations use static region data)\n// =============================================================================\n\n/// Verify that zero-copy loading actually uses data from the static region.\n/// This runs once at startup to prove correctness before benchmarking.\n#[divan::bench_group(name = \"Zero-Copy Verification\", sample_count = 1)]\nmod verification {\n    use super::*;\n    use burn_ndarray::NdArray;\n\n    type B = NdArray<f32>;\n\n    /// Verify zero-copy: tensor storage is borrowed (not owned)\n    #[divan::bench]\n    fn verify_storage_is_borrowed() {\n        let static_bytes = get_static_model_bytes();\n\n        // Load model with zero-copy from static bytes\n        let device = Default::default();\n        let mut model = LargeModel::<B>::new(&device);\n        let mut store = BurnpackStore::from_static(static_bytes);\n        model.load_from(&mut store).expect(\"Failed to load\");\n\n        // Get the first layer's weight tensor and verify it uses borrowed storage\n        let weight = model.layers[0].weight.val();\n        // .into_primitive() returns TensorPrimitive<B>, .tensor() extracts B::FloatTensorPrimitive\n        let ndarray_tensor = weight.into_primitive().tensor();\n\n        // Verify the storage is borrowed (zero-copy from static region)\n        assert!(\n            ndarray_tensor.is_borrowed(),\n            \"ZERO-COPY FAILURE: Tensor storage is NOT borrowed. \\\n             Data was copied instead of being zero-copy!\"\n        );\n\n        println!(\"✅ Verified: Tensor storage is borrowed (zero-copy from static region)\");\n    }\n\n    /// Verify ALL layers use borrowed (zero-copy) storage.\n    /// This is the key proof that loaded weights point to static memory.\n    #[divan::bench]\n    fn verify_all_layers_borrowed() {\n        let static_bytes = get_static_model_bytes();\n\n        // Load model with zero-copy\n        let device = Default::default();\n        let mut model = LargeModel::<B>::new(&device);\n        let mut store = BurnpackStore::from_static(static_bytes);\n        model.load_from(&mut store).expect(\"Failed to load\");\n\n        // Check ALL layers have borrowed storage\n        let mut total_elements = 0usize;\n        for (i, layer) in model.layers.iter().enumerate() {\n            let weight = layer.weight.val();\n            total_elements += weight.shape().num_elements();\n\n            assert!(\n                weight.into_primitive().tensor().is_borrowed(),\n                \"Layer {} weight should be borrowed (zero-copy)\",\n                i\n            );\n        }\n\n        let total_mb = (total_elements * 4) as f64 / 1_048_576.0;\n        println!(\n            \"✅ Verified: All {} layers use borrowed storage\",\n            model.layers.len()\n        );\n        println!(\n            \"   - Model size: {:.2} MB - all pointing to static region\",\n            total_mb\n        );\n    }\n\n    /// Verify data is readable and correct using sum().into_scalar().\n    /// Note: sum() triggers COW copy, so this shows ops work correctly on zero-copy data.\n    #[divan::bench]\n    fn verify_ops_produce_correct_results() {\n        let static_bytes = get_static_model_bytes();\n\n        let device = Default::default();\n        let mut model = LargeModel::<B>::new(&device);\n        let mut store = BurnpackStore::from_static(static_bytes);\n        model.load_from(&mut store).expect(\"Failed to load\");\n\n        // Compute sum of first layer weight - proves data is valid\n        let weight = model.layers[0].weight.val();\n        let sum: f32 = weight.sum().into_scalar();\n\n        assert!(sum.is_finite(), \"Sum should be finite\");\n        println!(\"✅ Verified: Operations on zero-copy data produce valid results\");\n        println!(\"   - First layer sum: {:.4}\", sum);\n    }\n\n    /// Verify operations produce correct results on zero-copy data\n    #[divan::bench]\n    fn verify_operations_on_static_data() {\n        let static_bytes = get_static_model_bytes();\n\n        // Load model with zero-copy\n        let device = Default::default();\n        let mut model = LargeModel::<B>::new(&device);\n        let mut store = BurnpackStore::from_static(static_bytes);\n        model.load_from(&mut store).expect(\"Failed to load\");\n\n        // Perform operations on the loaded weights\n        let weight = model.layers[0].weight.val();\n        let shape = weight.shape();\n\n        // Test 1: Sum should be finite (not NaN or Inf)\n        let sum: f32 = weight.clone().sum().to_data().to_vec().unwrap()[0];\n        assert!(\n            sum.is_finite(),\n            \"Operation failed: sum is not finite ({})\",\n            sum\n        );\n\n        // Test 2: Matrix multiply with itself transposed (W @ W.T)\n        let transposed = weight.clone().transpose();\n        let matmul_result = weight.clone().matmul(transposed);\n        let matmul_sum: f32 = matmul_result.sum().to_data().to_vec().unwrap()[0];\n        assert!(\n            matmul_sum.is_finite(),\n            \"Matmul failed: result sum is not finite ({})\",\n            matmul_sum\n        );\n\n        // Test 3: Element-wise operations\n        let doubled = weight.clone() * 2.0;\n        let doubled_sum: f32 = doubled.sum().to_data().to_vec().unwrap()[0];\n        assert!(\n            (doubled_sum - sum * 2.0).abs() < 1e-3,\n            \"Element-wise op failed: doubled_sum ({}) != sum*2 ({})\",\n            doubled_sum,\n            sum * 2.0\n        );\n\n        println!(\"✅ Verified: Operations on zero-copy data produce correct results\");\n        println!(\"   - Weight shape: {:?}\", shape.as_slice());\n        println!(\"   - Sum: {:.4}\", sum);\n        println!(\"   - Matmul result sum: {:.4}\", matmul_sum);\n    }\n\n    /// Compare zero-copy vs copy: verify both produce identical results\n    #[divan::bench]\n    fn verify_copy_vs_zero_copy_equality() {\n        let static_bytes = get_static_model_bytes();\n        let device: <B as Backend>::Device = Default::default();\n\n        // Load with zero-copy\n        let mut model_zc = LargeModel::<B>::new(&device);\n        let mut store_zc = BurnpackStore::from_static(static_bytes);\n        model_zc\n            .load_from(&mut store_zc)\n            .expect(\"Failed to load zero-copy\");\n\n        // Load with copy (simulate old behavior)\n        let mut model_copy = LargeModel::<B>::new(&device);\n        let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());\n        let mut store_copy = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);\n        model_copy\n            .load_from(&mut store_copy)\n            .expect(\"Failed to load copy\");\n\n        // Compare weights from both models\n        for (i, (layer_zc, layer_copy)) in model_zc\n            .layers\n            .iter()\n            .zip(model_copy.layers.iter())\n            .enumerate()\n        {\n            let weight_zc = layer_zc.weight.val();\n            let weight_copy = layer_copy.weight.val();\n\n            // Check shapes match\n            assert_eq!(\n                weight_zc.shape(),\n                weight_copy.shape(),\n                \"Layer {} weight shapes don't match\",\n                i\n            );\n\n            // Check values match (using sum as a proxy)\n            let sum_zc: f32 = weight_zc.clone().sum().to_data().to_vec().unwrap()[0];\n            let sum_copy: f32 = weight_copy.clone().sum().to_data().to_vec().unwrap()[0];\n            assert!(\n                (sum_zc - sum_copy).abs() < 1e-6,\n                \"Layer {} weight sums don't match: zero-copy={}, copy={}\",\n                i,\n                sum_zc,\n                sum_copy\n            );\n        }\n\n        println!(\n            \"✅ Verified: Zero-copy and copy loading produce identical results for all {} layers\",\n            model_zc.layers.len()\n        );\n    }\n}\n\n// =============================================================================\n// Store-only benchmarks (no backend allocation overhead)\n// These show the TRUE zero-copy benefit at the store level\n// =============================================================================\n\n#[divan::bench_group(name = \"Store Only (no backend)\", sample_count = 10)]\nmod store_only {\n    use super::*;\n\n    /// File-based store with copy mode - measures store overhead only\n    #[divan::bench]\n    fn file_copy(bencher: Bencher) {\n        let bp_path = get_burnpack_path();\n        let file_size = fs::metadata(&bp_path).unwrap().len();\n\n        bencher\n            .counter(divan::counter::BytesCount::new(file_size))\n            .bench(|| {\n                let mut store = BurnpackStore::from_file(&bp_path).zero_copy(false);\n                // Just iterate through all tensor snapshots, calling to_data() on each\n                // This forces the store to read and materialize all tensor data\n                let snapshots = store.get_all_snapshots().expect(\"Failed to get snapshots\");\n                for snapshot in snapshots.values() {\n                    let _data = snapshot.to_data().expect(\"Failed to get tensor data\");\n                }\n            });\n    }\n\n    /// File-based store with zero-copy mode - measures store overhead only\n    #[divan::bench]\n    fn file_zero_copy(bencher: Bencher) {\n        let bp_path = get_burnpack_path();\n        let file_size = fs::metadata(&bp_path).unwrap().len();\n\n        bencher\n            .counter(divan::counter::BytesCount::new(file_size))\n            .bench(|| {\n                let mut store = BurnpackStore::from_file(&bp_path).zero_copy(true);\n                let snapshots = store.get_all_snapshots().expect(\"Failed to get snapshots\");\n                for snapshot in snapshots.values() {\n                    let _data = snapshot.to_data().expect(\"Failed to get tensor data\");\n                }\n            });\n    }\n\n    /// Static bytes with copy mode - measures store overhead only\n    #[divan::bench]\n    fn static_copy(bencher: Bencher) {\n        let static_bytes = get_static_model_bytes();\n        let file_size = static_bytes.len() as u64;\n\n        bencher\n            .counter(divan::counter::BytesCount::new(file_size))\n            .bench(|| {\n                // Simulate old behavior: copy static bytes to Vec\n                let bytes = Bytes::from_bytes_vec(static_bytes.to_vec());\n                let mut store = BurnpackStore::from_bytes(Some(bytes)).zero_copy(false);\n                let snapshots = store.get_all_snapshots().expect(\"Failed to get snapshots\");\n                for snapshot in snapshots.values() {\n                    let _data = snapshot.to_data().expect(\"Failed to get tensor data\");\n                }\n            });\n    }\n\n    /// Static bytes with zero-copy mode - measures store overhead only\n    #[divan::bench]\n    fn static_zero_copy(bencher: Bencher) {\n        let static_bytes = get_static_model_bytes();\n        let file_size = static_bytes.len() as u64;\n\n        bencher\n            .counter(divan::counter::BytesCount::new(file_size))\n            .bench(|| {\n                let mut store = BurnpackStore::from_static(static_bytes);\n                let snapshots = store.get_all_snapshots().expect(\"Failed to get snapshots\");\n                for snapshot in snapshots.values() {\n                    let _data = snapshot.to_data().expect(\"Failed to get tensor data\");\n                }\n            });\n    }\n}\n\n// =============================================================================\n// Full model loading benchmarks (includes backend allocation)\n// =============================================================================\n\n// Generate benchmarks for each backend\nbench_backend!(NdArrayBackend, ndarray_backend, \"NdArray Backend (CPU)\");\n\n#[cfg(feature = \"wgpu\")]\nbench_backend!(WgpuBackend, wgpu_backend, \"WGPU Backend (GPU)\");\n\n#[cfg(feature = \"cuda\")]\nbench_backend!(CudaBackend, cuda_backend, \"CUDA Backend (NVIDIA GPU)\");\n\n#[cfg(feature = \"tch\")]\nbench_backend!(TchBackend, tch_backend, \"LibTorch Backend\");\n\n#[cfg(feature = \"metal\")]\nbench_backend!(MetalBackend, metal_backend, \"Metal Backend (Apple GPU)\");\n"
  },
  {
    "path": "crates/burn-store/examples/burnpack_inspect.rs",
    "content": "//! Example: Generate a Burnpack file for inspection\n//!\n//! This example creates a simple Burnpack file that you can examine to understand the format.\n//!\n//! Usage:\n//!   cargo run --example burnpack-inspect [output_path]\n//!\n//! Example:\n//!   cargo run --example burnpack-inspect sample.bpk\n//!   cargo run --example burnpack-inspect /tmp/test.bpk\n//!\n//! After generating the file, examine it with:\n//!   hexdump -C sample.bpk | head -100\n//!   xxd sample.bpk | head -100\n//!   hexyl sample.bpk\nuse burn_core as burn;\n\nuse burn_core::module::Module;\nuse burn_ndarray::NdArray;\nuse burn_nn::{Linear, LinearConfig};\nuse burn_store::{BurnpackStore, ModuleSnapshot};\nuse burn_tensor::backend::Backend;\nuse std::env;\n\n// Simple model with a few layers\n#[derive(Module, Debug)]\nstruct SampleModel<B: Backend> {\n    linear1: Linear<B>,\n    linear2: Linear<B>,\n    linear3: Linear<B>,\n}\n\nimpl<B: Backend> SampleModel<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            linear1: LinearConfig::new(128, 64).init(device),\n            linear2: LinearConfig::new(64, 32).init(device),\n            linear3: LinearConfig::new(32, 10).init(device),\n        }\n    }\n}\n\nfn main() {\n    type Backend = NdArray<f32>;\n\n    // Get output path from command line or use default\n    let output_path = env::args()\n        .nth(1)\n        .unwrap_or_else(|| \"sample.bpk\".to_string());\n\n    println!(\"Creating sample Burnpack file: {}\", output_path);\n    println!();\n\n    // Create a simple model\n    let device = Default::default();\n    let model = SampleModel::<Backend>::new(&device);\n\n    // Save to Burnpack format with metadata\n    let mut store = BurnpackStore::from_file(&output_path)\n        .overwrite(true)\n        .metadata(\"format\", \"burnpack\")\n        .metadata(\"description\", \"Sample file for examining Burnpack format\")\n        .metadata(\"version\", env!(\"CARGO_PKG_VERSION\"))\n        .metadata(\"author\", \"Burn Example\");\n\n    model.save_into(&mut store).expect(\"Failed to save model\");\n\n    println!(\"✅ Successfully created: {}\", output_path);\n    println!();\n    println!(\"📋 File Structure:\");\n    println!(\"  ┌─────────────────────────────────────┐\");\n    println!(\"  │ Header (10 bytes)                   │\");\n    println!(\"  ├─────────────────────────────────────┤\");\n    println!(\"  │ - Magic: 0x4E525542 (BURN in LE)   │\");\n    println!(\"  │ - Version: 0x0001 (2 bytes)         │\");\n    println!(\"  │ - Metadata size: (4 bytes, u32 LE)  │\");\n    println!(\"  ├─────────────────────────────────────┤\");\n    println!(\"  │ Metadata (CBOR format)              │\");\n    println!(\"  ├─────────────────────────────────────┤\");\n    println!(\"  │ - Tensor descriptors                │\");\n    println!(\"  │   * name, dtype, shape, offsets     │\");\n    println!(\"  │ - User metadata                     │\");\n    println!(\"  ├─────────────────────────────────────┤\");\n    println!(\"  │ Tensor Data (raw bytes, LE)         │\");\n    println!(\"  ├─────────────────────────────────────┤\");\n    println!(\"  │ - linear1.weight [64, 128]          │\");\n    println!(\"  │ - linear1.bias [64]                 │\");\n    println!(\"  │ - linear2.weight [32, 64]           │\");\n    println!(\"  │ - linear2.bias [32]                 │\");\n    println!(\"  │ - linear3.weight [10, 32]           │\");\n    println!(\"  │ - linear3.bias [10]                 │\");\n    println!(\"  └─────────────────────────────────────┘\");\n    println!();\n    println!(\"📊 Model Contents:\");\n    println!(\"  - linear1.weight: [64, 128] = 8,192 params → 32,768 bytes\");\n    println!(\"  - linear1.bias:   [64]      = 64 params    → 256 bytes\");\n    println!(\"  - linear2.weight: [32, 64]  = 2,048 params → 8,192 bytes\");\n    println!(\"  - linear2.bias:   [32]      = 32 params    → 128 bytes\");\n    println!(\"  - linear3.weight: [10, 32]  = 320 params   → 1,280 bytes\");\n    println!(\"  - linear3.bias:   [10]      = 10 params    → 40 bytes\");\n    println!(\"  ───────────────────────────────────────────────────────\");\n\n    let total_params = 8192 + 64 + 2048 + 32 + 320 + 10;\n    let total_bytes = total_params * 4;\n    println!(\n        \"  Total: {} parameters = {} KB\",\n        total_params,\n        total_bytes / 1024\n    );\n    println!();\n\n    // Get actual file size\n    if let Ok(metadata) = std::fs::metadata(&output_path) {\n        let file_size = metadata.len();\n        println!(\n            \"📦 File size: {} bytes ({:.2} KB)\",\n            file_size,\n            file_size as f64 / 1024.0\n        );\n    }\n\n    println!();\n    println!(\"🔍 Inspection Commands:\");\n    println!();\n    println!(\"  # View first 100 bytes in hex:\");\n    println!(\"  hexdump -C {} | head -20\", output_path);\n    println!();\n    println!(\"  # View header only (10 bytes):\");\n    println!(\"  head -c 10 {} | hexdump -C\", output_path);\n    println!();\n    println!(\"  # View with prettier hex viewer (if installed):\");\n    println!(\"  hexyl {} | head -50\", output_path);\n    println!();\n    println!(\"  # View in binary format:\");\n    println!(\"  xxd -b {} | head -20\", output_path);\n    println!();\n    println!(\"  # Extract and examine header:\");\n    println!(\"  # Magic (bytes 0-3): Should be 42 55 52 4E (BURN)\");\n    println!(\"  # Version (bytes 4-5): Should be 01 00\");\n    println!(\"  # Metadata size (bytes 6-9): u32 little-endian\");\n    println!();\n    println!(\"  # Load back the model:\");\n    println!(\n        \"  # let mut store = BurnpackStore::from_file(\\\"{}\\\");\",\n        output_path\n    );\n    println!(\"  # model.load_from(&mut store)?;\");\n}\n"
  },
  {
    "path": "crates/burn-store/examples/half_precision.rs",
    "content": "//! Example: Save and load a model with half-precision (F32 <-> F16)\n//!\n//! Demonstrates using HalfPrecisionAdapter to automatically convert between\n//! F32 and F16 during saving/loading. The same adapter instance handles both\n//! directions.\n//!\n//! Usage:\n//!   cargo run -p burn-store --example half_precision\n\nuse burn_core as burn;\nuse burn_core::module::Module;\nuse burn_ndarray::NdArray;\nuse burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};\nuse burn_store::{BurnpackStore, HalfPrecisionAdapter, ModuleSnapshot};\nuse burn_tensor::backend::Backend;\n\n// A model with mixed layer types to show selective conversion\n#[derive(Module, Debug)]\nstruct DemoModel<B: Backend> {\n    linear1: Linear<B>,\n    norm: LayerNorm<B>,\n    linear2: Linear<B>,\n}\n\nimpl<B: Backend> DemoModel<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            linear1: LinearConfig::new(128, 64).init(device),\n            norm: LayerNormConfig::new(64).init(device),\n            linear2: LinearConfig::new(64, 10).init(device),\n        }\n    }\n}\n\nfn main() {\n    type B = NdArray<f32>;\n    let device = Default::default();\n    let model = DemoModel::<B>::new(&device);\n\n    // 1) Save at full F32 precision (baseline)\n    let dir = tempfile::tempdir().expect(\"Failed to create temp dir\");\n    let path_f32 = dir.path().join(\"model_f32\");\n    let path_f16 = dir.path().join(\"model_f16\");\n    let path_mixed = dir.path().join(\"model_mixed\");\n\n    let mut store = BurnpackStore::from_file(path_f32.to_str().unwrap()).overwrite(true);\n    model.save_into(&mut store).expect(\"Failed to save F32\");\n    let size_f32 = std::fs::metadata(format!(\"{}.bpk\", path_f32.display()))\n        .map(|m| m.len())\n        .unwrap_or(0);\n\n    // 2) Save with default half-precision (all default modules get F16)\n    let adapter = HalfPrecisionAdapter::new();\n    let mut store = BurnpackStore::from_file(path_f16.to_str().unwrap())\n        .overwrite(true)\n        .with_to_adapter(adapter.clone());\n    model.save_into(&mut store).expect(\"Failed to save F16\");\n    let size_f16 = std::fs::metadata(format!(\"{}.bpk\", path_f16.display()))\n        .map(|m| m.len())\n        .unwrap_or(0);\n\n    // 3) Save with without_module: keep LayerNorm at F32\n    let adapter_no_norm = HalfPrecisionAdapter::new().without_module(\"LayerNorm\");\n    let mut store = BurnpackStore::from_file(path_mixed.to_str().unwrap())\n        .overwrite(true)\n        .with_to_adapter(adapter_no_norm);\n    model.save_into(&mut store).expect(\"Failed to save mixed\");\n    let size_mixed = std::fs::metadata(format!(\"{}.bpk\", path_mixed.display()))\n        .map(|m| m.len())\n        .unwrap_or(0);\n\n    println!(\"F32 (full precision):    {} bytes\", size_f32);\n    println!(\"F16 (default modules):   {} bytes\", size_f16);\n    println!(\"Mixed (norm stays F32):  {} bytes\", size_mixed);\n    println!(\n        \"F16 savings: {:.1}%\",\n        (1.0 - size_f16 as f64 / size_f32 as f64) * 100.0\n    );\n\n    // 4) Round-trip: load the F16 file back to F32 with the same adapter\n    let mut load_store =\n        BurnpackStore::from_file(path_f16.to_str().unwrap()).with_from_adapter(adapter);\n    let mut model2 = DemoModel::<B>::new(&device);\n    let result = model2.load_from(&mut load_store).expect(\"Failed to load\");\n    println!(\n        \"\\nRound-trip loaded {} tensors successfully\",\n        result.applied.len()\n    );\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/Cargo.toml",
    "content": "[package]\nname = \"pytorch-tests\"\nversion.workspace = true\nedition.workspace = true\nlicense.workspace = true\n\n[dev-dependencies]\nburn = { path = \"../../burn\" }\nburn-ndarray = { path = \"../../burn-ndarray\" }\nburn-autodiff = { path = \"../../burn-autodiff\" }\nburn-store = { path = \"../\", features = [\"std\", \"pytorch\"] }\nserde = { workspace = true }\nfloat-cmp = { workspace = true }\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/src/lib.rs",
    "content": "\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/backend.rs",
    "content": "pub type TestBackend = burn_ndarray::NdArray<f32>;\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/batch_norm/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.norm1 = nn.BatchNorm2d(5)\n        \n    def forward(self, x):\n        x = self.norm1(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    # Condition batch norm (each forward will affect the running stats)\n    x1 = torch.ones(1, 5, 2, 2) - 0.5\n    _ = model(x1)\n    model.eval() # Set to eval mode to freeze running stats\n    # Save the model after the first forward\n    torch.save(model.state_dict(), \"batch_norm2d.pt\")\n    \n    x2 = torch.ones(1, 5, 2, 2) - 0.3\n    print(\"Input shape: {}\", x2.shape)\n    output = model(x2)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/batch_norm/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{BatchNorm, BatchNormConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    norm1: BatchNorm<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            norm1: BatchNormConfig::new(5).init(device), // Python model uses BatchNorm2d(5)\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.norm1.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    #[test]\n    fn batch_norm2d() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::new(&device);\n        let mut store = PytorchStore::from_file(\"tests/batch_norm/batch_norm2d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::ones([1, 5, 2, 2], &device) - 0.3;\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.68515635, 0.68515635], [0.68515635, 0.68515635]],\n                [[0.68515635, 0.68515635], [0.68515635, 0.68515635]],\n                [[0.68515635, 0.68515635], [0.68515635, 0.68515635]],\n                [[0.68515635, 0.68515635], [0.68515635, 0.68515635]],\n                [[0.68515635, 0.68515635], [0.68515635, 0.68515635]],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/boolean/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        buffer = torch.tensor([True, False, True])\n        self.register_buffer(\"buffer\", buffer, persistent=True)\n        \n    def forward(self, x):\n        x = self.buffer\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"boolean.pt\")\n    \n    input = torch.ones(3, 3)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/boolean/mod.rs",
    "content": "use burn::{\n    module::{Module, Param, ParamId},\n    tensor::{Bool, Tensor, TensorData, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    buffer: Param<Tensor<B, 1, Bool>>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model with placeholder values.\n    pub fn init(device: &B::Device) -> Self {\n        Self {\n            buffer: Param::initialized(\n                ParamId::new(),\n                Tensor::from_bool(TensorData::from([false, false, false]), device),\n            ),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Bool> {\n        self.buffer.val()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n\n    use burn::tensor::TensorData;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    use crate::backend::TestBackend;\n\n    #[test]\n    fn boolean() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/boolean/boolean.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 1, Bool>::from_bool(\n            TensorData::from([true, false, true]),\n            &device,\n        );\n\n        assert_eq!(output.to_data(), expected.to_data());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/buffer/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        buffer = torch.ones(3, 3)\n        self.register_buffer(\"buffer\", buffer, persistent=True)\n        \n    def forward(self, x):\n        x = self.buffer + x\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"buffer.pt\")\n    \n    input = torch.ones(3, 3)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/buffer/mod.rs",
    "content": "use burn::{\n    module::{Module, Param},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    buffer: Param<Tensor<B, 2>>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model with placeholder values.\n    pub fn init(device: &B::Device) -> Self {\n        Self {\n            buffer: Param::from_tensor(Tensor::zeros([3, 3], device)),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {\n        self.buffer.val() + x\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    #[test]\n    fn buffer() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/buffer/buffer.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 2>::ones([3, 3], &device) * 2.0;\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/complex_nested/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass ConvBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size):\n        super(ConvBlock, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)\n        self.norm = nn.BatchNorm2d(out_channels)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.norm(x)\n        return x\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n\n        self.conv_blocks = nn.Sequential(\n            ConvBlock(2, 4, (3, 2)),\n            ConvBlock(4, 6, (3, 2)),\n        )\n        self.norm1 = nn.BatchNorm2d(6)\n\n        self.fc1 = nn.Linear(120, 12)\n        self.fc2 = nn.Linear(12, 10)\n        \n    def forward(self, x):\n        x = self.conv_blocks(x)\n        x = self.norm1(x)\n        x = torch.flatten(x, 1)\n        x = self.fc1(x)\n        x = F.relu(x)\n        x = self.fc2(x)\n        x = F.log_softmax(x, dim=1)\n        return x\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(2)\n\n\n    model = Net().to(torch.device(\"cpu\"))\n\n    # Condition the model (batch norm requires a forward pass to compute the mean and variance)\n    x1 = torch.ones(1, 2, 9, 6) - 0.1\n    x2 = torch.ones(1, 2, 9, 6) - 0.3\n    output = model(x1)\n    output = model(x2)\n    model.eval() # set to eval mode\n\n    torch.save(model.state_dict(), \"complex_nested.pt\")\n\n    # feed test data\n    x = torch.ones(1, 2, 9, 6) - 0.5\n    output = model(x)\n    print(\"Input shape: {}\", x.shape)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/complex_nested/mod.rs",
    "content": "use burn::tensor::Tolerance;\nuse burn::tensor::ops::FloatElem;\nuse burn::{\n    module::Module,\n    nn::{\n        BatchNorm, BatchNormConfig, Linear, LinearConfig,\n        conv::{Conv2d, Conv2dConfig},\n    },\n    tensor::{\n        Tensor,\n        activation::{log_softmax, relu},\n        backend::Backend,\n    },\n};\nuse burn_autodiff::Autodiff;\nuse burn_store::{ModuleSnapshot, PytorchStore};\n\n#[derive(Module, Debug)]\npub struct ConvBlock<B: Backend> {\n    conv: Conv2d<B>,\n    norm: BatchNorm<B>,\n}\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv_blocks: Vec<ConvBlock<B>>,\n    norm1: BatchNorm<B>,\n    fc1: Linear<B>,\n    fc2: Linear<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    pub fn init(device: &B::Device) -> Self {\n        let conv_blocks = vec![\n            ConvBlock {\n                conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),\n                norm: BatchNormConfig::new(4).init(device), // matches conv output channels\n            },\n            ConvBlock {\n                conv: Conv2dConfig::new([4, 6], [3, 2]).init(device),\n                norm: BatchNormConfig::new(6).init(device), // matches conv output channels\n            },\n        ];\n        let norm1 = BatchNormConfig::new(6).init(device);\n        let fc1 = LinearConfig::new(120, 12).init(device);\n        let fc2 = LinearConfig::new(12, 10).init(device);\n\n        Self {\n            conv_blocks,\n            norm1,\n            fc1,\n            fc2,\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {\n        let x = self.conv_blocks[0].forward(x);\n        let x = self.conv_blocks[1].forward(x);\n        let x = self.norm1.forward(x);\n        let x = x.reshape([0, -1]);\n        let x = self.fc1.forward(x);\n        let x = relu(x);\n        let x = self.fc2.forward(x);\n\n        log_softmax(x, 1)\n    }\n}\n\nimpl<B: Backend> ConvBlock<B> {\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv.forward(x);\n\n        self.norm.forward(x)\n    }\n}\n\n/// Partial model to test loading of partial records.\n#[derive(Module, Debug)]\npub struct PartialNet<B: Backend> {\n    conv1: ConvBlock<B>,\n}\n\nimpl<B: Backend> PartialNet<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = ConvBlock {\n            conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),\n            norm: BatchNormConfig::new(4).init(device), // matches conv output channels\n        };\n        Self { conv1 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.conv1.forward(x)\n    }\n}\n\n/// Model with extra fields to test loading of records (e.g. from a different model).\n#[derive(Module, Debug)]\npub struct PartialWithExtraNet<B: Backend> {\n    conv1: ConvBlock<B>,\n    extra_field: bool, // This field is not present in the pytorch model\n}\n\nimpl<B: Backend> PartialWithExtraNet<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = ConvBlock {\n            conv: Conv2dConfig::new([2, 4], [3, 2]).init(device),\n            norm: BatchNormConfig::new(4).init(device), // matches conv output channels\n        };\n\n        Self {\n            conv1,\n            extra_field: true,\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.conv1.forward(x)\n    }\n}\n\ntype TestBackend = burn_ndarray::NdArray<f32>;\n\nfn model_test(model: Net<TestBackend>, precision: f32) {\n    let device = Default::default();\n\n    let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;\n\n    let output = model.forward(input);\n\n    let expected = Tensor::<TestBackend, 2>::from_data(\n        [[\n            -2.306_613,\n            -2.058_945_4,\n            -2.298_372_7,\n            -2.358_294,\n            -2.296_395_5,\n            -2.416_090_5,\n            -2.107_669,\n            -2.428_420_8,\n            -2.526_469,\n            -2.319_918_6,\n        ]],\n        &device,\n    );\n\n    output.to_data().assert_approx_eq::<FloatElem<TestBackend>>(\n        &expected.to_data(),\n        Tolerance::absolute(precision),\n    );\n}\n\n#[test]\nfn full_record() {\n    let device = Default::default();\n    let mut model = Net::<TestBackend>::init(&device);\n    let mut store = PytorchStore::from_file(\"tests/complex_nested/complex_nested.pt\");\n    model\n        .load_from(&mut store)\n        .expect(\"Should decode state successfully\");\n\n    model_test(model, 1e-8);\n}\n\n#[test]\nfn full_record_autodiff() {\n    let device = Default::default();\n    let mut model = Net::<Autodiff<TestBackend>>::init(&device);\n    let mut store = PytorchStore::from_file(\"tests/complex_nested/complex_nested.pt\");\n    model\n        .load_from(&mut store)\n        .expect(\"Should decode state successfully\");\n}\n\n#[test]\nfn half_record() {\n    let device = Default::default();\n    let mut model = Net::<TestBackend>::init(&device);\n    let mut store = PytorchStore::from_file(\"tests/complex_nested/complex_nested.pt\");\n    model\n        .load_from(&mut store)\n        .expect(\"Should decode state successfully\");\n\n    model_test(model, 1e-4);\n}\n\n#[test]\nfn partial_model_loading() {\n    let device = Default::default();\n    let mut model = PartialNet::<TestBackend>::init(&device);\n\n    // Load the full model but rename \"conv_blocks.0.*\" to \"conv1.*\"\n    let mut store = PytorchStore::from_file(\"tests/complex_nested/complex_nested.pt\")\n        .with_key_remapping(\"conv_blocks\\\\.0\\\\.(.*)\", \"conv1.$1\")\n        .allow_partial(true);\n\n    model\n        .load_from(&mut store)\n        .expect(\"Should decode state successfully\");\n\n    let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;\n\n    let output = model.forward(input);\n\n    // get the sum of all elements in the output tensor for quick check\n    let sum = output.sum();\n\n    assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);\n}\n\n#[test]\nfn extra_field_model_loading() {\n    let device = Default::default();\n    let mut model = PartialWithExtraNet::<TestBackend>::init(&device);\n\n    // Load the full model but rename \"conv_blocks.0.*\" to \"conv1.*\"\n    let mut store = PytorchStore::from_file(\"tests/complex_nested/complex_nested.pt\")\n        .with_key_remapping(\"conv_blocks\\\\.0\\\\.(.*)\", \"conv1.$1\")\n        .allow_partial(true);\n\n    model\n        .load_from(&mut store)\n        .expect(\"Should decode state successfully\");\n\n    let input = Tensor::<TestBackend, 4>::ones([1, 2, 9, 6], &device) - 0.5;\n\n    let output = model.forward(input);\n\n    // get the sum of all elements in the output tensor for quick check\n    let sum = output.sum();\n\n    assert!((sum.into_scalar() - 4.871538).abs() < 0.000002);\n\n    assert!(model.extra_field);\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/config/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.fc1 = nn.Linear(2, 3)\n        self.fc2 = nn.Linear(3, 4, bias=False)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2\n        x = self.fc2(x)\n\n        return x\n\nCONFIG = {\n    \"n_head\": 2,\n    \"n_layer\": 3,\n    \"d_model\": 512,\n    \"some_float\": 0.1,\n    \"some_int\": 1,\n    \"some_bool\": True,\n    \"some_str\": \"hello\",\n    \"some_list_int\": [1, 2, 3],\n    \"some_list_str\": [\"hello\", \"world\"],\n    \"some_list_float\": [0.1, 0.2, 0.3],\n    \"some_dict\": {\n        \"some_key\": \"some_value\"\n    }\n}\n\nclass ModelWithBias(nn.Module):\n    def __init__(self):\n        super(ModelWithBias, self).__init__()\n        self.fc1 = nn.Linear(2, 3)\n\n    def forward(self, x):\n        x = self.fc1(x)\n\n        return x\n\n\ndef main():\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    weights_with_config = {\n        \"my_model\": model.state_dict(),\n        \"my_config\": CONFIG\n    }\n\n    torch.save(weights_with_config, \"weights_with_config.pt\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/config/mod.rs",
    "content": "#![allow(clippy::too_many_arguments)] // To mute derive Config warning\nuse std::collections::HashMap;\n\nuse burn::config::Config;\n\n#[allow(clippy::too_many_arguments)]\n#[derive(Debug, PartialEq, Config)]\nstruct NetConfig {\n    n_head: usize,\n    n_layer: usize,\n    d_model: usize,\n    some_float: f64,\n    some_int: i32,\n    some_bool: bool,\n    some_str: String,\n    some_list_int: Vec<i32>,\n    some_list_str: Vec<String>,\n    some_list_float: Vec<f64>,\n    some_dict: HashMap<String, String>,\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_store::pytorch::PytorchReader;\n\n    use super::*;\n\n    #[test]\n    fn test_net_config() {\n        let config_expected = NetConfig {\n            n_head: 2,\n            n_layer: 3,\n            d_model: 512,\n            some_float: 0.1,\n            some_int: 1,\n            some_bool: true,\n            some_str: \"hello\".to_string(),\n            some_list_int: vec![1, 2, 3],\n            some_list_str: vec![\"hello\".to_string(), \"world\".to_string()],\n            some_list_float: vec![0.1, 0.2, 0.3],\n            some_dict: {\n                let mut map = HashMap::new();\n                map.insert(\"some_key\".to_string(), \"some_value\".to_string());\n                map\n            },\n        };\n        let path = \"tests/config/weights_with_config.pt\";\n        let top_level_key = Some(\"my_config\");\n        let config: NetConfig = PytorchReader::load_config(path, top_level_key).unwrap();\n\n        assert_eq!(config, config_expected);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv1d/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv1d(2, 2, 2)\n        self.conv2 = nn.Conv1d(2, 2, 2, bias=False)\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"conv1d.pt\")\n    \n    input = torch.rand(1, 2, 6)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv1d/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{Conv1d, Conv1dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv1d<B>,\n    conv2: Conv1d<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = Conv1dConfig::new(2, 2, 2).init(device);\n        let conv2 = Conv1dConfig::new(2, 2, 2).with_bias(false).init(device);\n\n        Self { conv1, conv2 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {\n        let x = self.conv1.forward(x);\n\n        self.conv2.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    fn conv1d(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            [[\n                [\n                    0.93708336, 0.65559506, 0.31379688, 0.19801933, 0.41619217, 0.28432965,\n                ],\n                [\n                    0.33977574,\n                    0.523_940_8,\n                    0.798_063_9,\n                    0.77176833,\n                    0.01122457,\n                    0.80996025,\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 3>::from_data(\n            [[\n                [0.02987457, 0.03134188, 0.04234261, -0.02437721],\n                [-0.03788019, -0.02972012, -0.00806090, -0.01981254],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn conv1d_full_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv1d/conv1d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv1d(model, 1e-7);\n    }\n\n    #[test]\n    fn conv1d_half_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv1d/conv1d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv1d(model, 1e-4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv2d/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv2d(2, 2, (2,2))\n        self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"conv2d.pt\")\n    \n    input = torch.rand(1, 2, 5, 5)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv2d/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{Conv2d, Conv2dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);\n        let conv2 = Conv2dConfig::new([2, 2], [2, 2])\n            .with_bias(false)\n            .init(device);\n\n        Self { conv1, conv2 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv1.forward(x);\n\n        self.conv2.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn conv2d(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [\n                        0.024_595_8,\n                        0.25883394,\n                        0.93905586,\n                        0.416_715_5,\n                        0.713_979_7,\n                    ],\n                    [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],\n                    [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],\n                    [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],\n                    [\n                        0.99802476,\n                        0.900_794_2,\n                        0.476_588_2,\n                        0.16625845,\n                        0.804_481_1,\n                    ],\n                ],\n                [\n                    [\n                        0.65517855,\n                        0.17679012,\n                        0.824_772_3,\n                        0.803_550_9,\n                        0.943_447_5,\n                    ],\n                    [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],\n                    [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],\n                    [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],\n                    [\n                        0.751_675_7,\n                        0.148_438_4,\n                        0.12274551,\n                        0.530_407_2,\n                        0.414_796_4,\n                    ],\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [-0.02502128, 0.00250649, 0.04841233],\n                    [0.04589614, -0.00296854, 0.01991477],\n                    [0.02920526, 0.059_497_3, 0.04326791],\n                ],\n                [\n                    [-0.04825336, 0.080_190_9, -0.02375088],\n                    [0.02885434, 0.09638263, -0.07460806],\n                    [0.02004079, 0.06244051, 0.035_887_1],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn conv2d_full_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv2d/conv2d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv2d(model, 1e-7);\n    }\n\n    #[test]\n    fn conv2d_half_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv2d/conv2d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv2d(model, 1e-4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv_transpose1d/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.ConvTranspose1d(2, 2, 2)\n        self.conv2 = nn.ConvTranspose1d(2, 2, 2, bias=False)\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"conv_transpose1d.pt\")\n    \n    input = torch.rand(1, 2, 2)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv_transpose1d/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{ConvTranspose1d, ConvTranspose1dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: ConvTranspose1d<B>,\n    conv2: ConvTranspose1d<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = ConvTranspose1dConfig::new([2, 2], 2).init(device);\n        let conv2 = ConvTranspose1dConfig::new([2, 2], 2)\n            .with_bias(false)\n            .init(device);\n\n        Self { conv1, conv2 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {\n        let x = self.conv1.forward(x);\n\n        self.conv2.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn conv_transpose1d(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 3>::from_data(\n            [[[0.93708336, 0.65559506], [0.31379688, 0.19801933]]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 3>::from_data(\n            [[\n                [0.02935525, 0.01119324, -0.01356167, -0.00682688],\n                [0.01644749, -0.01429807, 0.00083987, 0.00279229],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn conv_transpose1d_full() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv_transpose1d/conv_transpose1d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv_transpose1d(model, 1e-8);\n    }\n\n    #[test]\n    fn conv_transpose1d_half() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv_transpose1d/conv_transpose1d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv_transpose1d(model, 1e-4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv_transpose2d/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.ConvTranspose2d(2, 2, (2, 2))\n        self.conv2 = nn.ConvTranspose2d(2, 2, (2, 2), bias=False)\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"conv_transpose2d.pt\")\n    \n    input = torch.rand(1, 2, 2, 2)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/conv_transpose2d/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{ConvTranspose2d, ConvTranspose2dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: ConvTranspose2d<B>,\n    conv2: ConvTranspose2d<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = ConvTranspose2dConfig::new([2, 2], [2, 2]).init(device);\n        let conv2 = ConvTranspose2dConfig::new([2, 2], [2, 2])\n            .with_bias(false)\n            .init(device);\n\n        Self { conv1, conv2 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv1.forward(x);\n\n        self.conv2.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn conv_transpose2d(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],\n                [[0.713_979_7, 0.267_644_3], [0.990_609, 0.28845078]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.04547675, 0.01879685, -0.01636661, 0.00310803],\n                    [0.02090115, 0.01192738, -0.048_240_2, 0.02252235],\n                    [0.03249975, -0.00460748, 0.05003899, 0.04029131],\n                    [0.02185687, -0.10226749, -0.06508022, -0.01267705],\n                ],\n                [\n                    [0.00277598, -0.00513832, -0.059_048_3, 0.00567626],\n                    [-0.03149522, -0.195_757_4, 0.03474613, 0.01997269],\n                    [-0.10096474, 0.00679589, 0.041_919_7, -0.02464108],\n                    [-0.03174751, 0.02963913, -0.02703723, -0.01860938],\n                ],\n            ]],\n            &device,\n        );\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn conv_transpose2d_full() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv_transpose2d/conv_transpose2d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv_transpose2d(model, 1e-7);\n    }\n\n    #[test]\n    fn conv_transpose2d_half() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/conv_transpose2d/conv_transpose2d.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        conv_transpose2d(model, 1e-4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/embedding/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.embed = nn.Embedding(10, 3)\n        \n    def forward(self, x):\n        x = self.embed(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"embedding.pt\")\n    \n    input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/embedding/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{Embedding, EmbeddingConfig},\n    tensor::{Int, Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    embed: Embedding<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model.\n    pub fn init(device: &B::Device) -> Self {\n        let embed = EmbeddingConfig::new(10, 3).init(device);\n        Self { embed }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {\n        self.embed.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn embedding(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 2, Int>::from_data([[1, 2, 4, 5], [4, 3, 2, 9]], &device);\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 3>::from_data(\n            [\n                [\n                    [-1.609_484_9, -0.10016718, -0.609_188_9],\n                    [-0.97977227, -1.609_096_3, -0.712_144_6],\n                    [-0.22227049, 1.687_113_4, -0.32062083],\n                    [-0.29934573, 1.879_345_7, -0.07213178],\n                ],\n                [\n                    [-0.22227049, 1.687_113_4, -0.32062083],\n                    [0.303_722, -0.777_314_3, -0.25145486],\n                    [-0.97977227, -1.609_096_3, -0.712_144_6],\n                    [-0.02878714, 2.357_111, -1.037_338_7],\n                ],\n            ],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn embedding_full_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/embedding/embedding.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        embedding(model, 1e-3);\n    }\n\n    #[test]\n    fn embedding_half_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/embedding/embedding.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        embedding(model, 1e-3);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/enum_module/export_weights.py",
    "content": "#!/usr/bin/env python3\nimport torch\nfrom torch import nn, Tensor\n\nclass DwsConv(nn.Module):\n    \"\"\"Depthwise separable convolution.\"\"\"\n\n    def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:\n        super().__init__()\n        # Depthwise conv\n        self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)\n        # Pointwise conv\n        self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.dconv(x)\n        return self.pconv(x)\n\n\nclass Model(nn.Module):\n    def __init__(self, depthwise: bool = False) -> None:\n        super().__init__()\n        self.conv = DwsConv(2, 2, 3) if depthwise else nn.Conv2d(2, 2, 3)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.conv(x)\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"enum_depthwise_false.pt\")\n\n    input = torch.rand(1, 2, 5, 5)\n\n    print(\"Depthwise is False\")\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n\n    print(\"Depthwise is True\")\n    model = Model(depthwise=True).to(torch.device(\"cpu\"))\n    torch.save(model.state_dict(), \"enum_depthwise_true.pt\")\n\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/enum_module/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{Conv2d, Conv2dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\n#[allow(clippy::large_enum_variant)]\npub enum Conv<B: Backend> {\n    DwsConv(DwsConv<B>),\n    Conv(Conv2d<B>),\n}\n\n#[derive(Module, Debug)]\npub struct DwsConv<B: Backend> {\n    dconv: Conv2d<B>,\n    pconv: Conv2d<B>,\n}\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv: Conv<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model with DwsConv variant.\n    pub fn init_dws_conv(device: &B::Device) -> Self {\n        let dconv = Conv2dConfig::new([2, 2], [3, 3])\n            .with_groups(2)\n            .init(device);\n        let pconv = Conv2dConfig::new([2, 2], [1, 1])\n            .with_groups(1)\n            .init(device);\n        Net {\n            conv: Conv::DwsConv(DwsConv { dconv, pconv }),\n        }\n    }\n\n    /// Create a new model with Conv variant.\n    pub fn init_conv(device: &B::Device) -> Self {\n        let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]);\n        Net {\n            conv: Conv::Conv(conv2d_config.init(device)),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        match &self.conv {\n            Conv::DwsConv(dws_conv) => {\n                let x = dws_conv.dconv.forward(x);\n                dws_conv.pconv.forward(x)\n            }\n            Conv::Conv(conv) => conv.forward(x),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    #[test]\n    fn depthwise_false() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init_conv(&device);\n        let mut store = PytorchStore::from_file(\"tests/enum_module/enum_depthwise_false.pt\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],\n                    [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],\n                    [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],\n                    [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],\n                    [\n                        0.804_481_1,\n                        0.65517855,\n                        0.17679012,\n                        0.824_772_3,\n                        0.803_550_9,\n                    ],\n                ],\n                [\n                    [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],\n                    [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],\n                    [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],\n                    [\n                        0.03694397,\n                        0.751_675_7,\n                        0.148_438_4,\n                        0.12274551,\n                        0.530_407_2,\n                    ],\n                    [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.35449377, -0.02832414, 0.490_976_1],\n                    [0.29709217, 0.332_586_3, 0.30594018],\n                    [0.18101373, 0.30932188, 0.30558896],\n                ],\n                [\n                    [-0.17683622, -0.13244139, -0.05608707],\n                    [0.23467252, -0.07038684, 0.255_044_1],\n                    [-0.241_931_3, -0.20476191, -0.14468731],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn depthwise_true() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init_dws_conv(&device);\n        let mut store = PytorchStore::from_file(\"tests/enum_module/enum_depthwise_true.pt\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.713_979_7, 0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4],\n                    [0.505_920_8, 0.23659128, 0.757_007_4, 0.23458993, 0.64705235],\n                    [0.355_621_4, 0.445_182_8, 0.01930594, 0.26160914, 0.771_317],\n                    [0.37846136, 0.99802476, 0.900_794_2, 0.476_588_2, 0.16625845],\n                    [\n                        0.804_481_1,\n                        0.65517855,\n                        0.17679012,\n                        0.824_772_3,\n                        0.803_550_9,\n                    ],\n                ],\n                [\n                    [0.943_447_5, 0.21972018, 0.417_697, 0.49031407, 0.57302874],\n                    [0.12054086, 0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7],\n                    [0.52850497, 0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537],\n                    [\n                        0.03694397,\n                        0.751_675_7,\n                        0.148_438_4,\n                        0.12274551,\n                        0.530_407_2,\n                    ],\n                    [0.414_796_4, 0.793_662, 0.21043217, 0.05550903, 0.863_884_4],\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.77874625, 0.859_017_6, 0.834_283_5],\n                    [0.773_056_4, 0.73817325, 0.78292674],\n                    [0.710_775_2, 0.747_187_2, 0.733_264_4],\n                ],\n                [\n                    [-0.44891885, -0.49027523, -0.394_170_7],\n                    [-0.43836114, -0.33961445, -0.387_311_5],\n                    [-0.581_134_3, -0.34197026, -0.535_035_7],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/group_norm/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.norm1 = nn.GroupNorm(2, 6)\n        \n    def forward(self, x):\n        x = self.norm1(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"group_norm.pt\")\n    \n    x2 = torch.rand(1, 6, 2, 2)\n    print(\"Input shape: {}\", x2.shape)\n    print(\"Input: {}\", x2)\n    output = model(x2)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/group_norm/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{GroupNorm, GroupNormConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    norm1: GroupNorm<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model from the given record.\n    pub fn init(device: &B::Device) -> Self {\n        let norm1 = GroupNormConfig::new(2, 6).init(device);\n        Self { norm1 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.norm1.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn group_norm(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],\n                [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],\n                [[0.569_508_5, 0.43877792], [0.63868046, 0.524_665_9]],\n                [[0.682_614_1, 0.305_149_5], [0.46354562, 0.45498633]],\n                [[0.572_472, 0.498_002_6], [0.93708336, 0.65559506]],\n                [[0.31379688, 0.19801933], [0.41619217, 0.28432965]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[1.042_578_5, -1.122_016_7], [-0.56195974, 0.938_733_6]],\n                [[-2.253_500_7, 1.233_672_9], [-0.588_804_1, 1.027_827_3]],\n                [[0.19124532, -0.40036356], [0.504_276_5, -0.01168585]],\n                [[1.013_829_2, -0.891_984_6], [-0.09224463, -0.13546038]],\n                [[0.45772314, 0.08172822], [2.298_641_4, 0.877_410_4]],\n                [[-0.84832406, -1.432_883_4], [-0.331_331_5, -0.997_103_7]],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn group_norm_full() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/group_norm/group_norm.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        group_norm(model, 1e-3);\n    }\n\n    #[test]\n    fn group_norm_half() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/group_norm/group_norm.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        group_norm(model, 1e-3);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/integer/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        buffer = torch.tensor([1, 2, 3])\n        self.register_buffer(\"buffer\", buffer, persistent=True)\n        \n    def forward(self, x):\n        x = self.buffer\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"integer.pt\")\n    \n    input = torch.ones(3, 3)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/integer/mod.rs",
    "content": "use burn::{\n    module::{Module, Param, ParamId},\n    tensor::{Int, Tensor, TensorData, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    buffer: Param<Tensor<B, 1, Int>>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model with placeholder values.\n    pub fn init(device: &B::Device) -> Self {\n        Self {\n            buffer: Param::initialized(\n                ParamId::new(),\n                Tensor::<B, 1, Int>::from_data(TensorData::from([0, 0, 0]), device),\n            ),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, _x: Tensor<B, 2>) -> Tensor<B, 1, Int> {\n        self.buffer.val()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n    use burn::tensor::TensorData;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    fn integer(model: Net<TestBackend>) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 2>::ones([3, 3], &device);\n\n        let output = model.forward(input);\n\n        let expected =\n            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 2, 3]), &device);\n\n        assert_eq!(output.to_data(), expected.to_data());\n    }\n\n    #[test]\n    fn integer_full_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/integer/integer.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        integer(model);\n    }\n\n    #[test]\n    fn integer_half_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/integer/integer.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        integer(model);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/key_remap/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass ConvModule(nn.Module):\n    def __init__(self):\n        super(ConvModule, self).__init__()\n        self.conv1 = nn.Conv2d(2, 2, (2,2))\n        self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        return x\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv = ConvModule()\n    \n    def forward(self, x):\n        x = self.conv(x)\n        return x\n    \n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"key_remap.pt\")\n    \n    input = torch.rand(1, 2, 5, 5)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/key_remap/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::conv::{Conv2d, Conv2dConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model.\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = Conv2dConfig::new([2, 2], [2, 2]).init(device);\n        let conv2 = Conv2dConfig::new([2, 2], [2, 2])\n            .with_bias(false)\n            .init(device);\n        Self { conv1, conv2 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv1.forward(x);\n\n        self.conv2.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    #[test]\n    fn key_remap() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/key_remap/key_remap.pt\")\n            .with_key_remapping(\"conv\\\\.(.*)\", \"$1\"); // Remove \"conv\" prefix, e.g. \"conv.conv1\" -> \"conv1\"\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [\n                        0.024_595_8,\n                        0.25883394,\n                        0.93905586,\n                        0.416_715_5,\n                        0.713_979_7,\n                    ],\n                    [0.267_644_3, 0.990_609, 0.28845078, 0.874_962_4, 0.505_920_8],\n                    [0.23659128, 0.757_007_4, 0.23458993, 0.64705235, 0.355_621_4],\n                    [0.445_182_8, 0.01930594, 0.26160914, 0.771_317, 0.37846136],\n                    [\n                        0.99802476,\n                        0.900_794_2,\n                        0.476_588_2,\n                        0.16625845,\n                        0.804_481_1,\n                    ],\n                ],\n                [\n                    [\n                        0.65517855,\n                        0.17679012,\n                        0.824_772_3,\n                        0.803_550_9,\n                        0.943_447_5,\n                    ],\n                    [0.21972018, 0.417_697, 0.49031407, 0.57302874, 0.12054086],\n                    [0.14518881, 0.772_002_3, 0.38275403, 0.744_236_7, 0.52850497],\n                    [0.664_172_4, 0.60994434, 0.681_799_7, 0.74785537, 0.03694397],\n                    [\n                        0.751_675_7,\n                        0.148_438_4,\n                        0.12274551,\n                        0.530_407_2,\n                        0.414_796_4,\n                    ],\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [-0.02502128, 0.00250649, 0.04841233],\n                    [0.04589614, -0.00296854, 0.01991477],\n                    [0.02920526, 0.059_497_3, 0.04326791],\n                ],\n                [\n                    [-0.04825336, 0.080_190_9, -0.02375088],\n                    [0.02885434, 0.09638263, -0.07460806],\n                    [0.02004079, 0.06244051, 0.035_887_1],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/key_remap_chained/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nfrom torch import nn, Tensor\n\n\nclass ConvBlock(nn.Module):\n    def __init__(self, in_channels: int, out_channels: int):\n        super().__init__()\n        self.block = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, 1, bias=False),\n            nn.BatchNorm2d(out_channels),\n        )\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.block(x)\n\n\nclass Model(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = nn.Conv2d(3, 6, 3, bias=False)\n        self.bn = nn.BatchNorm2d(6)\n        self.layer = nn.Sequential(ConvBlock(6, 6), ConvBlock(6, 6))\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.conv(x)\n        x = self.bn(x)\n        x = self.layer(x)\n\n        return x\n\n\ndef main():\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(42)\n\n    model = Model()\n\n    input = torch.rand(1, 3, 4, 4)\n    model(input)  # condition batch norm\n    model.eval()\n\n    with torch.no_grad():\n        print(f\"Input shape: {input.shape}\")\n        print(\"Input type: {}\", input.dtype)\n        print(f\"Input: {input}\")\n        output = model(input)\n\n    print(f\"Output: {output}\")\n    print(f\"Output Shape: {output.shape}\")\n\n    torch.save(model.state_dict(), \"key_remap.pt\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/key_remap_chained/mod.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn::{\n    module::Module,\n    nn::{\n        BatchNorm, BatchNormConfig,\n        conv::{Conv2d, Conv2dConfig},\n    },\n    tensor::{Device, Tensor, backend::Backend},\n};\n\n/// Some module that implements a specific method so it can be used in a sequential block.\npub trait ForwardModule<B: Backend> {\n    fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;\n}\n\n/// Conv2d + BatchNorm block.\n#[derive(Module, Debug)]\npub struct ConvBlock<B: Backend> {\n    conv: Conv2d<B>,\n    bn: BatchNorm<B>,\n}\n\nimpl<B: Backend> ForwardModule<B> for ConvBlock<B> {\n    fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let out = self.conv.forward(input);\n        self.bn.forward(out)\n    }\n}\n\nimpl<B: Backend> ConvBlock<B> {\n    pub fn new(in_channels: usize, out_channels: usize, device: &Device<B>) -> Self {\n        let conv = Conv2dConfig::new([in_channels, out_channels], [1, 1])\n            .with_bias(false)\n            .init(device);\n        let bn = BatchNormConfig::new(out_channels).init(device);\n\n        Self { conv, bn }\n    }\n}\n\n/// Collection of sequential blocks.\n#[derive(Module, Debug)]\npub struct ModuleBlock<B: Backend, M> {\n    blocks: Vec<M>,\n    _backend: PhantomData<B>,\n}\n\nimpl<B: Backend, M: ForwardModule<B>> ModuleBlock<B, M> {\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let mut out = input;\n        for block in &self.blocks {\n            out = block.forward(out);\n        }\n        out\n    }\n}\n\nimpl<B: Backend> ModuleBlock<B, ConvBlock<B>> {\n    pub fn new(device: &Device<B>) -> Self {\n        let blocks = vec![ConvBlock::new(6, 6, device), ConvBlock::new(6, 6, device)];\n\n        Self {\n            blocks,\n            _backend: PhantomData,\n        }\n    }\n}\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend, M> {\n    conv: Conv2d<B>,\n    bn: BatchNorm<B>,\n    layer: ModuleBlock<B, M>,\n}\n\nimpl<B: Backend> Model<B, ConvBlock<B>> {\n    pub fn new(device: &Device<B>) -> Self {\n        let conv = Conv2dConfig::new([3, 6], [3, 3])\n            .with_bias(false)\n            .init(device);\n        let bn = BatchNormConfig::new(6).init(device);\n\n        let layer = ModuleBlock::new(device);\n\n        Self { conv, bn, layer }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let out = self.conv.forward(input);\n        let out = self.bn.forward(out);\n        self.layer.forward(out)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    #[test]\n    #[should_panic]\n    fn key_remap_chained_missing_pattern() {\n        // Loading record should fail due to missing pattern to map the layer.blocks\n        let device = Default::default();\n        let mut model: Model<TestBackend, _> = Model::new(&device);\n        let mut store = PytorchStore::from_file(\"tests/key_remap_chained/key_remap.pt\")\n            // Map *.block.0.* -> *.conv.*\n            .with_key_remapping(\"(.+)\\\\.block\\\\.0\\\\.(.+)\", \"$1.conv.$2\")\n            // Map *.block.1.* -> *.bn.*\n            .with_key_remapping(\"(.+)\\\\.block\\\\.1\\\\.(.+)\", \"$1.bn.$2\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n    }\n\n    #[test]\n    fn key_remap_chained() {\n        let device = Default::default();\n        let mut model: Model<TestBackend, _> = Model::new(&device);\n        let mut store = PytorchStore::from_file(\"tests/key_remap_chained/key_remap.pt\")\n            // Map *.block.0.* -> *.conv.*\n            .with_key_remapping(\"(.+)\\\\.block\\\\.0\\\\.(.+)\", \"$1.conv.$2\")\n            // Map *.block.1.* -> *.bn.*\n            .with_key_remapping(\"(.+)\\\\.block\\\\.1\\\\.(.+)\", \"$1.bn.$2\")\n            // Map layer.[i].* -> layer.blocks.[i].*\n            .with_key_remapping(\"layer\\\\.([0-9])\\\\.(.+)\", \"layer.blocks.$1.$2\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.76193494, 0.626_546_1, 0.49510366, 0.11974698],\n                    [0.07161391, 0.03232569, 0.704_681, 0.254_516],\n                    [0.399_373_7, 0.21224737, 0.40888822, 0.14808255],\n                    [0.17329216, 0.665_855_4, 0.351_401_8, 0.808_671_6],\n                ],\n                [\n                    [0.33959562, 0.13321638, 0.41178054, 0.257_626_3],\n                    [0.347_029_2, 0.02400219, 0.77974546, 0.15189773],\n                    [0.75130886, 0.726_892_1, 0.85721636, 0.11647397],\n                    [0.859_598_4, 0.263_624_2, 0.685_534_6, 0.96955734],\n                ],\n                [\n                    [0.42948407, 0.49613327, 0.38488472, 0.08250773],\n                    [0.73995143, 0.00364107, 0.81039995, 0.87411255],\n                    [0.972_853_2, 0.38206023, 0.08917904, 0.61241513],\n                    [0.77621365, 0.00234562, 0.38650817, 0.20027226],\n                ],\n            ]],\n            &device,\n        );\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.198_967_1, 0.17847246], [0.06883702, 0.20012866]],\n                [[0.17582723, 0.11344293], [0.05444185, 0.13307181]],\n                [[0.192_229_5, 0.20391327], [0.06150475, 0.22688155]],\n                [[0.00230906, -0.02177845], [0.01129148, 0.00925517]],\n                [[0.14751078, 0.14433631], [0.05498439, 0.29049855]],\n                [[0.16868964, 0.133_269_3], [0.06917118, 0.35094324]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/layer_norm/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.norm1 = nn.LayerNorm(2)\n        \n    def forward(self, x):\n        x = self.norm1(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"layer_norm.pt\")\n    \n    x2 = torch.rand(1, 2, 2, 2)\n    print(\"Input shape: {}\", x2.shape)\n    print(\"Input: {}\", x2)\n    output = model(x2)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/layer_norm/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{LayerNorm, LayerNormConfig},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    norm1: LayerNorm<B>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model.\n    pub fn init(device: &B::Device) -> Self {\n        let norm1 = LayerNormConfig::new(2).init(device);\n        Self { norm1 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.norm1.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    fn layer_norm(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.757_631_6, 0.27931088], [0.40306926, 0.73468447]],\n                [[0.02928156, 0.799_858_6], [0.39713734, 0.75437194]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.99991274, -0.999_912_5], [-0.999_818_3, 0.999_818_3]],\n                [[-0.999_966_2, 0.99996626], [-0.99984336, 0.99984336]],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn layer_norm_full() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/layer_norm/layer_norm.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n        layer_norm(model, 1e-3);\n    }\n\n    #[test]\n    fn layer_norm_half() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/layer_norm/layer_norm.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n        layer_norm(model, 1e-3);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/linear/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.fc1 = nn.Linear(2, 3)\n        self.fc2 = nn.Linear(3, 4, bias=False)\n        \n    def forward(self, x):\n        x = self.fc1(x)\n        x = F.relu(x) # Add relu so that PyTorch optimizer does not combine fc1 and fc2\n        x = self.fc2(x)\n\n        return x\n\n\nclass ModelWithBias(nn.Module):\n    def __init__(self):\n        super(ModelWithBias, self).__init__()\n        self.fc1 = nn.Linear(2, 3)\n        \n    def forward(self, x):\n        x = self.fc1(x)\n\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n    model_with_bias = ModelWithBias().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"linear.pt\")\n    torch.save(model_with_bias.state_dict(), \"linear_with_bias.pt\")\n    \n    input = torch.rand(1, 2, 2, 2)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    print(\"Model with bias\")\n    output = model_with_bias(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n    \n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/linear/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{Linear, LinearConfig, Relu},\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    fc1: Linear<B>,\n    fc2: Linear<B>,\n    relu: Relu,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model.\n    pub fn init(device: &B::Device) -> Self {\n        let fc1 = LinearConfig::new(2, 3).init(device);\n        let fc2 = LinearConfig::new(3, 4).with_bias(false).init(device);\n        let relu = Relu;\n\n        Self { fc1, fc2, relu }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.fc1.forward(x);\n        let x = self.relu.forward(x);\n\n        self.fc2.forward(x)\n    }\n}\n\n#[derive(Module, Debug)]\nstruct NetWithBias<B: Backend> {\n    fc1: Linear<B>,\n}\n\nimpl<B: Backend> NetWithBias<B> {\n    /// Create a new model.\n    pub fn init(device: &B::Device) -> Self {\n        let fc1 = LinearConfig::new(2, 3).init(device);\n\n        Self { fc1 }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.fc1.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    fn linear_test(model: Net<TestBackend>, precision: f32) {\n        let device = Default::default();\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],\n                [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.09778349, -0.13756673, 0.04962806, 0.08856435],\n                    [0.03163241, -0.02848549, 0.01437942, 0.11905234],\n                ],\n                [\n                    [0.07628226, -0.10757702, 0.03656857, 0.03824598],\n                    [0.05443089, -0.06904714, 0.02744314, 0.09997337],\n                ],\n            ]],\n            &device,\n        );\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(precision));\n    }\n\n    #[test]\n    fn linear_full_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/linear/linear.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        linear_test(model, 1e-7);\n    }\n\n    #[test]\n    fn linear_half_precision() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/linear/linear.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        linear_test(model, 1e-4);\n    }\n\n    #[test]\n    fn linear_with_bias() {\n        let device = Default::default();\n\n        let mut model = NetWithBias::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/linear/linear_with_bias.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [[0.63968194, 0.97427773], [0.830_029_9, 0.04443115]],\n                [[0.024_595_8, 0.25883394], [0.93905586, 0.416_715_5]],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [-0.00432095, -1.107_101_2, 0.870_691_4],\n                    [0.024_595_5, -0.954_462_9, 0.48518157],\n                ],\n                [\n                    [0.34315687, -0.757_384_2, 0.548_288],\n                    [-0.06608963, -1.072_072_7, 0.645_800_5],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/missing_module_field/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv2d(2, 2, (2,2))\n\n    def forward(self, x):\n        x = self.conv1(x)\n        return x\n\n\ndef main():\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n    model = Model().to(torch.device(\"cpu\"))\n    torch.save(model.state_dict(), \"missing_module_field.pt\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/missing_module_field/mod.rs",
    "content": "use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};\n\n#[derive(Module, Debug)]\n#[allow(unused)]\npub struct Net<B: Backend> {\n    do_not_exist_in_pt: Conv2d<B>,\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::nn::conv::Conv2dConfig;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    impl<B: Backend> Net<B> {\n        pub fn init(device: &B::Device) -> Self {\n            Self {\n                do_not_exist_in_pt: Conv2dConfig::new([2, 2], [2, 2]).init(device),\n            }\n        }\n    }\n\n    #[test]\n    #[should_panic(expected = \"do_not_exist_in_pt\")]\n    fn should_fail_if_struct_field_is_missing() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store =\n            PytorchStore::from_file(\"tests/missing_module_field/missing_module_field.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/non_contiguous_indexes/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        num_layers = 5  # Number of repeated convolutional layers\n\n        # Create a list to store the layers\n        layers = []\n        for _ in range(num_layers):\n            layers.append(nn.Conv2d(2, 2, kernel_size=3, padding=1, bias=True))\n            layers.append(nn.ReLU(inplace=True))\n\n        # Use nn.Sequential to create a single module from the layers\n        self.fc = nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.fc(x)\n        return x\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    torch.save(model.state_dict(), \"non_contiguous_indexes.pt\")\n\n    input = torch.rand(1, 2, 5, 5)\n    print(\"Input shape: {}\", input.shape)\n    print(\"Input: {}\", input)\n    output = model(input)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/non_contiguous_indexes/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{\n        PaddingConfig2d,\n        conv::{Conv2d, Conv2dConfig},\n    },\n    tensor::{Tensor, activation::relu, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    fc: Vec<Conv2d<B>>,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new model with placeholder values.\n    pub fn init(device: &B::Device) -> Self {\n        let conv2d_config = Conv2dConfig::new([2, 2], [3, 3]).with_padding(PaddingConfig2d::Same);\n        // The PyTorch file has 5 Conv2d layers at non-contiguous indices (0, 2, 4, 6, 8)\n        // in the Sequential (alternating with ReLU layers)\n        let fc = vec![\n            conv2d_config.init(device),\n            conv2d_config.init(device),\n            conv2d_config.init(device),\n            conv2d_config.init(device),\n            conv2d_config.init(device),\n        ];\n        Net { fc }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        self.fc.iter().fold(x, |x_i, conv| relu(conv.forward(x_i)))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::{Tolerance, ops::FloatElem};\n    use burn_store::{ModuleSnapshot, PytorchStore};\n    type FT = FloatElem<TestBackend>;\n\n    use super::*;\n\n    #[test]\n    fn non_contiguous_indexes() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store =\n            PytorchStore::from_file(\"tests/non_contiguous_indexes/non_contiguous_indexes.pt\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [\n                        0.67890584,\n                        0.307_537_2,\n                        0.265_156_2,\n                        0.528_318_8,\n                        0.86194897,\n                    ],\n                    [0.14828813, 0.73480314, 0.821_220_7, 0.989_098_6, 0.15003455],\n                    [0.62109494, 0.13028657, 0.926_875_1, 0.30604684, 0.80117637],\n                    [0.514_885_7, 0.46105868, 0.484_046_1, 0.58499724, 0.73569804],\n                    [0.58018994, 0.65252745, 0.05023766, 0.864_268_7, 0.935_932],\n                ],\n                [\n                    [0.913_302_9, 0.869_611_3, 0.139_184_3, 0.314_65, 0.94086266],\n                    [0.11917073, 0.953_610_6, 0.10675198, 0.14779574, 0.744_439],\n                    [0.14075547, 0.38544965, 0.863_745_9, 0.89604443, 0.97287786],\n                    [0.39854127, 0.11136961, 0.99230546, 0.39348692, 0.29428244],\n                    [0.621_886_9, 0.15033776, 0.828_640_1, 0.81336635, 0.10325938],\n                ],\n            ]],\n            &device,\n        );\n\n        let output = model.forward(input);\n\n        let expected = Tensor::<TestBackend, 4>::from_data(\n            [[\n                [\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.04485746, 0.03582812, 0.03432692, 0.02892298, 0.013_844_3],\n                ],\n                [\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                    [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000],\n                ],\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::absolute(1e-7));\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/test_mod.rs",
    "content": "mod backend;\n\nmod batch_norm;\nmod boolean;\nmod buffer;\nmod complex_nested;\nmod config;\nmod conv1d;\nmod conv2d;\nmod conv_transpose1d;\nmod conv_transpose2d;\nmod embedding;\nmod enum_module;\nmod group_norm;\nmod integer;\nmod key_remap;\nmod key_remap_chained;\nmod layer_norm;\nmod linear;\nmod missing_module_field;\nmod non_contiguous_indexes;\nmod top_level_key;\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/top_level_key/export_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv2d(2, 2, (2,2))\n\n    def forward(self, x):\n        x = self.conv1(x)\n        return x\n\n\ndef main():\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n    model = Model().to(torch.device(\"cpu\"))\n    torch.save({\"my_state_dict\": model.state_dict()}, \"top_level_key.pt\")\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "crates/burn-store/pytorch-tests/tests/top_level_key/mod.rs",
    "content": "use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};\n\n#[derive(Module, Debug)]\n#[allow(unused)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::nn::conv::Conv2dConfig;\n    use burn_store::{ModuleSnapshot, PytorchStore};\n\n    use super::*;\n\n    impl<B: Backend> Net<B> {\n        pub fn init(device: &B::Device) -> Self {\n            Self {\n                conv1: Conv2dConfig::new([2, 2], [2, 2]).init(device),\n            }\n        }\n    }\n\n    #[test]\n    #[should_panic]\n    fn should_fail_if_not_found() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/top_level_key/top_level_key.pt\");\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n    }\n\n    #[test]\n    fn should_load() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::init(&device);\n        let mut store = PytorchStore::from_file(\"tests/top_level_key/top_level_key.pt\")\n            .with_top_level_key(\"my_state_dict\");\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/Cargo.toml",
    "content": "[package]\nname = \"safetensors-tests\"\nversion.workspace = true\nedition.workspace = true\nlicense.workspace = true\n\n[dev-dependencies]\nburn = { path = \"../../burn\" }\nburn-ndarray = { path = \"../../burn-ndarray\" }\nburn-autodiff = { path = \"../../burn-autodiff\" }\nburn-store = { path = \"../\", features = [\"std\", \"safetensors\"] }\nserde = { workspace = true }\nfloat-cmp = { workspace = true }\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/src/lib.rs",
    "content": "\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/tests/backend.rs",
    "content": "pub type TestBackend = burn_ndarray::NdArray<f32>;\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/tests/multi_layer/mod.rs",
    "content": "use burn::{\n    module::Module,\n    nn::{\n        BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,\n        conv::{Conv2d, Conv2dConfig},\n    },\n    tensor::{Tensor, backend::Backend},\n};\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n    norm1: BatchNorm<B>,\n    fc1: Linear<B>,\n    relu: Relu,\n}\n\nimpl<B: Backend> Net<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            conv1: Conv2dConfig::new([3, 4], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .init(device),\n            norm1: BatchNormConfig::new(4).init(device),\n            fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),\n            relu: Relu::new(),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {\n        let x = self.conv1.forward(x);\n        let x = self.norm1.forward(x);\n        let x = self.relu.forward(x);\n        // Flatten all dimensions except the batch dimension\n        let x = x.flatten(1, 3);\n        self.fc1.forward(x)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::backend::TestBackend;\n\n    use burn::tensor::Tolerance;\n    use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};\n\n    use super::*;\n\n    #[test]\n    fn multi_layer_model() {\n        let device = Default::default();\n        let mut model = Net::<TestBackend>::new(&device);\n        let mut store = SafetensorsStore::from_file(\"tests/multi_layer/multi_layer.safetensors\")\n            .with_from_adapter(PyTorchToBurnAdapter);\n\n        model\n            .load_from(&mut store)\n            .expect(\"Should decode state successfully\");\n\n        let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);\n\n        let output = model.forward(input);\n\n        // Note: Expected values should be updated based on the actual output from the PyTorch model\n        let expected = Tensor::<TestBackend, 2>::from_data(\n            [[\n                0.04971555,\n                -0.16849735,\n                0.05182848,\n                -0.18032673,\n                0.23138367,\n                0.05041867,\n                0.13005908,\n                -0.32202929,\n                -0.07915690,\n                -0.03232457,\n                -0.19790289,\n                -0.17476529,\n                -0.19627589,\n                -0.21757686,\n                -0.31376451,\n                0.08377837,\n            ]],\n            &device,\n        );\n\n        output\n            .to_data()\n            .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/tests/multi_layer/multi_layer.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom safetensors.torch import save_file\n\n\nclass Model(nn.Module):\n    def __init__(self):\n        super(Model, self).__init__()\n        self.conv1 = nn.Conv2d(3, 4, kernel_size=3, padding=1)\n        self.norm1 = nn.BatchNorm2d(4)\n        self.flatten = nn.Flatten()\n        self.fc1 = nn.Linear(4 * 8 * 8, 16)  # Changed for smaller input size\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = F.relu(x)\n        x = self.flatten(x)\n        x = self.fc1(x)\n        return x\n\n\ndef main():\n\n    torch.set_printoptions(precision=8)\n    torch.manual_seed(1)\n\n    model = Model().to(torch.device(\"cpu\"))\n\n    # Use a smaller input size\n    # 1 batch, 3 channels (RGB), 8x8 image (small input)\n    x1 = torch.ones(1, 3, 8, 8)\n    _ = model(x1)\n    model.eval()  # Set to eval mode to freeze running stats\n    # Save the model to safetensors after the first forward\n    save_file(model.state_dict(), \"multi_layer.safetensors\")\n\n    x2 = torch.ones(1, 3, 8, 8)\n    print(\"Input shape: {}\", x2.shape)\n    output = model(x2)\n    print(\"Output: {}\", output)\n    print(\"Output Shape: {}\", output.shape)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "crates/burn-store/safetensors-tests/tests/test_mod.rs",
    "content": "mod backend;\n\nmod multi_layer;\n"
  },
  {
    "path": "crates/burn-store/src/adapter.rs",
    "content": "//! Module adapters for transforming tensor snapshots during save/load\n//!\n//! This module provides adapters for:\n//! - PyTorch/Burn format conversion (weight transposition, parameter renaming)\n//! - Mixed-precision storage (F32/F16 dtype casting via [`HalfPrecisionAdapter`])\n//! - Adapter chaining for composing multiple transformations\n\nuse crate::TensorSnapshot;\n\nuse alloc::boxed::Box;\nuse alloc::format;\nuse alloc::rc::Rc;\nuse alloc::string::String;\nuse alloc::string::ToString;\nuse alloc::vec;\n\nuse burn_tensor::shape;\nuse burn_tensor::{DType, TensorData};\nuse hashbrown::HashSet;\n\n// Module type names as they appear in the container_type field\n// These come from the Module derive macro which uses stringify! on the struct name\n// Format: \"Struct:TypeName\" for user-defined structs\nmod module_names {\n    // The actual string constants that match what the Module derive macro produces\n    pub const LINEAR: &str = \"Struct:Linear\";\n    pub const BATCH_NORM: &str = \"Struct:BatchNorm\";\n    pub const LAYER_NORM: &str = \"Struct:LayerNorm\";\n    pub const GROUP_NORM: &str = \"Struct:GroupNorm\";\n    pub const EMBEDDING: &str = \"Struct:Embedding\";\n    pub const CONV1D: &str = \"Struct:Conv1d\";\n    pub const CONV2D: &str = \"Struct:Conv2d\";\n    pub const CONV3D: &str = \"Struct:Conv3d\";\n    pub const CONV_TRANSPOSE1D: &str = \"Struct:ConvTranspose1d\";\n    pub const CONV_TRANSPOSE2D: &str = \"Struct:ConvTranspose2d\";\n    pub const CONV_TRANSPOSE3D: &str = \"Struct:ConvTranspose3d\";\n    pub const DEFORM_CONV2D: &str = \"Struct:DeformConv2d\";\n    pub const INSTANCE_NORM: &str = \"Struct:InstanceNorm\";\n    pub const RMS_NORM: &str = \"Struct:RmsNorm\";\n    pub const PRELU: &str = \"Struct:PRelu\";\n}\n\n/// Trait for adapting tensor snapshots between different module formats\npub trait ModuleAdapter: Send + Sync {\n    /// Adapt a tensor snapshot based on its container type and parameter name\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;\n\n    /// Get alternative parameter name to try during matching\n    ///\n    /// When looking for a parameter in a module, this method provides an alternative\n    /// name to try if the direct name doesn't match. This enables matching parameters\n    /// with different naming conventions (e.g., PyTorch's \"weight\" vs Burn's \"gamma\").\n    ///\n    /// # Arguments\n    /// * `param_name` - The parameter name we're looking for\n    /// * `container_type` - The type of container module (e.g., \"BatchNorm\")\n    ///\n    /// # Returns\n    /// Alternative parameter name to try, or None if no alternative exists\n    fn get_alternative_param_name(\n        &self,\n        _param_name: &str,\n        _container_type: &str,\n    ) -> Option<String> {\n        None\n    }\n\n    /// Clone the adapter into a boxed trait object\n    fn clone_box(&self) -> Box<dyn ModuleAdapter>;\n\n    /// Chain adapters together, applying `self` first and then `next`.\n    ///\n    /// This is useful when multiple transformations are required when importing model weights\n    /// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping).\n    ///\n    /// The semantics follow a simple pipeline:\n    /// - `adapt`: `next.adapt(&self.adapt(snapshot))`\n    /// - `get_alternative_param_name`: try `self` first; if it returns an alternative name,\n    ///   try `next` with that name, otherwise return the first alternative name.\n    fn chain<A>(self, next: A) -> ChainAdapter\n    where\n        Self: Sized + 'static,\n        A: ModuleAdapter + 'static,\n    {\n        ChainAdapter::new(self, next)\n    }\n}\n\nimpl Clone for Box<dyn ModuleAdapter> {\n    fn clone(&self) -> Self {\n        self.clone_box()\n    }\n}\n\n/// Adapter that applies two adapters in sequence.\n///\n/// This allows composing smaller adapters instead of creating one large monolithic adapter.\n#[derive(Clone)]\npub struct ChainAdapter {\n    first: Box<dyn ModuleAdapter>,\n    second: Box<dyn ModuleAdapter>,\n}\n\nimpl ChainAdapter {\n    /// Create a new adapter chain.\n    pub fn new<A, B>(first: A, second: B) -> Self\n    where\n        A: ModuleAdapter + 'static,\n        B: ModuleAdapter + 'static,\n    {\n        Self {\n            first: Box::new(first),\n            second: Box::new(second),\n        }\n    }\n}\n\nimpl ModuleAdapter for ChainAdapter {\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n        let snapshot = self.first.adapt(snapshot);\n        self.second.adapt(&snapshot)\n    }\n\n    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {\n        if let Some(name) = self\n            .first\n            .get_alternative_param_name(param_name, container_type)\n        {\n            self.second\n                .get_alternative_param_name(&name, container_type)\n                .or(Some(name))\n        } else {\n            self.second\n                .get_alternative_param_name(param_name, container_type)\n        }\n    }\n\n    fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n        Box::new(self.clone())\n    }\n}\n\n/// Identity adapter that passes tensors through unchanged\n#[derive(Debug, Clone, Default)]\npub struct IdentityAdapter;\n\nimpl ModuleAdapter for IdentityAdapter {\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n        snapshot.clone()\n    }\n\n    fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n        Box::new(self.clone())\n    }\n}\n\n/// Returns the default set of module types that `HalfPrecisionAdapter` converts.\n///\n/// Includes: Linear, Embedding, all Conv variants, LayerNorm, GroupNorm,\n/// InstanceNorm, RmsNorm, PRelu.\n///\n/// Excludes BatchNorm by default because `running_var` underflows in F16.\nfn default_half_precision_modules() -> HashSet<String> {\n    let modules = [\n        module_names::LINEAR,\n        module_names::EMBEDDING,\n        module_names::CONV1D,\n        module_names::CONV2D,\n        module_names::CONV3D,\n        module_names::CONV_TRANSPOSE1D,\n        module_names::CONV_TRANSPOSE2D,\n        module_names::CONV_TRANSPOSE3D,\n        module_names::DEFORM_CONV2D,\n        module_names::LAYER_NORM,\n        module_names::GROUP_NORM,\n        module_names::INSTANCE_NORM,\n        module_names::RMS_NORM,\n        module_names::PRELU,\n    ];\n    modules.iter().map(|s| s.to_string()).collect()\n}\n\n/// Adapter for mixed-precision (F32/F16) model storage.\n///\n/// Auto-detects conversion direction from the snapshot's dtype:\n/// - F32 source -> cast to F16 (typical for saving)\n/// - F16 source -> cast to F32 (typical for loading)\n/// - Other dtypes -> passed through unchanged\n///\n/// The same instance works for both `with_to_adapter` (save) and `with_from_adapter` (load).\n///\n/// By default, converts weights in: Linear, Embedding, Conv*, LayerNorm, GroupNorm,\n/// InstanceNorm, RmsNorm, PRelu. BatchNorm is excluded because `running_var` underflows in F16.\n///\n/// # Examples\n///\n/// Default usage (same adapter for save and load):\n/// ```rust\n/// # use burn_store::HalfPrecisionAdapter;\n/// let adapter = HalfPrecisionAdapter::new();\n/// // store.with_to_adapter(adapter.clone());  // F32 -> F16 on save\n/// // store.with_from_adapter(adapter);        // F16 -> F32 on load\n/// ```\n///\n/// Exclude a module type:\n/// ```rust\n/// # use burn_store::HalfPrecisionAdapter;\n/// let adapter = HalfPrecisionAdapter::new()\n///     .without_module(\"LayerNorm\");\n/// ```\n///\n/// Add a custom module type:\n/// ```rust\n/// # use burn_store::HalfPrecisionAdapter;\n/// let adapter = HalfPrecisionAdapter::new()\n///     .with_module(\"CustomLayer\");\n/// ```\n#[derive(Debug, Clone)]\npub struct HalfPrecisionAdapter {\n    modules: HashSet<String>,\n}\n\nimpl HalfPrecisionAdapter {\n    /// Create a new adapter with the default set of modules.\n    pub fn new() -> Self {\n        Self {\n            modules: default_half_precision_modules(),\n        }\n    }\n\n    /// Add a module type to convert. Accepts both short (`\"MyLayer\"`) and\n    /// qualified (`\"Struct:MyLayer\"`) forms.\n    ///\n    /// Note: short names are mapped to `\"Struct:Name\"`. If you have an Enum-based\n    /// module, use the qualified form `\"Enum:MyModule\"` explicitly.\n    pub fn with_module(mut self, module_type: impl Into<String>) -> Self {\n        let name = module_type.into();\n        if name.contains(':') {\n            self.modules.insert(name);\n        } else {\n            self.modules.insert(format!(\"Struct:{}\", name));\n        }\n        self\n    }\n\n    /// Remove a module type from conversion. Accepts both short and qualified forms.\n    pub fn without_module(mut self, module_type: impl Into<String>) -> Self {\n        let name = module_type.into();\n        let key = if name.contains(':') {\n            name\n        } else {\n            format!(\"Struct:{}\", name)\n        };\n        assert!(\n            self.modules.contains(&key),\n            \"without_module called with '{}' which is not in the module set\",\n            key\n        );\n        self.modules.remove(&key);\n        self\n    }\n\n    /// Check whether the tensor belongs to a module that should be converted.\n    fn should_convert(&self, snapshot: &TensorSnapshot) -> bool {\n        snapshot\n            .module_type()\n            .is_some_and(|mt| self.modules.contains(&mt))\n    }\n}\n\nimpl Default for HalfPrecisionAdapter {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl ModuleAdapter for HalfPrecisionAdapter {\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n        // Determine target dtype from source: F32 -> F16, F16 -> F32, anything else -> skip\n        let target_dtype = match snapshot.dtype {\n            DType::F32 => DType::F16,\n            DType::F16 => DType::F32,\n            _ => return snapshot.clone(),\n        };\n\n        if !self.should_convert(snapshot) {\n            return snapshot.clone();\n        }\n\n        let original_data_fn = snapshot.clone_data_fn();\n\n        let cast_data_fn = Rc::new(move || {\n            let data = original_data_fn()?;\n            Ok(data.convert_dtype(target_dtype))\n        });\n\n        TensorSnapshot::from_closure(\n            cast_data_fn,\n            target_dtype,\n            snapshot.shape.clone(),\n            snapshot.path_stack.clone().unwrap_or_default(),\n            snapshot.container_stack.clone().unwrap_or_default(),\n            snapshot.tensor_id.unwrap_or_default(),\n        )\n    }\n\n    fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n        Box::new(self.clone())\n    }\n}\n\n/// Adapter for converting from PyTorch format to Burn format\n///\n/// Handles:\n/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])\n/// - Normalization parameter renaming (weight → gamma, bias → beta)\n#[derive(Debug, Clone, Default)]\npub struct PyTorchToBurnAdapter;\n\nimpl ModuleAdapter for PyTorchToBurnAdapter {\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)\n    }\n\n    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {\n        // For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias)\n        if is_normalization_layer(container_type) {\n            burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())\n        } else {\n            None\n        }\n    }\n\n    fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n        Box::new(self.clone())\n    }\n}\n\n/// Adapter for converting from Burn format to PyTorch format\n///\n/// Handles:\n/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])\n/// - Normalization parameter renaming (gamma → weight, beta → bias)\n#[derive(Debug, Clone, Default)]\npub struct BurnToPyTorchAdapter;\n\nimpl ModuleAdapter for BurnToPyTorchAdapter {\n    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)\n    }\n\n    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {\n        // For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta)\n        if is_normalization_layer(container_type) {\n            pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())\n        } else {\n            None\n        }\n    }\n\n    fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n        Box::new(self.clone())\n    }\n}\n\n/// Direction of PyTorch conversion for parameter naming\n#[derive(Debug, Clone, Copy)]\nenum PyTorchConversionDirection {\n    PyTorchToBurn,\n    BurnToPyTorch,\n}\n\n/// Check if container type is a normalization layer\nfn is_normalization_layer(container_type: &str) -> bool {\n    matches!(\n        container_type,\n        module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM\n    )\n}\n\n/// Map PyTorch normalization parameter name to Burn\nfn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {\n    match param_name {\n        \"weight\" => Some(\"gamma\"),\n        \"bias\" => Some(\"beta\"),\n        _ => None,\n    }\n}\n\n/// Map Burn normalization parameter name to PyTorch\nfn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {\n    match param_name {\n        \"gamma\" => Some(\"weight\"),\n        \"beta\" => Some(\"bias\"),\n        _ => None,\n    }\n}\n\n/// Core tensor adaptation logic for PyTorch format conversions\nfn adapt_pytorch_tensor(\n    snapshot: &TensorSnapshot,\n    direction: PyTorchConversionDirection,\n) -> TensorSnapshot {\n    // Extract path and parameter name\n    let (path_stack, param_name) = match get_path_and_param(snapshot) {\n        Some(result) => result,\n        None => return snapshot.clone(),\n    };\n\n    // Get module type for matching (ignores Vec/Array wrappers)\n    let module_type = match snapshot.module_type() {\n        Some(mt) => mt,\n        None => return snapshot.clone(), // No user-defined module found\n    };\n\n    // Linear: transpose weight (bidirectional - same operation both ways)\n    if module_type == module_names::LINEAR && param_name == \"weight\" && snapshot.shape.len() == 2 {\n        return transpose_2d_tensor(snapshot);\n    }\n\n    // Normalization layers: rename parameters based on direction\n    if is_normalization_layer(&module_type) {\n        let new_name = match direction {\n            PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),\n            PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),\n        };\n\n        if let Some(new_name) = new_name {\n            return rename_parameter(snapshot, path_stack, new_name);\n        }\n    }\n\n    snapshot.clone()\n}\n\n/// Extract path stack and parameter name from snapshot\nfn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {\n    let path_stack = snapshot.path_stack.as_ref()?;\n    let param_name = path_stack.last()?.as_str();\n    Some((path_stack.as_slice(), param_name))\n}\n\n/// Rename a parameter in the snapshot\nfn rename_parameter(\n    snapshot: &TensorSnapshot,\n    path_stack: &[String],\n    new_name: &str,\n) -> TensorSnapshot {\n    let mut new_path = path_stack.to_vec();\n    *new_path.last_mut().unwrap() = new_name.to_string();\n\n    TensorSnapshot::from_closure(\n        snapshot.clone_data_fn(),\n        snapshot.dtype,\n        snapshot.shape.clone(),\n        new_path,\n        snapshot.container_stack.clone().unwrap_or_default(),\n        snapshot.tensor_id.unwrap_or_default(),\n    )\n}\n\n/// Transpose a 2D tensor\nfn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {\n    if snapshot.shape.len() != 2 {\n        return snapshot.clone();\n    }\n\n    let original_data_fn = snapshot.clone_data_fn();\n    let dtype = snapshot.dtype;\n    let transposed_shape = shape![snapshot.shape[1], snapshot.shape[0]];\n\n    // Create a lazy closure that transposes when called\n    let transposed_data_fn = Rc::new(move || {\n        let data = original_data_fn()?;\n        Ok(transpose_tensor_data(data))\n    });\n\n    TensorSnapshot::from_closure(\n        transposed_data_fn,\n        dtype,\n        transposed_shape,\n        snapshot.path_stack.clone().unwrap_or_default(),\n        snapshot.container_stack.clone().unwrap_or_default(),\n        snapshot.tensor_id.unwrap_or_default(),\n    )\n}\n\n/// Transpose tensor data (assumes 2D shape is already validated)\nfn transpose_tensor_data(data: TensorData) -> TensorData {\n    let shape = &data.shape;\n    let rows = shape[0];\n    let cols = shape[1];\n    let transposed_shape = vec![cols, rows];\n\n    // Get the raw bytes and element size\n    let bytes = data.as_bytes();\n    let element_size = data.dtype.size();\n\n    // Create a new buffer for transposed data\n    let mut transposed_bytes = vec![0u8; bytes.len()];\n\n    // Transpose at the byte level - works for any data type\n    for i in 0..rows {\n        for j in 0..cols {\n            let src_idx = (i * cols + j) * element_size;\n            let dst_idx = (j * rows + i) * element_size;\n\n            // Copy the bytes for this element\n            transposed_bytes[dst_idx..dst_idx + element_size]\n                .copy_from_slice(&bytes[src_idx..src_idx + element_size]);\n        }\n    }\n\n    // Create new TensorData from transposed bytes\n    TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use alloc::rc::Rc;\n    use alloc::sync::Arc;\n    use burn_tensor::{DType, Shape, TensorData};\n    use core::sync::atomic::{AtomicUsize, Ordering};\n\n    #[test]\n    fn test_module_names_match_burn_nn() {\n        // If these types are renamed or moved in `burn-nn`, this test will fail to compile.\n        #[allow(unused_imports)]\n        use burn_nn::{\n            BatchNorm, Embedding, GroupNorm, InstanceNorm, LayerNorm, Linear, PRelu, RmsNorm,\n            conv::{\n                Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d,\n                DeformConv2d,\n            },\n        };\n\n        assert_eq!(module_names::LINEAR, \"Struct:Linear\");\n        assert_eq!(module_names::BATCH_NORM, \"Struct:BatchNorm\");\n        assert_eq!(module_names::LAYER_NORM, \"Struct:LayerNorm\");\n        assert_eq!(module_names::GROUP_NORM, \"Struct:GroupNorm\");\n        assert_eq!(module_names::EMBEDDING, \"Struct:Embedding\");\n        assert_eq!(module_names::CONV1D, \"Struct:Conv1d\");\n        assert_eq!(module_names::CONV2D, \"Struct:Conv2d\");\n        assert_eq!(module_names::CONV3D, \"Struct:Conv3d\");\n        assert_eq!(module_names::CONV_TRANSPOSE1D, \"Struct:ConvTranspose1d\");\n        assert_eq!(module_names::CONV_TRANSPOSE2D, \"Struct:ConvTranspose2d\");\n        assert_eq!(module_names::CONV_TRANSPOSE3D, \"Struct:ConvTranspose3d\");\n        assert_eq!(module_names::DEFORM_CONV2D, \"Struct:DeformConv2d\");\n        assert_eq!(module_names::INSTANCE_NORM, \"Struct:InstanceNorm\");\n        assert_eq!(module_names::RMS_NORM, \"Struct:RmsNorm\");\n        assert_eq!(module_names::PRELU, \"Struct:PRelu\");\n    }\n\n    fn create_test_snapshot(path: &str, shape: Shape, container_type: &str) -> TensorSnapshot {\n        let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();\n        let values = vec![1.0f32; shape.iter().product()];\n        let data = TensorData::new(values, shape.clone());\n\n        TensorSnapshot::from_closure(\n            Rc::new(move || Ok(data.clone())),\n            DType::F32,\n            shape,\n            path_parts,\n            vec![container_type.to_string()],\n            burn_core::module::ParamId::new(),\n        )\n    }\n\n    #[test]\n    fn test_pytorch_to_burn_linear_weight() {\n        let adapter = PyTorchToBurnAdapter;\n\n        // Linear layer weight should be transposed\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![10, 5], module_names::LINEAR);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.shape, shape![5, 10]);\n\n        // Linear layer bias should not be transposed\n        let snapshot = create_test_snapshot(\"fc.bias\", shape![10], module_names::LINEAR);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.shape, shape![10]);\n    }\n\n    #[test]\n    fn test_pytorch_to_burn_norm_params() {\n        let adapter = PyTorchToBurnAdapter;\n\n        // BatchNorm weight -> gamma\n        let snapshot = create_test_snapshot(\"norm.weight\", shape![10], module_names::BATCH_NORM);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.full_path(), \"norm.gamma\");\n\n        // BatchNorm bias -> beta\n        let snapshot = create_test_snapshot(\"norm.bias\", shape![10], module_names::BATCH_NORM);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.full_path(), \"norm.beta\");\n    }\n\n    #[test]\n    fn test_burn_to_pytorch_linear_weight() {\n        let adapter = BurnToPyTorchAdapter;\n\n        // Linear layer weight should be transposed\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![5, 10], module_names::LINEAR);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.shape, shape![10, 5]);\n    }\n\n    #[test]\n    fn test_burn_to_pytorch_norm_params() {\n        let adapter = BurnToPyTorchAdapter;\n\n        // BatchNorm gamma -> weight\n        let snapshot = create_test_snapshot(\"norm.gamma\", shape![10], module_names::BATCH_NORM);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.full_path(), \"norm.weight\");\n\n        // BatchNorm beta -> bias\n        let snapshot = create_test_snapshot(\"norm.beta\", shape![10], module_names::BATCH_NORM);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.full_path(), \"norm.bias\");\n    }\n\n    #[test]\n    fn test_transpose_different_dtypes() {\n        // Test that transpose works for different data types\n\n        // Test with F32\n        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);\n        let transposed = transpose_tensor_data(f32_data);\n        assert_eq!(transposed.shape, shape![3, 2]);\n        let values = transposed.to_vec::<f32>().unwrap();\n        assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);\n\n        // Test with I32\n        let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], [2, 3]);\n        let transposed = transpose_tensor_data(i32_data);\n        assert_eq!(transposed.shape, shape![3, 2]);\n        let values = transposed.to_vec::<i32>().unwrap();\n        assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);\n\n        // Test with F64\n        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);\n        let transposed = transpose_tensor_data(f64_data);\n        assert_eq!(transposed.shape, shape![2, 2]);\n        let values = transposed.to_vec::<f64>().unwrap();\n        assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);\n    }\n\n    #[test]\n    fn test_no_container_info() {\n        let adapter = PyTorchToBurnAdapter;\n\n        // Without container info, adapter returns unchanged for non-norm parameters\n        let mut snapshot = create_test_snapshot(\"fc.weight\", shape![10, 5], module_names::LINEAR);\n        snapshot.container_stack = None;\n\n        // Without container info, no transformation occurs for linear layers\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.shape, shape![10, 5]); // No transposition without container info\n\n        // Test a non-linear, non-norm parameter - should pass through unchanged\n        let mut snapshot2 = create_test_snapshot(\"other.weight\", shape![10, 5], \"Struct:Other\");\n        snapshot2.container_stack = None;\n        let adapted2 = adapter.adapt(&snapshot2);\n        assert_eq!(adapted2.shape, shape![10, 5]); // No transposition\n    }\n\n    #[derive(Clone)]\n    struct RenameParamAdapter {\n        from: &'static str,\n        to: &'static str,\n        called: Arc<AtomicUsize>,\n    }\n\n    impl ModuleAdapter for RenameParamAdapter {\n        fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n            self.called.fetch_add(1, Ordering::Relaxed);\n\n            let path_stack = match snapshot.path_stack.as_ref() {\n                Some(stack) => stack,\n                None => return snapshot.clone(),\n            };\n            let param = match path_stack.last() {\n                Some(p) => p.as_str(),\n                None => return snapshot.clone(),\n            };\n            if param != self.from {\n                return snapshot.clone();\n            }\n\n            let mut new_path = path_stack.to_vec();\n            *new_path.last_mut().unwrap() = self.to.to_string();\n\n            TensorSnapshot::from_closure(\n                snapshot.clone_data_fn(),\n                snapshot.dtype,\n                snapshot.shape.clone(),\n                new_path,\n                snapshot.container_stack.clone().unwrap_or_default(),\n                snapshot.tensor_id.unwrap_or_default(),\n            )\n        }\n\n        fn get_alternative_param_name(\n            &self,\n            _param_name: &str,\n            _container_type: &str,\n        ) -> Option<String> {\n            None\n        }\n\n        fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n            Box::new(self.clone())\n        }\n    }\n\n    #[derive(Clone)]\n    struct AltNameAdapter {\n        from: &'static str,\n        to: &'static str,\n        called: Arc<AtomicUsize>,\n    }\n\n    impl ModuleAdapter for AltNameAdapter {\n        fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {\n            TensorSnapshot::from_closure(\n                snapshot.clone_data_fn(),\n                snapshot.dtype,\n                snapshot.shape.clone(),\n                snapshot.path_stack.clone().unwrap_or_default(),\n                snapshot.container_stack.clone().unwrap_or_default(),\n                snapshot.tensor_id.unwrap_or_default(),\n            )\n        }\n\n        fn get_alternative_param_name(\n            &self,\n            param_name: &str,\n            _container_type: &str,\n        ) -> Option<String> {\n            self.called.fetch_add(1, Ordering::Relaxed);\n            if param_name == self.from {\n                Some(self.to.to_string())\n            } else {\n                None\n            }\n        }\n\n        fn clone_box(&self) -> Box<dyn ModuleAdapter> {\n            Box::new(self.clone())\n        }\n    }\n\n    #[test]\n    fn test_chain_adapter_pipes_adapt() {\n        let called1 = Arc::new(AtomicUsize::new(0));\n        let called2 = Arc::new(AtomicUsize::new(0));\n\n        let a = RenameParamAdapter {\n            from: \"weight\",\n            to: \"a\",\n            called: called1.clone(),\n        };\n        let b = RenameParamAdapter {\n            from: \"a\",\n            to: \"b\",\n            called: called2.clone(),\n        };\n\n        let chain = a.chain(b);\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![2, 2], module_names::LINEAR);\n        let adapted = chain.adapt(&snapshot);\n\n        assert_eq!(adapted.full_path(), \"fc.b\");\n        assert_eq!(called1.load(Ordering::Relaxed), 1);\n        assert_eq!(called2.load(Ordering::Relaxed), 1);\n    }\n\n    #[test]\n    fn test_chain_adapter_alternative_name_pipes_and_fallbacks() {\n        let called1 = Arc::new(AtomicUsize::new(0));\n        let called2 = Arc::new(AtomicUsize::new(0));\n\n        let a = AltNameAdapter {\n            from: \"gamma\",\n            to: \"weight\",\n            called: called1.clone(),\n        };\n        let b = AltNameAdapter {\n            from: \"weight\",\n            to: \"scale\",\n            called: called2.clone(),\n        };\n\n        let chain = a.chain(b);\n        let alt = chain.get_alternative_param_name(\"gamma\", module_names::LAYER_NORM);\n        assert_eq!(alt.as_deref(), Some(\"scale\"));\n        assert_eq!(called1.load(Ordering::Relaxed), 1);\n        assert_eq!(called2.load(Ordering::Relaxed), 1);\n\n        // If the second adapter doesn't have a mapping for the first alternative,\n        // fall back to the first alternative name.\n        let called1 = Arc::new(AtomicUsize::new(0));\n        let called2 = Arc::new(AtomicUsize::new(0));\n        let a = AltNameAdapter {\n            from: \"gamma\",\n            to: \"weight\",\n            called: called1.clone(),\n        };\n        let b = AltNameAdapter {\n            from: \"something-else\",\n            to: \"unused\",\n            called: called2.clone(),\n        };\n        let chain = a.chain(b);\n        let alt = chain.get_alternative_param_name(\"gamma\", module_names::LAYER_NORM);\n        assert_eq!(alt.as_deref(), Some(\"weight\"));\n        assert_eq!(called1.load(Ordering::Relaxed), 1);\n        assert_eq!(called2.load(Ordering::Relaxed), 1);\n\n        // If the first adapter doesn't provide an alternative, try the second with the original name.\n        let called1 = Arc::new(AtomicUsize::new(0));\n        let called2 = Arc::new(AtomicUsize::new(0));\n        let a = AltNameAdapter {\n            from: \"something-else\",\n            to: \"unused\",\n            called: called1.clone(),\n        };\n        let b = AltNameAdapter {\n            from: \"gamma\",\n            to: \"weight\",\n            called: called2.clone(),\n        };\n        let chain = a.chain(b);\n        let alt = chain.get_alternative_param_name(\"gamma\", module_names::LAYER_NORM);\n        assert_eq!(alt.as_deref(), Some(\"weight\"));\n        assert_eq!(called1.load(Ordering::Relaxed), 1);\n        assert_eq!(called2.load(Ordering::Relaxed), 1);\n\n        // clone_box must preserve behavior.\n        let boxed = chain.clone_box();\n        let alt = boxed.get_alternative_param_name(\"gamma\", module_names::LAYER_NORM);\n        assert_eq!(alt.as_deref(), Some(\"weight\"));\n    }\n\n    #[test]\n    fn test_half_precision_f32_to_f16() {\n        let adapter = HalfPrecisionAdapter::new();\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![2, 3], module_names::LINEAR);\n\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.dtype, DType::F16);\n        assert_eq!(adapted.shape, shape![2, 3]);\n\n        let data = adapted.to_data().unwrap();\n        assert_eq!(data.dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_f16_to_f32() {\n        let adapter = HalfPrecisionAdapter::new();\n\n        // Create an F16 snapshot\n        let values = vec![1.0f32; 6];\n        let data = TensorData::new(values, shape![2, 3]).convert_dtype(DType::F16);\n        let path_parts = vec![\"fc\".to_string(), \"weight\".to_string()];\n        let snapshot = TensorSnapshot::from_closure(\n            Rc::new(move || Ok(data.clone())),\n            DType::F16,\n            shape![2, 3],\n            path_parts,\n            vec![module_names::LINEAR.to_string()],\n            burn_core::module::ParamId::new(),\n        );\n\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.dtype, DType::F32);\n    }\n\n    #[test]\n    fn test_half_precision_skips_batch_norm() {\n        let adapter = HalfPrecisionAdapter::new();\n\n        // BatchNorm is excluded by default\n        let snapshot = create_test_snapshot(\"norm.weight\", shape![10], module_names::BATCH_NORM);\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.dtype, DType::F32); // unchanged\n    }\n\n    #[test]\n    fn test_half_precision_converts_default_modules() {\n        let adapter = HalfPrecisionAdapter::new();\n\n        // Linear\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![2, 3], module_names::LINEAR);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n\n        // Embedding\n        let snapshot = create_test_snapshot(\"emb.weight\", shape![100, 64], module_names::EMBEDDING);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n\n        // Conv2d\n        let snapshot =\n            create_test_snapshot(\"conv.weight\", shape![3, 3, 3, 3], module_names::CONV2D);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n\n        // LayerNorm (included by default)\n        let snapshot = create_test_snapshot(\"norm.gamma\", shape![10], module_names::LAYER_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n\n        // GroupNorm\n        let snapshot = create_test_snapshot(\"gn.gamma\", shape![10], module_names::GROUP_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n\n        // RmsNorm\n        let snapshot = create_test_snapshot(\"rms.weight\", shape![10], module_names::RMS_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_without_module() {\n        let adapter = HalfPrecisionAdapter::new().without_module(\"LayerNorm\");\n\n        // LayerNorm removed from conversion set\n        let snapshot = create_test_snapshot(\"norm.gamma\", shape![10], module_names::LAYER_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);\n\n        // Linear still converted\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![2, 3], module_names::LINEAR);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_with_module() {\n        let adapter = HalfPrecisionAdapter::new().with_module(\"CustomLayer\");\n\n        // Custom module should now be converted\n        let snapshot = create_test_snapshot(\"custom.weight\", shape![5], \"Struct:CustomLayer\");\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_with_qualified_name() {\n        let adapter = HalfPrecisionAdapter::new().with_module(\"Struct:CustomLayer\");\n\n        let snapshot = create_test_snapshot(\"custom.weight\", shape![5], \"Struct:CustomLayer\");\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_chain() {\n        let adapter = PyTorchToBurnAdapter.chain(HalfPrecisionAdapter::new());\n\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![10, 5], module_names::LINEAR);\n        let adapted = adapter.adapt(&snapshot);\n\n        // Should be both transposed and cast\n        assert_eq!(adapted.shape, shape![5, 10]);\n        assert_eq!(adapted.dtype, DType::F16);\n    }\n\n    #[test]\n    fn test_half_precision_skips_no_container() {\n        let adapter = HalfPrecisionAdapter::new();\n        let mut snapshot = create_test_snapshot(\"fc.weight\", shape![2, 3], module_names::LINEAR);\n        snapshot.container_stack = None;\n\n        // No module type info: skip\n        let adapted = adapter.adapt(&snapshot);\n        assert_eq!(adapted.dtype, DType::F32);\n    }\n\n    #[test]\n    fn test_half_precision_skips_non_float() {\n        use burn_tensor::quantization::QuantScheme;\n\n        let adapter = HalfPrecisionAdapter::new();\n\n        // QFloat source: skip\n        let qfloat_dtype = DType::QFloat(QuantScheme::default());\n        let snapshot = create_test_snapshot(\"fc.weight\", shape![2, 3], module_names::LINEAR);\n        let qfloat_snapshot = TensorSnapshot::from_closure(\n            snapshot.clone_data_fn(),\n            qfloat_dtype,\n            snapshot.shape.clone(),\n            snapshot.path_stack.clone().unwrap_or_default(),\n            snapshot.container_stack.clone().unwrap_or_default(),\n            snapshot.tensor_id.unwrap_or_default(),\n        );\n        let adapted = adapter.adapt(&qfloat_snapshot);\n        assert_eq!(adapted.dtype, qfloat_dtype);\n    }\n\n    #[test]\n    fn test_half_precision_default_module_count() {\n        let adapter = HalfPrecisionAdapter::new();\n        // 14 modules: Linear, Embedding, Conv1d-3d, ConvTranspose1d-3d,\n        // DeformConv2d, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, PRelu\n        assert_eq!(adapter.modules.len(), 14);\n    }\n\n    #[test]\n    fn test_half_precision_without_module_qualified() {\n        let adapter = HalfPrecisionAdapter::new().without_module(\"Struct:LayerNorm\");\n\n        let snapshot = create_test_snapshot(\"norm.gamma\", shape![10], module_names::LAYER_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);\n    }\n\n    #[test]\n    fn test_half_precision_with_module_batch_norm_opt_in() {\n        let adapter = HalfPrecisionAdapter::new().with_module(\"BatchNorm\");\n\n        let snapshot = create_test_snapshot(\"bn.weight\", shape![10], module_names::BATCH_NORM);\n        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/applier.rs",
    "content": "//! Applier that correctly applies tensor snapshots with adapter support\n\nuse alloc::boxed::Box;\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\n\nuse hashbrown::{HashMap, HashSet};\n\nuse burn_core::module::{ModuleMapper, Param};\nuse burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend};\n\nuse crate::apply_result::{ApplyError, ApplyResult};\nuse crate::{ModuleAdapter, PathFilter, TensorSnapshot};\n\n/// Applier that applies tensor snapshots to module parameters\n/// with proper adapter support using container type information\npub struct Applier<B: Backend> {\n    /// Map of tensor paths to their snapshots\n    snapshots: HashMap<String, TensorSnapshot>,\n    /// Current path in the module hierarchy\n    path_stack: Vec<String>,\n    /// Current container type stack in the module hierarchy\n    container_stack: Vec<String>,\n    /// Optional filter for selective application\n    filter: Option<PathFilter>,\n    /// Optional adapter to transform tensors based on container types\n    adapter: Option<Box<dyn ModuleAdapter>>,\n    /// Successfully applied tensor paths\n    applied: Vec<String>,\n    /// Skipped tensor paths\n    skipped: HashSet<String>,\n    /// Errors encountered during application\n    errors: Vec<ApplyError>,\n    /// Track visited paths with their container stacks (in dot notation) to find missing tensors\n    visited_paths: HashMap<String, String>,\n    /// Skip enum variant names when matching paths\n    /// When true, \"feature.BaseConv.weight\" will also try to match \"feature.weight\"\n    skip_enum_variants: bool,\n    /// Phantom data for backend type\n    _backend: core::marker::PhantomData<B>,\n}\n\nimpl<B: Backend> Applier<B> {\n    /// Create a new applier with snapshots, optional filter, and optional adapter\n    ///\n    /// # Arguments\n    ///\n    /// * `views` - A vector of TensorSnapshot objects to apply\n    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.\n    ///   When `None`, all available tensors are applied.\n    /// * `adapter` - Optional adapter to transform tensors based on container types\n    /// * `skip_enum_variants` - Skip enum variant names when matching paths\n    pub fn new(\n        views: Vec<TensorSnapshot>,\n        filter: Option<PathFilter>,\n        adapter: Option<Box<dyn ModuleAdapter>>,\n        skip_enum_variants: bool,\n    ) -> Self {\n        let views_map: HashMap<String, TensorSnapshot> = views\n            .into_iter()\n            .map(|view| (view.full_path(), view))\n            .collect();\n\n        Self {\n            snapshots: views_map,\n            path_stack: Vec::new(),\n            container_stack: Vec::new(),\n            filter,\n            adapter,\n            applied: Vec::new(),\n            skipped: HashSet::new(),\n            errors: Vec::new(),\n            visited_paths: HashMap::new(),\n            skip_enum_variants,\n            _backend: core::marker::PhantomData,\n        }\n    }\n\n    /// Get the current path in the module hierarchy\n    fn current_path(&self) -> String {\n        self.path_stack.join(\".\")\n    }\n\n    /// Get the current module type (last Struct/Enum in container stack)\n    fn current_module_type(&self) -> Option<&str> {\n        self.container_stack\n            .iter()\n            .rev()\n            .find(|ct| ct.starts_with(\"Struct:\") || ct.starts_with(\"Enum:\"))\n            .map(|s| s.as_str())\n    }\n\n    /// Check if a tensor should be applied based on filter\n    fn should_apply(&self) -> bool {\n        match &self.filter {\n            None => true,\n            Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),\n        }\n    }\n\n    /// Convert the applier into a result\n    pub fn into_result(self) -> ApplyResult {\n        let mut unused: Vec<String> = self\n            .snapshots\n            .keys()\n            .filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path))\n            .cloned()\n            .collect();\n        // Sort for stable output order\n        unused.sort();\n\n        // Create a set of successfully applied paths for efficient lookup\n        let applied_set: HashSet<String> = self.applied.iter().cloned().collect();\n\n        // Extract paths that have errors - these are not \"missing\", they were found but had issues\n        let errored_paths: HashSet<String> = self\n            .errors\n            .iter()\n            .map(|e| match e {\n                ApplyError::ShapeMismatch { path, .. } => path.clone(),\n                ApplyError::DTypeMismatch { path, .. } => path.clone(),\n                ApplyError::AdapterError { path, .. } => path.clone(),\n                ApplyError::LoadError { path, .. } => path.clone(),\n            })\n            .collect();\n\n        // A path is missing if it was visited but not successfully applied, not skipped, and didn't have an error\n        // Store both the path and its container stack (in dot notation)\n        let mut missing: Vec<(String, String)> = self\n            .visited_paths\n            .into_iter()\n            .filter(|(p, _)| {\n                !applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p)\n            })\n            .collect();\n        // Sort for stable output order (by path)\n        missing.sort_by(|a, b| a.0.cmp(&b.0));\n\n        // Convert skipped HashSet to sorted Vec for stable output\n        let mut skipped: Vec<String> = self.skipped.into_iter().collect();\n        skipped.sort();\n\n        ApplyResult {\n            applied: self.applied,\n            skipped,\n            missing,\n            unused,\n            errors: self.errors,\n        }\n    }\n\n    /// Apply a tensor snapshot with shape validation and optional adapter transformation\n    /// Returns None if snapshot not found, filtered, or validation fails\n    fn apply_tensor<const D: usize, K>(\n        &mut self,\n        target_device: &B::Device,\n        target_shape: Shape,\n    ) -> Option<Tensor<B, D, K>>\n    where\n        K: burn_tensor::TensorKind<B>,\n        K: burn_tensor::BasicOps<B>,\n    {\n        let path = self.current_path();\n        let container_stack_str = self.container_stack.join(\".\");\n        self.visited_paths.insert(path.clone(), container_stack_str);\n\n        // Try to get snapshot with original path first\n        let mut snapshot = self.snapshots.get(&path).cloned();\n\n        // If not found and we have an adapter, try alternative parameter names\n        if snapshot.is_none()\n            && let Some(ref adapter) = self.adapter\n            && let Some(module_type) = self.current_module_type()\n        {\n            // Get alternative name based on current module type (user-defined module only)\n            let param_name = self.path_stack.last()?;\n\n            if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) {\n                // Build alternative path with parameter name substitution\n                let mut alt_path_stack = self.path_stack.clone();\n                *alt_path_stack.last_mut().unwrap() = alt_name.clone();\n                let alt_path = alt_path_stack.join(\".\");\n\n                // Try to get snapshot with alternative name\n                snapshot = self.snapshots.get(&alt_path).cloned();\n\n                // Don't mark the alternative path as visited - only the original Burn path\n                // should be tracked. The alternative path is just for lookup.\n            }\n        }\n\n        let mut snapshot = snapshot?;\n\n        // Apply adapter transformation using current container_stack context (for data transformation like transpose)\n        if let Some(ref adapter) = self.adapter {\n            // Create a temporary snapshot with current context for adaptation\n            let snapshot_with_context = TensorSnapshot::from_closure(\n                snapshot.clone_data_fn(),\n                snapshot.dtype,\n                snapshot.shape.clone(),\n                self.path_stack.clone(),\n                self.container_stack.clone(),\n                snapshot.tensor_id.unwrap_or_default(),\n            );\n\n            // Transform using adapter (handles transpose)\n            snapshot = adapter.adapt(&snapshot_with_context);\n        }\n\n        // Check if we should apply based on filter\n        if !self.should_apply() {\n            self.skipped.insert(path.clone());\n            return None;\n        }\n\n        // Load tensor data\n        let data = match snapshot.to_data() {\n            Ok(data) => data,\n            Err(e) => {\n                self.errors.push(ApplyError::LoadError {\n                    path: path.clone(),\n                    message: format!(\"Failed to load tensor data: {:?}\", e),\n                });\n                return None; // Signal caller to fall back to initialization\n            }\n        };\n\n        // Validate shape\n        if data.shape != target_shape {\n            self.errors.push(ApplyError::ShapeMismatch {\n                path: path.clone(),\n                expected: target_shape,\n                found: data.shape,\n            });\n            return None; // Signal caller to fall back to initialization\n        }\n\n        self.applied.push(path);\n        Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype))\n    }\n}\n\nimpl<B: Backend> ModuleMapper<B> for Applier<B> {\n    fn enter_module(&mut self, name: &str, container_type: &str) {\n        // Always track the container type for proper module type detection\n        self.container_stack.push(container_type.to_string());\n\n        // Only add to path if it's not an enum variant (when skip_enum_variants is enabled)\n        // This ensures paths are built without enum variant names from the start\n        if !self.skip_enum_variants || !container_type.starts_with(\"Enum:\") {\n            self.path_stack.push(name.to_string());\n        }\n    }\n\n    fn exit_module(&mut self, _name: &str, container_type: &str) {\n        self.container_stack.pop();\n\n        // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)\n        if !self.skip_enum_variants || !container_type.starts_with(\"Enum:\") {\n            self.path_stack.pop();\n        }\n    }\n\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let param_id = param.id;\n        let target_device = param.lazy_device();\n        let target_shape = param.lazy_shape();\n\n        // Try to apply snapshot with shape validation\n        match self.apply_tensor(&target_device, target_shape) {\n            Some(tensor) => {\n                // We have a tensor to apply - load it\n                param.transform_for_load(tensor, param_id)\n            }\n            None => {\n                // No snapshot, filtered, or validation failed - return param unchanged\n                param\n            }\n        }\n    }\n\n    fn map_int<const D: usize>(\n        &mut self,\n        param: Param<Tensor<B, D, Int>>,\n    ) -> Param<Tensor<B, D, Int>> {\n        let param_id = param.id;\n        let target_device = param.lazy_device();\n        let target_shape = param.lazy_shape();\n\n        // Try to apply snapshot with shape validation\n        match self.apply_tensor(&target_device, target_shape) {\n            Some(tensor) => {\n                // We have a tensor to apply - load it\n                param.transform_for_load(tensor, param_id)\n            }\n            None => {\n                // No snapshot, filtered, or validation failed - return param unchanged\n                param\n            }\n        }\n    }\n\n    fn map_bool<const D: usize>(\n        &mut self,\n        param: Param<Tensor<B, D, Bool>>,\n    ) -> Param<Tensor<B, D, Bool>> {\n        let param_id = param.id;\n        let target_device = param.lazy_device();\n        let target_shape = param.lazy_shape();\n\n        // Try to apply snapshot with shape validation\n        match self.apply_tensor(&target_device, target_shape) {\n            Some(tensor) => {\n                // We have a tensor to apply - load it\n                param.transform_for_load(tensor, param_id)\n            }\n            None => {\n                // No snapshot, filtered, or validation failed - return param unchanged\n                param\n            }\n        }\n    }\n}\n\n#[cfg(all(test, feature = \"std\", target_has_atomic = \"ptr\"))]\nmod tests {\n    use super::*;\n    use burn_core::module::{ModuleMapper, Param, ParamId};\n    use burn_tensor::{DType, Tensor, TensorData};\n\n    type TestBackend = burn_ndarray::NdArray;\n\n    #[test]\n    fn root_level_parameters() {\n        let device = Default::default();\n\n        // Create root-level parameters (not inside any module)\n        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);\n\n        // Create snapshots with root-level paths (single-element path, no nested modules)\n        let weight_snapshot = crate::TensorSnapshot::from_data(\n            weight.val().to_data(),\n            vec![\"weight\".to_string()], // root-level parameter name\n            vec![],                     // no container\n            ParamId::new(),\n        );\n\n        let bias_snapshot = crate::TensorSnapshot::from_data(\n            bias.val().to_data(),\n            vec![\"bias\".to_string()], // root-level parameter name\n            vec![],                   // no container\n            ParamId::new(),\n        );\n\n        // Create applier with root-level snapshots\n        let mut applier =\n            Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None, false);\n\n        // Create new params to load into\n        let weight_target = Param::initialized(\n            ParamId::new(),\n            Tensor::<TestBackend, 2>::zeros([2, 2], &device),\n        );\n        let bias_target = Param::initialized(\n            ParamId::new(),\n            Tensor::<TestBackend, 1>::zeros([2], &device),\n        );\n\n        // Apply using the ModuleMapper interface - simulate module traversal\n        // Enter \"weight\" path (as if we're visiting a field named \"weight\")\n        applier.enter_module(\"weight\", \"\");\n        let weight_loaded = applier.map_float(weight_target);\n        applier.exit_module(\"weight\", \"\");\n\n        // Enter \"bias\" path (as if we're visiting a field named \"bias\")\n        applier.enter_module(\"bias\", \"\");\n        let bias_loaded = applier.map_float(bias_target);\n        applier.exit_module(\"bias\", \"\");\n\n        // Verify values were loaded\n        let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();\n        let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();\n\n        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);\n        assert_eq!(bias_data, vec![5.0, 6.0]);\n\n        // Verify applier result\n        let result = applier.into_result();\n        assert_eq!(result.applied.len(), 2);\n        assert_eq!(result.errors.len(), 0);\n    }\n\n    /// Test that the applier preserves dtype when loading tensor data.\n    /// This is a regression test for the bug where F16 tensors were being\n    /// loaded as F32 because `Tensor::from_data` was used instead of\n    /// `Tensor::from_data_dtype`.\n    #[test]\n    fn dtype_preservation_f64() {\n        // Use NdArray<f64> backend to properly test F64 dtype preservation\n        type TestBackendF64 = burn_ndarray::NdArray<f64>;\n        let device = Default::default();\n\n        // Create TensorData with F64 dtype explicitly\n        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);\n        assert_eq!(f64_data.dtype, DType::F64, \"Test setup: data should be F64\");\n\n        // Create a snapshot with F64 data\n        let snapshot = crate::TensorSnapshot::from_data(\n            f64_data.clone(),\n            vec![\"weight\".to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        assert_eq!(\n            snapshot.dtype,\n            DType::F64,\n            \"Snapshot should preserve F64 dtype\"\n        );\n\n        // Create applier with the F64 snapshot\n        let mut applier = Applier::<TestBackendF64>::new(vec![snapshot], None, None, false);\n\n        // Create target parameter\n        let target = Param::initialized(\n            ParamId::new(),\n            Tensor::<TestBackendF64, 2>::zeros([2, 2], &device),\n        );\n\n        // Apply the snapshot\n        applier.enter_module(\"weight\", \"\");\n        let loaded = applier.map_float(target);\n        applier.exit_module(\"weight\", \"\");\n\n        // Verify dtype is preserved - this would fail before the fix\n        // because the data would be converted to the backend's default FloatElem\n        assert_eq!(\n            loaded.val().dtype(),\n            DType::F64,\n            \"Loaded tensor should have F64 dtype\"\n        );\n\n        // Verify data values are correct\n        let loaded_data = loaded.val().to_data().to_vec::<f64>().unwrap();\n        assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);\n\n        // Verify applier result\n        let result = applier.into_result();\n        assert_eq!(result.applied.len(), 1);\n        assert_eq!(result.errors.len(), 0);\n    }\n\n    /// Test that F32 dtype is preserved when loading (verifies we didn't break F32 handling)\n    #[test]\n    fn dtype_preservation_f32() {\n        let device = Default::default();\n\n        // Create TensorData with F32 dtype\n        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]);\n        assert_eq!(f32_data.dtype, DType::F32);\n\n        // Create a snapshot with F32 data\n        let snapshot = crate::TensorSnapshot::from_data(\n            f32_data.clone(),\n            vec![\"weight\".to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        assert_eq!(snapshot.dtype, DType::F32);\n\n        // Create applier with the F32 snapshot\n        let mut applier = Applier::<TestBackend>::new(vec![snapshot], None, None, false);\n\n        // Create target parameter\n        let target = Param::initialized(\n            ParamId::new(),\n            Tensor::<TestBackend, 2>::zeros([2, 2], &device),\n        );\n\n        // Apply the snapshot\n        applier.enter_module(\"weight\", \"\");\n        let loaded = applier.map_float(target);\n        applier.exit_module(\"weight\", \"\");\n\n        // Verify dtype is F32\n        assert_eq!(loaded.val().dtype(), DType::F32);\n\n        // Verify data values\n        let loaded_data = loaded.val().to_data().to_vec::<f32>().unwrap();\n        assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);\n    }\n\n    /// Test that F16 dtype is correctly preserved in TensorSnapshot.\n    ///\n    /// Note: Full F16 tensor loading requires a backend that supports F16\n    /// (e.g., CUDA, WebGPU). The NdArray backend does not support F16.\n    /// This test verifies that the snapshot correctly preserves F16 dtype,\n    /// which is the key part of the dtype preservation fix.\n    #[test]\n    fn dtype_preservation_f16_snapshot() {\n        use half::f16;\n\n        // Create TensorData with F16 dtype using the half crate\n        let f16_values: Vec<f16> = vec![\n            f16::from_f32(1.0),\n            f16::from_f32(2.0),\n            f16::from_f32(3.0),\n            f16::from_f32(4.0),\n        ];\n        let f16_data = TensorData::new(f16_values.clone(), [2, 2]);\n        assert_eq!(\n            f16_data.dtype,\n            DType::F16,\n            \"TensorData should have F16 dtype\"\n        );\n\n        // Create a snapshot with F16 data\n        let snapshot = crate::TensorSnapshot::from_data(\n            f16_data.clone(),\n            vec![\"weight\".to_string()],\n            vec![],\n            ParamId::new(),\n        );\n\n        // Verify snapshot preserves F16 dtype\n        assert_eq!(\n            snapshot.dtype,\n            DType::F16,\n            \"TensorSnapshot should preserve F16 dtype\"\n        );\n\n        // Verify the data can be retrieved with correct dtype\n        let retrieved_data = snapshot.to_data().expect(\"Should be able to retrieve data\");\n        assert_eq!(\n            retrieved_data.dtype,\n            DType::F16,\n            \"Retrieved data should have F16 dtype\"\n        );\n\n        // Verify the actual values are preserved\n        let retrieved_values: Vec<f16> = retrieved_data\n            .to_vec()\n            .expect(\"Should be able to convert to f16 vec\");\n        assert_eq!(\n            retrieved_values, f16_values,\n            \"F16 values should be preserved\"\n        );\n\n        // Note: To fully test F16 tensor creation, you would need a backend\n        // that supports F16 (like CUDA or WebGPU). The applier fix ensures\n        // that `Tensor::from_data_dtype(data, device, snapshot.dtype)` is\n        // called with DType::F16, which will correctly create an F16 tensor\n        // on backends that support it.\n    }\n\n    /// Test that BF16 dtype is correctly preserved in TensorSnapshot.\n    #[test]\n    fn dtype_preservation_bf16_snapshot() {\n        use half::bf16;\n\n        // Create TensorData with BF16 dtype\n        let bf16_values: Vec<bf16> = vec![\n            bf16::from_f32(1.0),\n            bf16::from_f32(2.0),\n            bf16::from_f32(3.0),\n            bf16::from_f32(4.0),\n        ];\n        let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]);\n        assert_eq!(\n            bf16_data.dtype,\n            DType::BF16,\n            \"TensorData should have BF16 dtype\"\n        );\n\n        // Create a snapshot with BF16 data\n        let snapshot = crate::TensorSnapshot::from_data(\n            bf16_data.clone(),\n            vec![\"weight\".to_string()],\n            vec![],\n            ParamId::new(),\n        );\n\n        // Verify snapshot preserves BF16 dtype\n        assert_eq!(\n            snapshot.dtype,\n            DType::BF16,\n            \"TensorSnapshot should preserve BF16 dtype\"\n        );\n\n        // Verify the data can be retrieved with correct dtype\n        let retrieved_data = snapshot.to_data().expect(\"Should be able to retrieve data\");\n        assert_eq!(\n            retrieved_data.dtype,\n            DType::BF16,\n            \"Retrieved data should have BF16 dtype\"\n        );\n\n        // Verify the actual values are preserved\n        let retrieved_values: Vec<bf16> = retrieved_data\n            .to_vec()\n            .expect(\"Should be able to convert to bf16 vec\");\n        assert_eq!(\n            retrieved_values, bf16_values,\n            \"BF16 values should be preserved\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/apply_result.rs",
    "content": "//! Result types and diagnostics for tensor application operations\n\nuse alloc::string::String;\nuse alloc::vec;\nuse alloc::vec::Vec;\n\nuse burn_tensor::{DType, Shape};\n\n/// Error types that can occur during tensor application\n#[derive(Debug, Clone)]\npub enum ApplyError {\n    /// Shape mismatch between expected and actual tensor\n    ShapeMismatch {\n        /// Path of the tensor\n        path: String,\n        /// Expected shape\n        expected: Shape,\n        /// Found shape\n        found: Shape,\n    },\n    /// Data type mismatch between expected and actual tensor\n    DTypeMismatch {\n        /// Path of the tensor\n        path: String,\n        /// Expected data type\n        expected: DType,\n        /// Found data type\n        found: DType,\n    },\n    /// Error from adapter transformation\n    AdapterError {\n        /// Path of the tensor\n        path: String,\n        /// Error message\n        message: String,\n    },\n    /// Error loading tensor data\n    LoadError {\n        /// Path of the tensor\n        path: String,\n        /// Error message\n        message: String,\n    },\n}\n\nimpl core::fmt::Display for ApplyError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        match self {\n            Self::ShapeMismatch {\n                path,\n                expected,\n                found,\n            } => {\n                write!(\n                    f,\n                    \"Shape mismatch for '{}': expected {:?}, found {:?}\",\n                    path, expected, found\n                )\n            }\n            Self::DTypeMismatch {\n                path,\n                expected,\n                found,\n            } => {\n                write!(\n                    f,\n                    \"DType mismatch for '{}': expected {:?}, found {:?}\",\n                    path, expected, found\n                )\n            }\n            Self::AdapterError { path, message } => {\n                write!(f, \"Adapter error for '{}': {}\", path, message)\n            }\n            Self::LoadError { path, message } => {\n                write!(f, \"Load error for '{}': {}\", path, message)\n            }\n        }\n    }\n}\n\nimpl core::error::Error for ApplyError {}\n\n/// Result of applying tensor snapshots to a module\n#[derive(Clone)]\npub struct ApplyResult {\n    /// Successfully applied tensor paths\n    pub applied: Vec<String>,\n    /// Skipped tensor paths (due to filter)\n    pub skipped: Vec<String>,\n    /// Missing tensor paths with their container stacks in dot notation (path, container_stack)\n    /// Container stack shows the hierarchy: \"Struct:Model.Struct:Linear\" or \"Struct:Model.Enum:ConvType.Struct:Linear\"\n    pub missing: Vec<(String, String)>,\n    /// Unused tensor paths (in snapshots but not in module)\n    pub unused: Vec<String>,\n    /// Errors encountered during application\n    pub errors: Vec<ApplyError>,\n}\n\nimpl ApplyResult {\n    /// Try to strip enum variant from a path\n    /// e.g., \"field.BaseConv.weight\" -> \"field.weight\"\n    fn strip_enum_variant(path: &str) -> Option<String> {\n        let segments: Vec<&str> = path.split('.').collect();\n\n        // Find segments that look like enum variants (CamelCase in middle of path)\n        let variant_indices: Vec<usize> = segments\n            .iter()\n            .enumerate()\n            .filter(|(i, segment)| {\n                *i > 0 && *i < segments.len() - 1 // Not first or last\n                    && !segment.is_empty()\n                    && segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)\n                    && segment.len() > 1\n                    && segment.chars().skip(1).any(|c| c.is_lowercase())\n            })\n            .map(|(i, _)| i)\n            .collect();\n\n        if variant_indices.is_empty() {\n            return None;\n        }\n\n        // Remove the first found variant and return the modified path\n        let mut result_segments = segments.clone();\n        result_segments.remove(variant_indices[0]);\n        Some(result_segments.join(\".\"))\n    }\n\n    /// Find similar paths for a given missing path (for \"Did you mean?\" suggestions)\n    fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec<String> {\n        // First, try exact match with enum variant stripped\n        if let Some(stripped) = Self::strip_enum_variant(missing_path)\n            && self.unused.contains(&stripped)\n        {\n            return vec![stripped];\n        }\n\n        // Fall back to Jaro similarity (used by Elixir for \"did you mean?\" suggestions)\n        // Jaro gives higher weight to matching prefixes, ideal for hierarchical tensor paths\n        let mut similarities: Vec<(String, f64)> = self\n            .unused\n            .iter()\n            .map(|available| {\n                let similarity = textdistance::nstr::jaro(missing_path, available);\n                (available.clone(), similarity)\n            })\n            .collect();\n\n        // Sort by similarity (higher = more similar)\n        similarities\n            .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal));\n\n        // Only suggest paths with >= 70% similarity\n        const SIMILARITY_THRESHOLD: f64 = 0.7;\n        similarities\n            .into_iter()\n            .filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD)\n            .take(max_suggestions)\n            .map(|(path, _)| path)\n            .collect()\n    }\n}\n\nimpl ApplyResult {\n    /// Check if the apply operation was successful (no errors)\n    /// Note: Missing tensors are not considered errors when allow_partial is true\n    pub fn is_success(&self) -> bool {\n        self.errors.is_empty()\n    }\n}\n\nimpl core::fmt::Debug for ApplyResult {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        // Delegate to Display for comprehensive output\n        core::fmt::Display::fmt(self, f)\n    }\n}\n\nimpl core::fmt::Display for ApplyResult {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        writeln!(f, \"┌─ Tensor Loading Summary ─────────────────────────\")?;\n        writeln!(f, \"│\")?;\n        writeln!(\n            f,\n            \"│ ✓ Successfully applied: {} tensors\",\n            self.applied.len()\n        )?;\n        writeln!(f, \"│ ⊘ Skipped (filtered):  {} tensors\", self.skipped.len())?;\n        writeln!(\n            f,\n            \"│ ✗ Missing in source:    {} tensors\",\n            self.missing.len()\n        )?;\n        writeln!(f, \"│ ? Unused in target:     {} tensors\", self.unused.len())?;\n        writeln!(f, \"│ ! Errors:               {} errors\", self.errors.len())?;\n\n        if !self.missing.is_empty() {\n            writeln!(f, \"│\")?;\n            writeln!(\n                f,\n                \"├─ Missing Tensors (requested by model but not found in source)\"\n            )?;\n            writeln!(f, \"│\")?;\n\n            // Use actual container stack data to detect enum variants\n            // Count how many missing paths have \"Enum:\" in their container stack\n            let enum_variant_missing: Vec<_> = self\n                .missing\n                .iter()\n                .filter(|(_, stack)| stack.contains(\"Enum:\"))\n                .collect();\n\n            if !enum_variant_missing.is_empty() {\n                writeln!(\n                    f,\n                    \"│  ⚠️  {} paths contain enum variants (detected from container stack)\",\n                    enum_variant_missing.len()\n                )?;\n                writeln!(\n                    f,\n                    \"│      Burn includes enum variant names in paths, but PyTorch doesn't.\"\n                )?;\n                writeln!(\n                    f,\n                    \"│      Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'\"\n                )?;\n                writeln!(f, \"│\")?;\n                writeln!(\n                    f,\n                    \"│      💡 Solution 1: Enable skip_enum_variants flag (simplest):\"\n                )?;\n                writeln!(f, \"│\")?;\n                writeln!(\n                    f,\n                    \"│         let mut store = PytorchStore::from_file(\\\"model.pth\\\")\"\n                )?;\n                writeln!(f, \"│             .skip_enum_variants(true);  // ← Add this\")?;\n                writeln!(f, \"│\")?;\n                writeln!(\n                    f,\n                    \"│      💡 Solution 2: Remap enum keys in source (most precise):\"\n                )?;\n                writeln!(f, \"│\")?;\n                writeln!(\n                    f,\n                    \"│         let mut store = SafetensorsStore::from_file(\\\"model.safetensors\\\")\"\n                )?;\n                writeln!(\n                    f,\n                    \"│             .with_key_remapping(r\\\"field\\\\.(\\\\w+)\\\", \\\"field.BaseConv.$1\\\");\"\n                )?;\n                writeln!(f, \"│\")?;\n            }\n\n            writeln!(f, \"│  First 10 missing tensors:\")?;\n            for (path, _) in self.missing.iter().take(10) {\n                writeln!(f, \"│    • {}\", path)?;\n\n                // Show \"Did you mean?\" suggestions for this path\n                let suggestions = self.find_similar_paths(path, 1);\n                if !suggestions.is_empty() {\n                    writeln!(f, \"│        Did you mean: '{}'?\", suggestions[0])?;\n                }\n            }\n            if self.missing.len() > 10 {\n                writeln!(f, \"│    ... and {} more\", self.missing.len() - 10)?;\n            }\n        }\n\n        if !self.unused.is_empty() {\n            writeln!(f, \"│\")?;\n            writeln!(f, \"├─ Unused Tensors (in source but not used by model)\")?;\n            writeln!(f, \"│\")?;\n            writeln!(f, \"│  First 10 unused tensors:\")?;\n            for path in self.unused.iter().take(10) {\n                writeln!(f, \"│    • {}\", path)?;\n            }\n            if self.unused.len() > 10 {\n                writeln!(f, \"│    ... and {} more\", self.unused.len() - 10)?;\n            }\n        }\n\n        if !self.errors.is_empty() {\n            writeln!(f, \"│\")?;\n            writeln!(f, \"├─ Errors\")?;\n            writeln!(f, \"│\")?;\n            for error in self.errors.iter().take(10) {\n                writeln!(f, \"│  ⚠️  {}\", error)?;\n            }\n            if self.errors.len() > 10 {\n                writeln!(f, \"│    ... and {} more\", self.errors.len() - 10)?;\n            }\n        }\n\n        writeln!(f, \"│\")?;\n        write!(f, \"└───────────────────────────────────────────────────\")?;\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/base.rs",
    "content": "//! Core types and constants for the Burnpack file format.\n//!\n//! See the [parent module](crate::burnpack) for the complete file format specification.\n\nuse alloc::collections::BTreeMap;\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse burn_tensor::DType;\nuse byteorder::{ByteOrder, LittleEndian};\nuse serde::{Deserialize, Serialize};\n\n/// Magic number identifying a Burnpack file: \"BURN\" in ASCII (0x4255524E)\n/// When written to file in little-endian format, appears as \"NRUB\" bytes\npub const MAGIC_NUMBER: u32 = 0x4255524E;\n\n/// Current format version\npub const FORMAT_VERSION: u16 = 0x0001;\n\n/// Size of the magic number in bytes\npub const MAGIC_SIZE: usize = 4;\n\n/// Size of the format version in bytes\npub const VERSION_SIZE: usize = 2;\n\n/// Size of the metadata size field in bytes\npub const METADATA_SIZE_FIELD_SIZE: usize = 4;\n\n/// Total header size (computed from components)\npub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;\n\n/// Alignment for tensor data in bytes.\n///\n/// All tensor data is aligned to 256-byte boundaries to enable efficient\n/// memory-mapped (mmap) zero-copy loading. This alignment ensures:\n/// - Proper pointer alignment for all tensor element types (f64 requires 8-byte alignment)\n/// - Cache-line friendly access (most CPUs use 64-byte cache lines)\n/// - GPU memory alignment (CUDA prefers 256-byte for coalesced access)\n/// - Future-proofing for wider SIMD (AVX-512 = 64 bytes, future AVX-1024 = 128 bytes)\n///\n/// Industry alignment choices:\n/// - 256-byte: GGUF, MLX, ncnn, MNN, TNN, vLLM-AWQ, Marlin (15+ formats)\n/// - 64-byte: SafeTensors (minimum for AVX-512)\n/// - 4096-byte: Core ML\n///\n/// 256-byte alignment has negligible overhead for typical tensor sizes while\n/// providing maximum compatibility with current and future hardware.\npub const TENSOR_ALIGNMENT: u64 = 256;\n\n/// Calculate the byte offset where the tensor data section starts.\n///\n/// The data section is padded to start at a 256-byte aligned position\n/// so that all tensor offsets (which are relative to data section) result\n/// in properly aligned absolute file positions for mmap zero-copy access.\n///\n/// This function must be used consistently by both writer and reader.\n#[inline]\npub fn aligned_data_section_start(metadata_size: usize) -> usize {\n    let unaligned_start = (HEADER_SIZE + metadata_size) as u64;\n    // Keep multiplication in u64 space to avoid overflow on 32-bit systems\n    (unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize\n}\n\n// Security limits to prevent DoS attacks via resource exhaustion\n// These can be adjusted based on your use case\n\n/// Maximum allowed metadata size (100 MB)\n/// Prevents memory exhaustion attacks via oversized metadata claims\npub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;\n\n/// Maximum allowed tensor size per tensor\n/// Prevents memory exhaustion attacks via oversized tensor claims\n/// 32-bit platforms: 2 GB limit (to fit within usize range)\n/// 64-bit platforms: 10 GB limit\n#[cfg(target_pointer_width = \"32\")]\npub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;\n#[cfg(not(target_pointer_width = \"32\"))]\npub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;\n\n/// Maximum allowed number of tensors (100,000)\n/// Prevents resource exhaustion via excessive tensor counts\npub const MAX_TENSOR_COUNT: usize = 100_000;\n\n/// Maximum CBOR deserialization recursion depth (128 levels)\n/// Prevents stack overflow attacks via deeply nested CBOR structures\npub const MAX_CBOR_RECURSION_DEPTH: usize = 128;\n\n/// Maximum allowed file size (100 GB)\n/// Prevents resource exhaustion from extremely large files\n/// This limit applies to file-based loading (mmap and buffered)\n#[cfg(feature = \"std\")]\npub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;\n\n/// Byte range for magic number in header\npub const fn magic_range() -> core::ops::Range<usize> {\n    let start = 0;\n    let end = start + MAGIC_SIZE;\n    start..end\n}\n\n/// Byte range for format version in header\npub const fn version_range() -> core::ops::Range<usize> {\n    let start = MAGIC_SIZE;\n    let end = start + VERSION_SIZE;\n    start..end\n}\n\n/// Byte range for metadata size field in header\npub const fn metadata_size_range() -> core::ops::Range<usize> {\n    let start = MAGIC_SIZE + VERSION_SIZE;\n    let end = start + METADATA_SIZE_FIELD_SIZE;\n    start..end\n}\n\n// Compile-time validation that ranges are correct\nconst _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);\n\n/// Header structure for Burnpack files\n#[derive(Debug, Clone, Copy)]\npub struct BurnpackHeader {\n    /// Magic number (4 bytes): 0x4255524E (\"BURN\")\n    pub magic: u32,\n    /// Format version (2 bytes)\n    pub version: u16,\n    /// Size of CBOR metadata in bytes (4 bytes)\n    pub metadata_size: u32,\n}\n\nimpl BurnpackHeader {\n    /// Create a new header with the given metadata size\n    #[allow(dead_code)]\n    pub fn new(metadata_size: u32) -> Self {\n        Self {\n            magic: MAGIC_NUMBER,\n            version: FORMAT_VERSION,\n            metadata_size,\n        }\n    }\n\n    /// Serialize header into bytes\n    pub fn into_bytes(self) -> [u8; HEADER_SIZE] {\n        let mut bytes = [0u8; HEADER_SIZE];\n        LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);\n        LittleEndian::write_u16(&mut bytes[version_range()], self.version);\n        LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);\n        bytes\n    }\n\n    /// Deserialize header from bytes\n    pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {\n        if bytes.len() < HEADER_SIZE {\n            return Err(BurnpackError::InvalidHeader);\n        }\n\n        let magic = LittleEndian::read_u32(&bytes[magic_range()]);\n        if magic != MAGIC_NUMBER {\n            return Err(BurnpackError::InvalidMagicNumber);\n        }\n\n        let version = LittleEndian::read_u16(&bytes[version_range()]);\n        let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);\n\n        Ok(Self {\n            magic,\n            version,\n            metadata_size,\n        })\n    }\n}\n\n/// Metadata structure serialized with CBOR\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct BurnpackMetadata {\n    /// Tensor descriptors mapped by name for efficient lookup\n    pub tensors: BTreeMap<String, TensorDescriptor>,\n    /// Optional additional metadata\n    #[serde(default, skip_serializing_if = \"BTreeMap::is_empty\")]\n    pub metadata: BTreeMap<String, String>,\n}\n\n/// Individual tensor descriptor\n#[derive(Debug, Clone, Serialize, Deserialize)]\npub struct TensorDescriptor {\n    /// Data type of the tensor\n    pub dtype: DType,\n    /// Tensor shape dimensions\n    pub shape: Vec<u64>,\n    /// Byte offsets in data section (start, end)\n    pub data_offsets: (u64, u64),\n    /// Parameter ID for training state persistence matching.\n    /// Generated automatically if not present during loading.\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub param_id: Option<u64>,\n}\n\n/// Error types for Burnpack operations\n#[derive(Debug)]\npub enum BurnpackError {\n    InvalidHeader,\n    InvalidMagicNumber,\n    InvalidVersion,\n    MetadataSerializationError(String),\n    MetadataDeserializationError(String),\n    IoError(String),\n    TensorNotFound(String),\n    TensorBytesSizeMismatch(String),\n    ValidationError(String),\n}\n\nimpl core::fmt::Display for BurnpackError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        match self {\n            BurnpackError::InvalidHeader => write!(f, \"Invalid header: insufficient bytes\"),\n            BurnpackError::InvalidMagicNumber => write!(f, \"Invalid magic number\"),\n            BurnpackError::InvalidVersion => write!(f, \"Unsupported version\"),\n            BurnpackError::MetadataSerializationError(e) => {\n                write!(f, \"Metadata serialization error: {}\", e)\n            }\n            BurnpackError::MetadataDeserializationError(e) => {\n                write!(f, \"Metadata deserialization error: {}\", e)\n            }\n            BurnpackError::IoError(e) => write!(f, \"I/O error: {}\", e),\n            BurnpackError::TensorNotFound(name) => write!(f, \"Tensor not found: {}\", name),\n            BurnpackError::TensorBytesSizeMismatch(e) => {\n                write!(f, \"Tensor bytes size mismatch: {}\", e)\n            }\n            BurnpackError::ValidationError(e) => write!(f, \"Validation error: {}\", e),\n        }\n    }\n}\n\nimpl core::error::Error for BurnpackError {}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/mod.rs",
    "content": "//! # Burnpack - Native Burn Model Storage Format\n//!\n//! Burnpack is the native binary storage format for Burn models, designed for efficient\n//! serialization, fast loading, and cross-platform compatibility.\n//!\n//! ## Key Features\n//!\n//! - **CBOR Metadata**: Structured metadata with efficient binary encoding\n//! - **Memory-Mapped Loading**: Zero-copy loading for optimal performance\n//! - **256-byte Tensor Alignment**: Enables efficient mmap zero-copy access\n//! - **No-std Support**: Works in embedded and WASM environments\n//! - **ParamId Persistence**: Preserves parameter identities for stateful training\n//! - **Lazy Tensor Loading**: Deferred data materialization for efficient memory usage\n//!\n//! ## File Format Structure\n//!\n//! ```text\n//! ┌──────────────────────────────────┐\n//! │  Header (10 bytes)               │\n//! ├──────────────────────────────────┤\n//! │  - Magic number (4 bytes)        │  0x4E525542 (\"NRUB\" in LE)\n//! │  - Version (2 bytes)             │  Format version (0x0001)\n//! │  - Metadata size (4 bytes)       │  Size of CBOR metadata (u32)\n//! ├──────────────────────────────────┤\n//! │  Metadata (CBOR)                 │\n//! ├──────────────────────────────────┤\n//! │  - Tensor descriptors (BTreeMap) │  Order-preserving map of tensor metadata\n//! │    Key: tensor name (string)     │  e.g., \"model.layer1.weight\"\n//! │    Value: TensorDescriptor       │\n//! │      - dtype: DType              │  Data type (F32, F64, I32, etc.)\n//! │      - shape: Vec<u64>           │  Tensor dimensions\n//! │      - data_offsets: (u64, u64)  │  (start, end) byte offsets (256-byte aligned)\n//! │      - param_id: Option<u64>     │  Parameter ID (for training state)\n//! │  - Additional metadata(BTreeMap) │  User-defined key-value pairs\n//! ├──────────────────────────────────┤\n//! │  Tensor Data Section             │\n//! ├──────────────────────────────────┤\n//! │  [padding][tensor1][padding]...  │  Each tensor aligned to 256-byte boundary\n//! │  Raw tensor bytes (little-endian)│  Enables mmap zero-copy loading\n//! └──────────────────────────────────┘\n//! ```\n//!\n//! ## Tensor Alignment\n//!\n//! All tensor data is aligned to 256-byte boundaries to enable efficient memory-mapped\n//! (mmap) zero-copy loading. This alignment ensures:\n//!\n//! - Proper pointer alignment for all tensor element types (f64 requires 8 bytes)\n//! - Cache-line friendly access (most CPUs use 64-byte cache lines)\n//! - GPU memory alignment (CUDA prefers 256-byte for coalesced access)\n//! - Future-proofing for wider SIMD instructions (AVX-512, future AVX-1024)\n//!\n//! The 256-byte alignment matches industry standards used by GGUF, MLX, ncnn, MNN,\n//! and other major model formats.\n\npub mod base;\npub mod reader;\npub mod store;\npub mod writer;\n\n#[cfg(test)]\nmod tests;\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/reader.rs",
    "content": "#[cfg(feature = \"std\")]\nuse super::base::MAX_FILE_SIZE;\nuse super::base::{\n    BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,\n    MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE,\n    aligned_data_section_start,\n};\nuse crate::TensorSnapshot;\nuse alloc::format;\nuse alloc::rc::Rc;\nuse alloc::string::ToString;\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_core::module::ParamId;\nuse burn_tensor::{Bytes, Shape, TensorData};\n\n#[cfg(feature = \"std\")]\nuse std::cell::RefCell;\n#[cfg(feature = \"std\")]\nuse std::fs::File;\n#[cfg(feature = \"std\")]\nuse std::io::{Read, Seek};\n#[cfg(feature = \"std\")]\nuse std::path::Path;\n\n/// Storage backend for BurnpackReader\npub(crate) enum StorageBackend {\n    /// Memory-based storage (also used for memory-mapped files converted to bytes::Bytes)\n    Memory(Rc<Bytes>),\n    /// File-based storage with buffered reading\n    #[cfg(feature = \"std\")]\n    #[allow(dead_code)]\n    FileBuffered { file: Rc<RefCell<File>> },\n}\n\nimpl StorageBackend {\n    /// Read data from storage into the provided buffer at the given offset.\n    ///\n    /// # Arguments\n    /// * `bytes` - The buffer to read into (caller-allocated)\n    /// * `offset` - Absolute file/data position to start reading from\n    ///\n    /// # Errors\n    ///\n    /// Returns an error if:\n    /// - The requested data range is out of bounds\n    /// - Less data is available than requested (indicates corruption or incorrect offset)\n    /// - File I/O fails\n    ///\n    /// # Notes\n    ///\n    /// The caller allocates the buffer, which allows for buffer reuse and future optimizations\n    /// like memory pools and pinned memory.\n    ///\n    /// This method ensures all backends have consistent behavior: if the exact number of\n    /// requested bytes cannot be read, an error is returned to prevent data corruption.\n    pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> {\n        match self {\n            StorageBackend::Memory(data) => {\n                let data_bytes = data.as_ref();\n                let end = offset.checked_add(bytes.len()).ok_or_else(|| {\n                    BurnpackError::IoError(format!(\n                        \"Offset overflow: offset {} + length {} exceeds maximum\",\n                        offset,\n                        bytes.len()\n                    ))\n                })?;\n\n                if end > data_bytes.len() {\n                    return Err(BurnpackError::IoError(format!(\n                        \"Read out of bounds: requested {}..{} but data length is {}\",\n                        offset,\n                        end,\n                        data_bytes.len()\n                    )));\n                }\n\n                bytes.copy_from_slice(&data_bytes[offset..end]);\n                Ok(())\n            }\n            #[cfg(feature = \"std\")]\n            StorageBackend::FileBuffered { file } => {\n                use std::io::SeekFrom;\n\n                let mut file = file.borrow_mut();\n                file.seek(SeekFrom::Start(offset as u64)).map_err(|e| {\n                    BurnpackError::IoError(format!(\"Failed to seek in file: {}\", e))\n                })?;\n\n                file.read_exact(bytes).map_err(|e| {\n                    BurnpackError::IoError(format!(\"Failed to read from file: {}\", e))\n                })?;\n                Ok(())\n            }\n        }\n    }\n\n    /// Get full data reference for raw access\n    #[allow(dead_code)]\n    pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> {\n        match self {\n            StorageBackend::Memory(data) => Ok(data.as_ref()),\n            #[cfg(feature = \"std\")]\n            StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(\n                \"Cannot get full bytes reference for FileBuffered backend\".into(),\n            )),\n        }\n    }\n\n    /// Attempt to slice bytes without copying (zero-copy).\n    ///\n    /// This uses `Bytes::clone()` + `split()` which is zero-copy when the underlying\n    /// `Bytes` was created via `Bytes::from_shared()` (backed by `bytes::Bytes`).\n    ///\n    /// # Returns\n    /// - `Ok(bytes)` - Successfully created a zero-copy slice\n    /// - `Err(_)` - Backend doesn't support zero-copy or split failed\n    pub(crate) fn slice_bytes(&self, start: usize, end: usize) -> Result<Bytes, BurnpackError> {\n        if end < start {\n            return Err(BurnpackError::IoError(format!(\n                \"Invalid slice range: end ({}) < start ({})\",\n                end, start\n            )));\n        }\n\n        match self {\n            StorageBackend::Memory(data) => {\n                // Clone the Bytes - cheap if backed by SharedBytesAllocationController\n                let cloned = (**data).clone();\n\n                // Split at start offset to get (_, right)\n                let (_, right) = cloned.split(start).map_err(|(_, e)| {\n                    BurnpackError::IoError(format!(\"Failed to split at start {}: {:?}\", start, e))\n                })?;\n\n                // Split right at (end - start) to get (middle, _)\n                let slice_len = end - start;\n                let (middle, _) = right.split(slice_len).map_err(|(_, e)| {\n                    BurnpackError::IoError(format!(\n                        \"Failed to split at length {}: {:?}\",\n                        slice_len, e\n                    ))\n                })?;\n\n                Ok(middle)\n            }\n            #[cfg(feature = \"std\")]\n            StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(\n                \"Zero-copy not supported for buffered file reading. Use from_file() with memmap feature for zero-copy loading.\".into(),\n            )),\n        }\n    }\n}\n\n/// Reader for loading Burnpack files\npub struct BurnpackReader {\n    /// Parsed metadata\n    pub(crate) metadata: BurnpackMetadata,\n    /// Storage backend\n    pub(crate) storage: StorageBackend,\n    /// Offset to the start of tensor data\n    pub(crate) data_offset: usize,\n}\n\nimpl BurnpackReader {\n    /// Load from bytes\n    pub fn from_bytes(bytes: Bytes) -> Result<Self, BurnpackError> {\n        // Validate minimum size\n        if bytes.len() < HEADER_SIZE {\n            return Err(BurnpackError::InvalidHeader);\n        }\n\n        // Parse header\n        let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?;\n\n        // Verify magic number\n        if header.magic != MAGIC_NUMBER {\n            return Err(BurnpackError::InvalidMagicNumber);\n        }\n\n        // Verify version compatibility\n        if header.version > FORMAT_VERSION {\n            return Err(BurnpackError::InvalidVersion);\n        }\n\n        // Validate metadata size against security limit\n        if header.metadata_size > MAX_METADATA_SIZE {\n            return Err(BurnpackError::ValidationError(format!(\n                \"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)\",\n                header.metadata_size, MAX_METADATA_SIZE\n            )));\n        }\n\n        // Parse metadata\n        let metadata_start = HEADER_SIZE;\n        let metadata_end = metadata_start\n            .checked_add(header.metadata_size as usize)\n            .ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Metadata size overflow: {} + {}\",\n                    metadata_start, header.metadata_size\n                ))\n            })?;\n\n        if bytes.len() < metadata_end {\n            return Err(BurnpackError::InvalidHeader);\n        }\n\n        let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(\n            &bytes[metadata_start..metadata_end],\n            MAX_CBOR_RECURSION_DEPTH,\n        )\n        .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;\n\n        // Validate tensor count against security limit\n        if metadata.tensors.len() > MAX_TENSOR_COUNT {\n            return Err(BurnpackError::ValidationError(format!(\n                \"File contains {} tensors, exceeding maximum of {} (potential DoS attack)\",\n                metadata.tensors.len(),\n                MAX_TENSOR_COUNT\n            )));\n        }\n\n        // Validate total file size - ensure file is large enough for all claimed tensor data\n        if !metadata.tensors.is_empty() {\n            let max_data_offset = metadata\n                .tensors\n                .values()\n                .map(|t| t.data_offsets.1)\n                .max()\n                .unwrap_or(0);\n\n            let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {\n                BurnpackError::ValidationError(format!(\n                    \"Data offset {} exceeds platform maximum\",\n                    max_data_offset\n                ))\n            })?;\n\n            let min_file_size =\n                metadata_end\n                    .checked_add(max_data_offset_usize)\n                    .ok_or_else(|| {\n                        BurnpackError::ValidationError(\"File size calculation overflow\".into())\n                    })?;\n\n            if bytes.len() < min_file_size {\n                return Err(BurnpackError::ValidationError(format!(\n                    \"File truncated: expected at least {} bytes, got {} bytes\",\n                    min_file_size,\n                    bytes.len()\n                )));\n            }\n        }\n\n        Ok(Self {\n            metadata,\n            storage: StorageBackend::Memory(Rc::new(bytes)),\n            data_offset: aligned_data_section_start(header.metadata_size as usize),\n        })\n    }\n\n    /// Load from file with memory mapping (most efficient for large files)\n    #[cfg(all(feature = \"std\", feature = \"memmap\"))]\n    pub(crate) fn from_file_mmap<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {\n        let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        // Validate maximum file size to prevent resource exhaustion\n        let file_size = file\n            .metadata()\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?\n            .len();\n\n        if file_size > MAX_FILE_SIZE {\n            return Err(BurnpackError::ValidationError(format!(\n                \"File size {} bytes exceeds maximum allowed size of {} bytes\",\n                file_size, MAX_FILE_SIZE\n            )));\n        }\n\n        // Memory map the file\n        let mmap = unsafe {\n            memmap2::MmapOptions::new()\n                .map(&file)\n                .map_err(|e| BurnpackError::IoError(e.to_string()))?\n        };\n\n        // Parse header\n        if mmap.len() < HEADER_SIZE {\n            return Err(BurnpackError::InvalidHeader);\n        }\n\n        let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?;\n\n        // Verify magic number and version\n        if header.magic != MAGIC_NUMBER {\n            return Err(BurnpackError::InvalidMagicNumber);\n        }\n\n        if header.version > FORMAT_VERSION {\n            return Err(BurnpackError::InvalidVersion);\n        }\n\n        // Validate metadata size against security limit\n        if header.metadata_size > MAX_METADATA_SIZE {\n            return Err(BurnpackError::ValidationError(format!(\n                \"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)\",\n                header.metadata_size, MAX_METADATA_SIZE\n            )));\n        }\n\n        // Parse metadata\n        let metadata_start = HEADER_SIZE;\n        let metadata_end = metadata_start\n            .checked_add(header.metadata_size as usize)\n            .ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Metadata size overflow: {} + {}\",\n                    metadata_start, header.metadata_size\n                ))\n            })?;\n\n        if mmap.len() < metadata_end {\n            return Err(BurnpackError::InvalidHeader);\n        }\n\n        let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(\n            &mmap[metadata_start..metadata_end],\n            MAX_CBOR_RECURSION_DEPTH,\n        )\n        .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;\n\n        // Validate tensor count against security limit\n        if metadata.tensors.len() > MAX_TENSOR_COUNT {\n            return Err(BurnpackError::ValidationError(format!(\n                \"File contains {} tensors, exceeding maximum of {} (potential DoS attack)\",\n                metadata.tensors.len(),\n                MAX_TENSOR_COUNT\n            )));\n        }\n\n        // Validate total file size - ensure file is large enough for all claimed tensor data\n        if !metadata.tensors.is_empty() {\n            let max_data_offset = metadata\n                .tensors\n                .values()\n                .map(|t| t.data_offsets.1)\n                .max()\n                .unwrap_or(0);\n\n            let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {\n                BurnpackError::ValidationError(format!(\n                    \"Data offset {} exceeds platform maximum\",\n                    max_data_offset\n                ))\n            })?;\n\n            let min_file_size =\n                metadata_end\n                    .checked_add(max_data_offset_usize)\n                    .ok_or_else(|| {\n                        BurnpackError::ValidationError(\"File size calculation overflow\".into())\n                    })?;\n\n            if mmap.len() < min_file_size {\n                return Err(BurnpackError::ValidationError(format!(\n                    \"File truncated: expected at least {} bytes, got {} bytes\",\n                    min_file_size,\n                    mmap.len()\n                )));\n            }\n        }\n\n        // Convert mmap to bytes::Bytes for zero-copy slicing support\n        // bytes::Bytes::from_owner takes ownership and enables efficient slicing\n        let shared_bytes = bytes::Bytes::from_owner(mmap);\n        let bytes = Bytes::from_shared(shared_bytes, burn_tensor::AllocationProperty::File);\n\n        Ok(Self {\n            metadata,\n            storage: StorageBackend::Memory(Rc::new(bytes)),\n            data_offset: aligned_data_section_start(header.metadata_size as usize),\n        })\n    }\n\n    /// Load from file - automatically uses memory mapping if available, otherwise uses buffered reading\n    #[cfg(feature = \"std\")]\n    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {\n        #[cfg(feature = \"memmap\")]\n        {\n            // Use memory mapping for efficient access\n            Self::from_file_mmap(path)\n        }\n        #[cfg(not(feature = \"memmap\"))]\n        {\n            // Fall back to buffered reading for memory efficiency\n            Self::from_file_buffered(path)\n        }\n    }\n\n    /// Load from file with buffered reading (memory efficient but slower)\n    /// This is less efficient than memory mapping but works everywhere\n    #[cfg(feature = \"std\")]\n    #[allow(dead_code)]\n    pub(crate) fn from_file_buffered<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {\n        let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        // Validate maximum file size to prevent resource exhaustion\n        let file_size = file\n            .metadata()\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?\n            .len();\n\n        if file_size > MAX_FILE_SIZE {\n            return Err(BurnpackError::ValidationError(format!(\n                \"File size {} bytes exceeds maximum allowed size of {} bytes\",\n                file_size, MAX_FILE_SIZE\n            )));\n        }\n\n        // Read header\n        let mut header_bytes = [0u8; HEADER_SIZE];\n        file.read_exact(&mut header_bytes)\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        let header = BurnpackHeader::from_bytes(&header_bytes)?;\n\n        // Verify version\n        if header.version > FORMAT_VERSION {\n            return Err(BurnpackError::InvalidVersion);\n        }\n\n        // Validate metadata size against security limit\n        if header.metadata_size > MAX_METADATA_SIZE {\n            return Err(BurnpackError::ValidationError(format!(\n                \"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)\",\n                header.metadata_size, MAX_METADATA_SIZE\n            )));\n        }\n\n        // Read metadata\n        let mut metadata_bytes = vec![0u8; header.metadata_size as usize];\n        file.read_exact(&mut metadata_bytes)\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(\n            metadata_bytes.as_slice(),\n            MAX_CBOR_RECURSION_DEPTH,\n        )\n        .map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;\n\n        // Validate tensor count against security limit\n        if metadata.tensors.len() > MAX_TENSOR_COUNT {\n            return Err(BurnpackError::ValidationError(format!(\n                \"File contains {} tensors, exceeding maximum of {} (potential DoS attack)\",\n                metadata.tensors.len(),\n                MAX_TENSOR_COUNT\n            )));\n        }\n\n        // Calculate metadata end offset\n        let metadata_end = HEADER_SIZE\n            .checked_add(header.metadata_size as usize)\n            .ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Metadata size overflow: {} + {}\",\n                    HEADER_SIZE, header.metadata_size\n                ))\n            })?;\n\n        // Validate total file size - ensure file is large enough for all claimed tensor data\n        if !metadata.tensors.is_empty() {\n            let max_data_offset = metadata\n                .tensors\n                .values()\n                .map(|t| t.data_offsets.1)\n                .max()\n                .unwrap_or(0);\n\n            let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {\n                BurnpackError::ValidationError(format!(\n                    \"Data offset {} exceeds platform maximum\",\n                    max_data_offset\n                ))\n            })?;\n\n            let min_file_size =\n                metadata_end\n                    .checked_add(max_data_offset_usize)\n                    .ok_or_else(|| {\n                        BurnpackError::ValidationError(\"File size calculation overflow\".into())\n                    })?;\n\n            // Get actual file size\n            let file_size = file\n                .metadata()\n                .map_err(|e| BurnpackError::IoError(e.to_string()))?\n                .len() as usize;\n\n            if file_size < min_file_size {\n                return Err(BurnpackError::ValidationError(format!(\n                    \"File truncated: expected at least {} bytes, got {} bytes\",\n                    min_file_size, file_size\n                )));\n            }\n        }\n\n        Ok(Self {\n            metadata,\n            storage: StorageBackend::FileBuffered {\n                file: Rc::new(RefCell::new(file)),\n            },\n            data_offset: aligned_data_section_start(header.metadata_size as usize),\n        })\n    }\n\n    /// Get all tensor snapshots at once for efficient loading (always copies data)\n    pub fn get_snapshots(&self) -> Result<Vec<TensorSnapshot>, BurnpackError> {\n        self.get_snapshots_internal(false)\n    }\n\n    /// Get all tensor snapshots with optional zero-copy loading.\n    ///\n    /// When `zero_copy` is true and the backend supports it (Memory backend with\n    /// `Bytes::from_shared()`), tensor data is sliced without copying. This keeps\n    /// the original data alive as long as any tensor holds a reference.\n    ///\n    /// When `zero_copy` is false or the backend doesn't support it, data is copied\n    /// into newly allocated buffers (default behavior).\n    pub fn get_snapshots_zero_copy(\n        &self,\n        zero_copy: bool,\n    ) -> Result<Vec<TensorSnapshot>, BurnpackError> {\n        self.get_snapshots_internal(zero_copy)\n    }\n\n    /// Internal implementation with optional zero-copy support\n    fn get_snapshots_internal(\n        &self,\n        zero_copy: bool,\n    ) -> Result<Vec<TensorSnapshot>, BurnpackError> {\n        let mut snapshots = Vec::new();\n\n        for (name, descriptor) in &self.metadata.tensors {\n            // Clone metadata for use in closure\n            // Convert shape dimensions with overflow checking\n            let shape: Shape = Shape::from(descriptor\n                .shape\n                .iter()\n                .map(|&s| {\n                    s.try_into().map_err(|_| {\n                        BurnpackError::ValidationError(format!(\n                            \"Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum\",\n                            name, s\n                        ))\n                    })\n                })\n                .collect::<Result<Vec<usize>, BurnpackError>>()?);\n\n            let dtype = descriptor.dtype;\n\n            // Clone storage reference for the closure\n            let storage = match &self.storage {\n                StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()),\n                #[cfg(feature = \"std\")]\n                StorageBackend::FileBuffered { file } => {\n                    StorageBackend::FileBuffered { file: file.clone() }\n                }\n            };\n\n            // Always use absolute positions for all backends\n            // Convert offsets with overflow checking\n            let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {\n                BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum\",\n                    name, descriptor.data_offsets.0\n                ))\n            })?;\n\n            let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {\n                BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum\",\n                    name, descriptor.data_offsets.1\n                ))\n            })?;\n\n            let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {\n                BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' has corrupted offset data: start offset overflow {} + {}\",\n                    name, self.data_offset, offset_start\n                ))\n            })?;\n\n            let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {\n                BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' has corrupted offset data: end offset overflow {} + {}\",\n                    name, self.data_offset, offset_end\n                ))\n            })?;\n\n            // Clone shape for the closure (TensorSnapshot::from_closure will also need it)\n            let shape_for_closure = shape.clone();\n\n            // Validate offset range\n            if end < start {\n                return Err(BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' has corrupted offset data: end offset {} < start offset {}\",\n                    name, end, start\n                )));\n            }\n\n            // Validate tensor size against security limit\n            let tensor_size = end - start;\n            if tensor_size > MAX_TENSOR_SIZE {\n                return Err(BurnpackError::ValidationError(format!(\n                    \"Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)\",\n                    name, tensor_size, MAX_TENSOR_SIZE\n                )));\n            }\n\n            // Restore param_id if it was saved, otherwise generate\n            let tensor_id = descriptor\n                .param_id\n                .map(ParamId::from)\n                .unwrap_or_else(ParamId::new);\n\n            // Create the data-loading closure based on zero_copy flag\n            let data_fn: Rc<dyn Fn() -> Result<TensorData, crate::TensorSnapshotError>> =\n                if zero_copy {\n                    // Zero-copy closure: slice without copying, error if not supported\n                    Rc::new(move || {\n                        let bytes = storage.slice_bytes(start, end).map_err(|e| {\n                            crate::TensorSnapshotError::IoError(format!(\n                                \"Zero-copy slice failed: {}\",\n                                e\n                            ))\n                        })?;\n                        Ok(TensorData::from_bytes(\n                            bytes,\n                            shape_for_closure.clone(),\n                            dtype,\n                        ))\n                    })\n                } else {\n                    // Copying closure: always allocate and copy\n                    Rc::new(move || {\n                        let len = end - start;\n                        // TODO Should be allocated by the backend in the future\n                        // See https://github.com/tracel-ai/burn/pull/3792#discussion_r2416812091\n                        let mut data_bytes = vec![0u8; len];\n                        storage.read_into(&mut data_bytes, start).map_err(|e| {\n                            crate::TensorSnapshotError::IoError(format!(\n                                \"Failed to read tensor data: {}\",\n                                e\n                            ))\n                        })?;\n                        Ok(TensorData::from_bytes_vec(\n                            data_bytes,\n                            shape_for_closure.clone(),\n                            dtype,\n                        ))\n                    })\n                };\n\n            // Create lazy TensorSnapshot\n            let snapshot = TensorSnapshot::from_closure(\n                data_fn,\n                dtype,\n                shape,\n                name.split('.').map(|s| s.to_string()).collect(),\n                vec![],    // empty container_stack\n                tensor_id, // restored or newly generated param id\n            );\n\n            snapshots.push(snapshot);\n        }\n\n        Ok(snapshots)\n    }\n\n    // Legacy methods for test compatibility - will be removed\n\n    /// Get tensor as TensorSnapshot with lazy loading\n    #[allow(dead_code)]\n    pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result<TensorSnapshot, BurnpackError> {\n        let snapshots = self.get_snapshots()?;\n        snapshots\n            .into_iter()\n            .find(|s| s.full_path() == name)\n            .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))\n    }\n\n    /// Get list of tensor names\n    #[allow(dead_code)]\n    pub(crate) fn tensor_names(&self) -> Vec<&str> {\n        self.metadata\n            .tensors\n            .keys()\n            .map(|name| name.as_str())\n            .collect()\n    }\n\n    /// Get metadata\n    #[allow(dead_code)]\n    pub(crate) fn metadata(&self) -> &BurnpackMetadata {\n        &self.metadata\n    }\n\n    /// Get tensor data as raw bytes\n    #[allow(dead_code)]\n    pub(crate) fn get_tensor_data(&self, name: &str) -> Result<Vec<u8>, BurnpackError> {\n        let descriptor = self\n            .metadata\n            .tensors\n            .get(name)\n            .ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?;\n\n        // Always use absolute positions for all backends\n        // Convert offsets with overflow checking\n        let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {\n            BurnpackError::IoError(format!(\n                \"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum\",\n                name, descriptor.data_offsets.0\n            ))\n        })?;\n\n        let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {\n            BurnpackError::IoError(format!(\n                \"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum\",\n                name, descriptor.data_offsets.1\n            ))\n        })?;\n\n        let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {\n            BurnpackError::IoError(format!(\n                \"Tensor '{}' has corrupted offset data: start offset overflow {} + {}\",\n                name, self.data_offset, offset_start\n            ))\n        })?;\n\n        let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {\n            BurnpackError::IoError(format!(\n                \"Tensor '{}' has corrupted offset data: end offset overflow {} + {}\",\n                name, self.data_offset, offset_end\n            ))\n        })?;\n\n        // Validate offset range\n        if end < start {\n            return Err(BurnpackError::IoError(format!(\n                \"Tensor '{}' has corrupted offset data: end offset {} < start offset {}\",\n                name, end, start\n            )));\n        }\n\n        let len = end - start;\n        let mut buffer = vec![0u8; len];\n        self.storage.read_into(&mut buffer, start)?;\n        Ok(buffer)\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/store.rs",
    "content": "#[cfg(feature = \"std\")]\nuse std::path::PathBuf;\n\nuse super::reader::BurnpackReader;\nuse super::writer::BurnpackWriter;\n#[cfg(feature = \"std\")]\nuse crate::KeyRemapper;\nuse crate::burnpack::base::BurnpackError;\nuse crate::{\n    IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot,\n};\nuse alloc::boxed::Box;\nuse alloc::collections::BTreeMap;\nuse alloc::format;\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse burn_core::prelude::Backend;\nuse burn_tensor::Bytes;\n\n/// Store mode for BurnpackStore\nenum StoreMode {\n    #[cfg(feature = \"std\")]\n    File(PathBuf),\n    Bytes(Option<Bytes>),\n}\n\n/// BurnpackStore - A Burn-specific file format store using CBOR for metadata\npub struct BurnpackStore {\n    /// Store mode - either file path or bytes\n    mode: StoreMode,\n    /// Optional filter for selective loading/saving\n    filter: Option<PathFilter>,\n    /// Additional metadata\n    metadata: BTreeMap<String, String>,\n    /// Allow partial loading (ignore missing tensors)\n    allow_partial: bool,\n    /// Validate tensors during loading (check shapes and dtypes)\n    validate: bool,\n    /// Allow overwriting existing files (default: false)\n    overwrite: bool,\n    /// Enable zero-copy tensor loading (default: false)\n    ///\n    /// When enabled and the backend supports it, tensor data is sliced from\n    /// the source without copying. This requires keeping the source data alive.\n    zero_copy: bool,\n    /// Automatically append .bpk extension if not present (default: true)\n    #[cfg(feature = \"std\")]\n    auto_extension: bool,\n    /// Key remapper for tensor name transformations\n    #[cfg(feature = \"std\")]\n    remapper: KeyRemapper,\n    /// Adapter applied when loading (source -> Burn)\n    from_adapter: Box<dyn ModuleAdapter>,\n    /// Adapter applied when saving (Burn -> target)\n    to_adapter: Box<dyn ModuleAdapter>,\n    /// Writer for saving\n    writer: Option<BurnpackWriter>,\n    /// Reader for loading\n    reader: Option<BurnpackReader>,\n    /// Cached tensor snapshots (parsed once, reused)\n    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,\n}\n\nimpl BurnpackStore {\n    /// Get the default metadata that includes Burn framework information.\n    ///\n    /// This includes:\n    /// - `format`: \"burnpack\"\n    /// - `producer`: \"burn\"\n    /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)\n    ///\n    /// These metadata fields are automatically added to all saved models.\n    pub fn default_metadata() -> BTreeMap<String, String> {\n        let mut metadata = BTreeMap::new();\n        metadata.insert(\"format\".into(), \"burnpack\".into());\n        metadata.insert(\"producer\".into(), \"burn\".into());\n        metadata.insert(\"version\".into(), env!(\"CARGO_PKG_VERSION\").into());\n        metadata\n    }\n    /// Create a new store from a file path\n    ///\n    /// By default, automatically appends `.bpk` extension if the path doesn't have one.\n    /// Use `.auto_extension(false)` to disable this behavior.\n    ///\n    /// # Examples\n    ///\n    /// ```no_run\n    /// # use burn_store::BurnpackStore;\n    /// // Automatically appends .bpk\n    /// let store = BurnpackStore::from_file(\"model\");  // creates \"model.bpk\"\n    ///\n    /// // Already has extension, no append\n    /// let store = BurnpackStore::from_file(\"model.bpk\");  // uses \"model.bpk\"\n    /// let store = BurnpackStore::from_file(\"model.myext\");  // uses \"model.myext\"\n    ///\n    /// // Disable auto-extension\n    /// let store = BurnpackStore::from_file(\"model\").auto_extension(false);  // uses \"model\"\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {\n        Self {\n            mode: StoreMode::File(path.as_ref().to_path_buf()),\n            filter: None,\n            metadata: Self::default_metadata(),\n            allow_partial: false,\n            validate: true,\n            overwrite: false,\n            zero_copy: false,\n            #[cfg(feature = \"std\")]\n            auto_extension: true,\n            #[cfg(feature = \"std\")]\n            remapper: KeyRemapper::new(),\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            writer: None,\n            reader: None,\n            snapshots_cache: None,\n        }\n    }\n\n    /// Create a new store from bytes (for reading) or empty (for writing)\n    pub fn from_bytes(bytes: Option<Bytes>) -> Self {\n        Self {\n            mode: StoreMode::Bytes(bytes),\n            filter: None,\n            metadata: Self::default_metadata(),\n            allow_partial: false,\n            validate: true,\n            overwrite: false,\n            zero_copy: false,\n            #[cfg(feature = \"std\")]\n            auto_extension: false, // Not used for bytes mode\n            #[cfg(feature = \"std\")]\n            remapper: KeyRemapper::new(),\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            writer: None,\n            reader: None,\n            snapshots_cache: None,\n        }\n    }\n\n    /// Create a new store from static bytes with zero-copy loading enabled.\n    ///\n    /// This is optimized for embedded model weights where the data lives in the\n    /// binary's `.rodata` section. Tensor data is sliced without copying, keeping\n    /// the static reference alive.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// static MODEL_DATA: &[u8] = include_bytes!(\"model.bpk\");\n    /// let store = BurnpackStore::from_static(MODEL_DATA);\n    /// ```\n    pub fn from_static(data: &'static [u8]) -> Self {\n        use burn_tensor::AllocationProperty;\n\n        // Create bytes::Bytes from static data (zero-copy, stays in .rodata)\n        let shared = bytes::Bytes::from_static(data);\n\n        // Wrap in cubecl Bytes with shared-bytes allocation controller\n        let bytes = Bytes::from_shared(shared, AllocationProperty::Other);\n\n        Self {\n            mode: StoreMode::Bytes(Some(bytes)),\n            filter: None,\n            metadata: Self::default_metadata(),\n            allow_partial: false,\n            validate: true,\n            overwrite: false,\n            zero_copy: true, // Enable zero-copy by default for static data\n            #[cfg(feature = \"std\")]\n            auto_extension: false,\n            #[cfg(feature = \"std\")]\n            remapper: KeyRemapper::new(),\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            writer: None,\n            reader: None,\n            snapshots_cache: None,\n        }\n    }\n\n    /// Add metadata key-value pair\n    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {\n        self.metadata.insert(key.into(), value.into());\n        self\n    }\n\n    /// Clear all metadata (including defaults)\n    ///\n    /// This removes all metadata including the default format, producer, and version fields.\n    /// Use with caution as some tools may expect these fields to be present.\n    pub fn clear_metadata(mut self) -> Self {\n        self.metadata.clear();\n        self\n    }\n\n    /// Allow partial loading (ignore missing tensors)\n    ///\n    /// When set to `true`, the store will not fail if some tensors are missing\n    /// during loading. This is useful when loading a subset of a model's parameters.\n    ///\n    /// Default: `false`\n    pub fn allow_partial(mut self, allow: bool) -> Self {\n        self.allow_partial = allow;\n        self\n    }\n\n    /// Enable or disable validation during loading\n    ///\n    /// When validation is enabled, the store will check that loaded tensors\n    /// match the expected shapes and data types. Disabling validation can\n    /// improve performance but may lead to runtime errors if data is corrupted.\n    ///\n    /// Default: `true`\n    pub fn validate(mut self, validate: bool) -> Self {\n        self.validate = validate;\n        self\n    }\n\n    /// Allow overwriting existing files when saving\n    ///\n    /// When set to `false`, attempting to save to an existing file will result in an error.\n    /// When set to `true`, existing files will be overwritten without warning.\n    ///\n    /// Default: `false`\n    pub fn overwrite(mut self, overwrite: bool) -> Self {\n        self.overwrite = overwrite;\n        self\n    }\n\n    /// Enable or disable zero-copy tensor loading.\n    ///\n    /// When enabled and the backend supports it (memory-backed with shared bytes),\n    /// tensor data is sliced from the source without copying. This keeps the source\n    /// data alive as long as any tensor holds a reference.\n    ///\n    /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).\n    /// Use this method to enable it for other memory-backed stores created with\n    /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.\n    ///\n    /// Default: `false` (except for `from_static` which defaults to `true`)\n    pub fn zero_copy(mut self, enable: bool) -> Self {\n        self.zero_copy = enable;\n        self\n    }\n\n    /// Enable or disable automatic .bpk extension appending\n    ///\n    /// When enabled (default), automatically appends `.bpk` to the file path\n    /// if no extension is detected. If an extension is already present, it is preserved.\n    ///\n    /// When disabled, uses the exact path provided without modification.\n    ///\n    /// Default: `true`\n    ///\n    /// # Examples\n    ///\n    /// ```no_run\n    /// # use burn_store::BurnpackStore;\n    /// // With auto_extension enabled (default)\n    /// let store = BurnpackStore::from_file(\"model\");  // -> \"model.bpk\"\n    ///\n    /// // With auto_extension disabled\n    /// let store = BurnpackStore::from_file(\"model\")\n    ///     .auto_extension(false);  // -> \"model\"\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn auto_extension(mut self, enable: bool) -> Self {\n        self.auto_extension = enable;\n        self\n    }\n\n    /// Set the adapter for loading tensors (converting from source format to Burn).\n    pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {\n        self.from_adapter = Box::new(adapter);\n        self\n    }\n\n    /// Set the adapter for saving tensors (converting from Burn to target format).\n    pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {\n        self.to_adapter = Box::new(adapter);\n        self\n    }\n\n    /// Set path filter for selective loading/saving\n    pub fn with_filter(mut self, filter: PathFilter) -> Self {\n        self.filter = Some(filter);\n        self\n    }\n\n    /// Add regex pattern to filter\n    #[cfg(feature = \"std\")]\n    pub fn with_regex(mut self, pattern: &str) -> Self {\n        let filter = self.filter.unwrap_or_default();\n        self.filter = Some(filter.with_regex(pattern));\n        self\n    }\n\n    /// Add exact path to filter\n    pub fn with_full_path(mut self, path: impl Into<String>) -> Self {\n        let filter = self.filter.unwrap_or_default();\n        self.filter = Some(filter.with_full_path(path));\n        self\n    }\n\n    /// Match all tensors (no filtering)\n    pub fn match_all(mut self) -> Self {\n        self.filter = Some(PathFilter::new().match_all());\n        self\n    }\n\n    /// Set key remapper for tensor name transformations during loading\n    #[cfg(feature = \"std\")]\n    pub fn remap(mut self, remapper: KeyRemapper) -> Self {\n        self.remapper = remapper;\n        self\n    }\n\n    /// Add a single regex pattern for key remapping\n    #[cfg(feature = \"std\")]\n    pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self\n    where\n        S1: AsRef<str>,\n        S2: Into<String>,\n    {\n        self.remapper = self\n            .remapper\n            .add_pattern(from.as_ref(), to.into())\n            .expect(\"Invalid regex pattern\");\n        self\n    }\n\n    /// Set the path filter\n    pub fn filter(mut self, filter: PathFilter) -> Self {\n        self.filter = Some(filter);\n        self\n    }\n\n    /// Get the bytes after writing (only valid for bytes mode after collecting)\n    pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {\n        if let Some(writer) = &self.writer {\n            return writer.to_bytes();\n        }\n\n        match &self.mode {\n            StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),\n            _ => Err(BurnpackError::IoError(\"No bytes available\".into())),\n        }\n    }\n\n    /// Process the file path with auto-extension logic\n    #[cfg(feature = \"std\")]\n    fn process_path(&self, path: &std::path::Path) -> PathBuf {\n        if !self.auto_extension {\n            return path.to_path_buf();\n        }\n\n        // Check if path already has an extension\n        if path.extension().is_some() {\n            // Has extension, use as-is\n            return path.to_path_buf();\n        }\n\n        // No extension, append .bpk\n        let mut new_path = path.to_path_buf();\n        new_path.set_extension(\"bpk\");\n        new_path\n    }\n\n    /// Ensure the reader is initialized, loading from storage if needed\n    fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {\n        if self.reader.is_none() {\n            let reader = match &self.mode {\n                #[cfg(feature = \"std\")]\n                StoreMode::File(path) => {\n                    let final_path = self.process_path(path);\n                    BurnpackReader::from_file(&final_path)?\n                }\n                StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,\n                StoreMode::Bytes(None) => {\n                    return Err(BurnpackError::IoError(\"No bytes to read from\".into()));\n                }\n            };\n            self.reader = Some(reader);\n        }\n\n        self.reader\n            .as_ref()\n            .ok_or_else(|| BurnpackError::IoError(\"Reader not initialized\".into()))\n    }\n}\n\nimpl ModuleStore for BurnpackStore {\n    type Error = BurnpackError;\n\n    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &M,\n    ) -> Result<(), Self::Error> {\n        // Invalidate cache since we're writing new data\n        self.snapshots_cache = None;\n        self.reader = None;\n\n        // Collect snapshots from module with adapter\n        let snapshots = module.collect(self.filter.clone(), Some(self.to_adapter.clone()), false);\n\n        // Initialize writer with snapshots\n        let mut writer = BurnpackWriter::new(snapshots);\n\n        // Add metadata using builder pattern\n        for (key, value) in &self.metadata {\n            writer = writer.with_metadata(key.as_str(), value.as_str());\n        }\n\n        // Store the writer for finalization\n        self.writer = Some(writer);\n\n        // Write to storage based on mode\n        if let Some(writer) = &self.writer {\n            match &self.mode {\n                #[cfg(feature = \"std\")]\n                StoreMode::File(path) => {\n                    // Process path with auto-extension logic\n                    let final_path = self.process_path(path);\n\n                    // Check if file exists and overwrite is disabled\n                    if final_path.exists() && !self.overwrite {\n                        return Err(BurnpackError::IoError(format!(\n                            \"File already exists: {}. Use .overwrite(true) to overwrite.\",\n                            final_path.display()\n                        )));\n                    }\n                    writer.write_to_file(&final_path)?;\n                }\n                StoreMode::Bytes(_) => {\n                    // Generate and store the bytes\n                    let bytes_data = writer.to_bytes()?;\n                    // Update mode with bytes - this pattern is irrefutable in no-std mode\n                    #[cfg_attr(not(feature = \"std\"), allow(irrefutable_let_patterns))]\n                    let StoreMode::Bytes(bytes_ref) = &mut self.mode else {\n                        unreachable!(\"We just matched Bytes variant\");\n                    };\n                    *bytes_ref = Some(bytes_data);\n                }\n            }\n        }\n\n        Ok(())\n    }\n\n    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &mut M,\n    ) -> Result<crate::ApplyResult, Self::Error> {\n        // Get all snapshots using the cached method\n        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();\n\n        // Apply all snapshots at once to the module\n        // Burnpack is Burn's native format, so no enum variant skipping needed\n        // Filter is applied here during apply, not during cache population\n        let result = module.apply(\n            snapshots,\n            self.filter.clone(),\n            Some(self.from_adapter.clone()),\n            false,\n        );\n\n        // Validate if needed\n        if self.validate && !result.errors.is_empty() {\n            return Err(BurnpackError::ValidationError(format!(\n                \"Import errors: {:?}\",\n                result.errors\n            )));\n        }\n\n        // Check for missing tensors if partial loading is not allowed\n        if !self.allow_partial && !result.missing.is_empty() {\n            return Err(BurnpackError::ValidationError(format!(\n                \"Missing tensors: {:?}\",\n                result.missing\n            )));\n        }\n\n        Ok(result)\n    }\n\n    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {\n        // Ensure cache is populated\n        self.ensure_snapshots_cache()?;\n        Ok(self.snapshots_cache.as_ref().unwrap().get(name))\n    }\n\n    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {\n        // Ensure cache is populated\n        self.ensure_snapshots_cache()?;\n        Ok(self.snapshots_cache.as_ref().unwrap())\n    }\n\n    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {\n        // Always use the cache to ensure remapping is applied consistently\n        Ok(self.get_all_snapshots()?.keys().cloned().collect())\n    }\n}\n\nimpl BurnpackStore {\n    /// Ensure the snapshots cache is populated\n    fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {\n        if self.snapshots_cache.is_some() {\n            return Ok(());\n        }\n\n        // Ensure reader is loaded\n        self.ensure_reader()?;\n\n        // Get snapshots from reader with zero-copy if enabled\n        let reader = self.reader.as_ref().unwrap();\n        let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;\n\n        // Apply remapping if configured (but NOT filtering - that's done at apply time)\n        #[cfg(feature = \"std\")]\n        let snapshots = if !self.remapper.patterns.is_empty() {\n            let (remapped, _remapped_names) = self.remapper.remap(snapshots);\n            remapped\n        } else {\n            snapshots\n        };\n\n        // Build the cache as BTreeMap\n        let cache: BTreeMap<String, TensorSnapshot> =\n            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();\n\n        self.snapshots_cache = Some(cache);\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/alignment.rs",
    "content": "//! Tests for tensor data alignment in burnpack format.\n//!\n//! These tests verify that tensor data is properly aligned for mmap zero-copy access.\n\nuse crate::TensorSnapshot;\nuse crate::burnpack::{\n    base::{\n        BurnpackHeader, BurnpackMetadata, HEADER_SIZE, TENSOR_ALIGNMENT, aligned_data_section_start,\n    },\n    reader::BurnpackReader,\n    writer::BurnpackWriter,\n};\nuse burn_core::module::ParamId;\nuse burn_tensor::{DType, TensorData};\n\n/// Verify that aligned_data_section_start always returns 256-byte aligned values\n#[test]\nfn test_aligned_data_section_start_is_always_aligned() {\n    // Test various metadata sizes\n    for metadata_size in 0..1024 {\n        let result = aligned_data_section_start(metadata_size);\n        assert_eq!(\n            result % TENSOR_ALIGNMENT as usize,\n            0,\n            \"aligned_data_section_start({}) = {} is not aligned to {}\",\n            metadata_size,\n            result,\n            TENSOR_ALIGNMENT\n        );\n    }\n}\n\n/// Verify data section starts at correct aligned position\n#[test]\nfn test_data_section_alignment() {\n    // Create a tensor\n    let data = [1.0f32, 2.0, 3.0, 4.0];\n    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes, vec![4], DType::F32),\n        vec![\"tensor\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    // Parse header to get metadata size\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    // Verify data section starts at 256-byte aligned position\n    assert_eq!(\n        data_section_start % TENSOR_ALIGNMENT as usize,\n        0,\n        \"Data section start {} is not 256-byte aligned\",\n        data_section_start\n    );\n\n    // Verify the file is large enough\n    assert!(\n        file_bytes.len() >= data_section_start,\n        \"File too small: {} < {}\",\n        file_bytes.len(),\n        data_section_start\n    );\n}\n\n/// Verify that first tensor's absolute file position is 256-byte aligned\n#[test]\nfn test_first_tensor_absolute_position_aligned() {\n    let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, vec![8], DType::U8),\n        vec![\"first\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    let tensor_desc = metadata.tensors.get(\"first\").unwrap();\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    // Absolute file position of first tensor\n    let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;\n\n    assert_eq!(\n        absolute_pos % TENSOR_ALIGNMENT as usize,\n        0,\n        \"First tensor absolute position {} is not 256-byte aligned\",\n        absolute_pos\n    );\n}\n\n/// Verify that all tensors in a multi-tensor file have 256-byte aligned absolute positions\n#[test]\nfn test_all_tensors_absolute_positions_aligned() {\n    // Create multiple tensors of different sizes (all U8 to simplify shape calculation)\n    let tensors = vec![\n        (\"tensor_a\", vec![1u8, 2, 3]), // 3 bytes\n        (\"tensor_b\", vec![0u8; 16]),   // 16 bytes\n        (\"tensor_c\", vec![0u8; 64]),   // 64 bytes\n        (\"tensor_d\", vec![42u8]),      // 1 byte\n        (\"tensor_e\", vec![0u8; 400]),  // 400 bytes\n    ];\n\n    let snapshots: Vec<TensorSnapshot> = tensors\n        .into_iter()\n        .map(|(name, data)| {\n            let len = data.len();\n            TensorSnapshot::from_data(\n                TensorData::from_bytes_vec(data, vec![len], DType::U8),\n                vec![name.to_string()],\n                vec![],\n                ParamId::new(),\n            )\n        })\n        .collect();\n\n    let writer = BurnpackWriter::new(snapshots);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    // Check every tensor has aligned absolute position\n    for (name, desc) in &metadata.tensors {\n        let absolute_pos = data_section_start + desc.data_offsets.0 as usize;\n        assert_eq!(\n            absolute_pos % TENSOR_ALIGNMENT as usize,\n            0,\n            \"Tensor '{}' at absolute position {} is not 256-byte aligned (offset in data section: {})\",\n            name,\n            absolute_pos,\n            desc.data_offsets.0\n        );\n    }\n}\n\n/// Test edge case: metadata size that results in no padding needed\n#[test]\nfn test_alignment_with_minimal_padding() {\n    // We can't control metadata size directly, but we can verify the math works\n    // When HEADER_SIZE + metadata_size is already a multiple of 256, no padding needed\n    let aligned_metadata_size = TENSOR_ALIGNMENT as usize - HEADER_SIZE; // 256 - 10 = 246\n\n    let result = aligned_data_section_start(aligned_metadata_size);\n    assert_eq!(result, TENSOR_ALIGNMENT as usize); // Should be exactly 256\n\n    // One byte more should still round up to 256\n    let result_plus_one = aligned_data_section_start(aligned_metadata_size + 1);\n    assert_eq!(result_plus_one, 2 * TENSOR_ALIGNMENT as usize); // Should be 512\n}\n\n/// Verify padding bytes in the file are zeros\n#[test]\nfn test_padding_bytes_are_zeros() {\n    let data: Vec<u8> = vec![0xAA; 16]; // Distinctive pattern\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data.clone(), vec![16], DType::U8),\n        vec![\"tensor\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    // Check padding between metadata and data section\n    if data_section_start > metadata_end {\n        let padding = &file_bytes[metadata_end..data_section_start];\n        assert!(\n            padding.iter().all(|&b| b == 0),\n            \"Padding bytes between metadata and data section contain non-zero values\"\n        );\n    }\n}\n\n/// Verify alignment is sufficient for all primitive types\n/// 256-byte alignment is a multiple of all primitive type alignments:\n/// - f64/i64/u64: 8 bytes\n/// - f32/i32/u32: 4 bytes\n/// - f16/bf16/i16/u16: 2 bytes\n/// - i8/u8/bool: 1 byte\n#[test]\n#[allow(clippy::modulo_one)]\nfn test_alignment_covers_all_primitive_types() {\n    // 256 must be divisible by all common alignments\n    assert_eq!(\n        TENSOR_ALIGNMENT % 8,\n        0,\n        \"256 not divisible by 8 (f64 alignment)\"\n    );\n    assert_eq!(\n        TENSOR_ALIGNMENT % 4,\n        0,\n        \"256 not divisible by 4 (f32 alignment)\"\n    );\n    assert_eq!(\n        TENSOR_ALIGNMENT % 2,\n        0,\n        \"256 not divisible by 2 (f16 alignment)\"\n    );\n    assert_eq!(\n        TENSOR_ALIGNMENT % 1,\n        0,\n        \"256 not divisible by 1 (u8 alignment)\"\n    );\n}\n\n/// Verify that tensor data can be read correctly after alignment\n#[test]\nfn test_aligned_tensor_data_readable() {\n    // Create f32 tensor\n    let f32_data = vec![1.0f32, 2.0, 3.0, 4.0];\n    let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(f32_bytes.clone(), vec![4], DType::F32),\n        vec![\"floats\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    let tensor_desc = metadata.tensors.get(\"floats\").unwrap();\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    let start = data_section_start + tensor_desc.data_offsets.0 as usize;\n    let end = data_section_start + tensor_desc.data_offsets.1 as usize;\n    let tensor_bytes = &file_bytes[start..end];\n\n    // Verify the bytes match what we wrote\n    assert_eq!(tensor_bytes, f32_bytes.as_slice());\n\n    // Verify we can interpret them as floats\n    let mut floats = Vec::new();\n    for chunk in tensor_bytes.chunks_exact(4) {\n        floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));\n    }\n    assert_eq!(floats, f32_data);\n}\n\n/// Verify alignment works with f64 data\n#[test]\nfn test_aligned_f64_tensor_data_readable() {\n    let f64_data = vec![1.0f64, 2.0, 3.0, 4.0];\n    let f64_bytes: Vec<u8> = f64_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(f64_bytes.clone(), vec![4], DType::F64),\n        vec![\"doubles\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    let tensor_desc = metadata.tensors.get(\"doubles\").unwrap();\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n\n    let start = data_section_start + tensor_desc.data_offsets.0 as usize;\n    let end = data_section_start + tensor_desc.data_offsets.1 as usize;\n    let tensor_bytes = &file_bytes[start..end];\n\n    // Verify the bytes match\n    assert_eq!(tensor_bytes, f64_bytes.as_slice());\n\n    // Verify we can interpret them as doubles\n    let mut doubles = Vec::new();\n    for chunk in tensor_bytes.chunks_exact(8) {\n        doubles.push(f64::from_le_bytes([\n            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],\n        ]));\n    }\n    assert_eq!(doubles, f64_data);\n}\n\n/// Test round-trip preserves alignment (write then read)\n#[test]\nfn test_round_trip_maintains_alignment() {\n    let f32_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];\n    let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(f32_bytes, vec![2, 4], DType::F32),\n        vec![\"matrix\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    // Write\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    // Read back\n    let reader = BurnpackReader::from_bytes(file_bytes.clone()).unwrap();\n    let snapshots = reader.get_snapshots().unwrap();\n\n    assert_eq!(snapshots.len(), 1);\n    let loaded = &snapshots[0];\n    assert_eq!(loaded.full_path(), \"matrix\");\n\n    // Verify the loaded data is correct\n    let tensor_data = loaded.to_data().unwrap();\n    let mut loaded_floats = Vec::new();\n    for chunk in tensor_data.bytes.chunks_exact(4) {\n        loaded_floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));\n    }\n    assert_eq!(loaded_floats, f32_data);\n}\n\n/// Test that tensor offsets within data section are also aligned\n#[test]\nfn test_tensor_relative_offsets_are_aligned() {\n    // Create several small tensors to force multiple alignment padding\n    let tensors: Vec<_> = (0..5)\n        .map(|i| {\n            let data = vec![i as u8; 7]; // 7 bytes each - not aligned\n            TensorSnapshot::from_data(\n                TensorData::from_bytes_vec(data, vec![7], DType::U8),\n                vec![format!(\"tensor_{}\", i)],\n                vec![],\n                ParamId::new(),\n            )\n        })\n        .collect();\n\n    let writer = BurnpackWriter::new(tensors);\n    let file_bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    // All tensor start offsets within data section should be multiples of 256\n    for (name, desc) in &metadata.tensors {\n        assert_eq!(\n            desc.data_offsets.0 % TENSOR_ALIGNMENT,\n            0,\n            \"Tensor '{}' relative offset {} is not 256-byte aligned\",\n            name,\n            desc.data_offsets.0\n        );\n    }\n}\n\n#[cfg(feature = \"std\")]\nmod file_tests {\n    use super::*;\n    use std::fs;\n    use tempfile::tempdir;\n\n    /// Test alignment is preserved when writing to and reading from file\n    #[test]\n    fn test_file_io_preserves_alignment() {\n        let dir = tempdir().unwrap();\n        let file_path = dir.path().join(\"aligned.bpk\");\n\n        let f32_data = [1.0f32, 2.0, 3.0, 4.0];\n        let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(f32_bytes, vec![4], DType::F32),\n            vec![\"floats\".to_string()],\n            vec![],\n            ParamId::new(),\n        );\n\n        // Write to file\n        let writer = BurnpackWriter::new(vec![snapshot]);\n        writer.write_to_file(&file_path).unwrap();\n\n        // Read file bytes directly\n        let file_bytes = fs::read(&file_path).unwrap();\n\n        let header = BurnpackHeader::from_bytes(&file_bytes[..HEADER_SIZE]).unwrap();\n        let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n        let metadata: BurnpackMetadata =\n            ciborium::de::from_reader(&file_bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n        let tensor_desc = metadata.tensors.get(\"floats\").unwrap();\n        let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n        let absolute_pos = data_section_start + tensor_desc.data_offsets.0 as usize;\n\n        assert_eq!(\n            absolute_pos % TENSOR_ALIGNMENT as usize,\n            0,\n            \"Tensor absolute position in file {} is not 256-byte aligned\",\n            absolute_pos\n        );\n\n        // Verify data is correct\n        let start = data_section_start + tensor_desc.data_offsets.0 as usize;\n        let end = data_section_start + tensor_desc.data_offsets.1 as usize;\n        let tensor_bytes = &file_bytes[start..end];\n\n        let mut floats = Vec::new();\n        for chunk in tensor_bytes.chunks_exact(4) {\n            floats.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));\n        }\n        assert_eq!(floats, vec![1.0f32, 2.0, 3.0, 4.0]);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/edge_cases.rs",
    "content": "use crate::TensorSnapshot;\nuse crate::burnpack::{\n    base::{BurnpackHeader, HEADER_SIZE},\n    reader::BurnpackReader,\n    writer::BurnpackWriter,\n};\nuse burn_core::module::ParamId;\nuse burn_tensor::{BoolStore, DType, TensorData, shape};\n\n#[test]\nfn test_maximum_metadata_size() {\n    // Create metadata that approaches u32::MAX (4GB limit)\n    // In practice, we'll test with a reasonably large metadata\n    let large_key = \"x\".repeat(1000);\n    let large_value = \"y\".repeat(10000);\n\n    let mut writer = BurnpackWriter::new(vec![]);\n\n    for i in 0..100 {\n        writer = writer.with_metadata(&format!(\"{}_{}\", large_key, i), &large_value);\n    }\n\n    let result = writer.to_bytes();\n    assert!(result.is_ok());\n\n    let bytes = result.unwrap();\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n\n    // Metadata size should be large but within u32 bounds\n    assert!(header.metadata_size > 1000000); // At least 1MB of metadata\n    assert!(header.metadata_size < u32::MAX);\n}\n\n#[test]\nfn test_zero_size_tensor_shapes() {\n    // Test various zero-dimensional shapes\n    let test_cases = [\n        (vec![0], vec![]),        // Empty 1D\n        (vec![0, 10], vec![]),    // Zero rows\n        (vec![10, 0], vec![]),    // Zero columns\n        (vec![0, 0], vec![]),     // Zero both dimensions\n        (vec![5, 0, 10], vec![]), // Zero in middle dimension\n    ];\n\n    let mut snapshots = vec![];\n    for (i, (shape, data)) in test_cases.iter().enumerate() {\n        let name = format!(\"zero_tensor_{}\", i);\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::F32),\n            vec![name.clone()],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read back and verify\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let names = reader.tensor_names();\n    assert_eq!(names.len(), 5);\n}\n\n#[test]\nfn test_extremely_long_tensor_names() {\n    // Create a tensor with an extremely long name\n    let long_name = \"a\".repeat(10000);\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),\n        vec![long_name.clone()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let names = reader.tensor_names();\n    assert_eq!(names[0].len(), 10000);\n}\n\n#[test]\nfn test_unicode_in_names_and_metadata() {\n    // Test various Unicode characters in tensor names and metadata\n    let unicode_names = vec![\n        \"测试_tensor\",    // Chinese\n        \"тест_tensor\",    // Cyrillic\n        \"テスト_tensor\",  // Japanese\n        \"🔥_burn_tensor\", // Emoji\n        \"αβγδ_tensor\",    // Greek\n        \"한글_tensor\",    // Korean\n    ];\n\n    let mut snapshots = vec![];\n    for name in &unicode_names {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),\n            vec![name.to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots)\n        .with_metadata(\"模型名称\", \"测试模型\")\n        .with_metadata(\"מודל\", \"בדיקה\")\n        .with_metadata(\"🔥\", \"fire\");\n\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Verify all Unicode names are preserved\n    let names = reader.tensor_names();\n    assert_eq!(names.len(), unicode_names.len());\n\n    // Verify metadata\n    assert_eq!(\n        reader.metadata().metadata.get(\"模型名称\"),\n        Some(&\"测试模型\".to_string())\n    );\n    assert_eq!(\n        reader.metadata().metadata.get(\"🔥\"),\n        Some(&\"fire\".to_string())\n    );\n}\n\n#[test]\nfn test_all_supported_dtypes() {\n    // Test all DTypes with their boundary values\n    let dtypes_with_data = [\n        (\n            DType::F32,\n            [\n                f32::MIN.to_le_bytes().to_vec(),\n                f32::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (\n            DType::F64,\n            [\n                f64::MIN.to_le_bytes().to_vec(),\n                f64::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (\n            DType::I32,\n            [\n                i32::MIN.to_le_bytes().to_vec(),\n                i32::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (\n            DType::I64,\n            [\n                i64::MIN.to_le_bytes().to_vec(),\n                i64::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (\n            DType::U32,\n            [\n                u32::MIN.to_le_bytes().to_vec(),\n                u32::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (\n            DType::U64,\n            [\n                u64::MIN.to_le_bytes().to_vec(),\n                u64::MAX.to_le_bytes().to_vec(),\n            ]\n            .concat(),\n        ),\n        (DType::U8, vec![u8::MIN, u8::MAX]),\n        (DType::Bool(BoolStore::Native), vec![0, 1]),\n    ];\n\n    let mut snapshots = vec![];\n    for (i, (dtype, data)) in dtypes_with_data.iter().enumerate() {\n        let name = format!(\"dtype_test_{}\", i);\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data.clone(), vec![2], *dtype),\n            vec![name],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    assert_eq!(reader.tensor_names().len(), dtypes_with_data.len());\n\n    // Verify dtypes are preserved\n    for (i, (expected_dtype, _)) in dtypes_with_data.iter().enumerate() {\n        let name = format!(\"dtype_test_{}\", i);\n        let snapshot = reader.get_tensor_snapshot(&name).unwrap();\n        assert_eq!(snapshot.dtype, *expected_dtype);\n    }\n}\n\n#[test]\nfn test_special_float_values() {\n    // Test special floating-point values (NaN, Inf, -Inf)\n    let special_values = [\n        f32::NAN,\n        f32::INFINITY,\n        f32::NEG_INFINITY,\n        0.0_f32,\n        -0.0_f32,\n    ];\n\n    let data: Vec<u8> = special_values\n        .iter()\n        .flat_map(|f| f.to_le_bytes())\n        .collect();\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data.clone(), vec![5], DType::F32),\n        vec![\"special_floats\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let tensor_data = reader.get_tensor_data(\"special_floats\").unwrap();\n\n    // Check data is preserved exactly (bit-for-bit)\n    assert_eq!(tensor_data, data);\n}\n\n#[test]\nfn test_metadata_with_empty_values() {\n    let writer = BurnpackWriter::new(vec![])\n        .with_metadata(\"empty_value\", \"\")\n        .with_metadata(\"\", \"empty_key\")\n        .with_metadata(\"normal\", \"value\");\n\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    let metadata = &reader.metadata().metadata;\n    assert_eq!(metadata.get(\"empty_value\"), Some(&\"\".to_string()));\n    assert_eq!(metadata.get(\"\"), Some(&\"empty_key\".to_string()));\n    assert_eq!(metadata.get(\"normal\"), Some(&\"value\".to_string()));\n}\n\n#[test]\nfn test_single_byte_tensor() {\n    // Test the smallest possible tensor (1 byte)\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![42], vec![1], DType::U8),\n        vec![\"single_byte\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let data = reader.get_tensor_data(\"single_byte\").unwrap();\n    assert_eq!(data, vec![42]);\n}\n\n#[test]\nfn test_high_dimensional_tensor() {\n    // Test a tensor with many dimensions (10D)\n    let shape = shape![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; // 10 dimensions, 1024 elements total\n    let data = vec![1u8; 1024];\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data.clone(), shape.clone(), DType::U8),\n        vec![\"high_dim\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let loaded_snapshot = reader.get_tensor_snapshot(\"high_dim\").unwrap();\n    assert_eq!(loaded_snapshot.shape, shape);\n}\n\n#[test]\nfn test_metadata_key_collision() {\n    // Test that later values override earlier ones for the same key\n    let writer = BurnpackWriter::new(vec![])\n        .with_metadata(\"key\", \"value1\")\n        .with_metadata(\"key\", \"value2\")\n        .with_metadata(\"key\", \"value3\");\n\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    assert_eq!(\n        reader.metadata().metadata.get(\"key\"),\n        Some(&\"value3\".to_string())\n    );\n}\n\n#[test]\nfn test_tensor_name_with_path_separators() {\n    // Test tensor names that look like file paths\n    let path_like_names = vec![\n        \"model/encoder/layer1/weights\",\n        \"model\\\\decoder\\\\layer1\\\\bias\",\n        \"model::module::param\",\n        \"model.submodule.weight\",\n    ];\n\n    let mut snapshots = vec![];\n    for name in &path_like_names {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),\n            vec![name.to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let names = reader.tensor_names();\n\n    // All names should be preserved exactly\n    for expected_name in &path_like_names {\n        assert!(names.contains(expected_name));\n    }\n}\n\n// The following tests are commented out as they test error conditions\n// that might be handled differently in the new API\n\n// #[test]\n// fn test_data_overflow_protection() {\n//     // Test that we handle potential integer overflows in offset calculations\n//     ...\n// }\n\n// #[test]\n// fn test_reading_corrupted_header() {\n//     // Test reading files with corrupted headers\n//     ...\n// }\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/header.rs",
    "content": "use crate::burnpack::base::*;\n\n#[test]\nfn test_header_serialization() {\n    let header = BurnpackHeader::new(12345);\n\n    // Check fields\n    assert_eq!(header.magic, MAGIC_NUMBER);\n    assert_eq!(header.version, FORMAT_VERSION);\n    assert_eq!(header.metadata_size, 12345);\n\n    // Serialize to bytes\n    let bytes = header.into_bytes();\n    assert_eq!(bytes.len(), HEADER_SIZE);\n\n    // Deserialize back\n    let header2 = BurnpackHeader::from_bytes(&bytes).unwrap();\n    assert_eq!(header2.magic, header.magic);\n    assert_eq!(header2.version, header.version);\n    assert_eq!(header2.metadata_size, header.metadata_size);\n}\n\n#[test]\nfn test_header_invalid_magic() {\n    let mut bytes = [0u8; HEADER_SIZE];\n    // Write wrong magic number\n    bytes[0..4].copy_from_slice(&[0x00, 0x00, 0x00, 0x00]);\n\n    let result = BurnpackHeader::from_bytes(&bytes);\n    match result {\n        Err(BurnpackError::InvalidMagicNumber) => {}\n        _ => panic!(\"Expected InvalidMagicNumber error\"),\n    }\n}\n\n#[test]\nfn test_header_insufficient_bytes() {\n    let bytes = [0u8; 5]; // Too short\n\n    let result = BurnpackHeader::from_bytes(&bytes);\n    match result {\n        Err(BurnpackError::InvalidHeader) => {}\n        _ => panic!(\"Expected InvalidHeader error\"),\n    }\n}\n\n#[test]\nfn test_version_compatibility() {\n    // Create a header with current version\n    let header = BurnpackHeader::new(100);\n    let bytes = header.into_bytes();\n\n    // Should succeed with current version\n    let result = BurnpackHeader::from_bytes(&bytes);\n    assert!(result.is_ok());\n\n    // Test with future version (should fail in real implementation)\n    // For now, we just verify the version field is correctly set\n    let header = result.unwrap();\n    assert_eq!(header.version, FORMAT_VERSION);\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/helpers.rs",
    "content": "use crate::TensorSnapshot;\nuse burn_core::module::ParamId;\nuse burn_tensor::{DType, TensorData};\n\n/// Helper to create a test TensorSnapshot\n#[allow(dead_code)]\npub fn create_test_snapshot(\n    name: String,\n    data: Vec<u8>,\n    shape: Vec<usize>,\n    dtype: DType,\n) -> TensorSnapshot {\n    TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, shape, dtype),\n        vec![name],\n        vec![],\n        ParamId::new(),\n    )\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/mod.rs",
    "content": "use crate::TensorSnapshot;\n\nmod alignment;\nmod edge_cases;\nmod header;\nmod helpers;\nmod reader;\nmod round_trip;\nmod store;\nmod writer;\nmod zero_copy;\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/reader.rs",
    "content": "use crate::burnpack::{\n    base::{\n        BurnpackError, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER, magic_range, metadata_size_range,\n        version_range,\n    },\n    reader::BurnpackReader,\n    writer::BurnpackWriter,\n};\n\nuse super::*;\nuse burn_tensor::{BoolStore, Bytes, DType, TensorData, shape};\n\n#[test]\nfn test_reader_from_bytes_empty() {\n    // Create empty burnpack data\n    let writer = BurnpackWriter::new(Vec::new());\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read it back\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    assert_eq!(reader.metadata().tensors.len(), 0);\n    assert!(reader.metadata().metadata.is_empty());\n}\n\n#[test]\nfn test_reader_from_bytes_with_data() {\n    // Create test data\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"test\", \"value\");\n\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read it back\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    assert_eq!(reader.metadata().tensors.len(), 1);\n    assert_eq!(\n        reader.metadata().metadata.get(\"test\"),\n        Some(&\"value\".to_string())\n    );\n\n    // Get tensor data\n    let tensor_data = reader.get_tensor_data(\"test_tensor\").unwrap();\n    assert_eq!(tensor_data, &[1, 2, 3, 4]);\n}\n\n#[test]\nfn test_reader_invalid_magic_number() {\n    let mut bytes = vec![0u8; 100];\n    // Write invalid magic number\n    bytes[magic_range()].copy_from_slice(b\"NOPE\");\n\n    let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));\n    assert!(matches!(result, Err(BurnpackError::InvalidMagicNumber)));\n}\n\n#[test]\nfn test_reader_invalid_version() {\n    let mut bytes = vec![0u8; 100];\n    // Write correct magic but invalid version\n    bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());\n    bytes[version_range()].copy_from_slice(&999u16.to_le_bytes()); // Invalid version\n    bytes[metadata_size_range()].copy_from_slice(&10u32.to_le_bytes()); // Metadata size\n\n    let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));\n    assert!(matches!(result, Err(BurnpackError::InvalidVersion)));\n}\n\n#[test]\nfn test_reader_header_too_short() {\n    let bytes = vec![0u8; 5]; // Less than HEADER_SIZE\n\n    let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));\n    assert!(matches!(result, Err(BurnpackError::InvalidHeader)));\n}\n\n#[test]\nfn test_reader_metadata_truncated() {\n    let mut bytes = vec![0u8; HEADER_SIZE + 10];\n    // Write valid header\n    bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());\n    bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());\n    bytes[metadata_size_range()].copy_from_slice(&100u32.to_le_bytes()); // Claims 100 bytes of metadata\n\n    // But only provide 10 bytes after header\n    let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));\n    assert!(matches!(result, Err(BurnpackError::InvalidHeader)));\n}\n\n#[test]\nfn test_reader_get_tensor_not_found() {\n    let writer = BurnpackWriter::new(Vec::new());\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    let result = reader.get_tensor_data(\"non_existent\");\n    assert!(matches!(result, Err(BurnpackError::TensorNotFound(_))));\n}\n\n#[test]\nfn test_reader_get_tensor_snapshot() {\n    let data = [1.0f32, 2.0, 3.0, 4.0];\n    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),\n        vec![\"weights\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let writer_bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(writer_bytes).unwrap();\n\n    // Get tensor as snapshot\n    let loaded_snapshot = reader.get_tensor_snapshot(\"weights\").unwrap();\n\n    // Verify snapshot metadata\n    assert_eq!(loaded_snapshot.full_path(), \"weights\");\n    assert_eq!(loaded_snapshot.dtype, DType::F32);\n    assert_eq!(loaded_snapshot.shape, shape![2, 2]);\n\n    // Verify data through closure\n    let tensor_data = loaded_snapshot.to_data().unwrap();\n    assert_eq!(tensor_data.shape, shape![2, 2]);\n}\n\n#[test]\nfn test_reader_multiple_tensors() {\n    // Add multiple tensors\n    let mut snapshots = Vec::new();\n    for i in 0..10 {\n        let name = format!(\"tensor_{}\", i);\n        let data = vec![i as u8; 100];\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data, shape![100], DType::U8),\n            vec![name.clone()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Verify all tensors can be read\n    for i in 0..10 {\n        let name = format!(\"tensor_{}\", i);\n        let data = reader.get_tensor_data(&name).unwrap();\n        assert_eq!(data.len(), 100);\n        assert!(data.iter().all(|&b| b == i as u8));\n    }\n}\n\n#[test]\nfn test_reader_lazy_loading() {\n    // Create large tensor\n    let size = 1024 * 1024; // 1MB\n    let data = vec![42u8; size];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data.clone(), vec![size], DType::U8),\n        vec![\"large\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Get snapshot (should be lazy)\n    let snapshot = reader.get_tensor_snapshot(\"large\").unwrap();\n\n    // Data should only be accessed when to_data is called\n    let tensor_data = snapshot.to_data().unwrap();\n    assert_eq!(tensor_data.bytes.len(), size);\n    assert!(tensor_data.bytes.iter().all(|&b| b == 42));\n}\n\n#[test]\nfn test_reader_all_dtypes() {\n    // Test all data types\n    let test_data = [\n        (DType::F32, [1.0f32.to_le_bytes().to_vec()].concat()),\n        (DType::F64, [2.0f64.to_le_bytes().to_vec()].concat()),\n        (DType::I32, [3i32.to_le_bytes().to_vec()].concat()),\n        (DType::I64, [4i64.to_le_bytes().to_vec()].concat()),\n        (DType::U32, [5u32.to_le_bytes().to_vec()].concat()),\n        (DType::U64, [6u64.to_le_bytes().to_vec()].concat()),\n        (DType::U8, vec![7u8]),\n        (DType::Bool(BoolStore::Native), vec![1u8]),\n    ];\n\n    let mut snapshots = Vec::new();\n    for (i, (dtype, data)) in test_data.iter().enumerate() {\n        let name = format!(\"tensor_{}\", i);\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data.clone(), vec![1], *dtype),\n            vec![name.clone()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Verify all dtypes are preserved\n    for (i, (expected_dtype, expected_data)) in test_data.iter().enumerate() {\n        let name = format!(\"tensor_{}\", i);\n        let snapshot = reader.get_tensor_snapshot(&name).unwrap();\n        assert_eq!(snapshot.dtype, *expected_dtype);\n\n        let data = reader.get_tensor_data(&name).unwrap();\n        assert_eq!(data, expected_data.as_slice());\n    }\n}\n\n#[test]\nfn test_reader_empty_tensor() {\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![], vec![0], DType::F32),\n        vec![\"empty\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    let data = reader.get_tensor_data(\"empty\").unwrap();\n    assert_eq!(data.len(), 0);\n\n    let snapshot = reader.get_tensor_snapshot(\"empty\").unwrap();\n    assert_eq!(snapshot.shape, shape![0]);\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn test_reader_from_file() {\n    use tempfile::tempdir;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"test.bpk\");\n\n    // Create test file\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![10, 20, 30], vec![3], DType::U8),\n        vec![\"file_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"from_file_test\", \"true\");\n\n    writer.write_to_file(&file_path).unwrap();\n\n    // Read from file\n    let reader = BurnpackReader::from_file(&file_path).unwrap();\n\n    assert_eq!(\n        reader.metadata().metadata.get(\"from_file_test\"),\n        Some(&\"true\".to_string())\n    );\n\n    let data = reader.get_tensor_data(\"file_tensor\").unwrap();\n    assert_eq!(data, &[10, 20, 30]);\n}\n\n#[cfg(all(feature = \"std\", feature = \"memmap\"))]\n#[test]\nfn test_reader_from_file_mmap() {\n    use tempfile::tempdir;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"test_mmap.bpk\");\n\n    // Create large test file\n    let size = 1024 * 1024; // 1MB\n    let data = vec![99u8; size];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, vec![size], DType::U8),\n        vec![\"large_mmap\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    writer.write_to_file(&file_path).unwrap();\n\n    // Read using mmap\n    let reader = BurnpackReader::from_file_mmap(&file_path).unwrap();\n\n    let data = reader.get_tensor_data(\"large_mmap\").unwrap();\n    assert_eq!(data.len(), size);\n    assert!(data.iter().all(|&b| b == 99));\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn test_reader_from_file_buffered() {\n    use tempfile::tempdir;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"test_buffered.bpk\");\n\n    // Create test file\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![5, 10, 15], vec![3], DType::U8),\n        vec![\"buffered_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    writer.write_to_file(&file_path).unwrap();\n\n    // Read using buffered reader\n    let reader = BurnpackReader::from_file_buffered(&file_path).unwrap();\n\n    let data = reader.get_tensor_data(\"buffered_tensor\").unwrap();\n    assert_eq!(data, &[5, 10, 15]);\n}\n\n#[test]\nfn test_reader_metadata_access() {\n    // Add various metadata using builder pattern\n    let writer = BurnpackWriter::new(Vec::new())\n        .with_metadata(\"model_name\", \"test_model\")\n        .with_metadata(\"version\", \"1.2.3\")\n        .with_metadata(\"author\", \"test_author\")\n        .with_metadata(\"description\", \"A test model\");\n\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    let metadata = reader.metadata();\n    assert_eq!(metadata.metadata.len(), 4);\n    assert_eq!(\n        metadata.metadata.get(\"model_name\"),\n        Some(&\"test_model\".to_string())\n    );\n    assert_eq!(metadata.metadata.get(\"version\"), Some(&\"1.2.3\".to_string()));\n    assert_eq!(\n        metadata.metadata.get(\"author\"),\n        Some(&\"test_author\".to_string())\n    );\n    assert_eq!(\n        metadata.metadata.get(\"description\"),\n        Some(&\"A test model\".to_string())\n    );\n}\n\n#[test]\nfn test_reader_tensor_iteration() {\n    // Add tensors\n    let tensor_names = vec![\"weights\", \"bias\", \"running_mean\", \"running_var\"];\n    let mut snapshots = Vec::new();\n    for name in &tensor_names {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),\n            vec![name.to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Iterate through all tensors\n    let metadata = reader.metadata();\n    assert_eq!(metadata.tensors.len(), 4);\n\n    // Check that all expected tensor names are present\n    for name in &tensor_names {\n        let tensor_desc = metadata.tensors.get(*name).unwrap();\n        assert_eq!(tensor_desc.shape, vec![4u64]);\n        assert_eq!(tensor_desc.dtype, DType::U8);\n    }\n\n    // Verify the keys match the expected names\n    let mut actual_names: Vec<_> = metadata.tensors.keys().cloned().collect();\n    actual_names.sort();\n    let mut expected_names = tensor_names\n        .iter()\n        .map(|s| s.to_string())\n        .collect::<Vec<_>>();\n    expected_names.sort();\n    assert_eq!(actual_names, expected_names);\n}\n\n#[test]\nfn test_reader_corrupt_metadata() {\n    let mut bytes = vec![0u8; 100];\n\n    // Write valid header\n    bytes[magic_range()].copy_from_slice(&MAGIC_NUMBER.to_le_bytes());\n    bytes[version_range()].copy_from_slice(&FORMAT_VERSION.to_le_bytes());\n    bytes[metadata_size_range()].copy_from_slice(&50u32.to_le_bytes()); // 50 bytes of metadata\n\n    // Write garbage as metadata\n    #[allow(clippy::needless_range_loop)]\n    for i in HEADER_SIZE..HEADER_SIZE + 50 {\n        bytes[i] = 0xFF;\n    }\n\n    let result = BurnpackReader::from_bytes(Bytes::from_bytes_vec(bytes));\n    assert!(result.is_err());\n}\n\n#[test]\nfn test_reader_data_offsets_validation() {\n    // Add two tensors\n    let snapshot1 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),\n        vec![\"tensor1\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n    let snapshot2 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),\n        vec![\"tensor2\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Verify offsets don't overlap and are properly aligned\n    let metadata = reader.metadata();\n    let tensor1_desc = metadata.tensors.get(\"tensor1\").unwrap();\n    let tensor2_desc = metadata.tensors.get(\"tensor2\").unwrap();\n\n    // First tensor starts at offset 0 (already aligned to 256 bytes)\n    assert_eq!(tensor1_desc.data_offsets, (0, 4));\n    // Second tensor starts at next 256-byte aligned offset\n    assert_eq!(tensor2_desc.data_offsets, (256, 260));\n}\n\n#[test]\nfn test_reader_out_of_bounds_error() {\n    use crate::burnpack::reader::StorageBackend;\n    use alloc::rc::Rc;\n\n    // Create a small data buffer\n    let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);\n    let backend = StorageBackend::Memory(Rc::new(data));\n\n    // Try to read beyond the available data\n    let mut buffer = vec![0u8; 10];\n    let result = backend.read_into(&mut buffer, 0);\n\n    // Should return an error\n    assert!(result.is_err());\n    let err = result.unwrap_err();\n    assert!(err.to_string().contains(\"out of bounds\"));\n}\n\n#[test]\nfn test_reader_offset_overflow_error() {\n    use crate::burnpack::reader::StorageBackend;\n    use alloc::rc::Rc;\n\n    let data = Bytes::from_bytes_vec(vec![1, 2, 3, 4, 5]);\n    let backend = StorageBackend::Memory(Rc::new(data));\n\n    // Try to read with an offset that would overflow\n    let mut buffer = vec![0u8; 10];\n    let result = backend.read_into(&mut buffer, usize::MAX - 5);\n\n    // Should return an error about overflow\n    assert!(result.is_err());\n    let err = result.unwrap_err();\n    assert!(err.to_string().contains(\"overflow\"));\n}\n\n#[test]\nfn test_reader_corrupted_shape_returns_error() {\n    // Only test this on platforms where usize is smaller than u64\n    // On 64-bit platforms, u64 values can fit in usize\n    #[cfg(target_pointer_width = \"32\")]\n    {\n        use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};\n        use alloc::collections::BTreeMap;\n        use alloc::rc::Rc;\n        use burn_tensor::DType;\n\n        // Create metadata with a shape dimension that exceeds usize::MAX on 32-bit platforms\n        let mut tensors = BTreeMap::new();\n        tensors.insert(\n            \"corrupted_tensor\".to_string(),\n            TensorDescriptor {\n                dtype: DType::F32,\n                shape: vec![u64::MAX, 2, 3], // First dimension exceeds usize::MAX on 32-bit\n                data_offsets: (0, 100),\n                param_id: None,\n            },\n        );\n\n        let metadata = BurnpackMetadata {\n            tensors,\n            metadata: BTreeMap::new(),\n        };\n\n        // Create a small data buffer\n        let data = Bytes::from_bytes_vec(vec![0u8; 1000]);\n        let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));\n\n        let reader = BurnpackReader {\n            metadata,\n            storage: backend,\n            data_offset: 0,\n        };\n\n        // This should return an error, not panic\n        let result = reader.get_snapshots();\n        assert!(result.is_err());\n        let err = result.unwrap_err();\n        assert!(matches!(err, BurnpackError::ValidationError(_)));\n        assert!(\n            err.to_string().contains(\"corrupted shape data\")\n                || err.to_string().contains(\"exceeds platform maximum\")\n        );\n    }\n\n    #[cfg(not(target_pointer_width = \"32\"))]\n    {\n        // On 64-bit platforms, just pass the test\n        // The conversion logic is still correct, but u64 fits in usize\n    }\n}\n\n#[test]\nfn test_reader_corrupted_offsets_returns_error() {\n    // Only test this on platforms where usize is smaller than u64\n    #[cfg(target_pointer_width = \"32\")]\n    {\n        use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};\n        use alloc::collections::BTreeMap;\n        use alloc::rc::Rc;\n        use burn_tensor::DType;\n\n        // Create metadata with offsets that would overflow\n        let mut tensors = BTreeMap::new();\n        tensors.insert(\n            \"tensor_bad_offset\".to_string(),\n            TensorDescriptor {\n                dtype: DType::F32,\n                shape: vec![2, 2],\n                data_offsets: (u64::MAX - 10, u64::MAX), // Offsets that exceed usize::MAX on 32-bit\n                param_id: None,\n            },\n        );\n\n        let metadata = BurnpackMetadata {\n            tensors,\n            metadata: BTreeMap::new(),\n        };\n\n        let data = Bytes::from_bytes_vec(vec![0u8; 1000]);\n        let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));\n\n        let reader = BurnpackReader {\n            metadata,\n            storage: backend,\n            data_offset: 0,\n        };\n\n        // This should return an error, not panic\n        let result = reader.get_snapshots();\n        assert!(result.is_err());\n        let err = result.unwrap_err();\n        assert!(matches!(err, BurnpackError::ValidationError(_)));\n        assert!(\n            err.to_string().contains(\"corrupted offset data\")\n                || err.to_string().contains(\"exceeds platform maximum\")\n        );\n    }\n\n    #[cfg(not(target_pointer_width = \"32\"))]\n    {\n        use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};\n        use alloc::collections::BTreeMap;\n        use alloc::rc::Rc;\n        use burn_tensor::DType;\n\n        // On 64-bit platforms, test offset overflow during addition\n        let mut tensors = BTreeMap::new();\n        tensors.insert(\n            \"tensor_overflow\".to_string(),\n            TensorDescriptor {\n                dtype: DType::F32,\n                shape: vec![2, 2],\n                data_offsets: (0, 100),\n                param_id: None,\n            },\n        );\n\n        let metadata = BurnpackMetadata {\n            tensors,\n            metadata: BTreeMap::new(),\n        };\n\n        let data = Bytes::from_bytes_vec(vec![0u8; 1000]);\n        let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));\n\n        // Use a data_offset that will overflow when added to the tensor offset\n        let reader = BurnpackReader {\n            metadata,\n            storage: backend,\n            data_offset: usize::MAX - 50, // Will overflow when added to 100\n        };\n\n        // This should return an error, not panic\n        let result = reader.get_snapshots();\n        assert!(result.is_err());\n        let err = result.unwrap_err();\n        assert!(matches!(err, BurnpackError::ValidationError(_)));\n        assert!(err.to_string().contains(\"overflow\"));\n    }\n}\n\n#[test]\nfn test_reader_inverted_offsets_returns_error() {\n    use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};\n    use alloc::collections::BTreeMap;\n    use alloc::rc::Rc;\n    use burn_tensor::DType;\n\n    // Create metadata with end offset < start offset (corrupted)\n    let mut tensors = BTreeMap::new();\n    tensors.insert(\n        \"inverted_tensor\".to_string(),\n        TensorDescriptor {\n            dtype: DType::F32,\n            shape: vec![2, 2],\n            data_offsets: (100, 50), // End offset < start offset\n            param_id: None,\n        },\n    );\n\n    let metadata = BurnpackMetadata {\n        tensors,\n        metadata: BTreeMap::new(),\n    };\n\n    let data = Bytes::from_bytes_vec(vec![0u8; 1000]);\n    let backend = crate::burnpack::reader::StorageBackend::Memory(Rc::new(data));\n\n    let reader = BurnpackReader {\n        metadata,\n        storage: backend,\n        data_offset: 0,\n    };\n\n    // This should return an error, not panic\n    let result = reader.get_snapshots();\n    assert!(result.is_err());\n    let err = result.unwrap_err();\n    assert!(matches!(err, BurnpackError::ValidationError(_)));\n    assert!(err.to_string().contains(\"end offset\") && err.to_string().contains(\"start offset\"));\n}\n\n#[test]\nfn test_reader_truncated_file_from_bytes() {\n    // Create a valid burnpack with tensor data\n    let tensor_size = 1024; // 1KB of data\n    let data = vec![42u8; tensor_size];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),\n        vec![\"large_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let full_bytes = writer.to_bytes().unwrap();\n\n    // Truncate the bytes - remove the last 512 bytes of tensor data\n    let truncated_len = full_bytes.len() - 512;\n    let truncated_bytes = Bytes::from_bytes_vec(full_bytes.to_vec()[..truncated_len].to_vec());\n\n    // This should fail with a validation error indicating file truncation\n    let result = BurnpackReader::from_bytes(truncated_bytes);\n    assert!(result.is_err());\n    if let Err(err) = result {\n        assert!(matches!(err, BurnpackError::ValidationError(_)));\n        assert!(err.to_string().contains(\"File truncated\"));\n        assert!(err.to_string().contains(\"expected at least\"));\n    }\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn test_reader_truncated_file_from_file() {\n    use std::fs::OpenOptions;\n    use tempfile::tempdir;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"truncated.bpk\");\n\n    // Create a valid burnpack file with tensor data\n    let tensor_size = 2048; // 2KB of data\n    let data = vec![99u8; tensor_size];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),\n        vec![\"data_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    writer.write_to_file(&file_path).unwrap();\n\n    // Read the full file to get its size\n    let full_size = std::fs::metadata(&file_path).unwrap().len();\n\n    // Truncate the file - remove the last 1KB\n    let truncated_size = full_size - 1024;\n    let truncated_file = OpenOptions::new().write(true).open(&file_path).unwrap();\n    truncated_file.set_len(truncated_size).unwrap();\n    drop(truncated_file);\n\n    // Try to read the truncated file - should fail with validation error\n    let result = BurnpackReader::from_file(&file_path);\n    assert!(result.is_err());\n    if let Err(err) = result {\n        assert!(matches!(err, BurnpackError::ValidationError(_)));\n        assert!(err.to_string().contains(\"File truncated\"));\n        assert!(err.to_string().contains(\"expected at least\"));\n    }\n}\n\n#[test]\nfn test_reader_file_size_exactly_correct() {\n    // Test that a file with exactly the right size passes validation\n    let tensor_size = 100;\n    let data = vec![77u8; tensor_size];\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(data, vec![tensor_size], DType::U8),\n        vec![\"exact_size\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    // This should succeed - file is exactly the right size\n    let reader = BurnpackReader::from_bytes(bytes);\n    assert!(reader.is_ok());\n\n    // Verify we can read the data\n    let reader = reader.unwrap();\n    let tensor_data = reader.get_tensor_data(\"exact_size\").unwrap();\n    assert_eq!(tensor_data.len(), tensor_size);\n    assert!(tensor_data.iter().all(|&b| b == 77));\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/round_trip.rs",
    "content": "use crate::burnpack::{reader::BurnpackReader, writer::BurnpackWriter};\n\nuse super::*;\nuse alloc::collections::BTreeMap;\nuse alloc::string::String;\nuse burn_tensor::{BoolStore, DType, TensorData, shape};\n\n/// Helper function to perform round-trip test\nfn round_trip_test<F>(setup: F)\nwhere\n    F: FnOnce(&mut Vec<TensorSnapshot>, &mut BTreeMap<String, String>),\n{\n    // Collect snapshots and metadata\n    let mut snapshots = Vec::new();\n    let mut metadata = BTreeMap::new();\n    setup(&mut snapshots, &mut metadata);\n\n    // Sort snapshots by name to ensure consistent ordering\n    // This is necessary because BTreeMap will store them sorted\n    snapshots.sort_by_key(|a| a.full_path());\n\n    // Create writer with snapshots and metadata\n    let mut writer = BurnpackWriter::new(snapshots);\n    for (key, value) in &metadata {\n        writer = writer.with_metadata(key, value);\n    }\n\n    let bytes = writer.to_bytes().unwrap();\n    let reader = BurnpackReader::from_bytes(bytes.clone()).unwrap();\n\n    // Write to bytes again from reader data\n    let mut snapshots2 = Vec::new();\n\n    // Copy tensors (metadata.tensors is now BTreeMap<String, TensorDescriptor>)\n    // They will come out in sorted order from tensor_names()\n    for tensor_name in reader.tensor_names() {\n        let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();\n        snapshots2.push(snapshot);\n    }\n\n    // Create writer2 with collected snapshots and metadata\n    let mut writer2 = BurnpackWriter::new(snapshots2);\n    for (key, value) in &reader.metadata().metadata {\n        writer2 = writer2.with_metadata(key, value);\n    }\n\n    let bytes2 = writer2.to_bytes().unwrap();\n\n    // Both byte representations should be identical\n    assert_eq!(bytes, bytes2, \"Round-trip produced different bytes\");\n}\n\n#[test]\nfn test_round_trip_empty() {\n    round_trip_test(|_snapshots, _metadata| {\n        // Empty writer\n    });\n}\n\n#[test]\nfn test_round_trip_metadata_only() {\n    round_trip_test(|_snapshots, metadata| {\n        metadata.insert(\"key1\".to_string(), \"value1\".to_string());\n        metadata.insert(\"key2\".to_string(), \"value2\".to_string());\n        metadata.insert(\"key3\".to_string(), \"value3\".to_string());\n    });\n}\n\n#[test]\nfn test_round_trip_f32() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [1.0f32, 2.0, 3.0, 4.0, 5.0];\n        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![5], DType::F32),\n            vec![\"f32_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_f64() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [1.0f64, 2.0, 3.0];\n        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![3], DType::F64),\n            vec![\"f64_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_i32() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [-10i32, 0, 10, 20];\n        let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![4], DType::I32),\n            vec![\"i32_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_i64() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [i64::MIN, 0, i64::MAX];\n        let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![3], DType::I64),\n            vec![\"i64_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_u32() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [0u32, 100, 1000, u32::MAX];\n        let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![4], DType::U32),\n            vec![\"u32_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_u64() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = [0u64, u64::MAX / 2, u64::MAX];\n        let bytes: Vec<u8> = data.iter().flat_map(|u| u.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![3], DType::U64),\n            vec![\"u64_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_u8() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = vec![0u8, 127, 255];\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data, vec![3], DType::U8),\n            vec![\"u8_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_bool() {\n    round_trip_test(|snapshots, _metadata| {\n        let data = vec![0u8, 1, 0, 1, 1];\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data, vec![5], DType::Bool(BoolStore::Native)),\n            vec![\"bool_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_mixed_dtypes() {\n    round_trip_test(|snapshots, _metadata| {\n        // F32\n        let f32_data = [1.0f32, 2.0];\n        let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let f32_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(f32_bytes, vec![2], DType::F32),\n            vec![\"f32\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(f32_snapshot);\n\n        // I64\n        let i64_data = [100i64, 200];\n        let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();\n        let i64_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(i64_bytes, vec![2], DType::I64),\n            vec![\"i64\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(i64_snapshot);\n\n        // Bool\n        let bool_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1, 0, 1], vec![3], DType::Bool(BoolStore::Native)),\n            vec![\"bool\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(bool_snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_multidimensional() {\n    round_trip_test(|snapshots, _metadata| {\n        // 2D tensor\n        let data_2d = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];\n        let bytes_2d: Vec<u8> = data_2d.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot_2d = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes_2d, vec![2, 3], DType::F32),\n            vec![\"tensor_2d\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot_2d);\n\n        // 3D tensor\n        let data_3d = [1.0f32; 24];\n        let bytes_3d: Vec<u8> = data_3d.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot_3d = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes_3d, vec![2, 3, 4], DType::F32),\n            vec![\"tensor_3d\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot_3d);\n\n        // 4D tensor (common for CNNs)\n        let data_4d = vec![1.0f32; 120];\n        let bytes_4d: Vec<u8> = data_4d.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot_4d = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes_4d, vec![2, 3, 4, 5], DType::F32),\n            vec![\"tensor_4d\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot_4d);\n    });\n}\n\n#[test]\nfn test_round_trip_with_metadata_and_tensors() {\n    round_trip_test(|snapshots, metadata| {\n        // Add metadata\n        metadata.insert(\"model_name\".to_string(), \"test_model\".to_string());\n        metadata.insert(\"version\".to_string(), \"1.0.0\".to_string());\n        metadata.insert(\n            \"description\".to_string(),\n            \"A test model for round-trip testing\".to_string(),\n        );\n\n        // Add tensors\n        let weights = [0.1f32, 0.2, 0.3, 0.4];\n        let weights_bytes: Vec<u8> = weights.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let weights_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(weights_bytes, vec![2, 2], DType::F32),\n            vec![\"layer1\".to_string(), \"weights\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(weights_snapshot);\n\n        let bias = [0.5f32, 0.6];\n        let bias_bytes: Vec<u8> = bias.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let bias_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bias_bytes, vec![2], DType::F32),\n            vec![\"layer1\".to_string(), \"bias\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(bias_snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_special_values() {\n    round_trip_test(|snapshots, _metadata| {\n        // Test special float values\n        let special_f32 = [\n            0.0f32,\n            -0.0,\n            f32::INFINITY,\n            f32::NEG_INFINITY,\n            f32::NAN,\n            f32::MIN,\n            f32::MAX,\n            f32::EPSILON,\n        ];\n        let f32_bytes: Vec<u8> = special_f32.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let f32_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(f32_bytes, vec![8], DType::F32),\n            vec![\"special_f32\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(f32_snapshot);\n\n        // Test special f64 values\n        let special_f64 = [\n            0.0f64,\n            -0.0,\n            f64::INFINITY,\n            f64::NEG_INFINITY,\n            f64::NAN,\n            f64::MIN,\n            f64::MAX,\n            f64::EPSILON,\n        ];\n        let f64_bytes: Vec<u8> = special_f64.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let f64_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(f64_bytes, vec![8], DType::F64),\n            vec![\"special_f64\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(f64_snapshot);\n    });\n}\n\n#[test]\nfn test_round_trip_large_tensors() {\n    round_trip_test(|snapshots, _metadata| {\n        // Large tensor (100KB)\n        let size = 25600; // 100KB / 4 bytes per f32\n        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();\n        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(bytes, vec![size], DType::F32),\n            vec![\"large_tensor\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    });\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn test_round_trip_file_io() {\n    use std::fs;\n    use tempfile::tempdir;\n\n    use crate::burnpack::writer::BurnpackWriter;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"round_trip.bpk\");\n\n    // Create original data\n    let data = [1.0f32, 2.0, 3.0, 4.0];\n    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),\n        vec![\"weights\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"test\", \"round_trip\");\n\n    // Write to file\n    writer.write_to_file(&file_path).unwrap();\n\n    // Read from file\n    let reader = BurnpackReader::from_file(&file_path).unwrap();\n\n    // Write to another file\n    let file_path2 = dir.path().join(\"round_trip2.bpk\");\n\n    // Collect snapshots from reader\n    let mut snapshots2 = Vec::new();\n    for tensor_name in reader.tensor_names() {\n        let snapshot = reader.get_tensor_snapshot(tensor_name).unwrap();\n        snapshots2.push(snapshot);\n    }\n\n    // Create writer2 with snapshots and metadata\n    let mut writer2 = BurnpackWriter::new(snapshots2);\n    for (key, value) in &reader.metadata().metadata {\n        writer2 = writer2.with_metadata(key, value);\n    }\n\n    writer2.write_to_file(&file_path2).unwrap();\n\n    // Compare files\n    let bytes1 = fs::read(&file_path).unwrap();\n    let bytes2 = fs::read(&file_path2).unwrap();\n\n    assert_eq!(\n        bytes1, bytes2,\n        \"Round-trip through files produced different content\"\n    );\n}\n\n#[test]\nfn test_round_trip_empty_shapes() {\n    round_trip_test(|snapshots, _metadata| {\n        // Scalar (0-dimensional)\n        let scalar = [42.0f32];\n        let scalar_bytes: Vec<u8> = scalar.iter().flat_map(|f| f.to_le_bytes()).collect();\n        let scalar_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(scalar_bytes, shape![], DType::F32),\n            vec![\"scalar\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(scalar_snapshot);\n\n        // Empty tensor\n        let empty_snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![], shape![0], DType::F32),\n            vec![\"empty\".to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(empty_snapshot);\n    });\n}\n\n#[test]\nfn test_param_id_persistence() {\n    use burn_core::module::ParamId;\n\n    // Create a specific ParamId with a known value\n    let original_param_id = ParamId::from(123456789u64);\n\n    let data = [1.0f32, 2.0, 3.0, 4.0];\n    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes, vec![2, 2], DType::F32),\n        vec![\"weights\".to_string()],\n        vec![],\n        original_param_id,\n    );\n\n    // Write to burnpack\n    let writer = BurnpackWriter::new(vec![snapshot]);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read back from burnpack\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n    let loaded_snapshot = reader.get_tensor_snapshot(\"weights\").unwrap();\n\n    // Verify ParamId was preserved\n    assert!(\n        loaded_snapshot.tensor_id.is_some(),\n        \"ParamId should be present\"\n    );\n    let loaded_param_id = loaded_snapshot.tensor_id.unwrap();\n    assert_eq!(\n        loaded_param_id.val(),\n        original_param_id.val(),\n        \"ParamId value should be preserved: expected {}, got {}\",\n        original_param_id.val(),\n        loaded_param_id.val()\n    );\n}\n\n#[test]\nfn test_param_id_backward_compatibility() {\n    use crate::burnpack::base::{BurnpackMetadata, TensorDescriptor};\n    use alloc::collections::BTreeMap;\n\n    // Create metadata without param_id (simulating old burnpack format)\n    let mut tensors = BTreeMap::new();\n    tensors.insert(\n        \"old_tensor\".to_string(),\n        TensorDescriptor {\n            dtype: DType::F32,\n            shape: vec![2, 2],\n            data_offsets: (0, 16),\n            param_id: None, // No param_id stored (old format)\n        },\n    );\n\n    let metadata = BurnpackMetadata {\n        tensors,\n        metadata: BTreeMap::new(),\n    };\n\n    // Serialize metadata\n    let mut metadata_bytes = Vec::new();\n    ciborium::ser::into_writer(&metadata, &mut metadata_bytes).unwrap();\n\n    // Create a complete burnpack with header and data\n    use crate::burnpack::base::{BurnpackHeader, FORMAT_VERSION, MAGIC_NUMBER};\n\n    let metadata_size = metadata_bytes.len() as u32;\n    let header = BurnpackHeader {\n        magic: MAGIC_NUMBER,\n        version: FORMAT_VERSION,\n        metadata_size,\n    };\n\n    let mut full_bytes = Vec::new();\n    full_bytes.extend_from_slice(&header.into_bytes());\n    full_bytes.extend_from_slice(&metadata_bytes);\n\n    // Add tensor data (4 f32 values = 16 bytes)\n    let tensor_data = vec![1.0f32, 2.0, 3.0, 4.0];\n    for value in tensor_data {\n        full_bytes.extend_from_slice(&value.to_le_bytes());\n    }\n\n    // Read the old format burnpack\n    let reader =\n        BurnpackReader::from_bytes(burn_tensor::Bytes::from_bytes_vec(full_bytes)).unwrap();\n    let loaded_snapshot = reader.get_tensor_snapshot(\"old_tensor\").unwrap();\n\n    // Verify that a new ParamId was generated (backward compatibility)\n    assert!(\n        loaded_snapshot.tensor_id.is_some(),\n        \"ParamId should be generated for old format\"\n    );\n\n    // The generated ParamId should be different each time (it's new), but we can't test the exact value\n    // We just verify it exists and has a valid u64 value\n    let generated_param_id = loaded_snapshot.tensor_id.unwrap();\n    assert!(\n        generated_param_id.val() > 0,\n        \"Generated ParamId should have a valid value\"\n    );\n}\n\n#[test]\nfn test_multiple_tensors_preserve_distinct_param_ids() {\n    use burn_core::module::ParamId;\n\n    // Create multiple tensors with distinct ParamIds\n    let param_id_1 = ParamId::from(111111u64);\n    let param_id_2 = ParamId::from(222222u64);\n    let param_id_3 = ParamId::from(333333u64);\n\n    let mut snapshots = Vec::new();\n\n    let data1 = [1.0f32, 2.0];\n    let bytes1: Vec<u8> = data1.iter().flat_map(|f| f.to_le_bytes()).collect();\n    snapshots.push(TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes1, vec![2], DType::F32),\n        vec![\"tensor1\".to_string()],\n        vec![],\n        param_id_1,\n    ));\n\n    let data2 = [3.0f32, 4.0];\n    let bytes2: Vec<u8> = data2.iter().flat_map(|f| f.to_le_bytes()).collect();\n    snapshots.push(TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes2, vec![2], DType::F32),\n        vec![\"tensor2\".to_string()],\n        vec![],\n        param_id_2,\n    ));\n\n    let data3 = [5.0f32, 6.0];\n    let bytes3: Vec<u8> = data3.iter().flat_map(|f| f.to_le_bytes()).collect();\n    snapshots.push(TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes3, vec![2], DType::F32),\n        vec![\"tensor3\".to_string()],\n        vec![],\n        param_id_3,\n    ));\n\n    // Write to burnpack\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read back\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    let snapshot1 = reader.get_tensor_snapshot(\"tensor1\").unwrap();\n    let snapshot2 = reader.get_tensor_snapshot(\"tensor2\").unwrap();\n    let snapshot3 = reader.get_tensor_snapshot(\"tensor3\").unwrap();\n\n    // Verify each ParamId was preserved correctly\n    assert_eq!(snapshot1.tensor_id.unwrap().val(), param_id_1.val());\n    assert_eq!(snapshot2.tensor_id.unwrap().val(), param_id_2.val());\n    assert_eq!(snapshot3.tensor_id.unwrap().val(), param_id_3.val());\n\n    // Verify they are distinct\n    let id1 = snapshot1.tensor_id.unwrap().val();\n    let id2 = snapshot2.tensor_id.unwrap().val();\n    let id3 = snapshot3.tensor_id.unwrap().val();\n\n    assert_ne!(id1, id2, \"ParamIds should be distinct\");\n    assert_ne!(id2, id3, \"ParamIds should be distinct\");\n    assert_ne!(id1, id3, \"ParamIds should be distinct\");\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/store.rs",
    "content": "#[cfg(feature = \"std\")]\nuse crate::KeyRemapper;\nuse crate::burnpack::store::BurnpackStore;\nuse crate::{ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter};\n\nuse burn_core as burn;\nuse burn_core::module::{Module, Param};\nuse burn_tensor::shape;\nuse burn_tensor::{Tensor, backend::Backend};\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[derive(Module, Debug)]\nstruct TestModule<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    bias: Param<Tensor<B, 1>>,\n    nested: NestedModule<B>,\n}\n\n#[derive(Module, Debug)]\nstruct NestedModule<B: Backend> {\n    gamma: Param<Tensor<B, 1>>,\n    beta: Param<Tensor<B, 1>>,\n}\n\nimpl<B: Backend> TestModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n            bias: Param::from_data([0.1, 0.2], device),\n            nested: NestedModule {\n                gamma: Param::from_data([1.0, 1.0], device),\n                beta: Param::from_data([0.0, 0.0], device),\n            },\n        }\n    }\n\n    fn new_zeros(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_tensor(Tensor::zeros([2, 2], device)),\n            bias: Param::from_tensor(Tensor::zeros([2], device)),\n            nested: NestedModule {\n                gamma: Param::from_tensor(Tensor::zeros([2], device)),\n                beta: Param::from_tensor(Tensor::zeros([2], device)),\n            },\n        }\n    }\n\n    fn new_uninitialized(device: &B::Device) -> Self {\n        use burn_core::module::ParamId;\n        let device_clone = device.clone();\n        let device_clone2 = device.clone();\n        let device_clone3 = device.clone();\n        let device_clone4 = device.clone();\n\n        Self {\n            weight: Param::uninitialized(\n                ParamId::new(),\n                move |d, _| Tensor::zeros([2, 2], d),\n                device_clone,\n                true,\n                [2, 2].into(),\n            ),\n            bias: Param::uninitialized(\n                ParamId::new(),\n                move |d, _| Tensor::zeros([2], d),\n                device_clone2,\n                true,\n                [2].into(),\n            ),\n            nested: NestedModule {\n                gamma: Param::uninitialized(\n                    ParamId::new(),\n                    move |d, _| Tensor::zeros([2], d),\n                    device_clone3,\n                    true,\n                    [2].into(),\n                ),\n                beta: Param::uninitialized(\n                    ParamId::new(),\n                    move |d, _| Tensor::zeros([2], d),\n                    device_clone4,\n                    true,\n                    [2].into(),\n                ),\n            },\n        }\n    }\n}\n\n#[test]\nfn test_store_from_bytes_round_trip() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load from bytes\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    // Verify success\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4); // weight, bias, nested.gamma, nested.beta\n    assert!(result.errors.is_empty());\n\n    // Verify data was loaded correctly\n    let weight1 = module.weight.val().to_data().to_vec::<f32>().unwrap();\n    let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(weight1, weight2);\n}\n\n#[test]\nfn test_store_with_metadata() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with metadata\n    let mut save_store = BurnpackStore::from_bytes(None)\n        .metadata(\"version\", \"1.0.0\")\n        .metadata(\"model_name\", \"test_model\")\n        .metadata(\"author\", \"burn_team\");\n\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load and verify metadata is preserved\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_with_path_filter() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save all tensors\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with filter - only load weight and bias (not nested)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_regex(\"^(weight|bias)$\");\n\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 2); // Only weight and bias\n    assert_eq!(result.skipped.len(), 2); // nested.gamma and nested.beta skipped\n\n    // Verify weight and bias were loaded\n    let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]);\n\n    // Verify nested module was NOT loaded (should still be zeros)\n    let gamma2 = module2\n        .nested\n        .gamma\n        .val()\n        .to_data()\n        .to_vec::<f32>()\n        .unwrap();\n    assert_eq!(gamma2, vec![0.0, 0.0]);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_with_key_remapping() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with original names\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with remapping: nested.gamma -> nested.new_gamma, nested.beta -> nested.new_beta\n    let remapper = KeyRemapper::new()\n        .add_pattern(r\"nested\\.gamma\", \"nested.new_gamma\")\n        .unwrap()\n        .add_pattern(r\"nested\\.beta\", \"nested.new_beta\")\n        .unwrap();\n\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes))\n        .remap(remapper)\n        .allow_partial(true);\n\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    // The remapping should cause missing tensors since names don't match\n    assert_eq!(result.applied.len(), 2); // Only weight and bias match\n    assert_eq!(result.unused.len(), 2); // nested.new_gamma and nested.new_beta are unused\n    assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing\n}\n\n#[test]\nfn test_store_allow_partial() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save only weight and bias\n    let filter = PathFilter::new()\n        .with_full_path(\"weight\")\n        .with_full_path(\"bias\");\n    let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with allow_partial\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true);\n\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 2);\n    assert_eq!(result.missing.len(), 2); // nested.gamma and nested.beta are missing but that's OK\n\n    // Verify loaded tensors\n    let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]);\n}\n\n#[test]\nfn test_store_match_all() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with match_all filter (should save everything)\n    let mut save_store = BurnpackStore::from_bytes(None).match_all();\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load everything\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4);\n    assert!(result.errors.is_empty());\n    assert!(result.missing.is_empty());\n    assert!(result.unused.is_empty());\n}\n\n#[test]\nfn test_store_with_full_path() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save everything\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load only specific tensors by full path\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes))\n        .with_full_path(\"weight\")\n        .with_full_path(\"nested.gamma\");\n\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 2); // Only weight and nested.gamma\n    assert_eq!(result.skipped.len(), 2); // bias and nested.beta skipped\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_chain_multiple_patterns() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with chained metadata and filters\n    let mut save_store = BurnpackStore::from_bytes(None)\n        .metadata(\"version\", \"1.0\")\n        .metadata(\"format\", \"burnpack\")\n        .with_regex(r\"^(weight|nested\\.)\")\n        .match_all(); // This overrides the previous filter\n\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load everything since match_all was called last\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4); // All tensors loaded\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_with_remap_pattern() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save normally\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with single remap pattern using the convenience method\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes))\n        .with_remap_pattern(r\"^nested\\.\", \"sub_module.\")\n        .allow_partial(true);\n\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    // After remapping, nested.* becomes sub_module.*, which won't match\n    assert_eq!(result.applied.len(), 2); // Only weight and bias\n    assert_eq!(result.unused.len(), 2); // sub_module.gamma and sub_module.beta unused\n}\n\n#[test]\nfn test_store_default_metadata() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save without adding custom metadata\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Verify default metadata is included\n    // We can't directly inspect metadata from bytes, but we can verify\n    // that the model loads successfully which means metadata was written correctly\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_store_default_metadata_with_custom() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with custom metadata (should preserve defaults)\n    let mut save_store = BurnpackStore::from_bytes(None)\n        .metadata(\"custom_field\", \"custom_value\")\n        .metadata(\"author\", \"test_author\");\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load and verify it works (metadata including defaults was saved)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_store_clear_metadata() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save with cleared metadata (no defaults)\n    let mut save_store = BurnpackStore::from_bytes(None).clear_metadata();\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Verify it still loads correctly\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_store_validate_enabled() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save normally\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with validation enabled (default)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert!(result.errors.is_empty());\n}\n\n#[test]\nfn test_store_validate_disabled() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save normally\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load with validation disabled\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).validate(false);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    // Should still succeed\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_store_allow_partial_missing_tensors() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save only weight (not bias or nested)\n    let filter = PathFilter::new().with_full_path(\"weight\");\n    let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Try to load without allow_partial - should fail due to missing tensors\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes.clone()));\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2);\n\n    // Should fail because of missing tensors\n    assert!(result.is_err());\n\n    // Now try with allow_partial - should succeed\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true);\n    let mut module3 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module3).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 1); // Only weight\n    assert!(!result.missing.is_empty()); // Has missing tensors\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_file_round_trip() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory and file path\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_file_round_trip.bpk\");\n\n    // Save to file\n    let mut save_store = BurnpackStore::from_file(&path).metadata(\"test\", \"value\");\n    save_store.collect_from(&module).unwrap();\n\n    // Verify file exists\n    assert!(path.exists());\n\n    // Load from file\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4);\n\n    // Verify data\n    let weight1 = module.weight.val().to_data().to_vec::<f32>().unwrap();\n    let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(weight1, weight2);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_overwrite_protection() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory and file path (file doesn't exist yet)\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_model.bpk\");\n\n    // First save - should succeed\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n    assert!(path.exists());\n\n    // Second save without overwrite flag - should fail\n    let mut save_store2 = BurnpackStore::from_file(&path);\n    let result = save_store2.collect_from(&module);\n    assert!(result.is_err());\n    assert!(\n        result\n            .unwrap_err()\n            .to_string()\n            .contains(\"File already exists\")\n    );\n\n    // Third save with overwrite flag - should succeed\n    let mut save_store3 = BurnpackStore::from_file(&path).overwrite(true);\n    save_store3.collect_from(&module).unwrap();\n\n    // Verify file still exists and is valid\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_overwrite_with_metadata() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory and file path\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_model_metadata.bpk\");\n\n    // First save with v1 metadata\n    let mut save_store = BurnpackStore::from_file(&path)\n        .metadata(\"version\", \"1.0\")\n        .overwrite(true);\n    save_store.collect_from(&module).unwrap();\n\n    // Second save with v2 metadata and overwrite enabled\n    let mut save_store2 = BurnpackStore::from_file(&path)\n        .metadata(\"version\", \"2.0\")\n        .overwrite(true);\n    save_store2.collect_from(&module).unwrap();\n\n    // Verify file loads correctly\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_auto_extension_default() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"model\");\n\n    // Save without extension - should auto-append .bpk\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Verify that model.bpk was created\n    let expected_path = temp_dir.path().join(\"model.bpk\");\n    assert!(expected_path.exists());\n    assert!(!path.exists()); // Original path without extension should not exist\n\n    // Load using the path without extension - should work\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_auto_extension_with_existing_extension() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"model.bpk\");\n\n    // Save with .bpk extension - should not double append\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Verify that only model.bpk was created\n    assert!(path.exists());\n    let double_ext_path = temp_dir.path().join(\"model.bpk.bpk\");\n    assert!(!double_ext_path.exists());\n\n    // Load and verify\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_auto_extension_with_custom_extension() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"model.mpk\");\n\n    // Save with .mpk extension - should preserve it\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Verify that model.mpk was created (not model.mpk.bpk)\n    assert!(path.exists());\n    let burnpack_path = temp_dir.path().join(\"model.mpk.bpk\");\n    assert!(!burnpack_path.exists());\n\n    // Load and verify\n    let mut load_store = BurnpackStore::from_file(&path);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_auto_extension_disabled() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Create temp directory\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"model\");\n\n    // Save with auto_extension disabled - should use exact path\n    let mut save_store = BurnpackStore::from_file(&path).auto_extension(false);\n    save_store.collect_from(&module).unwrap();\n\n    // Verify that \"model\" (without extension) was created\n    assert!(path.exists());\n    let burnpack_path = temp_dir.path().join(\"model.bpk\");\n    assert!(!burnpack_path.exists());\n\n    // Load with auto_extension disabled\n    let mut load_store = BurnpackStore::from_file(&path).auto_extension(false);\n    let mut module2 = TestModule::<TestBackend>::new_zeros(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_partial_loading_preserves_lazy_initialization() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n\n    // Create and save a full module\n    let module = TestModule::<TestBackend>::new(&device);\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"model.bpk\");\n\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Create an uninitialized module (all params lazy)\n    let mut load_module = TestModule::<TestBackend>::new_uninitialized(&device);\n\n    // Before loading: verify ALL params are uninitialized (lazy)\n    assert!(\n        !load_module.weight.is_initialized(),\n        \"weight should be uninitialized before loading\"\n    );\n    assert!(\n        !load_module.bias.is_initialized(),\n        \"bias should be uninitialized before loading\"\n    );\n    assert!(\n        !load_module.nested.gamma.is_initialized(),\n        \"nested.gamma should be uninitialized before loading\"\n    );\n    assert!(\n        !load_module.nested.beta.is_initialized(),\n        \"nested.beta should be uninitialized before loading\"\n    );\n\n    // Partial load: only load weight and bias (skip nested.*)\n    let filter = PathFilter::new().with_regex(\"^(weight|bias)$\");\n    let mut load_store = BurnpackStore::from_file(&path).filter(filter);\n    let result = load_module.load_from(&mut load_store).unwrap();\n\n    // Verify only weight and bias were loaded\n    assert_eq!(result.applied.len(), 2);\n    assert!(result.applied.contains(&\"weight\".to_string()));\n    assert!(result.applied.contains(&\"bias\".to_string()));\n    assert_eq!(result.skipped.len(), 2);\n    assert!(result.skipped.contains(&\"nested.gamma\".to_string()));\n    assert!(result.skipped.contains(&\"nested.beta\".to_string()));\n\n    // After loading: verify loaded params are initialized, skipped remain lazy\n    assert!(\n        load_module.weight.is_initialized(),\n        \"weight should be initialized after loading\"\n    );\n    assert!(\n        load_module.bias.is_initialized(),\n        \"bias should be initialized after loading\"\n    );\n    assert!(\n        !load_module.nested.gamma.is_initialized(),\n        \"nested.gamma should remain uninitialized (was skipped)\"\n    );\n    assert!(\n        !load_module.nested.beta.is_initialized(),\n        \"nested.beta should remain uninitialized (was skipped)\"\n    );\n\n    // Verify the loaded values are correct\n    let weight_data = load_module.weight.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);\n\n    let bias_data = load_module.bias.val().to_data().to_vec::<f32>().unwrap();\n    assert_eq!(bias_data, vec![0.1, 0.2]);\n\n    // Now check that nested params can still be initialized on first access\n    let gamma_data = load_module\n        .nested\n        .gamma\n        .val()\n        .to_data()\n        .to_vec::<f32>()\n        .unwrap();\n    assert_eq!(gamma_data, vec![0.0, 0.0]); // Initialized to zeros via the init function\n\n    // After accessing, they should be initialized\n    assert!(\n        load_module.nested.gamma.is_initialized(),\n        \"nested.gamma should be initialized after first access\"\n    );\n}\n\n// Model with forward pass for testing weight preservation\n#[derive(Module, Debug)]\nstruct ForwardTestModel<B: Backend> {\n    linear1: burn_nn::Linear<B>,\n    linear2: burn_nn::Linear<B>,\n}\n\nimpl<B: Backend> ForwardTestModel<B> {\n    /// Forward pass: input -> linear1 -> gelu -> linear2\n    fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        let x = self.linear1.forward(input);\n        let x = burn::tensor::activation::gelu(x);\n        self.linear2.forward(x)\n    }\n}\n\n#[derive(burn::config::Config, Debug)]\nstruct ForwardTestModelConfig {\n    input_size: usize,\n    hidden_size: usize,\n    output_size: usize,\n}\n\nimpl ForwardTestModelConfig {\n    fn init<B: Backend>(&self, device: &B::Device) -> ForwardTestModel<B> {\n        ForwardTestModel {\n            linear1: burn_nn::LinearConfig::new(self.input_size, self.hidden_size)\n                .with_bias(true)\n                .init(device),\n            linear2: burn_nn::LinearConfig::new(self.hidden_size, self.output_size)\n                .with_bias(true)\n                .init(device),\n        }\n    }\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_forward_pass_preservation_after_save_load() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n\n    // Create model config\n    let config = ForwardTestModelConfig {\n        input_size: 4,\n        hidden_size: 8,\n        output_size: 2,\n    };\n\n    // Initialize model1 with random weights\n    let model1 = config.init::<TestBackend>(&device);\n\n    // Create random input\n    let input = Tensor::<TestBackend, 2>::random(\n        [1, 4],\n        burn_tensor::Distribution::Uniform(-1.0, 1.0),\n        &device,\n    );\n\n    // Forward pass with model1 -> output1\n    let output1 = model1.forward(input.clone());\n\n    // Save model1 weights\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"forward_test_model.bpk\");\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&model1).unwrap();\n\n    // Initialize model2 with different random weights\n    let mut model2 = config.init::<TestBackend>(&device);\n\n    // Forward pass with model2 -> output2 (should differ from output1)\n    let output2 = model2.forward(input.clone());\n\n    // Verify output2 differs from output1 (different random weights)\n    assert!(\n        !output1\n            .clone()\n            .all_close(output2.clone(), Some(1e-6), Some(1e-6)),\n        \"output2 should differ from output1 (different random initializations)\"\n    );\n\n    // Load model1 weights into model2\n    let mut load_store = BurnpackStore::from_file(&path);\n    let result = load_store.apply_to(&mut model2).unwrap();\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases\n\n    // Forward pass with model2 (now has model1 weights) -> output3\n    let output3 = model2.forward(input.clone());\n\n    // Verify output3 equals output1 (same weights)\n    assert!(\n        output1.all_close(output3, Some(1e-6), Some(1e-6)),\n        \"output3 should equal output1 after loading weights\"\n    );\n}\n\n#[test]\nfn test_store_get_all_snapshots() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Get all snapshots (returns &BTreeMap<String, TensorSnapshot>)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshots = load_store.get_all_snapshots().unwrap();\n\n    // Should have 4 tensors\n    assert_eq!(snapshots.len(), 4);\n\n    // Verify tensor names exist (BTreeMap keys)\n    assert!(snapshots.contains_key(\"weight\"));\n    assert!(snapshots.contains_key(\"bias\"));\n    assert!(snapshots.contains_key(\"nested.gamma\"));\n    assert!(snapshots.contains_key(\"nested.beta\"));\n}\n\n#[test]\nfn test_store_get_snapshot_existing() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Get a specific snapshot (returns Option<&TensorSnapshot>)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshot = load_store.get_snapshot(\"weight\").unwrap();\n\n    // Should find the tensor\n    assert!(snapshot.is_some());\n    let snapshot = snapshot.unwrap();\n    assert_eq!(snapshot.full_path(), \"weight\");\n    assert_eq!(snapshot.shape, shape![2, 2]);\n\n    // Verify data can be loaded\n    let data = snapshot.to_data().unwrap();\n    assert_eq!(data.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);\n}\n\n#[test]\nfn test_store_get_snapshot_nested() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Get a nested snapshot\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshot = load_store.get_snapshot(\"nested.gamma\").unwrap();\n\n    assert!(snapshot.is_some());\n    let snapshot = snapshot.unwrap();\n    assert_eq!(snapshot.full_path(), \"nested.gamma\");\n    assert_eq!(snapshot.shape, shape![2]);\n}\n\n#[test]\nfn test_store_get_snapshot_not_found() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Try to get a non-existent snapshot\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshot = load_store.get_snapshot(\"nonexistent\").unwrap();\n\n    // Should return None\n    assert!(snapshot.is_none());\n}\n\n#[test]\nfn test_store_keys() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Get all keys\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let keys = load_store.keys().unwrap();\n\n    // Should have 4 keys\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"weight\".to_string()));\n    assert!(keys.contains(&\"bias\".to_string()));\n    assert!(keys.contains(&\"nested.gamma\".to_string()));\n    assert!(keys.contains(&\"nested.beta\".to_string()));\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_store_get_all_snapshots_from_file() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save to file\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_get_all_snapshots.bpk\");\n\n    let mut save_store = BurnpackStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Get snapshots from file (returns &BTreeMap)\n    let mut load_store = BurnpackStore::from_file(&path);\n    let snapshots = load_store.get_all_snapshots().unwrap();\n\n    assert_eq!(snapshots.len(), 4);\n\n    // Verify we can load data from a snapshot (use get() on BTreeMap)\n    let weight_snapshot = snapshots.get(\"weight\").unwrap();\n    let data = weight_snapshot.to_data().unwrap();\n    assert_eq!(data.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);\n}\n\n#[test]\nfn test_store_caching_behavior() {\n    let device = Default::default();\n    let module = TestModule::<TestBackend>::new(&device);\n\n    // Save module to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Create store and call get_snapshots multiple times\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n\n    // First call should populate cache\n    let snapshots1 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots1.len(), 4);\n\n    // Second call should return cached data (same reference)\n    let snapshots2 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots2.len(), 4);\n\n    // get_snapshot should also use the cache\n    let weight = load_store.get_snapshot(\"weight\").unwrap();\n    assert!(weight.is_some());\n}\n\n#[test]\nfn test_store_cache_invalidation_on_save() {\n    let device = Default::default();\n\n    // Create first module with specific weights\n    let module1 = TestModule::<TestBackend>::new(&device);\n\n    // Save module1 to bytes store\n    let mut store = BurnpackStore::from_bytes(None);\n    store.collect_from(&module1).unwrap();\n\n    // Populate cache by calling get_snapshots\n    let snapshots1 = store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots1.len(), 4);\n    let weight1_data = snapshots1.get(\"weight\").unwrap().to_data().unwrap();\n    let weight1_values: Vec<f32> = weight1_data.to_vec().unwrap();\n\n    // Create a different module with different weights\n    let module2 = TestModule::<TestBackend> {\n        weight: Param::from_tensor(Tensor::from_data([[10.0, 20.0], [30.0, 40.0]], &device)),\n        bias: Param::from_tensor(Tensor::from_data([100.0, 200.0], &device)),\n        nested: NestedModule {\n            gamma: Param::from_tensor(Tensor::from_data([1000.0, 2000.0], &device)),\n            beta: Param::from_tensor(Tensor::from_data([3000.0, 4000.0], &device)),\n        },\n    };\n\n    // Save module2 - this should invalidate the cache\n    store.collect_from(&module2).unwrap();\n\n    // Get snapshots again - should return NEW data, not cached old data\n    let snapshots2 = store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots2.len(), 4);\n    let weight2_data = snapshots2.get(\"weight\").unwrap().to_data().unwrap();\n    let weight2_values: Vec<f32> = weight2_data.to_vec().unwrap();\n\n    // Verify the data changed (cache was invalidated)\n    assert_ne!(weight1_values, weight2_values);\n    assert_eq!(weight2_values, vec![10.0, 20.0, 30.0, 40.0]);\n}\n\n/// Test storing and loading quantized weights with BurnpackStore.\n/// Regression test for https://github.com/tracel-ai/burn/issues/4179\n#[test]\nfn test_store_quantized_module_round_trip() {\n    use burn_core::module::Quantizer;\n    use burn_nn::LinearConfig;\n    use burn_tensor::quantization::{\n        Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue,\n    };\n\n    let device = Default::default();\n\n    // Create a simple linear module (512x512 as in the bug report)\n    let linear = LinearConfig::new(512, 512)\n        .with_bias(false)\n        .init::<TestBackend>(&device);\n\n    // Define quantization scheme (Q8S with tensor-level quantization)\n    let scheme = <<TestBackend as burn_tensor::backend::Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::Tensor)\n        .with_param(QuantParam::F32);\n\n    // Quantize the module\n    let calibration = Calibration::MinMax;\n    let mut quantizer = Quantizer {\n        calibration,\n        scheme,\n    };\n    let quantized_linear = linear.quantize_weights(&mut quantizer);\n\n    // Save the quantized module\n    let mut save_store = BurnpackStore::from_bytes(None);\n    let result = save_store.collect_from(&quantized_linear);\n    assert!(\n        result.is_ok(),\n        \"Failed to save quantized module: {:?}\",\n        result.err()\n    );\n\n    // Get the bytes\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    // Load the bytes and verify we can read the tensor metadata\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshots = load_store\n        .get_all_snapshots()\n        .expect(\"Failed to get snapshots\");\n\n    // Verify we have the weight tensor\n    assert_eq!(snapshots.len(), 1, \"Expected 1 tensor (weight)\");\n    assert!(snapshots.contains_key(\"weight\"), \"Expected 'weight' tensor\");\n\n    // Verify the tensor metadata\n    let weight_snapshot = snapshots.get(\"weight\").unwrap();\n    assert_eq!(weight_snapshot.shape, shape![512, 512]);\n\n    // Verify we can load the tensor data\n    let weight_data = weight_snapshot\n        .to_data()\n        .expect(\"Failed to load tensor data\");\n    assert_eq!(weight_data.shape, shape![512, 512]);\n}\n\n/// Test HalfPrecisionAdapter bidirectional round-trip: same adapter for save and load.\n#[test]\nfn test_store_half_precision_round_trip() {\n    use crate::HalfPrecisionAdapter;\n    use burn_nn::{Linear, LinearConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct HalfModel<B: Backend> {\n        linear: Linear<B>,\n    }\n\n    let device = Default::default();\n    let model = HalfModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n    };\n\n    // Save with HalfPrecisionAdapter (F32 -> F16)\n    let adapter = HalfPrecisionAdapter::new();\n    let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter.clone());\n    save_store.collect_from(&model).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Verify stored tensors are F16\n    let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone()));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (_, snapshot) in snapshots.iter() {\n        assert_eq!(snapshot.dtype, DType::F16, \"Expected F16 in stored data\");\n    }\n\n    // Load back with same adapter instance (F16 -> F32)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter);\n    let mut model2 = HalfModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n    };\n    let result = load_store.apply_to(&mut model2).unwrap();\n    assert!(result.is_success());\n\n    // Verify values are close (F32 -> F16 -> F32 has rounding)\n    let w1 = model.linear.weight.val().to_data().to_vec::<f32>().unwrap();\n    let w2 = model2\n        .linear\n        .weight\n        .val()\n        .to_data()\n        .to_vec::<f32>()\n        .unwrap();\n    for (a, b) in w1.iter().zip(w2.iter()) {\n        assert!(\n            (a - b).abs() < 0.01,\n            \"Weight values differ too much after F16 round-trip: {} vs {}\",\n            a,\n            b\n        );\n    }\n}\n\n/// Test HalfPrecisionAdapter: BatchNorm excluded by default.\n#[test]\nfn test_store_half_precision_batch_norm_excluded() {\n    use crate::HalfPrecisionAdapter;\n    use burn_nn::{BatchNorm, BatchNormConfig, Linear, LinearConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct BnModel<B: Backend> {\n        linear: Linear<B>,\n        bn: BatchNorm<B>,\n    }\n\n    let device = Default::default();\n    let model = BnModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n        bn: BatchNormConfig::new(2).init(&device),\n    };\n\n    let adapter = HalfPrecisionAdapter::new();\n    let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);\n    save_store.collect_from(&model).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Verify: Linear tensors are F16, BatchNorm tensors remain F32\n    let mut inspect_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (name, snapshot) in snapshots.iter() {\n        if name.starts_with(\"linear\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F16,\n                \"Linear tensor '{}' should be F16\",\n                name\n            );\n        } else if name.starts_with(\"bn\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F32,\n                \"BatchNorm tensor '{}' should stay F32\",\n                name\n            );\n        }\n    }\n}\n\n/// Test HalfPrecisionAdapter with without_module customization.\n#[test]\nfn test_store_half_precision_without_module() {\n    use crate::HalfPrecisionAdapter;\n    use burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct MixedModel<B: Backend> {\n        linear: Linear<B>,\n        norm: LayerNorm<B>,\n    }\n\n    let device = Default::default();\n    let model = MixedModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n        norm: LayerNormConfig::new(2).init(&device),\n    };\n\n    // Remove LayerNorm from half-precision conversion\n    let adapter = HalfPrecisionAdapter::new().without_module(\"LayerNorm\");\n    let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);\n    save_store.collect_from(&model).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut inspect_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (name, snapshot) in snapshots.iter() {\n        if name.starts_with(\"linear\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F16,\n                \"Linear tensor '{}' should be F16\",\n                name\n            );\n        } else if name.starts_with(\"norm\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F32,\n                \"LayerNorm tensor '{}' should stay F32\",\n                name\n            );\n        }\n    }\n}\n\n/// Test HalfPrecisionAdapter chained with PyTorch adapter.\n#[test]\nfn test_store_half_precision_chained_with_pytorch() {\n    use crate::{HalfPrecisionAdapter, PyTorchToBurnAdapter};\n    use burn_nn::{Linear, LinearConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct ChainModel<B: Backend> {\n        linear: Linear<B>,\n    }\n\n    let device = Default::default();\n    let model = ChainModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n    };\n\n    // Save with chained adapter: BurnToPyTorch then half-precision\n    let adapter = crate::BurnToPyTorchAdapter.chain(HalfPrecisionAdapter::new());\n    let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);\n    save_store.collect_from(&model).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Verify stored tensors are F16 and transposed\n    let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone()));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    let weight = snapshots.get(\"linear.weight\").unwrap();\n    assert_eq!(weight.dtype, DType::F16);\n    // Weight should be transposed: [4, 2] original -> [2, 4] after BurnToPyTorch\n    assert_eq!(weight.shape, shape![2, 4]);\n\n    // Load back with reverse chain: half-precision (F16 -> F32) then PyTorchToBurn\n    let adapter = HalfPrecisionAdapter::new().chain(PyTorchToBurnAdapter);\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter);\n    let mut model2 = ChainModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n    };\n    let result = load_store.apply_to(&mut model2).unwrap();\n    assert!(result.is_success());\n}\n\n/// Test storing quantized weights with block-level quantization.\n#[test]\nfn test_store_quantized_module_block_level() {\n    use burn_core::module::Quantizer;\n    use burn_nn::LinearConfig;\n    use burn_tensor::quantization::{\n        Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue,\n    };\n\n    let device = Default::default();\n\n    // Create a linear module\n    let linear = LinearConfig::new(128, 128)\n        .with_bias(false)\n        .init::<TestBackend>(&device);\n\n    // Define quantization scheme with block-level quantization\n    let scheme = <<TestBackend as burn_tensor::backend::Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()\n        .with_value(QuantValue::Q8S)\n        .with_level(QuantLevel::block([32])) // Block size of 32\n        .with_param(QuantParam::F32);\n\n    // Quantize the module\n    let calibration = Calibration::MinMax;\n    let mut quantizer = Quantizer {\n        calibration,\n        scheme,\n    };\n    let quantized_linear = linear.quantize_weights(&mut quantizer);\n\n    // Save the quantized module\n    let mut save_store = BurnpackStore::from_bytes(None);\n    let result = save_store.collect_from(&quantized_linear);\n    assert!(\n        result.is_ok(),\n        \"Failed to save quantized module with block-level quantization: {:?}\",\n        result.err()\n    );\n\n    // Get the bytes and verify round-trip\n    let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let snapshots = load_store\n        .get_all_snapshots()\n        .expect(\"Failed to get snapshots\");\n\n    assert_eq!(snapshots.len(), 1);\n    let weight_snapshot = snapshots.get(\"weight\").unwrap();\n    assert_eq!(weight_snapshot.shape, shape![128, 128]);\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/writer.rs",
    "content": "use crate::burnpack::{\n    base::{\n        BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,\n        aligned_data_section_start, magic_range,\n    },\n    writer::BurnpackWriter,\n};\n\nuse super::*;\nuse burn_core::module::ParamId;\nuse burn_tensor::{BoolStore, DType, TensorData, shape};\nuse std::rc::Rc;\n\n#[test]\nfn test_writer_new() {\n    let writer = BurnpackWriter::new(vec![]);\n    assert_eq!(writer.snapshots.len(), 0);\n    assert!(writer.metadata.is_empty());\n}\n\n#[test]\nfn test_writer_add_metadata() {\n    let writer = BurnpackWriter::new(vec![])\n        .with_metadata(\"model_name\", \"test_model\")\n        .with_metadata(\"version\", \"1.0.0\")\n        .with_metadata(\"author\", \"test_author\");\n\n    assert_eq!(writer.metadata.len(), 3);\n    assert_eq!(\n        writer.metadata.get(\"model_name\"),\n        Some(&\"test_model\".to_string())\n    );\n    assert_eq!(writer.metadata.get(\"version\"), Some(&\"1.0.0\".to_string()));\n    assert_eq!(\n        writer.metadata.get(\"author\"),\n        Some(&\"test_author\".to_string())\n    );\n}\n\n#[test]\nfn test_writer_add_tensor_snapshot() {\n    // Create test tensor snapshots\n    let snapshot1 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"layer1\".to_string(), \"weights\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let snapshot2 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![5, 6, 7, 8], vec![4], DType::U8),\n        vec![\"layer1\".to_string(), \"bias\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]);\n\n    assert_eq!(writer.snapshots.len(), 2);\n    assert_eq!(writer.snapshots[0].full_path(), \"layer1.weights\");\n    assert_eq!(writer.snapshots[1].full_path(), \"layer1.bias\");\n}\n\n#[test]\nfn test_writer_to_bytes_empty() {\n    let writer = BurnpackWriter::new(vec![]);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Verify header\n    assert!(bytes.len() >= HEADER_SIZE);\n    assert_eq!(&bytes[magic_range()], &MAGIC_NUMBER.to_le_bytes());\n\n    // Parse header\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    assert_eq!(header.magic, MAGIC_NUMBER);\n    assert_eq!(header.version, FORMAT_VERSION);\n\n    // Verify metadata\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata_bytes = &bytes[HEADER_SIZE..metadata_end];\n    let metadata: BurnpackMetadata = ciborium::de::from_reader(metadata_bytes).unwrap();\n\n    assert_eq!(metadata.tensors.len(), 0);\n    assert!(metadata.metadata.is_empty());\n}\n\n#[test]\nfn test_writer_to_bytes_with_tensors() {\n    // Add tensors with different data types\n    let f32_data = [1.0f32, 2.0, 3.0, 4.0];\n    let f32_bytes: Vec<u8> = f32_data.iter().flat_map(|f| f.to_le_bytes()).collect();\n    let snapshot_f32 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(f32_bytes.clone(), vec![2, 2], DType::F32),\n        vec![\"weights\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let i64_data = [10i64, 20, 30];\n    let i64_bytes: Vec<u8> = i64_data.iter().flat_map(|i| i.to_le_bytes()).collect();\n    let snapshot_i64 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(i64_bytes.clone(), vec![3], DType::I64),\n        vec![\"bias\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot_f32, snapshot_i64])\n        .with_metadata(\"test_key\", \"test_value\");\n\n    let bytes = writer.to_bytes().unwrap();\n\n    // Parse and verify\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    // Verify metadata\n    assert_eq!(\n        metadata.metadata.get(\"test_key\"),\n        Some(&\"test_value\".to_string())\n    );\n\n    // Verify tensors\n    assert_eq!(metadata.tensors.len(), 2);\n\n    let weights = metadata.tensors.get(\"weights\").unwrap();\n    assert_eq!(weights.dtype, DType::F32);\n    assert_eq!(weights.shape, vec![2, 2]);\n    assert_eq!(weights.data_offsets.1 - weights.data_offsets.0, 16); // 4 * 4 bytes\n\n    let bias = metadata.tensors.get(\"bias\").unwrap();\n    assert_eq!(bias.dtype, DType::I64);\n    assert_eq!(bias.shape, vec![3]);\n    assert_eq!(bias.data_offsets.1 - bias.data_offsets.0, 24); // 3 * 8 bytes\n\n    // Verify actual tensor data\n    // Data section starts at aligned position after metadata\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n    let weights = metadata.tensors.get(\"weights\").unwrap();\n    let bias = metadata.tensors.get(\"bias\").unwrap();\n    let weights_data = &bytes[data_section_start + weights.data_offsets.0 as usize\n        ..data_section_start + weights.data_offsets.1 as usize];\n    assert_eq!(weights_data, f32_bytes);\n\n    let bias_data = &bytes[data_section_start + bias.data_offsets.0 as usize\n        ..data_section_start + bias.data_offsets.1 as usize];\n    assert_eq!(bias_data, i64_bytes);\n}\n\n#[test]\nfn test_writer_all_dtypes() {\n    use half::{bf16, f16};\n\n    // Test all supported data types (excluding QFloat which is tested separately)\n    // Format: (DType, expected_size_per_element, sample_data_bytes)\n    let test_cases = vec![\n        // Floating point types\n        (DType::F64, 8, 1.0f64.to_le_bytes().to_vec()),\n        (DType::F32, 4, 1.0f32.to_le_bytes().to_vec()),\n        (DType::F16, 2, f16::from_f32(1.0).to_le_bytes().to_vec()),\n        (DType::BF16, 2, bf16::from_f32(1.0).to_le_bytes().to_vec()),\n        // Signed integers\n        (DType::I64, 8, 1i64.to_le_bytes().to_vec()),\n        (DType::I32, 4, 1i32.to_le_bytes().to_vec()),\n        (DType::I16, 2, 1i16.to_le_bytes().to_vec()),\n        (DType::I8, 1, 1i8.to_le_bytes().to_vec()),\n        // Unsigned integers\n        (DType::U64, 8, 255u64.to_le_bytes().to_vec()),\n        (DType::U32, 4, 255u32.to_le_bytes().to_vec()),\n        (DType::U16, 2, 255u16.to_le_bytes().to_vec()),\n        (DType::U8, 1, vec![255u8]),\n        // Boolean\n        (DType::Bool(BoolStore::Native), 1, vec![1u8]),\n    ];\n\n    let mut snapshots = vec![];\n    let mut expected_data = vec![];\n    for (i, (dtype, expected_size, data)) in test_cases.into_iter().enumerate() {\n        let name = format!(\"tensor_{}\", i);\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data.clone(), vec![1], dtype),\n            vec![name.clone()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n        expected_data.push((name, dtype, expected_size, data));\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Parse and verify metadata\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])\n            .unwrap();\n\n    assert_eq!(\n        metadata.tensors.len(),\n        13,\n        \"Expected 13 dtypes to be tested\"\n    );\n\n    // Verify each tensor's metadata and data\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n    for (name, expected_dtype, expected_size, expected_bytes) in expected_data {\n        let tensor = metadata\n            .tensors\n            .get(&name)\n            .unwrap_or_else(|| panic!(\"Missing tensor: {}\", name));\n        assert_eq!(tensor.dtype, expected_dtype, \"DType mismatch for {}\", name);\n        assert_eq!(tensor.shape, vec![1], \"Shape mismatch for {}\", name);\n\n        // Verify data size matches expected\n        let data_size = (tensor.data_offsets.1 - tensor.data_offsets.0) as usize;\n        assert_eq!(\n            data_size, expected_size,\n            \"Data size mismatch for {} ({:?})\",\n            name, expected_dtype\n        );\n\n        // Verify actual data bytes match\n        let actual_bytes = &bytes[data_section_start + tensor.data_offsets.0 as usize\n            ..data_section_start + tensor.data_offsets.1 as usize];\n        assert_eq!(\n            actual_bytes,\n            expected_bytes.as_slice(),\n            \"Data mismatch for {} ({:?})\",\n            name,\n            expected_dtype\n        );\n    }\n}\n\n#[test]\nfn test_writer_all_dtypes_round_trip() {\n    use crate::burnpack::reader::BurnpackReader;\n    use half::{bf16, f16};\n\n    // Test all dtypes can be written and read back correctly\n    let test_cases = vec![\n        // Floating point types - use multiple elements to better test\n        (\n            \"f64_tensor\",\n            DType::F64,\n            [1.0f64, 2.0, 3.0, 4.0]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![4],\n        ),\n        (\n            \"f32_tensor\",\n            DType::F32,\n            [1.0f32, 2.0, 3.0, 4.0]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2, 2],\n        ),\n        (\n            \"f16_tensor\",\n            DType::F16,\n            [f16::from_f32(1.0), f16::from_f32(2.0)]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2],\n        ),\n        (\n            \"bf16_tensor\",\n            DType::BF16,\n            [bf16::from_f32(1.0), bf16::from_f32(2.0)]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2],\n        ),\n        // Signed integers\n        (\n            \"i64_tensor\",\n            DType::I64,\n            [1i64, -2, 3, -4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![4],\n        ),\n        (\n            \"i32_tensor\",\n            DType::I32,\n            [1i32, -2, 3, -4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2, 2],\n        ),\n        (\n            \"i16_tensor\",\n            DType::I16,\n            [1i16, -2, 3, -4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![4],\n        ),\n        (\n            \"i8_tensor\",\n            DType::I8,\n            [1i8, -2, 3, -4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2, 2],\n        ),\n        // Unsigned integers\n        (\n            \"u64_tensor\",\n            DType::U64,\n            [1u64, 2, 3, 4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![4],\n        ),\n        (\n            \"u32_tensor\",\n            DType::U32,\n            [1u32, 2, 3, 4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![2, 2],\n        ),\n        (\n            \"u16_tensor\",\n            DType::U16,\n            [1u16, 2, 3, 4]\n                .iter()\n                .flat_map(|v| v.to_le_bytes())\n                .collect::<Vec<u8>>(),\n            shape![4],\n        ),\n        (\"u8_tensor\", DType::U8, vec![1u8, 2, 3, 4], shape![2, 2]),\n        // Boolean\n        (\n            \"bool_tensor\",\n            DType::Bool(BoolStore::Native),\n            vec![1u8, 0, 1, 0],\n            shape![4],\n        ),\n    ];\n\n    let mut snapshots = vec![];\n    let mut expected_results: Vec<(&str, DType, Vec<u8>, _)> = vec![];\n\n    for (name, dtype, data, shape) in test_cases.into_iter() {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(data.clone(), shape.clone(), dtype),\n            vec![name.to_string()],\n            vec![],\n            burn_core::module::ParamId::new(),\n        );\n        snapshots.push(snapshot);\n        expected_results.push((name, dtype, data, shape));\n    }\n\n    // Write to bytes\n    let writer = BurnpackWriter::new(snapshots);\n    let bytes = writer.to_bytes().unwrap();\n\n    // Read back using BurnpackReader\n    let reader = BurnpackReader::from_bytes(bytes).unwrap();\n\n    // Verify each tensor can be read back with correct data\n    for (name, expected_dtype, expected_data, expected_shape) in expected_results {\n        let snapshot = reader\n            .get_tensor_snapshot(name)\n            .unwrap_or_else(|e| panic!(\"Failed to get tensor snapshot {}: {}\", name, e));\n        let tensor_data = snapshot\n            .to_data()\n            .unwrap_or_else(|e| panic!(\"Failed to read tensor data {}: {}\", name, e));\n\n        assert_eq!(\n            tensor_data.dtype, expected_dtype,\n            \"DType mismatch for {}\",\n            name\n        );\n        assert_eq!(\n            tensor_data.shape, expected_shape,\n            \"Shape mismatch for {}\",\n            name\n        );\n        assert_eq!(\n            &tensor_data.bytes[..],\n            expected_data.as_slice(),\n            \"Data mismatch for {}\",\n            name\n        );\n    }\n}\n\n#[test]\nfn test_writer_large_tensor() {\n    // Create a large tensor (1MB)\n    let size = 256 * 1024; // 256K floats = 1MB\n    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();\n    let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(bytes.clone(), vec![size], DType::F32),\n        vec![\"large_tensor\".to_string()],\n        vec![],\n        burn_core::module::ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n\n    let result = writer.to_bytes().unwrap();\n\n    // Verify the large tensor is correctly stored\n    let header = BurnpackHeader::from_bytes(&result[..HEADER_SIZE]).unwrap();\n    let metadata: BurnpackMetadata = ciborium::de::from_reader(\n        &result[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize],\n    )\n    .unwrap();\n\n    assert_eq!(metadata.tensors.len(), 1);\n    let tensor = metadata.tensors.get(\"large_tensor\").unwrap();\n    assert_eq!(tensor.shape, vec![size as u64]);\n    assert_eq!(\n        tensor.data_offsets.1 - tensor.data_offsets.0,\n        (size * 4) as u64\n    );\n}\n\n#[test]\nfn test_writer_empty_tensors() {\n    // Add tensor with empty data\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![], vec![0], DType::F32),\n        vec![\"empty\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n\n    let bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])\n            .unwrap();\n\n    assert_eq!(metadata.tensors.len(), 1);\n    let tensor = metadata.tensors.get(\"empty\").unwrap();\n    assert_eq!(tensor.shape, vec![0]);\n    assert_eq!(tensor.data_offsets.1 - tensor.data_offsets.0, 0);\n}\n\n#[test]\nfn test_writer_special_characters_in_names() {\n    // Test various special characters in tensor names\n    let special_names = vec![\n        \"layer.0.weight\",\n        \"model/encoder/layer1\",\n        \"model::layer::weight\",\n        \"layer[0].bias\",\n        \"layer_1_weight\",\n        \"layer-1-bias\",\n        \"layer@1#weight\",\n        \"emoji_😀_tensor\",\n        \"unicode_测试_tensor\",\n        \"spaces in name\",\n    ];\n\n    let mut snapshots = vec![];\n    for name in &special_names {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![4], DType::U8),\n            vec![name.to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n\n    let bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])\n            .unwrap();\n\n    assert_eq!(metadata.tensors.len(), 10);\n    for (tensor_name, _tensor) in metadata.tensors.iter() {\n        assert!(!tensor_name.is_empty());\n        // Names should be preserved exactly\n        assert!(\n            tensor_name.contains(\"layer\")\n                || tensor_name.contains(\"model\")\n                || tensor_name.contains(\"emoji\")\n                || tensor_name.contains(\"unicode\")\n                || tensor_name.contains(\"spaces\")\n        );\n    }\n}\n\n#[test]\nfn test_writer_metadata_overwrite() {\n    let writer = BurnpackWriter::new(vec![])\n        .with_metadata(\"key\", \"value1\")\n        .with_metadata(\"key\", \"value2\");\n\n    assert_eq!(writer.metadata.get(\"key\"), Some(&\"value2\".to_string()));\n    assert_eq!(writer.metadata.len(), 1);\n}\n\n#[test]\nfn test_writer_tensor_order_preserved() {\n    // Add tensors in specific order\n    let names = vec![\"z_tensor\", \"a_tensor\", \"m_tensor\", \"b_tensor\"];\n\n    let mut snapshots = vec![];\n    for name in &names {\n        let snapshot = TensorSnapshot::from_data(\n            TensorData::from_bytes_vec(vec![1], vec![1], DType::U8),\n            vec![name.to_string()],\n            vec![],\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    let writer = BurnpackWriter::new(snapshots);\n\n    let bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..HEADER_SIZE + header.metadata_size as usize])\n            .unwrap();\n\n    // Verify all tensors are present (BTreeMap stores in sorted order by key)\n    let expected_sorted = vec![\"a_tensor\", \"b_tensor\", \"m_tensor\", \"z_tensor\"];\n    let actual_names: Vec<_> = metadata.tensors.keys().collect();\n    assert_eq!(actual_names, expected_sorted);\n}\n\n#[test]\nfn test_writer_lazy_snapshot_evaluation() {\n    // Create a lazy snapshot using closure\n    let data = Rc::new(vec![1.0f32, 2.0, 3.0, 4.0]);\n    let data_clone = data.clone();\n\n    let snapshot = TensorSnapshot::from_closure(\n        Rc::new(move || {\n            let bytes: Vec<u8> = data_clone.iter().flat_map(|f| f.to_le_bytes()).collect();\n            Ok(TensorData::from_bytes_vec(bytes, shape![2, 2], DType::F32))\n        }),\n        DType::F32,\n        shape![2, 2],\n        vec![\"lazy\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n\n    // The closure should only be evaluated when to_bytes is called\n    let bytes = writer.to_bytes().unwrap();\n\n    let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE]).unwrap();\n    let metadata_end = HEADER_SIZE + header.metadata_size as usize;\n    let metadata: BurnpackMetadata =\n        ciborium::de::from_reader(&bytes[HEADER_SIZE..metadata_end]).unwrap();\n\n    assert_eq!(metadata.tensors.len(), 1);\n    let tensor = metadata.tensors.get(\"lazy\").unwrap();\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, vec![2, 2]);\n\n    // Verify the data was correctly written\n    // Data section starts at aligned position after metadata\n    let data_section_start = aligned_data_section_start(header.metadata_size as usize);\n    let tensor_data = &bytes[data_section_start..data_section_start + 16];\n    let expected: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]\n        .iter()\n        .flat_map(|f| f.to_le_bytes())\n        .collect();\n    assert_eq!(tensor_data, expected.as_slice());\n}\n\n#[cfg(feature = \"std\")]\n#[test]\nfn test_writer_write_to_file() {\n    use std::fs;\n    use tempfile::tempdir;\n\n    let dir = tempdir().unwrap();\n    let file_path = dir.path().join(\"test.bpk\");\n\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"file_test\", \"true\");\n\n    writer.write_to_file(&file_path).unwrap();\n\n    // Verify file exists and has correct content\n    assert!(file_path.exists());\n\n    let file_bytes = fs::read(&file_path).unwrap();\n    let memory_bytes = writer.to_bytes().unwrap();\n\n    assert_eq!(file_bytes.as_slice(), &*memory_bytes);\n}\n\n#[test]\nfn test_writer_size() {\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"test\", \"value\");\n\n    let size = writer.size().unwrap();\n    let bytes = writer.to_bytes().unwrap();\n\n    // Size should match actual bytes length\n    assert_eq!(size, bytes.len());\n}\n\n#[test]\nfn test_writer_write_into() {\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]).with_metadata(\"test\", \"value\");\n\n    // Get size and allocate buffer\n    let size = writer.size().unwrap();\n    let mut buffer = vec![0u8; size];\n\n    // Write into buffer\n    writer.write_into(&mut buffer).unwrap();\n\n    // Compare with to_bytes()\n    let bytes = writer.to_bytes().unwrap();\n    assert_eq!(buffer.as_slice(), &*bytes);\n}\n\n#[test]\nfn test_writer_write_into_buffer_too_small() {\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n\n    // Allocate a buffer that's too small\n    let mut buffer = vec![0u8; 10];\n\n    // Should fail with buffer too small error\n    let result = writer.write_into(&mut buffer);\n    assert!(result.is_err());\n    assert!(result.unwrap_err().to_string().contains(\"Buffer too small\"));\n}\n\n#[test]\nfn test_writer_write_into_buffer_larger_than_needed() {\n    let snapshot = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"test\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot]);\n\n    // Allocate a larger buffer\n    let size = writer.size().unwrap();\n    let mut buffer = vec![0u8; size + 100]; // Extra 100 bytes\n\n    // Should succeed and only write the necessary bytes\n    writer.write_into(&mut buffer).unwrap();\n\n    // Compare the written portion with to_bytes()\n    let bytes = writer.to_bytes().unwrap();\n    assert_eq!(&buffer[..size], &*bytes);\n}\n\n#[test]\nfn test_writer_write_into_multiple_tensors() {\n    let snapshot1 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![1, 2, 3, 4], vec![2, 2], DType::U8),\n        vec![\"tensor1\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let snapshot2 = TensorSnapshot::from_data(\n        TensorData::from_bytes_vec(vec![5, 6, 7, 8, 9, 10], vec![2, 3], DType::U8),\n        vec![\"tensor2\".to_string()],\n        vec![],\n        ParamId::new(),\n    );\n\n    let writer = BurnpackWriter::new(vec![snapshot1, snapshot2]).with_metadata(\"test\", \"multiple\");\n\n    let size = writer.size().unwrap();\n    let mut buffer = vec![0u8; size];\n    writer.write_into(&mut buffer).unwrap();\n\n    let bytes = writer.to_bytes().unwrap();\n    assert_eq!(buffer.as_slice(), &*bytes);\n}\n\n#[test]\nfn test_writer_write_into_empty() {\n    let writer = BurnpackWriter::new(vec![]);\n\n    let size = writer.size().unwrap();\n    let mut buffer = vec![0u8; size];\n    writer.write_into(&mut buffer).unwrap();\n\n    let bytes = writer.to_bytes().unwrap();\n    assert_eq!(buffer.as_slice(), &*bytes);\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/tests/zero_copy.rs",
    "content": "//! Tests for zero-copy tensor loading functionality.\n\nuse crate::ModuleStore;\nuse crate::burnpack::store::BurnpackStore;\n\nuse burn_core as burn;\nuse burn_core::module::{Module, Param};\nuse burn_tensor::{AllocationProperty, Bytes, Tensor, backend::Backend};\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[derive(Module, Debug)]\nstruct SimpleModule<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    bias: Param<Tensor<B, 1>>,\n}\n\nimpl<B: Backend> SimpleModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_data([[1.0f32, 2.0], [3.0, 4.0]], device),\n            bias: Param::from_data([0.5f32, 1.5], device),\n        }\n    }\n\n    fn new_zeros(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_tensor(Tensor::zeros([2, 2], device)),\n            bias: Param::from_tensor(Tensor::zeros([2], device)),\n        }\n    }\n}\n\n/// Test that from_static creates a store with zero_copy enabled by default.\n#[test]\nfn test_from_static_enables_zero_copy() {\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to bytes first\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Convert to Vec<u8> and then leak to get &'static [u8]\n    let bytes_vec: Vec<u8> = bytes.to_vec();\n    let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());\n\n    // Create store from static - zero_copy should be enabled\n    let mut load_store = BurnpackStore::from_static(static_bytes);\n\n    // Load into a new module\n    let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);\n    load_store.apply_to(&mut loaded_module).unwrap();\n\n    // Verify data is correct\n    let loaded_weight = loaded_module.weight.val().to_data();\n    let loaded_bias = loaded_module.bias.val().to_data();\n\n    assert_eq!(\n        loaded_weight.to_vec::<f32>().unwrap(),\n        vec![1.0, 2.0, 3.0, 4.0]\n    );\n    assert_eq!(loaded_bias.to_vec::<f32>().unwrap(), vec![0.5, 1.5]);\n}\n\n/// Test that zero_copy builder method works.\n#[test]\nfn test_zero_copy_builder_method() {\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to bytes first\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Create shared bytes for zero-copy\n    let shared = bytes::Bytes::from(bytes.to_vec());\n    let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);\n\n    // Create store with zero_copy enabled\n    let mut load_store = BurnpackStore::from_bytes(Some(cubecl_bytes)).zero_copy(true);\n\n    // Load into a new module\n    let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);\n    load_store.apply_to(&mut loaded_module).unwrap();\n\n    // Verify data is correct\n    let loaded_weight = loaded_module.weight.val().to_data();\n    assert_eq!(\n        loaded_weight.to_vec::<f32>().unwrap(),\n        vec![1.0, 2.0, 3.0, 4.0]\n    );\n}\n\n/// Test that zero_copy(false) uses copying even with shared bytes.\n#[test]\nfn test_zero_copy_disabled_uses_copy() {\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to bytes first\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Convert to Vec<u8> and then leak to get &'static [u8]\n    let bytes_vec: Vec<u8> = bytes.to_vec();\n    let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());\n\n    // Create store from static but disable zero_copy\n    let mut load_store = BurnpackStore::from_static(static_bytes).zero_copy(false);\n\n    // Load into a new module\n    let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);\n    load_store.apply_to(&mut loaded_module).unwrap();\n\n    // Verify data is correct (copied, not zero-copy)\n    let loaded_weight = loaded_module.weight.val().to_data();\n    assert_eq!(\n        loaded_weight.to_vec::<f32>().unwrap(),\n        vec![1.0, 2.0, 3.0, 4.0]\n    );\n}\n\n/// Test that from_bytes with regular Bytes uses copying by default.\n#[test]\nfn test_from_bytes_uses_copy_by_default() {\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to bytes\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Load from bytes (default: zero_copy = false)\n    let mut load_store = BurnpackStore::from_bytes(Some(bytes));\n    let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);\n    load_store.apply_to(&mut loaded_module).unwrap();\n\n    // Verify data is correct\n    let loaded_weight = loaded_module.weight.val().to_data();\n    assert_eq!(\n        loaded_weight.to_vec::<f32>().unwrap(),\n        vec![1.0, 2.0, 3.0, 4.0]\n    );\n}\n\n/// Test that slice_bytes works correctly on StorageBackend.\n#[test]\nfn test_storage_backend_slice_bytes() {\n    use crate::burnpack::reader::BurnpackReader;\n\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to bytes first\n    let mut save_store = BurnpackStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    // Create shared bytes\n    let shared = bytes::Bytes::from(bytes.to_vec());\n    let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);\n\n    // Create reader and get snapshots with zero-copy\n    let reader = BurnpackReader::from_bytes(cubecl_bytes).unwrap();\n    let snapshots = reader.get_snapshots_zero_copy(true).unwrap();\n\n    // Verify we got the expected number of tensors\n    assert_eq!(snapshots.len(), 2);\n\n    // Load the tensor data\n    for snapshot in &snapshots {\n        let data = snapshot.to_data().unwrap();\n        // Just verify we can access the data - the actual content depends on tensor order\n        assert!(!data.bytes.is_empty());\n    }\n}\n\n/// Test that zero_copy=true with file-based loading works (via mmap + bytes::Bytes).\n#[test]\nfn test_zero_copy_file_based_works() {\n    use tempfile::NamedTempFile;\n\n    let device = Default::default();\n    let module = SimpleModule::<TestBackend>::new(&device);\n\n    // Save to a temporary file\n    let temp_file = NamedTempFile::new().unwrap();\n    let path = temp_file.path();\n\n    let mut save_store = BurnpackStore::from_file(path).overwrite(true);\n    save_store.collect_from(&module).unwrap();\n\n    // Load with zero_copy=true - should work because mmap is converted to bytes::Bytes\n    let mut load_store = BurnpackStore::from_file(path).zero_copy(true);\n    let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);\n\n    // The apply should succeed - mmap now supports zero-copy via bytes::Bytes::from_owner()\n    load_store.apply_to(&mut loaded_module).unwrap();\n\n    // Verify data is correct\n    let loaded_weight = loaded_module.weight.val().to_data();\n    assert_eq!(\n        loaded_weight.to_vec::<f32>().unwrap(),\n        vec![1.0, 2.0, 3.0, 4.0]\n    );\n}\n"
  },
  {
    "path": "crates/burn-store/src/burnpack/writer.rs",
    "content": "use super::base::{\n    BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,\n    TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,\n};\nuse crate::TensorSnapshot;\nuse alloc::collections::BTreeMap;\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_tensor::Bytes;\n\n#[cfg(feature = \"std\")]\nuse std::fs::File;\n#[cfg(feature = \"std\")]\nuse std::io::Write;\n#[cfg(feature = \"std\")]\nuse std::path::Path;\n\n/// Align an offset to the specified alignment boundary.\n///\n/// Returns the smallest value >= `offset` that is a multiple of `alignment`.\n#[inline]\nconst fn align_offset(offset: u64, alignment: u64) -> u64 {\n    offset.div_ceil(alignment) * alignment\n}\n\n/// Writer for creating Burnpack files\npub struct BurnpackWriter {\n    /// Tensors to write\n    pub(crate) snapshots: Vec<TensorSnapshot>,\n    /// Metadata key-value pairs\n    pub(crate) metadata: BTreeMap<String, String>,\n}\n\nimpl BurnpackWriter {\n    /// Create a new writer\n    pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {\n        Self {\n            snapshots,\n            metadata: BTreeMap::new(),\n        }\n    }\n\n    /// Builder pattern: add metadata and return self\n    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {\n        self.metadata.insert(key.to_string(), value.to_string());\n        self\n    }\n\n    /// Build tensor descriptors and metadata\n    fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {\n        // Build tensor descriptors and calculate offsets with alignment\n        let mut tensors = BTreeMap::new();\n        let mut current_offset = 0u64;\n\n        for snapshot in &self.snapshots {\n            let data_len = snapshot.data_len() as u64;\n\n            // Align the start offset for mmap zero-copy support\n            let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);\n            let end = aligned_start.checked_add(data_len).ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Tensor offset overflow: {} + {} exceeds maximum\",\n                    aligned_start, data_len\n                ))\n            })?;\n\n            tensors.insert(\n                snapshot.full_path(),\n                TensorDescriptor {\n                    dtype: snapshot.dtype,\n                    shape: snapshot.shape.iter().map(|&s| s as u64).collect(),\n                    data_offsets: (aligned_start, end),\n                    param_id: snapshot.tensor_id.map(|id| id.val()),\n                },\n            );\n\n            current_offset = end;\n        }\n\n        // Create metadata structure\n        let metadata = BurnpackMetadata {\n            tensors,\n            metadata: self.metadata.clone(),\n        };\n\n        // Serialize metadata with CBOR\n        let mut metadata_bytes = Vec::new();\n        ciborium::ser::into_writer(&metadata, &mut metadata_bytes)\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        Ok((metadata, metadata_bytes))\n    }\n\n    /// Calculate the total size needed for the burnpack data\n    ///\n    /// This is useful when you want to pre-allocate a buffer for `write_into()`.\n    /// The size includes padding bytes for both metadata alignment and tensor alignment.\n    pub fn size(&self) -> Result<usize, BurnpackError> {\n        let (metadata, metadata_bytes) = self.build_metadata()?;\n\n        // Data section starts at aligned position after header + metadata\n        let data_section_start = aligned_data_section_start(metadata_bytes.len());\n\n        // Calculate total data section size from aligned offsets\n        // The last tensor's end offset gives us the total data section size\n        let data_size = metadata\n            .tensors\n            .values()\n            .map(|t| t.data_offsets.1)\n            .max()\n            .unwrap_or(0) as usize;\n\n        Ok(data_section_start + data_size)\n    }\n\n    /// Write burnpack data into a caller-provided buffer\n    ///\n    /// The buffer must be large enough to hold all data. Use `size()` to determine\n    /// the required buffer size. If the buffer is too small, this will return an error.\n    ///\n    /// This allows the caller to control buffer allocation, enabling optimizations like:\n    /// - Buffer reuse across multiple writes\n    /// - Custom allocators\n    /// - Pinned memory for GPU transfers\n    ///\n    /// # Arguments\n    ///\n    /// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes.\n    pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {\n        let (metadata, metadata_bytes) = self.build_metadata()?;\n\n        // Check metadata size fits in u32\n        let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {\n            BurnpackError::IoError(format!(\n                \"Metadata size {} exceeds maximum of {} bytes\",\n                metadata_bytes.len(),\n                u32::MAX\n            ))\n        })?;\n\n        // Create header\n        let header = BurnpackHeader {\n            magic: MAGIC_NUMBER,\n            version: FORMAT_VERSION,\n            metadata_size,\n        };\n\n        // Data section starts at aligned position after header + metadata\n        let data_section_start = aligned_data_section_start(metadata_bytes.len());\n\n        // Calculate required size from aligned offsets\n        let data_size = metadata\n            .tensors\n            .values()\n            .map(|t| t.data_offsets.1)\n            .max()\n            .unwrap_or(0) as usize;\n        let total_size = data_section_start + data_size;\n\n        // Check buffer size\n        if buffer.len() < total_size {\n            return Err(BurnpackError::IoError(format!(\n                \"Buffer too small: need {} bytes, got {} bytes\",\n                total_size,\n                buffer.len()\n            )));\n        }\n\n        let mut offset = 0;\n\n        // Write header\n        let header_bytes = header.into_bytes();\n        buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);\n        offset += HEADER_SIZE;\n\n        // Write metadata\n        buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);\n        offset += metadata_bytes.len();\n\n        // Write padding to align data section start\n        if data_section_start > offset {\n            buffer[offset..data_section_start].fill(0);\n            offset = data_section_start;\n        }\n\n        // Write tensor data with alignment padding\n        for snapshot in &self.snapshots {\n            // Get the aligned offset from metadata\n            let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Internal error: tensor '{}' not found in metadata\",\n                    snapshot.full_path()\n                ))\n            })?;\n            let aligned_offset = descriptor.data_offsets.0 as usize;\n            let target_offset = data_section_start + aligned_offset;\n\n            // Write padding zeros if needed\n            if target_offset > offset {\n                buffer[offset..target_offset].fill(0);\n                offset = target_offset;\n            }\n\n            let expected_len = snapshot.data_len();\n            let data = snapshot.to_data().map_err(|e| {\n                BurnpackError::IoError(format!(\"Failed to get tensor data: {:?}\", e))\n            })?;\n            let actual_len = data.bytes.len();\n\n            // Validate data length consistency\n            if actual_len != expected_len {\n                return Err(BurnpackError::IoError(format!(\n                    \"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})\",\n                    snapshot.full_path(),\n                    expected_len,\n                    actual_len\n                )));\n            }\n\n            buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);\n            offset += actual_len;\n        }\n\n        Ok(())\n    }\n\n    /// Write to a byte buffer (convenience method)\n    ///\n    /// This allocates a buffer internally and writes the burnpack data.\n    /// For more control over buffer allocation, use `size()` + `write_into()`.\n    pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {\n        let size = self.size()?;\n        let mut buffer = vec![0u8; size];\n        self.write_into(&mut buffer)?;\n        Ok(Bytes::from_bytes_vec(buffer))\n    }\n\n    /// Write directly to a file (more memory efficient for large models)\n    #[cfg(feature = \"std\")]\n    pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {\n        let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        let (metadata, metadata_bytes) = self.build_metadata()?;\n\n        // Check metadata size fits in u32\n        let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {\n            BurnpackError::IoError(format!(\n                \"Metadata size {} exceeds maximum of {} bytes\",\n                metadata_bytes.len(),\n                u32::MAX\n            ))\n        })?;\n\n        // Create and write header\n        let header = BurnpackHeader {\n            magic: MAGIC_NUMBER,\n            version: FORMAT_VERSION,\n            metadata_size,\n        };\n\n        file.write_all(&header.into_bytes())\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        // Write metadata\n        file.write_all(&metadata_bytes)\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        // Data section starts at aligned position after header + metadata\n        let data_section_start = aligned_data_section_start(metadata_bytes.len());\n        let current_file_pos = HEADER_SIZE + metadata_bytes.len();\n\n        // Write padding to align data section start\n        if data_section_start > current_file_pos {\n            let padding_size = data_section_start - current_file_pos;\n            let padding = vec![0u8; padding_size];\n            file.write_all(&padding)\n                .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n        }\n\n        // Track current position within data section (relative to data_section_start)\n        let mut data_offset = 0usize;\n\n        // Stream tensor data directly to file with alignment padding\n        for snapshot in &self.snapshots {\n            // Get the aligned offset from metadata\n            let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {\n                BurnpackError::IoError(format!(\n                    \"Internal error: tensor '{}' not found in metadata\",\n                    snapshot.full_path()\n                ))\n            })?;\n            let aligned_offset = descriptor.data_offsets.0 as usize;\n\n            // Write padding zeros if needed\n            if aligned_offset > data_offset {\n                let padding_size = aligned_offset - data_offset;\n                let padding = vec![0u8; padding_size];\n                file.write_all(&padding)\n                    .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n                data_offset = aligned_offset;\n            }\n\n            let expected_len = snapshot.data_len();\n            let data = snapshot.to_data().map_err(|e| {\n                BurnpackError::IoError(format!(\"Failed to get tensor data: {:?}\", e))\n            })?;\n            let actual_len = data.bytes.len();\n\n            // Validate data length consistency\n            if actual_len != expected_len {\n                return Err(BurnpackError::IoError(format!(\n                    \"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})\",\n                    snapshot.full_path(),\n                    expected_len,\n                    actual_len\n                )));\n            }\n\n            file.write_all(&data.bytes)\n                .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n            data_offset += actual_len;\n        }\n\n        file.flush()\n            .map_err(|e| BurnpackError::IoError(e.to_string()))?;\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/collector.rs",
    "content": "use alloc::boxed::Box;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\n\nuse burn_tensor::{Bool, Int, Tensor, backend::Backend};\n\nuse crate::{ModuleAdapter, PathFilter, TensorSnapshot};\nuse burn_core::module::{ModuleVisitor, Param, ParamId};\n\n/// Collects tensor views from modules without copying data.\n///\n/// This collector traverses a module hierarchy and creates lightweight views\n/// of tensors that can be materialized to `TensorData` on demand.\n///\n/// # Examples\n///\n/// ## Collect all tensors\n/// ```rust,no_run\n/// # use burn_store::Collector;\n/// let collector = Collector::new(None, None, false);\n/// // Use with module.visit(&mut collector);\n/// let all_tensors = collector.tensors;\n/// ```\n///\n/// ## Filter with single pattern\n/// ```rust,no_run\n/// # use burn_store::{Collector, PathFilter};\n/// let filter = PathFilter::new().with_regex(r\"^encoder\\..*\");\n/// let collector = Collector::new(Some(filter), None, false);\n/// // Use with module.visit(&mut collector);\n/// // Only collects tensors starting with \"encoder.\"\n/// ```\n///\n/// ## Filter with multiple patterns (OR union)\n/// ```rust,no_run\n/// # use burn_store::{Collector, PathFilter};\n/// let filter = PathFilter::new()\n///     .with_regex(r\"^encoder\\..*\")  // Match all encoder tensors\n///     .with_regex(r\".*\\.bias$\");    // OR match any bias tensors\n/// let collector = Collector::new(Some(filter), None, false);\n/// // Use with module.visit(&mut collector);\n/// // Collects tensors matching ANY of the patterns\n/// ```\npub struct Collector {\n    /// Collection of tensor views\n    pub tensors: Vec<TensorSnapshot>,\n    path_stack: Vec<String>,\n    container_stack: Vec<String>,\n    filter: Option<PathFilter>,\n    adapter: Option<Box<dyn ModuleAdapter>>,\n    /// Skip enum variant names when building paths\n    /// When true, enum variant names are not included in tensor paths\n    skip_enum_variants: bool,\n}\n\nimpl Default for Collector {\n    fn default() -> Self {\n        Self::new(None, None, false)\n    }\n}\n\nimpl Collector {\n    /// Create a new tensor view collector with an optional filter and adapter.\n    ///\n    /// # Arguments\n    ///\n    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.\n    ///   When `None`, all tensors are collected.\n    /// * `adapter` - Optional adapter to transform tensors based on container types.\n    ///   Applied to all collected tensors before returning.\n    /// * `skip_enum_variants` - Skip enum variant names when building paths.\n    ///   When true, paths will not include enum variant names (e.g., \"feature.weight\"\n    ///   instead of \"feature.BaseConv.weight\"). Useful when exporting to formats\n    ///   like PyTorch that don't use enum variants.\n    ///\n    /// # Examples\n    ///\n    /// ```rust,no_run\n    /// # use burn_store::{Collector, PathFilter};\n    /// // Collect all tensors without adapter\n    /// let collector = Collector::new(None, None, false);\n    ///\n    /// // Use PathFilter builder\n    /// let filter = PathFilter::new()\n    ///     .with_regex(r\"^encoder\\..*\")\n    ///     .with_full_path(\"decoder.weight\");\n    /// let collector = Collector::new(Some(filter), None, false);\n    ///\n    /// // Skip enum variants for PyTorch export\n    /// let collector = Collector::new(None, None, true);\n    /// ```\n    pub fn new(\n        filter: Option<PathFilter>,\n        adapter: Option<Box<dyn ModuleAdapter>>,\n        skip_enum_variants: bool,\n    ) -> Self {\n        Self {\n            tensors: Vec::new(),\n            path_stack: Vec::new(),\n            container_stack: Vec::new(),\n            filter,\n            adapter,\n            skip_enum_variants,\n        }\n    }\n\n    /// Apply the adapter to collected tensors and return the result.\n    pub fn into_tensors(self) -> Vec<TensorSnapshot> {\n        if let Some(adapter) = self.adapter {\n            self.tensors\n                .into_iter()\n                .map(|snapshot| adapter.adapt(&snapshot))\n                .collect()\n        } else {\n            self.tensors\n        }\n    }\n\n    fn should_collect(&self, path: &[String], container_stack: &[String]) -> bool {\n        // If filter is present, use it; otherwise collect all\n        match &self.filter {\n            None => true,\n            Some(f) => f.matches_with_container_path(path, container_stack),\n        }\n    }\n}\n\nimpl<B: Backend> ModuleVisitor<B> for Collector {\n    fn enter_module(&mut self, name: &str, container_type: &str) {\n        // Always track the container type for proper filtering and module type detection\n        self.container_stack.push(container_type.to_string());\n\n        // Only add to path if it's not an enum variant (when skip_enum_variants is enabled)\n        // This ensures paths are built without enum variant names from the start\n        if !self.skip_enum_variants || !container_type.starts_with(\"Enum:\") {\n            self.path_stack.push(name.to_string());\n        }\n    }\n\n    fn exit_module(&mut self, _name: &str, container_type: &str) {\n        self.container_stack.pop();\n\n        // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)\n        if !self.skip_enum_variants || !container_type.starts_with(\"Enum:\") {\n            self.path_stack.pop();\n        }\n    }\n\n    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n        if self.should_collect(&self.path_stack, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_float(\n                &param.transform_for_save().val(),\n                self.path_stack.clone(),\n                self.container_stack.clone(),\n                param.id,\n            ));\n        }\n    }\n\n    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {\n        if self.should_collect(&self.path_stack, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_int(\n                &param.transform_for_save().val(),\n                self.path_stack.clone(),\n                self.container_stack.clone(),\n                param.id,\n            ));\n        }\n    }\n\n    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {\n        if self.should_collect(&self.path_stack, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_bool(\n                &param.transform_for_save().val(),\n                self.path_stack.clone(),\n                self.container_stack.clone(),\n                param.id,\n            ));\n        }\n    }\n\n    fn visit_float_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D>,\n    ) {\n        // For path-based visits, we use the current container stack for filtering\n        if self.should_collect(path, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_float(\n                tensor,\n                path.to_vec(),\n                self.container_stack.clone(),\n                id,\n            ));\n        }\n    }\n\n    fn visit_int_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D, Int>,\n    ) {\n        if self.should_collect(path, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_int(\n                tensor,\n                path.to_vec(),\n                self.container_stack.clone(),\n                id,\n            ));\n        }\n    }\n\n    fn visit_bool_with_path<const D: usize>(\n        &mut self,\n        path: &[String],\n        id: ParamId,\n        tensor: &Tensor<B, D, Bool>,\n    ) {\n        if self.should_collect(path, &self.container_stack) {\n            self.tensors.push(TensorSnapshot::from_bool(\n                tensor,\n                path.to_vec(),\n                self.container_stack.clone(),\n                id,\n            ));\n        }\n    }\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use super::*;\n\n    use burn_core as burn;\n\n    type TestBackend = burn_ndarray::NdArray;\n    use alloc::collections::BTreeMap;\n    use alloc::string::String;\n    use burn_core::module::{Module, Param};\n    use burn_nn::LinearConfig;\n    use burn_tensor::shape;\n\n    #[test]\n    fn tensor_snapshot_collector() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        let mut collector = Collector::new(None, None, false);\n        let id = ParamId::new();\n\n        // Collect a tensor\n        collector.visit_float_with_path(&[\"model\".to_string(), \"weight\".to_string()], id, &tensor);\n\n        assert_eq!(collector.tensors.len(), 1);\n        assert_eq!(collector.tensors[0].full_path(), \"model.weight\");\n\n        // Verify the tensor can be converted to data\n        let view = &collector.tensors[0];\n        let data = view.to_data().unwrap();\n        assert_eq!(data.shape, shape![2, 2]);\n    }\n\n    #[test]\n    fn root_level_parameters() {\n        use burn_core::module::ModuleVisitor;\n\n        let device = Default::default();\n\n        // Create root-level parameters (single-element path, not nested in modules)\n        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);\n\n        let mut collector = Collector::new(None, None, false);\n\n        // Simulate module traversal for root-level parameters\n        // Enter \"weight\" path (as if we're visiting a field named \"weight\")\n        ModuleVisitor::<TestBackend>::enter_module(&mut collector, \"weight\", \"\");\n        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &weight);\n        ModuleVisitor::<TestBackend>::exit_module(&mut collector, \"weight\", \"\");\n\n        // Enter \"bias\" path (as if we're visiting a field named \"bias\")\n        ModuleVisitor::<TestBackend>::enter_module(&mut collector, \"bias\", \"\");\n        ModuleVisitor::<TestBackend>::visit_float(&mut collector, &bias);\n        ModuleVisitor::<TestBackend>::exit_module(&mut collector, \"bias\", \"\");\n\n        // Verify both parameters were collected\n        assert_eq!(collector.tensors.len(), 2);\n\n        // Verify paths are correct (single-element paths)\n        assert_eq!(collector.tensors[0].full_path(), \"weight\");\n        assert_eq!(collector.tensors[1].full_path(), \"bias\");\n\n        // Verify data is correct\n        let weight_data = collector.tensors[0]\n            .to_data()\n            .unwrap()\n            .to_vec::<f32>()\n            .unwrap();\n        let bias_data = collector.tensors[1]\n            .to_data()\n            .unwrap()\n            .to_vec::<f32>()\n            .unwrap();\n\n        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);\n        assert_eq!(bias_data, vec![5.0, 6.0]);\n    }\n\n    #[test]\n    #[cfg(target_has_atomic = \"ptr\")]\n    fn tensor_snapshot_collector_with_filter() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        let filter = PathFilter::new().with_regex(r\"^encoder\\..*\");\n        let mut collector = Collector::new(Some(filter), None, false);\n        let id = ParamId::new();\n\n        // This should be collected\n        collector.visit_float_with_path(\n            &[\"encoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        );\n        // This should NOT be collected\n        collector.visit_float_with_path(\n            &[\"decoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        );\n\n        assert_eq!(collector.tensors.len(), 1);\n        assert_eq!(collector.tensors[0].full_path(), \"encoder.weight\");\n    }\n\n    #[test]\n    #[cfg(target_has_atomic = \"ptr\")]\n    fn tensor_snapshot_collector_with_multiple_filters() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        // Multiple patterns - collect if matches ANY (OR union)\n        let filter = PathFilter::new()\n            .with_regex(r\"^encoder\\..*\") // Match encoder.*\n            .with_regex(r\".*\\.bias$\"); // Match *.bias\n        let mut collector = Collector::new(Some(filter), None, false);\n        let id = ParamId::new();\n\n        // These should be collected\n        collector.visit_float_with_path(\n            &[\"encoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        ); // matches first pattern\n        collector.visit_float_with_path(&[\"decoder\".to_string(), \"bias\".to_string()], id, &tensor); // matches second pattern\n        collector.visit_float_with_path(&[\"encoder\".to_string(), \"bias\".to_string()], id, &tensor); // matches both patterns\n\n        // This should NOT be collected\n        collector.visit_float_with_path(\n            &[\"decoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        ); // matches neither\n\n        assert_eq!(collector.tensors.len(), 3);\n        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();\n        assert!(paths.contains(&\"encoder.weight\".to_string()));\n        assert!(paths.contains(&\"decoder.bias\".to_string()));\n        assert!(paths.contains(&\"encoder.bias\".to_string()));\n        assert!(!paths.contains(&\"decoder.weight\".to_string()));\n    }\n\n    #[test]\n    fn tensor_snapshot_collector_with_predicate() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        // Use predicate function for filtering\n        fn filter_fn(path: &str, _container_path: &str) -> bool {\n            path.starts_with(\"encoder.\") || path == \"decoder.bias\"\n        }\n        let filter = PathFilter::new().with_predicate(filter_fn);\n        let mut collector = Collector::new(Some(filter), None, false);\n        let id = ParamId::new();\n\n        // These should be collected\n        collector.visit_float_with_path(\n            &[\"encoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        );\n        collector.visit_float_with_path(&[\"encoder\".to_string(), \"bias\".to_string()], id, &tensor);\n        collector.visit_float_with_path(&[\"decoder\".to_string(), \"bias\".to_string()], id, &tensor);\n\n        // This should NOT be collected\n        collector.visit_float_with_path(\n            &[\"decoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        );\n        collector.visit_float_with_path(&[\"other\".to_string(), \"tensor\".to_string()], id, &tensor);\n\n        assert_eq!(collector.tensors.len(), 3);\n        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();\n        assert!(paths.contains(&\"encoder.weight\".to_string()));\n        assert!(paths.contains(&\"encoder.bias\".to_string()));\n        assert!(paths.contains(&\"decoder.bias\".to_string()));\n        assert!(!paths.contains(&\"decoder.weight\".to_string()));\n        assert!(!paths.contains(&\"other.tensor\".to_string()));\n    }\n\n    #[test]\n    fn tensor_snapshot_collector_predicate_with_complex_logic() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        // Complex predicate with multiple conditions\n        fn complex_filter(path: &str, _container_path: &str) -> bool {\n            let parts: Vec<&str> = path.split('.').collect();\n            if parts.len() != 3 {\n                return false;\n            }\n            // Only collect if it's layer1 or layer2, and it's a weight tensor\n            (parts[1] == \"layer1\" || parts[1] == \"layer2\") && parts[2] == \"weight\"\n        }\n        let filter = PathFilter::new().with_predicate(complex_filter);\n        let mut collector = Collector::new(Some(filter), None, false);\n        let id = ParamId::new();\n\n        // These should be collected\n        collector.visit_float_with_path(\n            &[\n                \"model\".to_string(),\n                \"layer1\".to_string(),\n                \"weight\".to_string(),\n            ],\n            id,\n            &tensor,\n        );\n        collector.visit_float_with_path(\n            &[\n                \"model\".to_string(),\n                \"layer2\".to_string(),\n                \"weight\".to_string(),\n            ],\n            id,\n            &tensor,\n        );\n\n        // These should NOT be collected\n        collector.visit_float_with_path(\n            &[\n                \"model\".to_string(),\n                \"layer1\".to_string(),\n                \"bias\".to_string(),\n            ],\n            id,\n            &tensor,\n        );\n        collector.visit_float_with_path(\n            &[\n                \"model\".to_string(),\n                \"layer3\".to_string(),\n                \"weight\".to_string(),\n            ],\n            id,\n            &tensor,\n        );\n        collector.visit_float_with_path(\n            &[\"encoder\".to_string(), \"weight\".to_string()],\n            id,\n            &tensor,\n        ); // wrong structure\n\n        assert_eq!(collector.tensors.len(), 2);\n        let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();\n        assert!(paths.contains(&\"model.layer1.weight\".to_string()));\n        assert!(paths.contains(&\"model.layer2.weight\".to_string()));\n        assert!(!paths.contains(&\"model.layer1.bias\".to_string()));\n        assert!(!paths.contains(&\"model.layer3.weight\".to_string()));\n        assert!(!paths.contains(&\"encoder.weight\".to_string()));\n    }\n\n    // Test visitor that collects tensor paths\n    struct TensorPathCollector {\n        pub paths: BTreeMap<String, (ParamId, Vec<usize>)>,\n        path_stack: Vec<String>,\n    }\n\n    impl TensorPathCollector {\n        fn new() -> Self {\n            Self {\n                paths: BTreeMap::new(),\n                path_stack: Vec::new(),\n            }\n        }\n\n        fn current_path(&self) -> String {\n            self.path_stack.join(\".\")\n        }\n    }\n\n    impl<B: Backend> ModuleVisitor<B> for TensorPathCollector {\n        fn enter_module(&mut self, name: &str, _container_type: &str) {\n            self.path_stack.push(name.to_string());\n        }\n\n        fn exit_module(&mut self, _name: &str, _container_type: &str) {\n            self.path_stack.pop();\n        }\n\n        fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {\n            let path = self.current_path();\n            if !path.is_empty() {\n                self.paths.insert(\n                    path,\n                    (param.id, param.transform_for_save().val().shape().to_vec()),\n                );\n            }\n        }\n\n        fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {\n            let path = self.current_path();\n            if !path.is_empty() {\n                self.paths.insert(\n                    path,\n                    (param.id, param.transform_for_save().val().shape().to_vec()),\n                );\n            }\n        }\n\n        fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {\n            let path = self.current_path();\n            if !path.is_empty() {\n                self.paths.insert(\n                    path,\n                    (param.id, param.transform_for_save().val().shape().to_vec()),\n                );\n            }\n        }\n    }\n\n    // Simple nested module for testing\n    #[derive(Module, Debug)]\n    struct InnerModule<B: Backend> {\n        weight: Param<Tensor<B, 2>>,\n        bias: Param<Tensor<B, 1>>,\n    }\n\n    #[derive(Module, Debug)]\n    struct OuterModule<B: Backend> {\n        layer1: InnerModule<B>,\n        layer2: InnerModule<B>,\n    }\n\n    impl<B: Backend> InnerModule<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n                bias: Param::from_data([5.0, 6.0], device),\n            }\n        }\n    }\n\n    impl<B: Backend> OuterModule<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                layer1: InnerModule::new(device),\n                layer2: InnerModule::new(device),\n            }\n        }\n    }\n\n    #[test]\n    fn nested_module_path_tracking() {\n        let device = Default::default();\n        let module = OuterModule::<TestBackend>::new(&device);\n\n        let mut collector = TensorPathCollector::new();\n        module.visit(&mut collector);\n\n        let paths = collector.paths;\n\n        // Verify we have the expected paths\n        // Note: Param<Tensor> fields are themselves modules, so we get an extra level\n        assert!(paths.contains_key(\"layer1.weight\"), \"Missing layer1.weight\");\n        assert!(paths.contains_key(\"layer1.bias\"), \"Missing layer1.bias\");\n        assert!(paths.contains_key(\"layer2.weight\"), \"Missing layer2.weight\");\n        assert!(paths.contains_key(\"layer2.bias\"), \"Missing layer2.bias\");\n\n        // Verify the shapes are correct\n        assert_eq!(paths.get(\"layer1.weight\").unwrap().1, vec![2, 2]);\n        assert_eq!(paths.get(\"layer1.bias\").unwrap().1, vec![2]);\n        assert_eq!(paths.get(\"layer2.weight\").unwrap().1, vec![2, 2]);\n        assert_eq!(paths.get(\"layer2.bias\").unwrap().1, vec![2]);\n    }\n\n    #[test]\n    fn linear_module_paths() {\n        let device = Default::default();\n        let config = LinearConfig::new(10, 20).with_bias(true);\n        let linear = config.init::<TestBackend>(&device);\n\n        let mut collector = TensorPathCollector::new();\n        linear.visit(&mut collector);\n\n        let paths = collector.paths;\n\n        // Linear module has weight and optional bias\n        assert!(paths.contains_key(\"weight\"));\n        assert!(paths.contains_key(\"bias\"));\n\n        // Check dimensions\n        assert_eq!(paths.get(\"weight\").unwrap().1, vec![10, 20]);\n        assert_eq!(paths.get(\"bias\").unwrap().1, vec![20]);\n    }\n\n    // Deep nesting test structures (4+ levels)\n    #[derive(Module, Debug)]\n    struct Level4Module<B: Backend> {\n        weight: Param<Tensor<B, 2>>,\n        bias: Param<Tensor<B, 1>>,\n    }\n\n    #[derive(Module, Debug)]\n    struct Level3Module<B: Backend> {\n        layer: Level4Module<B>,\n        extra: Level4Module<B>,\n    }\n\n    #[derive(Module, Debug)]\n    struct Level2Module<B: Backend> {\n        block1: Level3Module<B>,\n        block2: Level3Module<B>,\n    }\n\n    #[derive(Module, Debug)]\n    struct Level1Module<B: Backend> {\n        encoder: Level2Module<B>,\n        decoder: Level2Module<B>,\n    }\n\n    #[derive(Module, Debug)]\n    struct DeepModel<B: Backend> {\n        backbone: Level1Module<B>,\n        head: Level4Module<B>,\n    }\n\n    impl<B: Backend> Level4Module<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n                bias: Param::from_data([5.0, 6.0], device),\n            }\n        }\n    }\n\n    impl<B: Backend> Level3Module<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                layer: Level4Module::new(device),\n                extra: Level4Module::new(device),\n            }\n        }\n    }\n\n    impl<B: Backend> Level2Module<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                block1: Level3Module::new(device),\n                block2: Level3Module::new(device),\n            }\n        }\n    }\n\n    impl<B: Backend> Level1Module<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                encoder: Level2Module::new(device),\n                decoder: Level2Module::new(device),\n            }\n        }\n    }\n\n    impl<B: Backend> DeepModel<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                backbone: Level1Module::new(device),\n                head: Level4Module::new(device),\n            }\n        }\n    }\n\n    #[test]\n    fn deep_module_path_tracking() {\n        let device = Default::default();\n        let model = DeepModel::<TestBackend>::new(&device);\n\n        let mut collector = Collector::new(None, None, false);\n        model.visit(&mut collector);\n\n        let views = collector.tensors;\n        let paths: Vec<String> = views.iter().map(|v| v.full_path()).collect();\n\n        // Test 5-level deep paths\n        assert!(paths.contains(&\"backbone.encoder.block1.layer.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block1.layer.bias\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block1.extra.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block1.extra.bias\".to_string()));\n\n        assert!(paths.contains(&\"backbone.encoder.block2.layer.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block2.layer.bias\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block2.extra.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.encoder.block2.extra.bias\".to_string()));\n\n        assert!(paths.contains(&\"backbone.decoder.block1.layer.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block1.layer.bias\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block1.extra.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block1.extra.bias\".to_string()));\n\n        assert!(paths.contains(&\"backbone.decoder.block2.layer.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block2.layer.bias\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block2.extra.weight\".to_string()));\n        assert!(paths.contains(&\"backbone.decoder.block2.extra.bias\".to_string()));\n\n        // Test 2-level paths\n        assert!(paths.contains(&\"head.weight\".to_string()));\n        assert!(paths.contains(&\"head.bias\".to_string()));\n\n        // Total should be 18 tensors (16 from backbone + 2 from head)\n        assert_eq!(views.len(), 18);\n\n        // Verify data can be materialized\n        let view = views\n            .iter()\n            .find(|v| v.full_path() == \"backbone.encoder.block1.layer.weight\")\n            .unwrap();\n        let data = view.to_data().unwrap();\n        assert_eq!(data.shape, shape![2, 2]);\n    }\n\n    #[test]\n    fn deep_module_filtered_export() {\n        let device = Default::default();\n        let model = DeepModel::<TestBackend>::new(&device);\n\n        // Test filtering at different depths\n        #[cfg(target_has_atomic = \"ptr\")]\n        {\n            let filter = PathFilter::new().with_regex(r\"^backbone\\.encoder\\..*\");\n            let mut collector = Collector::new(Some(filter), None, false);\n            model.visit(&mut collector);\n            assert_eq!(collector.tensors.len(), 8); // Only encoder tensors\n        }\n\n        // Test filtering specific blocks\n        #[cfg(target_has_atomic = \"ptr\")]\n        {\n            let filter = PathFilter::new().with_regex(r\".*\\.block1\\..*\");\n            let mut collector = Collector::new(Some(filter), None, false);\n            model.visit(&mut collector);\n            assert_eq!(collector.tensors.len(), 8); // block1 in both encoder and decoder\n        }\n\n        // Test filtering by tensor type at any depth\n        #[cfg(target_has_atomic = \"ptr\")]\n        {\n            let filter = PathFilter::new().with_regex(r\".*\\.weight$\");\n            let mut collector = Collector::new(Some(filter), None, false);\n            model.visit(&mut collector);\n            assert_eq!(collector.tensors.len(), 9); // All weight tensors\n        }\n\n        // Test complex multi-pattern filtering\n        #[cfg(target_has_atomic = \"ptr\")]\n        {\n            let filter = PathFilter::new()\n                .with_regex(r\"^backbone\\.encoder\\.block1\\..*\") // All encoder.block1 tensors\n                .with_regex(r\"^backbone\\.decoder\\..*\\.bias$\") // All decoder biases\n                .with_regex(r\"^head\\.weight$\"); // Head weight only\n            let mut collector = Collector::new(Some(filter), None, false);\n            model.visit(&mut collector);\n\n            // Should have:\n            // - 4 from encoder.block1 (2 weights + 2 biases)\n            // - 4 decoder biases\n            // - 1 head weight\n            assert_eq!(collector.tensors.len(), 9);\n\n            let paths: Vec<String> = collector.tensors.iter().map(|v| v.full_path()).collect();\n            assert!(paths.contains(&\"backbone.encoder.block1.layer.weight\".to_string()));\n            assert!(paths.contains(&\"backbone.decoder.block1.layer.bias\".to_string()));\n            assert!(paths.contains(&\"head.weight\".to_string()));\n            assert!(!paths.contains(&\"head.bias\".to_string())); // Not included\n        }\n    }\n\n    use crate::traits::ModuleSnapshot;\n    use burn_nn::Linear;\n    use hashbrown::HashMap;\n\n    // Test module with Option fields\n    #[derive(Module, Debug)]\n    struct OptionalFieldModule<B: Backend> {\n        required: Param<Tensor<B, 2>>,\n        optional: Option<Param<Tensor<B, 1>>>,\n    }\n\n    impl<B: Backend> OptionalFieldModule<B> {\n        fn new_with_optional(device: &B::Device) -> Self {\n            Self {\n                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n                optional: Some(Param::from_data([5.0, 6.0], device)),\n            }\n        }\n\n        fn new_without_optional(device: &B::Device) -> Self {\n            Self {\n                required: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n                optional: None,\n            }\n        }\n    }\n\n    #[test]\n    fn optional_field_module_with_value() {\n        let device = Default::default();\n        let module = OptionalFieldModule::<TestBackend>::new_with_optional(&device);\n\n        let views: HashMap<String, TensorSnapshot> = module\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        assert_eq!(views.len(), 2);\n        assert!(views.contains_key(\"required\"));\n        assert!(views.contains_key(\"optional\"));\n    }\n\n    #[test]\n    fn optional_field_module_without_value() {\n        let device = Default::default();\n        let module = OptionalFieldModule::<TestBackend>::new_without_optional(&device);\n\n        let views: HashMap<String, TensorSnapshot> = module\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        assert_eq!(views.len(), 1);\n        assert!(views.contains_key(\"required\"));\n        assert!(!views.contains_key(\"optional\"));\n    }\n\n    // Test Vec of modules\n    #[derive(Module, Debug)]\n    struct VecModule<B: Backend> {\n        layers: Vec<Linear<B>>,\n    }\n\n    impl<B: Backend> VecModule<B> {\n        fn new(device: &B::Device, num_layers: usize) -> Self {\n            Self {\n                layers: (0..num_layers)\n                    .map(|_| LinearConfig::new(10, 10).init(device))\n                    .collect(),\n            }\n        }\n    }\n\n    // Test tuple of modules\n    #[derive(Module, Debug)]\n    struct TupleModule<B: Backend> {\n        layers: (Linear<B>, Linear<B>, Linear<B>),\n    }\n\n    impl<B: Backend> TupleModule<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                layers: (\n                    LinearConfig::new(10, 10).init(device),\n                    LinearConfig::new(10, 10).init(device),\n                    LinearConfig::new(10, 10).init(device),\n                ),\n            }\n        }\n    }\n\n    #[test]\n    fn vec_module_collect() {\n        let device = Default::default();\n        let module = VecModule::<TestBackend>::new(&device, 3);\n\n        let views: HashMap<String, TensorSnapshot> = module\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        // With the fix, all Vec items should now be properly indexed and visited\n        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors\n\n        // Check that all indexed paths exist\n        assert!(views.contains_key(\"layers.0.weight\"));\n        assert!(views.contains_key(\"layers.0.bias\"));\n        assert!(views.contains_key(\"layers.1.weight\"));\n        assert!(views.contains_key(\"layers.1.bias\"));\n        assert!(views.contains_key(\"layers.2.weight\"));\n        assert!(views.contains_key(\"layers.2.bias\"));\n    }\n\n    #[test]\n    fn tuple_module_collect() {\n        let device = Default::default();\n        let module = TupleModule::<TestBackend>::new(&device);\n\n        let snapshots = module.collect(None, None, false);\n        assert_eq!(snapshots.len(), 6);\n\n        let views: HashMap<String, TensorSnapshot> =\n            snapshots.into_iter().map(|v| (v.full_path(), v)).collect();\n\n        assert_eq!(views.len(), 6);\n\n        assert!(views.contains_key(\"layers.0.weight\"));\n        assert!(views.contains_key(\"layers.0.bias\"));\n        assert!(views.contains_key(\"layers.1.weight\"));\n        assert!(views.contains_key(\"layers.1.bias\"));\n        assert!(views.contains_key(\"layers.2.weight\"));\n        assert!(views.contains_key(\"layers.2.bias\"));\n    }\n\n    // Test array of modules\n    #[derive(Module, Debug)]\n    struct ArrayModule<B: Backend> {\n        layers: [Linear<B>; 3],\n    }\n\n    impl<B: Backend> ArrayModule<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                layers: [\n                    LinearConfig::new(10, 10).init(device),\n                    LinearConfig::new(10, 10).init(device),\n                    LinearConfig::new(10, 10).init(device),\n                ],\n            }\n        }\n    }\n\n    #[test]\n    fn array_module_collect() {\n        let device = Default::default();\n        let module = ArrayModule::<TestBackend>::new(&device);\n\n        let views: HashMap<String, TensorSnapshot> = module\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        // All array items should be properly indexed\n        assert_eq!(views.len(), 6); // 3 layers × 2 tensors each = 6 tensors\n\n        // Check indexed paths\n        for i in 0..3 {\n            assert!(views.contains_key(&format!(\"layers.{}.weight\", i)));\n            assert!(views.contains_key(&format!(\"layers.{}.bias\", i)));\n        }\n    }\n\n    // Test enum modules\n    #[derive(Module, Debug)]\n    enum EnumModule<B: Backend> {\n        LayerA(Linear<B>),\n        LayerB(Linear<B>),\n        LayerC(Linear<B>),\n    }\n\n    #[test]\n    fn enum_module_collect() {\n        let device = Default::default();\n\n        // Test variant A\n        let module_a = EnumModule::<TestBackend>::LayerA(LinearConfig::new(10, 20).init(&device));\n        let views_a: HashMap<String, TensorSnapshot> = module_a\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        // Should have the variant name in the path\n        assert_eq!(views_a.len(), 2);\n        assert!(views_a.contains_key(\"LayerA.weight\"));\n        assert!(views_a.contains_key(\"LayerA.bias\"));\n\n        // Test variant B\n        let module_b = EnumModule::<TestBackend>::LayerB(LinearConfig::new(10, 20).init(&device));\n        let views_b: HashMap<String, TensorSnapshot> = module_b\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        assert_eq!(views_b.len(), 2);\n        assert!(views_b.contains_key(\"LayerB.weight\"));\n        assert!(views_b.contains_key(\"LayerB.bias\"));\n    }\n\n    // Container type tracking tests\n    #[test]\n    fn linear_container_type() {\n        let device = Default::default();\n\n        #[derive(Module, Debug)]\n        struct ModelWithLinear<B: Backend> {\n            linear: Linear<B>,\n        }\n\n        impl<B: Backend> ModelWithLinear<B> {\n            fn new(device: &B::Device) -> Self {\n                Self {\n                    linear: LinearConfig::new(10, 20).init(device),\n                }\n            }\n        }\n\n        let model = ModelWithLinear::<TestBackend>::new(&device);\n\n        let views: HashMap<String, TensorSnapshot> = model\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        // Check that tensors inside Linear layers have \"Struct:Linear\" as their module type\n        for (path, view) in views.iter() {\n            if path == \"linear.weight\" || path == \"linear.bias\" {\n                assert_eq!(\n                    view.module_type(),\n                    Some(\"Struct:Linear\".to_string()),\n                    \"Tensor '{}' should have module type 'Struct:Linear'\",\n                    path\n                );\n            }\n        }\n    }\n\n    #[test]\n    fn complex_model_container_types() {\n        let device = Default::default();\n\n        #[derive(Module, Debug)]\n        struct ComplexModel<B: Backend> {\n            linear_layers: [Linear<B>; 2],\n            vec_layers: Vec<Linear<B>>,\n            single_linear: Linear<B>,\n        }\n\n        impl<B: Backend> ComplexModel<B> {\n            fn new(device: &B::Device) -> Self {\n                Self {\n                    linear_layers: [\n                        LinearConfig::new(100, 50).init(device),\n                        LinearConfig::new(50, 10).init(device),\n                    ],\n                    vec_layers: vec![\n                        LinearConfig::new(10, 10).init(device),\n                        LinearConfig::new(10, 10).init(device),\n                    ],\n                    single_linear: LinearConfig::new(10, 1).init(device),\n                }\n            }\n        }\n\n        let model = ComplexModel::<TestBackend>::new(&device);\n\n        let views: HashMap<String, TensorSnapshot> = model\n            .collect(None, None, false)\n            .into_iter()\n            .map(|v| (v.full_path(), v))\n            .collect();\n\n        // Should have 10 tensors total\n        assert_eq!(views.len(), 10);\n\n        // Verify different module types\n        for (_path, view) in views.iter() {\n            assert_eq!(view.module_type(), Some(\"Struct:Linear\".to_string()));\n        }\n    }\n\n    #[test]\n    fn collect_with_container_filter() {\n        let device = Default::default();\n\n        #[derive(Module, Debug)]\n        struct FilterTestModel<B: Backend> {\n            layers: Vec<Linear<B>>,\n        }\n\n        impl<B: Backend> FilterTestModel<B> {\n            fn new(device: &B::Device) -> Self {\n                Self {\n                    layers: vec![\n                        LinearConfig::new(10, 10).init(device),\n                        LinearConfig::new(10, 10).init(device),\n                    ],\n                }\n            }\n        }\n\n        let model = FilterTestModel::<TestBackend>::new(&device);\n\n        // Filter to only collect tensors from Linear modules\n        let filter = PathFilter::new().with_predicate(|_path, container_path| {\n            container_path.split('.').next_back() == Some(\"Struct:Linear\")\n        });\n\n        let linear_views: Vec<TensorSnapshot> = model.collect(Some(filter), None, false);\n\n        // All collected tensors should be from Linear modules\n        for view in linear_views.iter() {\n            assert_eq!(\n                view.module_type(),\n                Some(\"Struct:Linear\".to_string()),\n                \"All tensors should be from Linear modules\"\n            );\n        }\n\n        // Should have collected all Linear tensors\n        assert_eq!(linear_views.len(), 4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/filter.rs",
    "content": "use alloc::format;\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse core::fmt;\n\n#[cfg(feature = \"std\")]\nuse regex::Regex;\n\n/// A sophisticated path filter that supports multiple matching strategies.\n///\n/// The filter uses an OR logic - a path is included if it matches ANY of the configured criteria.\n/// This allows for flexible and powerful filtering configurations.\n///\n/// # Examples\n///\n/// ```rust,no_run\n/// # use burn_store::PathFilter;\n/// // Create a filter that matches encoder paths or any weight path\n/// let filter = PathFilter::new()\n///     .with_regex(r\"^encoder\\..*\")\n///     .with_regex(r\".*\\.weight$\")\n///     .with_full_path(\"special_tensor\");\n///\n/// // Check if a path should be included\n/// if filter.matches(\"encoder.layer1.weight\") {\n///     // This will match due to both regex patterns\n/// }\n/// ```\n#[derive(Debug, Clone, Default)]\npub struct PathFilter {\n    /// Compiled regex patterns for matching paths\n    #[cfg(feature = \"std\")]\n    regex_patterns: Vec<Regex>,\n\n    /// Exact full paths to match\n    exact_paths: Vec<String>,\n\n    /// Predicate functions for custom matching logic based on path and container path\n    /// Note: These cannot be cloned, so we store them separately\n    predicates: Vec<fn(&str, &str) -> bool>,\n\n    /// If true, matches all paths (overrides other filters)\n    match_all: bool,\n}\n\nimpl PathFilter {\n    /// Create a new empty filter (matches nothing by default)\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Create a filter that matches all paths\n    pub fn all() -> Self {\n        Self {\n            match_all: true,\n            ..Default::default()\n        }\n    }\n\n    /// Create a filter that matches nothing\n    pub fn none() -> Self {\n        Self::default()\n    }\n\n    /// Add a regex pattern for matching paths\n    #[cfg(feature = \"std\")]\n    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {\n        if let Ok(regex) = Regex::new(pattern.as_ref()) {\n            self.regex_patterns.push(regex);\n        }\n        // TODO: Consider returning Result to handle regex compilation errors\n        self\n    }\n\n    /// Add multiple regex patterns\n    #[cfg(feature = \"std\")]\n    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: AsRef<str>,\n    {\n        for pattern in patterns {\n            if let Ok(regex) = Regex::new(pattern.as_ref()) {\n                self.regex_patterns.push(regex);\n            }\n        }\n        self\n    }\n\n    /// Add an exact full path to match\n    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {\n        self.exact_paths.push(path.into());\n        self\n    }\n\n    /// Add multiple exact full paths\n    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: Into<String>,\n    {\n        self.exact_paths.extend(paths.into_iter().map(|p| p.into()));\n        self\n    }\n\n    /// Add a predicate function for custom matching based on path and container path\n    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {\n        self.predicates.push(predicate);\n        self\n    }\n\n    /// Add multiple predicates\n    pub fn with_predicates<I>(mut self, predicates: I) -> Self\n    where\n        I: IntoIterator<Item = fn(&str, &str) -> bool>,\n    {\n        self.predicates.extend(predicates);\n        self\n    }\n\n    /// Set to match all paths\n    pub fn match_all(mut self) -> Self {\n        self.match_all = true;\n        self\n    }\n\n    /// Check if a path matches this filter (assumes empty container path for backward compatibility)\n    pub fn matches(&self, path: &str) -> bool {\n        self.matches_with_container_path_str(path, \"\")\n    }\n\n    /// Check if a path and container type match this filter (for backward compatibility)\n    pub fn matches_with_container(&self, path: &str, container_type: &str) -> bool {\n        // For backward compatibility, treat single container type as the full path\n        self.matches_with_container_path_str(path, container_type)\n    }\n\n    /// Check if a path and container path match this filter\n    pub fn matches_with_container_path(&self, path: &[String], container_stack: &[String]) -> bool {\n        let path_str = path.join(\".\");\n        let container_path = container_stack.join(\".\");\n        self.matches_with_container_path_str(&path_str, &container_path)\n    }\n\n    /// Check if a path and container path (dot-notated strings) match this filter\n    pub fn matches_with_container_path_str(&self, path: &str, container_path: &str) -> bool {\n        // If match_all is set, always return true\n        if self.match_all {\n            return true;\n        }\n\n        // If no filters are configured, match nothing\n        if self.is_empty() {\n            return false;\n        }\n\n        // Check exact path matches\n        if self.exact_paths.iter().any(|p| p == path) {\n            return true;\n        }\n\n        // Check regex patterns (on the path)\n        #[cfg(feature = \"std\")]\n        {\n            for regex in &self.regex_patterns {\n                if regex.is_match(path) {\n                    return true;\n                }\n            }\n        }\n\n        // Check predicates with container path\n        if self\n            .predicates\n            .iter()\n            .any(|pred| pred(path, container_path))\n        {\n            return true;\n        }\n\n        false\n    }\n\n    /// Check if the filter is empty (matches nothing)\n    pub fn is_empty(&self) -> bool {\n        if self.match_all {\n            return false;\n        }\n\n        #[cfg(feature = \"std\")]\n        let regex_empty = self.regex_patterns.is_empty();\n        #[cfg(not(feature = \"std\"))]\n        let regex_empty = true;\n\n        self.exact_paths.is_empty() && self.predicates.is_empty() && regex_empty\n    }\n\n    /// Get the number of filter criteria configured\n    pub fn criteria_count(&self) -> usize {\n        if self.match_all {\n            return 1;\n        }\n\n        #[allow(unused_mut)]\n        let mut count = self.exact_paths.len() + self.predicates.len();\n\n        #[cfg(feature = \"std\")]\n        {\n            count += self.regex_patterns.len();\n        }\n\n        count\n    }\n\n    /// Clear all regex patterns\n    #[cfg(feature = \"std\")]\n    pub fn clear_regex(&mut self) -> &mut Self {\n        self.regex_patterns.clear();\n        self\n    }\n\n    /// Clear all exact paths\n    pub fn clear_paths(&mut self) -> &mut Self {\n        self.exact_paths.clear();\n        self\n    }\n\n    /// Clear all predicates\n    pub fn clear_predicates(&mut self) -> &mut Self {\n        self.predicates.clear();\n        self\n    }\n\n    /// Clear all filters\n    pub fn clear(&mut self) -> &mut Self {\n        #[cfg(feature = \"std\")]\n        self.clear_regex();\n\n        self.clear_paths().clear_predicates();\n        self.match_all = false;\n        self\n    }\n\n    /// Create a filter from regex patterns only\n    #[cfg(feature = \"std\")]\n    pub fn from_regex_patterns<I, S>(patterns: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: AsRef<str>,\n    {\n        Self::new().with_regexes(patterns)\n    }\n\n    /// Create a filter from exact paths only\n    pub fn from_paths<I, S>(paths: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: Into<String>,\n    {\n        Self::new().with_full_paths(paths)\n    }\n\n    /// Create a filter from a single predicate\n    pub fn from_predicate(predicate: fn(&str, &str) -> bool) -> Self {\n        Self::new().with_predicate(predicate)\n    }\n\n    /// Combine with another filter using OR logic\n    pub fn or(mut self, other: Self) -> Self {\n        if self.match_all || other.match_all {\n            return Self::all();\n        }\n\n        #[cfg(feature = \"std\")]\n        {\n            self.regex_patterns.extend(other.regex_patterns);\n        }\n\n        self.exact_paths.extend(other.exact_paths);\n        self.predicates.extend(other.predicates);\n\n        self\n    }\n}\n\nimpl fmt::Display for PathFilter {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        if self.match_all {\n            return write!(f, \"PathFilter::all()\");\n        }\n\n        if self.is_empty() {\n            return write!(f, \"PathFilter::none()\");\n        }\n\n        write!(f, \"PathFilter[\")?;\n\n        let mut parts = Vec::new();\n\n        #[cfg(feature = \"std\")]\n        if !self.regex_patterns.is_empty() {\n            parts.push(format!(\"regex: {:?}\", self.regex_patterns));\n        }\n\n        if !self.exact_paths.is_empty() {\n            parts.push(format!(\"paths: {:?}\", self.exact_paths));\n        }\n\n        if !self.predicates.is_empty() {\n            parts.push(format!(\"predicates: {}\", self.predicates.len()));\n        }\n\n        write!(f, \"{}]\", parts.join(\", \"))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn empty_filter() {\n        let filter = PathFilter::new();\n        assert!(filter.is_empty());\n        assert!(!filter.matches(\"encoder.weight\"));\n        assert!(!filter.matches(\"decoder.bias\"));\n    }\n\n    #[test]\n    fn match_all() {\n        let filter = PathFilter::all();\n        assert!(!filter.is_empty());\n        assert!(filter.matches(\"encoder.weight\"));\n        assert!(filter.matches(\"decoder.bias\"));\n        assert!(filter.matches(\"anything\"));\n    }\n\n    #[test]\n    fn exact_paths() {\n        let filter = PathFilter::new()\n            .with_full_path(\"encoder.weight\")\n            .with_full_path(\"decoder.bias\");\n\n        assert!(filter.matches(\"encoder.weight\"));\n        assert!(filter.matches(\"decoder.bias\"));\n        assert!(!filter.matches(\"encoder.bias\"));\n        assert!(!filter.matches(\"decoder.weight\"));\n    }\n\n    #[test]\n    #[cfg(feature = \"std\")]\n    fn regex_patterns() {\n        let filter = PathFilter::new()\n            .with_regex(r\"^encoder\\..*\")\n            .with_regex(r\".*\\.weight$\");\n\n        assert!(filter.matches(\"encoder.layer1.bias\"));\n        assert!(filter.matches(\"decoder.weight\"));\n        assert!(filter.matches(\"encoder.weight\"));\n        assert!(!filter.matches(\"decoder.bias\"));\n    }\n\n    #[test]\n    fn predicates() {\n        fn contains_norm(path: &str, _container_path: &str) -> bool {\n            path.contains(\"norm\")\n        }\n\n        fn is_short(path: &str, _container_path: &str) -> bool {\n            path.len() < 10\n        }\n\n        let filter = PathFilter::new()\n            .with_predicate(contains_norm)\n            .with_predicate(is_short);\n\n        assert!(filter.matches(\"norm.weight\"));\n        assert!(filter.matches(\"layer.norm.bias\"));\n        assert!(filter.matches(\"bias\"));\n        assert!(!filter.matches(\"encoder.decoder.weight.long.name\"));\n    }\n\n    #[test]\n    fn combined_filters() {\n        let filter = PathFilter::new()\n            .with_full_path(\"special.tensor\")\n            .with_predicate(|path, _container_path| path.contains(\"attention\"));\n\n        #[cfg(feature = \"std\")]\n        let filter = filter.with_regex(r\"^encoder\\..*\");\n\n        assert!(filter.matches(\"special.tensor\"));\n        assert!(filter.matches(\"self_attention.query\"));\n\n        #[cfg(feature = \"std\")]\n        assert!(filter.matches(\"encoder.anything\"));\n\n        assert!(!filter.matches(\"decoder.weight\"));\n    }\n\n    #[test]\n    fn or_combination() {\n        let encoder_filter = PathFilter::new().with_full_path(\"encoder.weight\");\n        let decoder_filter = PathFilter::new().with_full_path(\"decoder.bias\");\n\n        let combined = encoder_filter.or(decoder_filter);\n\n        assert!(combined.matches(\"encoder.weight\"));\n        assert!(combined.matches(\"decoder.bias\"));\n        assert!(!combined.matches(\"model.head.weight\"));\n    }\n\n    #[test]\n    #[cfg(feature = \"std\")]\n    fn common_patterns() {\n        // Test encoder pattern\n        let encoder = PathFilter::new().with_regex(r\"^encoder\\..*\");\n        assert!(encoder.matches(\"encoder.weight\"));\n        assert!(!encoder.matches(\"decoder.weight\"));\n\n        // Test weights-only pattern\n        let weights = PathFilter::new().with_regex(r\".*\\.weight$\");\n        assert!(weights.matches(\"encoder.weight\"));\n        assert!(weights.matches(\"decoder.weight\"));\n        assert!(!weights.matches(\"encoder.bias\"));\n\n        // Test layer-specific patterns\n        let layers = PathFilter::new()\n            .with_regex(r\"(^|.*\\.)layers\\.0\\.\")\n            .with_regex(r\"(^|.*\\.)layers\\.2\\.\")\n            .with_regex(r\"(^|.*\\.)layers\\.4\\.\");\n        assert!(layers.matches(\"model.layers.0.weight\"));\n        assert!(layers.matches(\"layers.2.bias\"));\n        assert!(!layers.matches(\"layers.1.weight\"));\n    }\n\n    #[test]\n    fn criteria_count() {\n        let filter = PathFilter::new()\n            .with_full_path(\"path1\")\n            .with_full_path(\"path2\")\n            .with_predicate(|_, _| true);\n\n        #[cfg(feature = \"std\")]\n        let filter = filter.with_regex(\".*\");\n\n        #[cfg(feature = \"std\")]\n        assert_eq!(filter.criteria_count(), 4);\n\n        #[cfg(not(feature = \"std\"))]\n        assert_eq!(filter.criteria_count(), 3);\n    }\n\n    #[test]\n    fn clear_operations() {\n        let mut filter = PathFilter::new().with_full_path(\"test\");\n\n        filter.clear_paths();\n        assert!(!filter.matches(\"test\"));\n\n        filter.clear();\n        assert!(filter.is_empty());\n    }\n\n    #[test]\n    fn container_predicates() {\n        // Filter that matches only Linear module weights\n        let linear_weights = PathFilter::new().with_predicate(|path, container_path| {\n            container_path.split('.').next_back() == Some(\"Linear\") && path.ends_with(\".weight\")\n        });\n\n        assert!(linear_weights.matches_with_container(\"layer1.weight\", \"Linear\"));\n        assert!(!linear_weights.matches_with_container(\"layer1.weight\", \"Conv2d\"));\n        assert!(!linear_weights.matches_with_container(\"layer1.bias\", \"Linear\"));\n\n        // Filter for specific container types\n        let conv_only = PathFilter::new().with_predicate(|_path, container_path| {\n            let last = container_path.split('.').next_back();\n            last == Some(\"Conv2d\") || last == Some(\"ConvTranspose2d\")\n        });\n\n        assert!(conv_only.matches_with_container(\"encoder.weight\", \"Conv2d\"));\n        assert!(conv_only.matches_with_container(\"decoder.weight\", \"ConvTranspose2d\"));\n        assert!(!conv_only.matches_with_container(\"fc.weight\", \"Linear\"));\n\n        // Combine path and container predicates\n        let combined = PathFilter::new()\n            .with_predicate(|path, _container_path| path.starts_with(\"encoder.\"))\n            .with_predicate(|_path, container_path| {\n                container_path.split('.').next_back() == Some(\"BatchNorm2d\")\n            });\n\n        // Should match either condition (OR logic)\n        assert!(combined.matches_with_container(\"encoder.layer1\", \"Linear\"));\n        assert!(combined.matches_with_container(\"decoder.bn\", \"BatchNorm2d\"));\n        assert!(!combined.matches_with_container(\"decoder.layer\", \"Linear\"));\n    }\n\n    #[test]\n    fn container_predicate_with_regex() {\n        // Combine regex patterns with container predicates\n        #[cfg(feature = \"std\")]\n        {\n            let filter = PathFilter::new()\n                .with_regex(r\"^encoder\\..*\")\n                .with_predicate(|path, container_path| {\n                    container_path.split('.').next_back() == Some(\"Linear\")\n                        && path.contains(\".bias\")\n                });\n\n            // Matches due to regex\n            assert!(filter.matches_with_container(\"encoder.layer1.weight\", \"Conv2d\"));\n            // Matches due to container predicate\n            assert!(filter.matches_with_container(\"decoder.fc.bias\", \"Linear\"));\n            // Doesn't match either\n            assert!(!filter.matches_with_container(\"decoder.conv.weight\", \"Conv2d\"));\n        }\n    }\n\n    #[test]\n    fn container_stack_predicates() {\n        // Filter using full container path - only tensors nested in a specific hierarchy\n        let nested_filter = PathFilter::new().with_predicate(|_path, container_path| {\n            // Check if tensor is nested within: Model -> TransformerBlock -> Linear\n            let parts: Vec<&str> = container_path.split('.').collect();\n            parts.len() >= 3\n                && parts[0] == \"Model\"\n                && parts[1] == \"TransformerBlock\"\n                && parts[2] == \"Linear\"\n        });\n\n        assert!(nested_filter.matches_with_container_path_str(\n            \"encoder.weight\",\n            \"Model.TransformerBlock.Linear.Param\"\n        ));\n        assert!(\n            !nested_filter\n                .matches_with_container_path_str(\"decoder.weight\", \"Model.Decoder.Linear.Param\")\n        );\n        assert!(!nested_filter.matches_with_container_path_str(\n            \"encoder.weight\",\n            \"Model.TransformerBlock.Conv2d.Param\"\n        ));\n\n        // Filter that checks for specific depth in hierarchy\n        let depth_filter = PathFilter::new().with_predicate(|_path, container_path| {\n            let parts: Vec<&str> = container_path.split('.').collect();\n            parts.len() == 4 && parts.get(2) == Some(&\"Linear\")\n        });\n\n        assert!(depth_filter.matches_with_container_path_str(\n            \"model.layer.weight\",\n            \"Model.TransformerBlock.Linear.Param\"\n        ));\n        assert!(\n            !depth_filter\n                .matches_with_container_path_str(\"model.weight\", \"Model.TransformerBlock.Conv2d\")\n        ); // Too shallow\n\n        // Filter that checks any Linear in the path (not just the last)\n        let any_linear = PathFilter::new()\n            .with_predicate(|_path, container_path| container_path.contains(\"Linear\"));\n\n        assert!(\n            any_linear.matches_with_container_path_str(\n                \"some.path\",\n                \"Model.TransformerBlock.Linear.Param\"\n            )\n        );\n        assert!(\n            any_linear.matches_with_container_path_str(\"other.path\", \"Model.Decoder.Linear.Param\")\n        );\n        assert!(\n            !any_linear.matches_with_container_path_str(\n                \"conv.path\",\n                \"Model.TransformerBlock.Conv2d.Param\"\n            )\n        );\n    }\n\n    #[test]\n    fn container_path_dot_notation() {\n        // Filter using dot-notated container path\n        let dot_filter = PathFilter::new().with_predicate(|_path, container_path| {\n            container_path.starts_with(\"Model.TransformerBlock\")\n        });\n\n        // Test with matches_with_container_path\n        assert!(\n            dot_filter.matches_with_container_path_str(\"weight\", \"Model.TransformerBlock.Linear\")\n        );\n        assert!(!dot_filter.matches_with_container_path_str(\"weight\", \"Model.Decoder.Linear\"));\n\n        // Filter that checks for specific patterns in container path\n        let pattern_filter = PathFilter::new().with_predicate(|_path, container_path| {\n            // Match any path that has Linear after a block\n            container_path.contains(\"Block.Linear\") || container_path.contains(\"Block.Conv\")\n        });\n\n        assert!(\n            pattern_filter\n                .matches_with_container_path_str(\"weight\", \"Model.TransformerBlock.Linear\")\n        );\n        assert!(pattern_filter.matches_with_container_path_str(\"weight\", \"Model.ResBlock.Conv2d\"));\n        assert!(!pattern_filter.matches_with_container_path_str(\"weight\", \"Model.Linear.Param\"));\n\n        // Filter combining path and container path patterns\n        let combined = PathFilter::new().with_predicate(|path, container_path| {\n            // Only weights in Linear layers that are inside blocks\n            path.ends_with(\".weight\")\n                && container_path.contains(\"Block\")\n                && container_path.split('.').next_back() == Some(\"Linear\")\n        });\n\n        assert!(\n            combined\n                .matches_with_container_path_str(\"layer.weight\", \"Model.TransformerBlock.Linear\")\n        );\n        assert!(\n            !combined\n                .matches_with_container_path_str(\"layer.bias\", \"Model.TransformerBlock.Linear\")\n        );\n        assert!(!combined.matches_with_container_path_str(\"layer.weight\", \"Model.Decoder.Linear\"));\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/keyremapper.rs",
    "content": "use alloc::collections::BTreeMap;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\n\nuse regex::{self, Regex};\n\nuse crate::TensorSnapshot;\n\n/// Key remapper for transforming tensor names.\n///\n/// This allows mapping tensor names from one naming convention to another,\n/// which is useful for loading models from different frameworks or versions.\n///\n/// # Examples\n///\n/// ```rust\n/// # use burn_store::KeyRemapper;\n/// // Create a key remapper\n/// let remapper = KeyRemapper::new()\n///     .add_pattern(r\"^pytorch\\.(.*)\", \"burn.$1\").expect(\"valid regex\")  // pytorch.layer -> burn.layer\n///     .add_pattern(r\"\\.gamma$\", \".weight\").expect(\"valid regex\");       // layer.gamma -> layer.weight\n///\n/// // Use remapper with stores\n/// // store.remap(remapper)\n/// ```\n#[derive(Debug, Clone, Default)]\npub struct KeyRemapper {\n    /// Pattern-based remapping rules (regex pattern, replacement string)\n    pub patterns: Vec<(Regex, String)>,\n}\n\nimpl KeyRemapper {\n    /// Create a new empty key remapper\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Add a remapping pattern (compiles regex)\n    ///\n    /// # Arguments\n    ///\n    /// * `from` - Source pattern (regex string)\n    /// * `to` - Replacement string (can include capture groups like `$1`)\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(Self)` - Updated remapping configuration\n    /// * `Err(regex::Error)` - If regex compilation fails\n    pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>\n    where\n        S1: AsRef<str>,\n        S2: Into<String>,\n    {\n        let regex = Regex::new(from.as_ref())?;\n        self.patterns.push((regex, to.into()));\n        Ok(self)\n    }\n\n    /// Create from a list of compiled regex patterns\n    pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {\n        Self { patterns }\n    }\n\n    /// Create from string patterns (will compile to regex)\n    ///\n    /// # Arguments\n    ///\n    /// * `patterns` - Vector of (pattern, replacement) tuples\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(Self)` - New remapping configuration\n    /// * `Err(regex::Error)` - If any regex compilation fails\n    pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>\n    where\n        S1: AsRef<str>,\n        S2: Into<String>,\n    {\n        let mut compiled_patterns = Vec::new();\n        for (pattern, replacement) in patterns {\n            let regex = Regex::new(pattern.as_ref())?;\n            compiled_patterns.push((regex, replacement.into()));\n        }\n        Ok(Self {\n            patterns: compiled_patterns,\n        })\n    }\n\n    /// Create from an iterator of patterns\n    ///\n    /// # Arguments\n    ///\n    /// * `iter` - Iterator yielding (pattern, replacement) tuples\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(Self)` - New remapping configuration\n    /// * `Err(regex::Error)` - If any regex compilation fails\n    pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>\n    where\n        I: IntoIterator<Item = (S1, S2)>,\n        S1: AsRef<str>,\n        S2: Into<String>,\n    {\n        let patterns: Result<Vec<_>, _> = iter\n            .into_iter()\n            .map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))\n            .collect();\n        Ok(Self {\n            patterns: patterns?,\n        })\n    }\n\n    /// Check if the remapping is empty\n    pub fn is_empty(&self) -> bool {\n        self.patterns.is_empty()\n    }\n\n    /// Convert to the format expected by remap_tensor_paths_with_patterns\n    pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {\n        self.patterns.clone()\n    }\n\n    /// Remap tensor paths using the configured patterns.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensors` - Vec of TensorSnapshots to remap\n    ///\n    /// # Returns\n    ///\n    /// A tuple containing:\n    /// * The remapped Vec of TensorSnapshots with updated paths\n    /// * A vector of (new_path, original_path) showing the transformations\n    pub fn remap(\n        &self,\n        mut tensors: Vec<TensorSnapshot>,\n    ) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {\n        if self.patterns.is_empty() {\n            let remapped_names = tensors\n                .iter()\n                .map(|v| {\n                    let path = v.full_path();\n                    (path.clone(), path)\n                })\n                .collect();\n            return (tensors, remapped_names);\n        }\n\n        let mut remapped_snapshots = Vec::new();\n        let mut remapped_names = Vec::new();\n\n        for mut snapshot in tensors.drain(..) {\n            let original_path = snapshot.full_path();\n            let mut new_path = original_path.clone();\n\n            // Apply all patterns to get the new path\n            for (pattern, replacement) in &self.patterns {\n                if pattern.is_match(&new_path) {\n                    new_path = pattern\n                        .replace_all(&new_path, replacement.as_str())\n                        .to_string();\n                }\n            }\n\n            // Update the snapshot's internal path_stack if the path changed\n            if new_path != original_path\n                && let Some(ref mut path_stack) = snapshot.path_stack\n            {\n                *path_stack = new_path.split('.').map(|s| s.to_string()).collect();\n            }\n\n            remapped_names.push((new_path.clone(), original_path));\n            remapped_snapshots.push(snapshot);\n        }\n\n        (remapped_snapshots, remapped_names)\n    }\n}\n\n/// Map tensor paths to have contiguous numeric indices.\n///\n/// This function detects numeric indices in tensor paths and renumbers them\n/// to be contiguous (0, 1, 2, ...) while preserving their relative order.\n/// It handles nested sequential structures by processing ALL numeric indices\n/// in each path independently based on their position context.\n///\n/// This is useful when loading PyTorch models that have gaps in layer numbering,\n/// such as when using `nn.Sequential` with mixed layer types (e.g., Conv2d + ReLU\n/// where only Conv2d has parameters).\n///\n/// # Example\n///\n/// Simple case - input paths:\n/// - `fc.0.weight`, `fc.0.bias`\n/// - `fc.2.weight`, `fc.2.bias`\n/// - `fc.4.weight`, `fc.4.bias`\n///\n/// Output paths:\n/// - `fc.0.weight`, `fc.0.bias`\n/// - `fc.1.weight`, `fc.1.bias`\n/// - `fc.2.weight`, `fc.2.bias`\n///\n/// Nested case - input paths:\n/// - `feature.layers.0.conv_block.0.weight`\n/// - `feature.layers.0.conv_block.2.weight`\n/// - `feature.layers.2.conv_block.0.weight`\n/// - `feature.layers.2.conv_block.2.weight`\n///\n/// Output paths:\n/// - `feature.layers.0.conv_block.0.weight`\n/// - `feature.layers.0.conv_block.1.weight`\n/// - `feature.layers.1.conv_block.0.weight`\n/// - `feature.layers.1.conv_block.1.weight`\n///\n/// # Arguments\n///\n/// * `tensors` - Vec of TensorSnapshots to map\n///\n/// # Returns\n///\n/// A tuple containing:\n/// * The mapped Vec of TensorSnapshots with updated paths\n/// * A vector of (new_path, original_path) showing the transformations\npub fn map_indices_contiguous(\n    mut tensors: Vec<TensorSnapshot>,\n) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {\n    if tensors.is_empty() {\n        return (tensors, Vec::new());\n    }\n\n    // Step 1: Collect all paths and find all index positions\n    // For each index position (identified by prefix using ORIGINAL indices),\n    // collect all indices seen at that position.\n    //\n    // Key: prefix using original path (e.g., \"feature.layers.\" or \"feature.layers.0.conv_block.\")\n    // Value: BTreeMap of original_index -> new_index\n    let mut index_maps: BTreeMap<String, BTreeMap<usize, usize>> = BTreeMap::new();\n\n    // First pass: collect all indices at each position using original prefixes\n    for snapshot in &tensors {\n        let path = snapshot.full_path();\n        let parts: Vec<&str> = path.split('.').collect();\n\n        // Check each part for numeric indices\n        for (i, part) in parts.iter().enumerate() {\n            if let Ok(index) = part.parse::<usize>() {\n                // The prefix is everything before this index (using original path)\n                let prefix = if i > 0 {\n                    format!(\"{}.\", parts[..i].join(\".\"))\n                } else {\n                    String::new()\n                };\n\n                index_maps\n                    .entry(prefix)\n                    .or_default()\n                    .entry(index)\n                    .or_insert(usize::MAX); // Placeholder\n            }\n        }\n    }\n\n    // Second pass: assign contiguous indices for each position\n    for indices in index_maps.values_mut() {\n        let mut sorted_indices: Vec<usize> = indices.keys().cloned().collect();\n        sorted_indices.sort();\n\n        for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() {\n            indices.insert(old_idx, new_idx);\n        }\n    }\n\n    // Third pass: apply the remapping to all tensors\n    // We use original prefixes for lookup since that's how we collected indices\n    let mut mapped_snapshots = Vec::new();\n    let mut transformations = Vec::new();\n\n    for mut snapshot in tensors.drain(..) {\n        let original_path = snapshot.full_path();\n        let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps);\n\n        // Update the snapshot's internal path_stack if the path changed\n        if new_path != original_path\n            && let Some(ref mut path_stack) = snapshot.path_stack\n        {\n            *path_stack = new_path.split('.').map(|s| s.to_string()).collect();\n        }\n\n        transformations.push((new_path, original_path));\n        mapped_snapshots.push(snapshot);\n    }\n\n    (mapped_snapshots, transformations)\n}\n\n/// Remap all numeric indices in a path using the provided index maps.\n/// Uses original path prefixes for lookup.\nfn remap_all_indices_with_original_prefix(\n    path: &str,\n    index_maps: &BTreeMap<String, BTreeMap<usize, usize>>,\n) -> String {\n    let parts: Vec<&str> = path.split('.').collect();\n    let mut result_parts: Vec<String> = Vec::with_capacity(parts.len());\n\n    for (i, part) in parts.iter().enumerate() {\n        if let Ok(index) = part.parse::<usize>() {\n            // Build the prefix from ORIGINAL parts (not remapped)\n            let prefix = if i > 0 {\n                format!(\"{}.\", parts[..i].join(\".\"))\n            } else {\n                String::new()\n            };\n\n            // Look up the new index using original prefix\n            if let Some(index_map) = index_maps.get(&prefix)\n                && let Some(&new_index) = index_map.get(&index)\n            {\n                result_parts.push(new_index.to_string());\n                continue;\n            }\n        }\n        // Not a numeric index or no mapping found, keep as-is\n        result_parts.push((*part).to_string());\n    }\n\n    result_parts.join(\".\")\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use super::*;\n    use burn_core::module::ParamId;\n    use burn_tensor::{TensorData, shape};\n\n    fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {\n        let data = TensorData {\n            bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),\n            shape: shape![2, 2],\n            dtype: burn_tensor::DType::F32,\n        };\n        let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();\n        TensorSnapshot::from_data(data, path_parts, vec![\"Test\".to_string()], ParamId::new())\n    }\n\n    #[test]\n    fn test_key_remapper_basic() {\n        let remapper = KeyRemapper::new()\n            .add_pattern(r\"^encoder\\.\", \"transformer.encoder.\")\n            .expect(\"valid regex\");\n\n        let tensors = vec![\n            create_test_tensor_snapshot(\"encoder.layer1.weight\"),\n            create_test_tensor_snapshot(\"decoder.layer1.weight\"),\n        ];\n\n        let (remapped, transformations) = remapper.remap(tensors);\n\n        // Check that remapped views exist with correct paths\n        assert!(\n            remapped\n                .iter()\n                .any(|v| v.full_path() == \"transformer.encoder.layer1.weight\")\n        );\n        assert!(\n            remapped\n                .iter()\n                .any(|v| v.full_path() == \"decoder.layer1.weight\")\n        );\n        assert_eq!(remapped.len(), 2);\n\n        // Check transformations\n        let encoder_transform = transformations\n            .iter()\n            .find(|(_new, old)| old == \"encoder.layer1.weight\")\n            .expect(\"should find encoder transformation\");\n        assert_eq!(encoder_transform.0, \"transformer.encoder.layer1.weight\");\n    }\n\n    #[test]\n    fn test_key_remapper_multiple_patterns() {\n        let remapper = KeyRemapper::new()\n            .add_pattern(r\"^encoder\\.\", \"transformer.encoder.\")\n            .expect(\"valid regex\")\n            .add_pattern(r\"\\.gamma$\", \".weight\")\n            .expect(\"valid regex\");\n\n        let tensors = vec![create_test_tensor_snapshot(\"encoder.layer1.gamma\")];\n\n        let (remapped, _) = remapper.remap(tensors);\n\n        assert!(\n            remapped\n                .iter()\n                .any(|v| v.full_path() == \"transformer.encoder.layer1.weight\")\n        );\n        assert_eq!(remapped.len(), 1);\n    }\n\n    #[test]\n    fn test_key_remapper_from_patterns() {\n        let patterns = vec![(r\"^pytorch\\.\", \"burn.\"), (r\"\\.bias$\", \".bias_param\")];\n        let remapper = KeyRemapper::from_patterns(patterns).expect(\"valid patterns\");\n\n        let tensors = vec![create_test_tensor_snapshot(\"pytorch.linear.bias\")];\n\n        let (remapped, _) = remapper.remap(tensors);\n\n        assert!(\n            remapped\n                .iter()\n                .any(|v| v.full_path() == \"burn.linear.bias_param\")\n        );\n    }\n\n    #[test]\n    fn test_key_remapper_empty() {\n        let remapper = KeyRemapper::new();\n        assert!(remapper.is_empty());\n\n        let tensors = vec![create_test_tensor_snapshot(\"test.weight\")];\n\n        let (remapped, transformations) = remapper.remap(tensors);\n\n        assert!(remapped.iter().any(|v| v.full_path() == \"test.weight\"));\n        assert_eq!(remapped.len(), 1);\n        assert_eq!(transformations.len(), 1);\n        assert_eq!(\n            transformations[0],\n            (\"test.weight\".to_string(), \"test.weight\".to_string())\n        );\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_basic() {\n        // Simulate PyTorch nn.Sequential with Conv2d (0, 2, 4) and ReLU (1, 3, 5)\n        // Only Conv2d layers have parameters\n        let tensors = vec![\n            create_test_tensor_snapshot(\"fc.0.weight\"),\n            create_test_tensor_snapshot(\"fc.0.bias\"),\n            create_test_tensor_snapshot(\"fc.2.weight\"),\n            create_test_tensor_snapshot(\"fc.2.bias\"),\n            create_test_tensor_snapshot(\"fc.4.weight\"),\n            create_test_tensor_snapshot(\"fc.4.bias\"),\n        ];\n\n        let (reindexed, transformations) = map_indices_contiguous(tensors);\n\n        // Check that indices are now contiguous\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.0.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.0.bias\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.1.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.1.bias\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.2.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.2.bias\"));\n        assert_eq!(reindexed.len(), 6);\n\n        // Check transformations\n        let transform_2_to_1 = transformations\n            .iter()\n            .find(|(_, old)| old == \"fc.2.weight\")\n            .expect(\"should find fc.2.weight transformation\");\n        assert_eq!(transform_2_to_1.0, \"fc.1.weight\");\n\n        let transform_4_to_2 = transformations\n            .iter()\n            .find(|(_, old)| old == \"fc.4.weight\")\n            .expect(\"should find fc.4.weight transformation\");\n        assert_eq!(transform_4_to_2.0, \"fc.2.weight\");\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_already_contiguous() {\n        // Already contiguous indices should remain unchanged\n        let tensors = vec![\n            create_test_tensor_snapshot(\"fc.0.weight\"),\n            create_test_tensor_snapshot(\"fc.1.weight\"),\n            create_test_tensor_snapshot(\"fc.2.weight\"),\n        ];\n\n        let (reindexed, transformations) = map_indices_contiguous(tensors);\n\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.0.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.1.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.2.weight\"));\n        assert_eq!(reindexed.len(), 3);\n\n        // All transformations should have same old and new paths\n        for (new, old) in &transformations {\n            assert_eq!(new, old);\n        }\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_multiple_prefixes() {\n        // Different prefixes should be mapped independently\n        let tensors = vec![\n            create_test_tensor_snapshot(\"encoder.0.weight\"),\n            create_test_tensor_snapshot(\"encoder.2.weight\"),\n            create_test_tensor_snapshot(\"decoder.1.weight\"),\n            create_test_tensor_snapshot(\"decoder.5.weight\"),\n        ];\n\n        let (reindexed, _) = map_indices_contiguous(tensors);\n\n        // encoder: 0, 2 -> 0, 1\n        assert!(\n            reindexed\n                .iter()\n                .any(|v| v.full_path() == \"encoder.0.weight\")\n        );\n        assert!(\n            reindexed\n                .iter()\n                .any(|v| v.full_path() == \"encoder.1.weight\")\n        );\n\n        // decoder: 1, 5 -> 0, 1\n        assert!(\n            reindexed\n                .iter()\n                .any(|v| v.full_path() == \"decoder.0.weight\")\n        );\n        assert!(\n            reindexed\n                .iter()\n                .any(|v| v.full_path() == \"decoder.1.weight\")\n        );\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_no_indices() {\n        // Paths without indices should remain unchanged\n        let tensors = vec![\n            create_test_tensor_snapshot(\"encoder.weight\"),\n            create_test_tensor_snapshot(\"decoder.bias\"),\n        ];\n\n        let (reindexed, transformations) = map_indices_contiguous(tensors);\n\n        assert!(reindexed.iter().any(|v| v.full_path() == \"encoder.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"decoder.bias\"));\n\n        for (new, old) in &transformations {\n            assert_eq!(new, old);\n        }\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_empty() {\n        let tensors: Vec<TensorSnapshot> = vec![];\n        let (reindexed, transformations) = map_indices_contiguous(tensors);\n\n        assert!(reindexed.is_empty());\n        assert!(transformations.is_empty());\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() {\n        // Mix of indexed and non-indexed paths\n        let tensors = vec![\n            create_test_tensor_snapshot(\"fc.0.weight\"),\n            create_test_tensor_snapshot(\"fc.2.weight\"),\n            create_test_tensor_snapshot(\"output.weight\"), // no index\n        ];\n\n        let (reindexed, _) = map_indices_contiguous(tensors);\n\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.0.weight\"));\n        assert!(reindexed.iter().any(|v| v.full_path() == \"fc.1.weight\")); // 2 -> 1\n        assert!(reindexed.iter().any(|v| v.full_path() == \"output.weight\")); // unchanged\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_nested_sequential() {\n        // Test nested sequential structures like:\n        // feature = nn.Sequential(ConvBlock, ReLU, ConvBlock, ReLU, ConvBlock)\n        // where ConvBlock = nn.Sequential(Conv2d, ReLU, Conv2d)\n        //\n        // This produces paths like:\n        // feature.layers.0.conv_block.0.weight (layer 0, conv 0)\n        // feature.layers.0.conv_block.2.weight (layer 0, conv 2 - skipping ReLU at 1)\n        // feature.layers.2.conv_block.0.weight (layer 2 - skipping ReLU at 1, conv 0)\n        // feature.layers.2.conv_block.2.weight (layer 2, conv 2)\n        let tensors = vec![\n            create_test_tensor_snapshot(\"feature.layers.0.conv_block.0.weight\"),\n            create_test_tensor_snapshot(\"feature.layers.0.conv_block.2.weight\"),\n            create_test_tensor_snapshot(\"feature.layers.2.conv_block.0.weight\"),\n            create_test_tensor_snapshot(\"feature.layers.2.conv_block.2.weight\"),\n        ];\n\n        let (mapped, transformations) = map_indices_contiguous(tensors);\n\n        // Expected mapping:\n        // feature.layers: 0, 2 -> 0, 1\n        // feature.layers.0.conv_block: 0, 2 -> 0, 1\n        // feature.layers.2.conv_block: 0, 2 -> 0, 1\n        //\n        // Result:\n        // feature.layers.0.conv_block.0.weight -> feature.layers.0.conv_block.0.weight\n        // feature.layers.0.conv_block.2.weight -> feature.layers.0.conv_block.1.weight\n        // feature.layers.2.conv_block.0.weight -> feature.layers.1.conv_block.0.weight\n        // feature.layers.2.conv_block.2.weight -> feature.layers.1.conv_block.1.weight\n\n        assert!(\n            mapped\n                .iter()\n                .any(|v| v.full_path() == \"feature.layers.0.conv_block.0.weight\"),\n            \"0.0 should stay as 0.0\"\n        );\n        assert!(\n            mapped\n                .iter()\n                .any(|v| v.full_path() == \"feature.layers.0.conv_block.1.weight\"),\n            \"0.2 should become 0.1\"\n        );\n        assert!(\n            mapped\n                .iter()\n                .any(|v| v.full_path() == \"feature.layers.1.conv_block.0.weight\"),\n            \"2.0 should become 1.0\"\n        );\n        assert!(\n            mapped\n                .iter()\n                .any(|v| v.full_path() == \"feature.layers.1.conv_block.1.weight\"),\n            \"2.2 should become 1.1\"\n        );\n\n        // Verify specific transformations\n        let t1 = transformations\n            .iter()\n            .find(|(_, old)| old == \"feature.layers.2.conv_block.2.weight\");\n        assert_eq!(\n            t1.map(|(new, _)| new.as_str()),\n            Some(\"feature.layers.1.conv_block.1.weight\"),\n            \"2.2 should map to 1.1\"\n        );\n    }\n\n    #[test]\n    fn test_map_indices_contiguous_deeply_nested() {\n        // Test with three levels of nesting\n        let tensors = vec![\n            create_test_tensor_snapshot(\"a.0.b.0.c.0.weight\"),\n            create_test_tensor_snapshot(\"a.0.b.0.c.2.weight\"),\n            create_test_tensor_snapshot(\"a.0.b.2.c.0.weight\"),\n            create_test_tensor_snapshot(\"a.2.b.0.c.0.weight\"),\n        ];\n\n        let (mapped, _) = map_indices_contiguous(tensors);\n\n        // a: 0, 2 -> 0, 1\n        // a.0.b: 0, 2 -> 0, 1\n        // a.2.b: 0 -> 0\n        // a.0.b.0.c: 0, 2 -> 0, 1\n        // a.0.b.2.c: 0 -> 0\n        // a.2.b.0.c: 0 -> 0\n\n        assert!(mapped.iter().any(|v| v.full_path() == \"a.0.b.0.c.0.weight\"));\n        assert!(\n            mapped.iter().any(|v| v.full_path() == \"a.0.b.0.c.1.weight\"),\n            \"a.0.b.0.c.2 should become a.0.b.0.c.1\"\n        );\n        assert!(\n            mapped.iter().any(|v| v.full_path() == \"a.0.b.1.c.0.weight\"),\n            \"a.0.b.2.c.0 should become a.0.b.1.c.0\"\n        );\n        assert!(\n            mapped.iter().any(|v| v.full_path() == \"a.1.b.0.c.0.weight\"),\n            \"a.2.b.0.c.0 should become a.1.b.0.c.0\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n\n//! # Burn Store\n//!\n//! Advanced model storage and serialization infrastructure for the Burn deep learning framework.\n//!\n//! This crate provides comprehensive functionality for storing and loading Burn modules\n//! and their tensor data, with support for cross-framework interoperability, flexible filtering,\n//! and efficient memory management through lazy materialization.\n//!\n//! ## Key Features\n//!\n//! - **Burnpack Format**: Native Burn format with CBOR metadata, ParamId persistence for stateful training, and no-std support\n//! - **SafeTensors Format**: Industry-standard format for secure and efficient tensor serialization\n//! - **PyTorch Compatibility**: Load PyTorch models directly into Burn with automatic weight transformation\n//! - **Zero-Copy Loading**: Memory-mapped files and lazy tensor materialization for optimal performance\n//! - **Flexible Filtering**: Load/save specific model subsets using regex, exact paths, or custom predicates\n//! - **Tensor Remapping**: Rename tensors during load/save operations for framework compatibility\n//! - **No-std Support**: Core functionality available in embedded and WASM environments\n//!\n//! ## Quick Start\n//!\n//! ### Basic Save and Load\n//!\n//! ```rust,ignore\n//! use burn_store::{ModuleSnapshot, SafetensorsStore};\n//!\n//! // Save a model\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\");\n//! model.save_into(&mut store)?;\n//!\n//! // Load a model\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\");\n//! model.load_from(&mut store)?;\n//! ```\n//!\n//! ### Loading PyTorch Models\n//!\n//! ```rust,ignore\n//! use burn_store::PytorchStore;\n//!\n//! // Load PyTorch model (automatic weight transformation via PyTorchToBurnAdapter)\n//! let mut store = PytorchStore::from_file(\"pytorch_model.pth\")\n//!     .with_top_level_key(\"state_dict\")  // Access nested state dict if needed\n//!     .allow_partial(true);               // Skip unknown tensors\n//!\n//! model.load_from(&mut store)?;\n//! ```\n//!\n//! ### Filtering and Remapping\n//!\n//! ```rust,no_run\n//! # use burn_store::SafetensorsStore;\n//! // Save only specific layers with renaming\n//! let mut store = SafetensorsStore::from_file(\"encoder.safetensors\")\n//!     .with_regex(r\"^encoder\\..*\")                         // Filter: only encoder layers\n//!     .with_key_remapping(r\"^encoder\\.\", \"transformer.\")   // Rename: encoder.X -> transformer.X\n//!     .metadata(\"subset\", \"encoder_only\");\n//!\n//! // Use store with model.save_into(&mut store)?;\n//! ```\n//!\n//! ## Core Components\n//!\n//! - [`ModuleSnapshot`]: Extension trait for Burn modules providing `collect()` and `apply()` methods\n//! - [`BurnpackStore`]: Native Burn format with ParamId persistence for stateful training workflows\n//! - [`SafetensorsStore`]: Primary storage implementation supporting the SafeTensors format\n//! - [`PytorchStore`]: PyTorch model loader supporting .pth and .pt files\n//! - [`PathFilter`]: Flexible filtering system for selective tensor loading/saving\n//! - [`KeyRemapper`]: Advanced tensor name remapping with regex patterns\n//! - [`ModuleAdapter`]: Framework adapters for cross-framework compatibility\n//!\n//! ## Feature Flags\n//!\n//! - `std`: Enables file I/O and other std-only features (default)\n//! - `safetensors`: Enables SafeTensors format support (default)\n\nextern crate alloc;\n\nmod adapter;\nmod applier;\nmod apply_result;\nmod collector;\nmod filter;\nmod tensor_snapshot;\nmod traits;\n\npub use adapter::{\n    BurnToPyTorchAdapter, ChainAdapter, HalfPrecisionAdapter, IdentityAdapter, ModuleAdapter,\n    PyTorchToBurnAdapter,\n};\npub use applier::Applier;\npub use apply_result::{ApplyError, ApplyResult};\npub use collector::Collector;\npub use filter::PathFilter;\npub use tensor_snapshot::{TensorSnapshot, TensorSnapshotError};\npub use traits::{ModuleSnapshot, ModuleStore};\n\n#[cfg(feature = \"std\")]\nmod keyremapper;\n#[cfg(feature = \"std\")]\npub use keyremapper::{KeyRemapper, map_indices_contiguous};\n\n#[cfg(feature = \"pytorch\")]\npub mod pytorch;\n#[cfg(feature = \"pytorch\")]\npub use pytorch::{PytorchStore, PytorchStoreError};\n\n#[cfg(feature = \"safetensors\")]\nmod safetensors;\n#[cfg(feature = \"safetensors\")]\npub use safetensors::{SafetensorsStore, SafetensorsStoreError};\n\n#[cfg(feature = \"burnpack\")]\nmod burnpack;\n#[cfg(feature = \"burnpack\")]\npub use burnpack::writer::BurnpackWriter;\n#[cfg(feature = \"burnpack\")]\npub use burnpack::{base::BurnpackError, store::BurnpackStore};\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/lazy_data.rs",
    "content": "//! Lazy data loading support for PyTorch files.\n//!\n//! This module provides abstractions for lazy loading of tensor data from PyTorch files,\n//! avoiding the need to load all data into memory upfront.\n\nuse alloc::string::String;\nuse alloc::vec::Vec;\nuse std::collections::HashMap;\nuse std::fs::File;\nuse std::io::{BufReader, Read, Seek};\nuse std::path::{Path, PathBuf};\nuse std::sync::{Arc, Mutex, RwLock};\nuse zip::ZipArchive;\n\n/// A data source that can lazily load tensor data.\n#[derive(Clone)]\npub enum LazyDataSource {\n    /// ZIP archive with lazy loading\n    Zip(Arc<Mutex<ZipSource>>),\n    /// TAR archive format (older torchvision models)\n    Tar(Arc<Mutex<TarSource>>),\n    /// Legacy format with multiple storages in single blob\n    LegacyMultiStorage(Arc<Mutex<LegacyMultiStorageSource>>),\n}\n\n/// ZIP archive source for lazy loading\npub struct ZipSource {\n    path: PathBuf,\n    // Cache the file list to avoid reopening archive repeatedly\n    file_list: Vec<(String, u64, u64)>, // (name, offset, compressed_size)\n}\n\n/// TAR archive source for lazy loading (older torchvision models like AlexNet, SqueezeNet)\n///\n/// Older PyTorch/torchvision models (pre-1.6) use TAR format instead of ZIP.\n/// The TAR archive contains:\n/// - `sys_info`: System info pickle (endianness, type sizes)\n/// - `pickle`: OrderedDict mapping tensor names to storage keys\n/// - `tensors`: Tensor metadata pickles (unused, metadata is embedded in pickle)\n/// - `storages`: Storage count + sequential (metadata pickle, element count, raw data)\npub struct TarSource {\n    /// Cached storage map: storage_key -> (offset_in_storages, size_bytes)\n    storage_map: HashMap<String, (usize, usize)>,\n    /// The raw storages data (kept in memory for TAR format)\n    storages_data: Vec<u8>,\n}\n\n/// Legacy multi-storage source for old PyTorch format (0.1.10 - 1.5)\n///\n/// Legacy format stores tensor data as concatenated raw binary without explicit\n/// storage boundaries. This source tracks storage usage during tensor parsing\n/// to build a storage map for lazy loading.\n///\n/// ## Storage Layout\n/// - Pickle metadata with tensor definitions\n/// - List of storage keys (determines concatenation order)\n/// - Raw binary blob with all storages concatenated\npub struct LegacyMultiStorageSource {\n    path: PathBuf,\n    data_offset: u64,\n    #[allow(dead_code)]\n    data_size: u64,\n    // Map of storage_key -> (offset_in_blob, size)\n    storage_map: RwLock<Option<HashMap<String, (u64, u64)>>>,\n    // Storage keys in order (for boundary calculation)\n    storage_keys: RwLock<Option<Vec<String>>>,\n    // Track storage usage as tensors are accessed\n    storage_usage: RwLock<HashMap<String, usize>>, // key -> max_bytes_needed\n}\n\nimpl ZipSource {\n    /// Create a new ZIP source\n    pub fn new(path: PathBuf) -> std::io::Result<Self> {\n        let file = File::open(&path)?;\n        let reader = BufReader::new(file);\n        let mut archive = ZipArchive::new(reader)?;\n\n        // Cache file metadata\n        let mut file_list = Vec::new();\n        for i in 0..archive.len() {\n            let file = archive.by_index(i)?;\n            let name = file.name().to_string();\n            let offset = file.data_start();\n            let compressed_size = file.compressed_size();\n            file_list.push((\n                name,\n                offset.expect(\"should have an offset\"),\n                compressed_size,\n            ));\n        }\n\n        Ok(Self { path, file_list })\n    }\n\n    /// Check if a file exists in the archive\n    pub fn contains(&self, name: &str) -> bool {\n        self.file_list.iter().any(|(n, _, _)| n == name)\n    }\n\n    /// Get list of data files (excluding pickle files)\n    pub fn data_files(&self) -> Vec<String> {\n        self.file_list\n            .iter()\n            .filter(|(name, _, _)| name.starts_with(\"data/\") || name.contains(\"/data/\"))\n            .filter(|(name, _, _)| !name.ends_with(\".pkl\") && !name.ends_with(\"/\"))\n            .map(|(name, _, _)| name.clone())\n            .collect()\n    }\n\n    /// Read a specific file from the archive\n    pub fn read_file(&self, name: &str) -> std::io::Result<Vec<u8>> {\n        let file = File::open(&self.path)?;\n        let reader = BufReader::new(file);\n        let mut archive = ZipArchive::new(reader)?;\n\n        let mut file = archive.by_name(name)?;\n        let mut contents = Vec::with_capacity(file.size() as usize);\n        file.read_to_end(&mut contents)?;\n        Ok(contents)\n    }\n\n    /// Read a portion of a file\n    pub fn read_file_range(\n        &self,\n        name: &str,\n        offset: usize,\n        length: usize,\n    ) -> std::io::Result<Vec<u8>> {\n        let file = File::open(&self.path)?;\n        let reader = BufReader::new(file);\n        let mut archive = ZipArchive::new(reader)?;\n\n        let mut file = archive.by_name(name)?;\n        let mut buffer = vec![0u8; length];\n\n        // Skip to offset\n        let mut skip_buffer = vec![0u8; offset.min(8192)];\n        let mut skipped = 0;\n        while skipped < offset {\n            let to_skip = (offset - skipped).min(skip_buffer.len());\n            file.read_exact(&mut skip_buffer[..to_skip])?;\n            skipped += to_skip;\n        }\n\n        // Read the requested data\n        file.read_exact(&mut buffer)?;\n        Ok(buffer)\n    }\n}\n\nimpl LegacyMultiStorageSource {\n    /// Create a new legacy multi-storage source\n    pub fn new(path: PathBuf, data_offset: u64, data_size: u64) -> Self {\n        Self {\n            path,\n            data_offset,\n            data_size,\n            storage_map: RwLock::new(None),\n            storage_keys: RwLock::new(None),\n            storage_usage: RwLock::new(HashMap::new()),\n        }\n    }\n\n    /// Set the ordered storage keys from the pickle\n    pub fn set_storage_keys(&self, keys: Vec<String>) {\n        let mut storage_keys = self\n            .storage_keys\n            .write()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n        *storage_keys = Some(keys);\n    }\n\n    /// Track storage usage from tensor access\n    /// This is called from within tensor loading closures\n    pub fn track_storage_usage(&self, storage_key: &str, offset: usize, size: usize) {\n        let mut usage = self\n            .storage_usage\n            .write()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n        let max_extent = offset + size;\n        usage\n            .entry(storage_key.to_string())\n            .and_modify(|current| *current = (*current).max(max_extent))\n            .or_insert(max_extent);\n\n        // Try to build storage map if we have enough information\n        drop(usage);\n        self.try_build_storage_map();\n    }\n\n    /// Try to build the storage map from tracked usage\n    fn try_build_storage_map(&self) {\n        // Only build if we don't already have a map\n        if self\n            .storage_map\n            .read()\n            .unwrap_or_else(|poisoned| poisoned.into_inner())\n            .is_some()\n        {\n            return;\n        }\n\n        // Check if we have storage keys\n        let keys_guard = self\n            .storage_keys\n            .read()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n        if let Some(ref keys) = *keys_guard {\n            let usage = self\n                .storage_usage\n                .read()\n                .unwrap_or_else(|poisoned| poisoned.into_inner());\n\n            // Only build if we have usage info for all storages\n            if keys.iter().all(|k| usage.contains_key(k)) {\n                let mut map = HashMap::new();\n                let mut current_offset = 0u64;\n\n                for key in keys {\n                    if let Some(&size) = usage.get(key) {\n                        map.insert(key.clone(), (current_offset, size as u64));\n                        current_offset += size as u64;\n                    }\n                }\n\n                // Set the storage map\n                drop(keys_guard);\n                drop(usage);\n                let mut storage_map = self\n                    .storage_map\n                    .write()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                *storage_map = Some(map);\n            }\n        }\n    }\n\n    /// Read data for a specific storage key\n    /// Only loads the specific storage portion, never the entire blob\n    pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {\n        // Extract numeric key from paths like \"data/0\" or just \"0\"\n        let storage_key = key.split('/').next_back().unwrap_or(key);\n\n        // Get storage map - must be available for lazy loading to work\n        let storage_map = self\n            .storage_map\n            .read()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n\n        if let Some(ref map) = *storage_map\n            && let Some(&(offset, size)) = map.get(storage_key)\n        {\n            // Load only this specific storage\n            let mut file = File::open(&self.path)?;\n            file.seek(std::io::SeekFrom::Start(self.data_offset + offset))?;\n\n            let mut buffer = vec![0u8; size as usize];\n            file.read_exact(&mut buffer)?;\n            return Ok(buffer);\n        }\n\n        // NO FALLBACK! If we don't have storage boundaries, we cannot load data lazily\n        // The storage map MUST be built from tensor metadata for lazy loading to work\n        Err(std::io::Error::new(\n            std::io::ErrorKind::InvalidData,\n            format!(\n                \"Storage boundaries not available for key '{}'. Cannot perform lazy loading.\",\n                storage_key\n            ),\n        ))\n    }\n}\n\nimpl TarSource {\n    /// Create a new TAR source by parsing storages data.\n    ///\n    /// # Arguments\n    /// * `storages_data` - Raw storages blob with structure:\n    ///   - Count pickle (number of storages)\n    ///   - For each storage: metadata pickle + u64 num_elements + raw binary data\n    pub fn new(storages_data: Vec<u8>) -> std::io::Result<Self> {\n        use super::pickle_reader::{read_pickle, storage_type_to_element_size};\n        use std::io::Cursor;\n\n        let mut storage_map = HashMap::new();\n        let mut pos = 0usize;\n\n        // First, read the count of storages\n        let mut cursor = Cursor::new(&storages_data[pos..]);\n        let storage_count =\n            if let Ok(super::pickle_reader::Object::Int(count)) = read_pickle(&mut cursor) {\n                pos += cursor.position() as usize;\n                count as usize\n            } else {\n                0\n            };\n\n        // Parse each storage entry\n        for _i in 0..storage_count {\n            if pos >= storages_data.len() {\n                break;\n            }\n\n            // Read the storage metadata pickle: (storage_key, device, storage_type)\n            let mut cursor = Cursor::new(&storages_data[pos..]);\n            if let Ok(obj) = read_pickle(&mut cursor) {\n                let pickle_size = cursor.position() as usize;\n                pos += pickle_size;\n\n                // Extract storage info from pickle tuple\n                let (storage_key, storage_type) = match obj {\n                    super::pickle_reader::Object::Tuple(tuple) if tuple.len() >= 3 => {\n                        let key = match &tuple[0] {\n                            super::pickle_reader::Object::Int(i) => i.to_string(),\n                            super::pickle_reader::Object::String(s) => s.clone(),\n                            _ => continue,\n                        };\n                        // tuple[1] is device (e.g., \"cpu\")\n                        // tuple[2] is storage type class\n                        let stype = match &tuple[2] {\n                            super::pickle_reader::Object::Class { name, .. } => name.clone(),\n                            other => {\n                                return Err(std::io::Error::new(\n                                    std::io::ErrorKind::InvalidData,\n                                    format!(\"Expected Class for storage type, got {:?}\", other),\n                                ));\n                            }\n                        };\n                        (key, stype)\n                    }\n                    _ => continue,\n                };\n\n                // Read the number of elements (u64 little-endian)\n                if pos + 8 > storages_data.len() {\n                    break;\n                }\n                let num_elements = u64::from_le_bytes([\n                    storages_data[pos],\n                    storages_data[pos + 1],\n                    storages_data[pos + 2],\n                    storages_data[pos + 3],\n                    storages_data[pos + 4],\n                    storages_data[pos + 5],\n                    storages_data[pos + 6],\n                    storages_data[pos + 7],\n                ]) as usize;\n                pos += 8;\n\n                // Determine element size from storage type\n                let element_size = storage_type_to_element_size(&storage_type)\n                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;\n\n                let data_size = num_elements * element_size;\n\n                // Store the offset to raw data and its size\n                storage_map.insert(storage_key, (pos, data_size));\n\n                // Skip the raw binary data\n                pos += data_size;\n            } else {\n                break;\n            }\n        }\n\n        Ok(Self {\n            storage_map,\n            storages_data,\n        })\n    }\n\n    /// Read data for a specific storage key\n    pub fn read_file(&self, key: &str) -> std::io::Result<Vec<u8>> {\n        // Extract the storage key from paths like \"data/0\"\n        let storage_key = key.split('/').next_back().unwrap_or(key);\n\n        if let Some(&(offset, size)) = self.storage_map.get(storage_key)\n            && offset + size <= self.storages_data.len()\n        {\n            return Ok(self.storages_data[offset..offset + size].to_vec());\n        }\n\n        Err(std::io::Error::new(\n            std::io::ErrorKind::NotFound,\n            format!(\"Storage key '{}' not found in TAR archive\", storage_key),\n        ))\n    }\n\n    /// Read a range of data for a specific storage key (avoids double allocation)\n    pub fn read_file_range(\n        &self,\n        key: &str,\n        offset: usize,\n        length: usize,\n    ) -> std::io::Result<Vec<u8>> {\n        let storage_key = key.split('/').next_back().unwrap_or(key);\n\n        if let Some(&(storage_offset, storage_size)) = self.storage_map.get(storage_key)\n            && storage_offset + storage_size <= self.storages_data.len()\n        {\n            let start = storage_offset + offset;\n            let end = (storage_offset + offset + length).min(storage_offset + storage_size);\n            return Ok(self.storages_data[start..end].to_vec());\n        }\n\n        Err(std::io::Error::new(\n            std::io::ErrorKind::NotFound,\n            format!(\"Storage key '{}' not found in TAR archive\", storage_key),\n        ))\n    }\n\n    /// Check if a storage key exists\n    pub fn contains(&self, key: &str) -> bool {\n        let storage_key = key.split('/').next_back().unwrap_or(key);\n        self.storage_map.contains_key(storage_key)\n    }\n\n    /// Get list of storage keys\n    pub fn keys(&self) -> Vec<String> {\n        self.storage_map.keys().cloned().collect()\n    }\n}\n\nimpl LazyDataSource {\n    /// Create from a ZIP file\n    pub fn from_zip(path: impl AsRef<Path>) -> std::io::Result<Self> {\n        Ok(Self::Zip(Arc::new(Mutex::new(ZipSource::new(\n            path.as_ref().to_path_buf(),\n        )?))))\n    }\n\n    /// Create from a TAR archive's storages data\n    pub fn from_tar(storages_data: &[u8]) -> std::io::Result<Self> {\n        Ok(Self::Tar(Arc::new(Mutex::new(TarSource::new(\n            storages_data.to_vec(),\n        )?))))\n    }\n\n    /// Create from a legacy multi-storage file\n    pub fn from_legacy_multi_storage(\n        path: impl AsRef<Path>,\n        data_offset: u64,\n        data_size: u64,\n    ) -> Self {\n        Self::LegacyMultiStorage(Arc::new(Mutex::new(LegacyMultiStorageSource::new(\n            path.as_ref().to_path_buf(),\n            data_offset,\n            data_size,\n        ))))\n    }\n\n    /// Read data for a specific key\n    pub fn read(&self, key: &str) -> std::io::Result<Vec<u8>> {\n        match self {\n            Self::Zip(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.read_file(key)\n            }\n            Self::Tar(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.read_file(key)\n            }\n            Self::LegacyMultiStorage(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.read(key)\n            }\n        }\n    }\n\n    /// Read a portion of data for a specific key\n    pub fn read_range(&self, key: &str, offset: usize, length: usize) -> std::io::Result<Vec<u8>> {\n        match self {\n            Self::Zip(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.read_file_range(key, offset, length)\n            }\n            Self::Tar(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.read_file_range(key, offset, length)\n            }\n            Self::LegacyMultiStorage(source) => {\n                // For legacy format, read only the requested range\n                let storage_key = key.split('/').next_back().unwrap_or(key);\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n\n                // Get storage boundaries\n                let storage_map = source\n                    .storage_map\n                    .read()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                if let Some(ref map) = *storage_map\n                    && let Some(&(storage_offset, storage_size)) = map.get(storage_key)\n                {\n                    // Calculate actual file position\n                    let file_offset = source.data_offset + storage_offset + offset as u64;\n                    let read_length = length.min((storage_size as usize).saturating_sub(offset));\n\n                    // Read only the requested range\n                    let mut file = File::open(&source.path)?;\n                    file.seek(std::io::SeekFrom::Start(file_offset))?;\n\n                    let mut buffer = vec![0u8; read_length];\n                    file.read_exact(&mut buffer)?;\n                    Ok(buffer)\n                } else {\n                    Err(std::io::Error::new(\n                        std::io::ErrorKind::InvalidData,\n                        format!(\n                            \"Storage boundaries not available for key '{}'. Cannot perform lazy loading.\",\n                            storage_key\n                        ),\n                    ))\n                }\n            }\n        }\n    }\n\n    /// Check if a key exists\n    pub fn contains(&self, key: &str) -> bool {\n        match self {\n            Self::Zip(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.contains(key)\n            }\n            Self::Tar(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.contains(key)\n            }\n            Self::LegacyMultiStorage(_) => true, // Legacy format has all data\n        }\n    }\n\n    /// Get list of available keys (for ZIP sources)\n    pub fn keys(&self) -> Vec<String> {\n        match self {\n            Self::Zip(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.data_files()\n            }\n            Self::Tar(source) => {\n                let source = source\n                    .lock()\n                    .unwrap_or_else(|poisoned| poisoned.into_inner());\n                source.keys()\n            }\n            Self::LegacyMultiStorage(_) => vec![], // Legacy format doesn't have distinct keys\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/mod.rs",
    "content": "//! PyTorch format support for burn-store.\n//!\n//! This module provides comprehensive support for loading PyTorch model files (.pth, .pt)\n//! into Burn, with automatic weight transformation and flexible configuration options.\n//!\n//! ## Features\n//!\n//! - **Direct .pth/.pt file loading**: Load PyTorch checkpoint and state dict files\n//! - **Automatic weight transformation**: `PyTorchToBurnAdapter` is applied by default:\n//!   - Linear layer weights are automatically transposed\n//!   - Normalization parameters are renamed (gamma → weight, beta → bias)\n//!   - Conv2d weights maintain their format\n//! - **Flexible filtering**: Load only specific layers or parameters\n//! - **Key remapping**: Rename tensors during loading to match your model structure\n//! - **Partial loading**: Continue even when some tensors are missing\n//!\n//! ## Example\n//!\n//! ```rust,ignore\n//! use burn_store::PytorchStore;\n//!\n//! // Load a PyTorch model (PyTorchToBurnAdapter is applied automatically)\n//! let mut store = PytorchStore::from_file(\"model.pth\")\n//!     .with_top_level_key(\"state_dict\")              // Access nested state dict\n//!     .with_regex(r\"^encoder\\..*\")                   // Only load encoder layers\n//!     .with_key_remapping(r\"^fc\\.\", \"linear.\")       // Rename fc -> linear\n//!     .allow_partial(true);                          // Skip missing tensors\n//!\n//! let mut model = MyModel::new(&device);\n//! let result = model.load_from(&mut store)?;\n//!\n//! println!(\"Loaded {} tensors\", result.applied.len());\n//! if !result.missing.is_empty() {\n//!     println!(\"Missing tensors: {:?}\", result.missing);\n//! }\n//! ```\n\npub mod lazy_data;\npub mod pickle_reader;\npub mod reader;\npub mod store;\n\n#[cfg(test)]\npub mod tests;\n\n// Main public interface\npub use reader::{PytorchError, PytorchReader};\npub use store::{PytorchStore, PytorchStoreError};\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/pickle_reader.rs",
    "content": "//! Just enough pickle support to be able to read PyTorch checkpoints.\n//!\n//! This implementation is based on the candle project's pickle loader with significant\n//! modifications for improved separation of concerns and extended PyTorch compatibility.\n//!\n//! Original source: <https://github.com/huggingface/candle/blob/main/candle-core/src/pickle.rs>\n//!\n//! Modifications include:\n//! - Lazy tensor data loading for memory efficiency\n//! - Extended PyTorch version compatibility (0.1.10 - 2.x)\n//! - Better separation of pickle parsing and tensor extraction\n//! - Support for both legacy and modern PyTorch formats\nuse crate::TensorSnapshot;\nuse crate::pytorch::lazy_data::LazyDataSource;\nuse alloc::rc::Rc;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\nuse burn_core::module::ParamId;\nuse burn_tensor::{BoolStore, DType, TensorData};\nuse byteorder::{LittleEndian, ReadBytesExt};\nuse half::{bf16, f16};\nuse std::collections::HashMap;\nuse std::io::{self, BufRead};\nuse std::sync::Arc;\n\n/// Error type for pickle operations\n#[derive(Debug)]\npub enum PickleError {\n    Io(io::Error),\n    InvalidOpCode(u8),\n    InvalidProtocol(u8),\n    UnexpectedOpCode(OpCode),\n    UnsupportedType(String),\n    InvalidData(String),\n    StackUnderflow,\n    MemoNotFound(u32),\n    InvalidShapeOrType,\n}\n\nimpl From<io::Error> for PickleError {\n    fn from(e: io::Error) -> Self {\n        PickleError::Io(e)\n    }\n}\n\nimpl std::fmt::Display for PickleError {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            PickleError::Io(e) => write!(f, \"IO error: {}\", e),\n            PickleError::InvalidOpCode(code) => write!(\n                f,\n                \"Invalid pickle opcode: 0x{:02x}. The file may be corrupted or use an unsupported pickle protocol.\",\n                code\n            ),\n            PickleError::InvalidProtocol(proto) => write!(\n                f,\n                \"Invalid or unsupported pickle protocol version: {}. Supported versions are 2-5.\",\n                proto\n            ),\n            PickleError::UnexpectedOpCode(op) => {\n                write!(f, \"Unexpected pickle opcode {:?} in current context\", op)\n            }\n            PickleError::UnsupportedType(ty) => write!(\n                f,\n                \"Unsupported Python type '{}'. This may indicate a full model save rather than a state_dict.\",\n                ty\n            ),\n            PickleError::InvalidData(msg) => write!(f, \"Invalid data in pickle file: {}\", msg),\n            PickleError::StackUnderflow => {\n                write!(f, \"Pickle stack underflow - the file may be corrupted\")\n            }\n            PickleError::MemoNotFound(idx) => write!(\n                f,\n                \"Pickle memo reference {} not found - the file may be corrupted\",\n                idx\n            ),\n            PickleError::InvalidShapeOrType => {\n                write!(f, \"Invalid tensor shape or data type in PyTorch file\")\n            }\n        }\n    }\n}\n\nimpl std::error::Error for PickleError {}\n\ntype Result<T> = std::result::Result<T, PickleError>;\n\n/// Convert PyTorch storage type name to element size in bytes.\n///\n/// This is used to calculate storage sizes for lazy loading.\n/// The storage type names follow PyTorch's naming convention (e.g., \"FloatStorage\", \"BFloat16Storage\").\n///\n/// Returns an error for unknown storage types to avoid silently loading garbage data.\npub fn storage_type_to_element_size(storage_type: &str) -> std::result::Result<usize, String> {\n    match storage_type {\n        \"DoubleStorage\" | \"LongStorage\" | \"ComplexFloatStorage\" => Ok(8),\n        \"FloatStorage\" | \"IntStorage\" | \"ComplexHalfStorage\" => Ok(4),\n        \"HalfStorage\" | \"BFloat16Storage\" | \"ShortStorage\" => Ok(2),\n        \"ByteStorage\" | \"CharStorage\" | \"BoolStorage\" => Ok(1),\n        _ => Err(format!(\"Unknown storage type: {}\", storage_type)),\n    }\n}\n\n// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/\n#[repr(u8)]\n#[derive(Debug, Eq, PartialEq, Clone)]\npub enum OpCode {\n    // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123\n    Proto = 0x80,\n    Global = b'c',\n    BinPut = b'q',\n    LongBinPut = b'r',\n    EmptyTuple = b')',\n    Reduce = b'R',\n    Mark = b'(',\n    BinUnicode = b'X',\n    ShortBinString = b'U',\n    BinInt = b'J',\n    Int = b'I',\n    Tuple = b't',\n    BinPersId = b'Q',\n    BinInt1 = b'K',\n    BinInt2 = b'M',\n    Tuple1 = 0x85,\n    Tuple2 = 0x86,\n    Tuple3 = 0x87,\n    NewTrue = 0x88,\n    NewFalse = 0x89,\n    None = b'N',\n    BinGet = b'h',\n    LongBinGet = b'j',\n    SetItem = b's',\n    SetItems = b'u',\n    EmptyDict = b'}',\n    Dict = b'd',\n    Build = b'b',\n    Stop = b'.',\n    NewObj = 0x81,\n    EmptyList = b']',\n    List = b'l',\n    BinFloat = b'G',\n    Append = b'a',\n    Appends = b'e',\n    Long1 = 0x8a,\n    Memoize = 0x94,\n}\n\n// Avoid using FromPrimitive so as not to drag another dependency.\nimpl TryFrom<u8> for OpCode {\n    type Error = u8;\n    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {\n        match value {\n            0x80 => Ok(Self::Proto),\n            b'c' => Ok(Self::Global),\n            b'q' => Ok(Self::BinPut),\n            b'r' => Ok(Self::LongBinPut),\n            b')' => Ok(Self::EmptyTuple),\n            b'R' => Ok(Self::Reduce),\n            b'(' => Ok(Self::Mark),\n            b'X' => Ok(Self::BinUnicode),\n            b'U' => Ok(Self::ShortBinString),\n            b'J' => Ok(Self::BinInt),\n            b'I' => Ok(Self::Int),\n            b't' => Ok(Self::Tuple),\n            b'Q' => Ok(Self::BinPersId),\n            b'K' => Ok(Self::BinInt1),\n            b'M' => Ok(Self::BinInt2),\n            b'N' => Ok(Self::None),\n            0x85 => Ok(Self::Tuple1),\n            0x86 => Ok(Self::Tuple2),\n            0x87 => Ok(Self::Tuple3),\n            0x88 => Ok(Self::NewTrue),\n            0x89 => Ok(Self::NewFalse),\n            b'h' => Ok(Self::BinGet),\n            b'j' => Ok(Self::LongBinGet),\n            b's' => Ok(Self::SetItem),\n            b'u' => Ok(Self::SetItems),\n            b'}' => Ok(Self::EmptyDict),\n            b'd' => Ok(Self::Dict),\n            b'b' => Ok(Self::Build),\n            b'.' => Ok(Self::Stop),\n            0x81 => Ok(Self::NewObj),\n            b']' => Ok(Self::EmptyList),\n            b'l' => Ok(Self::List),\n            b'G' => Ok(Self::BinFloat),\n            b'a' => Ok(Self::Append),\n            b'e' => Ok(Self::Appends),\n            0x8a => Ok(Self::Long1),\n            0x94 => Ok(Self::Memoize),\n            value => Err(value),\n        }\n    }\n}\n\nfn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {\n    let mut data: Vec<u8> = Vec::with_capacity(32);\n    r.read_until(b'\\n', &mut data)?;\n    data.pop();\n    if data.last() == Some(&b'\\r') {\n        data.pop();\n    }\n    Ok(data)\n}\n\nfn buf_to_str(buf: &[u8]) -> Result<String> {\n    String::from_utf8(buf.to_vec())\n        .map_err(|e| PickleError::InvalidData(format!(\"Invalid UTF-8: {}\", e)))\n}\n\n#[derive(Debug, Clone)]\npub enum Object {\n    Class {\n        module_name: String,\n        name: String,\n    },\n    String(String),\n    Int(i64),\n    Float(f64),\n    Bool(bool),\n    None,\n    Tuple(Vec<Object>),\n    List(Vec<Object>),\n    Dict(HashMap<String, Object>),\n    Persistent(Vec<u8>),\n    PersistentTuple(Vec<Object>),\n    Reduce {\n        callable: Box<Object>,\n        args: Box<Object>,\n    },\n    Build {\n        callable: Box<Object>,\n        args: Box<Object>,\n    },\n    TorchParam(TensorSnapshot),\n}\n\nfn rebuild_from_type_v2(\n    o: Object,\n    memo: &mut HashMap<u32, Object>,\n    data_source: &Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    let args = if let Object::Tuple(args) = o {\n        if args.is_empty() {\n            return Err(PickleError::InvalidData(\n                \"rebuild_from_type_v2: empty args\".to_string(),\n            ));\n        }\n        args\n    } else {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_from_type_v2: expected tuple got {:?}\",\n            o\n        )));\n    };\n    let func = &args[0];\n    match func {\n        Object::Class { module_name, name } => {\n            let module_name = module_name.as_str();\n            let name = name.as_str();\n            // For rebuild_tensor_v2, the args might already be in a tuple\n            let actual_args = if args.len() == 2 && matches!(&args[1], Object::Tuple(_)) {\n                // If there's only one arg and it's a tuple, use it directly\n                args[1].clone()\n            } else {\n                // Otherwise, wrap the remaining args in a tuple\n                Object::Tuple(args[1..].to_vec())\n            };\n            if module_name == \"torch._utils\" && name == \"_rebuild_tensor_v2\" {\n                rebuild_tensor_v2(actual_args, memo, data_source)\n            } else if module_name == \"torch._utils\" && name == \"_rebuild_tensor\" {\n                // Legacy _rebuild_tensor (PyTorch < 1.6)\n                // Same as v2 but with fewer arguments: (storage, storage_offset, size, stride)\n                rebuild_tensor(actual_args, memo, data_source)\n            } else if module_name == \"torch._tensor\" && name == \"_rebuild_from_type_v2\" {\n                rebuild_from_type_v2(actual_args, memo, data_source)\n            } else if module_name == \"torch._utils\" && name == \"_rebuild_parameter\" {\n                rebuild_parameter(actual_args, memo, data_source)\n            } else if module_name == \"collections\" && name == \"OrderedDict\" {\n                // OrderedDict is treated as a regular Dict in our implementation\n                Ok(Object::Dict(HashMap::new()))\n            } else {\n                Err(PickleError::UnsupportedType(format!(\n                    \"{}.{}\",\n                    module_name, name\n                )))\n            }\n        }\n        _ => Err(PickleError::InvalidData(format!(\n            \"rebuild_from_type_v2: expected class got {:?}\",\n            func\n        ))),\n    }\n}\n\nfn rebuild_parameter(\n    args: Object,\n    memo: &mut HashMap<u32, Object>,\n    data_source: &Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    let args = if let Object::Tuple(args) = args {\n        if args.is_empty() {\n            return Err(PickleError::InvalidData(\n                \"rebuild_parameter: empty args\".to_string(),\n            ));\n        }\n        args\n    } else {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_parameter: expected tuple got {:?}\",\n            args\n        )));\n    };\n    let data = &args[0];\n    let tensor = match data {\n        Object::Reduce {\n            callable: _,\n            args: _,\n        } => rebuild_from_type_v2(data.clone(), memo, data_source)?,\n        _ => data.clone(),\n    };\n    Ok(tensor)\n}\n\n/// Parse storage argument and extract storage info and tuple.\nfn parse_storage_arg(arg: &Object, fn_name: &str) -> Result<(Vec<u8>, Option<Vec<Object>>)> {\n    match arg {\n        Object::Persistent(data) => Ok((data.clone(), None)),\n        Object::PersistentTuple(tuple) => Ok((vec![], Some(tuple.clone()))),\n        // Also accept regular Tuple for TAR format compatibility\n        Object::Tuple(tuple) => Ok((vec![], Some(tuple.clone()))),\n        _ => Err(PickleError::InvalidData(format!(\n            \"{}: expected persistent id got {:?}\",\n            fn_name, arg\n        ))),\n    }\n}\n\n/// Parse shape argument.\nfn parse_shape_arg(arg: &Object, fn_name: &str) -> Result<Vec<usize>> {\n    match arg {\n        Object::Tuple(shape) => shape\n            .iter()\n            .map(|x| match x {\n                Object::Int(i) => Ok(*i as usize),\n                _ => Err(PickleError::InvalidData(\n                    \"shape must contain ints\".to_string(),\n                )),\n            })\n            .collect::<Result<Vec<_>>>(),\n        _ => Err(PickleError::InvalidData(format!(\n            \"{}: expected shape tuple got {:?}\",\n            fn_name, arg\n        ))),\n    }\n}\n\n/// Legacy _rebuild_tensor function for PyTorch < 1.6.\n/// Thin wrapper that parses 4 arguments and calls rebuild_tensor_impl.\nfn rebuild_tensor(\n    args: Object,\n    _memo: &mut HashMap<u32, Object>,\n    data_source: &Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    let args = if let Object::Tuple(args) = args {\n        args\n    } else {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_tensor: expected tuple got {:?}\",\n            args\n        )));\n    };\n\n    if args.len() < 4 {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_tensor: expected at least 4 args, got {}\",\n            args.len()\n        )));\n    }\n\n    let (storage_info, storage_tuple) = parse_storage_arg(&args[0], \"rebuild_tensor\")?;\n    let storage_offset = match &args[1] {\n        Object::Int(offset) => *offset as usize,\n        _ => 0,\n    };\n    let shape = parse_shape_arg(&args[2], \"rebuild_tensor\")?;\n\n    rebuild_tensor_impl(\n        storage_info,\n        storage_tuple,\n        storage_offset,\n        shape,\n        data_source,\n    )\n}\n\n/// Modern _rebuild_tensor_v2 function for PyTorch >= 1.6.\n/// Thin wrapper that parses 5+ arguments and calls rebuild_tensor_impl.\nfn rebuild_tensor_v2(\n    args: Object,\n    _memo: &mut HashMap<u32, Object>,\n    data_source: &Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    let args = if let Object::Tuple(args) = args {\n        args\n    } else {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_tensor_v2: expected tuple got {:?}\",\n            args\n        )));\n    };\n\n    if args.len() < 5 {\n        return Err(PickleError::InvalidData(format!(\n            \"rebuild_tensor_v2: expected at least 5 args, got {}\",\n            args.len()\n        )));\n    }\n\n    let (storage_info, storage_tuple) = parse_storage_arg(&args[0], \"rebuild_tensor_v2\")?;\n    let storage_offset = match &args[1] {\n        Object::Int(offset) => *offset as usize,\n        _ => 0,\n    };\n    let shape = parse_shape_arg(&args[2], \"rebuild_tensor_v2\")?;\n    // args[3] is stride (unused)\n    // args[4] is requires_grad (unused)\n    // args[5] is backward_hooks (unused)\n\n    rebuild_tensor_impl(\n        storage_info,\n        storage_tuple,\n        storage_offset,\n        shape,\n        data_source,\n    )\n}\n\n/// Helper to convert storage type name to DType.\nfn storage_type_to_dtype(storage_type: &str) -> Result<DType> {\n    match storage_type {\n        \"FloatStorage\" => Ok(DType::F32),\n        \"DoubleStorage\" => Ok(DType::F64),\n        \"HalfStorage\" => Ok(DType::F16),\n        \"BFloat16Storage\" => Ok(DType::BF16),\n        \"LongStorage\" => Ok(DType::I64),\n        \"IntStorage\" => Ok(DType::I32),\n        \"ShortStorage\" => Ok(DType::I16),\n        \"CharStorage\" => Ok(DType::I8),\n        \"ByteStorage\" => Ok(DType::U8),\n        \"BoolStorage\" => Ok(DType::Bool(BoolStore::Native)),\n        _ => Err(PickleError::InvalidData(format!(\n            \"Unknown storage type: {}\",\n            storage_type\n        ))),\n    }\n}\n\n/// Core implementation for rebuilding tensors.\n/// Shared by both rebuild_tensor (legacy) and rebuild_tensor_v2 (modern).\nfn rebuild_tensor_impl(\n    storage_info: Vec<u8>,\n    storage_tuple: Option<Vec<Object>>,\n    storage_offset: usize,\n    shape: Vec<usize>,\n    data_source: &Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    // Parse the storage info to extract dtype and storage key\n    // The persistent ID is typically a tuple like: ('storage', 'FloatStorage', '0', 'cpu', 4)\n    let (dtype, storage_key) = if let Some(tuple) = storage_tuple {\n        // Direct tuple access\n        if tuple.len() >= 3 {\n            let storage_type = match &tuple[1] {\n                Object::String(s) => s.as_str(),\n                Object::Class {\n                    module_name: _,\n                    name,\n                } => name.as_str(),\n                other => {\n                    return Err(PickleError::InvalidData(format!(\n                        \"Expected storage type as String or Class, got {:?}\",\n                        other\n                    )));\n                }\n            };\n            let dtype = storage_type_to_dtype(storage_type)?;\n            let key = match &tuple[2] {\n                Object::String(s) => s.clone(),\n                other => {\n                    return Err(PickleError::InvalidData(format!(\n                        \"Expected storage key as String, got {:?}\",\n                        other\n                    )));\n                }\n            };\n            (dtype, key)\n        } else {\n            return Err(PickleError::InvalidData(format!(\n                \"Storage tuple too short, expected at least 3 elements, got {}\",\n                tuple.len()\n            )));\n        }\n    } else if !storage_info.is_empty() {\n        // Legacy string-based parsing\n        let storage_str = String::from_utf8_lossy(&storage_info);\n        if storage_str.starts_with(\"Tuple(\") {\n            // Parse from the debug representation we stored\n            let parts: Vec<&str> = storage_str\n                .trim_start_matches(\"Tuple(\")\n                .trim_end_matches(\")\")\n                .split(\", \")\n                .map(|s| {\n                    let trimmed = s.trim_matches('\"');\n                    if let Some(inner) = trimmed\n                        .strip_prefix(\"Object::String(\\\"\")\n                        .and_then(|s| s.strip_suffix(\"\\\")\"))\n                    {\n                        inner\n                    } else {\n                        trimmed\n                    }\n                })\n                .collect();\n\n            if parts.len() >= 3 {\n                let dtype = storage_type_to_dtype(parts[1])?;\n                (dtype, parts[2].to_string())\n            } else {\n                return Err(PickleError::InvalidData(format!(\n                    \"Storage info tuple too short, expected at least 3 parts, got {}\",\n                    parts.len()\n                )));\n            }\n        } else {\n            return Err(PickleError::InvalidData(format!(\n                \"Invalid storage info format: {}\",\n                storage_str\n            )));\n        }\n    } else {\n        return Err(PickleError::InvalidData(\n            \"No storage information available\".to_string(),\n        ));\n    };\n\n    // If no data source, we can't load tensor data\n    let data_source = match data_source {\n        Some(ds) => ds.clone(),\n        None => {\n            return Err(PickleError::InvalidData(\n                \"Cannot load tensor data without a data source\".to_string(),\n            ));\n        }\n    };\n\n    // Create clones for the closure\n    let data_source_clone = data_source.clone();\n    let shape_clone = shape.clone();\n\n    // Find the correct data file key\n    let data_file_key = {\n        let exact_key = format!(\"data/{}\", storage_key);\n        if data_source.contains(&exact_key) {\n            exact_key\n        } else {\n            // Try other patterns\n            data_source\n                .keys()\n                .into_iter()\n                .find(|key| {\n                    key.ends_with(&format!(\"/data/{}\", storage_key))\n                        || (key.contains(\"/data/\") && key.rsplit('/').next() == Some(&storage_key))\n                })\n                .unwrap_or_else(|| format!(\"data/{}\", storage_key))\n        }\n    };\n\n    // Track storage usage IMMEDIATELY for lazy boundary detection\n    // This must happen BEFORE creating the closure, not inside it!\n    if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source {\n        let source = source\n            .lock()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n        let num_elements: usize = shape.iter().product();\n        let bytes_needed = storage_offset * dtype.size() + num_elements * dtype.size();\n        source.track_storage_usage(&storage_key, 0, bytes_needed);\n    }\n\n    // Create a TensorSnapshot with a closure that loads the actual data on-demand\n    Ok(Object::TorchParam(TensorSnapshot::from_closure(\n        Rc::new(move || {\n            // Load data only when needed\n            if let Ok(data) = data_source_clone.read(&data_file_key) {\n                // Parse the binary data based on dtype\n                let num_elements = shape_clone.iter().product::<usize>().max(1);\n\n                // Use dtype.size() to get element size in bytes\n                let element_size = dtype.size();\n\n                // Apply storage offset\n                let offset_bytes = storage_offset * element_size;\n                if offset_bytes >= data.len() {\n                    return Ok(TensorData::new(\n                        vec![0.0f32; num_elements],\n                        shape_clone.clone(),\n                    ));\n                }\n\n                let data_slice = &data[offset_bytes..];\n                let available_elements = data_slice.len() / element_size;\n                let elements_to_read = num_elements.min(available_elements);\n\n                // Convert bytes to the appropriate type\n                match dtype {\n                    DType::F32 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let bytes = [\n                                data_slice[i * element_size],\n                                data_slice[i * element_size + 1],\n                                data_slice[i * element_size + 2],\n                                data_slice[i * element_size + 3],\n                            ];\n                            values.push(f32::from_le_bytes(bytes));\n                        }\n                        // Pad with zeros if needed\n                        values.resize(num_elements, 0.0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::F64 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 8];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(f64::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0.0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::I64 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 8];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(i64::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::I32 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 4];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(i32::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::I16 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 2];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(i16::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::I8 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for &byte in data_slice.iter().take(elements_to_read) {\n                            values.push(byte as i8);\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::Bool(BoolStore::Native) => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for &byte in data_slice.iter().take(elements_to_read) {\n                            values.push(byte != 0);\n                        }\n                        values.resize(num_elements, false);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::F16 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 2];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(f16::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, f16::ZERO);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::BF16 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 2];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(bf16::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, bf16::ZERO);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::U8 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for &byte in data_slice.iter().take(elements_to_read) {\n                            values.push(byte);\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::U16 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 2];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(u16::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::U32 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 4];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(u32::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    DType::U64 => {\n                        let mut values = Vec::with_capacity(num_elements);\n                        for i in 0..elements_to_read {\n                            let mut bytes = [0u8; 8];\n                            bytes.copy_from_slice(\n                                &data_slice[i * element_size..(i + 1) * element_size],\n                            );\n                            values.push(u64::from_le_bytes(bytes));\n                        }\n                        values.resize(num_elements, 0);\n                        Ok(TensorData::new(values, shape_clone.clone()))\n                    }\n                    _ => {\n                        // For any remaining unsupported types, return an error\n                        Err(crate::TensorSnapshotError::DataError(format!(\n                            \"Unsupported dtype for tensor data reading: {:?}\",\n                            dtype\n                        )))\n                    }\n                }\n            } else {\n                // If no data file found, return zeros of the appropriate type\n                let num_elements = shape_clone.iter().product::<usize>().max(1);\n                match dtype {\n                    DType::F32 => Ok(TensorData::new(\n                        vec![0.0f32; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::F64 => Ok(TensorData::new(\n                        vec![0.0f64; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::F16 => Ok(TensorData::new(\n                        vec![f16::ZERO; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::BF16 => Ok(TensorData::new(\n                        vec![bf16::ZERO; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::I64 => Ok(TensorData::new(\n                        vec![0i64; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::I32 => Ok(TensorData::new(\n                        vec![0i32; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::I16 => Ok(TensorData::new(\n                        vec![0i16; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::I8 => Ok(TensorData::new(\n                        vec![0i8; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::U8 => Ok(TensorData::new(\n                        vec![0u8; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::U16 => Ok(TensorData::new(\n                        vec![0u16; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::U32 => Ok(TensorData::new(\n                        vec![0u32; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::U64 => Ok(TensorData::new(\n                        vec![0u64; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    DType::Bool(BoolStore::Native) => Ok(TensorData::new(\n                        vec![false; num_elements],\n                        shape_clone.clone(),\n                    )),\n                    _ => {\n                        // For any remaining unsupported types, return an error\n                        Err(crate::TensorSnapshotError::DataError(format!(\n                            \"Unsupported dtype for tensor data reading: {:?}\",\n                            dtype\n                        )))\n                    }\n                }\n            }\n        }),\n        dtype,\n        shape.into(),\n        vec![],         // path_stack\n        vec![],         // container_stack\n        ParamId::new(), // tensor_id\n    )))\n}\n\npub struct Stack {\n    stack: Vec<Object>,\n    memo: HashMap<u32, Object>,\n    data_source: Option<Arc<LazyDataSource>>,\n}\n\nimpl Default for Stack {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Stack {\n    pub fn new() -> Self {\n        // For cases where no data source is needed (pure pickle without tensor data)\n        Self {\n            stack: Vec::new(),\n            memo: HashMap::new(),\n            data_source: None,\n        }\n    }\n\n    pub fn with_data_source(data_source: Arc<LazyDataSource>) -> Self {\n        Self {\n            stack: Vec::new(),\n            memo: HashMap::new(),\n            data_source: Some(data_source),\n        }\n    }\n\n    fn push(&mut self, o: Object) {\n        self.stack.push(o)\n    }\n\n    fn pop(&mut self) -> Result<Object> {\n        match self.stack.pop() {\n            None => Err(PickleError::StackUnderflow),\n            Some(o) => Ok(o),\n        }\n    }\n\n    fn top(&self) -> Result<Object> {\n        match self.stack.last() {\n            None => Err(PickleError::StackUnderflow),\n            Some(o) => Ok(o.clone()),\n        }\n    }\n\n    fn pop_to_marker(&mut self) -> Result<Vec<Object>> {\n        let marker_pos = self\n            .stack\n            .iter()\n            .rposition(|o| {\n                matches!(o, Object::Class { module_name, name }\n                if module_name == \"mark\" && name == \"mark\")\n            })\n            .ok_or(PickleError::InvalidData(\"marker not found\".to_string()))?;\n\n        let result = self.stack.split_off(marker_pos + 1);\n        self.stack.pop(); // Remove the marker\n        Ok(result)\n    }\n\n    fn last_mut(&mut self) -> Result<&mut Object> {\n        match self.stack.last_mut() {\n            None => Err(PickleError::StackUnderflow),\n            Some(o) => Ok(o),\n        }\n    }\n\n    fn push_mark(&mut self) {\n        self.stack.push(Object::Class {\n            module_name: \"mark\".to_string(),\n            name: \"mark\".to_string(),\n        });\n    }\n\n    fn memo_get(&self, idx: u32) -> Result<Object> {\n        self.memo\n            .get(&idx)\n            .cloned()\n            .ok_or(PickleError::MemoNotFound(idx))\n    }\n\n    fn memo_put(&mut self, idx: u32, obj: Object) {\n        self.memo.insert(idx, obj);\n    }\n\n    fn memo_len(&self) -> usize {\n        self.memo.len()\n    }\n}\n\nfn read_global<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    let module_name = buf_to_str(&read_to_newline(r)?)?;\n    let name = buf_to_str(&read_to_newline(r)?)?;\n    stack.push(Object::Class { module_name, name });\n    Ok(())\n}\n\nfn read_long1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    let len = r.read_u8()? as usize;\n    let mut data = vec![0u8; len];\n    r.read_exact(&mut data)?;\n    // Handle little-endian signed integer\n    let mut value = 0i64;\n    for (i, &byte) in data.iter().enumerate().take(8) {\n        // Only process up to 8 bytes for i64, and use wrapping to avoid overflow\n        value |= (byte as i64).wrapping_shl((i as u32) * 8);\n    }\n    // Handle sign extension for negative numbers\n    if len < 8 && data.last().is_some_and(|&b| b & 0x80 != 0) {\n        // Sign extend\n        for i in len..8 {\n            value |= 0xffi64.wrapping_shl((i as u32) * 8);\n        }\n    }\n    stack.push(Object::Int(value));\n    Ok(())\n}\n\nfn read_string<R: BufRead>(r: &mut R, stack: &mut Stack, len: usize) -> Result<()> {\n    let mut data = vec![0u8; len];\n    r.read_exact(&mut data)?;\n    let s = buf_to_str(&data)?;\n    stack.push(Object::String(s));\n    Ok(())\n}\n\nfn read_bin_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    let v = r.read_i32::<LittleEndian>()?;\n    stack.push(Object::Int(v as i64));\n    Ok(())\n}\n\nfn read_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    // INT opcode reads an integer as ASCII string followed by newline\n    let line = read_to_newline(r)?;\n    let s = buf_to_str(&line)?;\n    let v = s\n        .parse::<i64>()\n        .map_err(|e| PickleError::InvalidData(format!(\"Invalid INT value '{}': {}\", s, e)))?;\n    stack.push(Object::Int(v));\n    Ok(())\n}\n\nfn read_bin_int1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    let v = r.read_u8()?;\n    stack.push(Object::Int(v as i64));\n    Ok(())\n}\n\nfn read_bin_int2<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    let v = r.read_u16::<LittleEndian>()?;\n    stack.push(Object::Int(v as i64));\n    Ok(())\n}\n\nfn read_bin_float<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {\n    // Python's BINFLOAT uses big-endian encoding\n    let v = r.read_f64::<byteorder::BigEndian>()?;\n    stack.push(Object::Float(v));\n    Ok(())\n}\n\npub fn read_pickle<R: BufRead>(r: &mut R) -> Result<Object> {\n    // For pure pickle without tensor data, no data source is needed\n    read_pickle_with_optional_data(r, None)\n}\n\n/// Skip over a pickle without parsing it fully\n/// This is useful for legacy format where we need to skip the main object\n/// that contains tensors but we don't have a data source yet\npub fn skip_pickle<R: BufRead>(r: &mut R) -> Result<()> {\n    // Read the protocol marker if present\n    let mut first_byte = [0u8; 1];\n    r.read_exact(&mut first_byte)?;\n\n    if first_byte[0] == 0x80 {\n        // PROTO marker - read protocol version\n        let mut proto_version = [0u8; 1];\n        r.read_exact(&mut proto_version)?;\n    }\n    // If not PROTO, the first byte is an opcode - continue to main loop\n\n    // Helper to skip until newline\n    fn skip_line<R: BufRead>(r: &mut R) -> Result<()> {\n        let mut buf = Vec::new();\n        r.read_until(b'\\n', &mut buf)?;\n        Ok(())\n    }\n\n    // Helper to skip length-prefixed data\n    fn skip_length_prefixed<R: BufRead>(r: &mut R, length: usize) -> Result<()> {\n        let mut skip_buf = vec![0u8; length.min(8192)];\n        let mut skipped = 0;\n        while skipped < length {\n            let to_skip = (length - skipped).min(skip_buf.len());\n            r.read_exact(&mut skip_buf[..to_skip])?;\n            skipped += to_skip;\n        }\n        Ok(())\n    }\n\n    // Process first byte if it wasn't PROTO\n    let mut pending_byte = if first_byte[0] != 0x80 {\n        Some(first_byte[0])\n    } else {\n        None\n    };\n\n    // Scan until we find STOP (0x2e) opcode\n    loop {\n        let byte = if let Some(b) = pending_byte.take() {\n            b\n        } else {\n            let mut byte = [0u8; 1];\n            r.read_exact(&mut byte)?;\n            byte[0]\n        };\n\n        match byte {\n            0x2e => {\n                // STOP - end of pickle\n                break;\n            }\n            // === Newline-terminated string opcodes ===\n            0x63 => {\n                // GLOBAL - two newline-terminated strings (module\\nname\\n)\n                skip_line(r)?;\n                skip_line(r)?;\n            }\n            0x69 => {\n                // INST - two newline-terminated strings\n                skip_line(r)?;\n                skip_line(r)?;\n            }\n            0x53 => {\n                // STRING - quoted string ending with newline\n                skip_line(r)?;\n            }\n            0x46 | 0x49 | 0x4c => {\n                // FLOAT, INT, LONG - newline-terminated ASCII\n                skip_line(r)?;\n            }\n            0x50 => {\n                // PERSID - newline-terminated persistent ID\n                skip_line(r)?;\n            }\n            // === Length-prefixed binary opcodes ===\n            0x58 | 0x42 | 0x43 | 0x54 | 0x55 | 0x56 | 0x8c | 0x8d | 0x8e => {\n                // String/bytes opcodes with length prefixes\n                let length = match byte {\n                    0x43 | 0x55 | 0x8c => {\n                        // SHORT versions - 1 byte length\n                        let mut len_byte = [0u8; 1];\n                        r.read_exact(&mut len_byte)?;\n                        len_byte[0] as usize\n                    }\n                    0x42 | 0x54 | 0x58 | 0x56 => {\n                        // Regular versions - 4 byte length\n                        let mut len_bytes = [0u8; 4];\n                        r.read_exact(&mut len_bytes)?;\n                        u32::from_le_bytes(len_bytes) as usize\n                    }\n                    0x8d | 0x8e => {\n                        // 8-byte length versions\n                        let mut len_bytes = [0u8; 8];\n                        r.read_exact(&mut len_bytes)?;\n                        u64::from_le_bytes(len_bytes) as usize\n                    }\n                    _ => 0,\n                };\n                skip_length_prefixed(r, length)?;\n            }\n            // === Fixed-size integer opcodes ===\n            0x4b => {\n                // BININT1 - 1 byte\n                let mut buf = [0u8; 1];\n                r.read_exact(&mut buf)?;\n            }\n            0x4d => {\n                // BININT2 - 2 bytes\n                let mut buf = [0u8; 2];\n                r.read_exact(&mut buf)?;\n            }\n            0x4a => {\n                // BININT - 4 bytes (signed int)\n                let mut buf = [0u8; 4];\n                r.read_exact(&mut buf)?;\n            }\n            0x47 => {\n                // BINFLOAT - 8 bytes\n                let mut buf = [0u8; 8];\n                r.read_exact(&mut buf)?;\n            }\n            // === Variable-length integer opcodes ===\n            0x8a => {\n                // LONG1 - 1 byte length, then that many bytes\n                let mut len_byte = [0u8; 1];\n                r.read_exact(&mut len_byte)?;\n                let length = len_byte[0] as usize;\n                skip_length_prefixed(r, length)?;\n            }\n            0x8b => {\n                // LONG4 - 4 byte length, then that many bytes\n                let mut len_bytes = [0u8; 4];\n                r.read_exact(&mut len_bytes)?;\n                let length = u32::from_le_bytes(len_bytes) as usize;\n                skip_length_prefixed(r, length)?;\n            }\n            // === Memo opcodes ===\n            0x71 | 0x68 => {\n                // BINPUT, BINGET - 1 byte index\n                let mut buf = [0u8; 1];\n                r.read_exact(&mut buf)?;\n            }\n            0x72 | 0x6a => {\n                // LONG_BINPUT, LONG_BINGET - 4 byte index\n                let mut buf = [0u8; 4];\n                r.read_exact(&mut buf)?;\n            }\n            0x67 | 0x70 => {\n                // GET, PUT - newline-terminated decimal index\n                skip_line(r)?;\n            }\n            // === Extension opcodes ===\n            0x82 => {\n                // EXT1 - 1 byte code\n                let mut buf = [0u8; 1];\n                r.read_exact(&mut buf)?;\n            }\n            0x83 => {\n                // EXT2 - 2 byte code\n                let mut buf = [0u8; 2];\n                r.read_exact(&mut buf)?;\n            }\n            0x84 => {\n                // EXT4 - 4 byte code\n                let mut buf = [0u8; 4];\n                r.read_exact(&mut buf)?;\n            }\n            // === Frame opcode (protocol 4+) ===\n            0x95 => {\n                // FRAME - 8 byte frame size (we don't actually use framing, just skip the size)\n                let mut buf = [0u8; 8];\n                r.read_exact(&mut buf)?;\n            }\n            // === Opcodes with no additional data ===\n            // These just manipulate the stack or are markers\n            0x28 | 0x29 | 0x30 | 0x31 | 0x32 | // MARK, TUPLE, POP, POP_MARK, DUP\n            0x4e | 0x52 | 0x5d | 0x5b | 0x7d | // NONE, REDUCE, LIST, EMPTY_LIST, EMPTY_DICT\n            0x61 | 0x62 | 0x64 | 0x65 | 0x73 | // APPEND, BUILD, DICT, APPENDS, SETITEM\n            0x74 | 0x75 | 0x85 | 0x86 | 0x87 | // TUPLE, SETITEMS, TUPLE1, TUPLE2, TUPLE3\n            0x88 | 0x89 | 0x8f | 0x90 | 0x91 | // NEWTRUE, NEWFALSE, STACK_GLOBAL, MEMOIZE, EMPTY_SET\n            0x92 | 0x93 | 0x94 | 0x51 | 0x81 => { // ADDITEMS, FROZENSET, NEWOBJ, BINPERSID, NEWOBJ_EX\n                // No additional data to skip\n            }\n            _ => {\n                // Unknown opcode - assume no additional data\n                // This is a best-effort approach\n            }\n        }\n    }\n\n    Ok(())\n}\n\npub fn read_pickle_with_data<R: BufRead>(\n    r: &mut R,\n    data_source: Arc<LazyDataSource>,\n) -> Result<Object> {\n    read_pickle_with_optional_data(r, Some(data_source))\n}\n\nfn get_dict_key(obj: Object) -> Result<String> {\n    match obj {\n        Object::String(s) => Ok(s),\n        Object::Int(i) => Ok(i.to_string()),\n        _ => Err(PickleError::InvalidData(format!(\n            \"dict key must be a valid type, got {obj:?}\"\n        ))),\n    }\n}\n\npub fn read_pickle_with_optional_data<R: BufRead>(\n    r: &mut R,\n    data_source: Option<Arc<LazyDataSource>>,\n) -> Result<Object> {\n    let mut stack = match data_source {\n        Some(ds) => Stack::with_data_source(ds),\n        None => Stack::new(),\n    };\n    loop {\n        let op_code = r.read_u8()?;\n        let op_code = OpCode::try_from(op_code).map_err(PickleError::InvalidOpCode)?;\n        match op_code {\n            OpCode::Proto => {\n                let version = r.read_u8()?;\n                if version > 5 {\n                    return Err(PickleError::InvalidProtocol(version));\n                }\n            }\n            OpCode::Global => read_global(r, &mut stack)?,\n            OpCode::BinInt => read_bin_int(r, &mut stack)?,\n            OpCode::Int => read_int(r, &mut stack)?,\n            OpCode::BinInt1 => read_bin_int1(r, &mut stack)?,\n            OpCode::BinInt2 => read_bin_int2(r, &mut stack)?,\n            OpCode::BinFloat => read_bin_float(r, &mut stack)?,\n            OpCode::BinUnicode => {\n                let len = r.read_u32::<LittleEndian>()? as usize;\n                read_string(r, &mut stack, len)?\n            }\n            OpCode::ShortBinString => {\n                let len = r.read_u8()? as usize;\n                read_string(r, &mut stack, len)?\n            }\n            OpCode::Long1 => read_long1(r, &mut stack)?,\n            OpCode::None => stack.push(Object::None),\n            OpCode::NewTrue => stack.push(Object::Bool(true)),\n            OpCode::NewFalse => stack.push(Object::Bool(false)),\n            OpCode::EmptyTuple => stack.push(Object::Tuple(Vec::new())),\n            OpCode::EmptyList => stack.push(Object::List(Vec::new())),\n            OpCode::EmptyDict => stack.push(Object::Dict(HashMap::new())),\n            OpCode::Tuple => {\n                let objs = stack.pop_to_marker()?;\n                stack.push(Object::Tuple(objs))\n            }\n            OpCode::Tuple1 => {\n                let obj = stack.pop()?;\n                stack.push(Object::Tuple(vec![obj]))\n            }\n            OpCode::Tuple2 => {\n                let obj2 = stack.pop()?;\n                let obj1 = stack.pop()?;\n                stack.push(Object::Tuple(vec![obj1, obj2]))\n            }\n            OpCode::Tuple3 => {\n                let obj3 = stack.pop()?;\n                let obj2 = stack.pop()?;\n                let obj1 = stack.pop()?;\n                stack.push(Object::Tuple(vec![obj1, obj2, obj3]))\n            }\n            OpCode::Append => {\n                let value = stack.pop()?;\n                match stack.last_mut()? {\n                    Object::List(list) => list.push(value),\n                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),\n                }\n            }\n            OpCode::Appends => {\n                let objs = stack.pop_to_marker()?;\n                match stack.last_mut()? {\n                    Object::List(list) => list.extend(objs),\n                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),\n                }\n            }\n            OpCode::SetItem => {\n                let value = stack.pop()?;\n                let key = stack.pop()?;\n                match stack.last_mut()? {\n                    Object::Dict(dict) => {\n                        if let Object::String(key) = key {\n                            dict.insert(key, value);\n                        } else {\n                            return Err(PickleError::InvalidData(\n                                \"dict key must be a string\".to_string(),\n                            ));\n                        }\n                    }\n                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),\n                }\n            }\n            OpCode::SetItems => {\n                let mut objs = stack.pop_to_marker()?;\n                if objs.len() % 2 != 0 {\n                    return Err(PickleError::InvalidData(\n                        \"setitems requires even number of objects\".to_string(),\n                    ));\n                }\n                match stack.last_mut()? {\n                    Object::Dict(dict) => {\n                        while !objs.is_empty() {\n                            let key = objs.remove(0);\n                            let value = objs.remove(0);\n                            let key = get_dict_key(key)?;\n                            dict.insert(key, value);\n                        }\n                    }\n                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),\n                }\n            }\n            OpCode::BinPut => {\n                let idx = r.read_u8()? as u32;\n                let obj = stack.top()?;\n                stack.memo_put(idx, obj);\n            }\n            OpCode::LongBinPut => {\n                let idx = r.read_u32::<LittleEndian>()?;\n                let obj = stack.top()?;\n                stack.memo_put(idx, obj);\n            }\n            OpCode::BinGet => {\n                let idx = r.read_u8()? as u32;\n                let obj = stack.memo_get(idx)?;\n                stack.push(obj);\n            }\n            OpCode::LongBinGet => {\n                let idx = r.read_u32::<LittleEndian>()?;\n                let obj = stack.memo_get(idx)?;\n                stack.push(obj);\n            }\n            OpCode::Mark => stack.push_mark(),\n            OpCode::BinPersId => {\n                let pid = stack.pop()?;\n                match pid {\n                    Object::String(s) => {\n                        stack.push(Object::Persistent(s.into_bytes()));\n                    }\n                    Object::Tuple(tuple) => {\n                        // The persistent ID is a tuple (e.g., ('storage', 'FloatStorage', '0', 'cpu', 4))\n                        // Store it as a PersistentTuple for proper handling\n                        stack.push(Object::PersistentTuple(tuple));\n                    }\n                    _ => {\n                        return Err(PickleError::InvalidData(format!(\n                            \"persistent id must be a string or tuple, got {:?}\",\n                            pid\n                        )));\n                    }\n                }\n            }\n            OpCode::Reduce => {\n                let args = stack.pop()?;\n                let callable = stack.pop()?;\n\n                // Check if this is an OrderedDict\n                if let Object::Class { module_name, name } = &callable {\n                    if module_name == \"collections\" && name == \"OrderedDict\" {\n                        // OrderedDict can be created with items: OrderedDict([(key1, val1), ...])\n                        // The args is typically a tuple containing a list of [key, value] pairs\n                        let mut dict = HashMap::new();\n\n                        // Extract items from args\n                        let items = match &args {\n                            Object::Tuple(tuple) if !tuple.is_empty() => {\n                                // Args is a tuple, get the first element (the list of items)\n                                match &tuple[0] {\n                                    Object::List(list) => Some(list.clone()),\n                                    _ => None,\n                                }\n                            }\n                            Object::List(list) => Some(list.clone()),\n                            _ => None,\n                        };\n\n                        if let Some(items) = items {\n                            for item in items {\n                                // Each item is a list/tuple of [key, value]\n                                match item {\n                                    Object::List(pair) | Object::Tuple(pair) if pair.len() >= 2 => {\n                                        if let Object::String(key) = &pair[0] {\n                                            dict.insert(key.clone(), pair[1].clone());\n                                        }\n                                    }\n                                    _ => {}\n                                }\n                            }\n                        }\n\n                        stack.push(Object::Dict(dict));\n                    } else {\n                        let _obj = Object::Reduce {\n                            callable: Box::new(callable.clone()),\n                            args: Box::new(args.clone()),\n                        };\n                        let obj = rebuild_from_type_v2(\n                            Object::Tuple(vec![callable, args]),\n                            &mut stack.memo,\n                            &stack.data_source,\n                        )?;\n                        stack.push(obj);\n                    }\n                } else {\n                    let _obj = Object::Reduce {\n                        callable: Box::new(callable.clone()),\n                        args: Box::new(args.clone()),\n                    };\n                    let obj = rebuild_from_type_v2(\n                        Object::Tuple(vec![callable, args]),\n                        &mut stack.memo,\n                        &stack.data_source,\n                    )?;\n                    stack.push(obj);\n                }\n            }\n            OpCode::Build => {\n                let args = stack.pop()?;\n                let obj = stack.pop()?;\n                match obj {\n                    Object::Dict(mut dict) => {\n                        // For dicts, BUILD updates with the args\n                        if let Object::Dict(update) = args {\n                            dict.extend(update);\n                        }\n                        stack.push(Object::Dict(dict));\n                    }\n                    _ => {\n                        stack.push(Object::Build {\n                            callable: Box::new(obj),\n                            args: Box::new(args),\n                        });\n                    }\n                }\n            }\n            OpCode::NewObj => {\n                let args = stack.pop()?;\n                let cls = stack.pop()?;\n                stack.push(Object::Reduce {\n                    callable: Box::new(cls),\n                    args: Box::new(args),\n                });\n            }\n            OpCode::Dict => {\n                let objs = stack.pop_to_marker()?;\n                let mut dict = HashMap::new();\n                if objs.len() % 2 != 0 {\n                    return Err(PickleError::InvalidData(\n                        \"dict requires even number of objects\".to_string(),\n                    ));\n                }\n                for chunk in objs.chunks(2) {\n                    let key = get_dict_key(chunk[0].clone())?;\n                    dict.insert(key, chunk[1].clone());\n                }\n                stack.push(Object::Dict(dict));\n            }\n            OpCode::List => {\n                let objs = stack.pop_to_marker()?;\n                stack.push(Object::List(objs));\n            }\n            OpCode::Memoize => {\n                // Store top of stack in memo without popping\n                // The memo index is the current number of items in the memo\n                let obj = stack.top()?;\n                let idx = stack.memo_len() as u32;\n                stack.memo_put(idx, obj);\n            }\n            OpCode::Stop => break,\n        }\n    }\n    stack.pop()\n}\n\n/// Load tensors from a pickle file (PyTorch checkpoint format)\npub fn read_pickle_tensors<R: BufRead>(reader: &mut R) -> Result<HashMap<String, TensorSnapshot>> {\n    let obj = read_pickle(reader)?;\n\n    // Extract tensors from the loaded object\n    let mut tensors = HashMap::new();\n    let mut path = Vec::new();\n    extract_tensors(&obj, &mut path, &mut tensors);\n\n    Ok(tensors)\n}\n\nfn extract_tensors<'a>(\n    obj: &'a Object,\n    path: &mut Vec<&'a str>,\n    tensors: &mut HashMap<String, TensorSnapshot>,\n) {\n    match obj {\n        Object::Dict(dict) => {\n            for (key, value) in dict {\n                path.push(key);\n                extract_tensors(value, path, tensors);\n                path.pop();\n            }\n        }\n        Object::TorchParam(snapshot) => {\n            // Only allocate the string here when we actually insert\n            tensors.insert(path.join(\".\"), snapshot.clone());\n        }\n        _ => {}\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/reader.rs",
    "content": "//! PyTorch file reader implementation.\n//!\n//! This module provides support for reading PyTorch checkpoint files (.pt/.pth).\n//!\n//! # Supported Formats\n//!\n//! ## 1. Modern ZIP Format (PyTorch 1.6+)\n//! Files are ZIP archives containing:\n//! - `data.pkl` or `archive/data.pkl`: Pickled tensor metadata\n//! - `data/` directory: Binary tensor data files\n//!\n//! ## 2. TAR Format (older torchvision models like AlexNet, SqueezeNet)\n//! TAR archives containing:\n//! - `sys_info`: System info pickle (endianness, type sizes)\n//! - `pickle`: OrderedDict mapping tensor names to storage keys\n//! - `tensors`: Tensor metadata (unused, metadata is in pickle)\n//! - `storages`: Count pickle + sequential (metadata, num_elements, raw data)\n//!\n//! ## 3. Legacy Pickle Format (PyTorch 0.1.10 - 1.5)\n//! Sequential pickle streams with the structure:\n//! - Magic number pickle (0x1950a86a20f9469cfc6c)\n//! - Protocol version pickle (e.g., 1001)\n//! - System info pickle (endianness, type sizes)\n//! - Model data pickle (state_dict or full model)\n//!\n//! ## 4. Simple Pickle Format\n//! Direct pickle file with a dictionary at the root, commonly used for\n//! manually saved state_dicts.\n//!\n//! # Compatibility\n//!\n//! The reader handles backward compatibility by detecting the file format\n//! automatically. Files from PyTorch 0.1.10 through current versions are\n//! supported, though full model saves (vs state_dict) may have limitations\n//! as they contain Python code references.\n\nuse crate::TensorSnapshot;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\nuse burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};\nuse serde::de::DeserializeOwned;\nuse std::collections::HashMap;\nuse std::fs::File;\nuse std::io::{BufReader, Read, Seek, SeekFrom};\nuse std::path::Path;\n\nuse super::lazy_data::LazyDataSource;\nuse super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};\nuse std::sync::Arc;\n\n/// Error type for PyTorch file operations\n#[derive(Debug)]\npub enum PytorchError {\n    /// IO error\n    Io(std::io::Error),\n    /// Pickle parsing error\n    Pickle(PickleError),\n    /// Zip archive error\n    Zip(zip::result::ZipError),\n    /// TAR archive error\n    Tar(std::io::Error),\n    /// Invalid file format\n    InvalidFormat(String),\n    /// Key not found\n    KeyNotFound(String),\n    /// Serde deserialization error\n    Serde(burn_core::record::serde::error::Error),\n}\n\nimpl From<std::io::Error> for PytorchError {\n    fn from(e: std::io::Error) -> Self {\n        PytorchError::Io(e)\n    }\n}\n\nimpl From<PickleError> for PytorchError {\n    fn from(e: PickleError) -> Self {\n        PytorchError::Pickle(e)\n    }\n}\n\nimpl From<zip::result::ZipError> for PytorchError {\n    fn from(e: zip::result::ZipError) -> Self {\n        PytorchError::Zip(e)\n    }\n}\n\nimpl From<burn_core::record::serde::error::Error> for PytorchError {\n    fn from(e: burn_core::record::serde::error::Error) -> Self {\n        PytorchError::Serde(e)\n    }\n}\n\nimpl std::fmt::Display for PytorchError {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            PytorchError::Io(e) => write!(f, \"IO error: {}\", e),\n            PytorchError::Pickle(e) => write!(\n                f,\n                \"Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.\",\n                e\n            ),\n            PytorchError::Zip(e) => write!(f, \"Zip archive error: {}\", e),\n            PytorchError::Tar(e) => write!(f, \"TAR archive error: {}\", e),\n            PytorchError::InvalidFormat(msg) => write!(f, \"Invalid PyTorch file format: {}\", msg),\n            PytorchError::KeyNotFound(key) => write!(\n                f,\n                \"Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.\",\n                key\n            ),\n            PytorchError::Serde(e) => write!(f, \"Serde deserialization error: {}\", e),\n        }\n    }\n}\n\nimpl std::error::Error for PytorchError {}\n\ntype Result<T> = std::result::Result<T, PytorchError>;\n\n/// Metadata about a PyTorch file\n///\n/// Contains information about the file format, version, and other properties\n/// that can be useful for debugging or compatibility checking.\n#[derive(Debug, Clone)]\npub struct PytorchMetadata {\n    /// Format version (e.g., \"1.0\" for modern ZIP format)\n    pub format_version: Option<String>,\n    /// File format type (ZIP, Legacy, or Pickle)\n    pub format_type: FileFormat,\n    /// Byte order (endianness) - currently only LittleEndian is supported\n    pub byte_order: ByteOrder,\n    /// Whether the file has storage alignment information\n    pub has_storage_alignment: bool,\n    /// PyTorch version that saved the file (if available)\n    pub pytorch_version: Option<String>,\n    /// Number of tensors in the file\n    pub tensor_count: usize,\n    /// Total size of tensor data in bytes (if available)\n    pub total_data_size: Option<usize>,\n}\n\nimpl PytorchMetadata {\n    /// Check if this is a modern format file (ZIP-based, PyTorch 1.6+)\n    pub fn is_modern_format(&self) -> bool {\n        matches!(self.format_type, FileFormat::Zip)\n    }\n\n    /// Check if this is a legacy format file (PyTorch 0.1.10 - 1.5)\n    pub fn is_legacy_format(&self) -> bool {\n        matches!(self.format_type, FileFormat::Legacy)\n    }\n}\n\n/// File format type\n#[derive(Debug, Clone, PartialEq)]\npub enum FileFormat {\n    /// ZIP-based format (PyTorch 1.6+)\n    Zip,\n    /// TAR-based format (older torchvision models)\n    Tar,\n    /// Legacy format (PyTorch 0.1.10 - 1.5)\n    Legacy,\n    /// Simple pickle file\n    Pickle,\n}\n\n/// Byte order (endianness)\n#[derive(Debug, Clone, PartialEq)]\npub enum ByteOrder {\n    LittleEndian,\n    BigEndian,\n}\n\n/// PyTorch checkpoint reader\n///\n/// This is the main interface for reading PyTorch checkpoint files (.pt/.pth).\n/// It supports multiple PyTorch formats including modern ZIP-based format (1.6+),\n/// legacy format (0.1.10-1.5), and simple pickle files.\n///\n/// # Example\n/// ```rust,no_run\n/// # use burn_store::pytorch::PytorchReader;\n/// # fn example() -> Result<(), Box<dyn std::error::Error>> {\n/// // Load a checkpoint file\n/// let reader = PytorchReader::new(\"model.pt\")?;\n///\n/// // Get tensor names\n/// let keys = reader.keys();\n///\n/// // Access a specific tensor\n/// if let Some(tensor) = reader.get(\"conv1.weight\") {\n///     let data = tensor.to_data(); // Materializes the tensor\n/// }\n///\n/// // Check file metadata\n/// println!(\"Format: {:?}\", reader.metadata().format_type);\n/// println!(\"Tensor count: {}\", reader.metadata().tensor_count);\n/// # Ok(())\n/// # }\n/// ```\npub struct PytorchReader {\n    tensors: HashMap<String, TensorSnapshot>,\n    metadata: PytorchMetadata,\n}\n\nimpl PytorchReader {\n    /// Load a PyTorch checkpoint file\n    ///\n    /// # Arguments\n    /// * `path` - Path to the PyTorch file (.pt or .pth)\n    ///\n    /// # Returns\n    /// A `PytorchReader` with lazy-loaded tensors and metadata\n    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {\n        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;\n        Ok(Self { tensors, metadata })\n    }\n\n    /// Load a PyTorch checkpoint with a specific top-level key\n    ///\n    /// Many PyTorch checkpoints store the model weights under a specific key\n    /// like \"state_dict\", \"model\", or \"model_state_dict\".\n    ///\n    /// # Arguments\n    /// * `path` - Path to the PyTorch file\n    /// * `key` - Top-level key to extract (e.g., \"state_dict\")\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::pytorch::PytorchReader;\n    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {\n    /// let reader = PytorchReader::with_top_level_key(\"checkpoint.pt\", \"state_dict\")?;\n    /// # Ok(())\n    /// # }\n    /// ```\n    pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {\n        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;\n        Ok(Self { tensors, metadata })\n    }\n\n    /// Load from a reader\n    ///\n    /// This method is useful when loading from non-file sources like memory buffers.\n    /// Note: Metadata detection is limited when loading from a reader.\n    ///\n    /// # Arguments\n    /// * `reader` - Any type implementing `Read`\n    /// * `top_level_key` - Optional key to extract\n    pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {\n        // For reader-based loading, we don't have full metadata access\n        let tensors = load_from_reader(reader, top_level_key)?;\n        let metadata = PytorchMetadata {\n            format_version: None,\n            format_type: FileFormat::Pickle, // Default assumption\n            byte_order: ByteOrder::LittleEndian,\n            has_storage_alignment: false,\n            pytorch_version: None,\n            tensor_count: tensors.len(),\n            total_data_size: None,\n        };\n        Ok(Self { tensors, metadata })\n    }\n\n    /// Get all tensor names\n    pub fn keys(&self) -> Vec<String> {\n        self.tensors.keys().cloned().collect()\n    }\n\n    /// Get a tensor by name\n    pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {\n        self.tensors.get(name)\n    }\n\n    /// Get all tensors\n    pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {\n        &self.tensors\n    }\n\n    /// Take ownership of all tensors\n    pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {\n        self.tensors\n    }\n\n    /// Get metadata about the loaded file\n    ///\n    /// Provides information about the file format, version, endianness, etc.\n    pub fn metadata(&self) -> &PytorchMetadata {\n        &self.metadata\n    }\n\n    /// Get the number of tensors in the file\n    pub fn len(&self) -> usize {\n        self.tensors.len()\n    }\n\n    /// Check if the file contains no tensors\n    pub fn is_empty(&self) -> bool {\n        self.tensors.is_empty()\n    }\n\n    /// Read raw pickle data from a PyTorch file\n    ///\n    /// This is useful for extracting configuration or metadata that isn't tensor data.\n    /// Returns a simplified JSON-like structure that can be easily converted to other formats.\n    ///\n    /// # Arguments\n    /// * `path` - Path to the PyTorch file\n    /// * `top_level_key` - Optional key to extract from the top-level dictionary\n    ///\n    /// # Returns\n    /// A `PickleValue` representing the pickle data structure\n    pub fn read_pickle_data<P: AsRef<Path>>(\n        path: P,\n        top_level_key: Option<&str>,\n    ) -> Result<PickleValue> {\n        read_pickle_as_value(path.as_ref(), top_level_key)\n    }\n\n    /// Load and deserialize configuration data from a PyTorch file\n    ///\n    /// This method reads configuration or metadata stored in PyTorch checkpoint files\n    /// and deserializes it into the specified type. It's particularly useful for\n    /// extracting model configurations that might be saved alongside model weights.\n    ///\n    /// # Arguments\n    /// * `path` - Path to the PyTorch file (.pt or .pth)\n    /// * `top_level_key` - Optional key to extract specific data within the pickle file.\n    ///   If `None`, the entire content is deserialized.\n    ///\n    /// # Type Parameters\n    /// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`.\n    ///\n    /// # Returns\n    /// A `Result` containing the deserialized configuration data, or an `Error` if\n    /// reading or deserialization fails.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::pytorch::PytorchReader;\n    /// # use serde::Deserialize;\n    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {\n    /// #[derive(Debug, Deserialize)]\n    /// struct ModelConfig {\n    ///     hidden_size: usize,\n    ///     num_layers: usize,\n    /// }\n    ///\n    /// let config: ModelConfig = PytorchReader::load_config(\"model.pth\", Some(\"config\"))?;\n    /// # Ok(())\n    /// # }\n    /// ```\n    pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>\n    where\n        D: DeserializeOwned,\n        P: AsRef<Path>,\n    {\n        // Read the PyTorch file and extract the pickle data\n        let pickle_value = Self::read_pickle_data(path, top_level_key)?;\n\n        // Convert PickleValue to NestedValue\n        let nested_value = convert_pickle_to_nested_value(pickle_value)?;\n\n        // Create a deserializer with the default adapter\n        let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);\n\n        // Deserialize the nested value into the target type\n        let value = D::deserialize(deserializer)?;\n        Ok(value)\n    }\n}\n\n/// Simplified representation of pickle data\n///\n/// This enum provides a JSON-like structure that's easier to work with\n/// than the internal pickle Object type.\n#[derive(Debug, Clone, PartialEq)]\npub enum PickleValue {\n    /// None/null value\n    None,\n    /// Boolean value\n    Bool(bool),\n    /// Integer value\n    Int(i64),\n    /// Floating point value\n    Float(f64),\n    /// String value\n    String(String),\n    /// List/array of values\n    List(Vec<PickleValue>),\n    /// Dictionary/map of string keys to values\n    Dict(HashMap<String, PickleValue>),\n    /// Binary data\n    Bytes(Vec<u8>),\n}\n\n/// Internal function to load a PyTorch file with metadata\nfn load_pytorch_file_with_metadata(\n    path: &Path,\n    top_level_key: Option<&str>,\n) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {\n    // First, try to read as a zip file\n    if let Ok(file) = File::open(path)\n        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))\n    {\n        // PyTorch saves the main data in various locations within the zip\n        let mut pickle_data = Vec::new();\n        let mut pickle_found = false;\n\n        // Try different common pickle file locations\n        let possible_pickle_paths = [\n            \"data.pkl\",\n            \"archive/data.pkl\",\n            // Look for any .pkl file in the root or first-level directories\n        ];\n\n        for pickle_path in &possible_pickle_paths {\n            if archive.by_name(pickle_path).is_ok() {\n                let mut pickle_file = archive.by_name(pickle_path)?;\n                pickle_file.read_to_end(&mut pickle_data)?;\n                pickle_found = true;\n                break;\n            }\n        }\n\n        // If not found in standard locations, search for any .pkl file\n        if !pickle_found {\n            for i in 0..archive.len() {\n                let file = archive.by_index(i)?;\n                let name = file.name().to_string();\n                drop(file); // Release the borrow\n\n                if name.ends_with(\"data.pkl\") {\n                    let mut file = archive.by_index(i)?;\n                    file.read_to_end(&mut pickle_data)?;\n                    pickle_found = true;\n                    break;\n                }\n            }\n        }\n\n        if !pickle_found {\n            return Err(PytorchError::InvalidFormat(\n                \"No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl\".to_string(),\n            ));\n        }\n\n        // Check for format version (optional)\n        let format_version = if let Ok(mut version_file) = archive.by_name(\".format_version\") {\n            let mut version_data = Vec::new();\n            version_file.read_to_end(&mut version_data)?;\n            let version_str = String::from_utf8_lossy(&version_data);\n            let version = version_str.trim().to_string();\n            Some(version)\n        } else {\n            None\n        };\n\n        // Check for byteorder file to detect endianness\n        let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name(\"byteorder\") {\n            let mut byteorder_data = Vec::new();\n            byteorder_file.read_to_end(&mut byteorder_data)?;\n            let byteorder_str = String::from_utf8_lossy(&byteorder_data);\n            byteorder_str.trim() == \"big\"\n        } else {\n            false // Default to little-endian if no byteorder file\n        };\n\n        if is_big_endian {\n            // Big-endian files are not yet supported as they require different byte order conversion\n            // TODO: To support big-endian files, we need to:\n            // 1. Pass endianness info through to pickle_reader\n            // 2. Use from_be_bytes instead of from_le_bytes for tensor data\n            // 3. Handle byte swapping for all numeric types (f32, f64, i32, etc.)\n            return Err(PytorchError::InvalidFormat(\n                \"Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.\".to_string()\n            ));\n        }\n\n        // Check for storage alignment file\n        let has_storage_alignment = archive.by_name(\".storage_alignment\").is_ok();\n\n        // Check for PyTorch version (if saved)\n        let pytorch_version = if let Ok(mut version_file) = archive.by_name(\"version\") {\n            let mut version_data = Vec::new();\n            version_file.read_to_end(&mut version_data)?;\n            Some(String::from_utf8_lossy(&version_data).trim().to_string())\n        } else {\n            None\n        };\n\n        // Create a lazy data source instead of loading all data upfront\n        let data_source = Arc::new(LazyDataSource::from_zip(path)?);\n\n        // Calculate total data size without loading\n        let mut total_data_size = 0usize;\n        for i in 0..archive.len() {\n            let file = archive.by_index(i)?;\n            let name = file.name();\n\n            // Look for data files - they can be in various locations\n            let is_data_file = (name.contains(\"/data/\")\n                || name.starts_with(\"data/\")\n                || name.starts_with(\"archive/data/\"))\n                && !name.ends_with(\".pkl\")\n                && !name.ends_with(\"/\");\n\n            if is_data_file {\n                total_data_size += file.size() as usize;\n            }\n        }\n\n        // Parse the pickle data with lazy data source\n        let mut pickle_reader = BufReader::new(pickle_data.as_slice());\n        let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;\n\n        // Extract tensors with their data\n        let tensors = extract_tensors_with_data(obj, top_level_key)?;\n\n        // Create metadata\n        let metadata = PytorchMetadata {\n            format_version,\n            format_type: FileFormat::Zip,\n            byte_order: if is_big_endian {\n                ByteOrder::BigEndian\n            } else {\n                ByteOrder::LittleEndian\n            },\n            has_storage_alignment,\n            pytorch_version,\n            tensor_count: tensors.len(),\n            total_data_size: Some(total_data_size),\n        };\n\n        return Ok((tensors, metadata));\n    }\n\n    // If not a zip or zip reading failed, try TAR format\n    if is_tar_file(path) {\n        return load_tar_pytorch_file_with_metadata(path, top_level_key);\n    }\n\n    // Try reading as a plain pickle file\n    let mut file = File::open(path)?;\n\n    // Check for PyTorch legacy format (starts with magic number as pickled integer)\n    let mut header = [0u8; 15];\n    // Use read() instead of read_exact() to handle files smaller than 15 bytes\n    let bytes_read = file.read(&mut header)?;\n    file.seek(std::io::SeekFrom::Start(0))?;\n\n    // Only check for legacy format if we have enough bytes\n    // PyTorch legacy format detection (PyTorch 0.1.10 - 1.3)\n    // Reference: https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L65\n    //\n    // These files use sequential pickle streams with metadata before the actual data.\n    // Format structure:\n    //   1. Magic number (0x1950a86a20f9469cfc6c) stored as LONG1 pickle\n    //   2. Protocol version (e.g., 1001)\n    //   3. System info dict (protocol_version, little_endian, type_sizes)\n    //   4. Actual model data (state_dict or full model)\n    //   5. Storage keys list (pickle)\n    //   6. Raw binary data for each storage\n    //\n    // The pattern is: 0x80 0x02 0x8a 0x0a (PROTO 2, LONG1 with 10 bytes)\n    // followed by 10 bytes of magic number (little-endian), then 0x2e (STOP)\n    let is_legacy_format = bytes_read >= 15\n        && header[0] == 0x80  // PROTO opcode\n        && header[1] == 0x02  // Protocol version 2\n        && header[2] == 0x8a  // LONG1 opcode\n        && header[3] == 0x0a  // 10 bytes follow\n        // Magic number 0x1950a86a20f9469cfc6c in little-endian\n        && header[4] == 0x6c\n        && header[5] == 0xfc\n        && header[6] == 0x9c\n        && header[7] == 0x46\n        && header[8] == 0xf9\n        && header[9] == 0x20\n        && header[10] == 0x6a\n        && header[11] == 0xa8\n        && header[12] == 0x50\n        && header[13] == 0x19\n        && header[14] == 0x2e; // STOP opcode\n\n    if is_legacy_format {\n        return load_legacy_pytorch_file_with_metadata(path, top_level_key);\n    }\n\n    // Standard pickle file\n    // This might be a pickle with tensor references, so we need to handle that case\n    // For plain pickle files without a separate data section, we can't use lazy loading\n    // so we'll just create empty placeholder tensors for the structure\n    let file = File::open(path)?;\n    let mut reader = BufReader::new(file);\n\n    // Try reading without data source first\n    match read_pickle(&mut reader) {\n        Ok(obj) => {\n            let tensors = extract_tensors_with_data(obj, top_level_key)?;\n            let tensor_count = tensors.len();\n            Ok((\n                tensors,\n                PytorchMetadata {\n                    format_version: None,\n                    format_type: FileFormat::Pickle,\n                    byte_order: ByteOrder::LittleEndian,\n                    has_storage_alignment: false,\n                    pytorch_version: None,\n                    tensor_count,\n                    total_data_size: None,\n                },\n            ))\n        }\n        Err(e)\n            if e.to_string()\n                .contains(\"Cannot load tensor data without a data source\") =>\n        {\n            // This pickle file contains tensor data but we're trying to read it without\n            // providing a data source. This shouldn't happen in normal usage as PyTorch\n            // files with actual tensor data should be in ZIP or legacy format.\n            Err(PytorchError::InvalidFormat(\n                \"Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.\".to_string()\n            ))\n        }\n        Err(e) => Err(PytorchError::Pickle(e)),\n    }\n}\n\n/// Load from a reader\nfn load_from_reader<R: Read>(\n    reader: R,\n    top_level_key: Option<&str>,\n) -> Result<HashMap<String, TensorSnapshot>> {\n    let mut buf_reader = BufReader::new(reader);\n\n    // Try reading without data source\n    match read_pickle(&mut buf_reader) {\n        Ok(obj) => extract_tensors_with_data(obj, top_level_key),\n        Err(e)\n            if e.to_string()\n                .contains(\"Cannot load tensor data without a data source\") =>\n        {\n            // This reader contains tensor data but we can't load it without a file path\n            Err(PytorchError::InvalidFormat(\n                \"Reader contains tensor data but no data source is available. Use file-based loading instead.\".to_string()\n            ))\n        }\n        Err(e) => Err(PytorchError::Pickle(e)),\n    }\n}\n\n/// Extract tensors from a parsed pickle object\nfn extract_tensors_with_data(\n    obj: Object,\n    top_level_key: Option<&str>,\n) -> Result<HashMap<String, TensorSnapshot>> {\n    let dict = match obj {\n        Object::Dict(dict) => {\n            if let Some(key) = top_level_key {\n                // Extract the nested dictionary if a top-level key is specified\n                match dict.get(key) {\n                    Some(Object::Dict(nested)) => nested.clone(),\n                    _ => {\n                        return Err(PytorchError::KeyNotFound(format!(\n                            \"Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}\",\n                            key,\n                            dict.keys().collect::<Vec<_>>()\n                        )));\n                    }\n                }\n            } else {\n                dict\n            }\n        }\n        _ => {\n            return Err(PytorchError::InvalidFormat(\n                \"Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.\".to_string(),\n            ));\n        }\n    };\n\n    let mut tensors = HashMap::new();\n    let mut path = Vec::new();\n    extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);\n    Ok(tensors)\n}\n\n/// Recursively extract tensors from an object\nfn extract_tensors_recursive<'a>(\n    obj: &'a Object,\n    path: &mut Vec<&'a str>,\n    tensors: &mut HashMap<String, TensorSnapshot>,\n) {\n    match obj {\n        Object::Dict(dict) => {\n            for (key, value) in dict {\n                path.push(key);\n                extract_tensors_recursive(value, path, tensors);\n                path.pop();\n            }\n        }\n        Object::TorchParam(snapshot) => {\n            // The TensorSnapshot already contains the data loading closure\n            // Only allocate the string here when we actually insert\n            tensors.insert(path.join(\".\"), snapshot.clone());\n        }\n        _ => {}\n    }\n}\n\n/// Load a legacy PyTorch file with metadata\nfn load_legacy_pytorch_file_with_metadata(\n    path: &Path,\n    top_level_key: Option<&str>,\n) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {\n    let file = File::open(path)?;\n    let mut reader = BufReader::new(file);\n\n    // Skip metadata pickles\n    // 1. Magic number\n    let _ = read_pickle(&mut reader).map_err(|e| {\n        PytorchError::InvalidFormat(format!(\n            \"Failed to read magic number from legacy format: {}\",\n            e\n        ))\n    })?;\n\n    // 2. Protocol version\n    let _ = read_pickle(&mut reader).map_err(|e| {\n        PytorchError::InvalidFormat(format!(\n            \"Failed to read protocol version from legacy format: {}\",\n            e\n        ))\n    })?;\n\n    // 3. System info\n    let _ = read_pickle(&mut reader).map_err(|e| {\n        PytorchError::InvalidFormat(format!(\n            \"Failed to read system info from legacy format: {}\",\n            e\n        ))\n    })?;\n\n    // Save position before main pickle\n    let main_pickle_pos = reader.stream_position()?;\n\n    // 4. Skip main object - it might contain tensors so we can't parse it yet\n    // We'll re-read it with a data source later\n    use crate::pytorch::pickle_reader::skip_pickle;\n    skip_pickle(&mut reader).map_err(|e| {\n        PytorchError::InvalidFormat(format!(\n            \"Failed to skip main object in legacy format: {}\",\n            e\n        ))\n    })?;\n\n    // 5. Storage keys list (sorted keys as written by PyTorch)\n    let storage_keys = match read_pickle(&mut reader) {\n        Ok(Object::List(keys)) => keys\n            .into_iter()\n            .filter_map(|obj| match obj {\n                Object::String(s) => Some(s),\n                _ => None,\n            })\n            .collect::<Vec<_>>(),\n        _ => vec![],\n    };\n\n    // 6. Skip 8-byte header before raw binary data\n    // PyTorch legacy format has an 8-byte header (possibly protocol version or alignment)\n    // between the storage keys list and the actual tensor data\n    let mut header = [0u8; 8];\n    if reader.read(&mut header).is_ok() {\n        // Header read successfully, data starts after this\n    }\n\n    // 7. Raw binary data starts here\n    let data_start_pos = reader.stream_position()?;\n    let file_size = reader.seek(SeekFrom::End(0))?;\n    let data_size = file_size - data_start_pos;\n\n    // Create a lazy data source for legacy multi-storage format\n    let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(\n        path,\n        data_start_pos,\n        data_size,\n    ));\n\n    // Set storage keys BEFORE parsing the main pickle\n    // This is critical because track_storage_usage() is called during parsing\n    // and it needs storage_keys to build the storage map\n    if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source\n        && !storage_keys.is_empty()\n    {\n        let source = source\n            .lock()\n            .unwrap_or_else(|poisoned| poisoned.into_inner());\n        source.set_storage_keys(storage_keys.clone());\n    }\n\n    // Now re-read the main pickle with lazy data source\n    reader.seek(SeekFrom::Start(main_pickle_pos))?;\n    let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;\n\n    // Extract tensors normally\n    let tensors = extract_tensors_with_data(main_obj, top_level_key)?;\n\n    // Create metadata for legacy format\n    let metadata = PytorchMetadata {\n        format_version: None, // Legacy format doesn't have version files\n        format_type: FileFormat::Legacy,\n        byte_order: ByteOrder::LittleEndian, // Legacy format is little-endian\n        has_storage_alignment: false,\n        pytorch_version: None, // Could parse from protocol version, but not reliable\n        tensor_count: tensors.len(),\n        total_data_size: Some(data_size as usize),\n    };\n\n    Ok((tensors, metadata))\n}\n\n/// Check if a file is a TAR archive\nfn is_tar_file(path: &Path) -> bool {\n    if let Ok(mut file) = File::open(path) {\n        // TAR files have \"ustar\" magic at offset 257\n        let mut header = [0u8; 263];\n        if file.read_exact(&mut header).is_ok() {\n            // Check for \"ustar\" magic at offset 257\n            return &header[257..262] == b\"ustar\";\n        }\n    }\n    false\n}\n\n/// Load a TAR format PyTorch file with metadata\nfn load_tar_pytorch_file_with_metadata(\n    path: &Path,\n    top_level_key: Option<&str>,\n) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {\n    use tar::Archive;\n\n    let file = File::open(path)?;\n    let mut archive = Archive::new(BufReader::new(file));\n\n    // Extract the main entries from the TAR archive\n    let mut sys_info_data: Option<Vec<u8>> = None;\n    let mut pickle_data: Option<Vec<u8>> = None;\n    let mut storages_data: Option<Vec<u8>> = None;\n\n    for entry in archive.entries().map_err(PytorchError::Tar)? {\n        let mut entry = entry.map_err(PytorchError::Tar)?;\n        let entry_path = entry\n            .path()\n            .map_err(PytorchError::Tar)?\n            .to_string_lossy()\n            .to_string();\n\n        // Skip PAX headers\n        if entry_path.contains(\"@PaxHeader\") {\n            continue;\n        }\n\n        // Normalize path (remove ./ prefix if present)\n        let normalized = entry_path.trim_start_matches(\"./\");\n\n        match normalized {\n            \"sys_info\" => {\n                let mut data = Vec::new();\n                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;\n                sys_info_data = Some(data);\n            }\n            \"pickle\" => {\n                let mut data = Vec::new();\n                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;\n                pickle_data = Some(data);\n            }\n            \"storages\" => {\n                let mut data = Vec::new();\n                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;\n                storages_data = Some(data);\n            }\n            _ => {}\n        }\n    }\n\n    // Validate required entries\n    let pickle_data = pickle_data.ok_or_else(|| {\n        PytorchError::InvalidFormat(\"TAR file missing 'pickle' entry\".to_string())\n    })?;\n    let storages_data = storages_data.ok_or_else(|| {\n        PytorchError::InvalidFormat(\"TAR file missing 'storages' entry\".to_string())\n    })?;\n\n    // Parse sys_info to check endianness\n    let is_little_endian = if let Some(ref data) = sys_info_data {\n        parse_tar_sys_info(data)?\n    } else {\n        true // Default to little-endian\n    };\n\n    if !is_little_endian {\n        return Err(PytorchError::InvalidFormat(\n            \"Big-endian TAR PyTorch files are not supported\".to_string(),\n        ));\n    }\n\n    // Create TarSource for lazy loading\n    let data_source = Arc::new(LazyDataSource::from_tar(&storages_data)?);\n\n    // Parse the pickle (OrderedDict of name -> storage_key)\n    let mut pickle_reader = BufReader::new(pickle_data.as_slice());\n    let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;\n\n    // Extract tensors\n    let tensors = extract_tensors_with_data(obj, top_level_key)?;\n\n    let metadata = PytorchMetadata {\n        format_version: None,\n        format_type: FileFormat::Tar,\n        byte_order: ByteOrder::LittleEndian,\n        has_storage_alignment: false,\n        pytorch_version: None,\n        tensor_count: tensors.len(),\n        total_data_size: Some(storages_data.len()),\n    };\n\n    Ok((tensors, metadata))\n}\n\n/// Parse sys_info pickle from TAR format to extract endianness\nfn parse_tar_sys_info(data: &[u8]) -> Result<bool> {\n    let mut reader = BufReader::new(data);\n    let obj = read_pickle(&mut reader)?;\n\n    if let Object::Dict(dict) = obj\n        && let Some(Object::Bool(little_endian)) = dict.get(\"little_endian\")\n    {\n        return Ok(*little_endian);\n    }\n\n    Ok(true) // Default assumption\n}\n\n/// Read pickle data from a PyTorch file as a simplified value\nfn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {\n    use crate::pytorch::lazy_data::LazyDataSource;\n    use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};\n    use std::sync::Arc;\n\n    // Try to open as ZIP first\n    if let Ok(file) = File::open(path)\n        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))\n    {\n        // Read pickle data from ZIP\n        let mut pickle_data = Vec::new();\n\n        // Try standard locations\n        for pickle_path in &[\"data.pkl\", \"archive/data.pkl\"] {\n            if let Ok(mut pickle_file) = archive.by_name(pickle_path) {\n                pickle_file.read_to_end(&mut pickle_data)?;\n                break;\n            }\n        }\n\n        // If not found, search for any .pkl file\n        if pickle_data.is_empty() {\n            for i in 0..archive.len() {\n                let file = archive.by_index(i)?;\n                let name = file.name().to_string();\n                drop(file);\n\n                if name.ends_with(\"data.pkl\") {\n                    let mut file = archive.by_index(i)?;\n                    file.read_to_end(&mut pickle_data)?;\n                    break;\n                }\n            }\n        }\n\n        if !pickle_data.is_empty() {\n            // Create a data source for the ZIP file\n            let data_source = LazyDataSource::from_zip(path)?;\n            let data_source_arc = Arc::new(data_source);\n\n            let mut reader = BufReader::new(pickle_data.as_slice());\n            let obj = read_pickle_with_data(&mut reader, data_source_arc)?;\n            return convert_object_to_value(obj, top_level_key);\n        }\n    }\n\n    // Try as plain pickle file\n    // First attempt without data source (for pure metadata files)\n    let file = File::open(path)?;\n    let mut reader = BufReader::new(file);\n\n    match read_pickle(&mut reader) {\n        Ok(obj) => convert_object_to_value(obj, top_level_key),\n        Err(e)\n            if e.to_string()\n                .contains(\"Cannot load tensor data without a data source\") =>\n        {\n            // File contains tensors, need to use full PytorchReader\n            // Use the regular reader to get proper tensor handling\n            let reader = PytorchReader::new(path)?;\n\n            // Convert tensors to PickleValue structure\n            let mut result = std::collections::HashMap::new();\n            for key in reader.keys() {\n                // For pickle value extraction, we just need the structure, not the actual data\n                result.insert(\n                    key.clone(),\n                    PickleValue::String(format!(\"<Tensor:{}>\", key)),\n                );\n            }\n\n            if let Some(key) = top_level_key {\n                Ok(PickleValue::Dict(\n                    [(key.to_string(), PickleValue::Dict(result))]\n                        .into_iter()\n                        .collect(),\n                ))\n            } else {\n                Ok(PickleValue::Dict(result))\n            }\n        }\n        Err(e) => Err(PytorchError::Pickle(e)),\n    }\n}\n\n/// Convert internal Object to public PickleValue\nfn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {\n    use crate::pytorch::pickle_reader::Object;\n\n    // If a top-level key is specified, extract it first\n    if let Some(key) = top_level_key\n        && let Object::Dict(dict) = obj\n    {\n        if let Some(value) = dict.get(key) {\n            return object_to_pickle_value(value.clone());\n        } else {\n            return Err(PytorchError::KeyNotFound(format!(\n                \"Key '{}' not found in pickle data\",\n                key\n            )));\n        }\n    }\n\n    object_to_pickle_value(obj)\n}\n\n/// Convert Object to PickleValue\nfn object_to_pickle_value(obj: Object) -> Result<PickleValue> {\n    use crate::pytorch::pickle_reader::Object;\n\n    Ok(match obj {\n        Object::None => PickleValue::None,\n        Object::Bool(b) => PickleValue::Bool(b),\n        Object::Int(i) => PickleValue::Int(i),\n        Object::Float(f) => PickleValue::Float(f),\n        Object::String(s) => PickleValue::String(s),\n        Object::Persistent(data) => {\n            // Persistent data is raw bytes\n            PickleValue::Bytes(data)\n        }\n        Object::PersistentTuple(tuple) => {\n            // Convert persistent tuples to lists\n            let mut values = Vec::new();\n            for item in tuple {\n                values.push(object_to_pickle_value(item)?);\n            }\n            PickleValue::List(values)\n        }\n        Object::List(list) => {\n            let mut values = Vec::new();\n            for item in list {\n                values.push(object_to_pickle_value(item)?);\n            }\n            PickleValue::List(values)\n        }\n        Object::Dict(dict) => {\n            let mut map = HashMap::new();\n            for (k, v) in dict {\n                map.insert(k, object_to_pickle_value(v)?);\n            }\n            PickleValue::Dict(map)\n        }\n        Object::Tuple(tuple) => {\n            // Convert tuples to lists in the public API\n            let mut values = Vec::new();\n            for item in tuple {\n                values.push(object_to_pickle_value(item)?);\n            }\n            PickleValue::List(values)\n        }\n        Object::TorchParam(_) => {\n            // Skip tensor parameters in config reading\n            PickleValue::None\n        }\n        Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {\n            // Complex objects are represented as None for simplicity\n            PickleValue::None\n        }\n    })\n}\n\n/// Convert PickleValue to NestedValue for deserialization\nfn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {\n    Ok(match value {\n        PickleValue::None => NestedValue::Default(None),\n        PickleValue::Bool(b) => NestedValue::Bool(b),\n        PickleValue::Int(i) => NestedValue::I64(i),\n        PickleValue::Float(f) => NestedValue::F64(f),\n        PickleValue::String(s) => NestedValue::String(s),\n        PickleValue::List(list) => {\n            let mut vec = Vec::new();\n            for item in list {\n                vec.push(convert_pickle_to_nested_value(item)?);\n            }\n            NestedValue::Vec(vec)\n        }\n        PickleValue::Dict(dict) => {\n            let mut map = HashMap::new();\n            for (k, v) in dict {\n                map.insert(k, convert_pickle_to_nested_value(v)?);\n            }\n            NestedValue::Map(map)\n        }\n        PickleValue::Bytes(data) => {\n            // Convert bytes to a list of u8 values\n            let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();\n            NestedValue::Vec(vec)\n        }\n    })\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/store.rs",
    "content": "//! PyTorch store implementation for saving and loading models in PyTorch format.\n\nuse crate::{\n    ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,\n    TensorSnapshot, map_indices_contiguous,\n};\n\nuse alloc::collections::BTreeMap;\n\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse alloc::vec::Vec;\nuse burn_tensor::backend::Backend;\nuse core::fmt;\nuse std::path::PathBuf;\n\nuse super::reader::{PytorchError as ReaderError, PytorchReader};\n\n/// Errors that can occur during PyTorch operations.\n#[derive(Debug)]\npub enum PytorchStoreError {\n    /// Reader error.\n    Reader(ReaderError),\n\n    /// I/O error.\n    Io(std::io::Error),\n\n    /// Tensor not found.\n    TensorNotFound(String),\n\n    /// Validation failed.\n    ValidationFailed(String),\n\n    /// Other error.\n    Other(String),\n}\n\nimpl fmt::Display for PytorchStoreError {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        match self {\n            Self::Reader(e) => write!(f, \"PyTorch reader error: {}\", e),\n            Self::Io(e) => write!(f, \"I/O error: {}\", e),\n            Self::TensorNotFound(name) => write!(f, \"Tensor not found: {}\", name),\n            Self::ValidationFailed(msg) => write!(f, \"Validation failed: {}\", msg),\n            Self::Other(msg) => write!(f, \"{}\", msg),\n        }\n    }\n}\n\nimpl std::error::Error for PytorchStoreError {}\n\nimpl From<ReaderError> for PytorchStoreError {\n    fn from(e: ReaderError) -> Self {\n        PytorchStoreError::Reader(e)\n    }\n}\n\nimpl From<std::io::Error> for PytorchStoreError {\n    fn from(e: std::io::Error) -> Self {\n        PytorchStoreError::Io(e)\n    }\n}\n\n/// PyTorch store for file-based storage only.\n///\n/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)\n/// with automatic weight transformation using `PyTorchToBurnAdapter`.\n/// Linear weights are automatically transposed and normalization parameters\n/// are renamed (gamma -> weight, beta -> bias).\n///\n/// Note that saving to PyTorch format is not yet supported.\npub struct PytorchStore {\n    pub(crate) path: PathBuf,\n    pub(crate) filter: PathFilter,\n    pub(crate) remapper: KeyRemapper,\n    pub(crate) validate: bool,\n    pub(crate) allow_partial: bool,\n    pub(crate) top_level_key: Option<String>,\n    pub(crate) skip_enum_variants: bool,\n    /// Enable contiguous mapping of layer indices (default: true)\n    pub(crate) map_indices_contiguous: bool,\n    /// Cached tensor snapshots (parsed once, reused)\n    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,\n}\n\nimpl PytorchStore {\n    /// Create a store for loading from a PyTorch file.\n    ///\n    /// # Arguments\n    /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// use burn_store::PytorchStore;\n    ///\n    /// let store = PytorchStore::from_file(\"model.pth\");\n    /// ```\n    pub fn from_file(path: impl Into<PathBuf>) -> Self {\n        Self {\n            path: path.into(),\n            filter: PathFilter::new(),\n            remapper: KeyRemapper::new(),\n            validate: true,\n            allow_partial: false,\n            top_level_key: None,\n            // PyTorch models never include enum variant names in paths\n            skip_enum_variants: true,\n            // Enable contiguous index mapping by default for PyTorch files\n            // This handles nn.Sequential models with gaps in layer indices\n            map_indices_contiguous: true,\n            snapshots_cache: None,\n        }\n    }\n\n    /// Set a top-level key to extract tensors from.\n    ///\n    /// PyTorch files often contain nested dictionaries. Use this to extract\n    /// tensors from a specific top-level key like \"state_dict\" or \"model_state_dict\".\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// let store = PytorchStore::from_file(\"checkpoint.pth\")\n    ///     .with_top_level_key(\"model_state_dict\");\n    /// ```\n    pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {\n        self.top_level_key = Some(key.into());\n        self\n    }\n\n    /// Filter which tensors to load.\n    pub fn filter(mut self, filter: PathFilter) -> Self {\n        self.filter = filter;\n        self\n    }\n\n    /// Add a regex pattern to filter tensors.\n    ///\n    /// Multiple patterns can be added and they work with OR logic.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .with_regex(r\"^encoder\\..*\")  // Match all encoder tensors\n    ///     .with_regex(r\".*\\.weight$\");   // OR match any weight tensors\n    /// ```\n    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {\n        self.filter = self.filter.with_regex(pattern);\n        self\n    }\n\n    /// Add multiple regex patterns to filter tensors.\n    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: AsRef<str>,\n    {\n        self.filter = self.filter.with_regexes(patterns);\n        self\n    }\n\n    /// Add an exact full path to match.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .with_full_path(\"encoder.layer1.weight\")\n    ///     .with_full_path(\"decoder.output.bias\");\n    /// ```\n    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {\n        self.filter = self.filter.with_full_path(path);\n        self\n    }\n\n    /// Add multiple exact full paths to match.\n    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: Into<String>,\n    {\n        self.filter = self.filter.with_full_paths(paths);\n        self\n    }\n\n    /// Add a predicate function for custom filtering logic.\n    ///\n    /// The predicate receives the tensor path and container path.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .with_predicate(|path, _| path.starts_with(\"encoder.\") || path.ends_with(\".bias\"));\n    /// ```\n    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {\n        self.filter = self.filter.with_predicate(predicate);\n        self\n    }\n\n    /// Add multiple predicate functions.\n    pub fn with_predicates<I>(mut self, predicates: I) -> Self\n    where\n        I: IntoIterator<Item = fn(&str, &str) -> bool>,\n    {\n        self.filter = self.filter.with_predicates(predicates);\n        self\n    }\n\n    /// Set the filter to match all paths (disables filtering).\n    pub fn match_all(mut self) -> Self {\n        self.filter = self.filter.match_all();\n        self\n    }\n\n    /// Remap tensor names during load.\n    pub fn remap(mut self, remapper: KeyRemapper) -> Self {\n        self.remapper = remapper;\n        self\n    }\n\n    /// Add a regex pattern to remap tensor names during load.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .with_key_remapping(r\"^encoder\\.\", \"transformer.encoder.\")  // encoder.X -> transformer.encoder.X\n    ///     .with_key_remapping(r\"\\.gamma$\", \".weight\");               // X.gamma -> X.weight\n    /// ```\n    pub fn with_key_remapping(\n        mut self,\n        from_pattern: impl AsRef<str>,\n        to_pattern: impl Into<String>,\n    ) -> Self {\n        self.remapper = self\n            .remapper\n            .add_pattern(from_pattern, to_pattern)\n            .expect(\"Invalid regex pattern\");\n        self\n    }\n\n    /// Set whether to validate tensors during loading (default: true).\n    pub fn validate(mut self, validate: bool) -> Self {\n        self.validate = validate;\n        self\n    }\n\n    /// Allow partial loading of tensors (continue even if some tensors are missing).\n    pub fn allow_partial(mut self, allow: bool) -> Self {\n        self.allow_partial = allow;\n        self\n    }\n\n    /// Skip enum variant names when matching tensor paths (default: true).\n    ///\n    /// When enabled, tensor paths from PyTorch that don't include enum variants\n    /// can be matched against Burn module paths that do include them.\n    /// For example, PyTorch path \"feature.weight\" can match Burn path \"feature.BaseConv.weight\".\n    ///\n    /// This defaults to `true` for PytorchStore since PyTorch models never include\n    /// enum variant names in their parameter paths.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// // Disable enum variant skipping (not typical)\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .skip_enum_variants(false);\n    /// ```\n    pub fn skip_enum_variants(mut self, skip: bool) -> Self {\n        self.skip_enum_variants = skip;\n        self\n    }\n\n    /// Enable or disable automatic contiguous mapping of layer indices (default: true).\n    ///\n    /// When enabled, non-contiguous numeric indices in tensor paths are renumbered\n    /// to be contiguous. This is useful when loading PyTorch models that have gaps\n    /// in layer numbering, such as when using `nn.Sequential` with mixed layer types\n    /// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).\n    ///\n    /// # Example\n    ///\n    /// With index mapping enabled (default):\n    /// - `fc.0.weight` → `fc.0.weight`\n    /// - `fc.2.weight` → `fc.1.weight` (gap filled)\n    /// - `fc.4.weight` → `fc.2.weight` (gap filled)\n    ///\n    /// # Arguments\n    ///\n    /// * `map` - `true` to enable contiguous index mapping, `false` to disable\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::PytorchStore;\n    /// // Disable contiguous index mapping if your model already has contiguous indices\n    /// let store = PytorchStore::from_file(\"model.pth\")\n    ///     .map_indices_contiguous(false);\n    /// ```\n    pub fn map_indices_contiguous(mut self, map: bool) -> Self {\n        self.map_indices_contiguous = map;\n        self\n    }\n\n    /// Apply remapping to tensor snapshots.\n    fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {\n        if self.remapper.is_empty() {\n            return snapshots;\n        }\n\n        let (remapped, _) = self.remapper.remap(snapshots);\n        remapped\n    }\n\n    /// Create a PytorchReader for the configured path and options.\n    fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {\n        let reader = if let Some(ref key) = self.top_level_key {\n            PytorchReader::with_top_level_key(&self.path, key)?\n        } else {\n            PytorchReader::new(&self.path)?\n        };\n        Ok(reader)\n    }\n}\n\nimpl ModuleStore for PytorchStore {\n    type Error = PytorchStoreError;\n\n    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        _module: &M,\n    ) -> Result<(), Self::Error> {\n        // Saving to PyTorch format is not yet supported\n        Err(PytorchStoreError::Other(\n            \"Saving to PyTorch format is not yet supported. Use other formats for saving.\"\n                .to_string(),\n        ))\n    }\n\n    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &mut M,\n    ) -> Result<ApplyResult, Self::Error> {\n        // Get snapshots from cache\n        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();\n\n        // Get filter (convert to Option for apply)\n        let filter_opt = if self.filter.is_empty() {\n            None\n        } else {\n            Some(self.filter.clone())\n        };\n\n        // Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)\n        // This adapter handles:\n        // - Transposing linear weights from PyTorch format to Burn format\n        // - Renaming normalization parameters (gamma -> weight, beta -> bias)\n        // Filter is applied here during apply, not during cache population\n        let result = module.apply(\n            snapshots,\n            filter_opt,\n            Some(Box::new(PyTorchToBurnAdapter)),\n            self.skip_enum_variants,\n        );\n\n        // Validate if needed\n        if self.validate && !result.errors.is_empty() {\n            return Err(PytorchStoreError::ValidationFailed(format!(\n                \"Import errors:\\n{}\",\n                result\n            )));\n        }\n\n        if !self.allow_partial && !result.missing.is_empty() {\n            return Err(PytorchStoreError::TensorNotFound(format!(\"\\n{}\", result)));\n        }\n\n        Ok(result)\n    }\n\n    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {\n        self.ensure_snapshots_cache()?;\n        Ok(self.snapshots_cache.as_ref().unwrap().get(name))\n    }\n\n    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {\n        self.ensure_snapshots_cache()?;\n        Ok(self.snapshots_cache.as_ref().unwrap())\n    }\n\n    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {\n        // Always use the cache to ensure remapping is applied consistently\n        Ok(self.get_all_snapshots()?.keys().cloned().collect())\n    }\n}\n\nimpl PytorchStore {\n    /// Ensure the snapshots cache is populated\n    fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {\n        if self.snapshots_cache.is_some() {\n            return Ok(());\n        }\n\n        let reader = self.create_reader()?;\n\n        // Convert to tensor snapshots\n        let mut snapshots: Vec<TensorSnapshot> = reader\n            .into_tensors()\n            .into_iter()\n            .map(|(key, mut snapshot)| {\n                // Parse the key into path parts (split by '.')\n                let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();\n\n                // Set the path stack from the key\n                snapshot.path_stack = Some(path_parts);\n                snapshot.container_stack = None;\n                snapshot.tensor_id = None;\n\n                snapshot\n            })\n            .collect();\n\n        // Apply remapping (but NOT filtering - that's done at apply time)\n        snapshots = self.apply_remapping(snapshots);\n\n        // Apply contiguous index mapping if enabled\n        // This must be done after remapping so that remapped paths are mapped\n        if self.map_indices_contiguous {\n            let (mapped, _) = map_indices_contiguous(snapshots);\n            snapshots = mapped;\n        }\n\n        // Build cache as BTreeMap\n        let cache: BTreeMap<String, TensorSnapshot> =\n            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();\n\n        self.snapshots_cache = Some(cache);\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/mod.rs",
    "content": "pub mod reader;\npub mod store;\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/create_legacy_with_offsets.py",
    "content": "#!/usr/bin/env python3\n# /// script\n# dependencies = [\"torch\"]\n# ///\n\"\"\"Create a legacy format PyTorch file with specific storage offsets to test offset handling.\"\"\"\n\nimport torch\n\n# Create tensors with known values at specific storage offsets\n# This will help us verify we're reading from the correct location\n\n# Create a state dict with tensors that share storage\n# This is common in PyTorch models (e.g., weight and transposed weight views)\nstate_dict = {}\n\n# Create a base tensor with known pattern\nbase_data = torch.arange(100, dtype=torch.float32)\n\n# tensor1: uses elements 10-19 (offset 10*4 = 40 bytes)\ntensor1 = base_data[10:20].clone()\ntensor1[:] = torch.arange(1.0, 1.1, 0.01)[:10]  # 1.00, 1.01, 1.02, ...\n\n# tensor2: uses elements 30-35 (offset 30*4 = 120 bytes)\ntensor2 = base_data[30:35].clone()\ntensor2[:] = torch.arange(2.0, 2.5, 0.1)[:5]  # 2.0, 2.1, 2.2, 2.3, 2.4\n\n# tensor3: starts at beginning (offset 0)\ntensor3 = base_data[:5].clone()\ntensor3[:] = torch.arange(3.0, 3.5, 0.1)[:5]  # 3.0, 3.1, 3.2, 3.3, 3.4\n\nstate_dict['tensor1'] = tensor1\nstate_dict['tensor2'] = tensor2\nstate_dict['tensor3'] = tensor3\n\n# Save in legacy format\noutput_file = 'test_data/legacy_with_offsets.pt'\ntorch.save(state_dict, output_file, _use_new_zipfile_serialization=False)\n\nprint(f\"Created {output_file}\")\n\n# Verify by loading\nloaded = torch.load(output_file, weights_only=False)\nprint(\"\\nVerification - expected values:\")\nfor key, tensor in loaded.items():\n    print(f\"  {key}: {tensor.tolist()}\")\n    print(f\"    Storage offset: {tensor.storage_offset()}\")\n    print(f\"    Storage size: {len(tensor.storage())}\")\n\n# Also create a test with multiple tensors sharing the same storage\n# This is important for proper offset handling\nshared_storage = torch.randn(1000)\n\n# Create views into the same storage at different offsets\nview1 = shared_storage[100:110]  # offset 100\nview2 = shared_storage[500:520]  # offset 500\nview3 = shared_storage[0:10]     # offset 0\n\n# Need to save these properly - PyTorch will handle the storage sharing\nshared_dict = {\n    'view1': view1.clone(),  # Clone to avoid view issues\n    'view2': view2.clone(),\n    'view3': view3.clone(),\n}\n\noutput_file2 = 'test_data/legacy_shared_storage.pt'\ntorch.save(shared_dict, output_file2, _use_new_zipfile_serialization=False)\nprint(f\"\\nCreated {output_file2}\")\n\n# Print exact values for test verification\nprint(\"\\nExact test values for legacy_with_offsets.pt:\")\nprint(\"tensor1 (10 elements starting at 1.0):\")\nprint(\"  First 3 values: [1.00, 1.01, 1.02]\")\nprint(\"tensor2 (5 elements starting at 2.0):\")\nprint(\"  All values: [2.0, 2.1, 2.2, 2.3, 2.4]\")\nprint(\"tensor3 (5 elements starting at 3.0):\")\nprint(\"  All values: [3.0, 3.1, 3.2, 3.3, 3.4]\")"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/create_tar_format.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nCreate TAR format test fixtures for burn-store integration tests.\n\nThe TAR format was used by very early versions of PyTorch (pre 0.1.10).\nModern torch.save cannot create this format, so we construct it manually.\n\nTAR format structure:\n  - sys_info: pickle with {protocol_version, little_endian, type_sizes}\n  - pickle: pickle with OrderedDict containing _rebuild_tensor_v2 REDUCE calls\n  - storages: count_pickle + for each storage: (key, device, class) pickle + u64 num_elements + raw data\n\"\"\"\n\nimport io\nimport pickle\nimport struct\nimport tarfile\nimport os\nfrom collections import OrderedDict\n\n\ndef create_sys_info():\n    \"\"\"Create sys_info pickle data.\"\"\"\n    sys_info = {\n        \"protocol_version\": 1000,\n        \"little_endian\": True,\n        \"type_sizes\": {\n            \"short\": 2,\n            \"int\": 4,\n            \"long\": 8,\n        },\n    }\n    return pickle.dumps(sys_info, protocol=2)\n\n\ndef encode_tensor_data(values: list, storage_type: str) -> tuple:\n    \"\"\"Encode tensor values to bytes and return (bytes, element_size).\"\"\"\n    fmt_map = {\n        \"FloatStorage\": (\"<f\", 4),\n        \"DoubleStorage\": (\"<d\", 8),\n        \"LongStorage\": (\"<q\", 8),\n        \"IntStorage\": (\"<i\", 4),\n        \"ShortStorage\": (\"<h\", 2),\n        \"ByteStorage\": (\"<B\", 1),\n        \"CharStorage\": (\"<b\", 1),\n        \"BoolStorage\": (\"<B\", 1),\n        \"HalfStorage\": (\"<e\", 2),\n    }\n    fmt, size = fmt_map[storage_type]\n    data = b\"\".join(struct.pack(fmt, v) for v in values)\n    return data, size\n\n\ndef write_int(buffer, value):\n    \"\"\"Write an integer using appropriate pickle opcode.\"\"\"\n    if 0 <= value < 256:\n        buffer.write(b'K')  # BININT1\n        buffer.write(bytes([value]))\n    elif 0 <= value < 65536:\n        buffer.write(b'M')  # BININT2\n        buffer.write(struct.pack('<H', value))\n    else:\n        buffer.write(b'J')  # BININT\n        buffer.write(struct.pack('<i', value))\n\n\ndef write_string(buffer, s):\n    \"\"\"Write a string using appropriate pickle opcode.\"\"\"\n    s_bytes = s.encode('utf-8')\n    if len(s_bytes) < 256:\n        buffer.write(b'U')  # SHORT_BINSTRING\n        buffer.write(bytes([len(s_bytes)]))\n        buffer.write(s_bytes)\n    else:\n        buffer.write(b'T')  # BINSTRING\n        buffer.write(struct.pack('<I', len(s_bytes)))\n        buffer.write(s_bytes)\n\n\ndef create_storages_blob_manual(tensors: list) -> bytes:\n    \"\"\"\n    Create the storages binary blob manually.\n\n    Args:\n        tensors: List of (key, storage_type, element_size, data_bytes) tuples\n    \"\"\"\n    buffer = io.BytesIO()\n\n    # Write storage count as pickle (simple integer)\n    pickle.dump(len(tensors), buffer, protocol=2)\n\n    for key, storage_type, element_size, data_bytes in tensors:\n        # Manually construct the tuple pickle with GLOBAL class reference\n        # Format: (key, \"cpu\", <class 'torch.FloatStorage'>)\n\n        tuple_buffer = io.BytesIO()\n        # Protocol 2 header\n        tuple_buffer.write(b'\\x80\\x02')\n\n        # Build tuple with MARK + items + TUPLE\n        tuple_buffer.write(b'(')  # MARK\n\n        # First item: storage key (string)\n        write_string(tuple_buffer, key)\n\n        # Second item: device \"cpu\"\n        tuple_buffer.write(b'U\\x03cpu')\n\n        # Third item: class reference using GLOBAL\n        tuple_buffer.write(b'c')  # GLOBAL opcode\n        tuple_buffer.write(b'torch\\n')  # module\n        tuple_buffer.write(storage_type.encode('ascii') + b'\\n')  # name\n\n        # End tuple\n        tuple_buffer.write(b't')  # TUPLE\n        tuple_buffer.write(b'.')  # STOP\n\n        buffer.write(tuple_buffer.getvalue())\n\n        # Write num_elements as u64 little-endian\n        num_elements = len(data_bytes) // element_size\n        buffer.write(struct.pack(\"<Q\", num_elements))\n\n        # Write raw data\n        buffer.write(data_bytes)\n\n    return buffer.getvalue()\n\n\ndef create_main_pickle_manual(tensors_info: list) -> bytes:\n    \"\"\"\n    Create the main pickle containing _rebuild_tensor_v2 REDUCE calls.\n\n    For each tensor, we need:\n    - GLOBAL torch._utils _rebuild_tensor_v2\n    - MARK\n    - args tuple: (persistent_id, offset, shape, stride, requires_grad, hooks)\n    - TUPLE\n    - REDUCE\n\n    The persistent_id is a PersistentTuple: ('storage', <class>, key, device, num_elements)\n    \"\"\"\n    buffer = io.BytesIO()\n\n    # Protocol 2 header\n    buffer.write(b'\\x80\\x02')\n\n    # Build OrderedDict: GLOBAL + EMPTY_LIST + items + TUPLE + REDUCE\n    # OrderedDict([('name1', tensor1), ('name2', tensor2)])\n\n    # GLOBAL collections OrderedDict\n    buffer.write(b'ccollections\\nOrderedDict\\n')\n\n    # Start list for items\n    buffer.write(b'(')  # MARK\n    buffer.write(b']')  # EMPTY_LIST\n\n    # For each tensor, add (name, rebuilt_tensor) to the list\n    for name, storage_key, storage_type, shape, num_elements in tensors_info:\n        # Calculate stride for row-major (C) order\n        stride = []\n        s = 1\n        for dim in reversed(shape):\n            stride.insert(0, s)\n            s *= dim\n\n        # Build inner tuple: (name, tensor_value)\n        buffer.write(b'(')  # MARK for (name, value) tuple\n\n        # Write name\n        write_string(buffer, name)\n\n        # Now build the tensor using _rebuild_tensor_v2 REDUCE\n        # GLOBAL torch._utils _rebuild_tensor_v2\n        buffer.write(b'ctorch._utils\\n_rebuild_tensor_v2\\n')\n\n        # Build args tuple for _rebuild_tensor_v2\n        # (persistent_id, offset, shape, stride, requires_grad, backward_hooks)\n        buffer.write(b'(')  # MARK for args tuple\n\n        # arg 0: persistent_id tuple: ('storage', class, key, device, num_elements)\n        # This will be converted to PersistentTuple by the reader\n        buffer.write(b'(')  # MARK for persistent_id\n\n        write_string(buffer, 'storage')\n\n        # Class reference - GLOBAL torch FloatStorage\n        buffer.write(b'c')\n        buffer.write(b'torch\\n')\n        buffer.write(storage_type.encode('ascii') + b'\\n')\n\n        # Storage key\n        write_string(buffer, storage_key)\n\n        # Device\n        buffer.write(b'U\\x03cpu')\n\n        # num_elements\n        write_int(buffer, num_elements)\n\n        buffer.write(b't')  # TUPLE - end persistent_id\n\n        # arg 1: storage offset (0)\n        buffer.write(b'K\\x00')\n\n        # arg 2: shape tuple\n        buffer.write(b'(')\n        for dim in shape:\n            write_int(buffer, dim)\n        buffer.write(b't')\n\n        # arg 3: stride tuple\n        buffer.write(b'(')\n        for s_val in stride:\n            write_int(buffer, s_val)\n        buffer.write(b't')\n\n        # arg 4: requires_grad (False)\n        buffer.write(b'\\x89')  # NEWFALSE\n\n        # arg 5: backward_hooks (empty OrderedDict)\n        buffer.write(b'ccollections\\nOrderedDict\\n')\n        buffer.write(b'(')\n        buffer.write(b']')\n        buffer.write(b't')\n        buffer.write(b'R')  # REDUCE to create empty OrderedDict\n\n        buffer.write(b't')  # TUPLE - end args tuple\n\n        buffer.write(b'R')  # REDUCE - call _rebuild_tensor_v2 with args\n\n        buffer.write(b't')  # TUPLE - end (name, tensor) tuple\n\n        buffer.write(b'a')  # APPEND to list\n\n    buffer.write(b't')  # TUPLE - wrap list in tuple for REDUCE\n    buffer.write(b'R')  # REDUCE - call OrderedDict with the list\n    buffer.write(b'.')  # STOP\n\n    return buffer.getvalue()\n\n\ndef create_tar_pytorch_file(filename: str, tensors: dict, dtypes: dict):\n    \"\"\"\n    Create a TAR format PyTorch file.\n\n    Args:\n        filename: Output file path\n        tensors: Dict of tensor_name -> (values_list, shape)\n        dtypes: Dict of tensor_name -> storage_type\n    \"\"\"\n    # Prepare storage data\n    storage_list = []  # (key, storage_type, element_size, data_bytes)\n    tensors_info = []  # (name, storage_key, storage_type, shape, num_elements)\n\n    for idx, (name, (values, shape)) in enumerate(tensors.items()):\n        storage_key = str(idx)\n        storage_type = dtypes[name]\n        data_bytes, element_size = encode_tensor_data(values, storage_type)\n        num_elements = len(values)\n\n        storage_list.append((storage_key, storage_type, element_size, data_bytes))\n        tensors_info.append((name, storage_key, storage_type, shape, num_elements))\n\n    # Create the three main entries\n    sys_info_data = create_sys_info()\n    pickle_data = create_main_pickle_manual(tensors_info)\n    storages_data = create_storages_blob_manual(storage_list)\n\n    # Write TAR archive\n    os.makedirs(os.path.dirname(filename) or \".\", exist_ok=True)\n\n    with tarfile.open(filename, \"w\") as tar:\n        # Add sys_info\n        tarinfo = tarfile.TarInfo(name=\"sys_info\")\n        tarinfo.size = len(sys_info_data)\n        tar.addfile(tarinfo, io.BytesIO(sys_info_data))\n\n        # Add pickle\n        tarinfo = tarfile.TarInfo(name=\"pickle\")\n        tarinfo.size = len(pickle_data)\n        tar.addfile(tarinfo, io.BytesIO(pickle_data))\n\n        # Add storages\n        tarinfo = tarfile.TarInfo(name=\"storages\")\n        tarinfo.size = len(storages_data)\n        tar.addfile(tarinfo, io.BytesIO(storages_data))\n\n    size = os.path.getsize(filename)\n    print(f\"Created {filename} ({size} bytes)\")\n    print(f\"  Tensors: {list(tensors.keys())}\")\n\n\ndef main():\n    # Create test_data directory\n    os.makedirs(\"test_data\", exist_ok=True)\n\n    # Test 1: Single float32 tensor\n    create_tar_pytorch_file(\n        \"test_data/tar_float32.tar\",\n        {\"tensor\": ([1.0, 2.5, -3.7, 0.0], [4])},\n        {\"tensor\": \"FloatStorage\"},\n    )\n\n    # Test 2: Single float64 tensor\n    create_tar_pytorch_file(\n        \"test_data/tar_float64.tar\",\n        {\"tensor\": ([1.1, 2.2, 3.3], [3])},\n        {\"tensor\": \"DoubleStorage\"},\n    )\n\n    # Test 3: Single int64 tensor\n    create_tar_pytorch_file(\n        \"test_data/tar_int64.tar\",\n        {\"tensor\": ([100, -200, 300, 0], [4])},\n        {\"tensor\": \"LongStorage\"},\n    )\n\n    # Test 4: Multiple tensors (weight + bias)\n    create_tar_pytorch_file(\n        \"test_data/tar_weight_bias.tar\",\n        {\n            \"weight\": ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [2, 3]),\n            \"bias\": ([0.01, 0.02], [2]),\n        },\n        {\n            \"weight\": \"FloatStorage\",\n            \"bias\": \"FloatStorage\",\n        },\n    )\n\n    # Test 5: Different dtypes in one file\n    create_tar_pytorch_file(\n        \"test_data/tar_multi_dtype.tar\",\n        {\n            \"float_tensor\": ([1.5, 2.5, 3.5], [3]),\n            \"double_tensor\": ([1.111, 2.222], [2]),\n            \"int_tensor\": ([10, 20, 30, 40], [4]),\n        },\n        {\n            \"float_tensor\": \"FloatStorage\",\n            \"double_tensor\": \"DoubleStorage\",\n            \"int_tensor\": \"LongStorage\",\n        },\n    )\n\n    # Test 6: 2D tensor for shape verification\n    create_tar_pytorch_file(\n        \"test_data/tar_2d_tensor.tar\",\n        {\n            \"matrix\": ([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [3, 4]),\n        },\n        {\"matrix\": \"FloatStorage\"},\n    )\n\n    print(\"\\nAll TAR format test files created!\")\n    print(\"\\nTo run tests: cargo test -p burn-store --features pytorch test_tar\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/mod.rs",
    "content": "//! Tests for PyTorch file reader functionality\n//!\n//! Floating-point comparison tolerances:\n//! - F16/BF16: 1e-2 (~3 decimal digits precision)\n//! - F32: 1e-6 (~7 decimal digits precision)\n//! - F64: 1e-10 (~16 decimal digits precision)\n\n#![allow(clippy::needless_range_loop)]\n\nuse crate::pytorch::PytorchReader;\n// Import internal types for testing only\nuse crate::pytorch::reader::{ByteOrder, FileFormat};\nuse burn_tensor::{BoolStore, DType, shape};\nuse std::path::PathBuf;\n\nfn test_data_path(filename: &str) -> PathBuf {\n    // Get the path relative to the crate root\n    PathBuf::from(env!(\"CARGO_MANIFEST_DIR\"))\n        .join(\"src\")\n        .join(\"pytorch\")\n        .join(\"tests\")\n        .join(\"reader\")\n        .join(\"test_data\")\n        .join(filename)\n}\n\n#[test]\nfn test_float32_tensor() {\n    let path = test_data_path(\"float32.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load float32.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 4);\n    assert!((values[0] - 1.0).abs() < 1e-6);\n    assert!((values[1] - 2.5).abs() < 1e-6);\n    assert!((values[2] - (-3.7)).abs() < 1e-6);\n    assert!((values[3] - 0.0).abs() < 1e-6);\n}\n\n#[test]\nfn test_float64_tensor() {\n    let path = test_data_path(\"float64.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load float64.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F64);\n    assert_eq!(tensor.shape, shape![3]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f64>().unwrap();\n    assert_eq!(values.len(), 3);\n    assert!((values[0] - 1.1).abs() < 1e-10);\n    assert!((values[1] - 2.2).abs() < 1e-10);\n    assert!((values[2] - 3.3).abs() < 1e-10);\n}\n\n#[test]\nfn test_int64_tensor() {\n    let path = test_data_path(\"int64.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load int64.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::I64);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<i64>().unwrap();\n    assert_eq!(values, &[100, -200, 300, 0]);\n}\n\n#[test]\nfn test_int32_tensor() {\n    let path = test_data_path(\"int32.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load int32.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::I32);\n    assert_eq!(tensor.shape, shape![3]);\n\n    let data = tensor.to_data().unwrap();\n    // Convert to the appropriate element type\n    let data_converted = data.convert::<i32>();\n    let values = data_converted.as_slice::<i32>().unwrap();\n    assert_eq!(values, &[10, 20, -30]);\n}\n\n#[test]\nfn test_int16_tensor() {\n    let path = test_data_path(\"int16.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load int16.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::I16);\n    assert_eq!(tensor.shape, shape![3]);\n\n    let data = tensor.to_data().unwrap();\n    let data_converted = data.convert::<i16>();\n    let values = data_converted.as_slice::<i16>().unwrap();\n    assert_eq!(values, &[1000, -2000, 3000]);\n}\n\n#[test]\nfn test_int8_tensor() {\n    let path = test_data_path(\"int8.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load int8.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::I8);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let data_converted = data.convert::<i8>();\n    let values = data_converted.as_slice::<i8>().unwrap();\n    assert_eq!(values, &[127, -128, 0, 50]);\n}\n\n#[test]\nfn test_bool_tensor() {\n    let path = test_data_path(\"bool.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load bool.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::Bool(BoolStore::Native));\n    assert_eq!(tensor.shape, shape![5]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<bool>().unwrap();\n    assert_eq!(values, &[true, false, true, true, false]);\n}\n\n#[test]\nfn test_uint8_tensor() {\n    let path = test_data_path(\"uint8.pt\");\n\n    let reader = PytorchReader::new(&path).expect(\"Failed to load uint8.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::U8);\n    assert_eq!(tensor.shape, shape![4]);\n\n    // Verify actual U8 values [0, 128, 255, 42] from test_data.py\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<u8>().unwrap();\n    assert_eq!(values, &[0, 128, 255, 42]);\n}\n\n#[test]\nfn test_float16_tensor() {\n    use half::f16;\n\n    let path = test_data_path(\"float16.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load float16.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F16);\n    assert_eq!(tensor.shape, shape![3]);\n\n    // Verify actual F16 values [1.5, -2.25, 3.125] from test_data.py\n    let data = tensor.to_data().unwrap();\n    assert_eq!(data.shape, shape![3]);\n    let values = data.as_slice::<f16>().unwrap();\n    assert_eq!(values.len(), 3);\n    assert!((values[0].to_f32() - 1.5).abs() < 1e-2);\n    assert!((values[1].to_f32() - (-2.25)).abs() < 1e-2);\n    assert!((values[2].to_f32() - 3.125).abs() < 1e-2);\n}\n\n#[test]\nfn test_bfloat16_tensor() {\n    use half::bf16;\n\n    let path = test_data_path(\"bfloat16.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load bfloat16.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::BF16);\n    assert_eq!(tensor.shape, shape![3]);\n\n    // Verify actual BF16 values [1.5, -2.5, 3.5] from test_data.py\n    let data = tensor.to_data().unwrap();\n    assert_eq!(data.shape, shape![3]);\n    let values = data.as_slice::<bf16>().unwrap();\n    assert_eq!(values.len(), 3);\n    assert!((values[0].to_f32() - 1.5).abs() < 1e-2);\n    assert!((values[1].to_f32() - (-2.5)).abs() < 1e-2);\n    assert!((values[2].to_f32() - 3.5).abs() < 1e-2);\n}\n\n#[test]\nfn test_2d_tensor() {\n    let path = test_data_path(\"tensor_2d.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tensor_2d.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![3, 2]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 6);\n    // Check flattened values [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]\n    for (i, expected) in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0].iter().enumerate() {\n        assert!((values[i] - expected).abs() < 1e-6);\n    }\n}\n\n#[test]\nfn test_3d_tensor() {\n    let path = test_data_path(\"tensor_3d.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tensor_3d.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![2, 3, 4]);\n\n    let data = tensor.to_data().unwrap();\n    assert_eq!(data.shape, shape![2, 3, 4]);\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 24);\n}\n\n#[test]\nfn test_4d_tensor() {\n    let path = test_data_path(\"tensor_4d.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tensor_4d.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![2, 3, 2, 2]);\n\n    let data = tensor.to_data().unwrap();\n    assert_eq!(data.shape, shape![2, 3, 2, 2]);\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 24);\n}\n\n#[test]\nfn test_state_dict() {\n    let path = test_data_path(\"state_dict.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load state_dict.pt\");\n    let keys = reader.keys();\n\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"weight\".to_string()));\n    assert!(keys.contains(&\"bias\".to_string()));\n    assert!(keys.contains(&\"running_mean\".to_string()));\n    assert!(keys.contains(&\"running_var\".to_string()));\n\n    // Check weight tensor\n    let weight = reader.get(\"weight\").unwrap();\n    assert_eq!(weight.shape, shape![3, 4]);\n    assert_eq!(weight.dtype, DType::F32);\n\n    // Check bias tensor\n    let bias = reader.get(\"bias\").unwrap();\n    assert_eq!(bias.shape, shape![3]);\n    assert_eq!(bias.dtype, DType::F32);\n\n    // Check running_mean (should be zeros)\n    let running_mean = reader.get(\"running_mean\").unwrap();\n    assert_eq!(running_mean.shape, shape![3]);\n    let mean_data = running_mean.to_data().unwrap();\n    let mean_values = mean_data.as_slice::<f32>().unwrap();\n    assert!(mean_values.iter().all(|&v| v.abs() < 1e-6));\n\n    // Check running_var (should be ones)\n    let running_var = reader.get(\"running_var\").unwrap();\n    assert_eq!(running_var.shape, shape![3]);\n    let var_data = running_var.to_data().unwrap();\n    let var_values = var_data.as_slice::<f32>().unwrap();\n    assert!(var_values.iter().all(|&v| (v - 1.0).abs() < 1e-6));\n}\n\n#[test]\nfn test_nested_dict() {\n    let path = test_data_path(\"nested_dict.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load nested_dict.pt\");\n    let keys = reader.keys();\n\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"layer1.weight\".to_string()));\n    assert!(keys.contains(&\"layer1.bias\".to_string()));\n    assert!(keys.contains(&\"layer2.weight\".to_string()));\n    assert!(keys.contains(&\"layer2.bias\".to_string()));\n\n    // Check layer1.weight and load data\n    let layer1_weight = reader.get(\"layer1.weight\").unwrap();\n    assert_eq!(layer1_weight.shape, shape![2, 3]);\n    assert_eq!(layer1_weight.dtype, DType::F32);\n    let data = layer1_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 6); // 2x3 = 6 elements\n\n    // Check layer2.weight and load data\n    let layer2_weight = reader.get(\"layer2.weight\").unwrap();\n    assert_eq!(layer2_weight.shape, shape![4, 2]);\n    assert_eq!(layer2_weight.dtype, DType::F32);\n    let data = layer2_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 8); // 4x2 = 8 elements\n}\n\n#[test]\nfn test_checkpoint() {\n    let path = test_data_path(\"checkpoint.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load checkpoint.pt\");\n    let keys = reader.keys();\n\n    // Should have model_state_dict entries and optimizer entries\n    assert!(keys.contains(&\"model_state_dict.fc1.weight\".to_string()));\n    assert!(keys.contains(&\"model_state_dict.fc1.bias\".to_string()));\n    assert!(keys.contains(&\"model_state_dict.fc2.weight\".to_string()));\n    assert!(keys.contains(&\"model_state_dict.fc2.bias\".to_string()));\n\n    // Check fc1.weight dimensions and load data\n    let fc1_weight = reader.get(\"model_state_dict.fc1.weight\").unwrap();\n    assert_eq!(fc1_weight.shape, shape![10, 5]);\n    let data = fc1_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 50); // 10x5 = 50 elements\n\n    // Check fc2.weight dimensions and load data\n    let fc2_weight = reader.get(\"model_state_dict.fc2.weight\").unwrap();\n    assert_eq!(fc2_weight.shape, shape![3, 10]);\n    let data = fc2_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 30); // 3x10 = 30 elements\n}\n\n#[test]\nfn test_empty_tensor() {\n    let path = test_data_path(\"empty.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load empty.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.shape, shape![0]); // Empty tensor has shape [0]\n    assert_eq!(tensor.dtype, DType::F32);\n\n    // Note: Empty tensors cannot be loaded with to_data() due to TensorData validation\n    // We verify the metadata is correct, which confirms the .pt file is being read\n}\n\n#[test]\nfn test_scalar_tensor() {\n    let path = test_data_path(\"scalar.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load scalar.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.shape, shape![]); // Scalar has empty shape\n    assert_eq!(tensor.dtype, DType::F32);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 1);\n    assert!((values[0] - 42.0).abs() < 1e-6);\n}\n\n#[test]\nfn test_large_shape() {\n    let path = test_data_path(\"large_shape.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load large_shape.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.shape, shape![100, 100]);\n    assert_eq!(tensor.dtype, DType::F32);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 10000);\n\n    // Check specific non-zero values\n    assert!((values[0] - 1.0).abs() < 1e-6); // [0, 0] = 1.0\n    assert!((values[5050] - 2.0).abs() < 1e-6); // [50, 50] = 2.0\n    assert!((values[9999] - 3.0).abs() < 1e-6); // [99, 99] = 3.0\n}\n\n#[test]\nfn test_mixed_types() {\n    let path = test_data_path(\"mixed_types.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load mixed_types.pt\");\n    let tensors = reader.tensors();\n\n    assert_eq!(tensors.len(), 4);\n\n    // Check float32 tensor [1.0, 2.0] from test_data.py\n    let float32 = reader.get(\"float32\").unwrap();\n    assert_eq!(float32.dtype, DType::F32);\n    assert_eq!(float32.shape, shape![2]);\n    let data = float32.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert!((values[0] - 1.0).abs() < 1e-6);\n    assert!((values[1] - 2.0).abs() < 1e-6);\n\n    // Check int64 tensor [100, 200] from test_data.py\n    let int64 = reader.get(\"int64\").unwrap();\n    assert_eq!(int64.dtype, DType::I64);\n    assert_eq!(int64.shape, shape![2]);\n    let data = int64.to_data().unwrap();\n    let values = data.as_slice::<i64>().unwrap();\n    assert_eq!(values, &[100, 200]);\n\n    // Check bool tensor [True, False] from test_data.py\n    let bool_tensor = reader.get(\"bool\").unwrap();\n    assert_eq!(bool_tensor.dtype, DType::Bool(BoolStore::Native));\n    assert_eq!(bool_tensor.shape, shape![2]);\n    let data = bool_tensor.to_data().unwrap();\n    let values = data.as_slice::<bool>().unwrap();\n    assert_eq!(values, &[true, false]);\n\n    // Check float64 tensor [1.1, 2.2] from test_data.py\n    let float64 = reader.get(\"float64\").unwrap();\n    assert_eq!(float64.dtype, DType::F64);\n    assert_eq!(float64.shape, shape![2]);\n    let data = float64.to_data().unwrap();\n    let values = data.as_slice::<f64>().unwrap();\n    assert!((values[0] - 1.1).abs() < 1e-10);\n    assert!((values[1] - 2.2).abs() < 1e-10);\n}\n\n#[test]\nfn test_special_values() {\n    let path = test_data_path(\"special_values.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load special_values.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![5]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 5);\n\n    // Check for special values\n    assert!(values[0].is_nan());\n    assert!(values[1].is_infinite() && values[1] > 0.0);\n    assert!(values[2].is_infinite() && values[2] < 0.0);\n    assert!((values[3] - 0.0).abs() < 1e-6);\n    assert!((values[4] - 1.0).abs() < 1e-6);\n}\n\n#[test]\nfn test_extreme_values() {\n    let path = test_data_path(\"extreme_values.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load extreme_values.pt\");\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 4);\n\n    // Very small positive\n    assert!(values[0] > 0.0 && values[0] < 1e-20);\n    // Very large positive\n    assert!(values[1] > 1e20);\n    // Very small negative\n    assert!(values[2] < 0.0 && values[2] > -1e-20);\n    // Very large negative\n    assert!(values[3] < -1e20);\n}\n\n#[test]\nfn test_parameter() {\n    let path = test_data_path(\"parameter.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load parameter.pt\");\n    let tensors = reader.tensors();\n\n    // nn.Parameter is typically saved as a regular tensor\n    assert_eq!(tensors.len(), 1);\n    let param = reader.get(\"param\").unwrap();\n    assert_eq!(param.shape, shape![3, 3]);\n    assert_eq!(param.dtype, DType::F32);\n\n    let data = param.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 9);\n}\n\n#[test]\nfn test_buffers() {\n    let path = test_data_path(\"buffers.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load buffers.pt\");\n    let tensors = reader.tensors();\n\n    assert_eq!(tensors.len(), 2);\n\n    // Check buffer1 (int32)\n    let buffer1 = reader.get(\"buffer1\").unwrap();\n    assert_eq!(buffer1.dtype, DType::I32);\n    assert_eq!(buffer1.shape, shape![3]);\n    let data1 = buffer1.to_data().unwrap();\n    let data1_converted = data1.convert::<i32>();\n    let values1 = data1_converted.as_slice::<i32>().unwrap();\n    assert_eq!(values1, &[1, 2, 3]);\n\n    // Check buffer2 (bool)\n    let buffer2 = reader.get(\"buffer2\").unwrap();\n    assert_eq!(buffer2.dtype, DType::Bool(BoolStore::Native));\n    assert_eq!(buffer2.shape, shape![2]);\n    let data2 = buffer2.to_data().unwrap();\n    let values2 = data2.as_slice::<bool>().unwrap();\n    assert_eq!(values2, &[true, false]);\n}\n\n#[test]\nfn test_complex_structure() {\n    let path = test_data_path(\"complex_structure.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load complex_structure.pt\");\n    let keys = reader.keys();\n\n    // Should have nested structure tensors\n    assert!(keys.contains(&\"state.encoder.layer_0.weight\".to_string()));\n    assert!(keys.contains(&\"state.encoder.layer_0.bias\".to_string()));\n    assert!(keys.contains(&\"state.encoder.layer_1.weight\".to_string()));\n    assert!(keys.contains(&\"state.encoder.layer_1.bias\".to_string()));\n    assert!(keys.contains(&\"state.decoder.weight\".to_string()));\n    assert!(keys.contains(&\"state.decoder.bias\".to_string()));\n\n    // Check encoder layer_0 weight and load data\n    let layer0_weight = reader.get(\"state.encoder.layer_0.weight\").unwrap();\n    assert_eq!(layer0_weight.shape, shape![4, 3]);\n    let data = layer0_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 12); // 4x3 = 12 elements\n\n    // Check decoder weight and load data\n    let decoder_weight = reader.get(\"state.decoder.weight\").unwrap();\n    assert_eq!(decoder_weight.shape, shape![3, 2]);\n    let data = decoder_weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 6); // 3x2 = 6 elements\n}\n\n#[test]\nfn test_read_pytorch_tensors_convenience() {\n    // Test reading and materializing tensors into memory\n    let path = test_data_path(\"state_dict.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to read file\");\n\n    let keys = reader.keys();\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"weight\".to_string()));\n    assert!(keys.contains(&\"bias\".to_string()));\n\n    // Check that data can be materialized\n    let weight = reader.get(\"weight\").unwrap();\n    let weight_data = weight.to_data().unwrap();\n    assert_eq!(weight_data.shape, shape![3, 4]);\n    assert_eq!(weight_data.dtype, DType::F32);\n}\n\n#[test]\nfn test_with_top_level_key() {\n    // Test loading with a specific top-level key\n    let path = test_data_path(\"checkpoint.pt\");\n\n    // Load only model_state_dict\n    let reader = PytorchReader::with_top_level_key(&path, \"model_state_dict\")\n        .expect(\"Failed to load with top-level key\");\n\n    let keys = reader.keys();\n    // Should only have model weights, not optimizer state\n    assert!(keys.contains(&\"fc1.weight\".to_string()));\n    assert!(keys.contains(&\"fc1.bias\".to_string()));\n    assert!(keys.contains(&\"fc2.weight\".to_string()));\n    assert!(keys.contains(&\"fc2.bias\".to_string()));\n\n    // Should NOT have nested paths with model_state_dict prefix\n    assert!(!keys.contains(&\"model_state_dict.fc1.weight\".to_string()));\n}\n\n#[test]\nfn test_legacy_format() {\n    // Test loading PyTorch legacy format (pre-1.6)\n    let path = test_data_path(\"simple_legacy.pt\");\n\n    // This file has the sequential pickle structure of legacy PyTorch format\n    let reader = PytorchReader::new(&path).expect(\"Failed to load legacy format\");\n    let keys = reader.keys();\n\n    // Should have the tensors from the state dict\n    assert!(keys.contains(&\"weight\".to_string()), \"Missing 'weight' key\");\n    assert!(keys.contains(&\"bias\".to_string()), \"Missing 'bias' key\");\n    assert!(\n        keys.contains(&\"running_mean\".to_string()),\n        \"Missing 'running_mean' key\"\n    );\n\n    // Check weight tensor\n    let weight = reader.get(\"weight\").expect(\"weight not found\");\n    assert_eq!(weight.shape, shape![2, 3]);\n    assert_eq!(weight.dtype, DType::F32);\n\n    // Check bias tensor\n    let bias = reader.get(\"bias\").expect(\"bias not found\");\n    assert_eq!(bias.shape, shape![2]);\n    assert_eq!(bias.dtype, DType::F32);\n\n    // Verify bias values are all ones\n    let bias_data = bias.to_data().unwrap();\n    let bias_values = bias_data.as_slice::<f32>().unwrap();\n    // Note: values in simple_legacy.pt are randomly generated, not necessarily 1.0\n    assert_eq!(bias_values.len(), 2);\n\n    // Check running_mean tensor\n    let running_mean = reader.get(\"running_mean\").expect(\"running_mean not found\");\n    assert_eq!(running_mean.shape, shape![2]);\n    assert_eq!(running_mean.dtype, DType::F32);\n\n    // Verify running_mean values are accessible\n    let mean_data = running_mean.to_data().unwrap();\n    let mean_values = mean_data.as_slice::<f32>().unwrap();\n    assert_eq!(mean_values.len(), 2);\n}\n\n#[test]\nfn test_legacy_with_offsets() {\n    // Test with legacy format file that has storage offsets\n    let path = test_data_path(\"legacy_with_offsets.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Should read legacy file with offsets\");\n\n    let keys = reader.keys();\n    assert_eq!(keys.len(), 3, \"Should have 3 tensors\");\n\n    // Check that tensors exist\n    for key in &keys {\n        assert!(reader.get(key).is_some(), \"Should have tensor: {}\", key);\n        let tensor = reader.get(key).unwrap();\n        let data = tensor.to_data().unwrap();\n        let values = data.as_slice::<f32>().unwrap();\n        assert!(!values.is_empty(), \"Tensor {} should have data\", key);\n    }\n}\n\n#[test]\nfn test_legacy_shared_storage() {\n    // Test with legacy format file that has shared storage\n    let path = test_data_path(\"legacy_shared_storage.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Should read legacy file with shared storage\");\n\n    let keys = reader.keys();\n    assert!(keys.len() >= 2, \"Should have at least 2 tensors\");\n\n    // Check that tensors exist and can be loaded\n    for key in &keys {\n        assert!(reader.get(key).is_some(), \"Should have tensor: {}\", key);\n        let tensor = reader.get(key).unwrap();\n        let data = tensor.to_data().unwrap();\n\n        // Verify tensor data can be accessed\n        match tensor.dtype {\n            DType::F32 => {\n                let values = data.as_slice::<f32>().unwrap();\n                assert!(!values.is_empty(), \"Tensor {} should have data\", key);\n            }\n            DType::I64 => {\n                let values = data.as_slice::<i64>().unwrap();\n                assert!(!values.is_empty(), \"Tensor {} should have data\", key);\n            }\n            _ => {\n                // For other types, just verify we can convert to data\n                assert!(!data.shape.is_empty(), \"Tensor {} should have shape\", key);\n            }\n        }\n    }\n}\n\n#[test]\nfn test_metadata_zip_format() {\n    // Test that metadata is properly populated for ZIP format files\n    let path = test_data_path(\"float32.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load float32.pt\");\n\n    // Check metadata\n    let metadata = reader.metadata();\n    assert_eq!(metadata.format_type, FileFormat::Zip);\n    assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);\n    assert_eq!(metadata.tensor_count, 1);\n    assert!(metadata.total_data_size.is_some());\n\n    // Check that metadata is accessible\n    assert!(metadata.is_modern_format());\n    assert!(!metadata.is_legacy_format());\n}\n\n#[test]\nfn test_metadata_legacy_format() {\n    // Test that metadata is properly populated for legacy format files\n    let path = test_data_path(\"simple_legacy.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load legacy file\");\n\n    // Check metadata\n    let metadata = reader.metadata();\n    assert_eq!(metadata.format_type, FileFormat::Legacy);\n    assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);\n    assert_eq!(metadata.tensor_count, 3); // weight, bias, running_mean\n    assert!(metadata.total_data_size.is_some());\n}\n\n#[test]\nfn test_legacy_metadata_detailed() {\n    // Detailed test to prove we load all metadata for legacy format files\n    let path = test_data_path(\"simple_legacy.pt\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load legacy file\");\n\n    // Get and examine metadata\n    let metadata = reader.metadata();\n\n    // Verify the metadata is correct for legacy format\n    assert_eq!(\n        metadata.format_type,\n        FileFormat::Legacy,\n        \"Should be Legacy format\"\n    );\n    assert_eq!(\n        metadata.byte_order,\n        ByteOrder::LittleEndian,\n        \"Legacy format is little-endian\"\n    );\n    assert_eq!(\n        metadata.tensor_count, 3,\n        \"Should have 3 tensors: weight, bias, running_mean\"\n    );\n    assert!(\n        metadata.total_data_size.is_some(),\n        \"Should have total data size\"\n    );\n    assert!(\n        metadata.total_data_size.unwrap() > 0,\n        \"Data size should be positive\"\n    );\n\n    // Legacy format specifics\n    assert_eq!(\n        metadata.format_version, None,\n        \"Legacy format doesn't have version file\"\n    );\n    assert_eq!(\n        metadata.pytorch_version, None,\n        \"Legacy format doesn't store PyTorch version reliably\"\n    );\n    assert!(\n        !metadata.has_storage_alignment,\n        \"Legacy format doesn't have storage alignment\"\n    );\n\n    // Also verify we can access the tensors\n    let keys = reader.keys();\n    assert!(\n        keys.contains(&\"weight\".to_string()),\n        \"Should have weight tensor\"\n    );\n    assert!(\n        keys.contains(&\"bias\".to_string()),\n        \"Should have bias tensor\"\n    );\n    assert!(\n        keys.contains(&\"running_mean\".to_string()),\n        \"Should have running_mean tensor\"\n    );\n}\n\n#[test]\nfn test_small_invalid_file() {\n    // Test that we handle broken/invalid files gracefully\n    let path = test_data_path(\"broken.pt\");\n\n    // Should fail gracefully with an appropriate error\n    let result = PytorchReader::new(&path);\n    assert!(result.is_err(), \"Expected error for broken file\");\n\n    // The error should be a pickle error since the file is too small to be valid\n    if let Err(e) = result {\n        let err_str = format!(\"{}\", e);\n        assert!(\n            err_str.contains(\"Pickle\") || err_str.contains(\"Invalid\"),\n            \"Error should mention pickle or invalid format: {}\",\n            err_str\n        );\n    }\n}\n\n#[test]\nfn test_read_pickle_data_basic() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test reading pickle data from a checkpoint file\n    let path = test_data_path(\"checkpoint.pt\");\n\n    // Read the entire pickle data\n    let data = PytorchReader::read_pickle_data(&path, None).expect(\"Failed to read pickle data\");\n\n    // Should be a dictionary at the root\n    if let PickleValue::Dict(dict) = data {\n        // Check that expected keys exist\n        assert!(dict.contains_key(\"model_state_dict\"));\n        assert!(dict.contains_key(\"optimizer_state_dict\"));\n        assert!(dict.contains_key(\"epoch\"));\n        assert!(dict.contains_key(\"loss\"));\n\n        // Check epoch value\n        if let Some(PickleValue::Int(epoch)) = dict.get(\"epoch\") {\n            assert_eq!(*epoch, 42);\n        } else {\n            panic!(\"Expected epoch to be an integer\");\n        }\n\n        // Check loss value\n        if let Some(PickleValue::Float(loss)) = dict.get(\"loss\") {\n            assert!(*loss > 0.0 && *loss < 1.0, \"Loss should be between 0 and 1\");\n        } else {\n            panic!(\"Expected loss to be a float\");\n        }\n    } else {\n        panic!(\"Expected root to be a dictionary\");\n    }\n}\n\n#[test]\nfn test_read_pickle_data_with_key() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test reading specific key from checkpoint\n    let path = test_data_path(\"checkpoint.pt\");\n\n    // Read only the model_state_dict\n    let data = PytorchReader::read_pickle_data(&path, Some(\"model_state_dict\"))\n        .expect(\"Failed to read pickle data with key\");\n\n    // Should get the model_state_dict directly\n    if let PickleValue::Dict(dict) = data {\n        // Should have model weights\n        assert!(dict.contains_key(\"fc1.weight\"));\n        assert!(dict.contains_key(\"fc1.bias\"));\n        assert!(dict.contains_key(\"fc2.weight\"));\n        assert!(dict.contains_key(\"fc2.bias\"));\n\n        // Should NOT have optimizer keys\n        assert!(!dict.contains_key(\"optimizer_state_dict\"));\n        assert!(!dict.contains_key(\"epoch\"));\n    } else {\n        panic!(\"Expected model_state_dict to be a dictionary\");\n    }\n}\n\n#[test]\nfn test_read_pickle_data_nested_structure() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test reading nested dictionary structure\n    let path = test_data_path(\"nested_dict.pt\");\n\n    let data =\n        PytorchReader::read_pickle_data(&path, None).expect(\"Failed to read nested structure\");\n\n    if let PickleValue::Dict(dict) = data {\n        // nested_dict.pt has a nested structure, not flat keys\n        // It should have layer1 and layer2 as nested dicts\n        assert!(!dict.is_empty(), \"Dictionary should not be empty\");\n\n        // The structure depends on how the file was saved\n        // It could be flat keys like \"layer1.weight\" or nested dicts\n        // Just verify it's a valid dict structure\n        for (_key, value) in dict.iter() {\n            // Values could be None (tensors), nested dicts, or other types\n            assert!(\n                matches!(value, PickleValue::None | PickleValue::Dict(_)),\n                \"Values should be None or nested dicts\"\n            );\n        }\n    } else {\n        panic!(\"Expected nested_dict to be a dictionary\");\n    }\n}\n\n#[test]\nfn test_read_pickle_data_types() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test various data types in mixed_types.pt\n    let path = test_data_path(\"mixed_types.pt\");\n\n    let data = PytorchReader::read_pickle_data(&path, None).expect(\"Failed to read mixed types\");\n\n    if let PickleValue::Dict(dict) = data {\n        // The file contains different tensor types\n        assert!(dict.len() >= 3, \"Should have at least 3 tensor types\");\n\n        // All tensor values should be None in pickle data\n        for (_key, value) in dict.iter() {\n            // All values should be None (tensors are not included in pickle data)\n            assert!(\n                matches!(value, PickleValue::None),\n                \"Tensors should be None in pickle data\"\n            );\n        }\n    } else {\n        panic!(\"Expected mixed_types to be a dictionary\");\n    }\n}\n\n#[test]\nfn test_read_pickle_data_key_not_found() {\n    // Test error handling when key doesn't exist\n    let path = test_data_path(\"checkpoint.pt\");\n\n    let result = PytorchReader::read_pickle_data(&path, Some(\"nonexistent_key\"));\n    assert!(result.is_err());\n\n    if let Err(e) = result {\n        let err_str = format!(\"{}\", e);\n        assert!(\n            err_str.contains(\"not found\"),\n            \"Error should mention key not found: {}\",\n            err_str\n        );\n    }\n}\n\n#[test]\nfn test_read_pickle_data_simple_pickle() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test reading a simple pickle file (not ZIP)\n    // Note: simple_legacy.pt is a legacy format file, not a simple pickle\n    // It may return None because legacy format reading is different\n    let path = test_data_path(\"state_dict.pt\"); // Use a proper simple pickle file\n\n    let data = PytorchReader::read_pickle_data(&path, None).expect(\"Failed to read simple pickle\");\n\n    // Should contain state dict entries\n    if let PickleValue::Dict(dict) = data {\n        // state_dict.pt has weight, bias, running_mean, running_var\n        assert!(dict.len() >= 3);\n        assert!(dict.contains_key(\"weight\"));\n        assert!(dict.contains_key(\"bias\"));\n\n        // All tensor values should be None in pickle data\n        for (_key, value) in dict.iter() {\n            assert!(matches!(value, PickleValue::None));\n        }\n    } else {\n        panic!(\"Expected state_dict to contain a dictionary\");\n    }\n}\n\n#[test]\nfn test_load_config_basic() {\n    let path = test_data_path(\"checkpoint.pt\");\n\n    // Define a struct that matches part of the checkpoint data\n    #[derive(Debug, serde::Deserialize, PartialEq)]\n    struct CheckpointConfig {\n        epoch: i64,\n        loss: f64,\n    }\n\n    // Load config\n    let config: CheckpointConfig =\n        PytorchReader::load_config(&path, None).expect(\"Failed to load config\");\n\n    // Verify values - based on test_read_pickle_data_basic\n    assert_eq!(config.epoch, 42);\n    assert!((config.loss - 0.123).abs() < 1e-6);\n}\n\n#[test]\nfn test_load_config_with_top_level_key() {\n    // Test that we can extract a non-existent key and get an appropriate error\n    let path = test_data_path(\"checkpoint.pt\");\n\n    #[derive(Debug, serde::Deserialize, PartialEq)]\n    struct DummyConfig {\n        field: String,\n    }\n\n    // Try loading with a valid top-level key that exists but has wrong structure\n    let result: Result<DummyConfig, _> = PytorchReader::load_config(&path, Some(\"epoch\"));\n\n    // This should fail because epoch is an integer, not a struct with a field\n    assert!(result.is_err());\n\n    // Now test that we can load with a real key that has the right structure\n    // Since checkpoint.pt doesn't have nested configs, let's use nested_dict.pt\n    let path2 = test_data_path(\"nested_dict.pt\");\n\n    // Try to extract a specific nested key if it exists\n    // Since nested_dict has complex structure, let's just verify we can read it\n    let data = PytorchReader::read_pickle_data(&path2, None).unwrap();\n\n    // Verify it's a dict\n    if let crate::pytorch::reader::PickleValue::Dict(dict) = data {\n        assert!(!dict.is_empty());\n    } else {\n        panic!(\"Expected a dict\");\n    }\n}\n\n#[test]\nfn test_load_config_complex_types() {\n    // For this test, let's create a comprehensive test using checkpoint.pt\n    // which has both metadata and state_dict fields\n    let path = test_data_path(\"checkpoint.pt\");\n\n    // Define a partial config that only captures metadata fields\n    #[derive(Debug, serde::Deserialize, PartialEq)]\n    struct PartialCheckpoint {\n        epoch: i64,\n        loss: f64,\n        // We skip model_state_dict and optimizer_state_dict\n        // as they contain tensor references that become None\n    }\n\n    // Load partial config\n    let config: PartialCheckpoint =\n        PytorchReader::load_config(&path, None).expect(\"Failed to load config\");\n\n    // Verify we can extract the metadata\n    assert_eq!(config.epoch, 42);\n    assert!((config.loss - 0.123).abs() < 1e-6);\n}\n\n#[test]\nfn test_load_config_key_not_found() {\n    let path = test_data_path(\"checkpoint.pt\");\n\n    #[derive(Debug, serde::Deserialize)]\n    struct DummyConfig {\n        #[allow(dead_code)]\n        field: String,\n    }\n\n    // Try to load with non-existent key\n    let result: Result<DummyConfig, _> = PytorchReader::load_config(&path, Some(\"nonexistent\"));\n\n    assert!(result.is_err());\n    let error = result.unwrap_err();\n    assert!(error.to_string().contains(\"not found\") || error.to_string().contains(\"Key\"));\n}\n\n#[test]\nfn test_pickle_value_conversion() {\n    use crate::pytorch::reader::PickleValue;\n\n    // Test that PickleValue provides useful data structures\n    let path = test_data_path(\"checkpoint.pt\");\n    let data = PytorchReader::read_pickle_data(&path, None).unwrap();\n\n    // Test pattern matching and data extraction\n    match data {\n        PickleValue::Dict(dict) => {\n            // Extract epoch as integer\n            if let Some(PickleValue::Int(epoch)) = dict.get(\"epoch\") {\n                assert!(*epoch >= 0);\n            }\n\n            // Extract loss as float\n            if let Some(PickleValue::Float(loss)) = dict.get(\"loss\") {\n                assert!(loss.is_finite());\n            }\n\n            // Test nested access\n            if let Some(PickleValue::Dict(model_dict)) = dict.get(\"model_state_dict\") {\n                assert!(!model_dict.is_empty());\n            }\n        }\n        _ => panic!(\"Unexpected root type\"),\n    }\n}\n\n// ============================================================================\n// TAR Format Tests\n// ============================================================================\n// The TAR format was used by very early versions of PyTorch (pre 0.1.10).\n// These tests verify that we can correctly load models saved in this format.\n\n#[test]\nfn test_tar_format_detection() {\n    // Test that is_tar_file correctly detects TAR files\n    let tar_path = test_data_path(\"tar_float32.tar\");\n    let zip_path = test_data_path(\"float32.pt\");\n\n    // TAR file should be detected as TAR\n    let reader = PytorchReader::new(&tar_path).expect(\"Failed to load TAR file\");\n    let metadata = reader.metadata();\n    assert_eq!(metadata.format_type, FileFormat::Tar);\n\n    // ZIP file should NOT be detected as TAR\n    let reader = PytorchReader::new(&zip_path).expect(\"Failed to load ZIP file\");\n    let metadata = reader.metadata();\n    assert_ne!(metadata.format_type, FileFormat::Tar);\n}\n\n#[test]\nfn test_tar_float32_tensor() {\n    let path = test_data_path(\"tar_float32.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_float32.tar\");\n\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F32);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 4);\n    assert!((values[0] - 1.0).abs() < 1e-6);\n    assert!((values[1] - 2.5).abs() < 1e-6);\n    assert!((values[2] - (-3.7)).abs() < 1e-6);\n    assert!((values[3] - 0.0).abs() < 1e-6);\n}\n\n#[test]\nfn test_tar_float64_tensor() {\n    let path = test_data_path(\"tar_float64.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_float64.tar\");\n\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::F64);\n    assert_eq!(tensor.shape, shape![3]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<f64>().unwrap();\n    assert_eq!(values.len(), 3);\n    assert!((values[0] - 1.1).abs() < 1e-10);\n    assert!((values[1] - 2.2).abs() < 1e-10);\n    assert!((values[2] - 3.3).abs() < 1e-10);\n}\n\n#[test]\nfn test_tar_int64_tensor() {\n    let path = test_data_path(\"tar_int64.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_int64.tar\");\n\n    let tensor = reader.get(\"tensor\").expect(\"tensor key not found\");\n    assert_eq!(tensor.dtype, DType::I64);\n    assert_eq!(tensor.shape, shape![4]);\n\n    let data = tensor.to_data().unwrap();\n    let values = data.as_slice::<i64>().unwrap();\n    assert_eq!(values, &[100, -200, 300, 0]);\n}\n\n#[test]\nfn test_tar_multiple_tensors() {\n    // Test loading multiple tensors (weight + bias) with correct shapes\n    let path = test_data_path(\"tar_weight_bias.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_weight_bias.tar\");\n\n    // Check weight tensor (2x3 matrix)\n    let weight = reader.get(\"weight\").expect(\"weight key not found\");\n    assert_eq!(weight.dtype, DType::F32);\n    assert_eq!(weight.shape, shape![2, 3]);\n\n    let data = weight.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 6);\n    assert!((values[0] - 0.1).abs() < 1e-6);\n    assert!((values[1] - 0.2).abs() < 1e-6);\n    assert!((values[5] - 0.6).abs() < 1e-6);\n\n    // Check bias tensor (2-element vector)\n    let bias = reader.get(\"bias\").expect(\"bias key not found\");\n    assert_eq!(bias.dtype, DType::F32);\n    assert_eq!(bias.shape, shape![2]);\n\n    let data = bias.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 2);\n    assert!((values[0] - 0.01).abs() < 1e-6);\n    assert!((values[1] - 0.02).abs() < 1e-6);\n}\n\n#[test]\nfn test_tar_multi_dtype() {\n    // Test loading different dtypes from the same TAR file\n    let path = test_data_path(\"tar_multi_dtype.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_multi_dtype.tar\");\n\n    // Float32 tensor\n    let float_tensor = reader\n        .get(\"float_tensor\")\n        .expect(\"float_tensor key not found\");\n    assert_eq!(float_tensor.dtype, DType::F32);\n    let data = float_tensor.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert!((values[0] - 1.5).abs() < 1e-6);\n\n    // Float64 tensor\n    let double_tensor = reader\n        .get(\"double_tensor\")\n        .expect(\"double_tensor key not found\");\n    assert_eq!(double_tensor.dtype, DType::F64);\n    let data = double_tensor.to_data().unwrap();\n    let values = data.as_slice::<f64>().unwrap();\n    assert!((values[0] - 1.111).abs() < 1e-10);\n\n    // Int64 tensor\n    let int_tensor = reader.get(\"int_tensor\").expect(\"int_tensor key not found\");\n    assert_eq!(int_tensor.dtype, DType::I64);\n    let data = int_tensor.to_data().unwrap();\n    let values = data.as_slice::<i64>().unwrap();\n    assert_eq!(values, &[10, 20, 30, 40]);\n}\n\n#[test]\nfn test_tar_2d_tensor_shape() {\n    // Test that 2D tensor shapes are correctly preserved\n    let path = test_data_path(\"tar_2d_tensor.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_2d_tensor.tar\");\n\n    let matrix = reader.get(\"matrix\").expect(\"matrix key not found\");\n    assert_eq!(matrix.dtype, DType::F32);\n    assert_eq!(matrix.shape, shape![3, 4]); // 3 rows, 4 columns\n\n    let data = matrix.to_data().unwrap();\n    let values = data.as_slice::<f32>().unwrap();\n    assert_eq!(values.len(), 12);\n\n    // Verify values in row-major order\n    for i in 0..12 {\n        assert!((values[i] - (i as f32 + 1.0)).abs() < 1e-6);\n    }\n}\n\n#[test]\nfn test_tar_metadata() {\n    // Test that TAR metadata is correctly populated\n    let path = test_data_path(\"tar_float32.tar\");\n    let reader = PytorchReader::new(&path).expect(\"Failed to load tar_float32.tar\");\n\n    let metadata = reader.metadata();\n    assert_eq!(metadata.format_type, FileFormat::Tar);\n    assert_eq!(metadata.byte_order, ByteOrder::LittleEndian);\n    assert_eq!(metadata.tensor_count, 1);\n    assert!(metadata.total_data_size.is_some());\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/simple_legacy.py",
    "content": "#!/usr/bin/env python3\n# /// script\n# dependencies = [\"torch\"]\n# ///\n\"\"\"Create a simple legacy format PyTorch file.\"\"\"\n\nimport torch\n\n# Create a simple state dict\nstate_dict = {\n    'weight': torch.randn(2, 3),\n    'bias': torch.ones(2),\n    'running_mean': torch.zeros(2),\n}\n\n# Save without using zip format (legacy format)\ntorch.save(state_dict, 'test_data/simple_legacy.pt', _use_new_zipfile_serialization=False)\n\nprint(\"Created simple_legacy.pt\")\n\n# Verify\nloaded = torch.load('test_data/simple_legacy.pt', weights_only=False)\nprint(f\"Loaded {len(loaded)} tensors\")\nfor key, val in loaded.items():\n    print(f\"  {key}: shape {val.shape}, dtype {val.dtype}\")"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/test_data/broken.pt",
    "content": "abc"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/reader/test_data.py",
    "content": "#!/usr/bin/env python3\n# /// script\n# dependencies = [\"torch\", \"numpy\"]\n# ///\n\"\"\"\nGenerate test PyTorch .pt files for testing the burn-store PyTorch reader.\nRun with: uv run test_files.py\n\"\"\"\n\nimport torch\nimport numpy as np\nimport os\nfrom pathlib import Path\n\n# Create test directory\ntest_dir = Path(__file__).parent / \"test_data\"\ntest_dir.mkdir(exist_ok=True)\n\ndef save_test_file(filename, data, description):\n    \"\"\"Save a test file and print what was saved.\"\"\"\n    filepath = test_dir / filename\n    torch.save(data, filepath)\n    print(f\"✓ {filename}: {description}\")\n    return filepath\n\n# Test 1: Simple tensors of different types\nprint(\"\\n=== Generating Basic Tensor Tests ===\")\n\n# Float32 tensor (wrap in dict for compatibility)\nfloat32_tensor = torch.tensor([1.0, 2.5, -3.7, 0.0], dtype=torch.float32)\nsave_test_file(\"float32.pt\", {\"tensor\": float32_tensor}, \"Float32 tensor [1.0, 2.5, -3.7, 0.0]\")\n\n# Float64 tensor\nfloat64_tensor = torch.tensor([1.1, 2.2, 3.3], dtype=torch.float64)\nsave_test_file(\"float64.pt\", {\"tensor\": float64_tensor}, \"Float64 tensor [1.1, 2.2, 3.3]\")\n\n# Int64 tensor\nint64_tensor = torch.tensor([100, -200, 300, 0], dtype=torch.int64)\nsave_test_file(\"int64.pt\", {\"tensor\": int64_tensor}, \"Int64 tensor [100, -200, 300, 0]\")\n\n# Int32 tensor\nint32_tensor = torch.tensor([10, 20, -30], dtype=torch.int32)\nsave_test_file(\"int32.pt\", {\"tensor\": int32_tensor}, \"Int32 tensor [10, 20, -30]\")\n\n# Int16 tensor\nint16_tensor = torch.tensor([1000, -2000, 3000], dtype=torch.int16)\nsave_test_file(\"int16.pt\", {\"tensor\": int16_tensor}, \"Int16 tensor [1000, -2000, 3000]\")\n\n# Int8 tensor\nint8_tensor = torch.tensor([127, -128, 0, 50], dtype=torch.int8)\nsave_test_file(\"int8.pt\", {\"tensor\": int8_tensor}, \"Int8 tensor [127, -128, 0, 50]\")\n\n# Boolean tensor\nbool_tensor = torch.tensor([True, False, True, True, False], dtype=torch.bool)\nsave_test_file(\"bool.pt\", {\"tensor\": bool_tensor}, \"Bool tensor [True, False, True, True, False]\")\n\n# Float16 tensor (half precision)\nfloat16_tensor = torch.tensor([1.5, -2.25, 3.125], dtype=torch.float16)\nsave_test_file(\"float16.pt\", {\"tensor\": float16_tensor}, \"Float16 tensor [1.5, -2.25, 3.125]\")\n\n# BFloat16 tensor\nbfloat16_tensor = torch.tensor([1.5, -2.5, 3.5], dtype=torch.bfloat16)\nsave_test_file(\"bfloat16.pt\", {\"tensor\": bfloat16_tensor}, \"BFloat16 tensor [1.5, -2.5, 3.5]\")\n\n# UInt8 tensor\nuint8_tensor = torch.tensor([0, 128, 255, 42], dtype=torch.uint8)\nsave_test_file(\"uint8.pt\", {\"tensor\": uint8_tensor}, \"UInt8 tensor [0, 128, 255, 42]\")\n\n# Test 2: Multi-dimensional tensors\nprint(\"\\n=== Generating Multi-dimensional Tensor Tests ===\")\n\n# 2D tensor\ntensor_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)\nsave_test_file(\"tensor_2d.pt\", {\"tensor\": tensor_2d}, \"2D tensor shape (3, 2)\")\n\n# 3D tensor\ntorch.manual_seed(42)\ntensor_3d = torch.randn(2, 3, 4) * 10\nsave_test_file(\"tensor_3d.pt\", {\"tensor\": tensor_3d}, \"3D tensor shape (2, 3, 4)\")\n\n# 4D tensor (common for conv weights)\ntensor_4d = torch.randn(2, 3, 2, 2)\nsave_test_file(\"tensor_4d.pt\", {\"tensor\": tensor_4d}, \"4D tensor shape (2, 3, 2, 2)\")\n\n# Test 3: State dict (multiple tensors)\nprint(\"\\n=== Generating State Dict Tests ===\")\n\nstate_dict = {\n    \"weight\": torch.randn(3, 4),\n    \"bias\": torch.randn(3),\n    \"running_mean\": torch.zeros(3),\n    \"running_var\": torch.ones(3),\n}\nsave_test_file(\"state_dict.pt\", state_dict, \"State dict with 4 tensors\")\n\n# Nested state dict\nnested_dict = {\n    \"layer1\": {\n        \"weight\": torch.randn(2, 3),\n        \"bias\": torch.randn(2)\n    },\n    \"layer2\": {\n        \"weight\": torch.randn(4, 2),\n        \"bias\": torch.randn(4)\n    }\n}\nsave_test_file(\"nested_dict.pt\", nested_dict, \"Nested state dict\")\n\n# Test 4: Model checkpoint format\nprint(\"\\n=== Generating Model Checkpoint Tests ===\")\n\n# Typical checkpoint format (use string keys for compatibility)\ncheckpoint = {\n    \"model_state_dict\": {\n        \"fc1.weight\": torch.randn(10, 5),\n        \"fc1.bias\": torch.randn(10),\n        \"fc2.weight\": torch.randn(3, 10),\n        \"fc2.bias\": torch.randn(3),\n    },\n    \"optimizer_state_dict\": {\n        \"state\": {\n            \"0\": {  # Use string key instead of integer\n                \"momentum_buffer\": torch.randn(10, 5)\n            }\n        }\n    },\n    \"epoch\": 42,\n    \"loss\": 0.123\n}\nsave_test_file(\"checkpoint.pt\", checkpoint, \"Full checkpoint with model and optimizer state\")\n\n# Test 5: Edge cases\nprint(\"\\n=== Generating Edge Case Tests ===\")\n\n# Empty tensor (1D with 0 elements)\nempty_tensor = torch.zeros(0)\nsave_test_file(\"empty.pt\", {\"tensor\": empty_tensor}, \"Empty tensor\")\n\n# Scalar tensor (0-dimensional)\nscalar_tensor = torch.tensor(42.0)\nsave_test_file(\"scalar.pt\", {\"tensor\": scalar_tensor}, \"Scalar tensor (0-dim)\")\n\n# Large shape but small data (testing shape vs actual data)\nsparse_like = torch.zeros(100, 100)\nsparse_like[0, 0] = 1.0\nsparse_like[50, 50] = 2.0\nsparse_like[99, 99] = 3.0\nsave_test_file(\"large_shape.pt\", {\"tensor\": sparse_like}, \"Large shape (100, 100) mostly zeros\")\n\n# Test 6: Mixed types in dict\nprint(\"\\n=== Generating Mixed Type Tests ===\")\n\nmixed_types = {\n    \"float32\": torch.tensor([1.0, 2.0], dtype=torch.float32),\n    \"int64\": torch.tensor([100, 200], dtype=torch.int64),\n    \"bool\": torch.tensor([True, False], dtype=torch.bool),\n    \"float64\": torch.tensor([1.1, 2.2], dtype=torch.float64),\n}\nsave_test_file(\"mixed_types.pt\", mixed_types, \"Dict with mixed tensor types\")\n\n# Test 7: Special values\nprint(\"\\n=== Generating Special Value Tests ===\")\n\n# NaN and Inf values\nspecial_values = torch.tensor([float('nan'), float('inf'), float('-inf'), 0.0, 1.0])\nsave_test_file(\"special_values.pt\", {\"tensor\": special_values}, \"Tensor with NaN and Inf\")\n\n# Very small and very large values\nextreme_values = torch.tensor([1e-30, 1e30, -1e-30, -1e30], dtype=torch.float32)\nsave_test_file(\"extreme_values.pt\", {\"tensor\": extreme_values}, \"Tensor with extreme values\")\n\n# Test 8: Parameter wrapper (common in models)\nprint(\"\\n=== Generating Parameter Tests ===\")\n\nimport torch.nn as nn\nparam = nn.Parameter(torch.randn(3, 3))\nparam_dict = {\"param\": param}\nsave_test_file(\"parameter.pt\", param_dict, \"nn.Parameter wrapped tensor\")\n\n# Test 9: Buffer-style tensors\nprint(\"\\n=== Generating Buffer Tests ===\")\n\n# Simulate model buffers\nbuffers = {\n    \"buffer1\": torch.tensor([1, 2, 3], dtype=torch.int32),\n    \"buffer2\": torch.tensor([True, False], dtype=torch.bool),\n}\nsave_test_file(\"buffers.pt\", buffers, \"Model buffers\")\n\n# Test 10: Complex nested structure\nprint(\"\\n=== Generating Complex Structure Tests ===\")\n\ncomplex_structure = {\n    \"metadata\": {\n        \"version\": 1,\n        \"name\": \"test_model\"\n    },\n    \"state\": {\n        \"encoder\": {\n            \"layer_0\": {\n                \"weight\": torch.randn(4, 3),\n                \"bias\": torch.randn(4)\n            },\n            \"layer_1\": {\n                \"weight\": torch.randn(2, 4),\n                \"bias\": torch.randn(2)\n            }\n        },\n        \"decoder\": {\n            \"weight\": torch.randn(3, 2),\n            \"bias\": torch.randn(3)\n        }\n    },\n    \"config\": {\n        \"hidden_size\": 4,\n        \"num_layers\": 2\n    }\n}\nsave_test_file(\"complex_structure.pt\", complex_structure, \"Complex nested structure\")\n\nprint(f\"\\n✅ Generated {len(list(test_dir.glob('*.pt')))} test files in {test_dir}\")\nprint(\"\\nTest files can be used to verify PyTorch reader functionality:\")\nprint(\"- Different data types (float32, int64, bool, etc.)\")\nprint(\"- Multi-dimensional tensors\")\nprint(\"- State dicts and nested structures\")\nprint(\"- Edge cases (empty, scalar, special values)\")\nprint(\"- Model checkpoints and parameters\")\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/store/mod.rs",
    "content": "//! Comprehensive tests for PytorchStore with real model application\nuse burn_core as burn;\n\nuse std::path::PathBuf;\n\nuse crate::ModuleStore;\nuse crate::pytorch::PytorchStore;\nuse burn_core::module::Module;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\nuse burn_nn::{Linear, LinearConfig};\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\n/// Path to pytorch test files (now under burn-store)\nfn pytorch_test_path(subdir: &str, filename: &str) -> PathBuf {\n    PathBuf::from(env!(\"CARGO_MANIFEST_DIR\"))\n        .join(\"pytorch-tests\")\n        .join(\"tests\")\n        .join(subdir)\n        .join(filename)\n}\n\n/// Path to burn-store test data files\nfn test_data_path(filename: &str) -> PathBuf {\n    PathBuf::from(env!(\"CARGO_MANIFEST_DIR\"))\n        .join(\"src\")\n        .join(\"pytorch\")\n        .join(\"tests\")\n        .join(\"reader\")\n        .join(\"test_data\")\n        .join(filename)\n}\n\n/// Path to store test data files\nfn store_test_data_path(filename: &str) -> PathBuf {\n    PathBuf::from(env!(\"CARGO_MANIFEST_DIR\"))\n        .join(\"src\")\n        .join(\"pytorch\")\n        .join(\"tests\")\n        .join(\"store\")\n        .join(\"test_data\")\n        .join(filename)\n}\n\n#[cfg(test)]\nmod basic_tests {\n    use super::*;\n\n    #[test]\n    fn test_store_creation() {\n        let store = PytorchStore::from_file(\"model.pth\");\n        assert!(store.validate);\n        assert!(!store.allow_partial);\n        assert!(store.top_level_key.is_none());\n        // Contiguous index mapping is enabled by default for PyTorch files\n        assert!(store.map_indices_contiguous);\n    }\n\n    #[test]\n    fn test_store_map_indices_contiguous_default() {\n        // Verify that map_indices_contiguous is enabled by default\n        let store = PytorchStore::from_file(\"model.pth\");\n        assert!(\n            store.map_indices_contiguous,\n            \"map_indices_contiguous should be enabled by default\"\n        );\n    }\n\n    #[test]\n    fn test_store_map_indices_contiguous_disabled() {\n        // Verify that we can disable map_indices_contiguous\n        let store = PytorchStore::from_file(\"model.pth\").map_indices_contiguous(false);\n        assert!(\n            !store.map_indices_contiguous,\n            \"map_indices_contiguous should be disabled after explicit call\"\n        );\n    }\n\n    #[test]\n    fn test_store_with_top_level_key() {\n        let store = PytorchStore::from_file(\"model.pth\").with_top_level_key(\"state_dict\");\n        assert_eq!(store.top_level_key, Some(\"state_dict\".to_string()));\n    }\n\n    #[test]\n    fn test_store_configuration() {\n        let store = PytorchStore::from_file(\"model.pth\")\n            .validate(false)\n            .allow_partial(true)\n            .with_regex(r\"^encoder\\.\")\n            .with_full_path(\"decoder.weight\");\n\n        assert!(!store.validate);\n        assert!(store.allow_partial);\n        assert!(!store.filter.is_empty());\n    }\n\n    #[test]\n    fn test_store_with_remapping() {\n        let store = PytorchStore::from_file(\"model.pth\").with_key_remapping(r\"^old\\.\", \"new.\");\n\n        assert!(!store.remapper.is_empty());\n    }\n\n    #[test]\n    fn test_store_save_not_supported() {\n        // Currently, saving to PyTorch format is not implemented\n        // The collect_from method always returns an error\n        let store = PytorchStore::from_file(\"test.pth\");\n\n        // Just verify that store creation works\n        assert!(store.validate);\n\n        // Note: Actually testing save would require a proper Module implementation\n        // which is complex. The implementation guarantees it returns an error.\n    }\n}\n\n#[cfg(test)]\nmod linear_model_tests {\n    use super::*;\n    type TestBackend = burn_ndarray::NdArray;\n\n    #[derive(Module, Debug)]\n    pub struct SimpleLinearModel<B: Backend> {\n        fc1: Linear<B>,\n        fc2: Linear<B>,\n    }\n\n    impl<B: Backend> SimpleLinearModel<B> {\n        pub fn new(device: &B::Device) -> Self {\n            Self {\n                fc1: LinearConfig::new(2, 3).init(device),\n                fc2: LinearConfig::new(3, 4).init(device),\n            }\n        }\n\n        pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {\n            let x = self.fc1.forward(x);\n            self.fc2.forward(x)\n        }\n    }\n\n    #[test]\n    fn test_load_linear_model() {\n        let device = Default::default();\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        // Create a model and load weights from PyTorch\n        let mut model = SimpleLinearModel::<TestBackend>::new(&device);\n        let mut store = PytorchStore::from_file(path).allow_partial(true);\n\n        // Apply the PyTorch weights to our model\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        assert!(\n            result.is_ok(),\n            \"Failed to load linear model: {:?}\",\n            result.err()\n        );\n\n        let result = result.unwrap();\n        assert!(!result.applied.is_empty(), \"No tensors were applied\");\n\n        // Test forward pass with loaded weights\n        let input = Tensor::<TestBackend, 2>::ones([1, 2], &device);\n        let output = model.forward(input);\n\n        // Verify output shape\n        assert_eq!(&*output.shape(), [1, 4]);\n    }\n\n    #[test]\n    fn test_load_linear_with_bias() {\n        let device = Default::default();\n        let path = pytorch_test_path(\"linear\", \"linear_with_bias.pt\");\n\n        // Single linear layer with bias\n        #[derive(Module, Debug)]\n        struct LinearWithBias<B: Backend> {\n            fc1: Linear<B>,\n        }\n\n        let mut model = LinearWithBias {\n            fc1: LinearConfig::new(2, 3).init(&device),\n        };\n\n        let mut store = PytorchStore::from_file(path).allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n        assert!(result.is_ok(), \"Failed to load model with bias\");\n\n        // Verify biases were loaded\n        let result = result.unwrap();\n        let bias_loaded = result.applied.iter().any(|s| s.contains(\"bias\"));\n        assert!(bias_loaded, \"Bias parameters not loaded\");\n    }\n\n    #[test]\n    fn test_filter_layers() {\n        let device = Default::default();\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        let mut model = SimpleLinearModel::<TestBackend>::new(&device);\n\n        // Only load fc1 layers\n        let mut store = PytorchStore::from_file(path)\n            .with_regex(r\"^fc1\\.\")\n            .allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model).unwrap();\n\n        // Should only have fc1 tensors\n        for tensor in &result.applied {\n            assert!(tensor.contains(\"fc1\"));\n            assert!(!tensor.contains(\"fc2\"));\n        }\n    }\n\n    #[test]\n    fn test_remap_layer_names() {\n        let device = Default::default();\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        // Model with different layer names\n        #[derive(Module, Debug)]\n        struct RemappedModel<B: Backend> {\n            linear1: Linear<B>,\n            linear2: Linear<B>,\n        }\n\n        let mut model = RemappedModel {\n            linear1: LinearConfig::new(2, 3).init(&device),\n            linear2: LinearConfig::new(3, 4).init(&device),\n        };\n\n        let mut store = PytorchStore::from_file(path)\n            .with_key_remapping(r\"^fc1\\.\", \"linear1.\")\n            .with_key_remapping(r\"^fc2\\.\", \"linear2.\")\n            .allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n        assert!(result.is_ok(), \"Failed to load with remapped names\");\n\n        let result = result.unwrap();\n        // Verify remapped names were applied\n        let has_linear1 = result.applied.iter().any(|s| s.contains(\"linear1\"));\n        assert!(has_linear1, \"Remapped names not applied\");\n    }\n}\n\n#[cfg(test)]\nmod conv_model_tests {\n    use super::*;\n\n    type TestBackend = burn_ndarray::NdArray;\n\n    #[derive(Module, Debug)]\n    struct SimpleConvModel<B: Backend> {\n        conv1: Conv2d<B>,\n        conv2: Conv2d<B>,\n    }\n\n    impl<B: Backend> SimpleConvModel<B> {\n        pub fn new(device: &B::Device) -> Self {\n            Self {\n                conv1: Conv2dConfig::new([3, 16], [3, 3]).init(device),\n                conv2: Conv2dConfig::new([16, 32], [3, 3]).init(device),\n            }\n        }\n    }\n\n    #[test]\n    fn test_load_conv2d_model() {\n        let device = Default::default();\n        let path = pytorch_test_path(\"conv2d\", \"conv2d.pt\");\n\n        // Check if file exists, skip if not\n        if !path.exists() {\n            println!(\"Skipping conv2d test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut model = SimpleConvModel::<TestBackend>::new(&device);\n        let mut store = PytorchStore::from_file(path).allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        if let Ok(result) = result {\n            assert!(!result.applied.is_empty(), \"No conv tensors applied\");\n\n            // Check for conv weights\n            let has_conv_weights = result.applied.iter().any(|s| s.contains(\"weight\"));\n            assert!(has_conv_weights, \"Conv weights not loaded\");\n        }\n    }\n\n    #[test]\n    fn test_load_conv1d_model() {\n        let path = pytorch_test_path(\"conv1d\", \"conv1d.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping conv1d test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Just test that we can create a store for conv1d files\n        let store = PytorchStore::from_file(path).allow_partial(true);\n\n        assert!(store.allow_partial);\n    }\n}\n\n#[cfg(test)]\nmod complex_model_tests {\n    use super::*;\n    type TestBackend = burn_ndarray::NdArray;\n\n    #[test]\n    fn test_load_with_top_level_key() {\n        let path = test_data_path(\"checkpoint.pt\");\n\n        // Just verify that we can create a store with top-level key\n        let store = PytorchStore::from_file(path)\n            .with_top_level_key(\"model_state_dict\")\n            .allow_partial(true);\n\n        assert_eq!(store.top_level_key, Some(\"model_state_dict\".to_string()));\n    }\n\n    #[test]\n    fn test_load_nested_structure() {\n        let path = test_data_path(\"complex_structure.pt\");\n\n        // Just verify that we can create a store for nested structure\n        let store = PytorchStore::from_file(path).allow_partial(true);\n\n        assert!(store.allow_partial);\n    }\n\n    #[test]\n    fn test_legacy_format() {\n        let path = test_data_path(\"simple_legacy.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping legacy format test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Just verify that we can create a store for legacy format\n        let store = PytorchStore::from_file(path).allow_partial(true);\n\n        assert!(store.allow_partial);\n\n        // Could load into an actual model if we had legacy model structure\n    }\n\n    #[test]\n    fn test_key_remap_chained() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping key remap test - file not found: {:?}\", path);\n            return;\n        }\n\n        let device = Default::default();\n\n        // Model with different layer names that need remapping\n        #[derive(Module, Debug)]\n        struct RemappedChainModel<B: Backend> {\n            convolution1: Linear<B>, // Will be remapped from fc1\n            linear2: Linear<B>,      // Will be remapped from fc2\n        }\n\n        let mut model = RemappedChainModel {\n            convolution1: LinearConfig::new(2, 3).init(&device),\n            linear2: LinearConfig::new(3, 4).init(&device),\n        };\n\n        // Chain multiple remappings\n        let mut store = PytorchStore::from_file(path)\n            .with_key_remapping(r\"^fc1\\.\", \"convolution1.\")\n            .with_key_remapping(r\"^fc2\\.\", \"linear2.\")\n            .allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        if let Ok(result) = result {\n            // Check that remapped names were applied\n            assert!(\n                !result.applied.is_empty(),\n                \"No tensors were applied after remapping\"\n            );\n        }\n    }\n}\n\n#[cfg(test)]\nmod adapter_tests {\n    use super::*;\n\n    type TestBackend = burn_ndarray::NdArray;\n\n    #[derive(Module, Debug)]\n    pub struct SimpleLinearModel<B: Backend> {\n        fc1: Linear<B>,\n        fc2: Linear<B>,\n    }\n\n    impl<B: Backend> SimpleLinearModel<B> {\n        pub fn new(device: &B::Device) -> Self {\n            Self {\n                fc1: LinearConfig::new(2, 3).init(device),\n                fc2: LinearConfig::new(3, 4).init(device),\n            }\n        }\n    }\n\n    #[test]\n    fn test_pytorch_adapter_always_applied() {\n        // Test that PyTorchToBurnAdapter is always applied internally\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping adapter test - file not found: {:?}\", path);\n            return;\n        }\n\n        let device = Default::default();\n        let mut model = SimpleLinearModel::<TestBackend>::new(&device);\n\n        let mut store = PytorchStore::from_file(path).allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        // PyTorchToBurnAdapter is always applied internally\n        assert!(\n            result.is_ok(),\n            \"Failed to load with internal PyTorchToBurnAdapter: {:?}\",\n            result.err()\n        );\n        assert!(!result.unwrap().applied.is_empty());\n    }\n\n    #[test]\n    fn test_pytorch_adapter_with_filtering() {\n        // Test that PyTorchToBurnAdapter works with filtering\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping filtering test - file not found: {:?}\", path);\n            return;\n        }\n\n        let device = Default::default();\n        let mut model = SimpleLinearModel::<TestBackend>::new(&device);\n\n        // Filter to exclude bias tensors\n        let mut store = PytorchStore::from_file(path)\n            .with_predicate(|path, _| !path.contains(\"bias\"))\n            .allow_partial(true);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model).unwrap();\n\n        // Should not have any bias tensors due to filtering\n        for applied_path in &result.applied {\n            assert!(\n                !applied_path.contains(\"bias\"),\n                \"Bias tensor was not filtered: {}\",\n                applied_path\n            );\n        }\n    }\n}\n\n#[cfg(test)]\nmod error_handling_tests {\n    use super::*;\n    use burn_ndarray::NdArray;\n\n    #[derive(Module, Debug)]\n    pub struct SimpleLinearModel<B: Backend> {\n        fc1: Linear<B>,\n        fc2: Linear<B>,\n    }\n\n    impl<B: Backend> SimpleLinearModel<B> {\n        pub fn new(device: &B::Device) -> Self {\n            Self {\n                fc1: LinearConfig::new(2, 3).init(device),\n                fc2: LinearConfig::new(3, 4).init(device),\n            }\n        }\n    }\n\n    #[test]\n    fn test_missing_file() {\n        let device = Default::default();\n        let mut model = SimpleLinearModel::<NdArray>::new(&device);\n        let mut store = PytorchStore::from_file(\"nonexistent.pth\");\n\n        let result = store.apply_to::<NdArray, _>(&mut model);\n\n        assert!(result.is_err());\n        match result {\n            Err(crate::pytorch::PytorchStoreError::Reader(_)) => {}\n            _ => panic!(\"Expected reader error for missing file\"),\n        }\n    }\n\n    #[test]\n    fn test_invalid_top_level_key() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\n                \"Skipping invalid top level key test - file not found: {:?}\",\n                path\n            );\n            return;\n        }\n\n        let device = Default::default();\n        let mut model = SimpleLinearModel::<NdArray>::new(&device);\n\n        let mut store = PytorchStore::from_file(path).with_top_level_key(\"nonexistent_key\");\n\n        let result = store.apply_to::<NdArray, _>(&mut model);\n\n        assert!(result.is_err(), \"Should fail with invalid top level key\");\n    }\n\n    #[test]\n    fn test_strict_validation() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\n                \"Skipping strict validation test - file not found: {:?}\",\n                path\n            );\n            return;\n        }\n\n        let device = Default::default();\n        let mut model = SimpleLinearModel::<NdArray>::new(&device);\n\n        // Apply very restrictive filter that matches nothing\n        let mut store = PytorchStore::from_file(path)\n            .with_regex(r\"^this_will_never_match$\")\n            .validate(true)\n            .allow_partial(false);\n\n        let result = store.apply_to::<NdArray, _>(&mut model);\n\n        // Should fail because no tensors match and allow_partial is false\n        assert!(\n            result.is_err(),\n            \"Should fail when no tensors match with allow_partial=false\"\n        );\n    }\n}\n\n#[cfg(test)]\nmod enum_variant_tests {\n    use super::*;\n    use crate::ModuleSnapshot;\n    use burn_ndarray::NdArray;\n\n    /// Enum representing different convolution block types (similar to YOLOX architecture)\n    #[derive(Module, Debug)]\n    pub enum ConvBlock<B: Backend> {\n        /// Base convolution block\n        BaseConv(Linear<B>),\n        /// Depthwise separable convolution block\n        DwsConv(Linear<B>),\n    }\n\n    /// Model with enum field that will have variant names in Burn paths\n    #[derive(Module, Debug)]\n    pub struct ModelWithEnum<B: Backend> {\n        /// Feature extractor with enum variants\n        feature: ConvBlock<B>,\n        /// Output classifier\n        classifier: Linear<B>,\n    }\n\n    impl<B: Backend> ModelWithEnum<B> {\n        pub fn new(device: &B::Device) -> Self {\n            Self {\n                feature: ConvBlock::BaseConv(LinearConfig::new(3, 64).init(device)),\n                classifier: LinearConfig::new(64, 10).init(device),\n            }\n        }\n    }\n\n    #[test]\n    fn test_enum_variant_path_mismatch() {\n        let device = Default::default();\n        let mut model = ModelWithEnum::<NdArray>::new(&device);\n\n        // Load PyTorch model that was generated without enum variant names\n        // PyTorch paths: \"feature.weight\", \"feature.bias\", \"classifier.weight\", \"classifier.bias\"\n        // Burn paths:    \"feature.BaseConv.weight\", \"feature.BaseConv.bias\", \"classifier.weight\", \"classifier.bias\"\n        //                         ^^^^^^^^ enum variant name is included in Burn but not PyTorch\n\n        let pytorch_file = store_test_data_path(\"model_without_enum_variants.pt\");\n\n        // Try to load from PyTorch format (without enum variants)\n        // Explicitly disable skip_enum_variants to demonstrate the mismatch problem\n        let mut store = PytorchStore::from_file(pytorch_file)\n            .skip_enum_variants(false) // Disable to show the mismatch\n            .allow_partial(true) // Allow partial to see what's missing\n            .validate(false); // Disable validation to get detailed missing info\n\n        let result = store.apply_to::<NdArray, _>(&mut model);\n\n        // The load should succeed (allow_partial=true) but report missing tensors\n        match result {\n            Ok(apply_result) => {\n                // Verify we have missing tensors\n                assert!(\n                    !apply_result.missing.is_empty(),\n                    \"Should have missing tensors due to enum variant path mismatch\"\n                );\n\n                // Check that missing paths contain enum variants\n                let enum_missing: Vec<_> = apply_result\n                    .missing\n                    .iter()\n                    .filter(|(_, container_stack)| container_stack.contains(\"Enum:\"))\n                    .collect();\n\n                assert!(\n                    !enum_missing.is_empty(),\n                    \"Missing tensors should be detected as having enum containers\"\n                );\n\n                // Verify the paths look like what we expect\n                let has_base_conv_path = apply_result\n                    .missing\n                    .iter()\n                    .any(|(path, _)| path.contains(\"BaseConv\"));\n\n                assert!(\n                    has_base_conv_path,\n                    \"Should have missing paths with 'BaseConv' enum variant. Missing: {:?}\",\n                    apply_result\n                        .missing\n                        .iter()\n                        .map(|(p, _)| p)\n                        .collect::<Vec<_>>()\n                );\n\n                // Print the diagnostic output to show enum detection\n                println!(\"\\n{}\", apply_result);\n\n                // Verify the diagnostic message mentions enum variants\n                let display_output = format!(\"{}\", apply_result);\n                assert!(\n                    display_output.contains(\"enum variant\"),\n                    \"Display output should mention enum variants\"\n                );\n            }\n            Err(e) => panic!(\n                \"Load should succeed with allow_partial=true, got error: {}\",\n                e\n            ),\n        }\n    }\n\n    #[test]\n    fn test_enum_variant_detection_in_container_stack() {\n        let device = Default::default();\n\n        // Create model with enum\n        let model = ModelWithEnum::<NdArray>::new(&device);\n\n        // Collect snapshots to inspect container stacks\n        let snapshots = model.collect(None, None, false);\n\n        // Find a snapshot from inside the enum\n        let enum_snapshot = snapshots\n            .iter()\n            .find(|s| s.full_path().contains(\"feature\"))\n            .expect(\"Should have feature snapshots\");\n\n        // Verify container stack contains enum marker\n        if let Some(container_stack) = &enum_snapshot.container_stack {\n            let container_str = container_stack.join(\".\");\n            assert!(\n                container_str.contains(\"Enum:ConvBlock\"),\n                \"Container stack should contain Enum:ConvBlock marker. Got: {}\",\n                container_str\n            );\n        } else {\n            panic!(\"Snapshot should have container_stack\");\n        }\n    }\n\n    #[test]\n    fn test_skip_enum_variants_feature() {\n        let device = Default::default();\n        let mut model = ModelWithEnum::<NdArray>::new(&device);\n\n        // Load PyTorch model that was generated without enum variant names\n        // PyTorch paths: \"feature.weight\", \"feature.bias\", \"classifier.weight\", \"classifier.bias\"\n        // Burn paths:    \"feature.BaseConv.weight\", \"feature.BaseConv.bias\", \"classifier.weight\", \"classifier.bias\"\n\n        let pytorch_file = store_test_data_path(\"model_without_enum_variants.pt\");\n\n        // Try to load with skip_enum_variants enabled\n        let mut store = PytorchStore::from_file(pytorch_file)\n            .skip_enum_variants(true) // Enable enum variant skipping\n            .allow_partial(true)\n            .validate(false);\n\n        let result = store.apply_to::<NdArray, _>(&mut model);\n\n        // The load should succeed and all tensors should be loaded\n        match result {\n            Ok(apply_result) => {\n                println!(\"\\n{}\", apply_result);\n\n                // With skip_enum_variants enabled, we should successfully load the feature tensors\n                let feature_applied = apply_result\n                    .applied\n                    .iter()\n                    .filter(|path| path.contains(\"feature\"))\n                    .count();\n\n                assert!(\n                    feature_applied > 0,\n                    \"Should have applied feature tensors with skip_enum_variants=true. Applied: {:?}\",\n                    apply_result.applied\n                );\n\n                // The feature tensors should NOT be in missing anymore\n                let feature_missing = apply_result\n                    .missing\n                    .iter()\n                    .filter(|(path, _)| path.contains(\"feature\"))\n                    .count();\n\n                assert_eq!(\n                    feature_missing, 0,\n                    \"Feature tensors should not be missing with skip_enum_variants=true. Missing: {:?}\",\n                    apply_result.missing\n                );\n            }\n            Err(e) => panic!(\n                \"Load with skip_enum_variants should succeed, got error: {}\",\n                e\n            ),\n        }\n    }\n}\n\n#[cfg(test)]\nmod direct_access_tests {\n    use super::*;\n\n    #[test]\n    fn test_get_all_snapshots() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut store = PytorchStore::from_file(path);\n        let snapshots = store.get_all_snapshots().unwrap();\n\n        // linear.pt should have fc1.weight, fc1.bias, fc2.weight, fc2.bias\n        assert!(!snapshots.is_empty(), \"Should have snapshots\");\n        assert!(\n            snapshots.contains_key(\"fc1.weight\"),\n            \"Should contain fc1.weight\"\n        );\n        assert!(\n            snapshots.contains_key(\"fc1.bias\"),\n            \"Should contain fc1.bias\"\n        );\n    }\n\n    #[test]\n    fn test_get_snapshot_existing() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut store = PytorchStore::from_file(path);\n\n        // Get existing snapshot\n        let snapshot = store.get_snapshot(\"fc1.weight\").unwrap();\n        assert!(snapshot.is_some(), \"Should find fc1.weight\");\n\n        let snapshot = snapshot.unwrap();\n        // Linear weight should be 2D\n        assert_eq!(snapshot.shape.len(), 2, \"Weight should be 2D tensor\");\n\n        // Verify we can load data\n        let data = snapshot.to_data().unwrap();\n        assert!(!data.bytes.is_empty(), \"Data should not be empty\");\n    }\n\n    #[test]\n    fn test_get_snapshot_not_found() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut store = PytorchStore::from_file(path);\n\n        // Get non-existent snapshot\n        let snapshot = store.get_snapshot(\"nonexistent.weight\").unwrap();\n        assert!(snapshot.is_none(), \"Should not find nonexistent tensor\");\n    }\n\n    #[test]\n    fn test_keys() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut store = PytorchStore::from_file(path);\n        let keys = store.keys().unwrap();\n\n        assert!(!keys.is_empty(), \"Should have keys\");\n        assert!(\n            keys.contains(&\"fc1.weight\".to_string()),\n            \"Keys should contain fc1.weight\"\n        );\n        assert!(\n            keys.contains(&\"fc1.bias\".to_string()),\n            \"Keys should contain fc1.bias\"\n        );\n    }\n\n    #[test]\n    fn test_keys_fast_path() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Create fresh store - cache should be empty\n        let mut store = PytorchStore::from_file(&path);\n\n        // keys() should work without populating the full cache (fast path)\n        let keys = store.keys().unwrap();\n        assert!(!keys.is_empty(), \"Should have keys via fast path\");\n\n        // Now call get_all_snapshots to populate cache\n        let snapshots = store.get_all_snapshots().unwrap();\n        assert!(!snapshots.is_empty(), \"Should have snapshots\");\n\n        // keys() should now use the cached data\n        let keys2 = store.keys().unwrap();\n        assert_eq!(keys.len(), keys2.len(), \"Keys count should match\");\n    }\n\n    #[test]\n    fn test_caching_behavior() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let mut store = PytorchStore::from_file(path);\n\n        // First call populates cache\n        let snapshots1 = store.get_all_snapshots().unwrap();\n        let count1 = snapshots1.len();\n\n        // Second call uses cache\n        let snapshots2 = store.get_all_snapshots().unwrap();\n        let count2 = snapshots2.len();\n\n        assert_eq!(count1, count2, \"Cached results should match\");\n    }\n\n    #[test]\n    fn test_get_all_snapshots_with_remapping() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Create store with key remapping\n        let mut store = PytorchStore::from_file(path).with_key_remapping(r\"^fc1\\.\", \"linear1.\");\n\n        let snapshots = store.get_all_snapshots().unwrap();\n\n        // Should have remapped keys\n        assert!(\n            snapshots.contains_key(\"linear1.weight\"),\n            \"Should contain remapped key linear1.weight. Keys: {:?}\",\n            snapshots.keys().collect::<Vec<_>>()\n        );\n        assert!(\n            snapshots.contains_key(\"linear1.bias\"),\n            \"Should contain remapped key linear1.bias\"\n        );\n\n        // Original keys should not exist\n        assert!(\n            !snapshots.contains_key(\"fc1.weight\"),\n            \"Should not contain original key fc1.weight\"\n        );\n    }\n\n    #[test]\n    fn test_get_snapshot_with_remapped_name() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Create store with key remapping\n        let mut store = PytorchStore::from_file(path).with_key_remapping(r\"^fc1\\.\", \"linear1.\");\n\n        // Should find by remapped name\n        let snapshot = store.get_snapshot(\"linear1.weight\").unwrap();\n        assert!(snapshot.is_some(), \"Should find tensor by remapped name\");\n\n        // Should NOT find by original name\n        let snapshot_orig = store.get_snapshot(\"fc1.weight\").unwrap();\n        assert!(\n            snapshot_orig.is_none(),\n            \"Should not find tensor by original name after remapping\"\n        );\n    }\n\n    #[test]\n    fn test_get_all_snapshots_ignores_filter() {\n        let path = pytorch_test_path(\"linear\", \"linear.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        // Create store with filter that only matches fc1\n        let mut store = PytorchStore::from_file(path).with_regex(r\"^fc1\\.\");\n\n        // get_all_snapshots should return ALL tensors regardless of filter\n        let snapshots = store.get_all_snapshots().unwrap();\n\n        // Should have both fc1 and fc2 tensors\n        assert!(\n            snapshots.contains_key(\"fc1.weight\"),\n            \"Should contain fc1.weight\"\n        );\n        assert!(\n            snapshots.contains_key(\"fc2.weight\"),\n            \"Should contain fc2.weight (filter not applied to get_all_snapshots)\"\n        );\n    }\n}\n\n/// Tests for contiguous index mapping feature\n#[cfg(test)]\nmod map_indices_contiguous_tests {\n    use super::*;\n    type TestBackend = burn_ndarray::NdArray;\n\n    /// Model with a Vec of Conv2d layers that expects contiguous indices\n    #[derive(Module, Debug)]\n    struct SequentialConvModel<B: Backend> {\n        fc: Vec<Conv2d<B>>,\n    }\n\n    impl<B: Backend> SequentialConvModel<B> {\n        pub fn new(device: &B::Device, num_layers: usize) -> Self {\n            Self {\n                fc: (0..num_layers)\n                    .map(|_| {\n                        Conv2dConfig::new([2, 2], [3, 3])\n                            .with_bias(true)\n                            .init(device)\n                    })\n                    .collect(),\n            }\n        }\n    }\n\n    #[test]\n    fn test_load_non_contiguous_indexes_with_mapping() {\n        // This test uses the non_contiguous_indexes.pt file which has:\n        // fc.0.weight, fc.0.bias, fc.2.weight, fc.2.bias, fc.4.weight, ... (non-contiguous)\n        // The Burn model expects fc.0, fc.1, fc.2, ... (contiguous)\n\n        let path = pytorch_test_path(\"non_contiguous_indexes\", \"non_contiguous_indexes.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let device = Default::default();\n\n        // Create model with 5 conv layers (matching the PyTorch model)\n        let mut model = SequentialConvModel::<TestBackend>::new(&device, 5);\n\n        // Load with contiguous index mapping enabled (default)\n        let mut store = PytorchStore::from_file(&path)\n            .map_indices_contiguous(true)\n            .allow_partial(true)\n            .validate(false);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        match result {\n            Ok(apply_result) => {\n                println!(\"Applied tensors: {:?}\", apply_result.applied);\n                println!(\"Missing tensors: {:?}\", apply_result.missing);\n                println!(\"Unused tensors: {:?}\", apply_result.unused);\n\n                // All fc layers should be loaded successfully\n                assert!(\n                    !apply_result.applied.is_empty(),\n                    \"Should have applied tensors\"\n                );\n\n                // Verify we have tensors from all 5 layers\n                // With mapping: fc.0, fc.1, fc.2, fc.3, fc.4\n                for i in 0..5 {\n                    let has_weight = apply_result\n                        .applied\n                        .iter()\n                        .any(|p| p.contains(&format!(\"fc.{}.weight\", i)));\n                    let has_bias = apply_result\n                        .applied\n                        .iter()\n                        .any(|p| p.contains(&format!(\"fc.{}.bias\", i)));\n\n                    assert!(\n                        has_weight,\n                        \"Should have applied fc.{}.weight, applied: {:?}\",\n                        i, apply_result.applied\n                    );\n                    assert!(\n                        has_bias,\n                        \"Should have applied fc.{}.bias, applied: {:?}\",\n                        i, apply_result.applied\n                    );\n                }\n\n                // There should be no missing tensors (assuming model matches)\n                let missing_fc: Vec<_> = apply_result\n                    .missing\n                    .iter()\n                    .filter(|(p, _)| p.starts_with(\"fc.\"))\n                    .collect();\n                assert!(\n                    missing_fc.is_empty(),\n                    \"Should have no missing fc tensors with index mapping. Missing: {:?}\",\n                    missing_fc\n                );\n            }\n            Err(e) => panic!(\"Failed to load with index mapping: {}\", e),\n        }\n    }\n\n    #[test]\n    fn test_load_non_contiguous_indexes_without_mapping() {\n        // This test verifies that loading fails or has missing tensors when\n        // map_indices_contiguous is disabled\n\n        let path = pytorch_test_path(\"non_contiguous_indexes\", \"non_contiguous_indexes.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        let device = Default::default();\n\n        // Create model with 5 conv layers\n        let mut model = SequentialConvModel::<TestBackend>::new(&device, 5);\n\n        // Load with contiguous index mapping DISABLED\n        let mut store = PytorchStore::from_file(&path)\n            .map_indices_contiguous(false) // Disable index mapping\n            .allow_partial(true)\n            .validate(false);\n\n        let result = store.apply_to::<TestBackend, _>(&mut model);\n\n        match result {\n            Ok(apply_result) => {\n                println!(\n                    \"Without index mapping - Applied tensors: {:?}\",\n                    apply_result.applied\n                );\n                println!(\n                    \"Without index mapping - Missing tensors: {:?}\",\n                    apply_result.missing\n                );\n\n                // Without index mapping, we should have missing tensors for fc.1, fc.3\n                // because the source has fc.0, fc.2, fc.4, fc.6, fc.8 but model expects fc.0-4\n                let missing_fc: Vec<_> = apply_result\n                    .missing\n                    .iter()\n                    .filter(|(p, _)| p.starts_with(\"fc.\"))\n                    .collect();\n\n                assert!(\n                    !missing_fc.is_empty(),\n                    \"Should have missing fc tensors without index mapping (indices 1, 3 don't exist in file)\"\n                );\n\n                // Specifically, fc.1 and fc.3 should be missing\n                let has_fc1_missing = apply_result\n                    .missing\n                    .iter()\n                    .any(|(p, _)| p.starts_with(\"fc.1.\"));\n                let has_fc3_missing = apply_result\n                    .missing\n                    .iter()\n                    .any(|(p, _)| p.starts_with(\"fc.3.\"));\n\n                assert!(\n                    has_fc1_missing || has_fc3_missing,\n                    \"Should have fc.1 or fc.3 missing. Missing: {:?}\",\n                    apply_result.missing\n                );\n            }\n            Err(e) => panic!(\"Unexpected error: {}\", e),\n        }\n    }\n\n    #[test]\n    fn test_mapping_applied_to_keys() {\n        // Verify that the keys returned by the store are mapped\n        let path = pytorch_test_path(\"non_contiguous_indexes\", \"non_contiguous_indexes.pt\");\n\n        if !path.exists() {\n            println!(\"Skipping test - file not found: {:?}\", path);\n            return;\n        }\n\n        // With index mapping enabled (default)\n        let mut store_mapped = PytorchStore::from_file(&path).map_indices_contiguous(true);\n\n        let keys_mapped = store_mapped.keys().unwrap();\n        println!(\"Keys with index mapping: {:?}\", keys_mapped);\n\n        // Should have contiguous keys: fc.0, fc.1, fc.2, fc.3, fc.4\n        assert!(\n            keys_mapped.iter().any(|k| k.starts_with(\"fc.1.\")),\n            \"With index mapping, should have fc.1 (from fc.2)\"\n        );\n        assert!(\n            keys_mapped.iter().any(|k| k.starts_with(\"fc.2.\")),\n            \"With index mapping, should have fc.2 (from fc.4)\"\n        );\n\n        // Without index mapping\n        let mut store_no_mapping = PytorchStore::from_file(&path).map_indices_contiguous(false);\n\n        let keys_no_mapping = store_no_mapping.keys().unwrap();\n        println!(\"Keys without index mapping: {:?}\", keys_no_mapping);\n\n        // Should have original non-contiguous keys: fc.0, fc.2, fc.4, fc.6, fc.8\n        assert!(\n            keys_no_mapping.iter().any(|k| k.starts_with(\"fc.2.\")),\n            \"Without index mapping, should have original fc.2\"\n        );\n        assert!(\n            keys_no_mapping.iter().any(|k| k.starts_with(\"fc.4.\")),\n            \"Without index mapping, should have original fc.4\"\n        );\n        assert!(\n            !keys_no_mapping.iter().any(|k| k.starts_with(\"fc.1.\")),\n            \"Without index mapping, should NOT have fc.1 (not in original file)\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/pytorch/tests/store/test_data/generate_enum_test.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nGenerate PyTorch test data for enum variant path mismatch testing.\n\nThis script creates a PyTorch checkpoint that simulates how PyTorch models\nexport their state dicts WITHOUT enum variant names in the paths.\n\nExample:\n- PyTorch path: \"feature.weight\"\n- Burn path:    \"feature.BaseConv.weight\"  (includes enum variant \"BaseConv\")\n\nRun with: uv run generate_enum_test.py\n\"\"\"\n\nimport torch\nimport torch.nn as nn\n\n\nclass SimpleModel(nn.Module):\n    \"\"\"\n    Simple PyTorch model that represents what a Burn enum model would look like\n    WITHOUT the enum variant names in the path.\n\n    In Burn, this would be:\n    struct ModelWithEnum {\n        feature: ConvBlock,  // enum with BaseConv, DwsConv variants\n        classifier: Linear,\n    }\n\n    But PyTorch exports it as flat paths without the enum variant names.\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n        # This represents the \"feature\" field which is an enum in Burn\n        # PyTorch doesn't have enums, so it's just a Linear layer\n        # Path will be: \"feature.weight\" and \"feature.bias\"\n        self.feature = nn.Linear(3, 64)\n\n        # This represents the \"classifier\" field\n        # Path will be: \"classifier.weight\" and \"classifier.bias\"\n        self.classifier = nn.Linear(64, 10)\n\n    def forward(self, x):\n        x = self.feature(x)\n        x = torch.relu(x)\n        x = self.classifier(x)\n        return x\n\n\ndef generate_enum_variant_mismatch_test():\n    \"\"\"Generate test file demonstrating enum variant path mismatch.\"\"\"\n    model = SimpleModel()\n\n    # Initialize with some deterministic weights for testing\n    torch.manual_seed(42)\n    for param in model.parameters():\n        param.data.normal_(0, 0.1)\n\n    # Save the state dict\n    # PyTorch paths: \"feature.weight\", \"feature.bias\", \"classifier.weight\", \"classifier.bias\"\n    # Burn paths:    \"feature.BaseConv.weight\", \"feature.BaseConv.bias\", ...\n    #                        ^^^^^^^^ enum variant is missing in PyTorch\n    torch.save(model.state_dict(), \"model_without_enum_variants.pt\")\n\n    print(\"Generated: model_without_enum_variants.pt\")\n    print(\"\\nPyTorch state dict keys:\")\n    for key in model.state_dict().keys():\n        shape = tuple(model.state_dict()[key].shape)\n        print(f\"  {key}: {shape}\")\n\n    print(\"\\nExpected Burn paths (with enum variant):\")\n    print(\"  feature.BaseConv.weight: (3, 64)\")\n    print(\"  feature.BaseConv.bias: (64,)\")\n    print(\"  classifier.weight: (64, 10)\")\n    print(\"  classifier.bias: (10,)\")\n\n    print(\"\\n⚠️  Notice: Burn includes 'BaseConv' enum variant, PyTorch doesn't!\")\n\n\nif __name__ == \"__main__\":\n    generate_enum_variant_mismatch_test()\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/mod.rs",
    "content": "//! SafeTensors format support for Burn deep learning framework.\n//!\n//! [SafeTensors](https://github.com/huggingface/safetensors) is a simple, safe, and efficient format\n//! for storing and loading tensors. It provides fast zero-copy deserialization and strong safety\n//! guarantees, making it ideal for production environments.\n//!\n//! # Features\n//!\n//! - **Fast Loading**: Zero-copy tensor access using safetensors' built-in mechanisms\n//! - **Safety**: Prevents arbitrary code execution during model loading\n//! - **Efficiency**: Memory-mapped files enable lazy loading without reading entire file\n//! - **Filtering**: Load only specific tensors using path filters\n//! - **Remapping**: Transform tensor names during load/save operations\n//! - **Metadata**: Store and retrieve custom metadata alongside tensors (automatic `format`, `producer` and `version` metadata included)\n//! - **Cross-Platform**: Works on all platforms including no-std environments\n//!\n//! # Usage Examples\n//!\n//! ## Basic Save and Load\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot};\n//!\n//! // Save a model to a file\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\");\n//! model.save_into(&mut store)?;\n//!\n//! // Load a model from a file\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\");\n//! let mut model = Model::new(&device);\n//! model.load_from(&mut store)?;\n//! ```\n//!\n//! ## Memory-Based Operations\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot};\n//!\n//! // Save to memory buffer\n//! let mut store = SafetensorsStore::from_bytes(None);\n//! model.save_into(&mut store)?;\n//! let bytes = store.get_bytes()?;\n//!\n//! // Load from memory buffer\n//! let mut store = SafetensorsStore::from_bytes(Some(bytes));\n//! let mut model = Model::new(&device);\n//! model.load_from(&mut store)?;\n//! ```\n//!\n//! ## Advanced Features\n//!\n//! ### Filter Configuration with Builder Pattern\n//!\n//! ```rust,no_run\n//! # use burn_store::SafetensorsStore;\n//! // Filter with regex patterns (OR logic - matches any pattern)\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_regex(r\"^encoder\\..*\")     // Match all encoder tensors\n//!     .with_regex(r\".*\\.bias$\");        // OR match any bias tensors\n//!\n//! // Filter with exact paths\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_full_path(\"encoder.weight\")\n//!     .with_full_path(\"encoder.bias\")\n//!     .with_full_paths(vec![\"decoder.scale\", \"decoder.norm\"]);\n//!\n//! // Custom filter logic with predicate\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_predicate(|path, _dtype| {\n//!         // Only save layer weights (not biases)\n//!         path.contains(\"layer\") && path.ends_with(\"weight\")\n//!     });\n//!\n//! // Combine multiple filter methods\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_regex(r\"^encoder\\..*\")           // All encoder tensors\n//!     .with_full_path(\"decoder.scale\")       // Plus specific decoder.scale\n//!     .with_predicate(|path, _| {            // Plus any projection tensors\n//!         path.contains(\"projection\")\n//!     });\n//!\n//! // Save or load all tensors (no filtering)\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .match_all();\n//! ```\n//!\n//! ### Tensor Name Remapping\n//!\n//! Remap tensor names during load/save operations for compatibility between different frameworks:\n//!\n//! ```rust,no_run\n//! # use burn_store::{SafetensorsStore, KeyRemapper};\n//! // Using builder pattern for common remapping patterns\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_key_remapping(r\"^encoder\\.\", \"transformer.encoder.\")  // encoder.X -> transformer.encoder.X\n//!     .with_key_remapping(r\"\\.gamma$\", \".weight\")                // X.gamma -> X.weight\n//!     .with_key_remapping(r\"\\.beta$\", \".bias\");                  // X.beta -> X.bias\n//!\n//! // Or using a pre-configured KeyRemapper for complex transformations\n//! let remapper = KeyRemapper::new()\n//!     .add_pattern(r\"^pytorch\\.(.*)\", \"burn.$1\").expect(\"valid regex\")           // pytorch.layer -> burn.layer\n//!     .add_pattern(r\"^(.*)\\.running_mean$\", \"$1.mean\").expect(\"valid regex\")     // layer.running_mean -> layer.mean\n//!     .add_pattern(r\"^(.*)\\.running_var$\", \"$1.variance\").expect(\"valid regex\"); // layer.running_var -> layer.variance\n//!\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .remap(remapper);\n//! ```\n//!\n//! ### Framework Adapters\n//!\n//! Use adapters for automatic framework-specific transformations:\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter, BurnToPyTorchAdapter};\n//!\n//! // Loading PyTorch model into Burn\n//! let mut store = SafetensorsStore::from_file(\"pytorch_model.safetensors\")\n//!     .with_from_adapter(PyTorchToBurnAdapter)  // Transposes linear weights, renames norm params\n//!     .allow_partial(true);                     // PyTorch models may have extra tensors\n//!\n//! let mut burn_model = Model::new(&device);\n//! burn_model.load_from(&mut store)?;\n//!\n//! // Saving Burn model for PyTorch\n//! let mut store = SafetensorsStore::from_file(\"for_pytorch.safetensors\")\n//!     .with_to_adapter(BurnToPyTorchAdapter);   // Transposes weights back, renames for PyTorch\n//!\n//! burn_model.save_into(&mut store)?;\n//! ```\n//!\n//! ### Additional Configuration Options\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot};\n//!\n//! let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     // Add custom metadata\n//!     .metadata(\"version\", \"1.0.0\")\n//!     .metadata(\"producer\", \"burn\")\n//!     // Allow partial loading (continue even if some tensors are missing)\n//!     .allow_partial(true)\n//!     // Disable validation for faster loading\n//!     .validate(false);\n//!\n//! // Use the configured store\n//! model.save_into(&mut store)?;  // For saving\n//! // or\n//! model.load_from(&mut store)?;   // For loading\n//! ```\n//!\n//! # Efficient Loading with SafeTensors\n//!\n//! SafeTensors provides efficient tensor loading through its zero-copy design:\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot};\n//!\n//! let mut store = SafetensorsStore::from_file(\"large_model.safetensors\");\n//! // Uses memory mapping (when available) for zero-copy access\n//! // Falls back to buffered reading when mmap is not available\n//! let mut model = Model::new(&device);\n//! model.load_from(&mut store)?;\n//! ```\n//!\n//! The safetensors approach provides:\n//! - Zero-copy views - tensors are accessed directly from the mapped file\n//! - Lazy loading - only accessed tensors are materialized\n//! - Efficient memory usage - no unnecessary data duplication\n//!\n//! # Lazy Loading and Inspection\n//!\n//! SafeTensors provides efficient inspection and selective loading through its\n//! zero-copy design and built-in metadata handling:\n//!\n//! ```rust,ignore\n//! use burn_store::SafetensorsStore;\n//!\n//! // Open a file - uses safetensors' efficient header reading\n//! let store = SafetensorsStore::from_file(\"large_model.safetensors\");\n//!\n//! // List all tensor names from the metadata\n//! let tensor_names = store.list_tensors()?;\n//! println!(\"Model contains {} tensors\", tensor_names.len());\n//!\n//! // Get tensor metadata without loading tensor data\n//! if let Some((shape, dtype)) = store.tensor_info(\"encoder.weight\")? {\n//!     println!(\"Encoder weight shape: {:?}, dtype: {:?}\", shape, dtype);\n//! }\n//!\n//! // Selectively load tensors - safetensors handles efficient access\n//! let encoder_tensors = store.load_tensors(&[\n//!     \"encoder.weight\",\n//!     \"encoder.bias\",\n//!     \"encoder.norm\"\n//! ])?;\n//!\n//! // Distributed loading: each worker loads only its assigned layers\n//! // SafeTensors' zero-copy views ensure minimal memory usage\n//! let worker_layers = match worker_id {\n//!     0 => vec![\"encoder.layer1\", \"encoder.layer2\"],\n//!     1 => vec![\"encoder.layer3\", \"encoder.layer4\"],\n//!     2 => vec![\"decoder.layer1\", \"decoder.layer2\"],\n//!     _ => vec![\"head.weight\", \"head.bias\"],\n//! };\n//! let worker_tensors = store.load_tensors(&worker_layers)?;\n//! ```\n//!\n//! # Builder Pattern API Reference\n//!\n//! The SafetensorsStore provides a fluent builder API for configuration:\n//!\n//! ## Filtering Methods\n//!\n//! - **`with_regex(pattern)`** - Add regex pattern to match tensor names (OR logic with multiple patterns)\n//! - **`with_full_path(path)`** - Add exact tensor path to include\n//! - **`with_full_paths(paths)`** - Add multiple exact tensor paths to include\n//! - **`with_predicate(fn)`** - Add custom filter function `fn(&str, &str) -> bool`\n//! - **`match_all()`** - Disable filtering, include all tensors\n//!\n//! ## Remapping Methods\n//!\n//! - **`with_key_remapping(from, to)`** - Add regex pattern to rename tensors\n//! - **`remap(KeyRemapper)`** - Use a pre-configured KeyRemapper for complex transformations\n//!\n//! ## Adapter Methods\n//!\n//! - **`with_from_adapter(adapter)`** - Set adapter for loading (e.g., PyTorchToBurnAdapter)\n//! - **`with_to_adapter(adapter)`** - Set adapter for saving (e.g., BurnToPyTorchAdapter)\n//!\n//! ## Configuration Methods\n//!\n//! - **`metadata(key, value)`** - Add custom metadata to saved files (in addition to automatic `format`, `producer` and `version`)\n//! - **`allow_partial(bool)`** - Allow loading even if some tensors are missing\n//! - **`validate(bool)`** - Enable/disable tensor validation during loading\n//!\n//! All methods return `Self` for chaining:\n//!\n//! ```rust,no_run\n//! use burn_store::{SafetensorsStore, PyTorchToBurnAdapter};\n//!\n//! let store = SafetensorsStore::from_file(\"model.safetensors\")\n//!     .with_regex(r\"^encoder\\..*\")\n//!     .with_key_remapping(r\"\\.gamma$\", \".weight\")\n//!     .with_from_adapter(PyTorchToBurnAdapter)\n//!     .allow_partial(true)\n//!     .metadata(\"version\", \"2.0\");\n//! ```\n//!\n//! # Working with Bytes\n//!\n//! For direct byte operations without files:\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot};\n//!\n//! // Save to bytes with filtering and remapping\n//! let mut store = SafetensorsStore::from_bytes(None)\n//!     .with_regex(r\"^encoder\\..*\")                       // Only save encoder tensors\n//!     .with_key_remapping(r\"^encoder\\.\", \"transformer.\")  // Rename encoder.X -> transformer.X\n//!     .metadata(\"subset\", \"encoder_only\");\n//! model.save_into(&mut store)?;\n//! let bytes = store.get_bytes()?;\n//!\n//! // Load from bytes (allow partial since we only saved encoder)\n//! let mut store = SafetensorsStore::from_bytes(Some(bytes))\n//!     .with_key_remapping(r\"^transformer\\.\", \"encoder.\")  // Rename back: transformer.X -> encoder.X\n//!     .allow_partial(true);\n//! let mut model = Model::new(&device);\n//! let result = model.load_from(&mut store)?;\n//! println!(\"Applied {} tensors\", result.applied.len());\n//! ```\n//!\n//! # Complete Example: PyTorch Model Migration\n//!\n//! Migrating a PyTorch model to Burn with filtering, remapping, and adapters:\n//!\n//! ```rust,ignore\n//! use burn_store::{SafetensorsStore, ModuleSnapshot, PyTorchToBurnAdapter};\n//!\n//! // Load PyTorch model with all transformations\n//! let mut store = SafetensorsStore::from_file(\"pytorch_model.safetensors\")\n//!     // Use PyTorch adapter for automatic transformations\n//!     .with_from_adapter(PyTorchToBurnAdapter)\n//!     // Only load transformer layers\n//!     .with_regex(r\"^transformer\\..*\")\n//!     // Rename old layer names to new structure\n//!     .with_key_remapping(r\"^transformer\\.h\\.(\\d+)\\.\", \"transformer.layer$1.\")\n//!     // Skip unexpected tensors from PyTorch\n//!     .allow_partial(true)\n//!     // Add metadata about the conversion\n//!     .metadata(\"source\", \"pytorch\")\n//!     .metadata(\"converted_by\", \"burn-store\");\n//!\n//! let mut model = TransformerModel::new(&device);\n//! let result = model.load_from(&mut store)?;\n//!\n//! println!(\"Successfully loaded {} tensors\", result.applied.len());\n//! if !result.missing.is_empty() {\n//!     println!(\"Missing tensors: {:?}\", result.missing);\n//! }\n//! ```\n//!\n//! # Format Details\n//!\n//! SafeTensors uses a simple binary format:\n//! - **8 bytes**: Header size (unsigned little-endian 64-bit integer)\n//! - **N bytes**: JSON header with tensor metadata\n//!   - Contains: `{\"tensor_name\": {\"dtype\": \"F32\", \"shape\": [1, 2, 3], \"data_offsets\": [start, end]}, ...}`\n//!   - Special key `__metadata__` for user-defined string metadata\n//! - **Rest**: Raw tensor data (referenced by offsets in header)\n//!\n//! The format enables:\n//! - **Secure loading**: No code execution, just data\n//! - **Efficient access**: Use offsets to read only needed tensors\n//! - **Simple parsing**: Standard JSON header with fixed structure\n\nmod store;\n\npub use store::{SafetensorsStore, SafetensorsStoreError};\n\n#[cfg(test)]\nmod tests;\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/store.rs",
    "content": "//! SafeTensors store implementation using the official safetensors crate.\n\nuse crate::{\n    ApplyResult, IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter,\n    TensorSnapshot,\n};\n\n#[cfg(feature = \"std\")]\nuse crate::{KeyRemapper, map_indices_contiguous};\nuse alloc::boxed::Box;\nuse alloc::collections::BTreeMap;\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_core::module::ParamId;\nuse burn_tensor::backend::Backend;\nuse burn_tensor::{BoolStore, DType, TensorData};\nuse core::fmt;\nuse core::ops::Deref;\nuse hashbrown::HashMap;\n\n// Arc is only available on targets with atomic pointers\n#[cfg(target_has_atomic = \"ptr\")]\nuse alloc::sync::Arc;\n\n// For targets without atomic pointers, we use Box instead\n#[cfg(not(target_has_atomic = \"ptr\"))]\ntype Arc<T> = Box<T>;\n\n/// Errors that can occur during SafeTensors operations.\n#[derive(Debug)]\npub enum SafetensorsStoreError {\n    /// SafeTensors crate error.\n    Safetensors(safetensors::SafeTensorError),\n\n    /// I/O error.\n    #[cfg(feature = \"std\")]\n    Io(std::io::Error),\n\n    /// Tensor not found.\n    TensorNotFound(String),\n\n    /// Validation failed.\n    ValidationFailed(String),\n\n    /// Other error.\n    Other(String),\n}\n\nimpl fmt::Display for SafetensorsStoreError {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        match self {\n            Self::Safetensors(e) => write!(f, \"SafeTensors error: {}\", e),\n            #[cfg(feature = \"std\")]\n            Self::Io(e) => write!(f, \"I/O error: {}\", e),\n            Self::TensorNotFound(name) => write!(f, \"Tensor not found: {}\", name),\n            Self::ValidationFailed(msg) => write!(f, \"Validation failed: {}\", msg),\n            Self::Other(msg) => write!(f, \"{}\", msg),\n        }\n    }\n}\n\nimpl core::error::Error for SafetensorsStoreError {}\n\nimpl From<safetensors::SafeTensorError> for SafetensorsStoreError {\n    fn from(e: safetensors::SafeTensorError) -> Self {\n        SafetensorsStoreError::Safetensors(e)\n    }\n}\n\n#[cfg(feature = \"std\")]\nimpl From<std::io::Error> for SafetensorsStoreError {\n    fn from(e: std::io::Error) -> Self {\n        SafetensorsStoreError::Io(e)\n    }\n}\n\n/// SafeTensors store supporting both file and memory storage.\npub enum SafetensorsStore {\n    /// File-based storage.\n    #[cfg(feature = \"std\")]\n    File(FileStore),\n\n    /// Memory-based storage.\n    Memory(MemoryStore),\n}\n\nimpl Default for SafetensorsStore {\n    /// Create a default memory-based store.\n    fn default() -> Self {\n        Self::from_bytes(None)\n    }\n}\n\nimpl SafetensorsStore {\n    /// Get the default metadata that includes Burn framework information.\n    ///\n    /// This includes:\n    /// - `format`: \"safetensors\"\n    /// - `producer`: \"burn\"\n    /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)\n    ///\n    /// These metadata fields are automatically added to all saved models.\n    pub fn default_metadata() -> HashMap<String, String> {\n        let mut metadata = HashMap::new();\n        metadata.insert(\"format\".to_string(), \"safetensors\".to_string());\n        metadata.insert(\"producer\".to_string(), \"burn\".to_string());\n        metadata.insert(\"version\".to_string(), env!(\"CARGO_PKG_VERSION\").to_string());\n        metadata\n    }\n\n    /// Create a store for loading from or saving to a file.\n    #[cfg(feature = \"std\")]\n    pub fn from_file(path: impl Into<std::path::PathBuf>) -> Self {\n        Self::File(FileStore {\n            path: path.into(),\n            filter: PathFilter::new(),\n            remapper: KeyRemapper::new(),\n            metadata: Self::default_metadata(),\n            validate: true,\n            allow_partial: false,\n            overwrite: false,\n            skip_enum_variants: false,\n            // Contiguous index mapping is off by default for SafeTensors\n            // (SafeTensors files typically have clean, contiguous indices)\n            map_indices_contiguous: false,\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            snapshots_cache: None,\n        })\n    }\n\n    /// Create a store for working with bytes in memory.\n    pub fn from_bytes(bytes: Option<Vec<u8>>) -> Self {\n        Self::Memory(MemoryStore {\n            data: bytes.map(Arc::new),\n            filter: PathFilter::new(),\n            #[cfg(feature = \"std\")]\n            remapper: KeyRemapper::new(),\n            metadata: Self::default_metadata(),\n            validate: true,\n            allow_partial: false,\n            skip_enum_variants: false,\n            // Contiguous index mapping is off by default for SafeTensors\n            #[cfg(feature = \"std\")]\n            map_indices_contiguous: false,\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            snapshots_cache: None,\n        })\n    }\n\n    /// Filter which tensors to load/save.\n    pub fn filter(mut self, filter: PathFilter) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = filter,\n            Self::Memory(p) => p.filter = filter,\n        }\n        self\n    }\n\n    /// Add a regex pattern to filter tensors.\n    ///\n    /// Multiple patterns can be added and they work with OR logic.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .with_regex(r\"^encoder\\..*\")  // Match all encoder tensors\n    ///     .with_regex(r\".*\\.weight$\");   // OR match any weight tensors\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_regex(pattern),\n            Self::Memory(p) => p.filter = p.filter.clone().with_regex(pattern),\n        }\n        self\n    }\n\n    /// Add multiple regex patterns to filter tensors.\n    #[cfg(feature = \"std\")]\n    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: AsRef<str>,\n    {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_regexes(patterns),\n            Self::Memory(p) => p.filter = p.filter.clone().with_regexes(patterns),\n        }\n        self\n    }\n\n    /// Add an exact full path to match.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .with_full_path(\"encoder.layer1.weight\")\n    ///     .with_full_path(\"decoder.output.bias\");\n    /// ```\n    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_full_path(path),\n            Self::Memory(p) => p.filter = p.filter.clone().with_full_path(path),\n        }\n        self\n    }\n\n    /// Add multiple exact full paths to match.\n    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self\n    where\n        I: IntoIterator<Item = S>,\n        S: Into<String>,\n    {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_full_paths(paths),\n            Self::Memory(p) => p.filter = p.filter.clone().with_full_paths(paths),\n        }\n        self\n    }\n\n    /// Add a predicate function for custom filtering logic.\n    ///\n    /// The predicate receives the tensor path and container path.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .with_predicate(|path, _| path.starts_with(\"encoder.\") || path.ends_with(\".bias\"));\n    /// ```\n    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_predicate(predicate),\n            Self::Memory(p) => p.filter = p.filter.clone().with_predicate(predicate),\n        }\n        self\n    }\n\n    /// Add multiple predicate functions.\n    pub fn with_predicates<I>(mut self, predicates: I) -> Self\n    where\n        I: IntoIterator<Item = fn(&str, &str) -> bool>,\n    {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().with_predicates(predicates),\n            Self::Memory(p) => p.filter = p.filter.clone().with_predicates(predicates),\n        }\n        self\n    }\n\n    /// Set the filter to match all paths (disables filtering).\n    pub fn match_all(mut self) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.filter = p.filter.clone().match_all(),\n            Self::Memory(p) => p.filter = p.filter.clone().match_all(),\n        }\n        self\n    }\n\n    /// Remap tensor names during load/save.\n    #[cfg(feature = \"std\")]\n    pub fn remap(mut self, remapper: KeyRemapper) -> Self {\n        match &mut self {\n            Self::File(p) => p.remapper = remapper,\n            Self::Memory(p) => p.remapper = remapper,\n        }\n        self\n    }\n\n    /// Add a regex pattern to remap tensor names during load/save.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .with_key_remapping(r\"^encoder\\.\", \"transformer.encoder.\")  // encoder.X -> transformer.encoder.X\n    ///     .with_key_remapping(r\"\\.gamma$\", \".weight\");               // X.gamma -> X.weight\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn with_key_remapping(\n        mut self,\n        from_pattern: impl AsRef<str>,\n        to_pattern: impl Into<String>,\n    ) -> Self {\n        match &mut self {\n            Self::File(p) => {\n                p.remapper = p\n                    .remapper\n                    .clone()\n                    .add_pattern(from_pattern, to_pattern)\n                    .expect(\"Invalid regex pattern\");\n            }\n            Self::Memory(p) => {\n                p.remapper = p\n                    .remapper\n                    .clone()\n                    .add_pattern(from_pattern, to_pattern)\n                    .expect(\"Invalid regex pattern\");\n            }\n        }\n        self\n    }\n\n    /// Add metadata to be saved with the tensors.\n    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {\n        let key = key.into();\n        let value = value.into();\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => {\n                p.metadata.insert(key, value);\n            }\n            Self::Memory(p) => {\n                p.metadata.insert(key, value);\n            }\n        }\n        self\n    }\n\n    /// Clear all metadata including the default Burn framework metadata.\n    ///\n    /// This removes the automatic `format`, `producer` and `version` fields.\n    /// Use this when you need complete control over metadata or when\n    /// saving models for use with other frameworks.\n    pub fn clear_metadata(mut self) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => {\n                p.metadata.clear();\n            }\n            Self::Memory(p) => {\n                p.metadata.clear();\n            }\n        }\n        self\n    }\n\n    /// Set whether to validate tensors during loading (default: true).\n    pub fn validate(mut self, validate: bool) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.validate = validate,\n            Self::Memory(p) => p.validate = validate,\n        }\n        self\n    }\n\n    /// Allow partial loading of tensors (continue even if some tensors are missing).\n    pub fn allow_partial(mut self, allow: bool) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.allow_partial = allow,\n            Self::Memory(p) => p.allow_partial = allow,\n        }\n        self\n    }\n\n    /// Skip enum variant names when loading or saving tensor paths.\n    ///\n    /// When enabled during **loading**, tensor paths from the source that don't include enum variants\n    /// can be matched against Burn module paths that do include them.\n    /// For example, source path \"feature.weight\" can match Burn path \"feature.BaseConv.weight\".\n    ///\n    /// When enabled during **saving**, enum variant names are omitted from the exported tensor paths,\n    /// making them compatible with PyTorch naming conventions.\n    /// For example, \"feature.BaseConv.weight\" becomes \"feature.weight\" in the exported file.\n    ///\n    /// This is useful when working with models from/to formats that don't include enum variant\n    /// names in their parameter paths (like PyTorch models).\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// // For PyTorch compatibility\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .skip_enum_variants(true);\n    /// ```\n    pub fn skip_enum_variants(mut self, skip: bool) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.skip_enum_variants = skip,\n            Self::Memory(p) => p.skip_enum_variants = skip,\n        }\n        self\n    }\n\n    /// Enable or disable automatic contiguous mapping of layer indices (default: false).\n    ///\n    /// When enabled, non-contiguous numeric indices in tensor paths are renumbered\n    /// to be contiguous. This is useful when loading models that have gaps\n    /// in layer numbering, such as PyTorch models using `nn.Sequential` with mixed\n    /// layer types (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).\n    ///\n    /// # Example\n    ///\n    /// With index mapping enabled:\n    /// - `fc.0.weight` → `fc.0.weight`\n    /// - `fc.2.weight` → `fc.1.weight` (gap filled)\n    /// - `fc.4.weight` → `fc.2.weight` (gap filled)\n    ///\n    /// # Arguments\n    ///\n    /// * `map` - `true` to enable contiguous index mapping, `false` to disable\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// // Enable contiguous index mapping for PyTorch-exported safetensors\n    /// let store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .map_indices_contiguous(true);\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn map_indices_contiguous(mut self, map: bool) -> Self {\n        match &mut self {\n            Self::File(p) => p.map_indices_contiguous = map,\n            Self::Memory(p) => p.map_indices_contiguous = map,\n        }\n        self\n    }\n\n    /// Set whether to overwrite existing files when saving (default: false).\n    ///\n    /// When set to `false`, attempting to save to an existing file will result in an error.\n    /// When set to `true`, existing files will be overwritten without warning.\n    ///\n    /// This setting only applies to file-based stores.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// let mut store = SafetensorsStore::from_file(\"model.safetensors\")\n    ///     .overwrite(true);\n    /// // Will overwrite if file exists when saving\n    /// ```\n    #[cfg(feature = \"std\")]\n    pub fn overwrite(mut self, overwrite: bool) -> Self {\n        match &mut self {\n            Self::File(p) => p.overwrite = overwrite,\n            Self::Memory(_) => {\n                // Memory stores don't have overwrite semantics, ignore\n            }\n        }\n        self\n    }\n\n    /// Set the adapter for loading tensors (converting from source format to Burn).\n    pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.from_adapter = Box::new(adapter),\n            Self::Memory(p) => p.from_adapter = Box::new(adapter),\n        }\n        self\n    }\n\n    /// Set the adapter for saving tensors (converting from Burn to target format).\n    pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {\n        match &mut self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.to_adapter = Box::new(adapter),\n            Self::Memory(p) => p.to_adapter = Box::new(adapter),\n        }\n        self\n    }\n\n    /// Get saved bytes from memory-based store.\n    ///\n    /// # Example\n    /// ```rust,no_run\n    /// # use burn_store::SafetensorsStore;\n    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {\n    /// let mut store = SafetensorsStore::from_bytes(None);\n    /// // After saving model with collect_to()...\n    /// let bytes = store.get_bytes()?;\n    /// # Ok(())\n    /// # }\n    /// ```\n    pub fn get_bytes(&self) -> Result<Vec<u8>, SafetensorsStoreError> {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(_) => Err(SafetensorsStoreError::Other(\n                \"Cannot get bytes from file-based store\".to_string(),\n            )),\n            Self::Memory(p) => p\n                .data()\n                .map(|arc| arc.as_ref().clone())\n                .ok_or_else(|| SafetensorsStoreError::Other(\"No data available\".to_string())),\n        }\n    }\n}\n\n/// File-based store.\n#[cfg(feature = \"std\")]\npub struct FileStore {\n    path: std::path::PathBuf,\n    filter: PathFilter,\n    remapper: KeyRemapper,\n    metadata: HashMap<String, String>,\n    validate: bool,\n    allow_partial: bool,\n    overwrite: bool,\n    skip_enum_variants: bool,\n    /// Enable contiguous mapping of layer indices (default: false)\n    map_indices_contiguous: bool,\n    from_adapter: Box<dyn ModuleAdapter>,\n    to_adapter: Box<dyn ModuleAdapter>,\n    /// Cached tensor snapshots (parsed once, reused)\n    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,\n}\n\n/// Memory-based store.\npub struct MemoryStore {\n    data: Option<Arc<Vec<u8>>>,\n    filter: PathFilter,\n    #[cfg(feature = \"std\")]\n    remapper: KeyRemapper,\n    metadata: HashMap<String, String>,\n    validate: bool,\n    allow_partial: bool,\n    skip_enum_variants: bool,\n    /// Enable contiguous mapping of layer indices (default: false)\n    #[cfg(feature = \"std\")]\n    map_indices_contiguous: bool,\n    from_adapter: Box<dyn ModuleAdapter>,\n    to_adapter: Box<dyn ModuleAdapter>,\n    /// Cached tensor snapshots (parsed once, reused)\n    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,\n}\n\nimpl Default for MemoryStore {\n    fn default() -> Self {\n        Self {\n            data: None,\n            filter: PathFilter::new(),\n            #[cfg(feature = \"std\")]\n            remapper: KeyRemapper::new(),\n            metadata: HashMap::new(),\n            validate: true,\n            allow_partial: false,\n            skip_enum_variants: false,\n            #[cfg(feature = \"std\")]\n            map_indices_contiguous: false,\n            from_adapter: Box::new(IdentityAdapter),\n            to_adapter: Box::new(IdentityAdapter),\n            snapshots_cache: None,\n        }\n    }\n}\n\nimpl MemoryStore {\n    #[cfg(test)]\n    pub(crate) fn data(&self) -> Option<Arc<Vec<u8>>> {\n        self.data.clone()\n    }\n\n    #[cfg(not(test))]\n    fn data(&self) -> Option<Arc<Vec<u8>>> {\n        self.data.clone()\n    }\n\n    #[cfg(test)]\n    pub(crate) fn set_data(&mut self, data: Vec<u8>) {\n        self.data = Some(Arc::new(data));\n    }\n}\n\n// Adapter to use TensorSnapshot directly with safetensors\n#[derive(Debug)]\nstruct TensorSnapshotAdapter(TensorSnapshot);\n\nimpl safetensors::View for TensorSnapshotAdapter {\n    fn dtype(&self) -> safetensors::Dtype {\n        // Convert from burn dtype to safetensors dtype\n        dtype_to_safetensors(self.0.dtype).unwrap_or(safetensors::Dtype::F32)\n    }\n\n    fn shape(&self) -> &[usize] {\n        &self.0.shape\n    }\n\n    fn data(&self) -> alloc::borrow::Cow<'_, [u8]> {\n        // Only materialize data when actually needed for serialization\n        let data = self\n            .0\n            .to_data()\n            .unwrap_or_else(|e| panic!(\"Failed to get tensor data: {:?}\", e));\n        alloc::borrow::Cow::Owned(data.bytes.deref().to_vec())\n    }\n\n    fn data_len(&self) -> usize {\n        // Use the efficient data_len method from TensorSnapshot\n        self.0.data_len()\n    }\n}\n\nimpl ModuleStore for SafetensorsStore {\n    type Error = SafetensorsStoreError;\n\n    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &M,\n    ) -> Result<(), Self::Error> {\n        // Invalidate cache since we're writing new data\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.snapshots_cache = None,\n            Self::Memory(p) => p.snapshots_cache = None,\n        }\n\n        // Collect tensor snapshots from module with adapter\n        // The to_adapter converts from Burn format to target format for saving\n        let to_adapter = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.to_adapter.clone(),\n            Self::Memory(p) => p.to_adapter.clone(),\n        };\n        let mut snapshots = module.collect(None, Some(to_adapter), self.get_skip_enum_variants());\n\n        // Apply filtering\n        snapshots = apply_filter(snapshots, self.get_filter());\n\n        // Apply remapping\n        #[cfg(feature = \"std\")]\n        {\n            snapshots = apply_remapping(snapshots, self.get_remapper());\n        }\n\n        // Get metadata (already includes format, producer and version from default_metadata)\n        let metadata = self.get_metadata().clone();\n\n        #[cfg(feature = \"std\")]\n        let std_metadata: std::collections::HashMap<String, String> = metadata\n            .iter()\n            .map(|(k, v)| (k.clone(), v.clone()))\n            .collect();\n\n        // Write to storage\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => {\n                // Check if file exists and overwrite is disabled\n                if p.path.exists() && !p.overwrite {\n                    return Err(SafetensorsStoreError::Other(format!(\n                        \"File already exists: {}. Use .overwrite(true) to overwrite.\",\n                        p.path.display()\n                    )));\n                }\n\n                // Convert to safetensors format\n                let tensors = snapshots_to_safetensors(snapshots)?;\n\n                // Use serialize_to_file which streams directly to disk\n                // This calls the lazy closures on-demand without buffering everything\n                safetensors::serialize_to_file(tensors, Some(std_metadata), &p.path)?;\n                Ok(())\n            }\n            Self::Memory(p) => {\n                // For memory, we need to serialize to bytes\n                let tensors = snapshots_to_safetensors(snapshots)?;\n                // For no-std, serialize still needs std HashMap when std feature is enabled\n                #[cfg(feature = \"std\")]\n                let data = safetensors::serialize(tensors, Some(std_metadata))?;\n\n                #[cfg(not(feature = \"std\"))]\n                let data = safetensors::serialize(tensors, Some(metadata))?;\n                p.data = Some(Arc::new(data));\n                Ok(())\n            }\n        }\n    }\n\n    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &mut M,\n    ) -> Result<ApplyResult, Self::Error> {\n        // Get snapshots from cache\n        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();\n\n        // Get the adapter\n        let adapter: Box<dyn ModuleAdapter> = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.from_adapter.clone(),\n            Self::Memory(p) => p.from_adapter.clone(),\n        };\n\n        // Get filter (cloned to Option for apply)\n        let filter = self.get_filter();\n        let filter_opt = if filter.is_empty() {\n            None\n        } else {\n            Some(filter.clone())\n        };\n\n        // Apply to module with adapter\n        // The adapter will be applied during module traversal with proper container info\n        // Filter is applied here during apply, not during cache population\n        let result = module.apply(\n            snapshots,\n            filter_opt,\n            Some(adapter),\n            self.get_skip_enum_variants(),\n        );\n\n        // Validate if needed\n        if self.get_validate() && !result.errors.is_empty() {\n            return Err(SafetensorsStoreError::ValidationFailed(format!(\n                \"Import errors: {:?}\",\n                result.errors\n            )));\n        }\n\n        if !self.get_allow_partial() && !result.missing.is_empty() {\n            return Err(SafetensorsStoreError::TensorNotFound(format!(\n                \"\\n{}\",\n                result\n            )));\n        }\n\n        Ok(result)\n    }\n\n    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {\n        // Ensure cache is populated\n        self.ensure_snapshots_cache()?;\n        let cache = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.snapshots_cache.as_ref().unwrap(),\n            Self::Memory(p) => p.snapshots_cache.as_ref().unwrap(),\n        };\n        Ok(cache.get(name))\n    }\n\n    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {\n        // Ensure cache is populated\n        self.ensure_snapshots_cache()?;\n        let cache = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.snapshots_cache.as_ref().unwrap(),\n            Self::Memory(p) => p.snapshots_cache.as_ref().unwrap(),\n        };\n        Ok(cache)\n    }\n\n    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {\n        // Always use the cache to ensure remapping is applied consistently\n        Ok(self.get_all_snapshots()?.keys().cloned().collect())\n    }\n}\n\nimpl SafetensorsStore {\n    fn get_filter(&self) -> &PathFilter {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => &p.filter,\n            Self::Memory(p) => &p.filter,\n        }\n    }\n\n    #[cfg(feature = \"std\")]\n    fn get_remapper(&self) -> &KeyRemapper {\n        match self {\n            Self::File(p) => &p.remapper,\n            Self::Memory(p) => &p.remapper,\n        }\n    }\n\n    fn get_metadata(&self) -> &HashMap<String, String> {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => &p.metadata,\n            Self::Memory(p) => &p.metadata,\n        }\n    }\n\n    fn get_validate(&self) -> bool {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.validate,\n            Self::Memory(p) => p.validate,\n        }\n    }\n\n    fn get_allow_partial(&self) -> bool {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.allow_partial,\n            Self::Memory(p) => p.allow_partial,\n        }\n    }\n\n    fn get_skip_enum_variants(&self) -> bool {\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.skip_enum_variants,\n            Self::Memory(p) => p.skip_enum_variants,\n        }\n    }\n\n    #[cfg(feature = \"std\")]\n    fn get_map_indices_contiguous(&self) -> bool {\n        match self {\n            Self::File(p) => p.map_indices_contiguous,\n            Self::Memory(p) => p.map_indices_contiguous,\n        }\n    }\n\n    /// Ensure the snapshots cache is populated\n    fn ensure_snapshots_cache(&mut self) -> Result<(), SafetensorsStoreError> {\n        // Check if cache exists\n        let has_cache = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.snapshots_cache.is_some(),\n            Self::Memory(p) => p.snapshots_cache.is_some(),\n        };\n\n        if has_cache {\n            return Ok(());\n        }\n\n        // Load snapshots\n        #[allow(unused_mut)]\n        let mut snapshots = match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => safetensors_to_snapshots_lazy_file(&p.path)?,\n            Self::Memory(p) => {\n                let data_arc = p\n                    .data\n                    .clone()\n                    .ok_or_else(|| SafetensorsStoreError::Other(\"No data loaded\".to_string()))?;\n                safetensors_to_snapshots_lazy(data_arc)?\n            }\n        };\n\n        // Apply remapping (but NOT filtering - that's done at apply time)\n        #[cfg(feature = \"std\")]\n        {\n            snapshots = match self {\n                Self::File(p) => apply_remapping(snapshots, &p.remapper),\n                Self::Memory(p) => apply_remapping(snapshots, &p.remapper),\n            };\n        }\n\n        // Apply contiguous index mapping if enabled\n        // This must be done after remapping so that remapped paths are mapped\n        #[cfg(feature = \"std\")]\n        if self.get_map_indices_contiguous() {\n            let (mapped, _) = map_indices_contiguous(snapshots);\n            snapshots = mapped;\n        }\n\n        // Build cache as BTreeMap\n        let cache: BTreeMap<String, TensorSnapshot> =\n            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();\n\n        // Store cache\n        match self {\n            #[cfg(feature = \"std\")]\n            Self::File(p) => p.snapshots_cache = Some(cache),\n            Self::Memory(p) => p.snapshots_cache = Some(cache),\n        }\n\n        Ok(())\n    }\n}\n\n/// Apply filter to tensor snapshots.\nfn apply_filter(mut snapshots: Vec<TensorSnapshot>, filter: &PathFilter) -> Vec<TensorSnapshot> {\n    if filter.is_empty() {\n        return snapshots;\n    }\n\n    snapshots.retain(|snapshot| {\n        let path = snapshot.full_path();\n        filter.matches(&path)\n    });\n\n    snapshots\n}\n\n/// Apply remapping to tensor snapshots.\n#[cfg(feature = \"std\")]\nfn apply_remapping(snapshots: Vec<TensorSnapshot>, remapper: &KeyRemapper) -> Vec<TensorSnapshot> {\n    if remapper.is_empty() {\n        return snapshots;\n    }\n\n    let (remapped, _) = remapper.remap(snapshots);\n    remapped\n}\n\n/// Convert TensorSnapshots to safetensors format lazily.\nfn snapshots_to_safetensors(\n    snapshots: Vec<TensorSnapshot>,\n) -> Result<Vec<(String, TensorSnapshotAdapter)>, SafetensorsStoreError> {\n    let mut tensors = Vec::new();\n\n    for snapshot in snapshots {\n        let name = snapshot.full_path();\n        // No need to materialize data - TensorSnapshot now has dtype and shape cached!\n        tensors.push((name, TensorSnapshotAdapter(snapshot)));\n    }\n\n    Ok(tensors)\n}\n\n/// Convert safetensors to TensorSnapshots with lazy loading.\nfn safetensors_to_snapshots_lazy(\n    data_arc: Arc<Vec<u8>>,\n) -> Result<Vec<TensorSnapshot>, SafetensorsStoreError> {\n    // Parse to get metadata\n    let tensors = safetensors::SafeTensors::deserialize(&data_arc)?;\n    let mut snapshots = Vec::new();\n\n    for (name, tensor_snapshot) in tensors.tensors() {\n        // Extract metadata without materializing data\n        let dtype = safetensor_dtype_to_burn(tensor_snapshot.dtype())?;\n        let shape = tensor_snapshot.shape();\n        let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();\n\n        // Create a lazy closure that will deserialize only this tensor when needed\n        #[cfg(target_has_atomic = \"ptr\")]\n        let data_clone = Arc::clone(&data_arc);\n        #[cfg(not(target_has_atomic = \"ptr\"))]\n        let data_clone = data_arc.clone();\n        let name_clone = name.to_string();\n        let data_fn = alloc::rc::Rc::new(move || {\n            // Re-deserialize when needed (this is cheap, just parsing header)\n            let tensors = safetensors::SafeTensors::deserialize(&data_clone).map_err(|e| {\n                crate::TensorSnapshotError::IoError(format!(\n                    \"Failed to re-deserialize safetensors: {}\",\n                    e\n                ))\n            })?;\n\n            // Find our specific tensor\n            let tensor = tensors.tensor(&name_clone).map_err(|e| {\n                crate::TensorSnapshotError::DataError(format!(\n                    \"Tensor '{}' not found: {}\",\n                    name_clone, e\n                ))\n            })?;\n\n            // Now materialize just this tensor's data\n            let bytes = burn_tensor::Bytes::from_bytes_vec(tensor.data().to_vec());\n            Ok(TensorData {\n                bytes,\n                shape: tensor.shape().into(),\n                dtype: safetensor_dtype_to_burn(tensor.dtype())\n                    .map_err(|_| crate::TensorSnapshotError::DataError(\"Invalid dtype\".into()))?,\n            })\n        });\n\n        let snapshot = TensorSnapshot::from_closure(\n            data_fn,\n            dtype,\n            shape.into(),\n            path_parts,\n            vec![], // Empty container_stack - will be filled during module traversal\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    Ok(snapshots)\n}\n\n/// Convert safetensors to TensorSnapshots with true on-demand loading from file.\n/// This reads only the header initially, then loads tensor data on demand.\n#[cfg(feature = \"std\")]\nfn safetensors_to_snapshots_lazy_file(\n    path: &std::path::Path,\n) -> Result<Vec<TensorSnapshot>, SafetensorsStoreError> {\n    // Always use memory mapping for the most efficient access\n    use memmap2::MmapOptions;\n\n    // Memory map the file for efficient access\n    let file = std::fs::File::open(path)?;\n    let mmap = unsafe { MmapOptions::new().map(&file)? };\n    let mmap_arc = Arc::new(mmap);\n\n    // Parse just to get metadata (safetensors won't copy data with mmap)\n    let tensors = safetensors::SafeTensors::deserialize(&mmap_arc)?;\n    let mut snapshots = Vec::new();\n\n    for (name, tensor_snapshot) in tensors.tensors() {\n        let dtype = safetensor_dtype_to_burn(tensor_snapshot.dtype())?;\n        let shape = tensor_snapshot.shape();\n        let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();\n\n        // Create a lazy closure that accesses the mmap'd data\n        let mmap_clone = Arc::clone(&mmap_arc);\n        let name_clone = name.to_string();\n\n        let data_fn = alloc::rc::Rc::new(move || {\n            // Re-parse to get the tensor snapshot (this is cheap with mmap)\n            let tensors = safetensors::SafeTensors::deserialize(&mmap_clone).map_err(|e| {\n                crate::TensorSnapshotError::IoError(format!(\"Failed to deserialize: {}\", e))\n            })?;\n            let tensor = tensors.tensor(&name_clone).map_err(|e| {\n                crate::TensorSnapshotError::DataError(format!(\n                    \"Tensor '{}' not found: {}\",\n                    name_clone, e\n                ))\n            })?;\n\n            // Only now do we actually copy the tensor data\n            Ok(TensorData {\n                bytes: burn_tensor::Bytes::from_bytes_vec(tensor.data().to_vec()),\n                shape: tensor.shape().into(),\n                dtype: safetensor_dtype_to_burn(tensor.dtype())\n                    .map_err(|_| crate::TensorSnapshotError::DataError(\"Invalid dtype\".into()))?,\n            })\n        });\n\n        let snapshot = TensorSnapshot::from_closure(\n            data_fn,\n            dtype,\n            shape.into(),\n            path_parts,\n            vec![], // Empty container_stack - will be filled during module traversal\n            ParamId::new(),\n        );\n        snapshots.push(snapshot);\n    }\n\n    Ok(snapshots)\n}\n\n/// Helper to convert safetensors Dtype to burn DType.\nfn safetensor_dtype_to_burn(dtype: safetensors::Dtype) -> Result<DType, SafetensorsStoreError> {\n    use safetensors::Dtype;\n\n    match dtype {\n        Dtype::F64 => Ok(DType::F64),\n        Dtype::F32 => Ok(DType::F32),\n        Dtype::F16 => Ok(DType::F16),\n        Dtype::BF16 => Ok(DType::BF16),\n        Dtype::I64 => Ok(DType::I64),\n        Dtype::I32 => Ok(DType::I32),\n        Dtype::I16 => Ok(DType::I16),\n        Dtype::I8 => Ok(DType::I8),\n        Dtype::U64 => Ok(DType::U64),\n        Dtype::U32 => Ok(DType::U32),\n        Dtype::U8 => Ok(DType::U8),\n        Dtype::BOOL => Ok(DType::Bool(BoolStore::Native)),\n        _ => Err(SafetensorsStoreError::Other(format!(\n            \"Unsupported dtype: {:?}\",\n            dtype\n        ))),\n    }\n}\n\n/// Helper to convert DType to safetensors Dtype.\nfn dtype_to_safetensors(dtype: DType) -> Result<safetensors::Dtype, SafetensorsStoreError> {\n    use safetensors::Dtype;\n\n    match dtype {\n        DType::F64 => Ok(Dtype::F64),\n        DType::F32 | DType::Flex32 => Ok(Dtype::F32), // Flex32 is stored as F32\n        DType::F16 => Ok(Dtype::F16),\n        DType::BF16 => Ok(Dtype::BF16),\n        DType::I64 => Ok(Dtype::I64),\n        DType::I32 => Ok(Dtype::I32),\n        DType::I16 => Ok(Dtype::I16),\n        DType::I8 => Ok(Dtype::I8),\n        DType::U64 => Ok(Dtype::U64),\n        DType::U32 => Ok(Dtype::U32),\n        DType::U16 => Err(SafetensorsStoreError::Other(\n            \"U16 dtype not yet supported in safetensors\".to_string(),\n        )),\n        DType::U8 => Ok(Dtype::U8),\n        DType::Bool(BoolStore::Native) => Ok(Dtype::BOOL),\n        DType::Bool(BoolStore::U32) => Ok(Dtype::U32),\n        DType::Bool(BoolStore::U8) => Ok(Dtype::U8),\n        DType::QFloat(_) => Err(SafetensorsStoreError::Other(\n            \"Quantized tensors not yet supported in safetensors\".to_string(),\n        )),\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/adapter.rs",
    "content": "use burn_core as burn;\n\nuse crate::{\n    BurnToPyTorchAdapter, ModuleSnapshot, ModuleStore, PyTorchToBurnAdapter, SafetensorsStore,\n};\nuse burn_core::module::{Module, Param};\nuse burn_nn::{Linear, LinearConfig};\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[derive(Module, Debug)]\nstruct TestModel<B: Backend> {\n    linear: Linear<B>,\n    norm_weight: Param<Tensor<B, 1>>,\n    norm_bias: Param<Tensor<B, 1>>,\n}\n\nimpl<B: Backend> TestModel<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            linear: LinearConfig::new(4, 2).with_bias(true).init(device),\n            norm_weight: Param::from_data([1.0, 1.0], device),\n            norm_bias: Param::from_data([0.0, 0.0], device),\n        }\n    }\n}\n\n#[test]\nfn pytorch_to_burn_adapter_linear_transpose() {\n    let device = Default::default();\n    let model = TestModel::<TestBackend>::new(&device);\n\n    // Save with BurnToPyTorch adapter (will transpose linear weights)\n    let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter);\n    model.save_into(&mut save_store).unwrap();\n\n    // Load with PyTorchToBurn adapter (will transpose back)\n    let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let mut model2 = TestModel::<TestBackend>::new(&device);\n    let result = model2.load_from(&mut load_store).unwrap();\n\n    // Should successfully load all tensors\n    assert!(!result.applied.is_empty());\n\n    // Verify the linear weights are the same after round-trip\n    let weight1 = model.linear.weight.val().to_data();\n    let weight2 = model2.linear.weight.val().to_data();\n\n    assert_eq!(weight1.shape, weight2.shape);\n    let data1 = weight1.to_vec::<f32>().unwrap();\n    let data2 = weight2.to_vec::<f32>().unwrap();\n\n    for (a, b) in data1.iter().zip(data2.iter()) {\n        assert!(\n            (a - b).abs() < 1e-6,\n            \"Weights differ after adapter round-trip\"\n        );\n    }\n}\n\n#[test]\nfn pytorch_to_burn_adapter_norm_rename() {\n    let device = Default::default();\n\n    // Create a model with norm-like naming\n    #[derive(Module, Debug)]\n    struct NormModel<B: Backend> {\n        norm_gamma: Param<Tensor<B, 1>>,\n        norm_beta: Param<Tensor<B, 1>>,\n    }\n\n    impl<B: Backend> NormModel<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                norm_gamma: Param::from_data([1.0, 2.0, 3.0], device),\n                norm_beta: Param::from_data([0.1, 0.2, 0.3], device),\n            }\n        }\n    }\n\n    let model = NormModel::<TestBackend>::new(&device);\n\n    // Save with BurnToPyTorch adapter (will rename gamma->weight, beta->bias)\n    let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(BurnToPyTorchAdapter);\n    model.save_into(&mut save_store).unwrap();\n\n    // The saved data should have PyTorch naming convention\n    // We can't directly verify the internal names, but we can verify round-trip works\n\n    // Load with PyTorchToBurn adapter (will rename weight->gamma, bias->beta)\n    let mut load_store = SafetensorsStore::from_bytes(None).with_from_adapter(PyTorchToBurnAdapter);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let mut model2 = NormModel::<TestBackend>::new(&device);\n    let result = model2.load_from(&mut load_store).unwrap();\n\n    // Should load successfully\n    assert!(!result.applied.is_empty());\n\n    // Verify data is preserved\n    let gamma1 = model.norm_gamma.val().to_data().to_vec::<f32>().unwrap();\n    let gamma2 = model2.norm_gamma.val().to_data().to_vec::<f32>().unwrap();\n    let beta1 = model.norm_beta.val().to_data().to_vec::<f32>().unwrap();\n    let beta2 = model2.norm_beta.val().to_data().to_vec::<f32>().unwrap();\n\n    assert_eq!(gamma1, gamma2);\n    assert_eq!(beta1, beta2);\n}\n\n#[test]\nfn no_adapter_preserves_original() {\n    let device = Default::default();\n    let model = TestModel::<TestBackend>::new(&device);\n\n    // Save without adapter\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    model.save_into(&mut save_store).unwrap();\n\n    // Load without adapter\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let mut model2 = TestModel::<TestBackend>::new(&device);\n    let result = model2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n    assert!(!result.applied.is_empty());\n\n    // Verify data is exactly the same\n    let weight1 = model.linear.weight.val().to_data();\n    let weight2 = model2.linear.weight.val().to_data();\n\n    assert_eq!(weight1.shape, weight2.shape);\n    assert_eq!(\n        weight1.to_vec::<f32>().unwrap(),\n        weight2.to_vec::<f32>().unwrap()\n    );\n}\n\n#[test]\n#[cfg(all(feature = \"std\", target_has_atomic = \"ptr\"))]\nfn adapter_with_pytorch_import() {\n    use crate::PyTorchToBurnAdapter;\n\n    let device = Default::default();\n\n    // Reference the safetensors file from burn-store\n    let safetensors_path = concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    );\n\n    // Simple test model that matches some of the PyTorch structure\n    #[derive(Module, Debug)]\n    struct SimpleNet<B: Backend> {\n        fc1: Linear<B>,\n    }\n\n    impl<B: Backend> SimpleNet<B> {\n        fn new(device: &B::Device) -> Self {\n            Self {\n                fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),\n            }\n        }\n    }\n\n    // Load with PyTorchToBurn adapter\n    let mut store = SafetensorsStore::from_file(safetensors_path)\n        .with_from_adapter(PyTorchToBurnAdapter)\n        .validate(false)\n        .allow_partial(true);\n\n    let mut model = SimpleNet::<TestBackend>::new(&device);\n    let result = model.load_from(&mut store).unwrap();\n\n    // Should load some tensors (fc1 if it exists in the file)\n    // This mainly tests that the adapter works with real PyTorch files\n    assert!(!result.applied.is_empty() || !result.missing.is_empty());\n}\n\n#[test]\nfn half_precision_adapter_round_trip() {\n    use crate::HalfPrecisionAdapter;\n    use burn_tensor::DType;\n\n    let device = Default::default();\n    let model = TestModel::<TestBackend>::new(&device);\n\n    // Save with HalfPrecisionAdapter (F32 -> F16)\n    let adapter = HalfPrecisionAdapter::new();\n    let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter.clone());\n    model.save_into(&mut save_store).unwrap();\n\n    // Verify Linear tensors are F16, raw params stay F32 (no recognized module type)\n    let save_bytes = match &save_store {\n        SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(),\n        _ => panic!(\"Expected memory store\"),\n    };\n    let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes.clone()));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (name, snapshot) in snapshots.iter() {\n        if name.starts_with(\"linear\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F16,\n                \"Linear tensor '{}' should be F16\",\n                name\n            );\n        } else {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F32,\n                \"Raw param '{}' should stay F32\",\n                name\n            );\n        }\n    }\n\n    // Load back with same adapter (F16 -> F32)\n    let mut load_store = SafetensorsStore::from_bytes(Some(save_bytes)).with_from_adapter(adapter);\n\n    let mut model2 = TestModel::<TestBackend>::new(&device);\n    let result = model2.load_from(&mut load_store).unwrap();\n\n    assert!(!result.applied.is_empty());\n\n    // Verify values are close (F32 -> F16 -> F32 has rounding)\n    let w1 = model.linear.weight.val().to_data().to_vec::<f32>().unwrap();\n    let w2 = model2\n        .linear\n        .weight\n        .val()\n        .to_data()\n        .to_vec::<f32>()\n        .unwrap();\n    for (a, b) in w1.iter().zip(w2.iter()) {\n        assert!(\n            (a - b).abs() < 0.01,\n            \"Weight values differ too much after F16 round-trip: {} vs {}\",\n            a,\n            b\n        );\n    }\n}\n\n#[test]\nfn half_precision_adapter_without_module() {\n    use crate::HalfPrecisionAdapter;\n    use burn_nn::{LayerNorm, LayerNormConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct MixedModel<B: Backend> {\n        linear: Linear<B>,\n        norm: LayerNorm<B>,\n    }\n\n    let device = Default::default();\n    let model = MixedModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n        norm: LayerNormConfig::new(2).init(&device),\n    };\n\n    // Save: exclude LayerNorm from half-precision conversion\n    let adapter = HalfPrecisionAdapter::new().without_module(\"LayerNorm\");\n    let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter);\n    model.save_into(&mut save_store).unwrap();\n\n    // Verify: Linear tensors are F16, LayerNorm tensors remain F32\n    let save_bytes = match &save_store {\n        SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(),\n        _ => panic!(\"Expected memory store\"),\n    };\n    let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (name, snapshot) in snapshots {\n        if name.starts_with(\"linear\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F16,\n                \"Linear tensor '{}' should be F16\",\n                name\n            );\n        } else if name.starts_with(\"norm\") {\n            assert_eq!(\n                snapshot.dtype,\n                DType::F32,\n                \"LayerNorm tensor '{}' should stay F32\",\n                name\n            );\n        }\n    }\n}\n\n#[test]\nfn half_precision_adapter_default_converts_layer_norm() {\n    use crate::HalfPrecisionAdapter;\n    use burn_nn::{LayerNorm, LayerNormConfig};\n    use burn_tensor::DType;\n\n    #[derive(Module, Debug)]\n    struct NormModel<B: Backend> {\n        linear: Linear<B>,\n        norm: LayerNorm<B>,\n    }\n\n    let device = Default::default();\n    let model = NormModel::<TestBackend> {\n        linear: LinearConfig::new(4, 2).with_bias(true).init(&device),\n        norm: LayerNormConfig::new(2).init(&device),\n    };\n\n    // Default adapter converts LayerNorm\n    let adapter = HalfPrecisionAdapter::new();\n    let mut save_store = SafetensorsStore::from_bytes(None).with_to_adapter(adapter);\n    model.save_into(&mut save_store).unwrap();\n\n    let save_bytes = match &save_store {\n        SafetensorsStore::Memory(p) => p.data().unwrap().as_ref().clone(),\n        _ => panic!(\"Expected memory store\"),\n    };\n    let mut inspect_store = SafetensorsStore::from_bytes(Some(save_bytes));\n    let snapshots = inspect_store.get_all_snapshots().unwrap();\n    for (name, snapshot) in snapshots {\n        assert_eq!(\n            snapshot.dtype,\n            DType::F16,\n            \"All tensors should be F16 by default, but '{}' is {:?}\",\n            name,\n            snapshot.dtype\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/direct_access.rs",
    "content": "use burn_core as burn;\n\nuse crate::{ModuleStore, SafetensorsStore};\nuse burn_core::module::{Module, Param};\nuse burn_tensor::backend::Backend;\nuse burn_tensor::{Tensor, shape};\n\ntype TestBackend = burn_ndarray::NdArray;\n\n// Test module for direct access tests\n#[derive(Module, Debug)]\nstruct DirectAccessTestModule<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    bias: Param<Tensor<B, 1>>,\n    nested: DirectAccessNestedModule<B>,\n}\n\n#[derive(Module, Debug)]\nstruct DirectAccessNestedModule<B: Backend> {\n    gamma: Param<Tensor<B, 1>>,\n    beta: Param<Tensor<B, 1>>,\n}\n\nimpl<B: Backend> DirectAccessTestModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),\n            bias: Param::from_data([0.1, 0.2], device),\n            nested: DirectAccessNestedModule {\n                gamma: Param::from_data([1.0, 2.0], device),\n                beta: Param::from_data([0.5, 0.5], device),\n            },\n        }\n    }\n}\n\n#[test]\nfn test_memory_get_all_snapshots() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    // Save module to memory\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n\n    // Get bytes and create load store\n    let bytes = save_store.get_bytes().unwrap();\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    // Get all snapshots\n    let snapshots = load_store.get_all_snapshots().unwrap();\n\n    assert_eq!(snapshots.len(), 4);\n    assert!(snapshots.contains_key(\"weight\"));\n    assert!(snapshots.contains_key(\"bias\"));\n    assert!(snapshots.contains_key(\"nested.gamma\"));\n    assert!(snapshots.contains_key(\"nested.beta\"));\n}\n\n#[test]\nfn test_memory_get_snapshot_existing() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    // Get existing snapshot\n    let snapshot = load_store.get_snapshot(\"weight\").unwrap();\n    assert!(snapshot.is_some());\n\n    let snapshot = snapshot.unwrap();\n    assert_eq!(snapshot.shape, shape![2, 2]);\n\n    // Verify data\n    let data = snapshot.to_data().unwrap();\n    let values: Vec<f32> = data.to_vec().unwrap();\n    assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);\n}\n\n#[test]\nfn test_memory_get_snapshot_nested() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    // Get nested snapshot\n    let snapshot = load_store.get_snapshot(\"nested.gamma\").unwrap();\n    assert!(snapshot.is_some());\n\n    let snapshot = snapshot.unwrap();\n    let data = snapshot.to_data().unwrap();\n    let values: Vec<f32> = data.to_vec().unwrap();\n    assert_eq!(values, vec![1.0, 2.0]);\n}\n\n#[test]\nfn test_memory_get_snapshot_not_found() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    // Get non-existent snapshot\n    let snapshot = load_store.get_snapshot(\"nonexistent\").unwrap();\n    assert!(snapshot.is_none());\n}\n\n#[test]\nfn test_memory_keys() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    let keys = load_store.keys().unwrap();\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"weight\".to_string()));\n    assert!(keys.contains(&\"bias\".to_string()));\n    assert!(keys.contains(&\"nested.gamma\".to_string()));\n    assert!(keys.contains(&\"nested.beta\".to_string()));\n}\n\n#[test]\nfn test_memory_caching_behavior() {\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    save_store.collect_from(&module).unwrap();\n    let bytes = save_store.get_bytes().unwrap();\n\n    let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n\n    // Call get_all_snapshots multiple times - should return same cached data\n    let snapshots1 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots1.len(), 4);\n\n    let snapshots2 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots2.len(), 4);\n\n    // Verify we can still access individual snapshots after caching\n    let snapshot = load_store.get_snapshot(\"bias\").unwrap();\n    assert!(snapshot.is_some());\n}\n\n// ============================================================================\n// Tests for FileStore variant\n// ============================================================================\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_get_all_snapshots() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_get_all_snapshots.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n    let snapshots = load_store.get_all_snapshots().unwrap();\n\n    assert_eq!(snapshots.len(), 4);\n    assert!(snapshots.contains_key(\"weight\"));\n    assert!(snapshots.contains_key(\"bias\"));\n    assert!(snapshots.contains_key(\"nested.gamma\"));\n    assert!(snapshots.contains_key(\"nested.beta\"));\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_get_snapshot_existing() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_get_snapshot.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n\n    let snapshot = load_store.get_snapshot(\"weight\").unwrap();\n    assert!(snapshot.is_some());\n\n    let snapshot = snapshot.unwrap();\n    assert_eq!(snapshot.shape, shape![2, 2]);\n\n    let data = snapshot.to_data().unwrap();\n    let values: Vec<f32> = data.to_vec().unwrap();\n    assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_get_snapshot_not_found() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_not_found.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n\n    let snapshot = load_store.get_snapshot(\"nonexistent\").unwrap();\n    assert!(snapshot.is_none());\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_keys() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_keys.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n\n    let keys = load_store.keys().unwrap();\n    assert_eq!(keys.len(), 4);\n    assert!(keys.contains(&\"weight\".to_string()));\n    assert!(keys.contains(&\"bias\".to_string()));\n    assert!(keys.contains(&\"nested.gamma\".to_string()));\n    assert!(keys.contains(&\"nested.beta\".to_string()));\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_keys_fast_path() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_keys_fast.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    // Create fresh store - cache should be empty\n    let mut load_store = SafetensorsStore::from_file(&path);\n\n    // keys() should work without populating the full cache (fast path)\n    let keys = load_store.keys().unwrap();\n    assert_eq!(keys.len(), 4);\n\n    // Now call get_all_snapshots to populate cache\n    let snapshots = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots.len(), 4);\n\n    // keys() should now use the cached data\n    let keys2 = load_store.keys().unwrap();\n    assert_eq!(keys2.len(), 4);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_caching_behavior() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_caching.safetensors\");\n\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n\n    // First call populates cache\n    let snapshots1 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots1.len(), 4);\n\n    // Second call uses cache\n    let snapshots2 = load_store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots2.len(), 4);\n}\n\n#[test]\n#[cfg(feature = \"std\")]\nfn test_file_cache_invalidation_on_save() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = DirectAccessTestModule::<TestBackend>::new(&device);\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_invalidation.safetensors\");\n\n    // Create store, save, and populate cache\n    let mut store = SafetensorsStore::from_file(&path).overwrite(true);\n    store.collect_from(&module).unwrap();\n\n    let snapshots1 = store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots1.len(), 4);\n\n    // Save again (this should invalidate cache)\n    store.collect_from(&module).unwrap();\n\n    // Cache should be repopulated with fresh data\n    let snapshots2 = store.get_all_snapshots().unwrap();\n    assert_eq!(snapshots2.len(), 4);\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/error_handling.rs",
    "content": "use crate::{ModuleSnapshot, SafetensorsStore};\nuse burn_nn::LinearConfig;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[test]\nfn shape_mismatch_errors() {\n    let device = Default::default();\n\n    // Create a module\n    let module = LinearConfig::new(2, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Save module\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    module.save_into(&mut save_store).unwrap();\n\n    // Try to load into incompatible module (different dimensions)\n    let mut incompatible_module = LinearConfig::new(3, 3)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Load without validation - should return errors in the result\n    let mut load_store = SafetensorsStore::from_bytes(None).validate(false); // Disable validation to get errors in result\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        // Get Arc and extract data\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n\n    let result = incompatible_module.load_from(&mut load_store).unwrap();\n\n    // Should have errors due to shape mismatch\n    assert!(!result.errors.is_empty());\n\n    // Try again with validation enabled - should return Err\n    let mut load_store_with_validation = SafetensorsStore::from_bytes(None).validate(true);\n    if let SafetensorsStore::Memory(ref mut p) = load_store_with_validation\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        // Get Arc and extract data\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n\n    let validation_result = incompatible_module.load_from(&mut load_store_with_validation);\n    assert!(validation_result.is_err());\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/file_io.rs",
    "content": "use burn_core as burn;\n\nuse crate::{ModuleSnapshot, ModuleStore, SafetensorsStore};\nuse burn_core::module::{Module, Param};\nuse burn_nn::{Initializer, LinearConfig};\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\nuse tempfile::tempdir;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n// Define a test model with forward pass\n#[derive(Module, Debug)]\nstruct ForwardTestModel<B: burn_tensor::backend::Backend> {\n    linear1: burn_nn::Linear<B>,\n    linear2: burn_nn::Linear<B>,\n}\n\nimpl<B: burn_tensor::backend::Backend> ForwardTestModel<B> {\n    fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        let x = self.linear1.forward(input);\n        let x = burn::tensor::activation::gelu(x);\n        self.linear2.forward(x)\n    }\n}\n\n// Define config for the model\n#[derive(burn::config::Config, Debug)]\nstruct ForwardTestModelConfig {\n    input_size: usize,\n    hidden_size: usize,\n    output_size: usize,\n}\n\nimpl ForwardTestModelConfig {\n    fn init<B: burn_tensor::backend::Backend>(&self, device: &B::Device) -> ForwardTestModel<B> {\n        ForwardTestModel {\n            linear1: LinearConfig::new(self.input_size, self.hidden_size)\n                .with_bias(true)\n                .init(device),\n            linear2: LinearConfig::new(self.hidden_size, self.output_size)\n                .with_bias(true)\n                .init(device),\n        }\n    }\n}\n\n#[derive(Module, Debug)]\npub struct ModuleBasic<B: Backend> {\n    weight_basic: Param<Tensor<B, 2>>,\n}\n\nimpl<B: Backend> ModuleBasic<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight_basic: Initializer::Normal {\n                std: 1.0,\n                mean: 0.0,\n            }\n            .init([20, 20], device),\n        }\n    }\n}\n\n#[derive(Module, Debug)]\npub struct ModuleComposed<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    basic: ModuleBasic<B>,\n    tuple: (ModuleBasic<B>, ModuleBasic<B>),\n}\n\nimpl<B: Backend> ModuleComposed<B> {\n    fn new(device: &B::Device) -> Self {\n        let weight = Initializer::Normal {\n            std: 1.0,\n            mean: 0.0,\n        }\n        .init([20, 20], device);\n\n        Self {\n            weight,\n            basic: ModuleBasic::new(device),\n            tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),\n        }\n    }\n}\n\n#[test]\nfn file_based_loading() {\n    use std::fs;\n\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Create temp file path\n    let temp_dir = std::env::temp_dir();\n    let file_path = temp_dir.join(\"test_safetensors.st\");\n\n    // Save to file\n    let mut save_store = SafetensorsStore::from_file(&file_path).metadata(\"test\", \"file_loading\");\n\n    module.save_into(&mut save_store).unwrap();\n\n    // Verify file exists\n    assert!(file_path.exists());\n\n    // Load from file (will use memory-mapped loading if available)\n    let mut load_store = SafetensorsStore::from_file(&file_path);\n\n    let mut loaded_module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    let result = loaded_module.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 2); // weight and bias\n\n    // Clean up\n    fs::remove_file(file_path).ok();\n}\n\n#[test]\nfn test_store_overwrite_protection() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Create temp directory and file path (file doesn't exist yet)\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_model.safetensors\");\n\n    // First save - should succeed\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&module).unwrap();\n    assert!(path.exists());\n\n    // Second save without overwrite flag - should fail\n    let mut save_store2 = SafetensorsStore::from_file(&path);\n    let result = save_store2.collect_from(&module);\n    assert!(result.is_err());\n    assert!(\n        result\n            .unwrap_err()\n            .to_string()\n            .contains(\"File already exists\")\n    );\n\n    // Third save with overwrite flag - should succeed\n    let mut save_store3 = SafetensorsStore::from_file(&path).overwrite(true);\n    save_store3.collect_from(&module).unwrap();\n\n    // Verify file still exists and is valid\n    let mut load_store = SafetensorsStore::from_file(&path);\n    let mut module2 = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n    let result = load_store.apply_to(&mut module2).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_store_overwrite_with_metadata() {\n    use tempfile::tempdir;\n\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Create temp directory and file path\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"test_model_metadata.safetensors\");\n\n    // First save with v1 metadata and overwrite enabled\n    let mut save_store = SafetensorsStore::from_file(&path)\n        .metadata(\"model_version\", \"v1\")\n        .overwrite(true);\n    save_store.collect_from(&module).unwrap();\n\n    // Second save with v2 metadata and overwrite enabled\n    let mut save_store2 = SafetensorsStore::from_file(&path)\n        .metadata(\"model_version\", \"v2\")\n        .overwrite(true);\n    save_store2.collect_from(&module).unwrap();\n\n    // Load and verify the metadata was updated to v2\n    let mut load_store = SafetensorsStore::from_file(&path);\n    // Since we can't easily access metadata after loading, we just verify the file loads successfully\n    let mut module2 = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n    let result = module2.load_from(&mut load_store).unwrap();\n    assert!(result.is_success());\n}\n\n#[test]\nfn test_forward_pass_preservation_after_save_load() {\n    let device = Default::default();\n\n    // Create model config\n    let config = ForwardTestModelConfig {\n        input_size: 4,\n        hidden_size: 8,\n        output_size: 2,\n    };\n\n    // Initialize model1 with random weights\n    let model1 = config.init::<TestBackend>(&device);\n\n    // Create random input\n    let input = Tensor::<TestBackend, 2>::random(\n        [1, 4],\n        burn_tensor::Distribution::Uniform(-1.0, 1.0),\n        &device,\n    );\n\n    // Forward pass with model1 -> output1\n    let output1 = model1.forward(input.clone());\n\n    // Save model1 weights\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"forward_test_model.safetensors\");\n    let mut save_store = SafetensorsStore::from_file(&path);\n    save_store.collect_from(&model1).unwrap();\n\n    // Initialize model2 with different random weights\n    let mut model2 = config.init::<TestBackend>(&device);\n\n    // Forward pass with model2 -> output2 (should differ from output1)\n    let output2 = model2.forward(input.clone());\n\n    // Verify output2 differs from output1 (different random weights)\n    assert!(\n        !output1\n            .clone()\n            .all_close(output2.clone(), Some(1e-6), Some(1e-6)),\n        \"output2 should differ from output1 (different random initializations)\"\n    );\n\n    // Load model1 weights into model2\n    let mut load_store = SafetensorsStore::from_file(&path);\n    let result = load_store.apply_to(&mut model2).unwrap();\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 4); // 2 weights + 2 biases\n\n    // Forward pass with model2 (now has model1 weights) -> output3\n    let output3 = model2.forward(input.clone());\n\n    // Verify output3 equals output1 (same weights)\n    assert!(\n        output1.all_close(output3, Some(1e-6), Some(1e-6)),\n        \"output3 should equal output1 after loading weights\"\n    );\n}\n\n#[test]\nfn should_save_load_compose() {\n    let device = <TestBackend as Backend>::Device::default();\n    let module_1 = ModuleComposed::<TestBackend>::new(&device);\n    let mut module_2 = ModuleComposed::<TestBackend>::new(&device);\n    assert_ne!(module_1.weight.to_data(), module_2.weight.to_data());\n    assert_ne!(\n        module_1.basic.weight_basic.to_data(),\n        module_2.basic.weight_basic.to_data()\n    );\n\n    let temp_dir = tempdir().unwrap();\n    let path = temp_dir.path().join(\"save_load_compose.safetensors\");\n    let mut store = SafetensorsStore::from_file(&path);\n    module_1.save_into(&mut store).unwrap();\n\n    let mut load_store = SafetensorsStore::from_file(&path);\n    let result = module_2.load_from(&mut load_store).unwrap();\n    assert!(result.is_success());\n\n    assert_eq!(module_1.weight.to_data(), module_2.weight.to_data());\n    assert_eq!(\n        module_1.basic.weight_basic.to_data(),\n        module_2.basic.weight_basic.to_data()\n    );\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/filtering.rs",
    "content": "use crate::{ModuleSnapshot, SafetensorsStore};\n\nuse super::round_trip::ComplexModule;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[test]\n#[cfg(target_has_atomic = \"ptr\")]\nfn filtered_export_import() {\n    let device = Default::default();\n    let module1 = ComplexModule::<TestBackend>::new(&device);\n    let mut module2 = ComplexModule::<TestBackend>::new_zeros(&device);\n\n    // Export only encoder tensors using the builder pattern\n    let mut save_store = SafetensorsStore::from_bytes(None).with_regex(r\"^encoder\\..*\");\n    module1.save_into(&mut save_store).unwrap();\n\n    // Import filtered tensors - need to allow partial since we only saved encoder tensors\n    let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        // Get Arc and extract data\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n    let result = module2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 3); // encoder.weight, encoder.bias, encoder.norm\n    assert!(!result.missing.is_empty()); // decoder and layers tensors are missing\n}\n\n#[test]\n#[cfg(target_has_atomic = \"ptr\")]\nfn builder_pattern_filtering() {\n    let device = Default::default();\n    let module = ComplexModule::<TestBackend>::new(&device);\n\n    // Test with_regex - multiple patterns (OR logic)\n    let mut store = SafetensorsStore::from_bytes(None)\n        .with_regex(r\"^encoder\\..*\") // Match encoder tensors\n        .with_regex(r\".*\\.bias$\"); // OR match any bias tensors\n\n    let views = module.collect(None, None, false);\n    let filtered_count = views\n        .iter()\n        .filter(|v| {\n            let path = v.full_path();\n            path.starts_with(\"encoder.\") || path.ends_with(\".bias\")\n        })\n        .count();\n\n    module.save_into(&mut store).unwrap();\n\n    // Verify we saved the expected number of tensors\n    if let SafetensorsStore::Memory(ref p) = store {\n        let data = p.data().unwrap();\n        let tensors = safetensors::SafeTensors::deserialize(&data).unwrap();\n        assert_eq!(tensors.len(), filtered_count);\n    }\n}\n\n#[test]\nfn builder_pattern_exact_paths() {\n    let device = Default::default();\n    let module = ComplexModule::<TestBackend>::new(&device);\n\n    // Test with_full_path and with_full_paths\n    let paths = vec![\"encoder.weight\", \"decoder.scale\"];\n    let mut store = SafetensorsStore::from_bytes(None)\n        .with_full_path(\"encoder.norm\")\n        .with_full_paths(paths.clone());\n\n    module.save_into(&mut store).unwrap();\n\n    // Verify only specified tensors were saved\n    if let SafetensorsStore::Memory(ref p) = store {\n        let data = p.data().unwrap();\n        let tensors = safetensors::SafeTensors::deserialize(&data).unwrap();\n        assert_eq!(tensors.len(), 3); // encoder.norm + encoder.weight + decoder.scale\n\n        for (name, _) in tensors.tensors() {\n            assert!(name == \"encoder.norm\" || name == \"encoder.weight\" || name == \"decoder.scale\");\n        }\n    }\n}\n\n#[test]\nfn builder_pattern_with_predicate() {\n    let device = Default::default();\n    let module = ComplexModule::<TestBackend>::new(&device);\n\n    // Test with_predicate - custom logic\n    let mut store = SafetensorsStore::from_bytes(None).with_predicate(|path, _| {\n        // Only save tensors with \"layer\" in the path and ending with \"weight\"\n        path.contains(\"layer\") && path.ends_with(\"weight\")\n    });\n\n    module.save_into(&mut store).unwrap();\n\n    // Verify only layer weights were saved\n    if let SafetensorsStore::Memory(ref p) = store {\n        let data = p.data().unwrap();\n        let tensors = safetensors::SafeTensors::deserialize(&data).unwrap();\n\n        for (name, _) in tensors.tensors() {\n            assert!(name.contains(\"layer\"));\n            assert!(name.ends_with(\"weight\"));\n        }\n    }\n}\n\n#[test]\nfn builder_pattern_combined() {\n    let device = Default::default();\n    let module = ComplexModule::<TestBackend>::new(&device);\n\n    // Combine multiple filter methods\n    #[cfg(target_has_atomic = \"ptr\")]\n    {\n        let mut store = SafetensorsStore::from_bytes(None)\n            .with_regex(r\"^encoder\\..*\") // All encoder tensors\n            .with_full_path(\"decoder.scale\") // Plus specific decoder.scale\n            .with_predicate(|path, _| {\n                // Plus any projection tensors\n                path.contains(\"projection\")\n            });\n\n        module.save_into(&mut store).unwrap();\n\n        if let SafetensorsStore::Memory(ref p) = store {\n            let data = p.data().unwrap();\n            let tensors = safetensors::SafeTensors::deserialize(&data).unwrap();\n\n            // Should have encoder.*, decoder.scale, and projection tensors\n            let mut names = Vec::new();\n            for (name, _) in tensors.tensors() {\n                names.push(name);\n            }\n            assert!(names.iter().any(|n| n == \"encoder.weight\"));\n            assert!(names.iter().any(|n| n == \"encoder.bias\"));\n            assert!(names.iter().any(|n| n == \"encoder.norm\"));\n            assert!(names.iter().any(|n| n == \"decoder.scale\"));\n            // decoder.projection.* should also be included due to predicate\n            assert!(names.iter().any(|n| n.contains(\"projection\")));\n        }\n    }\n}\n\n#[test]\nfn builder_pattern_match_all() {\n    let device = Default::default();\n    let module = ComplexModule::<TestBackend>::new(&device);\n\n    let all_views = module.collect(None, None, false);\n    let total_count = all_views.len();\n\n    // Test match_all - should save everything\n    let mut store = SafetensorsStore::from_bytes(None).match_all();\n\n    module.save_into(&mut store).unwrap();\n\n    if let SafetensorsStore::Memory(ref p) = store {\n        let data = p.data().unwrap();\n        let tensors = safetensors::SafeTensors::deserialize(&data).unwrap();\n        assert_eq!(tensors.len(), total_count);\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/integration.rs",
    "content": "use burn_core as burn;\n\nuse crate::{ModuleSnapshot, SafetensorsStore};\nuse burn_core::module::{Module, Param};\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n// Integration tests demonstrating the SafeTensors store API\n#[derive(Module, Debug)]\nstruct IntegrationTestModel<B: Backend> {\n    encoder: IntegrationEncoderModule<B>,\n    decoder: IntegrationDecoderModule<B>,\n    head: IntegrationHeadModule<B>,\n}\n\n#[derive(Module, Debug)]\nstruct IntegrationEncoderModule<B: Backend> {\n    layer1: IntegrationLinearLayer<B>,\n    layer2: IntegrationLinearLayer<B>,\n    norm: IntegrationNormLayer<B>,\n}\n\n#[derive(Module, Debug)]\nstruct IntegrationDecoderModule<B: Backend> {\n    layer1: IntegrationLinearLayer<B>,\n    layer2: IntegrationLinearLayer<B>,\n    norm: IntegrationNormLayer<B>,\n}\n\n#[derive(Module, Debug)]\nstruct IntegrationHeadModule<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    bias: Param<Tensor<B, 1>>,\n}\n\n#[derive(Module, Debug)]\nstruct IntegrationLinearLayer<B: Backend> {\n    weight: Param<Tensor<B, 2>>,\n    bias: Param<Tensor<B, 1>>,\n}\n\n#[derive(Module, Debug)]\nstruct IntegrationNormLayer<B: Backend> {\n    scale: Param<Tensor<B, 1>>,\n    shift: Param<Tensor<B, 1>>,\n}\n\nimpl<B: Backend> IntegrationTestModel<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            encoder: IntegrationEncoderModule::new(device),\n            decoder: IntegrationDecoderModule::new(device),\n            head: IntegrationHeadModule::new(device),\n        }\n    }\n}\n\nimpl<B: Backend> IntegrationEncoderModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            layer1: IntegrationLinearLayer::new(device, 1),\n            layer2: IntegrationLinearLayer::new(device, 2),\n            norm: IntegrationNormLayer::new(device),\n        }\n    }\n}\n\nimpl<B: Backend> IntegrationDecoderModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            layer1: IntegrationLinearLayer::new(device, 3),\n            layer2: IntegrationLinearLayer::new(device, 4),\n            norm: IntegrationNormLayer::new(device),\n        }\n    }\n}\n\nimpl<B: Backend> IntegrationHeadModule<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            weight: Param::from_data([[5.0, 6.0], [7.0, 8.0]], device),\n            bias: Param::from_data([9.0, 10.0], device),\n        }\n    }\n}\n\nimpl<B: Backend> IntegrationLinearLayer<B> {\n    fn new(device: &B::Device, seed: i32) -> Self {\n        let weight_data = [\n            [seed as f32, (seed + 1) as f32],\n            [(seed + 2) as f32, (seed + 3) as f32],\n        ];\n        let bias_data = [(seed + 4) as f32, (seed + 5) as f32];\n\n        Self {\n            weight: Param::from_data(weight_data, device),\n            bias: Param::from_data(bias_data, device),\n        }\n    }\n}\n\nimpl<B: Backend> IntegrationNormLayer<B> {\n    fn new(device: &B::Device) -> Self {\n        Self {\n            scale: Param::from_data([1.0, 2.0], device),\n            shift: Param::from_data([0.1, 0.2], device),\n        }\n    }\n}\n\n#[test]\nfn basic_usage() {\n    let device = Default::default();\n    let model = IntegrationTestModel::<TestBackend>::new(&device);\n\n    // Save using new API (format, producer and version are automatically added)\n    let mut save_store = SafetensorsStore::from_bytes(None).metadata(\"model_name\", \"test_model\");\n\n    // Use collect_to method\n    model.save_into(&mut save_store).unwrap();\n\n    // Load using new API\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let mut target_model = IntegrationTestModel::<TestBackend>::new(&device);\n    let result = target_model.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n    assert_eq!(result.applied.len(), 14); // All tensors should be applied\n    assert_eq!(result.errors.len(), 0);\n    assert_eq!(result.unused.len(), 0);\n}\n\n#[test]\n#[cfg(target_has_atomic = \"ptr\")]\nfn with_filtering() {\n    let device = Default::default();\n    let model = IntegrationTestModel::<TestBackend>::new(&device);\n\n    // Save only encoder tensors using the builder pattern\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .with_regex(r\"^encoder\\..*\")\n        .metadata(\"subset\", \"encoder_only\");\n\n    model.save_into(&mut save_store).unwrap();\n\n    // Load into new model - need to allow partial loading since we only saved encoder tensors\n    let mut load_store = SafetensorsStore::from_bytes(None).allow_partial(true);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let mut target_model = IntegrationTestModel::<TestBackend>::new(&device);\n    let result = target_model.load_from(&mut load_store).unwrap();\n\n    // Only encoder tensors should be applied\n    assert_eq!(result.applied.len(), 6); // encoder has 6 tensors (2 layers × 2 + norm × 2)\n\n    // Check that only encoder tensors were applied\n    for tensor_name in &result.applied {\n        assert!(tensor_name.starts_with(\"encoder.\"));\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/metadata.rs",
    "content": "use crate::{ModuleSnapshot, SafetensorsStore};\nuse burn_nn::LinearConfig;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[test]\nfn default_metadata_included() {\n    // Verify that default metadata is automatically included\n    let default_metadata = SafetensorsStore::default_metadata();\n\n    // Check that format, producer and version are present\n    assert_eq!(default_metadata.get(\"format\").unwrap(), \"safetensors\");\n    assert_eq!(default_metadata.get(\"producer\").unwrap(), \"burn\");\n    assert!(default_metadata.contains_key(\"version\"));\n\n    // The version should be the crate version\n    let version = default_metadata.get(\"version\").unwrap();\n    assert!(!version.is_empty());\n}\n\n#[test]\nfn metadata_preservation() {\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Write with metadata - note that format, producer and version are automatically added\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .metadata(\"model_type\", \"linear\")\n        .metadata(\"custom_field\", \"test_value\");\n\n    module.save_into(&mut save_store).unwrap();\n\n    // Verify metadata was saved (would need to add a method to check metadata)\n    // For now, just verify the round trip works\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        // Get Arc and extract data\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n\n    let mut module2 = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n    let result = module2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n}\n\n#[test]\nfn clear_metadata_removes_all() {\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Create store with custom metadata, then clear all\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .metadata(\"model_type\", \"linear\")\n        .metadata(\"custom_field\", \"test_value\")\n        .clear_metadata(); // Should remove all metadata including defaults\n\n    module.save_into(&mut save_store).unwrap();\n\n    // Load and verify the module still works (metadata is optional)\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n\n    let mut module2 = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n    let result = module2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n}\n\n#[test]\nfn clear_then_add_custom_metadata() {\n    let device = Default::default();\n    let module = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n\n    // Clear all metadata, then add only custom ones\n    let mut save_store = SafetensorsStore::from_bytes(None)\n        .clear_metadata()\n        .metadata(\"only_custom\", \"value\");\n\n    module.save_into(&mut save_store).unwrap();\n\n    // Verify round-trip works\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n\n    let mut module2 = LinearConfig::new(4, 2)\n        .with_bias(true)\n        .init::<TestBackend>(&device);\n    let result = module2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/mixed_datatypes.rs",
    "content": "use burn_core as burn;\n\nuse burn_core::module::{Module, Param, ParamId};\nuse burn_nn as nn;\nuse burn_tensor::{Bool, Int, Tensor, backend::Backend};\n\nuse crate::{ModuleSnapshot, SafetensorsStore};\n\n/// Simple model with different data types for testing\n#[derive(Module, Debug)]\npub struct MixedDtypeModel<B: Backend> {\n    // Standard neural network layers (float tensors)\n    linear: nn::Linear<B>,\n\n    // Direct tensor parameters of different types\n    float_tensor: Param<Tensor<B, 2>>,\n\n    int_tensor: Param<Tensor<B, 2, Int>>,\n\n    bool_tensor: Param<Tensor<B, 2, Bool>>,\n}\n\nimpl<B: Backend> MixedDtypeModel<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            linear: nn::LinearConfig::new(3, 3).init(device),\n\n            // Simple float values\n            float_tensor: Param::from_tensor(Tensor::from_floats(\n                [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],\n                device,\n            )),\n\n            // Simple integer values\n            int_tensor: Param::initialized(\n                ParamId::new(),\n                Tensor::from_ints([[1, 2, 3], [4, 5, 6]], device),\n            ),\n\n            // Simple boolean values\n            bool_tensor: Param::initialized(\n                ParamId::new(),\n                Tensor::from_bool(\n                    burn::tensor::TensorData::new(\n                        vec![true, false, true, false, true, false],\n                        [2, 3],\n                    ),\n                    device,\n                ),\n            ),\n        }\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::excessive_precision)]\nmod tests {\n    use burn_tensor::BoolStore;\n\n    use super::*;\n\n    #[test]\n    fn test_mixed_dtypes_round_trip() {\n        type TestBackend = burn_ndarray::NdArray<f32>;\n        let device = Default::default();\n\n        // Create model with mixed data types\n        let model = MixedDtypeModel::<TestBackend>::new(&device);\n\n        // Save to bytes\n        let mut save_store = SafetensorsStore::from_bytes(None);\n        model.save_into(&mut save_store).expect(\"Failed to save\");\n        let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n        // Load into a new model\n        let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n        let mut loaded_model = MixedDtypeModel::<TestBackend>::new(&device);\n        loaded_model\n            .load_from(&mut load_store)\n            .expect(\"Failed to load\");\n\n        // Verify float tensor is preserved\n        let orig_float = model.float_tensor.val().into_data();\n        let loaded_float = loaded_model.float_tensor.val().into_data();\n        assert_eq!(orig_float, loaded_float, \"Float tensor not preserved\");\n\n        // Verify integer tensor is preserved\n        let orig_int = model.int_tensor.val().into_data();\n        let loaded_int = loaded_model.int_tensor.val().into_data();\n        assert_eq!(orig_int, loaded_int, \"Integer tensor not preserved\");\n\n        // Verify boolean tensor is preserved\n        let orig_bool = model.bool_tensor.val().into_data();\n        let loaded_bool = loaded_model.bool_tensor.val().into_data();\n        assert_eq!(orig_bool, loaded_bool, \"Boolean tensor not preserved\");\n    }\n\n    #[test]\n    fn test_dtype_detection() {\n        type TestBackend = burn_ndarray::NdArray<f32>;\n        let device = Default::default();\n\n        let model = MixedDtypeModel::<TestBackend>::new(&device);\n        let snapshots = model.collect(None, None, false);\n\n        for snapshot in snapshots {\n            let path = snapshot.full_path();\n            let dtype = snapshot.dtype;\n\n            if path.contains(\"float_tensor\") || path.contains(\"linear\") {\n                assert_eq!(\n                    dtype,\n                    burn::tensor::DType::F32,\n                    \"Float tensor {} should have F32 dtype\",\n                    path\n                );\n            } else if path.contains(\"int_tensor\") {\n                assert!(\n                    matches!(\n                        dtype,\n                        burn::tensor::DType::I64\n                            | burn::tensor::DType::I32\n                            | burn::tensor::DType::I16\n                            | burn::tensor::DType::I8\n                    ),\n                    \"Integer tensor {} should have integer dtype, got {:?}\",\n                    path,\n                    dtype\n                );\n            } else if path.contains(\"bool_tensor\") {\n                assert_eq!(\n                    dtype,\n                    burn::tensor::DType::Bool(BoolStore::Native),\n                    \"Boolean tensor {} should have Bool dtype\",\n                    path\n                );\n            }\n        }\n    }\n\n    #[test]\n    fn test_extreme_values() {\n        type TestBackend = burn_ndarray::NdArray<f32>;\n        let device = <TestBackend as Backend>::Device::default();\n\n        #[derive(Module, Debug)]\n        struct ExtremeValueModel<B: Backend> {\n            large_floats: Param<Tensor<B, 1>>,\n            small_floats: Param<Tensor<B, 1>>,\n            large_ints: Param<Tensor<B, 1, Int>>,\n        }\n\n        impl<B: Backend> ExtremeValueModel<B> {\n            fn new(device: &B::Device) -> Self {\n                Self {\n                    large_floats: Param::from_tensor(Tensor::from_floats(\n                        [1e30, -1e30, f32::MAX, f32::MIN],\n                        device,\n                    )),\n                    small_floats: Param::from_tensor(Tensor::from_floats(\n                        [1e-30, -1e-30, f32::MIN_POSITIVE, f32::EPSILON],\n                        device,\n                    )),\n                    large_ints: Param::initialized(\n                        ParamId::new(),\n                        Tensor::from_ints([i32::MAX, i32::MIN, 0, -1], device),\n                    ),\n                }\n            }\n        }\n\n        let model = ExtremeValueModel::<TestBackend>::new(&device);\n\n        // Save and load\n        let mut save_store = SafetensorsStore::from_bytes(None);\n        model.save_into(&mut save_store).expect(\"Failed to save\");\n        let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n        let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n        let mut loaded_model = ExtremeValueModel::<TestBackend>::new(&device);\n        loaded_model\n            .load_from(&mut load_store)\n            .expect(\"Failed to load\");\n\n        // Check exact preservation\n        assert_eq!(\n            model.large_floats.val().into_data(),\n            loaded_model.large_floats.val().into_data(),\n            \"Large floats not preserved\"\n        );\n        assert_eq!(\n            model.small_floats.val().into_data(),\n            loaded_model.small_floats.val().into_data(),\n            \"Small floats not preserved\"\n        );\n        assert_eq!(\n            model.large_ints.val().into_data(),\n            loaded_model.large_ints.val().into_data(),\n            \"Large integers not preserved\"\n        );\n    }\n\n    #[test]\n    fn test_mixed_precision_floats() {\n        // Note: While SafeTensors format supports storing tensors with different precisions\n        // (F16, BF16, F32, F64, etc.) in the same file, Burn's backend architecture currently\n        // requires all tensors in a model instance to share the same floating-point precision.\n        // This is determined at the backend level (e.g., NdArray<f32> or NdArray<f64>).\n        //\n        // However, for storage purposes, SafeTensors can correctly save and load tensors\n        // with their original precision, preserving the data type information in the file format.\n        // This test demonstrates that different precision backends work correctly with SafeTensors.\n\n        // Test with f32 backend\n        {\n            type TestBackend = burn_ndarray::NdArray<f32>;\n            let device = Default::default();\n\n            let model = MixedDtypeModel::<TestBackend>::new(&device);\n\n            // Save to bytes\n            let mut save_store = SafetensorsStore::from_bytes(None);\n            model.save_into(&mut save_store).expect(\"Failed to save\");\n            let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n            // Load and verify\n            let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n            let mut loaded_model = MixedDtypeModel::<TestBackend>::new(&device);\n            loaded_model\n                .load_from(&mut load_store)\n                .expect(\"Failed to load\");\n\n            assert_eq!(\n                model.float_tensor.val().into_data(),\n                loaded_model.float_tensor.val().into_data(),\n                \"F32 float tensor not preserved\"\n            );\n        }\n\n        // Test with f64 backend\n        {\n            type TestBackend = burn_ndarray::NdArray<f64>;\n            let device = Default::default();\n\n            #[derive(Module, Debug)]\n            struct F64Model<B: Backend> {\n                linear: nn::Linear<B>,\n                double_precision: Param<Tensor<B, 2>>,\n            }\n\n            let model = F64Model::<TestBackend> {\n                linear: nn::LinearConfig::new(2, 2).init(&device),\n                double_precision: Param::from_tensor(Tensor::from_floats(\n                    [\n                        [1.234567890123456789, 2.345678901234567890],\n                        [3.456789012345678901, 4.567890123456789012],\n                    ],\n                    &device,\n                )),\n            };\n\n            // Save to bytes\n            let mut save_store = SafetensorsStore::from_bytes(None);\n            model.save_into(&mut save_store).expect(\"Failed to save\");\n            let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n            // Load and verify\n            let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n            let mut loaded_model = F64Model::<TestBackend> {\n                linear: nn::LinearConfig::new(2, 2).init(&device),\n                double_precision: Param::from_tensor(Tensor::zeros([2, 2], &device)),\n            };\n            loaded_model\n                .load_from(&mut load_store)\n                .expect(\"Failed to load\");\n\n            let orig = model.double_precision.val().into_data();\n            let loaded = loaded_model.double_precision.val().into_data();\n            assert_eq!(orig, loaded, \"F64 double precision not preserved\");\n        }\n    }\n\n    #[test]\n    fn test_mixed_precision_integers() {\n        type TestBackend = burn_ndarray::NdArray<f32>;\n        let device = Default::default();\n\n        #[derive(Module, Debug)]\n        struct MultiIntModel<B: Backend> {\n            // Note: Burn's Tensor<B, D, Int> uses the backend's default int type\n            // We can't directly specify i8, i16, etc. in the type system\n            // But we can test with different values that would fit in different ranges\n            small_ints: Param<Tensor<B, 1, Int>>, // Values that fit in i8\n            medium_ints: Param<Tensor<B, 1, Int>>, // Values that fit in i16\n            large_ints: Param<Tensor<B, 1, Int>>, // Values that need i32/i64\n        }\n\n        let model = MultiIntModel::<TestBackend> {\n            small_ints: Param::initialized(\n                ParamId::new(),\n                Tensor::from_ints([127i32, -128, 0, 42], &device),\n            ),\n            medium_ints: Param::initialized(\n                ParamId::new(),\n                Tensor::from_ints([32767i32, -32768, 1000, -1000], &device),\n            ),\n            large_ints: Param::initialized(\n                ParamId::new(),\n                Tensor::from_ints([i32::MAX, i32::MIN, 1_000_000, -1_000_000], &device),\n            ),\n        };\n\n        // Save to bytes\n        let mut save_store = SafetensorsStore::from_bytes(None);\n        model.save_into(&mut save_store).expect(\"Failed to save\");\n        let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n        // Load and verify\n        let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n        let mut loaded_model = MultiIntModel::<TestBackend> {\n            small_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)),\n            medium_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)),\n            large_ints: Param::initialized(ParamId::new(), Tensor::zeros([4], &device)),\n        };\n        loaded_model\n            .load_from(&mut load_store)\n            .expect(\"Failed to load\");\n\n        assert_eq!(\n            model.small_ints.val().into_data(),\n            loaded_model.small_ints.val().into_data(),\n            \"Small ints (i8 range) not preserved\"\n        );\n        assert_eq!(\n            model.medium_ints.val().into_data(),\n            loaded_model.medium_ints.val().into_data(),\n            \"Medium ints (i16 range) not preserved\"\n        );\n        assert_eq!(\n            model.large_ints.val().into_data(),\n            loaded_model.large_ints.val().into_data(),\n            \"Large ints (i32 range) not preserved\"\n        );\n    }\n\n    #[test]\n    fn test_comprehensive_mixed_types() {\n        type TestBackend = burn_ndarray::NdArray<f32>;\n        let device = Default::default();\n\n        #[derive(Module, Debug)]\n        struct ComprehensiveModel<B: Backend> {\n            // Neural network layers\n            linear1: nn::Linear<B>,\n            conv2d: nn::conv::Conv2d<B>,\n\n            // Different tensor types\n            float32_weights: Param<Tensor<B, 3>>,\n            integer_indices: Param<Tensor<B, 2, Int>>,\n            boolean_mask: Param<Tensor<B, 2, Bool>>,\n        }\n\n        let model = ComprehensiveModel::<TestBackend> {\n            linear1: nn::LinearConfig::new(4, 8).init(&device),\n            conv2d: nn::conv::Conv2dConfig::new([3, 16], [3, 3]).init(&device),\n\n            float32_weights: Param::from_tensor(Tensor::from_floats(\n                [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],\n                &device,\n            )),\n\n            integer_indices: Param::initialized(\n                ParamId::new(),\n                Tensor::from_ints(\n                    [[0, 1, 2, 3], [10, 20, 30, 40], [100, 200, 300, 400]],\n                    &device,\n                ),\n            ),\n\n            boolean_mask: Param::initialized(\n                ParamId::new(),\n                Tensor::from_bool(\n                    burn::tensor::TensorData::new(\n                        vec![true, false, false, true, false, true, true, false],\n                        [2, 4],\n                    ),\n                    &device,\n                ),\n            ),\n        };\n\n        // Collect all tensors\n        let snapshots = model.collect(None, None, false);\n\n        // Verify we have all expected tensors\n        let paths: Vec<String> = snapshots.iter().map(|s| s.full_path()).collect();\n        assert!(paths.iter().any(|p| p.contains(\"linear1\")));\n        assert!(paths.iter().any(|p| p.contains(\"conv2d\")));\n        assert!(paths.iter().any(|p| p.contains(\"float32_weights\")));\n        assert!(paths.iter().any(|p| p.contains(\"integer_indices\")));\n        assert!(paths.iter().any(|p| p.contains(\"boolean_mask\")));\n\n        // Save to bytes\n        let mut save_store = SafetensorsStore::from_bytes(None);\n        model.save_into(&mut save_store).expect(\"Failed to save\");\n        let bytes = save_store.get_bytes().expect(\"Failed to get bytes\");\n\n        // Load into fresh model\n        let mut load_store = SafetensorsStore::from_bytes(Some(bytes));\n        let mut loaded_model = ComprehensiveModel::<TestBackend> {\n            linear1: nn::LinearConfig::new(4, 8).init(&device),\n            conv2d: nn::conv::Conv2dConfig::new([3, 16], [3, 3]).init(&device),\n            float32_weights: Param::from_tensor(Tensor::zeros([2, 2, 2], &device)),\n            integer_indices: Param::initialized(ParamId::new(), Tensor::zeros([3, 4], &device)),\n            boolean_mask: Param::initialized(\n                ParamId::new(),\n                Tensor::from_bool(\n                    burn::tensor::TensorData::new(vec![false; 8], [2, 4]),\n                    &device,\n                ),\n            ),\n        };\n        loaded_model\n            .load_from(&mut load_store)\n            .expect(\"Failed to load\");\n\n        // Verify all data is preserved\n        assert_eq!(\n            model.float32_weights.val().into_data(),\n            loaded_model.float32_weights.val().into_data(),\n            \"Float32 weights not preserved\"\n        );\n        assert_eq!(\n            model.integer_indices.val().into_data(),\n            loaded_model.integer_indices.val().into_data(),\n            \"Integer indices not preserved\"\n        );\n        assert_eq!(\n            model.boolean_mask.val().into_data(),\n            loaded_model.boolean_mask.val().into_data(),\n            \"Boolean mask not preserved\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/mod.rs",
    "content": "mod adapter;\nmod direct_access;\nmod error_handling;\n#[cfg(feature = \"std\")]\nmod file_io;\nmod filtering;\nmod integration;\nmod metadata;\nmod mixed_datatypes;\nmod multi_layer_verify;\nmod pytorch_import;\nmod round_trip;\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/multi_layer_verify.rs",
    "content": "//! Tests for multi-layer model loading with SafeTensors format\nuse burn_core as burn;\n\nuse burn_core::module::Module;\nuse burn_tensor::{Tensor, backend::Backend};\n\nuse burn_nn::{\n    BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,\n    conv::{Conv2d, Conv2dConfig},\n};\n\n/// Multi-layer neural network model for testing\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n    norm1: BatchNorm<B>,\n    fc1: Linear<B>,\n    relu: Relu,\n}\n\nimpl<B: Backend> Net<B> {\n    /// Create a new network instance\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            conv1: Conv2dConfig::new([3, 4], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .init(device),\n            norm1: BatchNormConfig::new(4).init(device),\n            fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),\n            relu: Relu::new(),\n        }\n    }\n\n    /// Forward pass of the model\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {\n        let x = self.conv1.forward(x);\n        let x = self.norm1.forward(x);\n        let x = self.relu.forward(x);\n        // Flatten all dimensions except the batch dimension\n        let x = x.flatten(1, 3);\n        self.fc1.forward(x)\n    }\n}\n\nuse crate::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};\nuse burn_tensor::Tolerance;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n/// Path to the multi_layer.safetensors test file\nfn get_safetensors_path() -> &'static str {\n    concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    )\n}\n\n#[test]\nfn multi_layer_model() {\n    let device = Default::default();\n    let safetensors_path = get_safetensors_path();\n\n    // Load model from SafeTensors file with PyTorch adapter\n    let mut store = SafetensorsStore::from_file(safetensors_path)\n        .with_from_adapter(PyTorchToBurnAdapter)\n        .validate(false)\n        .allow_partial(true);\n\n    let mut model = Net::<TestBackend>::new(&device);\n    let result = model.load_from(&mut store).unwrap();\n\n    // Verify loading was successful\n    assert!(\n        !result.applied.is_empty(),\n        \"Should have loaded some tensors\"\n    );\n    assert!(\n        result.errors.is_empty(),\n        \"Should have no errors: {:?}\",\n        result.errors\n    );\n\n    // Test forward pass\n    let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);\n    let output = model.forward(input);\n\n    // Expected output values from PyTorch model\n    let expected = Tensor::<TestBackend, 2>::from_data(\n        [[\n            0.04971555,\n            -0.16849735,\n            0.05182848,\n            -0.18032673,\n            0.23138367,\n            0.05041867,\n            0.13005908,\n            -0.32202929,\n            -0.07915690,\n            -0.03232457,\n            -0.19790289,\n            -0.17476529,\n            -0.19627589,\n            -0.21757686,\n            -0.31376451,\n            0.08377837,\n        ]],\n        &device,\n    );\n\n    // Verify output matches expected values\n    output\n        .to_data()\n        .assert_approx_eq::<f32>(&expected.to_data(), Tolerance::default());\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/pytorch_import.rs",
    "content": "use burn_core as burn;\n\nuse crate::{ModuleSnapshot, SafetensorsStore};\nuse burn_core::module::Module;\nuse burn_nn::{\n    BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,\n    conv::{Conv2d, Conv2dConfig},\n};\nuse burn_tensor::Tensor;\nuse burn_tensor::backend::Backend;\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[derive(Module, Debug)]\npub struct Net<B: Backend> {\n    conv1: Conv2d<B>,\n    norm1: BatchNorm<B>,\n    fc1: Linear<B>,\n    relu: Relu,\n}\n\nimpl<B: Backend> Net<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            conv1: Conv2dConfig::new([3, 4], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .init(device),\n            norm1: BatchNormConfig::new(4).init(device),\n            fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),\n            relu: Relu::new(),\n        }\n    }\n\n    /// Forward pass of the model.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {\n        let x = self.conv1.forward(x);\n        let x = self.norm1.forward(x);\n        let x = self.relu.forward(x);\n        // Flatten all dimensions except the batch dimension\n        let x = x.flatten(1, 3);\n        self.fc1.forward(x)\n    }\n}\n\n#[test]\n#[cfg(all(feature = \"std\", target_has_atomic = \"ptr\"))]\nfn multi_layer_model_import() {\n    let device = Default::default();\n\n    // Reference the safetensors file from burn-import\n    let safetensors_path = concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    );\n\n    // Load the model using SafetensorsStore\n    // Note: PyTorch and Burn have different conventions for linear layer weights\n    // PyTorch stores as [out_features, in_features], Burn as [in_features, out_features]\n    // Also, tensor names may differ (e.g., PyTorch uses different names for BatchNorm params)\n    let mut store = SafetensorsStore::from_file(safetensors_path)\n        .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format\n        .allow_partial(true); // Allow partial loading due to naming differences\n    let mut model = Net::<TestBackend>::new(&device);\n\n    let result = model.load_from(&mut store).unwrap();\n\n    // With the adapter, weights should load correctly\n    assert!(!result.applied.is_empty());\n    assert!(\n        result.errors.is_empty(),\n        \"Should have no errors with adapter: {:?}\",\n        result.errors\n    );\n\n    // Test forward pass with the loaded weights\n    // Note: Due to shape mismatches (PyTorch vs Burn conventions for linear layers),\n    // we can't directly compare outputs with PyTorch model.\n    // This test mainly verifies that the loading mechanism works.\n    let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);\n    let _output = model.forward(input);\n\n    // Verify that some tensors were loaded successfully\n    // Conv and BatchNorm layers should load correctly\n    assert!(result.applied.iter().any(|n| n.contains(\"conv1\")));\n    assert!(result.applied.iter().any(|n| n.contains(\"norm1\")));\n}\n\n#[test]\n#[cfg(all(feature = \"std\", target_has_atomic = \"ptr\"))]\nfn safetensors_round_trip_with_pytorch_model() {\n    let device = Default::default();\n\n    // Reference the safetensors file from burn-import\n    let safetensors_path = concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    );\n\n    // Load the model from PyTorch safetensors\n    let mut load_store = SafetensorsStore::from_file(safetensors_path)\n        .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format\n        .allow_partial(true); // Allow partial loading due to naming differences\n    let mut model = Net::<TestBackend>::new(&device);\n    let load_result = model.load_from(&mut load_store).unwrap();\n    // With the adapter, weights should load correctly\n    assert!(!load_result.applied.is_empty());\n    assert!(\n        load_result.errors.is_empty(),\n        \"Should have no errors with adapter: {:?}\",\n        load_result.errors\n    );\n\n    // Save the model to memory\n    // Note: format, producer and version are automatically added\n    let mut save_store = SafetensorsStore::from_bytes(None).metadata(\"source\", \"pytorch\");\n    model.save_into(&mut save_store).unwrap();\n\n    // Load into a new model\n    let mut model2 = Net::<TestBackend>::new(&device);\n    let mut load_store2 = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store2\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        p.set_data(p_save.data().unwrap().as_ref().clone());\n    }\n\n    let result = model2.load_from(&mut load_store2).unwrap();\n    assert!(!result.applied.is_empty());\n\n    // Verify both models produce the same output\n    let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);\n    let output1 = model.forward(input.clone());\n    let output2 = model2.forward(input);\n\n    // Check outputs are identical\n    let output1_data = output1.to_data().to_vec::<f32>().unwrap();\n    let output2_data = output2.to_data().to_vec::<f32>().unwrap();\n\n    for (a, b) in output1_data.iter().zip(output2_data.iter()) {\n        assert!((a - b).abs() < 1e-7, \"Outputs differ after round trip\");\n    }\n}\n\n#[test]\n#[cfg(all(feature = \"std\", target_has_atomic = \"ptr\"))]\nfn partial_load_from_pytorch_model() {\n    let device = Default::default();\n\n    // Reference the safetensors file from burn-import\n    let safetensors_path = concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    );\n\n    // Load only conv1 and norm1 parameters (not fc1)\n    let mut store = SafetensorsStore::from_file(safetensors_path)\n        .validate(false) // Disable validation due to shape differences\n        .allow_partial(true);\n\n    let mut model = Net::<TestBackend>::new(&device);\n\n    // Save initial fc1 weights for comparison\n    let _initial_fc1_weight = model.fc1.weight.val().to_data();\n\n    let result = model.load_from(&mut store).unwrap();\n\n    // Should load available tensors (with some errors due to shape mismatch)\n    assert!(!result.applied.is_empty());\n\n    // fc1 weight should remain unchanged if not in the file\n    // or should be updated if it is in the file\n    // This test verifies that partial loading works correctly\n}\n\n#[test]\n#[cfg(all(feature = \"std\", target_has_atomic = \"ptr\"))]\nfn verify_tensor_names_from_pytorch() {\n    let device = Default::default();\n\n    // Reference the safetensors file from burn-import\n    let safetensors_path = concat!(\n        env!(\"CARGO_MANIFEST_DIR\"),\n        \"/safetensors-tests/tests/multi_layer/multi_layer.safetensors\"\n    );\n\n    // Create a model and load from PyTorch\n    let mut model = Net::<TestBackend>::new(&device);\n    let mut store = SafetensorsStore::from_file(safetensors_path)\n        .validate(false) // Disable validation due to shape differences\n        .allow_partial(true); // Allow partial loading due to naming differences\n    let result = model.load_from(&mut store).unwrap();\n\n    // Check that we loaded some tensors (with errors due to shape mismatch)\n    assert!(!result.applied.is_empty());\n\n    // Collect tensor names from the model\n    let views = model.collect(None, None, false);\n    let tensor_names: Vec<String> = views.iter().map(|v| v.full_path()).collect();\n\n    // Verify expected tensor names are present\n    assert!(tensor_names.iter().any(|n| n.contains(\"conv1\")));\n    assert!(tensor_names.iter().any(|n| n.contains(\"norm1\")));\n    assert!(tensor_names.iter().any(|n| n.contains(\"fc1\")));\n}\n"
  },
  {
    "path": "crates/burn-store/src/safetensors/tests/round_trip.rs",
    "content": "use burn_core as burn;\n\nuse crate::{ModuleSnapshot, SafetensorsStore};\nuse burn_core::module::{Module, Param};\nuse burn_nn::{Linear, LinearConfig};\nuse burn_tensor::backend::Backend;\nuse burn_tensor::{Tensor, shape};\n\ntype TestBackend = burn_ndarray::NdArray;\n\n#[derive(Module, Debug)]\npub(super) struct ComplexModule<B: Backend> {\n    pub encoder: EncoderModule<B>,\n    pub decoder: DecoderModule<B>,\n    pub layers: Vec<Linear<B>>,\n}\n\n#[derive(Module, Debug)]\npub(super) struct EncoderModule<B: Backend> {\n    pub weight: Param<Tensor<B, 3>>,\n    pub bias: Param<Tensor<B, 1>>,\n    pub norm: Param<Tensor<B, 1>>,\n}\n\n#[derive(Module, Debug)]\npub(super) struct DecoderModule<B: Backend> {\n    pub projection: Linear<B>,\n    pub scale: Param<Tensor<B, 2>>,\n}\n\nimpl<B: Backend> ComplexModule<B> {\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            encoder: EncoderModule {\n                weight: Param::from_data(\n                    [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],\n                    device,\n                ),\n                bias: Param::from_data([0.1, 0.2, 0.3], device),\n                norm: Param::from_data([1.0, 1.0, 1.0], device),\n            },\n            decoder: DecoderModule {\n                projection: LinearConfig::new(4, 2).with_bias(true).init(device),\n                scale: Param::from_data([[0.5, 0.5], [0.5, 0.5]], device),\n            },\n            layers: vec![\n                LinearConfig::new(3, 4).with_bias(false).init(device),\n                LinearConfig::new(4, 3).with_bias(true).init(device),\n            ],\n        }\n    }\n\n    pub fn new_zeros(device: &B::Device) -> Self {\n        Self {\n            encoder: EncoderModule {\n                weight: Param::from_tensor(Tensor::zeros([2, 2, 2], device)),\n                bias: Param::from_tensor(Tensor::zeros([3], device)),\n                norm: Param::from_tensor(Tensor::zeros([3], device)),\n            },\n            decoder: DecoderModule {\n                projection: LinearConfig::new(4, 2).with_bias(true).init(device),\n                scale: Param::from_tensor(Tensor::zeros([2, 2], device)),\n            },\n            layers: vec![\n                LinearConfig::new(3, 4).with_bias(false).init(device),\n                LinearConfig::new(4, 3).with_bias(true).init(device),\n            ],\n        }\n    }\n}\n\n#[test]\nfn complex_module_round_trip() {\n    let device = Default::default();\n    let module1 = ComplexModule::<TestBackend>::new(&device);\n    let mut module2 = ComplexModule::<TestBackend>::new_zeros(&device);\n\n    // Save module1 using new store API\n    let mut save_store = SafetensorsStore::from_bytes(None);\n    module1.save_into(&mut save_store).unwrap();\n\n    // Load into module2\n    let mut load_store = SafetensorsStore::from_bytes(None);\n    if let SafetensorsStore::Memory(ref mut p) = load_store\n        && let SafetensorsStore::Memory(ref p_save) = save_store\n    {\n        // Get Arc and extract data\n        let data_arc = p_save.data().unwrap();\n        p.set_data(data_arc.as_ref().clone());\n    }\n    let result = module2.load_from(&mut load_store).unwrap();\n\n    assert!(result.is_success());\n    assert!(result.applied.len() > 5);\n    assert_eq!(result.errors.len(), 0);\n\n    // Verify data was imported correctly\n    let module2_views = module2.collect(None, None, false);\n    let encoder_weight = module2_views\n        .iter()\n        .find(|v| v.full_path() == \"encoder.weight\")\n        .unwrap()\n        .to_data()\n        .unwrap();\n    assert_eq!(encoder_weight.shape, shape![2, 2, 2]);\n}\n"
  },
  {
    "path": "crates/burn-store/src/tensor_snapshot.rs",
    "content": "use alloc::rc::Rc;\nuse alloc::string::String;\nuse alloc::string::ToString;\nuse alloc::vec::Vec;\nuse burn_core::module::ParamId;\nuse burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};\nuse burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};\nuse half::f16;\n\n/// Returns the byte size of a quantization parameter type.\n// TODO: Add `size_bytes()` method to `QuantParam` in cubecl and use it here.\nconst fn quant_param_size(param: QuantParam) -> usize {\n    match param {\n        QuantParam::F32 => core::mem::size_of::<f32>(),\n        QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),\n        QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),\n    }\n}\n\n/// Error type for TensorSnapshot operations\n#[derive(Debug, Clone)]\npub enum TensorSnapshotError {\n    /// I/O error occurred while loading tensor data\n    IoError(String),\n    /// Data corruption or invalid format\n    DataError(String),\n    /// Panic occurred while loading tensor data\n    PanicError(String),\n}\n\nimpl core::fmt::Display for TensorSnapshotError {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        match self {\n            Self::IoError(e) => write!(f, \"I/O error: {}\", e),\n            Self::DataError(e) => write!(f, \"Data error: {}\", e),\n            Self::PanicError(e) => write!(f, \"Panic error: {}\", e),\n        }\n    }\n}\n\nimpl core::error::Error for TensorSnapshotError {}\n\n/// A lightweight snapshot of a tensor that can lazily produce TensorData.\n///\n/// TensorSnapshot stores a cloned tensor internally (which is cheap due to reference counting)\n/// and only materializes the actual data when `to_data()` is called. This allows\n/// efficient inspection of module structure without the overhead of copying all tensor data.\n///\n/// The dtype and shape are cached for efficient access without requiring data materialization,\n/// which is particularly useful for serialization formats that need metadata upfront.\npub struct TensorSnapshot {\n    /// Function to get tensor data when needed (Rc allows cloning)\n    data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,\n    /// Data type of the tensor (cached for efficient access)\n    pub dtype: burn_tensor::DType,\n    /// Shape of the tensor (cached for efficient access)\n    pub shape: Shape,\n    /// Path stack representing the module hierarchy\n    pub path_stack: Option<Vec<String>>,\n    /// Container stack representing the container types at each level\n    pub container_stack: Option<Vec<String>>,\n    /// Unique identifier for the tensor parameter\n    pub tensor_id: Option<ParamId>,\n}\n\nimpl TensorSnapshot {\n    /// Create a new tensor snapshot from a float tensor\n    pub fn from_float<B: Backend, const D: usize>(\n        tensor: &Tensor<B, D>,\n        path_stack: Vec<String>,\n        container_stack: Vec<String>,\n        tensor_id: ParamId,\n    ) -> Self {\n        let dtype = tensor.dtype();\n        let shape = tensor.shape();\n        let tensor = tensor.clone(); // Clone is cheap (reference counted)\n        Self {\n            data_fn: Rc::new(move || Ok(tensor.to_data())),\n            dtype,\n            shape,\n            path_stack: Some(path_stack),\n            container_stack: Some(container_stack),\n            tensor_id: Some(tensor_id),\n        }\n    }\n\n    /// Create a new tensor snapshot from an int tensor\n    pub fn from_int<B: Backend, const D: usize>(\n        tensor: &Tensor<B, D, Int>,\n        path_stack: Vec<String>,\n        container_stack: Vec<String>,\n        tensor_id: ParamId,\n    ) -> Self {\n        let dtype = tensor.dtype();\n        let shape = tensor.shape();\n        let tensor = tensor.clone(); // Clone is cheap (reference counted)\n        Self {\n            data_fn: Rc::new(move || Ok(tensor.to_data())),\n            dtype,\n            shape,\n            path_stack: Some(path_stack),\n            container_stack: Some(container_stack),\n            tensor_id: Some(tensor_id),\n        }\n    }\n\n    /// Create a new tensor snapshot from a bool tensor\n    pub fn from_bool<B: Backend, const D: usize>(\n        tensor: &Tensor<B, D, Bool>,\n        path_stack: Vec<String>,\n        container_stack: Vec<String>,\n        tensor_id: ParamId,\n    ) -> Self {\n        let dtype = tensor.dtype();\n        let shape = tensor.shape();\n        let tensor = tensor.clone(); // Clone is cheap (reference counted)\n        Self {\n            data_fn: Rc::new(move || Ok(tensor.to_data())),\n            dtype,\n            shape,\n            path_stack: Some(path_stack),\n            container_stack: Some(container_stack),\n            tensor_id: Some(tensor_id),\n        }\n    }\n\n    /// Convert to TensorData (this is where actual data copy happens)\n    #[cfg(feature = \"std\")]\n    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {\n        // Use AssertUnwindSafe since we're working with Rc which is not UnwindSafe\n        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(\n            |_| {\n                Err(TensorSnapshotError::PanicError(\n                    \"Panic occurred while loading tensor data\".to_string(),\n                ))\n            },\n        )\n    }\n\n    /// Convert to TensorData (this is where actual data copy happens)\n    #[cfg(not(feature = \"std\"))]\n    pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {\n        (self.data_fn)() // Can't catch panics in no-std, do it when core::panic::AssertUnwindSafe is available\n    }\n\n    /// Get the full path by joining the path stack\n    pub fn full_path(&self) -> String {\n        self.path_stack\n            .as_ref()\n            .map(|stack| stack.join(\".\"))\n            .unwrap_or_default()\n    }\n\n    /// Get the full container path by joining the container stack\n    pub fn container_path(&self) -> String {\n        self.container_stack\n            .as_ref()\n            .map(|stack| stack.join(\".\"))\n            .unwrap_or_default()\n    }\n\n    /// Get the module type (last Struct/Enum in the hierarchy)\n    ///\n    /// Returns the last user-defined module type, skipping primitive containers\n    /// like \"Vec\", \"Array\". This is useful for determining which user-defined\n    /// module a tensor belongs to.\n    ///\n    /// # Examples\n    /// - `Linear.weight` → `Some(\"Struct:Linear\")`\n    /// - `Vec<Linear>[0].weight` → `Some(\"Struct:Linear\")`\n    /// - `Linear.bias` (Optional) → `Some(\"Struct:Linear\")`\n    /// - `Vec<Param>[0]` (no module) → `None`\n    pub fn module_type(&self) -> Option<String> {\n        self.container_stack.as_ref().and_then(|stack| {\n            // Find the last user-defined type (Struct: or Enum:)\n            stack\n                .iter()\n                .rev()\n                .find(|ct| ct.starts_with(\"Struct:\") || ct.starts_with(\"Enum:\"))\n                .cloned()\n        })\n    }\n\n    /// Get the immediate container type (last in the container stack)\n    ///\n    /// Returns the last element in the container stack, which could be a\n    /// user-defined type (\"Struct:\", \"Enum:\") or a collection type (\"Vec\", \"Array\").\n    /// This is useful for understanding the full container hierarchy.\n    ///\n    /// # Examples\n    /// - `Linear.weight` → `\"Struct:Linear\"`\n    /// - `Vec<Linear>[0].weight` → `\"Struct:Linear\"` (the Linear, not the Vec)\n    /// - `Vec<Param>[0]` → `\"Vec\"`\n    pub fn container_type(&self) -> String {\n        self.container_stack\n            .as_ref()\n            .and_then(|stack| stack.last())\n            .cloned()\n            .unwrap_or_else(|| \"Unknown\".to_string())\n    }\n\n    /// Create a TensorSnapshot from a closure that produces TensorData\n    /// This is used internally for lazy loading\n    pub fn from_closure(\n        data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,\n        dtype: burn_tensor::DType,\n        shape: Shape,\n        path_stack: Vec<String>,\n        container_stack: Vec<String>,\n        tensor_id: ParamId,\n    ) -> Self {\n        Self {\n            data_fn,\n            dtype,\n            shape,\n            path_stack: Some(path_stack),\n            container_stack: Some(container_stack),\n            tensor_id: Some(tensor_id),\n        }\n    }\n\n    /// Create a TensorSnapshot from TensorData directly\n    pub fn from_data(\n        data: TensorData,\n        path_stack: Vec<String>,\n        container_stack: Vec<String>,\n        tensor_id: ParamId,\n    ) -> Self {\n        let dtype = data.dtype;\n        let shape = data.shape.clone();\n        Self {\n            data_fn: Rc::new(move || Ok(data.clone())),\n            dtype,\n            shape,\n            path_stack: Some(path_stack),\n            container_stack: Some(container_stack),\n            tensor_id: Some(tensor_id),\n        }\n    }\n\n    /// Get the size of the tensor data in bytes without materializing it.\n    ///\n    /// For regular (non-quantized) types, this is simply `shape.product() * dtype.size()`.\n    ///\n    /// For quantized types (`QFloat`), this accounts for:\n    /// - The quantized values (packed according to the quantization scheme)\n    /// - Alignment padding (values are aligned to 4-byte boundary)\n    /// - Quantization parameters (scale values appended to the data)\n    pub fn data_len(&self) -> usize {\n        const BITS_PER_BYTE: usize = 8;\n\n        let num_elements: usize = self.shape.iter().product();\n\n        match self.dtype {\n            DType::QFloat(scheme) => {\n                // Calculate value bytes using scheme's packing information\n                let num_storage_elements = num_elements.div_ceil(scheme.num_quants());\n                let value_bytes =\n                    num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);\n\n                // Calculate number of quantization parameters (scales)\n                let num_params = params_shape(&self.shape, scheme.level).num_elements();\n\n                let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;\n                let scale_bytes = num_params * quant_param_size(scheme.param);\n\n                aligned_value_bytes + scale_bytes\n            }\n            _ => num_elements * self.dtype.size(),\n        }\n    }\n\n    /// Clone the data function for lazy composition\n    pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {\n        self.data_fn.clone()\n    }\n}\n\nimpl Clone for TensorSnapshot {\n    fn clone(&self) -> Self {\n        // Clone lazily - keep the same data function\n        Self {\n            data_fn: self.data_fn.clone(),\n            dtype: self.dtype,\n            shape: self.shape.clone(),\n            path_stack: self.path_stack.clone(),\n            container_stack: self.container_stack.clone(),\n            tensor_id: self.tensor_id,\n        }\n    }\n}\n\nimpl core::fmt::Debug for TensorSnapshot {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.debug_struct(\"TensorSnapshot\")\n            .field(\"dtype\", &self.dtype)\n            .field(\"shape\", &self.shape)\n            .field(\"path_stack\", &self.path_stack)\n            .field(\"container_stack\", &self.container_stack)\n            .field(\"tensor_id\", &self.tensor_id)\n            .finish()\n    }\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use super::*;\n    type TestBackend = burn_ndarray::NdArray;\n    use alloc::string::ToString;\n    use burn_tensor::{BoolStore, DType, shape};\n\n    #[test]\n    fn tensor_view_float() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n\n        let snapshot = TensorSnapshot::from_float(\n            &tensor,\n            vec![\"test\".to_string(), \"weight\".to_string()],\n            vec![\"TestModule\".to_string(), \"Param\".to_string()],\n            ParamId::new(),\n        );\n\n        // Test metadata access without materialization\n        assert_eq!(snapshot.dtype, DType::F32);\n        assert_eq!(snapshot.shape, shape![2, 2]);\n        assert_eq!(snapshot.full_path(), \"test.weight\");\n        assert_eq!(snapshot.container_path(), \"TestModule.Param\");\n\n        // Test data materialization\n        let data = snapshot.to_data().unwrap();\n        assert_eq!(data.shape, shape![2, 2]);\n        assert_eq!(data.dtype, DType::F32);\n    }\n\n    #[test]\n    fn tensor_view_int() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);\n\n        let snapshot = TensorSnapshot::from_int(\n            &tensor,\n            vec![\"test\".to_string(), \"int\".to_string()],\n            vec![\"TestModule\".to_string(), \"Param\".to_string()],\n            ParamId::new(),\n        );\n\n        // Test metadata access without materialization\n        // TestBackend uses I64 for integers\n        assert_eq!(snapshot.dtype, DType::I64);\n        assert_eq!(snapshot.shape, shape![2, 2]);\n\n        let data = snapshot.to_data().unwrap();\n        assert_eq!(data.shape, shape![2, 2]);\n        assert_eq!(data.dtype, DType::I64);\n    }\n\n    #[test]\n    fn tensor_view_bool() {\n        let device = Default::default();\n        let tensor =\n            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);\n\n        let snapshot = TensorSnapshot::from_bool(\n            &tensor,\n            vec![\"test\".to_string(), \"bool\".to_string()],\n            vec![\"TestModule\".to_string(), \"Param\".to_string()],\n            ParamId::new(),\n        );\n\n        // Test metadata access without materialization\n        assert_eq!(snapshot.dtype, DType::Bool(BoolStore::Native));\n        assert_eq!(snapshot.shape, shape![2, 2]);\n\n        let data = snapshot.to_data().unwrap();\n        assert_eq!(data.shape, shape![2, 2]);\n        assert_eq!(data.dtype, DType::Bool(BoolStore::Native));\n    }\n\n    #[test]\n    fn data_len() {\n        let device = Default::default();\n\n        // Test F32 tensor (4 bytes per element)\n        let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);\n        let view_f32 = TensorSnapshot::from_float(\n            &tensor_f32,\n            vec![\"test\".to_string()],\n            vec![\"Module\".to_string()],\n            ParamId::new(),\n        );\n        assert_eq!(view_f32.data_len(), 16); // 4 elements * 4 bytes\n\n        // Test I64 tensor (8 bytes per element) - TestBackend uses I64 for Int\n        let tensor_i64 =\n            Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);\n        let view_i64 = TensorSnapshot::from_int(\n            &tensor_i64,\n            vec![\"test\".to_string()],\n            vec![\"Module\".to_string()],\n            ParamId::new(),\n        );\n        assert_eq!(view_i64.data_len(), 64); // 8 elements * 8 bytes (I64)\n\n        // Test Bool tensor (1 byte per element)\n        let tensor_bool =\n            Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);\n        let view_bool = TensorSnapshot::from_bool(\n            &tensor_bool,\n            vec![\"test\".to_string()],\n            vec![\"Module\".to_string()],\n            ParamId::new(),\n        );\n        assert_eq!(view_bool.data_len(), 4); // 4 elements * 1 byte\n    }\n\n    #[test]\n    fn from_closure() {\n        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);\n        let dtype = data.dtype;\n        let shape = data.shape.clone();\n\n        let snapshot = TensorSnapshot::from_closure(\n            Rc::new(move || Ok(data.clone())),\n            dtype,\n            shape.clone(),\n            vec![\"model\".to_string(), \"layer\".to_string()],\n            vec![\"Model\".to_string(), \"Layer\".to_string()],\n            ParamId::new(),\n        );\n\n        // Test metadata access\n        assert_eq!(snapshot.dtype, DType::F32);\n        assert_eq!(snapshot.shape, shape![4]);\n        assert_eq!(snapshot.full_path(), \"model.layer\");\n        assert_eq!(snapshot.data_len(), 16); // 4 * 4 bytes\n\n        // Test data materialization\n        let materialized = snapshot.to_data().unwrap();\n        assert_eq!(materialized.shape, shape![4]);\n    }\n\n    #[test]\n    fn from_data() {\n        let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);\n        let original_dtype = data.dtype;\n        let original_shape = data.shape.clone();\n\n        let snapshot = TensorSnapshot::from_data(\n            data,\n            vec![\"encoder\".to_string(), \"weight\".to_string()],\n            vec![\"Struct:Encoder\".to_string(), \"Struct:Dense\".to_string()],\n            ParamId::new(),\n        );\n\n        // Test metadata\n        assert_eq!(snapshot.dtype, original_dtype);\n        assert_eq!(snapshot.shape, original_shape);\n        assert_eq!(snapshot.full_path(), \"encoder.weight\");\n        assert_eq!(snapshot.container_type(), \"Struct:Dense\");\n        assert_eq!(snapshot.data_len(), 24); // 6 * 4 bytes\n\n        // Test data materialization\n        let materialized = snapshot.to_data().unwrap();\n        assert_eq!(materialized.shape, original_shape);\n    }\n\n    #[test]\n    #[cfg(feature = \"std\")]\n    fn panic_catching_in_to_data() {\n        use alloc::rc::Rc;\n\n        // Create a TensorSnapshot with a closure that panics\n        let snapshot = TensorSnapshot {\n            data_fn: Rc::new(|| panic!(\"Test panic in data_fn\")),\n            dtype: DType::F32,\n            shape: shape![2, 2],\n            path_stack: Some(vec![\"test\".to_string()]),\n            container_stack: Some(vec![\"Test\".to_string()]),\n            tensor_id: Some(ParamId::new()),\n        };\n\n        // When std is available, to_data should catch the panic and return an error\n        let result = snapshot.to_data();\n        assert!(result.is_err());\n\n        match result {\n            Err(TensorSnapshotError::PanicError(msg)) => {\n                assert!(msg.contains(\"Panic occurred\"));\n            }\n            _ => panic!(\"Expected PanicError with panic message\"),\n        }\n    }\n\n    #[test]\n    fn error_propagation_in_closure() {\n        use alloc::rc::Rc;\n\n        // Create a snapshot with a closure that returns an error\n        let snapshot = TensorSnapshot::from_closure(\n            Rc::new(|| Err(TensorSnapshotError::IoError(\"Simulated IO error\".into()))),\n            DType::F32,\n            shape![2, 2],\n            vec![\"error_test\".into()],\n            vec![],\n            ParamId::new(),\n        );\n\n        // Should return an error when trying to get data\n        let result = snapshot.to_data();\n        assert!(result.is_err());\n        match result {\n            Err(TensorSnapshotError::IoError(msg)) => {\n                assert!(msg.contains(\"Simulated IO error\"));\n            }\n            _ => panic!(\"Expected IoError\"),\n        }\n    }\n\n    #[test]\n    fn container_type_extraction() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);\n\n        let snapshot = TensorSnapshot::from_float(\n            &tensor,\n            vec![\n                \"model\".to_string(),\n                \"layer1\".to_string(),\n                \"weight\".to_string(),\n            ],\n            vec![\n                \"Struct:Model\".to_string(),\n                \"Struct:Conv2d\".to_string(),\n                \"Struct:Param\".to_string(),\n            ],\n            ParamId::new(),\n        );\n\n        assert_eq!(snapshot.container_type(), \"Struct:Param\");\n        assert_eq!(snapshot.module_type(), Some(\"Struct:Param\".to_string()));\n        assert_eq!(\n            snapshot.container_path(),\n            \"Struct:Model.Struct:Conv2d.Struct:Param\"\n        );\n        assert_eq!(snapshot.full_path(), \"model.layer1.weight\");\n    }\n\n    #[test]\n    fn container_type_vs_module_type() {\n        let device = Default::default();\n        let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);\n\n        // Test case 1: Tensor inside a Vec<Linear>\n        // container_stack: [\"Struct:Model\", \"Vec\", \"Struct:Linear\"]\n        let snapshot = TensorSnapshot::from_float(\n            &tensor,\n            vec![\n                \"model\".to_string(),\n                \"layers\".to_string(),\n                \"0\".to_string(),\n                \"weight\".to_string(),\n            ],\n            vec![\n                \"Struct:Model\".to_string(),\n                \"Vec\".to_string(),\n                \"Struct:Linear\".to_string(),\n            ],\n            ParamId::new(),\n        );\n\n        // container_type() returns the last element (Struct:Linear in this case)\n        assert_eq!(snapshot.container_type(), \"Struct:Linear\");\n        // module_type() also returns Some(Struct:Linear) (skipping Vec)\n        assert_eq!(snapshot.module_type(), Some(\"Struct:Linear\".to_string()));\n\n        // Test case 2: Tensor that's just in a Vec\n        // container_stack: [\"Vec\"]\n        let snapshot2 = TensorSnapshot::from_float(\n            &tensor,\n            vec![\"data\".to_string(), \"0\".to_string()],\n            vec![\"Vec\".to_string()],\n            ParamId::new(),\n        );\n\n        // container_type() returns Vec\n        assert_eq!(snapshot2.container_type(), \"Vec\");\n        // module_type() returns None (no Struct/Enum found)\n        assert_eq!(snapshot2.module_type(), None);\n\n        // Test case 3: Nested collections\n        // container_stack: [\"Struct:Model\", \"Vec\", \"Array\", \"Struct:Linear\"]\n        let snapshot3 = TensorSnapshot::from_float(\n            &tensor,\n            vec![\n                \"model\".to_string(),\n                \"layers\".to_string(),\n                \"0\".to_string(),\n                \"sublayers\".to_string(),\n                \"1\".to_string(),\n                \"weight\".to_string(),\n            ],\n            vec![\n                \"Struct:Model\".to_string(),\n                \"Vec\".to_string(),\n                \"Array\".to_string(),\n                \"Struct:Linear\".to_string(),\n            ],\n            ParamId::new(),\n        );\n\n        // container_type() returns the immediate container\n        assert_eq!(snapshot3.container_type(), \"Struct:Linear\");\n        // module_type() returns the last Struct/Enum\n        assert_eq!(snapshot3.module_type(), Some(\"Struct:Linear\".to_string()));\n    }\n}\n"
  },
  {
    "path": "crates/burn-store/src/traits.rs",
    "content": "use alloc::boxed::Box;\nuse alloc::collections::BTreeMap;\nuse alloc::string::String;\nuse alloc::vec::Vec;\n\nuse super::applier::Applier;\nuse super::apply_result::ApplyResult;\nuse crate::collector::Collector;\nuse crate::{ModuleAdapter, PathFilter, TensorSnapshot};\nuse burn_core::module::Module;\nuse burn_tensor::backend::Backend;\n\n/// Extension trait for modules that provides tensor storage functionality.\n///\n/// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module.\n/// Collection operations create lightweight tensor snapshots without immediately copying data.\n/// Apply operations apply tensor data from snapshots to the corresponding tensors in the module.\npub trait ModuleSnapshot<B: Backend>: Module<B> {\n    /// Collects tensor snapshots for inspection without copying data.\n    ///\n    /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data.\n    /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`.\n    ///\n    /// # Arguments\n    ///\n    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.\n    ///   When `None`, all tensors are collected.\n    /// * `adapter` - Optional adapter to transform tensors based on container types.\n    ///   Applied to all collected tensors before returning.\n    /// * `skip_enum_variants` - Skip enum variant names when building paths.\n    ///   When true, paths will not include enum variant names (e.g., \"feature.weight\"\n    ///   instead of \"feature.BaseConv.weight\"). Useful when exporting to formats\n    ///   like PyTorch/SafeTensors that don't use enum variants.\n    fn collect(\n        &self,\n        filter: Option<PathFilter>,\n        adapter: Option<Box<dyn ModuleAdapter>>,\n        skip_enum_variants: bool,\n    ) -> Vec<TensorSnapshot> {\n        let mut collector = Collector::new(filter, adapter, skip_enum_variants);\n        self.visit(&mut collector);\n        collector.into_tensors()\n    }\n\n    /// Applies tensor snapshots to the module.\n    ///\n    /// This is the primary apply method that applies tensor data from `TensorSnapshot`s\n    /// to the corresponding tensors in the module. The snapshots are typically obtained\n    /// from `collect()` or loaded from storage.\n    ///\n    /// # Arguments\n    ///\n    /// * `snapshots` - A vector of TensorSnapshot objects\n    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.\n    ///   When `None`, all available tensors are applied.\n    /// * `adapter` - Optional adapter to transform tensors based on container types\n    /// * `skip_enum_variants` - Skip enum variant names when matching tensor paths\n    ///\n    /// # Returns\n    ///\n    /// An [`ApplyResult`] containing information about applied, skipped, missing,\n    /// and unused tensors, as well as any errors encountered.\n    ///\n    /// # Examples\n    ///\n    /// ```rust,ignore\n    /// use burn_store::PathFilter;\n    ///\n    /// // Apply all tensors\n    /// let result = model.apply(snapshots, None, None, false);\n    ///\n    /// // Apply only encoder tensors\n    /// let filter = PathFilter::new().with_regex(r\"^encoder\\..*\");\n    /// let result = model.apply(snapshots, Some(filter), None, false);\n    ///\n    /// // Apply with complex filter\n    /// let filter = PathFilter::new()\n    ///     .with_regex(r\"^encoder\\..*\")\n    ///     .with_regex(r\"^decoder\\..*\")\n    ///     .with_full_path(\"head.weight\");\n    /// let result = model.apply(snapshots, Some(filter), None, false);\n    ///\n    /// // Apply with enum variant skipping (for PyTorch models)\n    /// let result = model.apply(snapshots, None, None, true);\n    /// ```\n    fn apply(\n        &mut self,\n        snapshots: Vec<TensorSnapshot>,\n        filter: Option<PathFilter>,\n        adapter: Option<Box<dyn ModuleAdapter>>,\n        skip_enum_variants: bool,\n    ) -> ApplyResult\n    where\n        Self: Sized,\n    {\n        let mut applier = Applier::new(snapshots, filter, adapter, skip_enum_variants);\n\n        // Use unsafe to avoid cloning the entire module, which would double the memory usage\n        // We read the module out, map it, then write it back\n        // See https://github.com/tracel-ai/burn/issues/3754\n        unsafe {\n            // Read the module out of self (moves it, leaving self in undefined state)\n            let module = core::ptr::read(self as *const Self);\n\n            // Map the module to create a new one with updated tensors\n            let new_module = module.map(&mut applier);\n\n            // Write the new module back to self\n            core::ptr::write(self as *mut Self, new_module);\n        }\n\n        applier.into_result()\n    }\n\n    /// Saves tensor snapshots into a [`ModuleStore`].\n    ///\n    /// This method allows using a `ModuleStore` implementation to handle the\n    /// collection and writing logic in a configurable way.\n    ///\n    /// # Arguments\n    ///\n    /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors\n    fn save_into<P>(&self, store: &mut P) -> Result<(), P::Error>\n    where\n        P: ModuleStore,\n    {\n        store.collect_from(self)\n    }\n\n    /// Loads tensor data from a [`ModuleStore`].\n    ///\n    /// This method allows using a `ModuleStore` implementation to handle the\n    /// loading and application logic in a configurable way.\n    ///\n    /// # Arguments\n    ///\n    /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors\n    fn load_from<P>(&mut self, store: &mut P) -> Result<ApplyResult, P::Error>\n    where\n        P: ModuleStore,\n    {\n        store.apply_to(self)\n    }\n}\n\n/// A trait for handling module storage operations.\n///\n/// `ModuleStore` provides a unified interface for saving and loading module\n/// tensor data with support for various storage formats and advanced features like filtering,\n/// remapping, and metadata handling.\npub trait ModuleStore {\n    /// The error type that can be returned during storage operations.\n    ///\n    /// This should be a format-specific error type that provides detailed\n    /// information about what went wrong (e.g., I/O errors, format violations,\n    /// unsupported tensor types).\n    type Error: core::fmt::Debug + core::fmt::Display;\n\n    /// Collect tensor data from a module and store it to storage.\n    ///\n    /// This method traverses the module structure, collects all tensor data\n    /// according to the store's configuration (filters, remapping, etc.),\n    /// and writes it to the underlying storage.\n    ///\n    /// # Arguments\n    ///\n    /// * `module` - The module to collect tensor data from. The module must\n    ///   implement `ModuleSnapshot` to provide tensor access.\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(())` - If all tensors were successfully collected and stored\n    /// * `Err(Self::Error)` - If an error occurred during collection or writing\n    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &M,\n    ) -> Result<(), Self::Error>;\n\n    /// Load stored tensor data and apply it to a module.\n    ///\n    /// This method reads tensor data from storage and applies it to the provided\n    /// module. The operation is flexible and can handle partial matches, missing\n    /// tensors, and extra tensors in the storage.\n    ///\n    /// # Arguments\n    ///\n    /// * `module` - The module to apply tensor data to. The module must\n    ///   implement `ModuleSnapshot` to allow tensor updates.\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(ApplyResult)` - Detailed information about the apply operation:\n    ///   - `applied`: List of successfully applied tensor names\n    ///   - `missing`: Tensors expected by the module but not found in storage\n    ///   - `skipped`: Tensors in storage that were not applied (filtered or not needed)\n    ///   - `errors`: Non-critical errors that occurred during apply\n    /// * `Err(Self::Error)` - If a critical error prevented the apply operation\n    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(\n        &mut self,\n        module: &mut M,\n    ) -> Result<ApplyResult, Self::Error>;\n\n    /// Get a single tensor snapshot by name.\n    ///\n    /// This method provides direct access to individual tensors in storage without\n    /// requiring a module. The returned `TensorSnapshot` uses lazy loading - tensor\n    /// data is only materialized when `to_data()` is called.\n    ///\n    /// **Note:** Key remapping is applied, so use the remapped name if configured.\n    /// Filters are NOT applied - use `apply_to()` for filtered loading.\n    ///\n    /// Results are cached after the first call for efficient repeated access.\n    ///\n    /// # Arguments\n    ///\n    /// * `name` - The tensor name/path (e.g., \"encoder.layer1.weight\")\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(Some(&TensorSnapshot))` - Reference to the tensor snapshot if found\n    /// * `Ok(None)` - If no tensor with that name exists\n    /// * `Err(Self::Error)` - If an error occurred accessing storage\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// let mut store = BurnpackStore::from_file(\"model.bpk\");\n    /// if let Some(snapshot) = store.get_snapshot(\"encoder.weight\")? {\n    ///     println!(\"Shape: {:?}\", snapshot.shape);\n    ///     println!(\"Dtype: {:?}\", snapshot.dtype);\n    ///     let data = snapshot.to_data()?;  // Lazy load\n    /// }\n    /// ```\n    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error>;\n\n    /// Get all tensor snapshots from storage as an ordered map.\n    ///\n    /// This method returns all tensors in storage as lazy-loading snapshots,\n    /// organized in a `BTreeMap` for efficient lookup by name. The map preserves\n    /// alphabetical ordering of tensor names.\n    ///\n    /// **Note:** This returns ALL tensors in storage, regardless of any filter\n    /// settings. Filters are only applied during `apply_to()`. Key remapping\n    /// IS applied, so tensor names reflect any configured remapping.\n    ///\n    /// Results are cached after the first call for efficient repeated access.\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(&BTreeMap<String, TensorSnapshot>)` - Reference to all tensor snapshots\n    /// * `Err(Self::Error)` - If an error occurred accessing storage\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// let mut store = SafetensorsStore::from_file(\"model.safetensors\");\n    /// let snapshots = store.get_all_snapshots()?;\n    /// for (name, snapshot) in snapshots {\n    ///     println!(\"{}: {:?}\", name, snapshot.shape);\n    /// }\n    /// ```\n    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error>;\n\n    /// Get all tensor names/keys in storage.\n    ///\n    /// This method returns the names of all tensors in storage.\n    /// Useful for inspecting storage contents or checking if specific tensors exist.\n    ///\n    /// **Note:** Returns ALL tensor names regardless of filter settings.\n    /// Key remapping IS applied, so names reflect any configured remapping.\n    ///\n    /// # Returns\n    ///\n    /// * `Ok(Vec<String>)` - All tensor names in storage\n    /// * `Err(Self::Error)` - If an error occurred accessing storage\n    ///\n    /// # Example\n    ///\n    /// ```rust,ignore\n    /// let mut store = PytorchStore::from_file(\"model.pth\");\n    /// let keys = store.keys()?;\n    /// println!(\"Tensors in file: {:?}\", keys);\n    /// ```\n    fn keys(&mut self) -> Result<Vec<String>, Self::Error>;\n}\n\n// Blanket implementation for all modules\nimpl<B: Backend, M: Module<B>> ModuleSnapshot<B> for M {}\n"
  },
  {
    "path": "crates/burn-tch/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"LibTorch backend for the Burn framework using the tch bindings.\"\ndocumentation = \"https://docs.rs/burn-tch\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"data\"]\nlicense.workspace = true\nname = \"burn-tch\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-tch\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\nstd = [\"burn-backend/std\"]\ndoc = [\"tch/doc-only\"]\ntracing = [\n    \"burn-backend/tracing\",\n]\n\n[dependencies]\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n\nlibc = { workspace = true }\nlog = { workspace = true }\ntch = { workspace = true, features = [\"download-libtorch\"] }\ntorch-sys = { workspace = true }                             # for build script lib dir detection\n\n[build-dependencies]\ncc = \"1.2.56\"\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-tch/README.md",
    "content": "# Burn Torch Backend\n\n[Burn](https://github.com/tracel-ai/burn) Torch backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tch.svg)](https://crates.io/crates/burn-tch)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tch/blob/master/README.md)\n\nThis crate provides a Torch backend for [Burn](https://github.com/tracel-ai/burn) utilizing the\n[`tch-rs`](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the\n[PyTorch](https://pytorch.org/) C++ API.\n\nThe backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html)\n(multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS).\n\n## Installation\n\n[`tch-rs`](https://github.com/LaurentMazare/tch-rs) requires the C++ PyTorch library (LibTorch) to\nbe available on your system.\n\nBy default, the CPU distribution is installed for LibTorch v2.9.0 as required by `tch-rs`.\n\n<details>\n<summary><strong>CUDA</strong></summary>\n\nTo install the latest compatible CUDA distribution, set the `TORCH_CUDA_VERSION` environment\nvariable before the `tch-rs` dependency is retrieved with `cargo`.\n\n```shell\nexport TORCH_CUDA_VERSION=cu128\n```\n\nOn Windows:\n\n```powershell\n$Env:TORCH_CUDA_VERSION = \"cu128\"\n```\n\n> Note: `tch` doesn't expose the downloaded libtorch directory on Windows when using the automatic\n> download feature, so the `torch_cuda.dll` cannot be detected properly during build. In this case,\n> you can set the `LIBTORCH` environment variable to point to the `libtorch/` folder in `torch-sys`\n> `OUT_DIR` (or move the downloaded lib to a different folder and point to it).\n\nFor example, running the validation sample for the first time could be done with the following\ncommands:\n\n```shell\nexport TORCH_CUDA_VERSION=cu128\ncargo run --bin cuda --release\n```\n\n**Important:** make sure your driver version is compatible with the selected CUDA version. A CUDA\nToolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having\nthe latest driver version is recommended, but you can always take a look at the\n[toolkit driver version table](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id4)\nor\n[minimum required driver version](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#minor-version-compatibility)\n(limited feature-set, might not work with all operations).\n\n</details><br>\n\nOnce your installation is complete, you should be able to build/run your project. You can also\nvalidate your installation by running the appropriate `cpu`, `cuda` or `mps` sample as below.\n\n```shell\ncargo run --bin cpu --release\ncargo run --bin cuda --release\ncargo run --bin mps --release\n```\n\n_Note: no MPS distribution is available for automatic download at this time, please check out the\n[manual instructions](#metal-mps)._\n\n### Manual Download\n\nTo install `tch-rs` with a different LibTorch distribution, you will have to manually download the\ndesired LibTorch distribution. The instructions are detailed in the sections below for each\nplatform.\n\n| Compute Platform          |              CPU               | GPU | Linux | MacOS | Windows | Android | iOS | WASM |\n| :------------------------ | :----------------------------: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |\n| [CPU](#cpu)               |              Yes               | No  |  Yes  |  Yes  |   Yes   |   Yes   | Yes |  No  |\n| [CUDA](#cuda)             | Yes <sup>[[1]](#cpu-sup)</sup> | Yes |  Yes  |  No   |   Yes   |   No    | No  |  No  |\n| [Metal (MPS)](#metal-mps) |               No               | Yes |  No   |  Yes  |   No    |   No    | No  |  No  |\n| Vulkan                    |              Yes               | Yes |  Yes  |  Yes  |   Yes   |   Yes   | No  |  No  |\n\n<sup><a id=\"cpu-sup\">[1]</a> The LibTorch CUDA distribution also comes with CPU support.</sup>\n\n#### CPU\n\n<details open>\n<summary><strong>🐧 Linux</strong></summary>\n\nFirst, download the LibTorch CPU distribution.\n\n```shell\nwget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.9.0%2Bcpu.zip\nunzip libtorch.zip\n```\n\nThen, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables\nbefore building `burn-tch` or a crate which depends on it.\n\n```shell\nexport LIBTORCH=/absolute/path/to/libtorch/\nexport LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH\n```\n\n</details><br>\n\n<details>\n<summary><strong>🍎 Mac</strong></summary>\n\nFirst, download the LibTorch CPU distribution.\n\n```shell\nwget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.0.zip\nunzip libtorch.zip\n```\n\nThen, point to that installation using the `LIBTORCH` and `DYLD_LIBRARY_PATH` environment variables\nbefore building `burn-tch` or a crate which depends on it.\n\n```shell\nexport LIBTORCH=/absolute/path/to/libtorch/\nexport DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH\n```\n\n</details><br>\n\n<details>\n<summary><strong>🪟 Windows</strong></summary>\n\nFirst, download the LibTorch CPU distribution.\n\n```powershell\nwget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.0%2Bcpu.zip -OutFile libtorch.zip\nExpand-Archive libtorch.zip\n```\n\nThen, set the `LIBTORCH` environment variable and append the library to your path as with the\nPowerShell commands below before building `burn-tch` or a crate which depends on it.\n\n```powershell\n$Env:LIBTORCH = \"/absolute/path/to/libtorch/\"\n$Env:Path += \";/absolute/path/to/libtorch/\"\n```\n\n</details><br>\n\n#### CUDA\n\nLibTorch 2.9.0 currently includes binary distributions with CUDA 12.6, 12.8 or 13.0 runtimes. The\nmanual installation instructions are detailed below for CUDA 12.6, but can be applied to the other\nCUDA versions by replacing `cu126` with the corresponding version string (e.g., `cu130`).\n\n<details open>\n<summary><strong>🐧 Linux</strong></summary>\n\nFirst, download the LibTorch CUDA 12.6 distribution.\n\n```shell\nwget -O libtorch.zip https://download.pytorch.org/libtorch/cu126/libtorch-shared-with-deps-2.9.0%2Bcu126.zip\nunzip libtorch.zip\n```\n\nThen, point to that installation using the `LIBTORCH` and `LD_LIBRARY_PATH` environment variables\nbefore building `burn-tch` or a crate which depends on it.\n\n```shell\nexport LIBTORCH=/absolute/path/to/libtorch/\nexport LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH\n```\n\n**Note:** make sure your CUDA installation is in your `PATH` and `LD_LIBRARY_PATH`.\n\n</details><br>\n\n<details>\n<summary><strong>🪟 Windows</strong></summary>\n\nFirst, download the LibTorch CUDA 12.6 distribution.\n\n```powershell\nwget https://download.pytorch.org/libtorch/cu126/libtorch-win-shared-with-deps-2.9.0%2Bcu126.zip -OutFile libtorch.zip\nExpand-Archive libtorch.zip\n```\n\nThen, set the `LIBTORCH` environment variable and append the library to your path as with the\nPowerShell commands below before building `burn-tch` or a crate which depends on it.\n\n```powershell\n$Env:LIBTORCH = \"/absolute/path/to/libtorch/\"\n$Env:Path += \";/absolute/path/to/libtorch/\"\n```\n\n</details><br>\n\n#### Metal (MPS)\n\nThere is no official LibTorch distribution with MPS support at this time, so the easiest alternative\nis to use a PyTorch installation. This requires a Python installation.\n\n_Note: MPS acceleration is available on MacOS 12.3+._\n\n```shell\npip install torch==2.9.0 numpy==1.26.4 setuptools\nexport LIBTORCH_USE_PYTORCH=1\nexport DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH\n```\n\n**Note:** if `venv` is used, it should be activated during coding and building, or the compiler may\nnot work properly.\n\n## Example Usage\n\nFor a simple example, check out any of the test programs in [`src/bin/`](./src/bin/). Each program\nsets the device to use and performs a simple element-wise addition.\n\nFor a more complete example using the `tch` backend, take a loot at the\n[Burn mnist example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).\n\n## Too many environment variables?\n\nTry `.cargo/config.toml` ([cargo book](https://doc.rust-lang.org/cargo/reference/config.html#env)).\n\nInstead of setting the environments in your shell, you can manually add them to your\n`.cargo/config.toml`:\n\n```toml\n[env]\nLD_LIBRARY_PATH = \"/absolute/path/to/libtorch/lib\"\nLIBTORCH = \"/absolute/path/to/libtorch/libtorch\"\n```\n\nOr use bash commands below:\n\n```bash\nmkdir .cargo\ncat <<EOF > .cargo/config.toml\n[env]\nLD_LIBRARY_PATH = \"/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH\"\nLIBTORCH = \"/absolute/path/to/libtorch/libtorch\"\nEOF\n```\n\nThis will automatically include the old `LD_LIBRARY_PATH` value in the new one.\n"
  },
  {
    "path": "crates/burn-tch/build.rs",
    "content": "// The LIBTORCH environment variable can be used to specify the directory\n// where libtorch has been installed.\n// When not specified this script downloads the cpu version for libtorch\n// and extracts it in OUT_DIR.\n//\n// On Linux, the TORCH_CUDA_VERSION environment variable can be used,\n// like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch.\n\nuse std::path::{Path, PathBuf};\nuse std::{env, fs};\n\nconst PYTHON_PRINT_PYTORCH_DETAILS: &str = r\"\nimport torch\nfrom torch.utils import cpp_extension\nprint('LIBTORCH_VERSION:', torch.__version__.split('+')[0])\nprint('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI)\nfor include_path in cpp_extension.include_paths():\n  print('LIBTORCH_INCLUDE:', include_path)\nfor library_path in cpp_extension.library_paths():\n  print('LIBTORCH_LIB:', library_path)\n\";\n\n#[derive(Debug, Clone, Copy, PartialEq, Eq)]\nenum Os {\n    Linux,\n    Macos,\n    Windows,\n}\n\n#[allow(dead_code)]\n#[derive(Debug, Clone)]\nstruct SystemInfo {\n    os: Os,\n    cxx11_abi: String,\n    libtorch_include_dirs: Vec<PathBuf>,\n    libtorch_lib_dir: PathBuf,\n}\n\nfn env_var_rerun(name: &str) -> Result<String, env::VarError> {\n    println!(\"cargo:rerun-if-env-changed={name}\");\n    env::var(name)\n}\n\nimpl SystemInfo {\n    fn new() -> Option<Self> {\n        let os = match env::var(\"CARGO_CFG_TARGET_OS\")\n            .expect(\"Unable to get TARGET_OS\")\n            .as_str()\n        {\n            \"linux\" => Os::Linux,\n            \"windows\" => Os::Windows,\n            \"macos\" => Os::Macos,\n            os => panic!(\"unsupported TARGET_OS '{os}'\"),\n        };\n        // Locate the currently active Python binary, similar to:\n        // https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547\n        let python_interpreter = match os {\n            Os::Windows => PathBuf::from(\"python.exe\"),\n            Os::Linux | Os::Macos => {\n                if env::var_os(\"VIRTUAL_ENV\").is_some() {\n                    PathBuf::from(\"python\")\n                } else {\n                    PathBuf::from(\"python3\")\n                }\n            }\n        };\n        let mut libtorch_include_dirs = vec![];\n        let mut libtorch_lib_dir = None;\n        let cxx11_abi = if env_var_rerun(\"LIBTORCH_USE_PYTORCH\").is_ok() {\n            let output = std::process::Command::new(&python_interpreter)\n                .arg(\"-c\")\n                .arg(PYTHON_PRINT_PYTORCH_DETAILS)\n                .output()\n                .expect(\"error running python interpreter\");\n            let mut cxx11_abi = None;\n            for line in String::from_utf8_lossy(&output.stdout).lines() {\n                match line.strip_prefix(\"LIBTORCH_CXX11: \") {\n                    Some(\"True\") => cxx11_abi = Some(\"1\".to_owned()),\n                    Some(\"False\") => cxx11_abi = Some(\"0\".to_owned()),\n                    _ => {}\n                }\n                if let Some(path) = line.strip_prefix(\"LIBTORCH_INCLUDE: \") {\n                    libtorch_include_dirs.push(PathBuf::from(path))\n                }\n                if let Some(path) = line.strip_prefix(\"LIBTORCH_LIB: \") {\n                    libtorch_lib_dir = Some(PathBuf::from(path))\n                }\n            }\n            match cxx11_abi {\n                Some(cxx11_abi) => cxx11_abi,\n                None => panic!(\"no cxx11 abi returned by python {output:?}\"),\n            }\n        } else {\n            let libtorch = Self::prepare_libtorch_dir(os)?;\n            let includes = env_var_rerun(\"LIBTORCH_INCLUDE\")\n                .map(PathBuf::from)\n                .unwrap_or_else(|_| libtorch.clone());\n            let lib = env_var_rerun(\"LIBTORCH_LIB\")\n                .map(PathBuf::from)\n                .unwrap_or_else(|_| libtorch.clone());\n            libtorch_include_dirs.push(includes.join(\"include\"));\n            libtorch_include_dirs.push(includes.join(\"include/torch/csrc/api/include\"));\n            if lib.ends_with(\"lib\") {\n                // DEP_TCH_LIBTORCH_LIB might already point to /lib\n                libtorch_lib_dir = Some(lib);\n            } else {\n                libtorch_lib_dir = Some(lib.join(\"lib\"));\n            }\n            env_var_rerun(\"LIBTORCH_CXX11_ABI\").unwrap_or_else(|_| \"1\".to_owned())\n        };\n        let libtorch_lib_dir = libtorch_lib_dir?;\n        Some(Self {\n            os,\n            cxx11_abi,\n            libtorch_include_dirs,\n            libtorch_lib_dir,\n        })\n    }\n\n    fn check_system_location(os: Os) -> Option<PathBuf> {\n        match os {\n            Os::Linux => Path::new(\"/usr/lib/libtorch.so\")\n                .exists()\n                .then(|| PathBuf::from(\"/usr\")),\n            _ => None,\n        }\n    }\n\n    fn prepare_libtorch_dir(os: Os) -> Option<PathBuf> {\n        if let Ok(libtorch) = env_var_rerun(\"DEP_TCH_LIBTORCH_LIB\") {\n            Some(PathBuf::from(libtorch))\n        } else if let Ok(libtorch) = env_var_rerun(\"LIBTORCH\") {\n            Some(PathBuf::from(libtorch))\n        } else if let Some(pathbuf) = Self::check_system_location(os) {\n            Some(pathbuf)\n        } else {\n            check_out_dir()\n        }\n    }\n\n    fn make(&self, use_cuda: bool, use_hip: bool) {\n        let cuda_dependency = if use_cuda || use_hip {\n            \"src/cuda_hack/dummy_cuda_dependency.cpp\"\n        } else {\n            \"src/cuda_hack/fake_cuda_dependency.cpp\"\n        };\n        println!(\"cargo:rerun-if-changed={cuda_dependency}\");\n\n        match self.os {\n            Os::Linux | Os::Macos => {\n                cc::Build::new()\n                    .cpp(true)\n                    .pic(true)\n                    .warnings(false)\n                    .includes(&self.libtorch_include_dirs)\n                    .flag(format!(\"-Wl,-rpath={}\", self.libtorch_lib_dir.display()))\n                    .flag(\"-std=c++17\")\n                    .flag(format!(\"-D_GLIBCXX_USE_CXX11_ABI={}\", self.cxx11_abi))\n                    .files(&[cuda_dependency])\n                    .compile(\"burn-tch\");\n            }\n            Os::Windows => {\n                cc::Build::new()\n                    .cpp(true)\n                    .pic(true)\n                    .warnings(false)\n                    .includes(&self.libtorch_include_dirs)\n                    .flag(\"/std:c++17\")\n                    .files(&[cuda_dependency])\n                    .compile(\"burn-tch\");\n            }\n        };\n    }\n\n    fn make_cpu() {\n        let cuda_dependency = \"src/cuda_hack/fake_cuda_dependency.cpp\";\n        println!(\"cargo:rerun-if-changed={cuda_dependency}\");\n\n        let os = env::var(\"CARGO_CFG_TARGET_OS\").expect(\"Unable to get TARGET_OS\");\n\n        match os.as_str() {\n            \"windows\" => {\n                cc::Build::new()\n                    .cpp(true)\n                    .pic(true)\n                    .warnings(false)\n                    .flag(\"/std:c++17\")\n                    .files(&[cuda_dependency])\n                    .compile(\"burn-tch\");\n            }\n            _ => {\n                cc::Build::new()\n                    .cpp(true)\n                    .pic(true)\n                    .warnings(false)\n                    .flag(\"-std=c++17\")\n                    .files(&[cuda_dependency])\n                    .compile(\"tch\");\n            }\n        };\n    }\n}\n\nfn check_out_dir() -> Option<PathBuf> {\n    let out_dir = env_var_rerun(\"OUT_DIR\").ok()?;\n    let libtorch_dir = PathBuf::from(out_dir).join(\"libtorch\");\n    libtorch_dir.exists().then_some(libtorch_dir)\n}\n\nfn main() {\n    let system_info = SystemInfo::new();\n    let out_dir = env_var_rerun(\"OUT_DIR\").expect(\"Failed to get out dir\");\n\n    let mut gpu_found = false;\n    let found_dir = system_info.is_some();\n    if let Some(system_info) = &system_info {\n        let si_lib = &system_info.libtorch_lib_dir;\n        let use_cuda =\n            si_lib.join(\"libtorch_cuda.so\").exists() || si_lib.join(\"torch_cuda.dll\").exists();\n        let use_hip =\n            si_lib.join(\"libtorch_hip.so\").exists() || si_lib.join(\"torch_hip.dll\").exists();\n\n        system_info.make(use_cuda, use_hip);\n        gpu_found = use_cuda || use_hip;\n    } else {\n        SystemInfo::make_cpu();\n    }\n    let check_file = PathBuf::from(out_dir).join(\"tch_gpu_check.rs\");\n    if gpu_found {\n        fs::write(check_file, \"#[allow(clippy::no_effect)]\\n()\").unwrap();\n    } else {\n        let message = if !found_dir {\n            r#\"Could not find libtorch dir.\n\n        If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions.\n\n        If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it).\"#\n        } else {\n            \"No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device\"\n        };\n        fs::write(check_file, format!(\"panic!(\\\"{message}\\\")\")).unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/backend.rs",
    "content": "use std::marker::PhantomData;\n\nuse crate::IntoKind;\n\nuse super::TchTensor;\nuse super::element::TchElement;\nuse burn_backend::backend::{Backend, DeviceId, DeviceOps, ExecutionError};\nuse burn_backend::ops::IntTensorOps;\n\n#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n/// The device struct when using the `tch` backend.\n///\n/// Note that you need to provide the device index when using Cuda.\n///\n/// # Example\n///\n/// ```no_run\n/// use burn_tch::LibTorchDevice;\n///\n/// let device_gpu_1 = LibTorchDevice::Cuda(0); // First GPU\n/// let device_gpu_2 = LibTorchDevice::Cuda(1); // Second GPU\n/// let device_cpu = LibTorchDevice::Cpu; // CPU\n/// let device_mps = LibTorchDevice::Mps; // Metal Performance Shaders\n/// let device_vulkan = LibTorchDevice::Vulkan; // Vulkan\n/// ```\n#[derive(Default)]\npub enum LibTorchDevice {\n    /// CPU device.\n    #[default]\n    Cpu,\n\n    /// Cuda device with the given index. The index is the index of the Cuda device in the list of\n    /// all Cuda devices found on the system.\n    Cuda(usize),\n\n    /// Metal Performance Shaders device.\n    Mps,\n\n    /// Vulkan device.\n    Vulkan,\n}\n\nimpl From<LibTorchDevice> for tch::Device {\n    #[allow(\n        unreachable_code,\n        reason = \"CUDA branch always panics if the library is missing\"\n    )]\n    fn from(device: LibTorchDevice) -> Self {\n        match device {\n            LibTorchDevice::Cpu => tch::Device::Cpu,\n            LibTorchDevice::Cuda(_num) => {\n                include!(concat!(env!(\"OUT_DIR\"), \"/tch_gpu_check.rs\"));\n                tch::Device::Cuda(_num)\n            }\n            LibTorchDevice::Mps => tch::Device::Mps,\n            LibTorchDevice::Vulkan => tch::Device::Vulkan,\n        }\n    }\n}\n\nimpl From<tch::Device> for LibTorchDevice {\n    fn from(device: tch::Device) -> Self {\n        match device {\n            tch::Device::Cpu => LibTorchDevice::Cpu,\n            tch::Device::Cuda(num) => LibTorchDevice::Cuda(num),\n            tch::Device::Mps => LibTorchDevice::Mps,\n            tch::Device::Vulkan => LibTorchDevice::Vulkan,\n        }\n    }\n}\n\nimpl burn_backend::Device for LibTorchDevice {\n    fn from_id(device_id: DeviceId) -> Self {\n        match device_id.type_id {\n            0 => Self::Cuda(device_id.index_id as usize),\n            1 => Self::Mps,\n            2 => Self::Cpu,\n            3 => Self::Vulkan,\n            _ => LibTorchDevice::Cpu,\n        }\n    }\n\n    fn to_id(&self) -> DeviceId {\n        match self {\n            LibTorchDevice::Cuda(index) => DeviceId::new(0, *index as u32),\n            LibTorchDevice::Mps => DeviceId::new(1, 0),\n            LibTorchDevice::Cpu => DeviceId::new(2, 0),\n            LibTorchDevice::Vulkan => DeviceId::new(3, 0),\n        }\n    }\n\n    fn device_count(_type_id: u16) -> usize {\n        // TODO: Somehow find the info using the tch API.\n        1\n    }\n}\n\nimpl DeviceOps for LibTorchDevice {}\n\n/// Tensor backend that uses `LibTorch` with the [tch] crate for executing tensor operations.\n///\n/// This backend is compatible with a wide range of hardwares ranging from CPUs to GPUs, but\n/// requires `LibTorch` to be installed correctly. The CPU version can be downloaded\n/// automatically and the CUDA version as well by setting the `TORCH_CUDA_VERSION` environment\n/// variable. For more complex configurations, check out the manual installation for\n/// [burn-tch](https://github.com/tracel-ai/burn/tree/main/crates/burn-tch).\n///\n/// Refer to the [tch] crate for more information.\n#[derive(Clone, Copy, Default, Debug)]\npub struct LibTorch<E = f32> {\n    _e: PhantomData<E>,\n}\n\nimpl<E: TchElement> Backend for LibTorch<E> {\n    type Device = LibTorchDevice;\n\n    type FloatTensorPrimitive = TchTensor;\n    type FloatElem = E;\n\n    type IntTensorPrimitive = TchTensor;\n    type IntElem = i64;\n    type BoolTensorPrimitive = TchTensor;\n    type BoolElem = bool;\n\n    type QuantizedTensorPrimitive = TchTensor;\n\n    fn seed(_device: &Self::Device, seed: u64) {\n        tch::manual_seed(seed as i64);\n    }\n\n    fn ad_enabled(_device: &Self::Device) -> bool {\n        false\n    }\n\n    fn name(device: &Self::Device) -> String {\n        match device {\n            LibTorchDevice::Cpu => \"libtorch<cpu>\",\n            LibTorchDevice::Cuda(_) => \"libtorch<cuda>\",\n            LibTorchDevice::Mps => \"libtorch<metal>\",\n            LibTorchDevice::Vulkan => \"libtorch<vulkan>\",\n        }\n        .to_string()\n    }\n\n    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {\n        match device {\n            LibTorchDevice::Cpu => (),\n            LibTorchDevice::Cuda(index) => {\n                tch::Cuda::synchronize(*index as i64);\n            }\n            _ => {\n                // When there is no explicit way to synchronize, we write and read one value to sync\n                burn_backend::read_sync(Self::int_into_data(Self::int_zeros(\n                    [1].into(),\n                    device,\n                    E::dtype().into(),\n                )))\n                .unwrap();\n            }\n        };\n\n        Ok(())\n    }\n\n    fn dtype_usage(\n        _device: &Self::Device,\n        dtype: burn_backend::DType,\n    ) -> burn_backend::DTypeUsageSet {\n        if dtype.try_into_kind().is_ok() {\n            burn_backend::DTypeUsage::general()\n        } else {\n            burn_backend::DTypeUsageSet::empty()\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/bin/cpu.rs",
    "content": "use burn_backend::{TensorMetadata, ops::FloatTensorOps};\nuse burn_tch::{LibTorch, LibTorchDevice};\n\nfn main() {\n    type B = LibTorch<f32>;\n    let device = LibTorchDevice::Cpu;\n\n    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first\n    let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);\n    let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());\n\n    // Print the element-wise addition of the two tensors.\n    println!(\"{}\", B::float_add(tensor_1, tensor_2));\n}\n"
  },
  {
    "path": "crates/burn-tch/src/bin/cuda.rs",
    "content": "use burn_backend::{TensorMetadata, ops::FloatTensorOps};\nuse burn_tch::{LibTorch, LibTorchDevice};\n\nfn main() {\n    assert!(\n        tch::utils::has_cuda(),\n        \"Could not detect valid CUDA configuration\"\n    );\n\n    type B = LibTorch<f32>;\n    let device = LibTorchDevice::Cuda(0);\n\n    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first\n    let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);\n    let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());\n\n    // Print the element-wise addition of the two tensors.\n    println!(\"{}\", B::float_add(tensor_1, tensor_2));\n}\n"
  },
  {
    "path": "crates/burn-tch/src/bin/mps.rs",
    "content": "use burn_backend::{TensorMetadata, ops::FloatTensorOps};\nuse burn_tch::{LibTorch, LibTorchDevice};\n\nfn main() {\n    assert!(tch::utils::has_mps(), \"Could not detect MPS\");\n\n    type B = LibTorch<f32>;\n    let device = LibTorchDevice::Mps;\n\n    // Creation of two tensors, the first with explicit values and the second one with ones, with the same shape as the first\n    let tensor_1 = B::float_from_data([[2f32, 3.], [4., 5.]].into(), &device);\n    let tensor_2 = B::float_ones(tensor_1.shape(), &device, tensor_1.dtype().into());\n\n    // Print the element-wise addition of the two tensors.\n    println!(\"{}\", B::float_add(tensor_1, tensor_2));\n}\n"
  },
  {
    "path": "crates/burn-tch/src/cuda_hack/dummy_cuda_dependency.cpp",
    "content": "#include <iostream>\n#include <stdexcept>\n#include <stdint.h>\n#include <stdio.h>\nusing namespace std;\nextern \"C\" {\nvoid dummy_cuda_dependency();\n}\n\nstruct cublasContext;\n\nnamespace at {\nnamespace cuda {\ncublasContext *getCurrentCUDABlasHandle();\nint warp_size();\n} // namespace cuda\n} // namespace at\nchar *magma_strerror(int err);\nvoid dummy_cuda_dependency() {\n  try {\n    at::cuda::getCurrentCUDABlasHandle();\n    at::cuda::warp_size();\n  } catch (std::exception &e) {\n    if (getenv(\"TCH_PRINT_CUDA_INIT_ERROR\") != nullptr) {\n      std::cerr << \"error initializing cuda: \" << e.what() << std::endl;\n    }\n  }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/cuda_hack/fake_cuda_dependency.cpp",
    "content": "extern \"C\" {\nvoid dummy_cuda_dependency();\n}\n\nvoid dummy_cuda_dependency() {}\n"
  },
  {
    "path": "crates/burn-tch/src/element.rs",
    "content": "use burn_backend::Element;\nuse burn_backend::{bf16, f16};\n\n/// The element type for the tch backend.\npub trait TchElement: Element + tch::kind::Element {\n    /// Returns the associated tensor kind for [`tch::kind::Element`].\n    fn kind() -> tch::Kind {\n        Self::KIND\n    }\n}\n\nimpl TchElement for f64 {}\nimpl TchElement for f32 {}\nimpl TchElement for f16 {}\nimpl TchElement for bf16 {\n    fn kind() -> tch::Kind {\n        let mut kind = <Self as tch::kind::Element>::KIND;\n        // Incorrect kind mapping in tch definitions, force bfloat16\n        if matches!(Self::dtype(), burn_backend::DType::BF16) && kind == tch::Kind::Half {\n            kind = tch::Kind::BFloat16\n        }\n        kind\n    }\n}\n\nimpl TchElement for i64 {}\nimpl TchElement for i32 {}\nimpl TchElement for i16 {}\nimpl TchElement for i8 {}\n\nimpl TchElement for u8 {}\n\nimpl TchElement for bool {}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_elem_kinds() {\n        assert_eq!(f64::kind(), tch::Kind::Double);\n        assert_eq!(f32::kind(), tch::Kind::Float);\n        assert_eq!(f16::kind(), tch::Kind::Half);\n        assert_eq!(bf16::kind(), tch::Kind::BFloat16);\n        assert_eq!(i64::kind(), tch::Kind::Int64);\n        assert_eq!(i32::kind(), tch::Kind::Int);\n        assert_eq!(i16::kind(), tch::Kind::Int16);\n        assert_eq!(i8::kind(), tch::Kind::Int8);\n        assert_eq!(bool::kind(), tch::Kind::Bool);\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n#![allow(clippy::single_range_in_vec_init)]\n\n//! Burn Tch Backend\n\nmod backend;\nmod element;\nmod ops;\nmod tensor;\n\npub use backend::*;\npub use element::*;\npub use tensor::*;\n"
  },
  {
    "path": "crates/burn-tch/src/ops/activation.rs",
    "content": "use crate::{LibTorch, TchTensor, element::TchElement};\nuse burn_backend::ops::ActivationOps;\n\nimpl<E: TchElement> ActivationOps<Self> for LibTorch<E> {\n    fn relu(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())\n    }\n\n    fn gelu(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.gelu_(\"none\"),\n            |tensor| tensor.gelu(\"none\"),\n        )\n    }\n\n    fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.gelu_backward(&grad.tensor, \"none\");\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    fn sigmoid(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())\n    }\n\n    fn log_sigmoid(tensor: TchTensor) -> TchTensor {\n        // NOTE: we don't override log_sigmoid_backward because Torch has a special backward\n        // formula that uses a buffer with computed values from the forward pass\n\n        // no in-place log_sigmoid_\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.log_sigmoid();\n\n        TchTensor::from_existing(tensor, storage)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/base.rs",
    "content": "use burn_backend::{Shape, TensorMetadata};\nuse tch::Scalar;\n\nuse crate::{LibTorchDevice, TchShape, TchTensor};\n\npub struct TchOps {\n    // e: PhantomData<E>,\n}\n\nimpl TchOps {\n    pub fn to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {\n        let device = (*device).into();\n\n        // We have to manually check if the device is the same, since when it's the case, we need to keep\n        // the same storage reference and not create a new one.\n        if tensor.tensor.device() == device {\n            return tensor;\n        }\n\n        TchTensor::new(tensor.tensor.to(device))\n    }\n\n    pub fn reshape(tensor: TchTensor, shape: Shape) -> TchTensor {\n        let shape_tch: TchShape = shape.into();\n\n        TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)\n    }\n\n    pub fn repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {\n        let mut dims = vec![1; tensor.shape().num_dims()];\n        dims[dim] = times as i64;\n        let tensor = tch::Tensor::repeat(&tensor.tensor, dims);\n        TchTensor::new(tensor)\n    }\n\n    pub fn slice_with_steps(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let mut tensor = tensor.tensor.shallow_clone();\n\n        for (dim, slice) in slices.iter().enumerate() {\n            let dim_i64 = dim as i64;\n            // Convert slice to range using a dummy size (we'll use tensor dimensions)\n            let dim_size = tensor.size()[dim];\n            let range = slice.to_range(dim_size as usize);\n            let start = range.start as i64;\n            let end = range.end as i64;\n            let step = slice.step as i64;\n\n            if step > 0 {\n                // Forward stepping - use native slice\n                tensor = tensor.slice(dim_i64, Some(start), Some(end), step);\n            } else {\n                // Negative stepping - we need to handle the semantics correctly\n                // For negative steps, we iterate backwards from end-1\n                // PyTorch's negative step works differently than our semantics\n                // We need to reverse the selected range\n\n                // First get the slice with positive step\n                tensor = tensor.slice(dim_i64, Some(start), Some(end), 1);\n\n                // Then reverse it and apply the step\n                if step == -1 {\n                    // Simple reversal\n                    tensor = tensor.flip([dim_i64]);\n                } else {\n                    // Reverse and then take every nth element\n                    tensor = tensor.flip([dim_i64]);\n                    let abs_step = step.abs();\n                    tensor = tensor.slice(dim_i64, None, None, abs_step);\n                }\n            }\n        }\n\n        TchTensor::partial(tensor, storage)\n    }\n\n    pub fn slice_assign(\n        tensor: TchTensor,\n        slices: &[burn_backend::Slice],\n        value: TchTensor,\n    ) -> TchTensor {\n        // PyTorch's narrow operation only supports contiguous slices (step=1)\n        // For non-unit steps, we use advanced indexing as a workaround\n        let all_unit_steps = slices.iter().all(|s| s.step == 1);\n\n        if all_unit_steps {\n            // Fast path: use narrow and copy_ for unit steps\n            let tch_shape = TchShape::from(tensor.shape());\n\n            // Copy the input tensor if we can't mutate it\n            let tensor_original: TchTensor =\n                tensor.unary_ops(|tensor| tensor, |tensor| tensor.copy());\n            let tensor_original = tensor_original.tensor;\n\n            let mut tensor = tensor_original.view_(tch_shape.dims);\n\n            for (i, slice) in slices.iter().enumerate().take(slices.len()) {\n                // Convert Slice to range for narrow operation\n                let dim_size = tensor.size()[i] as usize;\n                let range = slice.to_range(dim_size);\n                let start = range.start as i64;\n                let length = (range.end - range.start) as i64;\n\n                tensor = tensor.narrow(i as i64, start, length);\n            }\n\n            tensor.copy_(&value.tensor);\n            TchTensor::new(tensor_original)\n        } else {\n            // Workaround for non-unit steps: use PyTorch's index_put operation\n            // This generates explicit indices for the slice and uses advanced indexing\n            let tensor_shape = tensor.shape();\n            let dims = tensor_shape.clone();\n\n            // Copy the tensor since we'll modify it\n            let result_tensor = tensor.tensor.shallow_clone();\n\n            // Use advanced indexing to set the values\n            Self::slice_assign_with_advanced_indexing(result_tensor, slices, value.tensor, &dims)\n        }\n    }\n\n    /// Generate indices for a slice with potentially non-unit step.\n    /// For negative steps, generates indices in reverse order.\n    fn generate_slice_indices(slice: &burn_backend::Slice, dim_size: usize) -> Vec<i64> {\n        let step = slice.step;\n        let range = slice.to_range(dim_size);\n\n        let mut indices = Vec::new();\n\n        if step > 0 {\n            let mut idx = range.start as i64;\n            while idx < range.end as i64 {\n                indices.push(idx);\n                idx += step as i64;\n            }\n        } else if step < 0 {\n            // For negative steps, iterate backwards through the range\n            let mut idx = (range.end - 1) as i64;\n            while idx >= range.start as i64 {\n                indices.push(idx);\n                idx += step as i64; // step is negative, so this decreases\n            }\n        }\n\n        indices\n    }\n\n    /// Implementation using advanced indexing for non-unit steps.\n    /// Uses PyTorch's index_put operation to assign values at specific indices.\n    fn slice_assign_with_advanced_indexing(\n        mut tensor: tch::Tensor,\n        slices: &[burn_backend::Slice],\n        value: tch::Tensor,\n        dims: &[usize],\n    ) -> TchTensor {\n        // Generate all index combinations for the sliced regions\n        let mut index_sets: Vec<Vec<i64>> = Vec::new();\n        for (i, slice) in slices.iter().enumerate() {\n            let dim_size = if i < dims.len() { dims[i] } else { 1 };\n            let indices = Self::generate_slice_indices(slice, dim_size);\n            index_sets.push(indices);\n        }\n\n        // For unsliced dimensions, include all indices\n        for &dim_size in dims.iter().skip(slices.len()) {\n            let indices: Vec<i64> = (0..dim_size as i64).collect();\n            index_sets.push(indices);\n        }\n\n        // Convert index sets to tensors for index_put\n        let mut final_indices = Vec::new();\n        let total_elements = index_sets.iter().map(|s| s.len()).product::<usize>();\n\n        // Build flattened index arrays for each dimension using cartesian product\n        // This creates the index tensors needed for PyTorch's index_put operation\n        for dim_idx in 0..index_sets.len() {\n            let mut dim_indices = Vec::with_capacity(total_elements);\n            let repeat = index_sets[dim_idx + 1..]\n                .iter()\n                .map(|s| s.len())\n                .product::<usize>()\n                .max(1);\n            let tile = index_sets[..dim_idx]\n                .iter()\n                .map(|s| s.len())\n                .product::<usize>()\n                .max(1);\n\n            for _ in 0..tile {\n                for &idx in &index_sets[dim_idx] {\n                    for _ in 0..repeat {\n                        dim_indices.push(idx);\n                    }\n                }\n            }\n\n            let indices_tensor = tch::Tensor::from_slice(&dim_indices).to_device(tensor.device());\n            final_indices.push(indices_tensor);\n        }\n\n        // PyTorch's index_put handles assignment correctly for negative steps\n        // following NumPy semantics: values[i] goes to selected_indices[i]\n        let value_flat = value.view(-1);\n\n        // Use index_put to assign values - convert to Option<Tensor>\n        let final_indices_opt: Vec<Option<tch::Tensor>> =\n            final_indices.into_iter().map(Some).collect();\n        tensor = tensor.index_put(&final_indices_opt, &value_flat, false);\n\n        TchTensor::new(tensor)\n    }\n\n    pub fn gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn scatter(\n        dim: usize,\n        tensor: TchTensor,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor\n            .tensor\n            .scatter_add(dim as i64, &indices.tensor, &value.tensor);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn index_select_dim(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn select_assign(\n        tensor: TchTensor,\n        dim: usize,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        tensor.clone().unary_ops(\n            |mut tensor| tensor.index_add_(dim as i64, &indices.tensor, &value.tensor),\n            |tensor| tensor.index_add(dim as i64, &indices.tensor, &value.tensor),\n        )\n    }\n\n    pub fn cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {\n        let tensors: Vec<tch::Tensor> = tensors\n            .into_iter()\n            .map(|t| t.tensor.shallow_clone())\n            .collect();\n        let tensor = tch::Tensor::cat(&tensors, dim as i64);\n\n        TchTensor::new(tensor)\n    }\n\n    pub fn equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| lhs.eq_tensor(rhs),\n        )\n    }\n\n    pub fn equal_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.eq_(rhs.clone().into()).to_kind(tch::Kind::Bool),\n            |tensor| tensor.eq(rhs.clone().into()),\n        )\n    }\n\n    pub fn greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.greater_tensor_(rhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| rhs.less_tensor_(lhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| lhs.greater_tensor(rhs),\n        )\n    }\n\n    pub fn greater_elem<S: Into<tch::Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.greater_(rhs.clone().into()).to_kind(tch::Kind::Bool),\n            |tensor| tensor.greater(rhs.clone().into()),\n        )\n    }\n\n    pub fn greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.greater_equal_tensor_(rhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| rhs.less_equal_tensor_(lhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| lhs.greater_equal_tensor(rhs),\n        )\n    }\n\n    pub fn greater_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| {\n                tensor\n                    .greater_equal_(rhs.clone().into())\n                    .to_kind(tch::Kind::Bool)\n            },\n            |tensor| tensor.greater_equal(rhs.clone().into()),\n        )\n    }\n\n    pub fn lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.less_tensor_(rhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| rhs.greater_tensor_(lhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| lhs.less_tensor(rhs),\n        )\n    }\n\n    pub fn lower_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.less_(rhs.clone().into()).to_kind(tch::Kind::Bool),\n            |tensor| tensor.less(rhs.clone().into()),\n        )\n    }\n\n    pub fn lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.less_equal_tensor_(rhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| rhs.greater_equal_tensor_(lhs).to_kind(tch::Kind::Bool),\n            |lhs, rhs| lhs.less_equal_tensor(rhs),\n        )\n    }\n\n    pub fn lower_equal_elem<S: Into<Scalar> + Clone>(lhs: TchTensor, rhs: S) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| {\n                tensor\n                    .less_equal_(rhs.clone().into())\n                    .to_kind(tch::Kind::Bool)\n            },\n            |tensor| tensor.less_equal(rhs.clone().into()),\n        )\n    }\n\n    pub fn add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_add_(rhs).unwrap(),\n            |lhs, rhs| rhs.f_add_(lhs).unwrap(),\n            |lhs, rhs| lhs.f_add(rhs).unwrap(),\n        )\n    }\n\n    pub fn sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_sub_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_sub(rhs).unwrap(),\n            |lhs, rhs| lhs.f_sub(rhs).unwrap(),\n        )\n    }\n\n    pub fn mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_mul_(rhs).unwrap(),\n            |lhs, rhs| rhs.f_mul_(lhs).unwrap(),\n            |lhs, rhs| lhs.f_mul(rhs).unwrap(),\n        )\n    }\n\n    pub fn div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_div_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_div(rhs).unwrap(),\n            |lhs, rhs| lhs.f_div(rhs).unwrap(),\n        )\n    }\n\n    pub fn remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_remainder_tensor_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),\n            |lhs, rhs| lhs.f_remainder_tensor(rhs).unwrap(),\n        )\n    }\n\n    pub fn mean(tensor: TchTensor) -> TchTensor {\n        // view as 1d tensor\n        let tensor = tensor.tensor.mean(tensor.tensor.kind()).view(1);\n        TchTensor::new(tensor)\n    }\n\n    pub fn mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchTensor::from_existing(\n            tensor\n                .tensor\n                .mean_dim(Some([dim as i64].as_slice()), true, tensor.tensor.kind()),\n            tensor.storage,\n        )\n    }\n\n    pub fn sum(tensor: TchTensor) -> TchTensor {\n        // view as 1d tensor\n        let tensor = tensor.tensor.sum(tensor.tensor.kind()).view(1);\n        TchTensor::new(tensor)\n    }\n\n    pub fn sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchTensor::from_existing(\n            tensor.tensor.sum_dim_intlist(\n                Some([dim as i64].as_slice()),\n                true,\n                tensor.tensor.kind(),\n            ),\n            tensor.storage,\n        )\n    }\n\n    pub fn prod(tensor: TchTensor) -> TchTensor {\n        // view as 1d tensor\n        let tensor = tensor.tensor.prod(tensor.tensor.kind()).view(1);\n        TchTensor::new(tensor)\n    }\n\n    pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchTensor::from_existing(\n            tensor\n                .tensor\n                .prod_dim_int(dim as i64, true, tensor.tensor.kind()),\n            tensor.storage,\n        )\n    }\n\n    pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchTensor::from_existing(\n            tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),\n            tensor.storage,\n        )\n    }\n\n    pub fn cumprod(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchTensor::from_existing(\n            tensor.tensor.cumprod(dim as i64, tensor.tensor.kind()),\n            tensor.storage,\n        )\n    }\n\n    pub fn cummin(tensor: TchTensor, dim: usize) -> TchTensor {\n        let (values, _indices) = tensor.tensor.cummin(dim as i64);\n        TchTensor::from_existing(values, tensor.storage)\n    }\n\n    pub fn cummax(tensor: TchTensor, dim: usize) -> TchTensor {\n        // cummax returns (values, indices) tuple in PyTorch, we only need values\n        let (values, _indices) = tensor.tensor.cummax(dim as i64);\n        TchTensor::from_existing(values, tensor.storage)\n    }\n\n    pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.argmax(dim as i64, true);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn argmin(tensor: TchTensor, dim: usize) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let tensor = tensor.tensor.argmin(dim as i64, true);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn max_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        let storage = tensor.storage.clone();\n        let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true);\n\n        let tensor = TchTensor::from_existing(tensor, storage);\n        let indices = TchTensor::new(indices);\n\n        (tensor, indices)\n    }\n\n    pub fn min_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n\n    pub fn min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        let storage = tensor.storage.clone();\n        let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true);\n\n        let tensor = TchTensor::from_existing(tensor, storage);\n        let indices = TchTensor::new(indices);\n\n        (tensor, indices)\n    }\n\n    pub fn clamp_min<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, min: S) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.clamp_min_(min),\n            |tensor| tensor.clamp_min(min),\n        )\n    }\n\n    pub fn clamp_max<S: Into<tch::Scalar> + Clone + Copy>(tensor: TchTensor, max: S) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.clamp_max_(max),\n            |tensor| tensor.clamp_max(max),\n        )\n    }\n\n    pub fn clamp<S: Into<tch::Scalar> + Clone + Copy>(\n        tensor: TchTensor,\n        min: S,\n        max: S,\n    ) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.clamp_(min, max),\n            |tensor| tensor.clamp(min, max),\n        )\n    }\n\n    pub fn swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {\n        let tensor = tensor.tensor.transpose(dim1 as i64, dim2 as i64);\n        TchTensor::new(tensor)\n    }\n\n    pub fn permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        let tensor = tensor\n            .tensor\n            .permute(axes.iter().map(|x| *x as i64).collect::<Vec<_>>());\n        TchTensor::new(tensor)\n    }\n\n    pub fn flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        let dims = axes.iter().map(|x| *x as i64).collect::<Vec<_>>();\n        let tensor = tensor.tensor.flip(dims);\n        TchTensor::new(tensor)\n    }\n\n    pub fn powf(tensor: TchTensor, exponent: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            tensor,\n            exponent,\n            |lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_pow(rhs).unwrap(),\n            |lhs, rhs| lhs.f_pow(rhs).unwrap(),\n        )\n    }\n\n    pub fn sign(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.sign_(), |tensor| tensor.sign())\n    }\n\n    pub fn expand(tensor: TchTensor, shape: Shape) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let broadcasted_tensor = tensor.tensor.broadcast_to(TchShape::from(shape).dims);\n        TchTensor::from_existing(broadcasted_tensor, storage)\n    }\n\n    pub fn unfold(tensor: TchTensor, dim: usize, size: usize, step: usize) -> TchTensor {\n        let storage = tensor.storage.clone();\n        let uf_tensor = tensor.tensor.unfold(dim as i64, size as i64, step as i64);\n\n        TchTensor::from_existing(uf_tensor, storage)\n    }\n\n    pub fn sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {\n        TchTensor::new(tensor.tensor.sort(dim as i64, descending).0)\n    }\n\n    pub fn sort_with_indices(\n        tensor: TchTensor,\n        dim: usize,\n        descending: bool,\n    ) -> (TchTensor, TchTensor) {\n        let sorted = tensor.tensor.sort(dim as i64, descending);\n        (TchTensor::new(sorted.0), TchTensor::new(sorted.1))\n    }\n\n    pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {\n        TchTensor::new(tensor.tensor.argsort(dim as i64, descending))\n    }\n\n    pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(),\n            |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(),\n        )\n    }\n\n    pub fn bitwise_and_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(),\n            |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(),\n        )\n    }\n\n    pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(),\n            |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(),\n        )\n    }\n\n    pub fn bitwise_or_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(),\n            |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(),\n        )\n    }\n\n    pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(),\n            |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(),\n        )\n    }\n\n    pub fn bitwise_xor_scalar<S: Into<Scalar> + Clone>(tensor: TchTensor, scalar: S) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(),\n            |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(),\n        )\n    }\n\n    pub fn bitwise_not(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.f_bitwise_not_().unwrap(),\n            |tensor| tensor.f_bitwise_not().unwrap(),\n        )\n    }\n\n    pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),\n        )\n    }\n\n    pub fn bitwise_left_shift_scalar<S: Into<Scalar> + Clone>(\n        tensor: TchTensor,\n        scalar: S,\n    ) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| {\n                tensor\n                    .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into())\n                    .unwrap()\n            },\n            |tensor| {\n                tensor\n                    .f_bitwise_left_shift_tensor_scalar(scalar.clone().into())\n                    .unwrap()\n            },\n        )\n    }\n\n    pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),\n            |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),\n        )\n    }\n\n    pub fn bitwise_right_shift_scalar<S: Into<Scalar> + Clone>(\n        tensor: TchTensor,\n        scalar: S,\n    ) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| {\n                tensor\n                    .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into())\n                    .unwrap()\n            },\n            |tensor| {\n                tensor\n                    .f_bitwise_right_shift_tensor_scalar(scalar.clone().into())\n                    .unwrap()\n            },\n        )\n    }\n\n    pub fn atan2(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.f_atan2_(rhs).unwrap(),\n            |lhs, rhs| lhs.f_atan2(rhs).unwrap(),\n            |lhs, rhs| lhs.f_atan2(rhs).unwrap(),\n        )\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/bool_tensor.rs",
    "content": "use super::TchOps;\nuse crate::{LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};\nuse burn_backend::BoolStore;\nuse burn_backend::ExecutionError;\nuse burn_backend::Scalar;\nuse burn_backend::tensor::BoolTensor;\nuse burn_backend::tensor::IntTensor;\nuse burn_backend::{Shape, TensorData, TensorMetadata, ops::BoolTensorOps};\n\nimpl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {\n    fn bool_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {\n        match data.dtype {\n            burn_backend::DType::Bool(BoolStore::Native) => {\n                TchTensor::from_data::<bool>(data, (*device).into())\n            }\n            _ => unimplemented!(\"Unsupported dtype for `bool_from_data`\"),\n        }\n    }\n\n    fn bool_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {\n        TchOps::repeat_dim(tensor, dim, times)\n    }\n\n    async fn bool_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {\n        let shape = tensor.shape();\n        let tensor = Self::bool_reshape(tensor.clone(), Shape::new([shape.num_elements()]));\n        let values: Result<Vec<bool>, tch::TchError> = tensor.tensor.shallow_clone().try_into();\n        Ok(TensorData::new(values.unwrap(), shape))\n    }\n\n    fn bool_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {\n        TchOps::to_device(tensor, device)\n    }\n\n    fn bool_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {\n        TchOps::reshape(tensor, shape)\n    }\n\n    fn bool_device(tensor: &TchTensor) -> LibTorchDevice {\n        tensor.tensor.device().into()\n    }\n\n    fn bool_empty(shape: Shape, device: &LibTorchDevice) -> TchTensor {\n        let tensor = tch::Tensor::empty(\n            TchShape::from(shape).dims,\n            (tch::Kind::Bool, (*device).into()),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn bool_zeros(shape: Shape, device: &LibTorchDevice) -> TchTensor {\n        let tensor = tch::Tensor::zeros(\n            TchShape::from(shape).dims,\n            (tch::Kind::Bool, (*device).into()),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn bool_ones(shape: Shape, device: &LibTorchDevice) -> TchTensor {\n        let tensor = tch::Tensor::ones(\n            TchShape::from(shape).dims,\n            (tch::Kind::Bool, (*device).into()),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn bool_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {\n        TchOps::slice_with_steps(tensor, slices)\n    }\n\n    fn bool_slice_assign(\n        tensor: TchTensor,\n        slices: &[burn_backend::Slice],\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::slice_assign(tensor, slices, value)\n    }\n\n    fn bool_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {\n        TchOps::cat(tensors, dim)\n    }\n\n    fn bool_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::equal(lhs, rhs)\n    }\n\n    fn bool_not(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.eq_(0).to_kind(tch::Kind::Bool),\n            |tensor| tensor.eq(0),\n        )\n    }\n\n    fn bool_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.logical_and_(rhs),\n            |lhs, rhs| rhs.logical_and_(lhs),\n            |lhs, rhs| lhs.logical_and(rhs),\n        )\n    }\n\n    fn bool_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            lhs,\n            rhs,\n            |lhs, rhs| lhs.logical_or_(rhs),\n            |lhs, rhs| rhs.logical_or_(lhs),\n            |lhs, rhs| lhs.logical_or(rhs),\n        )\n    }\n\n    fn bool_into_int(tensor: TchTensor) -> TchTensor {\n        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);\n        TchTensor::new(tensor)\n    }\n\n    fn bool_into_float(tensor: TchTensor) -> TchTensor {\n        let tensor = tensor.tensor.to_kind(E::kind());\n        TchTensor::new(tensor)\n    }\n\n    fn bool_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {\n        TchOps::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn bool_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        TchOps::permute(tensor, axes)\n    }\n\n    fn bool_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        TchOps::flip(tensor, axes)\n    }\n\n    async fn bool_argwhere(tensor: TchTensor) -> TchTensor {\n        TchTensor::new(tensor.tensor.argwhere())\n    }\n\n    fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {\n        TchOps::index_select_dim(tensor, dim, indices)\n    }\n\n    fn bool_select_or(\n        tensor: TchTensor,\n        dim: usize,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::select_assign(tensor, dim, indices, value)\n    }\n\n    fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {\n        TchOps::expand(tensor, shape)\n    }\n\n    fn bool_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        TchOps::unfold(tensor, dim, size, step)\n    }\n\n    fn bool_mask_where(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        TchTensor::binary_ops_tensor(\n            tensor,\n            value,\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n        )\n    }\n\n    fn bool_mask_fill(\n        tensor: BoolTensor<Self>,\n        mask: BoolTensor<Self>,\n        value: Scalar,\n    ) -> BoolTensor<Self> {\n        tensor.unary_ops(\n            |mut tensor| {\n                tensor\n                    .f_masked_fill_(&mask.tensor, value.elem::<i64>())\n                    .unwrap()\n            },\n            |tensor| {\n                tensor\n                    .f_masked_fill(&mask.tensor, value.elem::<i64>())\n                    .unwrap()\n            },\n        )\n    }\n\n    fn bool_gather(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n    ) -> BoolTensor<Self> {\n        TchOps::gather(dim, tensor, indices)\n    }\n\n    fn bool_scatter_or(\n        dim: usize,\n        tensor: BoolTensor<Self>,\n        indices: IntTensor<Self>,\n        value: BoolTensor<Self>,\n    ) -> BoolTensor<Self> {\n        TchOps::scatter(dim, tensor, indices, value)\n    }\n\n    fn bool_equal_elem(lhs: BoolTensor<Self>, rhs: Scalar) -> BoolTensor<Self> {\n        TchOps::equal_elem(lhs, rhs.elem::<i64>())\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/int_tensor.rs",
    "content": "use std::ops::Range;\n\nuse burn_backend::{\n    Distribution, ExecutionError, IntDType, Scalar, Shape, TensorData, TensorMetadata,\n    ops::{FloatTensorOps, IntTensorOps},\n    tensor::IntTensor,\n};\n\nuse crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};\n\nuse super::TchOps;\n\nimpl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {\n    fn int_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {\n        match data.dtype {\n            burn_backend::DType::I64 => TchTensor::from_data::<i64>(data, (*device).into()),\n            _ => unimplemented!(\"Unsupported dtype for `int_from_data`\"),\n        }\n    }\n\n    fn int_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {\n        TchOps::repeat_dim(tensor, dim, times)\n    }\n\n    async fn int_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {\n        let shape = tensor.shape();\n        let tensor = Self::int_reshape(tensor.clone(), Shape::new([shape.num_elements()]));\n        let values: Result<Vec<i64>, tch::TchError> = tensor.tensor.shallow_clone().try_into();\n        Ok(TensorData::new(values.unwrap(), shape))\n    }\n\n    fn int_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {\n        TchOps::to_device(tensor, device)\n    }\n\n    fn int_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {\n        TchOps::reshape(tensor, shape)\n    }\n\n    fn int_device(tensor: &TchTensor) -> LibTorchDevice {\n        tensor.tensor.device().into()\n    }\n\n    fn int_empty(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {\n        let tensor = tch::Tensor::empty(\n            TchShape::from(shape).dims,\n            (dtype.into_kind(), (*device).into()),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn int_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {\n        TchOps::slice_with_steps(tensor, slices)\n    }\n\n    fn int_slice_assign(\n        tensor: TchTensor,\n        slices: &[burn_backend::Slice],\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::slice_assign(tensor, slices, value)\n    }\n\n    fn int_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {\n        TchOps::cat(tensors, dim)\n    }\n\n    fn int_matmul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        let lhs = Self::int_into_float(lhs);\n        let rhs = Self::int_into_float(rhs);\n        let out = lhs.tensor.f_matmul(&rhs.tensor).unwrap();\n        Self::float_into_int(TchTensor::new(out))\n    }\n\n    fn int_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::equal(lhs, rhs)\n    }\n\n    fn int_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::equal_elem(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::greater(lhs, rhs)\n    }\n\n    fn int_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::greater_elem(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::greater_equal(lhs, rhs)\n    }\n\n    fn int_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::greater_equal_elem(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::lower(lhs, rhs)\n    }\n\n    fn int_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::lower_elem(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::lower_equal(lhs, rhs)\n    }\n\n    fn int_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::lower_equal_elem(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::add(lhs, rhs)\n    }\n\n    fn int_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.f_add_scalar_(rhs.elem::<i64>()).unwrap(),\n            |tensor| tensor.f_add_scalar(rhs.elem::<i64>()).unwrap(),\n        )\n    }\n\n    fn int_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::sub(lhs, rhs)\n    }\n\n    fn int_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.f_sub_scalar_(rhs.elem::<i64>()).unwrap(),\n            |tensor| tensor.f_sub_scalar(rhs.elem::<i64>()).unwrap(),\n        )\n    }\n\n    fn int_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::mul(lhs, rhs)\n    }\n\n    fn int_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        lhs.unary_ops(\n            |mut tensor| tensor.f_mul_scalar_(rhs.elem::<i64>()).unwrap(),\n            |tensor| tensor.f_mul_scalar(rhs.elem::<i64>()).unwrap(),\n        )\n    }\n\n    fn int_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        let dtype = lhs.tensor.kind();\n        let copy = false;\n        let non_blocking = true;\n        let lhs: TchTensor =\n            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));\n        let rhs: TchTensor =\n            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));\n\n        let out = TchOps::div(lhs, rhs);\n\n        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))\n    }\n\n    fn int_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let dtype = lhs.tensor.kind();\n        let copy = false;\n        let non_blocking = true;\n        let lhs: TchTensor =\n            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));\n\n        let out: TchTensor = lhs.unary_ops(\n            |mut tensor| tensor.f_div_scalar_(rhs.elem::<i64>()).unwrap(),\n            |tensor| tensor.f_div_scalar(rhs.elem::<i64>()).unwrap(),\n        );\n\n        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))\n    }\n\n    fn int_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        let dtype = lhs.tensor.kind();\n        let copy = false;\n        let non_blocking = true;\n        let lhs: TchTensor =\n            TchTensor::new(lhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));\n        let rhs: TchTensor =\n            TchTensor::new(rhs.tensor.to_dtype(tch::Kind::Float, non_blocking, copy));\n\n        let out = TchOps::remainder(lhs, rhs);\n\n        TchTensor::new(out.tensor.to_dtype(dtype, non_blocking, copy))\n    }\n\n    fn int_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        lhs.unary_ops(\n            |tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),\n            |tensor| tensor.f_remainder(rhs.elem::<i64>()).unwrap(),\n        )\n    }\n\n    fn int_zeros(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {\n        let shape = TchShape::from(shape);\n        let device: tch::Device = (*device).into();\n\n        TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))\n    }\n\n    fn int_ones(shape: Shape, device: &LibTorchDevice, dtype: IntDType) -> TchTensor {\n        let shape = TchShape::from(shape);\n        let device: tch::Device = (*device).into();\n\n        TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))\n    }\n\n    fn int_full(\n        shape: Shape,\n        fill_value: Scalar,\n        device: &LibTorchDevice,\n        dtype: IntDType,\n    ) -> TchTensor {\n        let shape = TchShape::from(shape);\n        let device: tch::Device = (*device).into();\n\n        TchTensor::new(tch::Tensor::full(\n            shape.dims,\n            fill_value.elem::<i64>(),\n            (dtype.into_kind(), device),\n        ))\n    }\n\n    fn int_sum(tensor: TchTensor) -> TchTensor {\n        TchOps::sum(tensor)\n    }\n\n    fn int_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::sum_dim(tensor, dim)\n    }\n\n    fn int_prod(tensor: TchTensor) -> TchTensor {\n        TchOps::prod(tensor)\n    }\n\n    fn int_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::prod_dim(tensor, dim)\n    }\n\n    fn int_mean(tensor: TchTensor) -> TchTensor {\n        let dtype = tensor.tensor.kind();\n        let tensor: TchTensor =\n            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));\n        let output: TchTensor = TchTensor::new(TchOps::mean(tensor).tensor);\n\n        TchTensor::new(output.tensor.to_dtype(dtype, true, false))\n    }\n\n    fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        let dtype = tensor.tensor.kind();\n        let tensor: TchTensor =\n            TchTensor::new(tensor.tensor.to_dtype(tch::Kind::Float, true, false));\n\n        let output: TchTensor = TchTensor::new(TchOps::mean_dim(tensor, dim).tensor);\n\n        TchTensor::new(output.tensor.to_dtype(dtype, true, false))\n    }\n\n    fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cumsum(tensor, dim)\n    }\n\n    fn int_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cumprod(tensor, dim)\n    }\n\n    fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cummin(tensor, dim)\n    }\n\n    fn int_cummax(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cummax(tensor, dim)\n    }\n\n    fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {\n        TchOps::gather(dim, tensor, indices)\n    }\n\n    fn int_scatter_add(\n        dim: usize,\n        tensor: TchTensor,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::scatter(dim, tensor, indices, value)\n    }\n\n    fn int_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {\n        TchOps::index_select_dim(tensor, dim, indices)\n    }\n\n    fn int_select_add(\n        tensor: TchTensor,\n        dim: usize,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::select_assign(tensor, dim, indices, value)\n    }\n\n    fn int_mask_where(tensor: TchTensor, mask: TchTensor, source: TchTensor) -> TchTensor {\n        TchTensor::binary_ops_tensor(\n            tensor,\n            source,\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n            |tensor, source| source.f_where_self(&mask.tensor, tensor).unwrap(),\n        )\n    }\n\n    fn int_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {\n        let value = value.elem::<i64>();\n        tensor.unary_ops(\n            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),\n            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),\n        )\n    }\n\n    fn int_argmax(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::argmax(tensor, dim)\n    }\n\n    fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::argmin(tensor, dim)\n    }\n\n    fn int_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::max_dim(tensor, dim)\n    }\n\n    fn int_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        TchOps::max_dim_with_indices(tensor, dim)\n    }\n\n    fn int_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::min_dim(tensor, dim)\n    }\n\n    fn int_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        TchOps::min_dim_with_indices(tensor, dim)\n    }\n\n    fn int_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {\n        TchOps::clamp_min(tensor, min.elem::<i64>())\n    }\n\n    fn int_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {\n        TchOps::clamp_max(tensor, max.elem::<i64>())\n    }\n\n    fn int_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {\n        TchOps::clamp(tensor, min.elem::<i64>(), max.elem::<i64>())\n    }\n\n    fn int_abs(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())\n    }\n\n    fn int_into_float(tensor: TchTensor) -> TchTensor {\n        let tensor = tensor.tensor.to_kind(E::kind());\n        TchTensor::new(tensor)\n    }\n\n    fn int_swap_dims(tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {\n        TchOps::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn int_random(shape: Shape, distribution: Distribution, device: &LibTorchDevice) -> TchTensor {\n        match distribution {\n            Distribution::Default => TchTensor::new(tch::Tensor::randint_low(\n                0,\n                255,\n                shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),\n                (tch::Kind::Int64, (*device).into()),\n            )),\n            Distribution::Bernoulli(prob) => {\n                let mut tensor = TchTensor::empty::<i64>(shape, *device);\n                tensor\n                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())\n                    .unwrap()\n            }\n            Distribution::Uniform(from, to) => TchTensor::new(tch::Tensor::randint_low(\n                from as i64,\n                to as i64,\n                shape.iter().map(|i| *i as i64).collect::<Vec<_>>(),\n                (tch::Kind::Int64, (*device).into()),\n            )),\n            Distribution::Normal(mean, std) => {\n                let mut tensor = TchTensor::empty::<i64>(shape, *device);\n                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()\n            }\n        }\n    }\n\n    fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor {\n        let device: tch::Device = (*device).into();\n        let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));\n\n        if range.start != 0 {\n            tensor = tensor.f_add_scalar_(range.start).unwrap();\n        }\n\n        TchTensor::new(tensor)\n    }\n\n    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        TchOps::permute(tensor, axes)\n    }\n\n    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {\n        TchOps::flip(tensor, axes)\n    }\n\n    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::sign(tensor)\n    }\n\n    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {\n        TchOps::expand(tensor, shape)\n    }\n\n    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        TchOps::sort(tensor, dim, descending)\n    }\n\n    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {\n        TchOps::argsort(tensor, dim, descending)\n    }\n\n    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_and(lhs, rhs)\n    }\n\n    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_or(lhs, rhs)\n    }\n\n    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_xor(lhs, rhs)\n    }\n\n    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_not(tensor)\n    }\n\n    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        TchOps::bitwise_and_scalar(lhs, rhs.elem::<i64>())\n    }\n\n    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        TchOps::bitwise_or_scalar(lhs, rhs.elem::<i64>())\n    }\n\n    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        TchOps::bitwise_xor_scalar(lhs, rhs.elem::<i64>())\n    }\n\n    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_left_shift(lhs, rhs)\n    }\n\n    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {\n        TchOps::bitwise_right_shift(lhs, rhs)\n    }\n\n    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        TchOps::bitwise_left_shift_scalar(lhs, rhs.elem::<i64>())\n    }\n\n    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: Scalar) -> IntTensor<Self> {\n        TchOps::bitwise_right_shift_scalar(lhs, rhs.elem::<i64>())\n    }\n\n    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {\n        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type\n        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc\n\n        // Type promotion is not automatic on all backends so this behavior might differ\n        let kind = dtype.into_kind();\n\n        if tensor.tensor.kind() == kind {\n            tensor\n        } else {\n            TchTensor::new(tensor.tensor.to_kind(kind))\n        }\n    }\n\n    fn int_unfold(\n        tensor: IntTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> IntTensor<Self> {\n        TchOps::unfold(tensor, dim, size, step)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/mod.rs",
    "content": "mod activation;\nmod base;\nmod bool_tensor;\nmod int_tensor;\nmod module;\nmod qtensor;\nmod tensor;\nmod transaction;\n\npub(crate) use base::*;\n"
  },
  {
    "path": "crates/burn-tch/src/ops/module.rs",
    "content": "use crate::{LibTorch, TchTensor, element::TchElement};\nuse burn_backend::{\n    TensorMetadata,\n    ops::{\n        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward,\n        DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,\n        MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, attention::attention_fallback,\n    },\n};\n\nimpl<E: TchElement> ModuleOps<Self> for LibTorch<E> {\n    fn embedding(weights: TchTensor, indices: TchTensor) -> TchTensor {\n        // Workaround for MPS \"Placeholder storage has not been allocated\" error.\n        // See: https://github.com/pytorch/pytorch/issues/123995\n        // MPS uses lazy allocation and the embedding operation (which uses index_select)\n        // can fail if the tensors haven't been materialized yet.\n        // We work around this by performing the embedding on CPU and transferring back to MPS.\n        if matches!(weights.tensor.device(), tch::Device::Mps) {\n            let cpu_weights = weights.tensor.to(tch::Device::Cpu);\n            let cpu_indices = indices.tensor.to(tch::Device::Cpu);\n            let result = tch::Tensor::embedding(&cpu_weights, &cpu_indices, -1, false, false)\n                .to(tch::Device::Mps);\n            return TchTensor::new(result);\n        }\n\n        let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);\n        TchTensor::new(tensor)\n    }\n\n    fn embedding_backward(weights: TchTensor, output: TchTensor, indices: TchTensor) -> TchTensor {\n        let [n_embedding, _d_model] = weights.shape().dims();\n\n        // Workaround for MPS \"Placeholder storage has not been allocated\" error.\n        // See: https://github.com/pytorch/pytorch/issues/123995\n        if matches!(output.tensor.device(), tch::Device::Mps) {\n            let cpu_output = output.tensor.to(tch::Device::Cpu);\n            let cpu_indices = indices.tensor.to(tch::Device::Cpu);\n            let result = tch::Tensor::embedding_backward(\n                &cpu_output,\n                &cpu_indices,\n                n_embedding as i64,\n                -1,\n                false,\n                false,\n            )\n            .to(tch::Device::Mps);\n            return TchTensor::new(result);\n        }\n\n        let tensor = tch::Tensor::embedding_backward(\n            &output.tensor,\n            &indices.tensor,\n            n_embedding as i64,\n            -1,\n            false,\n            false,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn conv1d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvOptions<1>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv1d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.dilation.map(|i| i as i64),\n            options.groups as i64,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn conv2d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvOptions<2>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv2d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.dilation.map(|i| i as i64),\n            options.groups as i64,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn conv3d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvOptions<3>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv3d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.dilation.map(|i| i as i64),\n            options.groups as i64,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn deform_conv2d(\n        _x: TchTensor,\n        _offset: TchTensor,\n        _weight: TchTensor,\n        _mask: Option<TchTensor>,\n        _bias: Option<TchTensor>,\n        _options: DeformConvOptions<2>,\n    ) -> TchTensor {\n        unimplemented!(\"Torch bindings don't support deform_conv2d\");\n    }\n\n    fn deform_conv2d_backward(\n        _x: TchTensor,\n        _offset: TchTensor,\n        _weight: TchTensor,\n        _mask: Option<TchTensor>,\n        _bias: Option<TchTensor>,\n        _out_grad: TchTensor,\n        _options: DeformConvOptions<2>,\n    ) -> DeformConv2dBackward<Self> {\n        unimplemented!(\"Torch bindings don't support deform_conv2d\");\n    }\n\n    fn conv_transpose1d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvTransposeOptions<1>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv_transpose1d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.padding_out.map(|i| i as i64),\n            options.groups as i64,\n            options.dilation.map(|i| i as i64),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn conv_transpose2d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvTransposeOptions<2>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv_transpose2d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.padding_out.map(|i| i as i64),\n            options.groups as i64,\n            options.dilation.map(|i| i as i64),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn conv_transpose3d(\n        x: TchTensor,\n        weight: TchTensor,\n        bias: Option<TchTensor>,\n        options: ConvTransposeOptions<3>,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::conv_transpose3d(\n            &x.tensor,\n            &weight.tensor,\n            bias.map(|t| t.tensor),\n            options.stride.map(|i| i as i64),\n            options.padding.map(|i| i as i64),\n            options.padding_out.map(|i| i as i64),\n            options.groups as i64,\n            options.dilation.map(|i| i as i64),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn avg_pool1d(\n        x: TchTensor,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::avg_pool1d(\n            &x.tensor,\n            [kernel_size as i64],\n            [stride as i64],\n            [padding as i64],\n            ceil_mode,\n            count_include_pad,\n        );\n\n        TchTensor::new(tensor)\n    }\n    fn avg_pool2d(\n        x: TchTensor,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::avg_pool2d(\n            &x.tensor,\n            [kernel_size[0] as i64, kernel_size[1] as i64],\n            [stride[0] as i64, stride[1] as i64],\n            [padding[0] as i64, padding[1] as i64],\n            ceil_mode,\n            count_include_pad,\n            None,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn avg_pool2d_backward(\n        x: TchTensor,\n        grad: TchTensor,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        count_include_pad: bool,\n        ceil_mode: bool,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::avg_pool2d_backward(\n            &x.tensor,\n            &grad.tensor,\n            [kernel_size[0] as i64, kernel_size[1] as i64],\n            [stride[0] as i64, stride[1] as i64],\n            [padding[0] as i64, padding[1] as i64],\n            ceil_mode,\n            count_include_pad,\n            None,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn max_pool1d(\n        x: TchTensor,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::max_pool1d(\n            &x.tensor,\n            kernel_size as i64,\n            stride as i64,\n            padding as i64,\n            dilation as i64,\n            ceil_mode,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn max_pool1d_with_indices(\n        x: TchTensor,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        dilation: usize,\n        ceil_mode: bool,\n    ) -> MaxPool1dWithIndices<Self> {\n        let (tensor, indices) = tch::Tensor::max_pool1d_with_indices(\n            &x.tensor,\n            kernel_size as i64,\n            stride as i64,\n            padding as i64,\n            dilation as i64,\n            ceil_mode,\n        );\n\n        MaxPool1dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))\n    }\n\n    fn max_pool2d(\n        x: TchTensor,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> TchTensor {\n        let tensor = tch::Tensor::max_pool2d(\n            &x.tensor,\n            [kernel_size[0] as i64, kernel_size[1] as i64],\n            [stride[0] as i64, stride[1] as i64],\n            [padding[0] as i64, padding[1] as i64],\n            [dilation[0] as i64, dilation[1] as i64],\n            ceil_mode,\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn max_pool2d_with_indices(\n        x: TchTensor,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n    ) -> MaxPool2dWithIndices<Self> {\n        let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(\n            &x.tensor,\n            [kernel_size[0] as i64, kernel_size[1] as i64],\n            [stride[0] as i64, stride[1] as i64],\n            [padding[0] as i64, padding[1] as i64],\n            [dilation[0] as i64, dilation[1] as i64],\n            ceil_mode,\n        );\n\n        MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))\n    }\n\n    fn max_pool2d_with_indices_backward(\n        x: TchTensor,\n        kernel_size: [usize; 2],\n        stride: [usize; 2],\n        padding: [usize; 2],\n        dilation: [usize; 2],\n        ceil_mode: bool,\n        output_grad: TchTensor,\n        indices: TchTensor,\n    ) -> MaxPool2dBackward<Self> {\n        let grad = tch::Tensor::max_pool2d_with_indices_backward(\n            &x.tensor,\n            &output_grad.tensor,\n            [kernel_size[0] as i64, kernel_size[1] as i64],\n            [stride[0] as i64, stride[1] as i64],\n            [padding[0] as i64, padding[1] as i64],\n            [dilation[0] as i64, dilation[1] as i64],\n            ceil_mode,\n            &indices.tensor,\n        );\n\n        MaxPool2dBackward::new(TchTensor::new(grad))\n    }\n\n    fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor {\n        let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64));\n\n        TchTensor::new(tensor)\n    }\n\n    fn adaptive_avg_pool2d_backward(x: TchTensor, grad: TchTensor) -> TchTensor {\n        let tensor = tch::Tensor::internal_adaptive_avg_pool2d_backward(&x.tensor, &grad.tensor);\n\n        TchTensor::new(tensor)\n    }\n\n    fn adaptive_avg_pool1d(x: TchTensor, output_size: usize) -> TchTensor {\n        let tensor = tch::Tensor::adaptive_avg_pool1d(&x.tensor, output_size as i64);\n\n        TchTensor::new(tensor)\n    }\n\n    fn interpolate(\n        x: TchTensor,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> TchTensor {\n        let output_size = output_size.map(|e| e as i64);\n\n        let align_corners = options.align_corners;\n        let tensor = match options.mode {\n            InterpolateMode::Nearest => {\n                tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)\n            }\n            InterpolateMode::Bilinear => {\n                tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)\n            }\n            InterpolateMode::Bicubic => {\n                tch::Tensor::upsample_bicubic2d(&x.tensor, output_size, align_corners, None, None)\n            }\n            InterpolateMode::Lanczos3 => {\n                panic!(\"lanczos3 interpolation is not supported by PyTorch/tch backend\")\n            }\n        };\n\n        TchTensor::new(tensor)\n    }\n\n    fn interpolate_backward(\n        x: TchTensor,\n        grad: TchTensor,\n        output_size: [usize; 2],\n        options: InterpolateOptions,\n    ) -> TchTensor {\n        let output_size = output_size.map(|e| e as i64);\n        let [n, c, h_in, w_in] = x.shape().dims();\n        let input_size = [n as i64, c as i64, h_in as i64, w_in as i64];\n        let align_corners = options.align_corners;\n\n        let tensor = match options.mode {\n            InterpolateMode::Nearest => tch::Tensor::upsample_nearest2d_backward(\n                &grad.tensor,\n                output_size,\n                input_size,\n                None,\n                None,\n            ),\n            InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(\n                &grad.tensor,\n                output_size,\n                input_size,\n                align_corners,\n                None,\n                None,\n            ),\n            InterpolateMode::Bicubic => tch::Tensor::upsample_bicubic2d_backward(\n                &grad.tensor,\n                output_size,\n                input_size,\n                align_corners,\n                None,\n                None,\n            ),\n            InterpolateMode::Lanczos3 => {\n                panic!(\"lanczos3 interpolation backward is not supported by PyTorch/tch backend\")\n            }\n        };\n\n        TchTensor::new(tensor)\n    }\n\n    fn attention(\n        query: TchTensor,\n        key: TchTensor,\n        value: TchTensor,\n        mask: Option<TchTensor>,\n        attn_bias: Option<TchTensor>,\n        options: AttentionModuleOptions,\n    ) -> TchTensor {\n        if attn_bias.is_some() {\n            return attention_fallback::<Self>(query, key, value, mask, attn_bias, options);\n        }\n\n        TchTensor::new(tch::Tensor::scaled_dot_product_attention(\n            &query.tensor,\n            &key.tensor,\n            &value.tensor,\n            mask.map(|m| m.tensor),\n            0.,\n            options.is_causal,\n            options.scale,\n            false,\n        ))\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/qtensor.rs",
    "content": "use burn_backend::{\n    ExecutionError, Shape, TensorData,\n    ops::QTensorOps,\n    quantization::{QuantScheme, QuantizationParametersPrimitive},\n    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},\n};\n\nuse crate::{LibTorch, LibTorchDevice, TchElement};\n\nimpl<E: TchElement> QTensorOps<Self> for LibTorch<E> {\n    fn q_from_data(_data: TensorData, _device: &LibTorchDevice) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn quantize(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n        _qparams: QuantizationParametersPrimitive<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn quantize_dynamic(\n        _tensor: FloatTensor<Self>,\n        _scheme: &QuantScheme,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn dequantize(_tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_device(_tensor: &QuantizedTensor<Self>) -> LibTorchDevice {\n        unimplemented!()\n    }\n\n    fn q_to_device(\n        _tensor: QuantizedTensor<Self>,\n        _device: &Device<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_reshape(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    async fn q_into_data(_tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {\n        unimplemented!()\n    }\n    fn q_swap_dims(\n        _tensor: QuantizedTensor<Self>,\n        _dim1: usize,\n        _dim2: usize,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_select(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _indices: IntTensor<Self>,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_slice(\n        _tensor: QuantizedTensor<Self>,\n        _slices: &[burn_backend::Slice],\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_argmax(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_argmin(_tensor: QuantizedTensor<Self>, _dim: usize) -> IntTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_max_dim_with_indices(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {\n        unimplemented!()\n    }\n\n    fn q_max_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_min_dim(_tensor: QuantizedTensor<Self>, _dim: usize) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_min_dim_with_indices(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {\n        unimplemented!()\n    }\n\n    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_sort(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _descending: bool,\n    ) -> QuantizedTensor<Self> {\n        unimplemented!()\n    }\n\n    fn q_sort_with_indices(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _descending: bool,\n    ) -> (QuantizedTensor<Self>, IntTensor<Self>) {\n        unimplemented!()\n    }\n\n    fn q_argsort(\n        _tensor: QuantizedTensor<Self>,\n        _dim: usize,\n        _descending: bool,\n    ) -> IntTensor<Self> {\n        unimplemented!()\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/tensor.rs",
    "content": "use super::TchOps;\nuse crate::{IntoKind, LibTorch, LibTorchDevice, TchShape, TchTensor, element::TchElement};\nuse burn_backend::backend::ExecutionError;\nuse burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor};\nuse burn_backend::{\n    DType, Distribution, FloatDType, Shape, TensorData, TensorMetadata, ops::FloatTensorOps,\n};\nuse burn_backend::{Scalar, bf16, f16};\n\nimpl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {\n    fn float_from_data(data: TensorData, device: &LibTorchDevice) -> TchTensor {\n        match data.dtype {\n            DType::F64 => TchTensor::from_data::<f64>(data, (*device).into()),\n            DType::F32 => TchTensor::from_data::<f32>(data, (*device).into()),\n            DType::F16 => TchTensor::from_data::<f16>(data, (*device).into()),\n            DType::BF16 => TchTensor::from_data::<bf16>(data, (*device).into()),\n            _ => unimplemented!(\"Unsupported dtype for `float_from_data`\"),\n        }\n    }\n\n    fn float_random(\n        shape: Shape,\n        distribution: Distribution,\n        device: &LibTorchDevice,\n    ) -> TchTensor {\n        match distribution {\n            Distribution::Default => {\n                let mut tensor = TchTensor::empty::<E>(shape, *device);\n                tensor\n                    .mut_ops(|tensor| tensor.rand_like_out(tensor))\n                    .unwrap()\n            }\n            Distribution::Bernoulli(prob) => {\n                let mut tensor = TchTensor::empty::<E>(shape, *device);\n                tensor\n                    .mut_ops(|tensor| tensor.f_bernoulli_float_(prob).unwrap())\n                    .unwrap()\n            }\n            Distribution::Uniform(from, to) => {\n                let mut tensor = TchTensor::empty::<E>(shape, *device);\n                tensor.mut_ops(|tensor| tensor.uniform_(from, to)).unwrap()\n            }\n            Distribution::Normal(mean, std) => {\n                let mut tensor = TchTensor::empty::<E>(shape, *device);\n                tensor.mut_ops(|tensor| tensor.normal_(mean, std)).unwrap()\n            }\n        }\n    }\n\n    fn float_repeat_dim(tensor: TchTensor, dim: usize, times: usize) -> TchTensor {\n        TchOps::repeat_dim(tensor, dim, times)\n    }\n\n    fn float_zeros(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {\n        let shape = TchShape::from(shape);\n        let device: tch::Device = (*device).into();\n\n        TchTensor::new(tch::Tensor::zeros(shape.dims, (dtype.into_kind(), device)))\n    }\n\n    fn float_ones(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {\n        let shape = TchShape::from(shape);\n        let device: tch::Device = (*device).into();\n\n        TchTensor::new(tch::Tensor::ones(shape.dims, (dtype.into_kind(), device)))\n    }\n\n    async fn float_into_data(tensor: TchTensor) -> Result<TensorData, ExecutionError> {\n        let shape = tensor.shape();\n        let tensor = Self::float_reshape(tensor.clone(), Shape::new([shape.num_elements()]));\n        Ok(match tensor.tensor.kind() {\n            tch::Kind::Half => {\n                let values = Vec::<f16>::try_from(&tensor).unwrap();\n                TensorData::new(values, shape)\n            }\n            tch::Kind::Float => {\n                let values = Vec::<f32>::try_from(&tensor).unwrap();\n                TensorData::new(values, shape)\n            }\n            tch::Kind::Double => {\n                let values = Vec::<f64>::try_from(&tensor).unwrap();\n                TensorData::new(values, shape)\n            }\n            tch::Kind::BFloat16 => {\n                let values = Vec::<bf16>::try_from(&tensor).unwrap();\n                TensorData::new(values, shape)\n            }\n            _ => panic!(\"Not a valid float kind\"),\n        })\n    }\n\n    fn float_device(tensor: &TchTensor) -> LibTorchDevice {\n        tensor.tensor.device().into()\n    }\n\n    fn float_to_device(tensor: TchTensor, device: &LibTorchDevice) -> TchTensor {\n        TchOps::to_device(tensor, device)\n    }\n\n    fn float_empty(shape: Shape, device: &LibTorchDevice, dtype: FloatDType) -> TchTensor {\n        let tensor = tch::Tensor::empty(\n            TchShape::from(shape).dims,\n            (dtype.into_kind(), (*device).into()),\n        );\n\n        TchTensor::new(tensor)\n    }\n\n    fn float_add(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::add(lhs, rhs)\n    }\n\n    fn float_add_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let rhs: f64 = rhs.elem();\n\n        lhs.unary_ops(\n            |mut tensor| tensor.f_add_scalar_(rhs).unwrap(),\n            |tensor| tensor.f_add_scalar(rhs).unwrap(),\n        )\n    }\n\n    fn float_sub(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::sub(lhs, rhs)\n    }\n\n    fn float_sub_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let rhs: f64 = rhs.elem();\n\n        lhs.unary_ops(\n            |mut tensor| tensor.f_sub_scalar_(rhs).unwrap(),\n            |tensor| tensor.f_sub_scalar(rhs).unwrap(),\n        )\n    }\n\n    fn float_mul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::mul(lhs, rhs)\n    }\n\n    fn float_mul_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let rhs: f64 = rhs.elem();\n\n        lhs.unary_ops(\n            |mut tensor| tensor.f_mul_scalar_(rhs).unwrap(),\n            |tensor| tensor.f_mul_scalar(rhs).unwrap(),\n        )\n    }\n\n    fn float_div(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::div(lhs, rhs)\n    }\n\n    fn float_div_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let rhs: f64 = rhs.elem();\n\n        lhs.unary_ops(\n            |mut tensor| tensor.f_div_scalar_(rhs).unwrap(),\n            |tensor| tensor.f_div_scalar(rhs).unwrap(),\n        )\n    }\n\n    fn float_remainder(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::remainder(lhs, rhs)\n    }\n\n    fn float_remainder_scalar(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        let rhs: f64 = rhs.elem();\n\n        lhs.unary_ops(\n            |tensor| tensor.f_remainder(rhs).unwrap(),\n            |tensor| tensor.f_remainder(rhs).unwrap(),\n        )\n    }\n\n    fn float_matmul(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        let tensor = lhs.tensor.matmul(&rhs.tensor);\n        TchTensor::new(tensor)\n    }\n\n    fn float_cross(lhs: TchTensor, rhs: TchTensor, dim: usize) -> TchTensor {\n        let tensor = lhs.tensor.cross(&rhs.tensor, dim as i64);\n        TchTensor::new(tensor)\n    }\n\n    fn float_recip(tensor: TchTensor) -> TchTensor {\n        TchTensor::new(tensor.tensor.reciprocal())\n    }\n\n    fn float_swap_dims(tensor: TchTensor, dim1: usize, dim2: usize) -> TchTensor {\n        TchOps::swap_dims(tensor, dim1, dim2)\n    }\n\n    fn float_reshape(tensor: TchTensor, shape: Shape) -> TchTensor {\n        TchOps::reshape(tensor, shape)\n    }\n\n    fn float_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor {\n        TchOps::gather(dim, tensor, indices)\n    }\n\n    fn float_scatter_add(\n        dim: usize,\n        tensor: TchTensor,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::scatter(dim, tensor, indices, value)\n    }\n\n    fn float_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {\n        TchOps::index_select_dim(tensor, dim, indices)\n    }\n\n    fn float_select_add(\n        tensor: TchTensor,\n        dim: usize,\n        indices: TchTensor,\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::select_assign(tensor, dim, indices, value)\n    }\n\n    fn float_slice(tensor: TchTensor, slices: &[burn_backend::Slice]) -> TchTensor {\n        TchOps::slice_with_steps(tensor, slices)\n    }\n\n    fn float_slice_assign(\n        tensor: TchTensor,\n        slices: &[burn_backend::Slice],\n        value: TchTensor,\n    ) -> TchTensor {\n        TchOps::slice_assign(tensor, slices, value)\n    }\n\n    fn float_mask_where(tensor: TchTensor, mask: TchTensor, value: TchTensor) -> TchTensor {\n        let output = value.tensor.where_self(&mask.tensor, &tensor.tensor);\n\n        TchTensor::new(output)\n    }\n\n    fn float_mask_fill(tensor: TchTensor, mask: TchTensor, value: Scalar) -> TchTensor {\n        let value: f64 = value.elem();\n\n        tensor.unary_ops(\n            |mut tensor| tensor.f_masked_fill_(&mask.tensor, value).unwrap(),\n            |tensor| tensor.f_masked_fill(&mask.tensor, value).unwrap(),\n        )\n    }\n\n    fn float_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::equal(lhs, rhs)\n    }\n\n    fn float_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::equal_elem(lhs, rhs.elem::<f64>())\n    }\n\n    fn float_greater(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::greater(lhs, rhs)\n    }\n\n    fn float_greater_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::greater_elem(lhs, rhs.elem::<f64>())\n    }\n\n    fn float_greater_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::greater_equal(lhs, rhs)\n    }\n\n    fn float_greater_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::greater_equal_elem(lhs, rhs.elem::<f64>())\n    }\n\n    fn float_lower(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::lower(lhs, rhs)\n    }\n\n    fn float_lower_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::lower_elem(lhs, rhs.elem::<f64>())\n    }\n\n    fn float_lower_equal(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::lower_equal(lhs, rhs)\n    }\n\n    fn float_lower_equal_elem(lhs: TchTensor, rhs: Scalar) -> TchTensor {\n        TchOps::lower_equal_elem(lhs, rhs.elem::<f64>())\n    }\n\n    fn float_mean(tensor: TchTensor) -> TchTensor {\n        TchOps::mean(tensor)\n    }\n\n    fn float_sum(tensor: TchTensor) -> TchTensor {\n        TchOps::sum(tensor)\n    }\n\n    fn float_sum_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::sum_dim(tensor, dim)\n    }\n\n    fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::mean_dim(tensor, dim)\n    }\n\n    fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cumsum(tensor, dim)\n    }\n\n    fn float_cumprod(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cumprod(tensor, dim)\n    }\n\n    fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cummin(tensor, dim)\n    }\n\n    fn float_cummax(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::cummax(tensor, dim)\n    }\n\n    fn float_prod(tensor: TchTensor) -> TchTensor {\n        TchOps::prod(tensor)\n    }\n\n    fn float_prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::prod_dim(tensor, dim)\n    }\n\n    fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::argmax(tensor, dim)\n    }\n\n    fn float_argmin(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::argmin(tensor, dim)\n    }\n\n    fn float_max_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::max_dim(tensor, dim)\n    }\n\n    fn float_max_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        TchOps::max_dim_with_indices(tensor, dim)\n    }\n\n    fn float_min_dim(tensor: TchTensor, dim: usize) -> TchTensor {\n        TchOps::min_dim(tensor, dim)\n    }\n\n    fn float_min_dim_with_indices(tensor: TchTensor, dim: usize) -> (TchTensor, TchTensor) {\n        TchOps::min_dim_with_indices(tensor, dim)\n    }\n\n    fn float_exp(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.exp_(), |tensor| tensor.exp())\n    }\n\n    fn float_log(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.log_(), |tensor| tensor.log())\n    }\n\n    fn float_log1p(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.log1p_(), |tensor| tensor.log1p())\n    }\n\n    fn float_powf_scalar_impl(tensor: TchTensor, value: Scalar) -> TchTensor {\n        tensor.unary_ops(\n            |mut tensor| tensor.f_pow_(value.elem::<f64>()).unwrap(),\n            |tensor| tensor.pow_tensor_scalar(value.elem::<f64>()),\n        )\n    }\n\n    fn float_sqrt(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.sqrt_(), |tensor| tensor.sqrt())\n    }\n\n    fn float_abs(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.abs_(), |tensor| tensor.abs())\n    }\n\n    fn float_cos(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.cos_(), |tensor| tensor.cos())\n    }\n\n    fn float_cosh(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.cosh_(), |tensor| tensor.cosh())\n    }\n\n    fn float_sin(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.sin_(), |tensor| tensor.sin())\n    }\n\n    fn float_sinh(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.sinh_(), |tensor| tensor.sinh())\n    }\n\n    fn float_tan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.tan_(), |tensor| tensor.tan())\n    }\n\n    fn float_tanh(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())\n    }\n\n    fn float_acos(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.acos_(), |tensor| tensor.acos())\n    }\n\n    fn float_acosh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.acosh_(), |tensor| tensor.acosh())\n    }\n\n    fn float_asin(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.asin_(), |tensor| tensor.asin())\n    }\n\n    fn float_asinh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.asinh_(), |tensor| tensor.asinh())\n    }\n\n    fn float_atan(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.atan_(), |tensor| tensor.atan())\n    }\n\n    fn float_atanh(tensor: FloatTensor<Self>) -> FloatTensor<Self> {\n        tensor.unary_ops(|mut tensor| tensor.atanh_(), |tensor| tensor.atanh())\n    }\n\n    fn float_atan2(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {\n        TchOps::atan2(lhs, rhs)\n    }\n\n    fn float_round(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())\n    }\n\n    fn float_floor(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())\n    }\n\n    fn float_ceil(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())\n    }\n\n    fn float_trunc(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.trunc_(), |tensor| tensor.trunc())\n    }\n\n    fn float_erf(tensor: TchTensor) -> TchTensor {\n        tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())\n    }\n\n    fn float_cat(tensors: Vec<TchTensor>, dim: usize) -> TchTensor {\n        TchOps::cat(tensors, dim)\n    }\n\n    fn float_clamp_min(tensor: TchTensor, min: Scalar) -> TchTensor {\n        TchOps::clamp_min(tensor, min.elem::<f64>())\n    }\n\n    fn float_clamp_max(tensor: TchTensor, max: Scalar) -> TchTensor {\n        TchOps::clamp_max(tensor, max.elem::<f64>())\n    }\n\n    fn float_clamp(tensor: TchTensor, min: Scalar, max: Scalar) -> TchTensor {\n        TchOps::clamp(tensor, min.elem::<f64>(), max.elem::<f64>())\n    }\n\n    fn float_into_int(tensor: TchTensor) -> TchTensor {\n        let tensor = tensor.tensor.to_kind(tch::Kind::Int64);\n        TchTensor::new(tensor)\n    }\n\n    fn float_powf(lhs: TchTensor, rhs: TchTensor) -> TchTensor {\n        TchOps::powf(lhs, rhs)\n    }\n\n    fn float_permute(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        TchOps::permute(tensor, axes)\n    }\n\n    fn float_flip(tensor: TchTensor, axes: &[usize]) -> TchTensor {\n        TchOps::flip(tensor, axes)\n    }\n\n    fn float_sign(tensor: TchTensor) -> TchTensor {\n        TchOps::sign(tensor)\n    }\n\n    fn float_expand(tensor: TchTensor, shape: Shape) -> TchTensor {\n        TchOps::expand(tensor, shape)\n    }\n\n    fn float_sort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor {\n        TchOps::sort(tensor, dim, descending)\n    }\n\n    fn float_sort_with_indices(\n        tensor: TchTensor,\n        dim: usize,\n        descending: bool,\n    ) -> (TchTensor, TchTensor) {\n        TchOps::sort_with_indices(tensor, dim, descending)\n    }\n\n    fn float_argsort(tensor: TchTensor, dim: usize, descending: bool) -> IntTensor<Self> {\n        TchOps::argsort(tensor, dim, descending)\n    }\n\n    fn float_cast(tensor: TchTensor, dtype: FloatDType) -> TchTensor {\n        // NOTE: when dtypes of inputs to an arithmetic operation differ, tch handles type\n        // promotion based on a set of rules: https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc\n\n        // Type promotion is not automatic on all backends so this behavior might differ\n        let kind = dtype.into_kind();\n\n        if tensor.tensor.kind() == kind {\n            tensor\n        } else {\n            TchTensor::new(tensor.tensor.to_kind(kind))\n        }\n    }\n\n    fn float_unfold(\n        tensor: FloatTensor<Self>,\n        dim: usize,\n        size: usize,\n        step: usize,\n    ) -> FloatTensor<Self> {\n        TchOps::unfold(tensor, dim, size, step)\n    }\n\n    fn float_is_nan(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        TchTensor::new(tensor.tensor.isnan())\n    }\n\n    fn float_is_inf(tensor: FloatTensor<Self>) -> BoolTensor<Self> {\n        TchTensor::new(tensor.tensor.isinf())\n    }\n}\n"
  },
  {
    "path": "crates/burn-tch/src/ops/transaction.rs",
    "content": "use burn_backend::ops::TransactionOps;\n\nuse crate::{LibTorch, TchElement};\n\nimpl<E: TchElement> TransactionOps<Self> for LibTorch<E> {}\n"
  },
  {
    "path": "crates/burn-tch/src/tensor.rs",
    "content": "use crate::{LibTorchDevice, TchElement};\nuse burn_backend::{BoolStore, DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};\nuse libc::c_void;\nuse std::sync::Arc;\n\n/// A reference to a tensor storage.\n///\n/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.\n#[allow(clippy::arc_with_non_send_sync)]\npub type StorageRef = Arc<*mut c_void>;\n\n/// A reference to a tensor storage.\n#[derive(PartialEq, Debug, Clone)]\npub enum Storage {\n    /// When a tensor is a partial view of another tensor.\n    View {\n        /// Storage reference for the whole buffer.\n        buffer_ref: StorageRef,\n        /// Storage reference for the partial buffer.\n        view_ref: StorageRef,\n    },\n    /// When a tensor use all of its buffer.\n    Owned {\n        /// Storage reference for the whole buffer.\n        buffer_ref: StorageRef,\n    },\n}\n\nimpl Storage {\n    /// Check if the storage can be used inplace.\n    pub fn can_mut(&self) -> bool {\n        match self {\n            Storage::View {\n                buffer_ref: start_ref,\n                view_ref,\n            } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,\n            Storage::Owned {\n                buffer_ref: start_ref,\n            } => Arc::strong_count(start_ref) == 1,\n        }\n    }\n\n    /// Get the whole buffer reference.\n    pub fn buffer_ref(&self) -> &StorageRef {\n        match self {\n            Storage::View {\n                buffer_ref: start_ref,\n                view_ref: _,\n            } => start_ref,\n            Storage::Owned {\n                buffer_ref: start_ref,\n            } => start_ref,\n        }\n    }\n}\n\n/// A tensor using the tch backend.\n#[derive(Debug, PartialEq)]\npub struct TchTensor {\n    /// Handle to the tensor. Call methods on this field.\n    pub tensor: tch::Tensor,\n\n    /// The tensor's storage\n    pub storage: Storage,\n}\n\nimpl TensorMetadata for TchTensor {\n    fn dtype(&self) -> DType {\n        match self.tensor.kind() {\n            tch::Kind::Uint8 => DType::U8,\n            tch::Kind::Int8 => DType::I8,\n            tch::Kind::Int16 => DType::I16,\n            tch::Kind::Int => DType::I32,\n            tch::Kind::Int64 => DType::I64,\n            tch::Kind::Half => DType::F16,\n            tch::Kind::Float => DType::F32,\n            tch::Kind::Double => DType::F64,\n            tch::Kind::Bool => DType::Bool(BoolStore::Native),\n            tch::Kind::BFloat16 => DType::BF16,\n            // Complex and quantization types are not valid/implemented.\n            _ => unimplemented!(),\n        }\n    }\n\n    fn shape(&self) -> Shape {\n        Shape::from(self.tensor.size())\n    }\n\n    fn rank(&self) -> usize {\n        self.tensor.dim()\n    }\n}\n\nimpl burn_backend::QTensorPrimitive for TchTensor {\n    fn scheme(&self) -> &burn_backend::quantization::QuantScheme {\n        unimplemented!(\"Quantization is not supported\")\n    }\n}\n\nimpl core::fmt::Display for TchTensor {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        write!(f, \"{}\", self.tensor)\n    }\n}\n\npub(crate) trait IntoKind {\n    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError>;\n    fn into_kind(self) -> tch::Kind\n    where\n        Self: Sized,\n    {\n        self.try_into_kind().unwrap()\n    }\n}\n\nimpl IntoKind for IntDType {\n    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {\n        let dtype: DType = self.into();\n        dtype.try_into_kind()\n    }\n}\n\nimpl IntoKind for FloatDType {\n    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {\n        let dtype: DType = self.into();\n        dtype.try_into_kind()\n    }\n}\n\nimpl IntoKind for DType {\n    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {\n        match self {\n            DType::F64 => Ok(tch::Kind::Double),\n            DType::F32 => Ok(tch::Kind::Float),\n            DType::Flex32 => Ok(tch::Kind::Float),\n            DType::F16 => Ok(tch::Kind::Half),\n            DType::BF16 => Ok(tch::Kind::BFloat16),\n            DType::I64 => Ok(tch::Kind::Int64),\n            DType::I32 => Ok(tch::Kind::Int),\n            DType::I16 => Ok(tch::Kind::Int16),\n            DType::I8 => Ok(tch::Kind::Int8),\n            DType::U8 => Ok(tch::Kind::Uint8),\n            DType::Bool(BoolStore::Native) => Ok(tch::Kind::Bool),\n            other => Err(tch::TchError::Kind(format!(\"Unsupported dtype {other:?}\"))),\n        }\n    }\n}\n\nimpl TchTensor {\n    /// Create a new tensor.\n    ///\n    /// Note that if the tensor was created from an operation that may reuse the same tensor\n    /// storage as the parent, you should use [from_existing](TchTensor::from_existing)\n    /// instead.\n    pub fn new(tensor: tch::Tensor) -> Self {\n        #[allow(clippy::arc_with_non_send_sync)]\n        let storage = Storage::Owned {\n            buffer_ref: Arc::new(tensor.data_ptr()),\n        };\n\n        Self { tensor, storage }\n    }\n\n    /// Create a tensor that was created from an operation executed on a parent tensor.\n    ///\n    /// If the child tensor shared the same storage as its parent, it will be cloned, effectively\n    /// tracking how much tensors point to the same memory space.\n    pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {\n        let storage_child = tensor.data_ptr();\n        let mut is_a_new_tensor = true;\n\n        match &storage_parent {\n            Storage::View {\n                buffer_ref: start_ref,\n                view_ref,\n            } => {\n                if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {\n                    is_a_new_tensor = false;\n                }\n            }\n            Storage::Owned {\n                buffer_ref: start_ref,\n            } => {\n                if storage_child == *start_ref.as_ref() {\n                    is_a_new_tensor = false;\n                }\n            }\n        };\n\n        let storage = match is_a_new_tensor {\n            true => Storage::Owned {\n                #[allow(clippy::arc_with_non_send_sync)]\n                buffer_ref: Arc::new(storage_child),\n            },\n            false => storage_parent.clone(),\n        };\n\n        Self { tensor, storage }\n    }\n\n    /// Create a tensor that uses a part of its parent tensor such as slice and narrow.\n    pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {\n        let storage = Storage::View {\n            buffer_ref: storage_parent.buffer_ref().clone(),\n            #[allow(clippy::arc_with_non_send_sync)]\n            view_ref: Arc::new(tensor.data_ptr()),\n        };\n        Self { tensor, storage }\n    }\n}\n\n// This is safe since we don't use autodiff from LibTorch.\n// Also, atomic reference counting is used to know if the tensor's data can be reused.\n// If there are multiple reference on the same tensor, it becomes read only.\nunsafe impl Send for TchTensor {}\nunsafe impl Sync for TchTensor {}\n\nimpl TchTensor {\n    /// Checks if the tensor can be mutated in-place.\n    ///\n    /// Returns `true` if the tensor's stride does not contain zero (no broadcasting)\n    /// and the storage can be mutated.\n    pub fn can_mut(&self) -> bool {\n        let stride_contains_zero = self.tensor.stride().contains(&0);\n\n        !stride_contains_zero && self.storage.can_mut()\n    }\n\n    /// Executes an operation on a tensor if the data can be reused.\n    pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(\n        &mut self,\n        func: F,\n    ) -> Option<TchTensor> {\n        if !self.can_mut() {\n            return None;\n        }\n\n        let data = self.storage.clone();\n        Some(TchTensor::from_existing(func(&mut self.tensor), data))\n    }\n\n    /// Executes a unary operation, reusing the tensor data if possible.\n    pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor\n    where\n        FOwn: Fn(tch::Tensor) -> tch::Tensor,\n        FRef: Fn(&tch::Tensor) -> tch::Tensor,\n    {\n        if !self.can_mut() {\n            return TchTensor::from_existing(fref(&self.tensor), self.storage);\n        }\n\n        TchTensor::from_existing(fown(self.tensor), self.storage)\n    }\n\n    /// Executes a binary operation, reusing the tensor data if possible.\n    pub fn binary_ops_tensor<FLMut, FRMut, FRef>(\n        mut lhs: Self,\n        mut rhs: Self,\n        flmut: FLMut,\n        frmut: FRMut,\n        fref: FRef,\n    ) -> TchTensor\n    where\n        FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,\n        FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,\n        FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,\n    {\n        let lhs_shape = lhs.shape();\n        let rhs_shape = rhs.shape();\n\n        // Both lhs and rhs are expected to have the same rank\n        let d_out = lhs_shape.num_dims();\n        let mut out_shape = Shape::from(vec![1usize; d_out]);\n\n        for i in 0..d_out {\n            out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);\n        }\n\n        let num_elements_out = out_shape.num_elements();\n\n        // Attempt to mutate lhs tensor\n        if lhs_shape.num_elements() == num_elements_out\n            && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))\n        {\n            return output;\n        }\n\n        // Attempt to mutate rhs tensor\n        if rhs_shape.num_elements() == num_elements_out\n            && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))\n        {\n            return output;\n        }\n\n        let storage = lhs.storage;\n        let tensor = fref(&lhs.tensor, &rhs.tensor);\n\n        TchTensor::from_existing(tensor, storage)\n    }\n}\n\nimpl Clone for TchTensor {\n    fn clone(&self) -> Self {\n        Self {\n            tensor: self.tensor.shallow_clone(),\n            storage: self.storage.clone(),\n        }\n    }\n}\n\n/// A shape that can be used by LibTorch.\n#[derive(Debug)]\npub struct TchShape {\n    /// The shape's dimensions.\n    pub dims: Vec<i64>,\n}\n\nimpl From<Shape> for TchShape {\n    fn from(shape: Shape) -> Self {\n        TchShape {\n            dims: shape.iter().map(|d| *d as i64).collect(),\n        }\n    }\n}\n\nimpl From<&[usize]> for TchShape {\n    fn from(shape: &[usize]) -> Self {\n        TchShape {\n            dims: shape.iter().map(|d| *d as i64).collect(),\n        }\n    }\n}\n\nimpl TchTensor {\n    /// Creates a new tensor from a shape and a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The tensor's data.\n    /// * `device` - The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor.\n    pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {\n        let shape_tch = TchShape::from(data.shape.as_slice());\n        let tensor =\n            tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device);\n\n        Self::new(tensor)\n    }\n}\n\nimpl TchTensor {\n    /// Creates an empty tensor from a shape and a device.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// A new empty tensor.\n    pub fn empty<E: TchElement>(shape: Shape, device: LibTorchDevice) -> Self {\n        let shape_tch = TchShape::from(shape);\n        let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into()));\n\n        Self::new(tensor)\n    }\n}\n\n// Adapted from `tch` to use patched `T::kind()` instead of `T::KIND` which is incorrect for bf16.\n// TODO: remove when fixed in `tch` release (https://github.com/LaurentMazare/tch-rs/pull/996).\nimpl<T: TchElement + Copy> TryFrom<&TchTensor> for Vec<T> {\n    type Error = tch::TchError;\n    fn try_from(tensor: &TchTensor) -> Result<Self, Self::Error> {\n        let tensor = &tensor.tensor;\n        let size = tensor.size();\n        if size.len() != 1 {\n            Err(tch::TchError::Convert(format!(\n                \"Attempting to convert a Tensor with {} dimensions to flat vector\",\n                size.len()\n            )))?;\n        }\n        let numel = size[0] as usize;\n        let mut vec = vec![T::ZERO; numel];\n        // Adapted to use patched `T::kind()` instead\n        // TODO: tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;\n        f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?;\n        Ok(vec)\n    }\n}\n\nunsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option<String> {\n    if !ptr.is_null() {\n        unsafe {\n            let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();\n            libc::free(ptr as *mut libc::c_void);\n            Some(str)\n        }\n    } else {\n        None\n    }\n}\n\n/// Copies `numel` elements from `self` to `dst`.\nfn f_copy_data<T: TchElement>(\n    tensor: &mut tch::Tensor,\n    dst: &mut [T],\n    numel: usize,\n) -> Result<(), tch::TchError> {\n    if T::kind() != tensor.f_kind()? {\n        return Err(tch::TchError::Kind(format!(\n            \"incoherent elt kind, {:?} != {:?}\",\n            tensor.f_kind(),\n            T::kind()\n        )));\n    }\n    if dst.len() < numel {\n        return Err(tch::TchError::Shape(format!(\"slice len < {numel}\")));\n    }\n\n    unsafe {\n        torch_sys::at_copy_data(\n            tensor.as_mut_ptr(),\n            dst.as_mut_ptr() as *const c_void,\n            numel,\n            T::kind().elt_size_in_bytes(),\n        );\n        match ptr_to_string(torch_sys::get_and_reset_last_err()) {\n            None => Ok(()),\n            Some(c_error) => Err(tch::TchError::Torch(c_error)),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_backend::ops::FloatTensorOps;\n    use burn_backend::{Backend, quantization::QuantScheme, read_sync};\n\n    type B = crate::LibTorch<f32>;\n\n    #[test]\n    fn should_have_bf16_kind() {\n        let data = TensorData::from([4.0, 4.0]);\n        let tensor_1: TchTensor = B::float_from_data(data, &Default::default());\n        let tensor_2 = B::float_cast(tensor_1, DType::BF16.into());\n\n        assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16);\n\n        let out = read_sync(B::float_into_data(tensor_2)).unwrap();\n\n        out.assert_eq(&TensorData::from([4.0, 4.0]), false);\n    }\n\n    #[test]\n    fn should_support_dtypes() {\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F64));\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::Flex32));\n        assert!(B::supports_dtype(&device, DType::F16));\n        assert!(B::supports_dtype(&device, DType::BF16));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::I16));\n        assert!(B::supports_dtype(&device, DType::I8));\n        assert!(B::supports_dtype(&device, DType::U8));\n        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n\n        assert!(!B::supports_dtype(&device, DType::U64));\n        assert!(!B::supports_dtype(&device, DType::U32));\n        assert!(!B::supports_dtype(&device, DType::U16));\n        assert!(!B::supports_dtype(\n            &device,\n            DType::QFloat(QuantScheme::default())\n        ));\n    }\n\n    #[test]\n    fn should_support_from_bf16() {\n        let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16);\n        let tensor_1: TchTensor = B::float_from_data(data, &Default::default());\n        let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16);\n        let tensor_2 = B::float_from_data(data, &Default::default());\n\n        let tensor_3 = B::float_add(tensor_1, tensor_2);\n\n        assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16);\n\n        let out = read_sync(B::float_into_data(tensor_3)).unwrap();\n\n        out.assert_eq(&TensorData::from([[3.0], [3.0]]), false);\n    }\n}\n\nunsafe extern \"C\" {\n    /// Dummy function to get CUDA to link properly\n    pub fn dummy_cuda_dependency();\n}\n\n#[used]\nstatic INIT_ARRAY: [unsafe extern \"C\" fn(); 1] = [dummy_cuda_dependency];\n"
  },
  {
    "path": "crates/burn-tensor/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\", \"no-std\", \"embedded\", \"wasm\"]\ndescription = \"Tensor library with user-friendly APIs and automatic differentiation support\"\ndocumentation = \"https://docs.rs/burn-tensor\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-tensor\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\"]\ndoc = [\"default\"]\nstd = [\n    \"num-traits/std\",\n    \"burn-std/std\",\n    \"burn-backend/std\",\n    \"colored\",\n]\ntracing = [\n    \"burn-std/tracing\",\n    \"burn-backend/tracing\",\n]\n\ncubecl = [\"burn-std/cubecl\", \"burn-backend/cubecl\"]\ncubecl-cuda = [\"burn-backend/cubecl-cuda\"]\ncubecl-hip = [\"burn-backend/cubecl-hip\"]\ncubecl-wgpu = [\"burn-backend/cubecl-wgpu\"]\ncubecl-cpu = [\"burn-backend/cubecl-cpu\"]\n\n[dependencies]\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false }\n\ncolored = { workspace = true, optional = true }\nderive-new = { workspace = true }\nnum-traits = { workspace = true }\n\n# Device\nhashbrown = { workspace = true }\nspin = { workspace = true }\nthiserror = { workspace = true }\n\n# Serialization\nserde = { workspace = true }\n\n[target.'cfg(not(target_has_atomic = \"ptr\"))'.dependencies]\nportable-atomic-util = { workspace = true }\n\n[dev-dependencies]\nserial_test = { workspace = true }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\", \"--html-in-header\", \"katex-header.html\"]\n"
  },
  {
    "path": "crates/burn-tensor/README.md",
    "content": "# Burn Tensor\n\n> [Burn](https://github.com/tracel-ai/burn) Tensor Library\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor.svg)](https://crates.io/crates/burn-tensor)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tensor/blob/master/README.md)\n\nThis library provides the core abstractions required to run tensor operations with Burn.\n\n`Tensor`s are generic over the backend to allow users to perform operations using different\n`Backend` implementations. Burn's tensors also support auto-differentiation thanks to the\n`AutodiffBackend` trait.\n"
  },
  {
    "path": "crates/burn-tensor/src/device.rs",
    "content": "use alloc::format;\nuse alloc::string::String;\nuse burn_backend::{Backend, Device, DeviceId, DeviceOps};\nuse burn_std::stub::RwLock;\nuse burn_std::{DType, FloatDType, IntDType};\n\n#[cfg(target_has_atomic = \"ptr\")]\nuse alloc::sync::Arc;\n\n#[cfg(not(target_has_atomic = \"ptr\"))]\nuse portable_atomic_util::Arc;\nuse thiserror::Error;\n\nuse core::any::TypeId;\n\n#[cfg(feature = \"std\")]\npub use std::collections::HashMap;\n#[cfg(feature = \"std\")]\nuse std::sync::LazyLock;\n\n#[cfg(not(feature = \"std\"))]\npub use hashbrown::HashMap;\n#[cfg(not(feature = \"std\"))]\nuse spin::Lazy as LazyLock;\n\n/// Policy controlling default device behavior.\n///\n/// This includes default data types used for tensor creation.\n#[derive(Debug, Clone, Copy, Default)]\npub(crate) struct DevicePolicy {\n    /// Default floating-point data type for tensor creation.\n    float_dtype: Option<FloatDType>,\n    /// Default integer data type for tensor creation.\n    int_dtype: Option<IntDType>,\n}\n\nimpl DevicePolicy {\n    /// Returns the default floating-point data type used for tensor creation.\n    pub(crate) fn float_dtype(&self) -> Option<FloatDType> {\n        self.float_dtype\n    }\n\n    /// Returns the default integer data type used for tensor creation.\n    pub(crate) fn int_dtype(&self) -> Option<IntDType> {\n        self.int_dtype\n    }\n\n    /// Sets the default floating-point data type.\n    pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) {\n        self.float_dtype = Some(dtype);\n    }\n\n    /// Sets the default integer data type.\n    pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) {\n        self.int_dtype = Some(dtype);\n    }\n}\n\n/// Key for the registry: physical device type + device id\ntype RegistryKey = (DeviceId, TypeId);\n\n/// Global registry mapping devices to their policies.\nstatic REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<DevicePolicy>>>> =\n    LazyLock::new(|| RwLock::new(HashMap::new()));\n\n/// Device policy management for controlling default tensor creation behavior.\n///\n/// # Policy Semantics\n///\n/// Device policies use snapshot semantics: when you retrieve a policy with\n/// [`get_device_policy`], you get an immutable snapshot of the current configuration.\n/// Updates to the policy (via [`set_default_dtypes`], [`set_default_float_dtype`], etc.)\n/// only affect future policy retrievals, not existing references.\n///\n/// This is intended for the common case where policies are set once during\n/// initialization and then read frequently during tensor creation.\nstruct DevicePolicyRegistry;\n\nimpl DevicePolicyRegistry {\n    /// Get the policy for a physical device type and device id.\n    ///\n    /// If no policy exists yet, a default one is created and stored.\n    fn get<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {\n        let key = Self::key(device);\n\n        if let Some(policy) = REGISTRY.read().unwrap().get(&key) {\n            return Arc::clone(policy);\n        }\n\n        let mut map = REGISTRY.write().unwrap();\n        Arc::clone(\n            map.entry(key)\n                .or_insert_with(|| Arc::new(DevicePolicy::default())),\n        )\n    }\n\n    /// Mutate the policy for a given device.\n    fn update<D: DeviceOps>(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) {\n        let key = Self::key(device);\n        let mut map = REGISTRY.write().unwrap();\n\n        let policy = map\n            .entry(key)\n            .or_insert_with(|| Arc::new(DevicePolicy::default()));\n\n        // Update the policy\n        let policy_mut = Arc::make_mut(policy);\n        update_fn(policy_mut);\n    }\n\n    /// Returns the device registry key.\n    fn key<D: Device>(device: &D) -> RegistryKey {\n        (device.to_id(), TypeId::of::<D>())\n    }\n}\n\n/// Get the [`device`'s policy](DevicePolicy).\n///\n/// Returns an immutable snapshot of the device's current policy. If the policy\n/// is updated after retrieval, this snapshot will not reflect those changes.\npub(crate) fn get_device_policy<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {\n    DevicePolicyRegistry::get(device)\n}\n\n/// Errors that can occur during device-related operations.\n///\n/// This covers errors related to hardware capability mismatches, such as\n/// requesting a data type not supported by the device, and configuration\n/// errors like attempting to change a policy in an invalid context.\n#[derive(Debug, Error)]\npub enum DeviceError {\n    /// Unsupported data type by the device.\n    #[error(\"Device {device} does not support the requested data type {dtype:?}\")]\n    UnsupportedDType {\n        /// The string representation of the device.\n        device: String,\n        /// The data type that caused the error.\n        dtype: DType,\n    },\n    // TODO: `InvalidContext` if a device policy cannot be changed after init / during training / etc.\n}\n\nimpl DeviceError {\n    /// Helper to create a [`DeviceError::UnsupportedDType`] from any device.\n    pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {\n        Self::UnsupportedDType {\n            device: format!(\"{device:?}\"),\n            dtype,\n        }\n    }\n}\n\nfn check_dtype_support<B: Backend>(\n    device: &B::Device,\n    dtype: impl Into<DType>,\n) -> Result<(), DeviceError> {\n    let dtype = dtype.into();\n    // Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized\n    // operations should not be used as default.\n    if B::supports_dtype(device, dtype) {\n        Ok(())\n    } else {\n        Err(DeviceError::unsupported_dtype(device, dtype))\n    }\n}\n\n/// Sets the default data types for the device.\n///\n/// This updates the device's default data types used for tensor creation.\n/// The policy should typically be set once during initialization and then\n/// remains global for all subsequent operations on that device.\n///\n/// # Example\n///\n/// ```rust\n/// use burn_tensor::backend::Backend;\n/// use burn_tensor::{DType, Int, Tensor, set_default_dtypes};\n///\n/// fn example<B: Backend>() {\n///     let device = B::Device::default();\n///     \n///     // Update the device policy\n///     set_default_dtypes::<B>(&device, DType::F16, DType::I32);\n///     \n///     // All float tensors created after this will use F16 by default\n///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);\n///     // All int tensors created after this will use I32 default\n///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);\n/// }\n/// ```\npub fn set_default_dtypes<B: Backend>(\n    device: &B::Device,\n    float_dtype: impl Into<FloatDType>,\n    int_dtype: impl Into<IntDType>,\n) -> Result<(), DeviceError> {\n    let float_dtype = float_dtype.into();\n    let int_dtype = int_dtype.into();\n    check_dtype_support::<B>(device, float_dtype)?;\n    check_dtype_support::<B>(device, int_dtype)?;\n\n    set_default_dtypes_unchecked(device, float_dtype, int_dtype);\n    Ok(())\n}\n\n/// Sets the default floating-point data type for the device.\n///\n/// This updates the device's default data types used for tensor creation.\n/// The policy should typically be set once during initialization and then\n/// remains global for all subsequent operations on that device.\n///\n/// # Example\n///\n/// ```rust\n/// use burn_tensor::backend::Backend;\n/// use burn_tensor::{DType, Tensor, set_default_float_dtype};\n///\n/// fn example<B: Backend>() {\n///     let device = B::Device::default();\n///     \n///     // Update the device policy\n///     set_default_float_dtype::<B>(&device, DType::F16);\n///     \n///     // All float tensors created after this will use F16 by default\n///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);\n/// }\n/// ```\npub fn set_default_float_dtype<B: Backend>(\n    device: &B::Device,\n    dtype: impl Into<FloatDType>,\n) -> Result<(), DeviceError> {\n    let dtype = dtype.into();\n    check_dtype_support::<B>(device, dtype)?;\n\n    set_default_float_dtype_unchecked(device, dtype);\n    Ok(())\n}\n\n/// Sets the default integer data type for the device.\n///\n/// This updates the device's default data types used for tensor creation.\n/// The policy should typically be set once during initialization and then\n/// remains global for all subsequent operations on that device.\n///\n/// # Example\n///\n/// ```rust\n/// use burn_tensor::backend::Backend;\n/// use burn_tensor::{DType, Int, Tensor, set_default_int_dtype};\n///\n/// fn example<B: Backend>() {\n///     let device = B::Device::default();\n///     \n///     // Update the device policy\n///     set_default_int_dtype::<B>(&device, DType::I32);\n///     \n///     // All int tensors created after this will use I32 default\n///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);\n/// }\n/// ```\npub fn set_default_int_dtype<B: Backend>(\n    device: &B::Device,\n    dtype: impl Into<IntDType>,\n) -> Result<(), DeviceError> {\n    let dtype = dtype.into();\n    check_dtype_support::<B>(device, dtype)?;\n\n    set_default_int_dtype_unchecked(device, dtype);\n    Ok(())\n}\n\n// Unchecked versions\nfn set_default_dtypes_unchecked<D: DeviceOps>(\n    device: &D,\n    float_dtype: FloatDType,\n    int_dtype: IntDType,\n) {\n    DevicePolicyRegistry::update(device, |p| {\n        p.set_float_dtype(float_dtype);\n        p.set_int_dtype(int_dtype);\n    });\n}\n\nfn set_default_float_dtype_unchecked<D: DeviceOps>(device: &D, dtype: FloatDType) {\n    DevicePolicyRegistry::update(device, |p| {\n        p.set_float_dtype(dtype);\n    });\n}\n\nfn set_default_int_dtype_unchecked<D: DeviceOps>(device: &D, dtype: IntDType) {\n    DevicePolicyRegistry::update(device, |p| {\n        p.set_int_dtype(dtype);\n    });\n}\n\n#[cfg(all(test, feature = \"std\"))]\nmod tests {\n    use serial_test::serial;\n\n    use super::*;\n\n    fn clear_registry() {\n        REGISTRY.write().unwrap().clear();\n    }\n\n    #[derive(Clone, Debug, Default, PartialEq, new)]\n    pub struct TestDeviceA {\n        index: u32,\n    }\n\n    impl Device for TestDeviceA {\n        fn from_id(device_id: DeviceId) -> Self {\n            Self {\n                index: device_id.index_id,\n            }\n        }\n\n        fn to_id(&self) -> DeviceId {\n            DeviceId {\n                type_id: 0,\n                index_id: self.index,\n            }\n        }\n\n        fn device_count(_type_id: u16) -> usize {\n            1\n        }\n    }\n\n    impl DeviceOps for TestDeviceA {}\n\n    #[derive(Clone, Debug, Default, PartialEq, new)]\n    pub struct TestDeviceB {\n        index: u32,\n    }\n\n    impl Device for TestDeviceB {\n        fn from_id(device_id: DeviceId) -> Self {\n            Self {\n                index: device_id.index_id,\n            }\n        }\n\n        fn to_id(&self) -> DeviceId {\n            DeviceId {\n                type_id: 0,\n                index_id: self.index,\n            }\n        }\n\n        fn device_count(_type_id: u16) -> usize {\n            1\n        }\n    }\n\n    impl DeviceOps for TestDeviceB {}\n\n    #[test]\n    #[serial]\n    fn default_policy_is_created_and_shared() {\n        clear_registry(); // reset registry for each test\n\n        let device = TestDeviceA::new(0);\n\n        let p1 = get_device_policy(&device);\n        let p2 = get_device_policy(&device);\n\n        assert!(Arc::ptr_eq(&p1, &p2));\n        // Not explicitly set\n        assert!(p1.float_dtype().is_none());\n        assert!(p1.int_dtype().is_none());\n        assert!(p2.float_dtype().is_none());\n        assert!(p2.int_dtype().is_none());\n    }\n\n    #[test]\n    #[serial]\n    fn updated_policy_is_shared() {\n        clear_registry(); // reset registry for each test\n\n        let device = TestDeviceA::new(0);\n\n        // The device policy is meant to be set once at initialization\n        set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32);\n        let p1 = get_device_policy(&device);\n        let p2 = get_device_policy(&device);\n\n        assert!(Arc::ptr_eq(&p1, &p2));\n        assert_eq!(p1.float_dtype(), Some(FloatDType::BF16));\n        assert_eq!(p1.int_dtype(), Some(IntDType::I32));\n        assert_eq!(p2.float_dtype(), Some(FloatDType::BF16));\n        assert_eq!(p2.int_dtype(), Some(IntDType::I32));\n    }\n\n    #[test]\n    #[serial]\n    fn policy_is_device_id_specific() {\n        clear_registry(); // reset registry for each test\n\n        let d1 = TestDeviceA::new(0);\n        let d2 = TestDeviceA::new(1);\n\n        set_default_float_dtype_unchecked(&d1, FloatDType::F16);\n\n        let p1 = get_device_policy(&d1);\n        let p2 = get_device_policy(&d2);\n\n        assert!(!Arc::ptr_eq(&p1, &p2));\n        assert_eq!(p1.float_dtype(), Some(FloatDType::F16));\n        assert!(p1.int_dtype().is_none());\n        assert!(p2.float_dtype().is_none());\n        assert!(p2.int_dtype().is_none());\n    }\n\n    #[test]\n    #[serial]\n    fn policy_is_device_type_specific() {\n        clear_registry(); // reset registry for each test\n\n        let d1 = TestDeviceA::new(0);\n        let d2 = TestDeviceB::new(0);\n\n        set_default_float_dtype_unchecked(&d2, FloatDType::F16);\n\n        let p1 = get_device_policy(&d1);\n        let p2 = get_device_policy(&d2);\n\n        assert!(p1.float_dtype().is_none());\n        assert!(p1.int_dtype().is_none());\n        assert_eq!(p2.float_dtype(), Some(FloatDType::F16));\n        assert!(p2.int_dtype().is_none());\n    }\n\n    #[test]\n    #[serial]\n    fn updating_policy_should_not_affect_snapshot() {\n        clear_registry(); // reset registry for each test\n\n        // The device policy is meant to be set once at initialization\n        let device = TestDeviceA::new(0);\n        let before = get_device_policy(&device);\n\n        set_default_float_dtype_unchecked(&device, FloatDType::BF16);\n\n        let after = get_device_policy(&device);\n\n        assert!(!Arc::ptr_eq(&before, &after));\n        assert_eq!(after.float_dtype(), Some(FloatDType::BF16));\n        assert!(before.float_dtype().is_none());\n    }\n\n    #[test]\n    #[serial]\n    fn set_default_dtypes_overwrites_fields() {\n        clear_registry(); // reset registry for each test\n\n        let device = TestDeviceA::new(0);\n\n        set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64);\n\n        let policy = get_device_policy(&device);\n\n        assert_eq!(policy.float_dtype(), Some(FloatDType::F16));\n        assert_eq!(policy.int_dtype(), Some(IntDType::I64));\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/lib.rs",
    "content": "#![cfg_attr(not(feature = \"std\"), no_std)]\n#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! This library provides the core abstractions required to run tensor operations with Burn.\n//! `Tensor`s are generic over the backend to allow users to perform operations using different `Backend` implementations.\n//! Burn's tensors also support auto-differentiation thanks to the `AutodiffBackend` trait.\n\n#[macro_use]\nextern crate derive_new;\n\nextern crate alloc;\n\nmod tensor;\n\npub(crate) use tensor::check::macros::check;\npub use tensor::*;\n\n// Re-exported types\npub use burn_backend::{AllocationProperty, Bytes, StreamId, bf16, f16, read_sync, try_read_sync};\n\nmod device;\npub use device::*;\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/activation/base.rs",
    "content": "use crate::backend::Backend;\nuse crate::check::TensorCheck;\nuse crate::{Tensor, TensorPrimitive, check, s};\n\n/// Applies the rectified linear unit function element-wise\n/// as described in the paper [Deep Learning using Rectified Linear Units (ReLU)](https://arxiv.org/pdf/1803.08375).\n///\n#[cfg_attr(doc, doc = \"$$\\\\text{ReLU}\\\\(x\\\\) = \\\\(x\\\\)^+ = \\\\max\\\\(0, x\\\\)$$\")]\n#[cfg_attr(not(doc), doc = \"`ReLU(x) = max(0, x)`\")]\npub fn relu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.relu()\n}\n\n/// Applies the leaky rectified linear unit function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{LeakyReLU}\\(x\\) = \\max\\(0,x\\) + \\text{negative\\\\_slope} \\cdot \\min\\(0, x\\)\n$$\n\nor\n\n$$\n\\text{LeakyReLU}(x) =\n \\begin{cases}\n     x & \\text{if } x \\geq 0 \\newline\n     \\text{negative\\\\_slope} \\cdot x & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`f(x) =`\\n- `x for x >= 0`\\n- `negative_slope * x if x < 0`\"\n)]\npub fn leaky_relu<const D: usize, B: Backend>(\n    tensor: Tensor<B, D>,\n    negative_slope: f64,\n) -> Tensor<B, D> {\n    Tensor::from_primitive(TensorPrimitive::Float(B::leaky_relu(\n        tensor.primitive.tensor(),\n        negative_slope.into(),\n    )))\n}\n\n/// Applies the Gaussian Error Linear Units function as described in the paper\n/// [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{GELU}(x)\n= x \\cdot \\Phi(x)\n= x \\cdot \\frac{1}{2}\\left(1 + \\text{erf}\\left(\\frac{x}{\\sqrt{2}}\\right)\\right)\n$$\n\nwhere $\\Phi(x)$ is the cumulative distribution function for the Gaussian distribution.\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = r#\"\n`GELU(x) = x * Φ(x) = x * 1/2 * (1 + erf(x / sqrt(2)))`\n\nwhere `Φ(x)` is the cumulative distribution function for the Gaussian distribution.\n\"#\n)]\npub fn gelu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    Tensor::from_primitive(TensorPrimitive::Float(B::gelu(tensor.primitive.tensor())))\n}\n\n/// Applies the tanh-based approximate GELU function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{GELU\\_approx}(x)\n= \\frac{x}{2}\\left(1 + \\tanh\\left(\\sqrt{\\frac{2}{\\pi}}\\left(x + 0.044715\\,x^3\\right)\\right)\\right)\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`GELU_approx(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`\"\n)]\npub fn gelu_approximate<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    /// sqrt(2/π) precomputed as FRAC_2_SQRT_PI * FRAC_1_SQRT_2\n    const SQRT_2_OVER_PI: f64 =\n        core::f64::consts::FRAC_2_SQRT_PI * core::f64::consts::FRAC_1_SQRT_2;\n\n    let x = tensor;\n    let inner = x.clone() + x.clone().powf_scalar(3.0) * 0.044715;\n    let inner = inner * SQRT_2_OVER_PI;\n    (x.clone() * (inner.tanh() + 1)) * 0.5\n}\n\n/// Applies Parametric ReLu activation function as described in the paper\n/// [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).\n///\n/// - The tensor is assumed to be of shape `[batch_size, channels, ...]`.\n/// - `alpha` is assumed to be of shape `[channels]` or `[1]`.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{PReLU}\\(x\\) = \\max\\(0,x\\) + \\alpha \\cdot \\min\\(0, x\\)\n$$\n\nor\n\n$$\n\\text{PReLU}(x) =\n \\begin{cases}\n     x & \\text{if } x \\geq 0 \\newline\n     \\alpha x & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`PReLu(x) = max(0,x) + alpha * min(0,x)`\")]\npub fn prelu<const D: usize, B: Backend>(\n    tensor: Tensor<B, D>,\n    alpha: Tensor<B, 1>,\n) -> Tensor<B, D> {\n    check!(TensorCheck::check_prelu_shape::<D>(\n        &tensor.shape(),\n        &alpha.shape()\n    ));\n\n    let weight = if alpha.dims()[0] == 1 {\n        // if there is only 1 weight, then reshape it to (1,1,1... D times) so that the rank is D\n        alpha.reshape([1; D])\n    } else {\n        // D>=2 because the case where D==1 and num_weights >1 is handled by check function\n        // there is more than 1 weight and rank is more than 2\n        let num_weights = alpha.dims()[0];\n        let mut s = [1; D];\n        s[1] = num_weights;\n        // reshape the weights to (1, channels,1 ...)\n        alpha.reshape(s)\n    };\n\n    Tensor::from_primitive(TensorPrimitive::Float(B::prelu(\n        tensor.primitive.tensor(),\n        weight.primitive.tensor(),\n    )))\n}\n\n/// Applies the softmax function on the input tensor along the given dimension.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{softmax}\\(x_i\\) = \\frac{\\exp\\(x_i\\)}{\\sum_j \\exp\\(x_j\\)}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`softmax(x_i) = exp(x_i) / sum_j(exp(x_j))`\")]\n///\n/// # Arguments\n/// - `dim`: the dimension along which Softmax will be computed.\n///\n/// # Panics\n/// - If `dim` is outside [0, D)\npub fn softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    check!(TensorCheck::dim_ops::<D>(\"softmax\", dim));\n\n    let tensor = tensor.clone() - tensor.detach().max_dim(dim);\n    let tensor = tensor.exp();\n    let tensor_tmp = tensor.clone().sum_dim(dim);\n\n    tensor.div(tensor_tmp)\n}\n\n/// Applies the softmin function on the input tensor along the given dimension.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{softmin}\\(x_i\\) = \\frac{\\exp\\(-x_i\\)}{\\sum_j \\exp\\(-x_j\\)}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`softmin(x_i) = exp(-x_i) / sum_j(exp(-x_j)`\")]\n///\n/// # Arguments\n/// - `dim`: the dimension along which Softmax will be computed.\n///\n/// # Panics\n/// - If `dim` is outside [0, D)\npub fn softmin<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    check!(TensorCheck::dim_ops::<D>(\"softmin\", dim));\n    softmax(tensor.neg(), dim)\n}\n\n/// Applies the SoftPlus function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{softplus}\\(x\\) = \\frac{1}{\\beta}\\log\\(1 + \\exp\\(\\beta x\\)\\)\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`softplus(x_i) = log(1 + exp(beta * x_i)) / beta`\")]\n///\n/// The SoftPlus function is a smooth approximation of the ReLU function.\npub fn softplus<const D: usize, B: Backend>(tensor: Tensor<B, D>, beta: f64) -> Tensor<B, D> {\n    let tensor = (tensor.mul_scalar(beta).exp() + 1).log();\n    tensor.div_scalar(beta)\n}\n\n/// Applies the \"quiet softmax\" function on the input tensor along the given dimension.\n///\n/// Also referred to as [`softmax1`](https://www.evanmiller.org/attention-is-off-by-one.html).\n///\n/// This function is similar to the softmax function, but it allows for \"no selection\" when\n/// all the outputs are close to zero.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{quiet\\\\_softmax}\\(x_i\\) = \\frac{\\exp\\(x_i\\)}{1 + \\sum_j \\exp\\(x_j\\)}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`quiet_softmax(x_i) = exp(x_i) / [ 1 + sum_j(exp(x_j)) ]`\"\n)]\n///\n/// # Arguments\n/// - `dim`: the dimension along which Softmax will be computed.\n///\n/// # Panics\n/// - If `dim` is outside [0, D)\npub fn quiet_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    check!(TensorCheck::dim_ops::<D>(\"softmax\", dim));\n\n    let max_vals = tensor.clone().detach().max_dim(dim);\n    let exp_x = (tensor - max_vals.clone()).exp();\n    let sum_exp = exp_x.clone().sum_dim(dim);\n\n    exp_x.div(sum_exp + max_vals.neg().exp())\n}\n\n/// Applies the log softmax function on the input tensor along the given dimension.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{log\\\\_softmax}\\(x_i\\)\n= \\log\\left(\\text{softmax}\\(x_i\\)\\right)\n= \\log\\left(\\frac{\\exp\\(x_i\\)}{\\sum_j \\exp\\(x_j\\)}\\right)\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`log_softmax(x_i) = log(softmax(x_i)) = log(exp(x_i) / sum_j(exp(x_j)))`\"\n)]\n///\n/// # Arguments\n/// - `dim`: the dimension along which Softmax will be computed.\n///\n/// # Panics\n/// - If `dim` is outside [0, D)\npub fn log_softmax<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    check!(TensorCheck::dim_ops::<D>(\"log softmax\", dim));\n\n    let tensor = tensor.clone() - tensor.detach().max_dim(dim);\n    let tensor_tmp = tensor.clone().exp().sum_dim(dim).log();\n\n    tensor.sub(tensor_tmp)\n}\n\n/// Applies the sigmoid function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{sigmoid}\\(x\\)\n= \\sigma(x)\n= \\frac{1}{1 + \\exp(-x)}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`sigmoid(x) = 1 / (1 + exp(-x))`\")]\npub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    Tensor::from_primitive(TensorPrimitive::Float(B::sigmoid(\n        tensor.primitive.tensor(),\n    )))\n}\n\n/// Applies the hard sigmoid function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{hard\\\\_sigmoid}\\(x\\) = \\max(0, \\min(1, \\alpha \\cdot x + \\beta))\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`\")]\npub fn hard_sigmoid<const D: usize, B: Backend>(\n    tensor: Tensor<B, D>,\n    alpha: f64,\n    beta: f64,\n) -> Tensor<B, D> {\n    Tensor::from_primitive(TensorPrimitive::Float(B::hard_sigmoid(\n        tensor.primitive.tensor(),\n        alpha.into(),\n        beta.into(),\n    )))\n}\n\n/// Applies the log sigmoid function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{log\\\\_sigmoid}\\(x\\) = \\log\\left(\\frac{1}{1 + \\exp(-x)}\\right)\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`log_sigmoid(x) = log(1 / (1 + exp(-x)))`\")]\npub fn log_sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    Tensor::from_primitive(TensorPrimitive::Float(B::log_sigmoid(\n        tensor.primitive.tensor(),\n    )))\n}\n\n/// Applies the SiLU function (also known as the swish function) element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{SiLU}\\(x\\) = x \\cdot \\sigma(x) = \\frac{x}{1 + \\exp(-x)}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))`\")]\npub fn silu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.clone().mul(sigmoid(tensor))\n}\n\n/// Applies the hard swish function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{hard\\_swish}\\(x\\) = x \\cdot \\text{hard\\_sigmoid}(x) = x \\cdot \\max(0, \\min(1, \\frac{x}{6} + 0.5))\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`hard_swish(x) = x * hard_sigmoid(x) = x * max(0, min(1, x/6 + 0.5))`\"\n)]\npub fn hard_swish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.clone().mul(hard_sigmoid(tensor, 1.0 / 6.0, 0.5))\n}\n\n/// Applies the Mish function as described in the paper in\n/// [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{Mish}\\(x\\)\n= x \\cdot \\tanh(\\text{Softplus}(x))\n= \\tanh\\left(\\log\\(1 + \\exp\\(x\\)\\)\\right)\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`mish(x) = x * tanh(softplus(x)) = tanh(log(1 + exp(x)))`\"\n)]\npub fn mish<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.clone().mul(softplus(tensor, 1.0).tanh())\n}\n\n/// Applies the tanh function element-wise.\npub fn tanh<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.tanh()\n}\n\n/// Applies the Exponential Linear Unit function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{ELU}\\(x\\) =\n \\begin{cases}\n     x & \\text{if } x > 0 \\newline\n     \\alpha \\cdot (\\exp(x) - 1) & \\text{if } x \\leq 0\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`f(x) =`\\n- `x for x > 0`\\n- `alpha * (exp(x) - 1) for x <= 0`\"\n)]\npub fn elu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {\n    let mask = tensor.clone().lower_equal_elem(0);\n    let scaled = tensor.clone().exp().sub_scalar(1).mul_scalar(alpha);\n    tensor.mask_where(mask, scaled)\n}\n\n/// Applies the Continuously Differentiable Exponential Linear Unit function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{CELU}(x) =\n \\begin{cases}\n     x & \\text{if } x \\geq 0 \\newline\n     \\alpha \\cdot \\left(\\exp\\left(\\frac{x}{\\alpha}\\right) - 1\\right) & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`celu(x) = max(0, x) + min(0, alpha * (exp(x / alpha) - 1))`\"\n)]\n///\n/// See also [CELU](https://pytorch.org/docs/stable/generated/torch.nn.CELU.html)\n///\n/// # Arguments\n/// - `alpha`: scaling parameter for the negative part.\npub fn celu<const D: usize, B: Backend>(tensor: Tensor<B, D>, alpha: f64) -> Tensor<B, D> {\n    let mask = tensor.clone().lower_equal_elem(0);\n    let scaled = tensor\n        .clone()\n        .div_scalar(alpha)\n        .exp()\n        .sub_scalar(1)\n        .mul_scalar(alpha);\n    tensor.mask_where(mask, scaled)\n}\n\n/// Applies the Scaled Exponential Linear Unit function element-wise\n/// as described in the paper [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{SELU}\\(x\\) = \\gamma \\cdot\n \\begin{cases}\n     x & \\text{if } x > 0 \\newline\n     \\alpha \\cdot (\\exp(x) - 1) & \\text{if } x \\leq 0\n \\end{cases}\n$$\n\nwhere $\\alpha \\approx 1.6733$ and $\\gamma \\approx 1.0507$.\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`selu(x) = gamma * x if x > 0, gamma * alpha * (exp(x) - 1) if x <= 0`\"\n)]\npub fn selu<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    // Constants from the SELU paper / ONNX spec\n    const ALPHA: f64 = 1.6732632423543772848170429916717_f64;\n    const GAMMA: f64 = 1.0507009873554804934193349852946_f64;\n\n    let mask = tensor.clone().greater_equal_elem(0.0);\n    let positive = tensor.clone().mul_scalar(GAMMA);\n    let negative = tensor.exp().sub_scalar(1.0).mul_scalar(ALPHA * GAMMA);\n\n    negative.mask_where(mask, positive)\n}\n\n/// Applies the thresholded rectified linear unit function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{ThresholdedReLU}(x) =\n \\begin{cases}\n     x & \\text{if } x > \\alpha \\newline\n     0 & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`f(x) =`\\n- `x if x > alpha`\\n- `0 otherwise`\")]\n///\n/// # Arguments\n/// - `alpha`: threshold value (default in ONNX is 1.0).\npub fn thresholded_relu<const D: usize, B: Backend>(\n    tensor: Tensor<B, D>,\n    alpha: f64,\n) -> Tensor<B, D> {\n    let mask = tensor.clone().lower_equal_elem(alpha);\n    tensor.mask_fill(mask, 0)\n}\n\n/// Applies the gated linear unit function.\n///\n/// GLU(a,b)=a⊗σ(b) where `a` is the first half of the input matrices and `b` is the second half.\n///\n/// **Note**:\n/// * The size of the input tensor along `dim` must be divisible by 2.\n///\n/// ### Arguments\n/// * `tensor` - The input tensor.\n///\n/// ### Returns\n/// * A tensor with the same shape as the input, except the size along `dim` is halved.\npub fn glu<const D: usize, B: Backend>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    // TODO: Handle negative indices with AsIndex for compatibility with Pytorch nn.GLU.\n\n    assert!(\n        tensor.dims()[dim].is_multiple_of(2),\n        \"Input tensor along dimension {dim} must have an even size. N is divisible by 2.\"\n    );\n    let new_len = tensor.dims()[dim] / 2;\n\n    let a = tensor.clone().slice_dim(dim, s![0..new_len]);\n    let b = tensor.slice_dim(dim, s![new_len..new_len * 2]);\n\n    a.mul(sigmoid(b))\n}\n\n/// Applies the Softsign function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{softsign}(x) = \\frac{x}{1 + |x|}\n$$\n\"#\n)]\n#[cfg_attr(not(doc), doc = \"`softsign(x_i) = x_i / (1 + |x_i|)`\")]\npub fn softsign<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    tensor.clone().div(tensor.abs() + 1)\n}\n\n/// Applies the HardShrink function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{hard\\_shrink}(x) =\n \\begin{cases}\n     x & \\text{if } x > \\lambda \\newline\n     x & \\text{if } x < -\\lambda \\newline\n     0 & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`\"\n)]\n/// # Arguments\n/// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5.\npub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {\n    let mask = tensor.clone().abs().lower_equal_elem(lambda);\n    tensor.mask_fill(mask, 0)\n}\n\n/// Applies the SoftShrink function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{soft\\_shrink}(x) =\n \\begin{cases}\n     x - \\lambda & \\text{if } x > \\lambda \\newline\n     x + \\lambda & \\text{if } x < -\\lambda \\newline\n     0 & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`\"\n)]\n/// # Arguments\n/// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5.\npub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {\n    shrink(tensor, lambda, lambda)\n}\n\n/// Applies the Shrink function element-wise.\n///\n#[cfg_attr(\n    doc,\n    doc = r#\"\n$$\n\\text{shrink}(x) =\n \\begin{cases}\n     x - \\text{bias} & \\text{if } x > \\lambda \\newline\n     x + \\text{bias} & \\text{if } x < -\\lambda \\newline\n     0 & \\text{otherwise}\n \\end{cases}\n$$\n\"#\n)]\n#[cfg_attr(\n    not(doc),\n    doc = \"`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`\"\n)]\n/// # Arguments\n/// - `lambda`: the lambda value for the Shrink formulation.\n/// - `bias`: the bias value for the Shrink formulation.\npub fn shrink<const D: usize, B: Backend>(\n    tensor: Tensor<B, D>,\n    lambda: f64,\n    bias: f64,\n) -> Tensor<B, D> {\n    let abs_tensor = tensor.clone().abs();\n    let sign = tensor.clone().sign();\n    let shrunk = tensor.sub(sign.mul_scalar(bias));\n    let mask = abs_tensor.lower_equal_elem(lambda);\n    shrunk.mask_fill(mask, 0)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/activation/mod.rs",
    "content": "mod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/autodiff.rs",
    "content": "pub use burn_backend::tensor::BasicAutodiffOps;\n\nuse crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};\n\nimpl<const D: usize, B: AutodiffBackend> Tensor<B, D> {\n    /// Backward pass of the tensor.\n    pub fn backward(&self) -> B::Gradients {\n        B::backward(self.primitive.clone().tensor())\n    }\n\n    /// Get the gradients of a tensor if it exist.\n    ///\n    /// Returns a new reference to the same tensor. Therefore the same grad tensor can\n    /// be accessed multiple times. If you only need to get the gradients one time,\n    /// consider using [grad_remove](Tensor::grad_remove) for better performance.\n    pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {\n        match &self.primitive {\n            TensorPrimitive::Float(tensor) => B::grad(tensor, grads)\n                .map(TensorPrimitive::Float)\n                .map(Tensor::new),\n            TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)\n                .map(TensorPrimitive::Float)\n                .map(Tensor::new),\n        }\n    }\n\n    /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.\n    pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {\n        match &self.primitive {\n            TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)\n                .map(TensorPrimitive::Float)\n                .map(Tensor::new),\n            TensorPrimitive::QFloat(_tensor) => {\n                B::grad_remove(&self.primitive.clone().tensor(), grads)\n                    .map(TensorPrimitive::Float)\n                    .map(Tensor::new)\n            }\n        }\n    }\n\n    /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided\n    /// gradient.\n    pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {\n        match &self.primitive {\n            TensorPrimitive::Float(tensor) => {\n                B::grad_replace(tensor, grads, grad.primitive.tensor())\n            }\n            TensorPrimitive::QFloat(_tensor) => B::grad_replace(\n                &self.primitive.clone().tensor(),\n                grads,\n                grad.primitive.tensor(),\n            ),\n        }\n    }\n}\n\nimpl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {\n    /// Returns the inner tensor without the autodiff information.\n    pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {\n        Tensor::new(K::inner(self.primitive))\n    }\n\n    /// Convert a tensor to the autodiff backend.\n    ///\n    /// # Arguments\n    ///\n    /// * `inner` - The tensor to convert.\n    ///\n    /// # Returns\n    ///\n    /// The tensor converted to the autodiff backend.\n    pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {\n        Self::new(K::from_inner(inner.primitive))\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/base.rs",
    "content": "#![allow(clippy::single_range_in_vec_init)]\nuse crate::backend::ExecutionError;\nuse crate::check::unwrap_shape_reshape;\n\nuse burn_backend::Scalar;\npub use burn_backend::tensor::BasicOps;\n\nuse alloc::vec::Vec;\n\nuse alloc::format;\nuse alloc::string::String;\nuse alloc::vec;\n\nuse burn_std::{SliceOps, stub::RwLock};\nuse core::iter::repeat;\nuse core::{fmt::Debug, ops::Range};\nuse serde::{Deserialize, Deserializer};\n\nuse crate::{AsIndex, Slice, SliceArg, wrap_index};\nuse crate::{\n    Bool, ElementConversion, Float, Int, Shape, TensorData, TensorKind, TensorMetadata,\n    backend::Backend, check,\n};\nuse crate::{DType, Element};\nuse crate::{IndexingUpdateOp, TensorCreationOptions};\nuse crate::{cast::ToElement, check::TensorCheck};\nuse serde::{Serialize, Serializer};\n\n/// A tensor with a given backend, shape and data type.\n///\n/// # Indexing\n/// Indexing a tensor can be done using [`slice`](Tensor::slice) for all tensor types\n/// or [`select`](Tensor::select) for numeric types.\n///\n/// ## Example\n///\n/// ```rust\n/// use burn_tensor::backend::Backend;\n/// use burn_tensor::Tensor;\n/// use burn_tensor::Int;\n///\n/// fn example<B: Backend>() {\n///     let device = Default::default();\n///\n///     let tensor = Tensor::<B, 2>::from_data(\n///         [\n///             [3.0, 4.9, 2.0],\n///             [2.0, 1.9, 3.0],\n///             [6.0, 1.5, 7.0],\n///             [3.0, 4.9, 9.0],\n///         ],\n///         &device,\n///     );\n///\n///     // Slice the tensor to get the second and third rows:\n///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]\n///     // The resulting tensor will have dimensions [2, 3].\n///     let slice = tensor.clone().slice([1..3]);\n///     println!(\"{slice}\");\n///\n///     // Slice the tensor to get the first two rows and the first 2 columns:\n///     // [[3.0, 4.9], [2.0, 1.9]]\n///     // The resulting tensor will have dimensions [2, 2].\n///     let slice = tensor.clone().slice([0..2, 0..2]);\n///     println!(\"{slice}\");\n///\n///     // Index the tensor along the dimension 1 to get the elements 0 and 2:\n///     // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]\n///     // The resulting tensor will have dimensions [4, 2]\n///     let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);\n///     let indexed = tensor.select(1, indices);\n///     println!(\"{indexed}\");\n/// }\n/// ```\n#[derive(new, Clone, Debug)]\npub struct Tensor<B, const D: usize, K = Float>\nwhere\n    B: Backend,\n    K: TensorKind<B>,\n{\n    pub(crate) primitive: K::Primitive,\n}\n\nimpl<B, const D: usize, K, T> From<T> for Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n    T: Into<TensorData>,\n{\n    fn from(value: T) -> Self {\n        Tensor::from_data(value.into(), &Default::default())\n    }\n}\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n    K::Elem: Element,\n{\n    /// Executes an operation on the tensor and modifies its value.\n    ///\n    /// # Notes\n    ///\n    /// This won't necessarily reuse the same tensor data/buffer, but it should if there is\n    /// no other reference pointing to the same tensor.\n    ///\n    /// Wrapping operations with inplace is not an optimization, it's mainly there if you\n    /// want to mutate a tensor by using owned operations. A plausible usage would be to\n    /// update the weights of a mutable model reference.\n    pub fn inplace<F: FnOnce(Self) -> Self>(&mut self, func: F) {\n        let mut tensor_owned = Tensor::empty([0; D], &self.device());\n        core::mem::swap(&mut tensor_owned, self);\n\n        let mut tensor_new = func(tensor_owned);\n        core::mem::swap(&mut tensor_new, self);\n    }\n\n    /// Converts the tensor into a primitive tensor.\n    pub fn into_primitive(self) -> K::Primitive {\n        self.primitive\n    }\n\n    /// Converts from a primitive tensor into a tensor.\n    pub fn from_primitive(tensor: K::Primitive) -> Self {\n        Self::new(tensor)\n    }\n\n    /// Returns the number of dimensions of the tensor.\n    pub fn rank(&self) -> usize {\n        self.primitive.rank()\n    }\n\n    /// Returns the tensor primitive data type.\n    ///\n    /// # Note\n    /// Some element types are encoded in different primitive types depending on the backend\n    /// (e.g., bool could be encoded as `u8` or `u32`).\n    pub fn dtype(&self) -> DType {\n        self.primitive.dtype()\n    }\n\n    /// Create an empty tensor of the given shape.\n    ///\n    /// # Arguments\n    ///\n    /// - `shape`: The shape of the tensor.\n    /// - `device`: The device where the tensor will be created.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    // Create an empty tensor with dimensions [2, 3, 4].\n    ///    let tensor = Tensor::<B, 3>::empty([2, 3, 4], &device);\n    /// }\n    /// ```\n    pub fn empty<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {\n        let opt = options.into();\n        let shape = shape.into();\n        let dtype = opt.resolve_policy::<K>();\n        check!(TensorCheck::creation_ops::<D>(\"Empty\", &shape));\n        Self::new(K::empty(shape, &opt.device, dtype))\n    }\n\n    /// Create a tensor of the given shape where each element is zero.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::zeros(Shape::new([2, 3]), &device);\n    ///    println!(\"{tensor}\");\n    ///    // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]\n    /// }\n    /// ```\n    pub fn zeros<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {\n        let opt = options.into();\n        let shape = shape.into();\n        let dtype = opt.resolve_policy::<K>();\n        check!(TensorCheck::creation_ops::<D>(\"Zeros\", &shape));\n        Self::new(K::zeros(shape, &opt.device, dtype))\n    }\n\n    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with zeros.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.zeros_like();\n    ///   println!(\"{tensor}\");\n    ///   // [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]\n    /// }\n    /// ```\n    pub fn zeros_like(&self) -> Self {\n        Self::new(K::zeros(self.shape(), &self.device(), self.dtype()))\n    }\n\n    /// Create a tensor of the given shape where each element is one.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::ones(Shape::new([2, 3]), &device);\n    ///   println!(\"{tensor}\");\n    ///   // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]\n    /// }\n    /// ```\n    pub fn ones<S: Into<Shape>>(shape: S, options: impl Into<TensorCreationOptions<B>>) -> Self {\n        let opt = options.into();\n        let shape = shape.into();\n        let dtype = opt.resolve_policy::<K>();\n        check!(TensorCheck::creation_ops::<D>(\"Ones\", &shape));\n        Self::new(K::ones(shape, &opt.device, dtype))\n    }\n\n    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled with ones.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.ones_like();\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]\n    /// }\n    /// ```\n    pub fn ones_like(&self) -> Self {\n        Self::new(K::ones(self.shape(), &self.device(), self.dtype()))\n    }\n\n    /// Create a tensor of the given shape where each element is equal to the provided value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::full(Shape::new([2, 3]), 5.0, &device);\n    ///   println!(\"{tensor}\");\n    ///   // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]\n    /// }\n    /// ```\n    pub fn full<S: Into<Shape>, E: ElementConversion>(\n        shape: S,\n        fill_value: E,\n        options: impl Into<TensorCreationOptions<B>>,\n    ) -> Self {\n        let opt = options.into();\n        let shape = shape.into();\n        let dtype = opt.resolve_policy::<K>();\n        check!(TensorCheck::creation_ops::<D>(\"Full\", &shape));\n        Self::new(K::full(\n            shape,\n            Scalar::new(fill_value, &dtype),\n            &opt.device,\n            dtype,\n        ))\n    }\n\n    /// Returns a new tensor with the same shape, dtype, and device as the current tensor,\n    /// filled with the provided value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.full_like(5.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[5.0, 5.0, 5.0], [5.0, 5.0, 5.0]]\n    /// }\n    /// ```\n    pub fn full_like<E: ElementConversion>(&self, fill_value: E) -> Self {\n        let dtype = self.dtype();\n        Self::new(K::full(\n            self.shape(),\n            Scalar::new(fill_value, &dtype),\n            &self.device(),\n            dtype,\n        ))\n    }\n\n    /// Returns the dimensions of the current tensor.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = Default::default();\n    ///   let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);\n    ///   let dims = tensor.dims(); // [2, 3, 4]\n    ///   println!(\"{dims:?}\");\n    /// }\n    /// ```\n    pub fn dims(&self) -> [usize; D] {\n        Self::shape(self).dims()\n    }\n\n    /// Returns the shape of the current tensor.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);\n    ///    // Shape { dims: [2, 3, 4] }\n    ///    let shape = tensor.shape();\n    /// }\n    /// ```\n    pub fn shape(&self) -> Shape {\n        self.primitive.shape()\n    }\n\n    /// Reshape the tensor to have the given shape.\n    ///\n    /// The tensor has the same data and number of elements as the input.\n    ///\n    /// A `-1` in the shape is used to infer the remaining dimensions, e.g.: `[2, -1]`\n    /// will reshape the tensor with [2, 3, 4] dimensions to [2, 12].\n    ///\n    /// A `0` in the shape instructs to keep the current dimension from the original tensor,\n    /// e.g.: `[2, 0, 4]` will reshape the tensor with [2, 3, 4] dimensions to [2, 3, 4].\n    /// This is useful when reshaping tensors with unknown dimensions and combining with `-1`\n    /// to infer the remaining dimensions, e.g. `[0, -1]` will reshape the tensor\n    /// with [1, 3, 4] dimensions to [1, 12].\n    ///\n    /// # Arguments\n    /// - `shape`: The new shape of the tensor.\n    ///\n    /// # Panics\n    /// - If the tensor contains more than one `-1` in the shape.\n    /// - If the tensor contains values that are not positive (other than -1).\n    /// - If the shape does not match the number of elements of the original shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    // Create a tensor with dimensions [2, 3, 4]\n    ///    let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);\n    ///    // Reshape it to [2, 12], where 12 is inferred from the number of elements.\n    ///    let reshaped = tensor.reshape([2, -1]);\n    ///    println!(\"{reshaped}\");\n    /// }\n    /// ```\n    pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {\n        // Convert reshape args to shape\n        let shape = shape.into_shape::<D2>(self.shape());\n        Tensor::new(K::reshape(self.primitive, shape))\n    }\n\n    /// Transpose the tensor.\n    ///\n    /// For a 2D tensor, this is the standard matrix transpose. For `D > 2`, the transpose is\n    /// applied on the last two dimensions. For example, the transpose of a tensor with shape\n    /// `[1, 2, 3, 4]` will have shape `[1, 2, 4, 3]`.\n    ///\n    /// See also [`permute`](Tensor::permute).\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to transpose.\n    ///\n    /// # Returns\n    ///\n    /// The transposed tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor of shape [2, 3]\n    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///\n    ///     // Transpose the tensor:\n    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]\n    ///     // The resulting tensor will have dimensions [3, 2].\n    ///     let transposed = tensor.transpose();\n    ///     println!(\"{transposed}\");\n    /// }\n    /// ```\n    pub fn transpose(self) -> Tensor<B, D, K> {\n        Tensor::new(K::transpose(self.primitive))\n    }\n\n    /// Alias for `transpose`.\n    #[inline(always)]\n    pub fn t(self) -> Tensor<B, D, K> {\n        self.transpose()\n    }\n\n    /// Swaps two dimensions of a tensor.\n    ///\n    /// This is a no-op when `dim1 == dim2`, assuming both are within bounds.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to swap the dimensions of.\n    /// * `dim1` - The first dimension to swap, supports negative indexing.\n    /// * `dim2` - The second dimension to swap, supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions swapped.\n    ///\n    /// # Panics\n    ///\n    /// When dimensions are out of bounds.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor of shape [2, 3]\n    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///\n    ///     // Swap the dimensions 0 and -1 (equivalent to `tensor.transpose()`):\n    ///     // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]\n    ///     // The resulting tensor will have dimensions [3, 2].\n    ///     let swapped = tensor.swap_dims(0, -1);\n    ///     println!(\"{swapped}\");\n    /// }\n    /// ```\n    pub fn swap_dims<Dim1, Dim2>(self, dim1: Dim1, dim2: Dim2) -> Tensor<B, D, K>\n    where\n        Dim1: AsIndex,\n        Dim2: AsIndex,\n    {\n        let dim1 = dim1.expect_dim_index(D);\n        let dim2 = dim2.expect_dim_index(D);\n        check!(TensorCheck::swap_dims::<D>(dim1, dim2));\n        if dim1 == dim2 {\n            self\n        } else {\n            Tensor::new(K::swap_dims(self.primitive, dim1, dim2))\n        }\n    }\n\n    /// Permute the dimensions of the tensor.\n    ///\n    /// This is a no-op when the resolved `axes` match the current order.\n    ///\n    /// # Arguments\n    ///\n    /// * `axes` - The new order of the dimensions. The length of the axes\n    ///   must be equal to the number of dimensions of the tensor.\n    ///   The values must be unique and in the range of the number of dimensions.\n    ///   The values can be negative, in which case they are used as an offset from the end.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions permuted.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor of shape [3, 2]\n    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]], &device);\n    ///\n    ///     // Permute the dimensions 1 and 0:\n    ///     // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]\n    ///     // The resulting tensor will have dimensions [3, 2].\n    ///     let permuted = tensor.permute([1, 0]);\n    ///     println!(\"{permuted}\");\n    /// }\n    /// ```\n    pub fn permute<Dim>(self, axes: [Dim; D]) -> Tensor<B, D, K>\n    where\n        Dim: AsIndex,\n    {\n        let mut no_op = true;\n        let mut fixed_axes = [0; D];\n        for (i, axis) in axes.into_iter().enumerate() {\n            let dim = axis.expect_dim_index(D);\n            no_op &= dim == i;\n            fixed_axes[i] = dim;\n        }\n\n        if no_op {\n            self\n        } else {\n            check!(TensorCheck::permute(fixed_axes));\n            Tensor::new(K::permute(self.primitive, &fixed_axes))\n        }\n    }\n\n    /// Moves the dimension(s) of input at the position(s) in source to the position(s) in destination.\n    ///\n    /// Other dimensions of input that are not explicitly moved remain in their original order and appear\n    /// at the positions not specified in destination.\n    ///\n    /// # Arguments\n    ///\n    /// * `src` - The dimension(s) to move. The values must be unique and in the range of the number of dimensions.\n    ///   The values can be negative, in which case they are used as an offset from the end.\n    ///\n    /// * `dst` - Destination positions for each of the original dims. These must also be unique.\n    ///\n    /// # Panics\n    ///\n    /// - If the source and destination dimensions are not of the same length.\n    /// - If the source and destination vectors contain duplicate values.\n    /// - If the source and destination vectors contain values that are out of bounds.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the dimensions moved.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 3D tensor of shape [3, 2, 1]\n    ///     let tensor = Tensor::<B, 3>::from_data([[[1.0], [5.0]], [[-2.0], [9.0]], [[3.0], [6.0]]], &device);\n    ///\n    ///     // Move the dimensions 0 and 1:\n    ///     // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]\n    ///     // The resulting tensor will have dimensions [2, 3, 1].\n    ///     let moved = tensor.movedim(1, 0);\n    ///     println!(\"{moved}\");\n    /// }\n    /// ```\n    ///\n    /// # Note\n    ///\n    /// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op\n    /// for it\n    pub fn movedim<S1: MovedimArgs, S2: MovedimArgs>(self, src: S1, dst: S2) -> Tensor<B, D, K> {\n        let source_dims = src.into_dim_vec::<D>();\n        let destination_dims = dst.into_dim_vec::<D>();\n\n        check!(TensorCheck::movedim_args_length(\n            &source_dims,\n            &destination_dims\n        ));\n\n        let mut m = [-1; D];\n        for (&d, &s) in destination_dims.iter().zip(source_dims.iter()) {\n            m[d] = s as isize;\n        }\n        let mut axes: [isize; D] = [0; D];\n        let mut source_i = 0;\n        for (dest_i, item) in axes.iter_mut().enumerate().take(D) {\n            *item = if m[dest_i] != -1 {\n                m[dest_i]\n            } else {\n                while source_dims.contains(&source_i) {\n                    source_i += 1;\n                }\n                let result = source_i as isize;\n                source_i += 1;\n                result\n            };\n        }\n\n        self.permute(axes)\n    }\n\n    /// Reverse the order of elements in the tensor along the given dimensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.\n    ///   The values can be negative, in which case they are used as an offset from the end.\n    ///\n    /// # Returns\n    ///\n    /// The tensor with the axes flipped.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [4, 3]\n    ///     let tensor = Tensor::<B, 2>::from_data(\n    ///         [\n    ///             [3.0, 4.9, 2.0],\n    ///             [2.0, 1.9, 3.0],\n    ///             [4.0, 5.9, 8.0],\n    ///             [1.4, 5.8, 6.0],\n    ///         ],\n    ///         &device,\n    ///     );\n    ///\n    ///     // Flip the elements in dimensions 0 and 1:\n    ///     // [[6.0, 5.8, 1.4],\n    ///     //  [8.0, 5.9, 4.0],\n    ///     //  [3.0, 1.9, 2.0],\n    ///     //  [2.0, 4.9, 3.0]]\n    ///     // The resulting tensor will have dimensions [4, 3].\n    ///     let flipped = tensor.flip([0, 1]);\n    ///     println!(\"{flipped}\");\n    /// }\n    /// ```\n    pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {\n        // Convert the axes to usize and handle negative values without using vector\n        let mut transformed_axes: [usize; N] = [0; N];\n        for (i, &x) in axes.iter().enumerate() {\n            transformed_axes[i] = if x < 0 {\n                (D as isize + x) as usize\n            } else {\n                x as usize\n            };\n        }\n\n        // Check if the axes are valid\n        check!(TensorCheck::flip(D, &transformed_axes));\n\n        Tensor::new(K::flip(self.primitive, &transformed_axes))\n    }\n\n    /// Flatten the tensor along a given range of dimensions.\n    ///\n    /// This function collapses the specified range of dimensions into a single dimension,\n    /// effectively flattening the tensor in that range.\n    ///\n    /// # Arguments\n    ///\n    /// - `start_dim`: The starting dimension of the range to be flattened,\n    ///   supports negative indexing.\n    /// - `end_dim`: The ending dimension of the range to be flattened (inclusive),\n    ///   supports negative indexing.\n    ///\n    /// # Type Parameters\n    ///\n    /// - `D2`: The resulting number of dimensions in the flattened tensor.\n    ///\n    /// # Returns\n    ///\n    /// A new `Tensor<B, D2, K>` instance with the specified range of dimensions flattened.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    ///\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 3D tensor with dimensions [2, 3, 4]\n    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 4]), &device);\n    ///\n    ///     // Flatten the tensor from dimensions 1 to 2 (inclusive).\n    ///     // The resulting tensor will have dimensions [2, 12]\n    ///     let flattened: Tensor<B, 2> = tensor.flatten(1, 2);\n    ///     println!(\"{flattened}\");\n    /// }\n    /// ```\n    pub fn flatten<const D2: usize>(\n        self,\n        start_dim: impl AsIndex,\n        end_dim: impl AsIndex,\n    ) -> Tensor<B, D2, K> {\n        let start_dim = start_dim.expect_dim_index(D);\n        let end_dim = end_dim.expect_dim_index(D);\n        check!(TensorCheck::flatten::<D, D2>(start_dim, end_dim));\n        let new_shape = self.shape().flatten_dims(start_dim, end_dim);\n\n        Tensor::new(K::reshape(self.primitive, new_shape))\n    }\n\n    /// Squeeze the tensor along all dimensions, removing dimensions\n    /// of size one, and effectively reducing the rank of the tensor.\n    ///\n    /// # Type Parameters\n    ///\n    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.\n    ///\n    /// # Returns\n    ///\n    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    ///\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 4D tensor with dimensions [1, 3, 1, 3]\n    ///     let tensor = Tensor::<B, 4>::from_data(\n    ///         [[[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]]],\n    ///         &device,\n    ///     );\n    ///\n    ///     // Squeeze the tensor dimensions.\n    ///     // The resulting tensor will have dimensions [3, 3].\n    ///     let squeezed = tensor.squeeze::<2>();\n    ///     println!(\"{squeezed}\");\n    /// }\n    /// ```\n    pub fn squeeze<const D2: usize>(self) -> Tensor<B, D2, K> {\n        let new_dims = self\n            .shape()\n            .iter()\n            .filter_map(|&dim| if dim == 1 { None } else { Some(dim) })\n            .collect::<Vec<_>>();\n        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));\n\n        Tensor::new(K::reshape(self.primitive, new_dims.into()))\n    }\n\n    /// Squeeze the tensor along the given dimension, removing the specified dimension\n    /// of size one, and effectively reducing the rank of the tensor by one.\n    ///\n    /// # Arguments\n    ///\n    /// - `dim`: The dimension to be squeezed.\n    ///\n    /// # Type Parameters\n    ///\n    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the size in the squeezed dimension is not 1.\n    ///\n    /// # Returns\n    ///\n    /// A new `Tensor<B, D2, K>` instance with the specified dimension removed.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    ///\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 3D tensor with dimensions [3, 1, 3]\n    ///     let tensor = Tensor::<B, 3>::from_data(\n    ///         [[[3.0, 4.9, 2.0]], [[2.0, 1.9, 3.0]], [[4.0, 5.9, 8.0]]],\n    ///         &device,\n    ///     );\n    ///\n    ///     // Squeeze the dimension 1.\n    ///     // The resulting tensor will have dimensions [3, 3].\n    ///     let squeezed = tensor.squeeze_dim::<2>(1);\n    ///     println!(\"{squeezed}\");\n    /// }\n    /// ```\n    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {\n        check!(TensorCheck::squeeze::<D2>(dim, &self.shape()));\n\n        let current_dims = self.shape();\n        let mut new_dims: [usize; D2] = [0; D2];\n\n        new_dims[..dim].copy_from_slice(&current_dims[..dim]);\n        new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);\n\n        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));\n        Tensor::new(K::reshape(self.primitive, new_dims.into()))\n    }\n\n    /// Removes specified dimensions of size 1 from a tensor's shape. This function takes a tensor and\n    /// an array of dimensions (`dims`) to be squeezed. If `dims` is provided, only the dimensions\n    /// specified in this array will be removed. Each dimension in `dims` should correspond to a size of 1\n    /// in the tensor; otherwise, the dimension will not be squeezed. If `dims` is empty, all single-dimensional entries\n    /// in the tensor will be removed. If entries in `dims` are negative, then dimensions will be counted\n    /// from the back.\n    ///\n    /// # Arguments\n    ///\n    /// - `dims`: The dimension(s) to be squeezed.\n    ///\n    /// # Type Parameters\n    ///\n    ///  - `D2`: The resulting number of dimensions in the squeezed tensor.\n    ///\n    /// # Returns\n    ///\n    /// A new `Tensor<B, D2, K>` instance with the specified dimensions removed.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    ///\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 4D tensor with dimensions [2, 1, 4, 1]\n    ///     let tensor = Tensor::<B, 4>::ones(Shape::new([2, 1, 4, 1]), &device);\n    ///\n    ///     // Squeeze the dimensions 1 and 3.\n    ///     // The resulting tensor will have dimensions [2, 4].\n    ///     let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);\n    ///     println!(\"{squeezed}\");\n    /// }\n    /// ```\n    pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {\n        let current_dims = self.shape();\n        let mut dim_indices: Vec<usize>;\n\n        // Check if dims is empty, if yes then assign dim_indices all single-dimensional entries\n        if dims.is_empty() {\n            dim_indices = current_dims\n                .iter()\n                .enumerate()\n                .filter_map(|(index, &dim)| if dim == 1 { Some(index) } else { None })\n                .collect();\n        } else {\n            // If negative dims, count from the back\n            dim_indices = dims\n                .iter()\n                .map(|&d| {\n                    if d < 0 {\n                        (current_dims.len() as isize + d) as usize\n                    } else {\n                        d as usize\n                    }\n                })\n                .collect();\n        }\n\n        // Sort indices and remove duplicates\n        dim_indices.sort_unstable();\n        dim_indices.dedup();\n\n        // Make sure squeeze_dims doesn't result in a tensor with < 1 dimensions\n        check!(TensorCheck::squeeze_dims_input::<D2>(\n            &dim_indices,\n            &current_dims\n        ));\n\n        // Calculate new dimensions\n        let mut new_dims = Vec::new();\n        for (index, &dim_size) in current_dims.iter().enumerate() {\n            // Exclude the dimension if it's explicitly marked for squeezing\n            if dim_indices.contains(&index) {\n                check!(TensorCheck::squeeze::<D2>(index, &current_dims));\n                continue;\n            }\n            new_dims.push(dim_size);\n        }\n\n        // Check that after squeezing, we still respect the D2 size\n        check!(TensorCheck::squeeze_dims_len::<D2>(new_dims.len()));\n\n        Tensor::new(K::reshape(self.primitive, new_dims.into()))\n    }\n\n    /// Unsqueeze the current tensor. Create new leading dimensions to fit the given size.\n    ///\n    /// # Type Parameters\n    ///\n    ///  - `D2`: The resulting number of dimensions in the unsqueezed tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the output size `D2` is smaller than the current number of dimensions.\n    ///\n    /// # Returns\n    ///\n    /// A new `Tensor<B, D2, K>` instance with the specified dimensions added.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [3, 3]\n    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);\n    ///     // Unsqueeze the tensor up to 4 dimensions.\n    ///     // The resulting tensor will have dimensions [1, 1, 3, 3].\n    ///     let unsqueezed = tensor.unsqueeze::<4>();\n    ///     println!(\"{unsqueezed}\");\n    /// }\n    /// ```\n    pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {\n        check!(TensorCheck::unsqueeze::<D, D2>());\n\n        let mut dims = [1; D2];\n        let num_ones = D2 - D;\n        let shape = self.shape();\n\n        dims[num_ones..(D + num_ones)].copy_from_slice(&shape[..D]);\n\n        let shape = Shape::new(dims);\n        self.reshape(shape)\n    }\n\n    /// Creates a new tensor with a dimension of size one inserted at the specified position.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [3, 3]\n    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([3, 3]), &device);\n    ///     // Unsqueeze the dimension 1.\n    ///     // The resulting tensor will have dimensions [3, 1, 3].\n    ///     let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);\n    ///     println!(\"{unsqueezed}\");\n    /// }\n    /// ```\n    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {\n        check!(TensorCheck::unsqueeze_dim::<D, D2>(dim));\n\n        let mut dims = [1; D2];\n        let shape = self.shape();\n\n        dims[0..dim].copy_from_slice(&shape[0..dim]);\n\n        if dim < D {\n            dims[dim] = 1;\n            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);\n        } else {\n            dims[dim] = 1;\n        }\n\n        let shape = Shape::new(dims);\n        self.reshape(shape)\n    }\n\n    /// Creates a new tensor with added dimensions of size one inserted at the specified indices.\n    /// The indices can be negative, in which case they are counted from the last to the first dimension.\n    /// the axes can contain duplicates, in which case the number of dimensions inserted at the index\n    /// is the number of duplicates.\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 3D tensor with dimensions [3, 4, 5]\n    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([3, 4, 5]), &device);\n    ///     // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.\n    ///     // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].\n    ///     let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);\n    ///     println!(\"{unsqueezed}\");\n    /// }\n    /// ```\n    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[impl AsIndex]) -> Tensor<B, D2, K> {\n        let mut new_dims = [1; D2];\n        let old_dims = self.shape();\n        //for checking if the dimension is in the acceptable range\n\n        //part 1: convert the negative indices to positive\n        let mut neg_offset = D2;\n        let mut dim_indices = axes\n            .iter()\n            .map(|d| {\n                let d = d.as_index();\n                // check if the dimension is in the acceptable range\n                check!(TensorCheck::unsqueeze_dims::<{ D2 }>(d));\n                (if d < 0 {\n                    neg_offset -= 1; // handle multiple negative indices (decrease dim value in reverse)\n                    d + neg_offset as isize + 1\n                } else {\n                    d\n                }) as usize\n            })\n            .collect::<Vec<usize>>();\n\n        //sort the indices\n        dim_indices.sort_unstable();\n\n        //Now use this to copy the chunks of the dims\n        let mut prev_idx: usize = 0;\n        let mut current_left_b: usize = 0;\n        let mut current_right_b: usize = 0;\n        let mut offset: usize = 0;\n        dim_indices.iter().for_each(|d| {\n            //check if there is space for at least one dimension\n            if prev_idx < *d {\n                current_right_b = *d - offset;\n                //copy the chunks of the dims\n                if current_right_b < D {\n                    new_dims[prev_idx..*d]\n                        .copy_from_slice(&old_dims[current_left_b..current_right_b]);\n                } else {\n                    new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);\n                }\n                prev_idx = *d + 1;\n                //offset is equal to the number of extracted elements from the original shape\n                offset += current_right_b - current_left_b;\n                current_left_b = current_right_b;\n            } else {\n                //it's sorted so the only reason this would happen\n                //is if multiple indices are the same\n                prev_idx += 1;\n            }\n        });\n        //copy over anything past the index of the last new dimension\n        if current_left_b < D {\n            new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);\n        }\n\n        //lastly, create the shape and reshape\n        let shape = Shape::new(new_dims);\n        self.reshape(shape)\n    }\n\n    /// Roll operation along a specific dimension; wrapping around the elements.\n    ///\n    /// ## Parameters\n    ///\n    /// - `shift`: The roll extent; supports negative values and wraps around.\n    /// - `dim`: The dimension to roll; supports negative indexing.\n    ///\n    /// ## Returns\n    ///\n    /// A new tensor with the specified dimension rolled by the given shift amount.\n    pub fn roll_dim<Shift, Dim>(self, shift: Shift, dim: Dim) -> Self\n    where\n        Shift: AsIndex,\n        Dim: AsIndex,\n    {\n        let dim = dim.expect_dim_index(D);\n        let size = self.shape()[dim];\n        if size == 0 {\n            // If the dimension is empty, return the tensor as is.\n            return self;\n        }\n\n        let shift = wrap_index(shift, size);\n        if shift == 0 {\n            // If the shift is zero, return the tensor as is.\n            return self;\n        }\n\n        self.unchecked_roll_dim(shift, dim)\n    }\n\n    /// Internal implementation of `roll_dim` that does not canonicalize dimensions or shifts.\n    ///\n    /// ## Parameters\n    ///\n    /// - `shift`: The number of positions to shift; must be (0 < shift < size).\n    /// - `dim`: The dimension to roll; must be a valid index for the tensor's shape.\n    ///\n    /// ## Returns\n    ///\n    /// A new tensor with the specified dimension rolled by the given shift amount.\n    #[inline(always)]\n    fn unchecked_roll_dim(self, shift: usize, dim: usize) -> Self {\n        #[cfg(debug_assertions)]\n        {\n            let size = self.shape()[dim];\n            assert!(\n                0 < shift && shift < size,\n                \"Expected: 0 < shift < size: found shift={shift}, size={size}\",\n            );\n            assert!(\n                dim < self.shape().num_dims(),\n                \"Expected: dim < num_dims: found dim={dim}, num_dims={size}\",\n            );\n        }\n\n        Tensor::cat(\n            vec![\n                self.clone().slice_dim(dim, shift..),\n                self.slice_dim(dim, ..shift),\n            ],\n            dim,\n        )\n    }\n\n    /// Roll operation.\n    ///\n    /// Note: unlike ``pytorch``, `dims` and `shifts` must have the same length.\n    ///\n    /// A given `dim` may be rolled multiple times, and the shifts will be applied sequentially.\n    ///\n    /// ## Parameters\n    ///\n    /// - `shifts`: A slice of shifts corresponding to each dimension;\n    ///   supports negative values and wraps around.\n    /// - `dims`: A slice of dimensions to roll; supports negative indexing.\n    ///\n    /// ## Returns\n    ///\n    /// A new tensor with the specified dimensions rolled by the given shifts.\n    pub fn roll<Shift, Dim>(self, shifts: &[Shift], dims: &[Dim]) -> Self\n    where\n        Shift: AsIndex,\n        Dim: AsIndex,\n    {\n        assert_eq!(\n            dims.len(),\n            shifts.len(),\n            \"Dimensions and shifts must align; found dims={dims:#?}, shifts={shifts:#?}\",\n        );\n\n        // This is a fair amount of complexity, which could be replaced\n        // by a simple canonicalization of `dims` and wrapping of `shifts`.\n        // The work is done here to ensure that any roll operation\n        // which could be a no-op is a no-op; simplifying the accounting\n        // needed by backend-specific implementations of the inner roll op.\n\n        let item_count = dims.len();\n\n        let shape = self.shape();\n\n        // Accumulate the effective shifts for each dimension.\n        let mut accumulated_shifts: Vec<isize> = vec![0; shape.len()];\n        for i in 0..item_count {\n            let dim = dims[i].expect_dim_index(D);\n            accumulated_shifts[dim] += shifts[i].as_index();\n        }\n\n        // Do this after we've checked the validity of `dims` and `shifts`.\n        if self.shape().num_elements() == 0 {\n            // If the tensor is empty, return it as is.\n            return self;\n        }\n\n        // Wrap the accumulated shifts, and filter out empty dimensions.\n        let mut effective_dims: Vec<usize> = Vec::with_capacity(item_count);\n        let mut effective_shifts: Vec<usize> = Vec::with_capacity(item_count);\n        for dim in 0..shape.len() {\n            // `wrap_index` should inline, and has a fast-exit path for zero shifts.\n            let shift = wrap_index(accumulated_shifts[dim], shape[dim]);\n            if shift == 0 {\n                continue;\n            }\n\n            effective_dims.push(dim);\n            effective_shifts.push(shift);\n        }\n\n        // If no shifts are needed, return the original tensor.\n        if effective_shifts.is_empty() {\n            return self;\n        }\n\n        // At this point:\n        // - `dims` contains the effective dimensions to roll, in index order,\n        // - `shifts` contains the effective usize shifts for each dimension.\n        // - Every shift is non-zero, and less than the size of the corresponding dimension.\n        self.unchecked_roll(&effective_shifts, &effective_dims)\n    }\n\n    /// `roll` internal implementation.\n    ///\n    /// ## Parameters\n    ///\n    /// - `shifts`: A slice of shifts corresponding to each dimension;\n    ///   must be non-empty, the same length as `dims`, and all ``1..<size>``.\n    /// - `dims`: A slice of dimensions to roll; must be non-empty;\n    ///   the same length as `shifts`, and must not contain repeats.\n    ///\n    /// ## Panics\n    ///\n    /// Panics if the shifts and dimensions do not align, or if dimensions contain repeats.\n    ///\n    /// ## Returns\n    ///\n    /// A new tensor with the specified dimensions rolled by the given shifts.\n    #[inline(always)]\n    fn unchecked_roll(self, shifts: &[usize], dims: &[usize]) -> Self {\n        #[cfg(debug_assertions)]\n        {\n            assert!(!shifts.is_empty());\n            assert_eq!(\n                shifts.len(),\n                dims.len(),\n                \"Shifts and dimensions must align; found {} shifts and {} dims\",\n                shifts.len(),\n                dims.len()\n            );\n\n            let mut unique_dims = dims.to_vec();\n            unique_dims.dedup();\n\n            assert_eq!(\n                unique_dims.len(),\n                dims.len(),\n                \"Dimensions must not contain repeats; found {} unique dims and {} total dims\",\n                unique_dims.len(),\n                dims.len()\n            )\n        }\n\n        let x = self.unchecked_roll_dim(shifts[0], dims[0]);\n\n        if dims.len() == 1 {\n            x\n        } else {\n            x.unchecked_roll(&shifts[1..], &dims[1..])\n        }\n    }\n\n    /// Returns a tensor containing the elements selected from the given slices.\n    ///\n    /// This method provides flexible tensor slicing with support for various range types,\n    /// negative indices, and stepped slicing. The method accepts both single slices and\n    /// arrays of slices, with the [`s!`] macro providing convenient syntax for complex patterns.\n    ///\n    /// # Arguments\n    ///\n    /// * `slices` - Can be:\n    ///   - A single range for 1D slicing (e.g., `0..5`, `..`, `2..`)\n    ///   - An array of ranges (e.g., `[0..2, 1..4]`)\n    ///   - The [`s!`] macro output for advanced slicing with steps\n    ///   - a `&Vec<Slice>` or `&[Slice]`\n    ///\n    /// # Behavior\n    ///\n    /// - Supports partial and full slicing in any number of dimensions\n    /// - Handles negative indices by wrapping from the end (-1 is the last element)\n    /// - Automatically clamps ranges that exceed tensor dimensions\n    /// - Supports stepped slicing for selecting every nth element\n    /// - Negative steps reverse the selection order\n    ///\n    /// # Panics\n    ///\n    /// - If the number of slices exceeds the tensor's dimensions\n    /// - If a range is descending (e.g., 2..1) or empty (e.g., 1..1) without negative step\n    /// - If a step is zero\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, s};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///\n    ///     // Single dimension slicing - no brackets needed!\n    ///     let tensor = Tensor::<B, 1, burn_tensor::Int>::arange(0..10, &device);\n    ///     let slice = tensor.clone().slice(2..8);  // Simple range\n    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![2, 3, 4, 5, 6, 7]);\n    ///\n    ///     // Using s! macro for single dimension with step\n    ///     let slice = tensor.clone().slice(s![0..10;2]);  // Every 2nd element\n    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![0, 2, 4, 6, 8]);\n    ///\n    ///     // Reverse a dimension with negative step\n    ///     let slice = tensor.slice(s![..;-1]);  // Reverse entire tensor\n    ///     assert_eq!(slice.into_data().to_vec::<i32>().unwrap(), vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);\n    ///\n    ///     // Multi-dimensional slicing\n    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);\n    ///\n    ///     // Array syntax for simple ranges\n    ///     let slice = tensor.clone().slice([1..3, 2..5]);\n    ///     assert_eq!(slice.dims(), [2, 3]);\n    ///\n    ///     // Advanced multi-dimensional with s! macro\n    ///     let slice = tensor.clone().slice(s![0..4;2, ..;-1]);  // Every 2nd row, reverse columns\n    ///     assert_eq!(slice.dims(), [2, 6]);\n    ///\n    ///     // Complex 3D example with mixed slice types\n    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([4, 6, 8]), &device);\n    ///     let slice = tensor.slice(s![1..3, ..;2, -3..]);  // Rows 1-2, every 2nd col, last 3 depth\n    ///     assert_eq!(slice.dims(), [2, 3, 3]);\n    ///\n    ///     // Using negative indices\n    ///     let tensor = Tensor::<B, 2>::ones(Shape::new([4, 6]), &device);\n    ///     let slice = tensor.slice(s![-2.., ..-1]);  // Last 2 rows, all but last column\n    ///     assert_eq!(slice.dims(), [2, 5]);\n    /// }\n    /// ```\n    ///\n    /// # See Also\n    ///\n    /// - [`s!`] - The recommended macro for creating complex slice specifications\n    /// - [`slice_assign`](Self::slice_assign) - Assign values to a slice\n    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value\n    /// - [`slice_dim`](Self::slice_dim) - Slice a single dimension\n    ///\n    /// [`s!`]: crate::s!\n    pub fn slice<S>(self, slices: S) -> Self\n    where\n        S: SliceArg,\n    {\n        let shape = self.shape();\n        let slices = slices.into_slices(&shape);\n\n        // Validate slices\n        check!(TensorCheck::slice::<D>(&shape, &slices));\n\n        // Calculate output shape and check for empty slices\n        let mut output_dims = shape.clone();\n        for (dim, slice) in slices.iter().enumerate() {\n            output_dims[dim] = slice.output_size(shape[dim]);\n        }\n\n        // Return empty tensor if any dimension is 0 (empty slice)\n        if output_dims.contains(&0) {\n            return Self::empty(output_dims, &self.device());\n        }\n        Self::new(K::slice(self.primitive, &slices))\n    }\n\n    /// Assigns values to a slice of the tensor and returns the updated tensor.\n    ///\n    /// This method supports advanced slicing with steps, including negative steps for reverse\n    /// assignment. Like `slice`, it accepts both single slices and arrays, with the [`s!`] macro\n    /// providing powerful syntax for complex patterns.\n    ///\n    /// # Arguments\n    ///\n    /// * `slices` - Slice specification (same format as `slice` method)\n    /// * `values` - Tensor with values to assign (must match slice dimensions)\n    ///\n    /// # Panics\n    ///\n    /// - If slices exceed tensor dimensions\n    /// - If values dimensions don't match the selected slice shape\n    /// - If a step is zero\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, s};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///\n    ///     // Simple assignment to a sub-region\n    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);\n    ///     let values = Tensor::<B, 2>::ones([2, 3], &device);\n    ///     tensor = tensor.slice_assign([1..3, 2..5], values);\n    ///     // Now tensor[1..3, 2..5] contains ones\n    ///\n    ///     // Single dimension assignment with step\n    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);\n    ///     let values = Tensor::<B, 1>::ones([5], &device);\n    ///     tensor = tensor.slice_assign(s![0..10;2], values);\n    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]\n    ///\n    ///     // Reverse assignment with negative step\n    ///     let mut tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);\n    ///     let values = Tensor::<B, 1>::from_data([10.0, 11.0, 12.0, 13.0, 14.0], &device);\n    ///     tensor = tensor.slice_assign(s![..;-1], values);\n    ///     // Assigns in reverse: [14, 13, 12, 11, 10]\n    ///\n    ///     // Complex multi-dimensional assignment\n    ///     let mut tensor = Tensor::<B, 3>::zeros([4, 6, 8], &device);\n    ///     let values = Tensor::<B, 3>::ones([2, 3, 3], &device);\n    ///     tensor = tensor.slice_assign(s![0..4;2, ..;2, -3..], values);\n    ///     // Assigns to every 2nd row, every 2nd column, last 3 in depth\n    ///\n    ///     // Mixed syntax example\n    ///     let mut tensor = Tensor::<B, 2>::zeros([8, 8], &device);\n    ///     let pattern = Tensor::<B, 2>::ones([4, 4], &device);\n    ///     tensor = tensor.slice_assign(s![..;2, ..;2], pattern);\n    ///     // Creates a checkerboard pattern with ones\n    /// }\n    /// ```\n    ///\n    /// # See Also\n    ///\n    /// - [`s!`] - The recommended macro for creating complex slice specifications\n    /// - [`slice`](Self::slice) - Extract a slice from a tensor\n    /// - [`slice_fill`](Self::slice_fill) - Fill a slice with a constant value\n    ///\n    /// [`s!`]: crate::s!\n    pub fn slice_assign<S>(self, slices: S, values: Self) -> Self\n    where\n        S: SliceArg,\n    {\n        let shape = self.shape();\n        let slices = slices.into_slices(&shape);\n\n        // Check if any slice produces 0 elements (empty assignment).\n        // Empty assignments are no-ops and would cause issues in backend implementations.\n        let is_empty_assignment = slices\n            .iter()\n            .enumerate()\n            .any(|(i, slice)| slice.output_size(shape[i]) == 0);\n\n        if is_empty_assignment {\n            return self;\n        }\n\n        check!(TensorCheck::slice_assign::<D>(\n            &shape,\n            &values.shape(),\n            &slices\n        ));\n\n        Self::new(K::slice_assign(self.primitive, &slices, values.primitive))\n    }\n\n    /// Fills a slice of the tensor with a constant value and returns the updated tensor.\n    ///\n    /// Like other slice methods, accepts both single slices and arrays. However, this method\n    /// currently **does not support stepped slicing** - use [`slice_assign`](Self::slice_assign)\n    /// with a constant tensor for stepped patterns.\n    ///\n    /// # Arguments\n    ///\n    /// * `slices` - Slice specification (same format as `slice` method, but no steps)\n    /// * `value` - The value to fill the slice with\n    ///\n    /// # Panics\n    ///\n    /// - If slices exceed tensor dimensions\n    /// - If any slice has a step != 1 (not yet supported)\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, s};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///\n    ///     // Simple fill for a single dimension\n    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);\n    ///     tensor = tensor.slice_fill(2..5, 1.0);\n    ///     // Now tensor is [0, 0, 1, 1, 1, 0, 0, 0, 0, 0]\n    ///\n    ///     // Multi-dimensional fill\n    ///     let mut tensor = Tensor::<B, 2>::zeros([4, 6], &device);\n    ///     tensor = tensor.slice_fill([1..3, 2..5], -1.0);\n    ///     // Fills the rectangle at rows 1-2, columns 2-4 with -1\n    ///\n    ///     // Using negative indices\n    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);\n    ///     tensor = tensor.slice_fill(-3.., 2.0);\n    ///     // Fills the last 3 elements with 2.0\n    ///\n    ///     // Complex multi-dimensional example\n    ///     let mut tensor = Tensor::<B, 3>::ones([4, 6, 8], &device);\n    ///     tensor = tensor.slice_fill(s![1..3, .., -2..], 0.0);\n    ///     // Sets rows 1-2, all columns, last 2 in depth to 0\n    ///\n    ///     // Stepped slicing is supported\n    ///     let mut tensor = Tensor::<B, 1>::zeros([10], &device);\n    ///     tensor = tensor.slice_fill(s![0..10;2], 1.0);\n    ///     // Now every 2nd element is 1: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]\n    /// }\n    /// ```\n    ///\n    /// # See Also\n    ///\n    /// - [`s!`] - The macro for creating slice specifications with steps\n    /// - [`slice`](Self::slice) - Extract a slice from a tensor\n    /// - [`slice_assign`](Self::slice_assign) - Assign tensor values to a slice\n    ///\n    /// [`s!`]: crate::s!\n    pub fn slice_fill<S, E: ElementConversion>(self, slices: S, value: E) -> Self\n    where\n        S: SliceArg,\n    {\n        let shape = self.shape();\n        let slices = slices.into_slices(&shape);\n\n        check!(TensorCheck::slice::<D>(&shape, &slices));\n\n        let slice_shape = shape.slice(&slices).unwrap();\n        let value = Tensor::<B, 1, K>::from_data_dtype(\n            [value.elem::<K::Elem>()],\n            &self.device(),\n            self.dtype(),\n        );\n        let value = value.expand(slice_shape);\n        self.slice_assign(&slices, value)\n    }\n\n    /// Returns a new tensor with the specified dimension sliced.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim`: The dimension to slice.\n    /// * `slice`: The slice specification for the dimension. Can be a range (e.g., `2..5`),\n    ///   slice with step (via `s!` macro, e.g., `s![0..10;2]`), or any type that implements `Into<Slice>`.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the specified dimension sliced.\n    ///\n    /// # Panics\n    ///\n    /// If the slice is out of bounds for the specified dimension.\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    /// # use burn_tensor::{Tensor, s};\n    /// # use burn_tensor::backend::Backend;\n    /// #\n    /// # fn example<B: Backend>() {\n    /// #     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 3>::zeros([3, 4, 5], &device);\n    ///\n    ///     // Simple range slicing\n    ///     let sliced = tensor.clone().slice_dim(1, 1..3);\n    ///     assert_eq!(sliced.shape().as_slice(), [3, 2, 5]);\n    ///\n    ///     // Slicing with step - take every 2nd element\n    ///     let sliced = tensor.clone().slice_dim(2, s![0..5;2]);\n    ///     assert_eq!(sliced.shape().as_slice(), [3, 4, 3]); // Takes indices 0, 2, 4\n    ///\n    ///     // Reverse slicing with negative step\n    ///     let sliced = tensor.clone().slice_dim(1, s![..;-1]);\n    ///     assert_eq!(sliced.shape().as_slice(), [3, 4, 5]); // Reverses dimension 1\n    ///\n    ///     // Select from index 2 with step 3\n    ///     let sliced = tensor.clone().slice_dim(0, s![2..;3]);\n    ///     assert_eq!(sliced.shape().as_slice(), [1, 4, 5]); // Takes only index 2\n    ///\n    ///     // Select single index (reduces dimension to size 1)\n    ///     let sliced = tensor.slice_dim(0, 1);\n    ///     assert_eq!(sliced.shape().as_slice(), [1, 4, 5]);\n    /// # }\n    /// ```\n    ///\n    /// # See Also\n    ///\n    /// - [`slice`](Self::slice) - Slice multiple dimensions simultaneously\n    /// - [`s!`] - The macro for creating complex slice specifications\n    ///\n    /// [`s!`]: crate::s!\n    pub fn slice_dim<S>(self, dim: usize, slice: S) -> Self\n    where\n        S: Into<Slice>,\n    {\n        check!(TensorCheck::check_dim::<D>(dim));\n        let slice: Slice = slice.into();\n\n        let mut slices = vec![Slice::full(); D];\n        slices[dim] = slice;\n\n        self.slice(&slices)\n    }\n\n    /// Returns the device of the current tensor.\n    pub fn device(&self) -> B::Device {\n        K::device(&self.primitive)\n    }\n\n    /// Move the tensor to the given device.\n    pub fn to_device(self, device: &B::Device) -> Self {\n        Self::new(K::to_device(self.primitive, device))\n    }\n\n    /// Select tensor elements along the given dimension corresponding to the given indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to select from. Supports negative indexing.\n    /// * `indices` - The indices of the elements to select.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Int};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    ///   let indices = Tensor::<B, 1, Int>::from_data([0], &device);\n    ///   let tensor = tensor.select(0, indices);\n    ///   println!(\"{tensor}\");\n    ///   //  [[1.0, -2.0, 3.0]]\n    /// }\n    /// ```\n    pub fn select(self, dim: impl AsIndex, indices: Tensor<B, 1, Int>) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::select::<D>(dim));\n        Self::new(K::select(self.primitive, dim, indices.primitive))\n    }\n\n    /// Assign the selected elements along the given dimension corresponding to the given indices\n    /// from the value tensor to the original tensor using sum reduction.\n    ///\n    /// # Note\n    /// For booleans, the sum operator is logical or.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension along which to select. Supports negative indexing.\n    /// * `indices` - The indices to select from the tensor.\n    /// * `values` - The values to assign to the selected indices.\n    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).\n    ///\n    /// # Example\n    ///\n    /// Example using a 3D tensor:\n    ///\n    /// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`\n    /// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`\n    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`\n    /// `input[i, j, indices[k]] += values[i, j, k]; // dim = -1 (same as dim = 2)`\n    ///\n    /// # Warning\n    ///\n    /// Not all backends have runtime bound checks for the indices, so make sure they are valid.\n    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.\n    pub fn select_assign(\n        self,\n        dim: impl AsIndex,\n        indices: Tensor<B, 1, Int>,\n        values: Tensor<B, D, K>,\n        update: IndexingUpdateOp,\n    ) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::select_assign::<D>(\n            dim,\n            &indices.shape(),\n            &values.shape()\n        ));\n\n        Self::new(K::select_assign(\n            self.primitive,\n            dim,\n            indices.primitive,\n            values.primitive,\n            update,\n        ))\n    }\n\n    /// Update the given tensor with the value tensor where the mask is true.\n    ///\n    /// This is similar to [mask_fill](Tensor::mask_fill), however the value is a tensor instead of\n    /// a scalar.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);\n    ///   let value = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///   let tensor = tensor.mask_where(mask, value);\n    ///   println!(\"{tensor}\");\n    ///   // [[2.0, -2.0, 4.0], [5.0, 2.0, 6.0]]\n    /// }\n    /// ```\n    pub fn mask_where(self, mask: Tensor<B, D, Bool>, value: Self) -> Self {\n        Self::new(K::mask_where(\n            self.primitive,\n            mask.primitive,\n            value.primitive,\n        ))\n    }\n\n    /// Update the given tensor with the value where the mask is true.\n    ///\n    /// This is similar to [mask_where](Tensor::mask_where), however the value is a scalar instead of\n    /// a tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let mask = Tensor::<B, 2, Bool>::from_data([[true, false, true], [false, true, false]], &device);\n    ///   let tensor = tensor.mask_fill(mask, 3.0);\n    ///   println!(\"{tensor}\");\n    ///   // [[3.0, -2.0, 3.0], [5.0, 3.0, 6.0]]\n    /// }\n    /// ```\n    pub fn mask_fill<E: ElementConversion>(self, mask: Tensor<B, D, Bool>, value: E) -> Self {\n        let value = Scalar::new(value, &self.dtype());\n        Self::new(K::mask_fill(self.primitive, mask.primitive, value))\n    }\n\n    /// Gather tensor elements corresponding to the given indices from the specified dim.\n    ///\n    /// Example using a 3D tensor:\n    ///\n    /// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`\n    /// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`\n    /// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`\n    ///\n    /// # Notes\n    ///\n    /// The index tensor should have the same shape as the original tensor except for the dim\n    /// specified.\n    ///\n    /// # Warning\n    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.\n    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.\n    pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {\n        check!(TensorCheck::gather::<D>(\n            dim,\n            &self.shape(),\n            &indices.shape()\n        ));\n\n        Self::new(K::gather(dim, self.primitive, indices.primitive))\n    }\n\n    /// Assign the gathered elements corresponding to the given indices along the specified dimension\n    /// from the value tensor to the original tensor using sum reduction.\n    ///\n    /// Example using a 3D tensor:\n    ///\n    /// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`\n    /// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`\n    /// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`\n    ///\n    /// # Arguments\n    /// * `dim` - The axis along which to scatter elements.\n    /// * `indices` - The indices of the elements to scatter.\n    /// * `values` - The values to scatter into the tensor.\n    /// * `update` - The operation used to update the existing values at the indexed positions (e.g., add).\n    ///\n    /// # Notes\n    ///\n    /// The index tensor should have the same shape as the original tensor except for the specified\n    /// dimension. The value and index tensors should have the same shape.\n    ///\n    /// Other references to the input tensor will not be modified by this operation.\n    ///\n    /// # Warning\n    /// Not all backends have runtime bound checks for the indices, so make sure the they are valid.\n    /// Otherwise, out of bounds indices could lead to unexpected results instead of panicking.\n    pub fn scatter(\n        self,\n        dim: usize,\n        indices: Tensor<B, D, Int>,\n        values: Self,\n        update: IndexingUpdateOp,\n    ) -> Self {\n        check!(TensorCheck::scatter::<D>(\n            dim,\n            &self.shape(),\n            &indices.shape(),\n            &values.shape()\n        ));\n\n        Self::new(K::scatter(\n            dim,\n            self.primitive,\n            indices.primitive,\n            values.primitive,\n            update,\n        ))\n    }\n\n    /// Converts the data of the current tensor.\n    ///\n    /// # Note\n    ///\n    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple\n    /// tensors at once. This may improve laziness, especially if executed on a different\n    /// thread in native environments.\n    pub fn into_data(self) -> TensorData {\n        self.try_into_data().expect(\n            \"Error while reading data: use `try_into_data` instead to catch the error at runtime\",\n        )\n    }\n\n    /// Converts the data of the current tensor and returns any error that might have occurred since the\n    /// last time the device was synchronized.\n    ///\n    /// # Note\n    ///\n    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple\n    /// tensors at once. This may improve laziness, especially if executed on a different\n    /// thread in native environments.\n    pub fn try_into_data(self) -> Result<TensorData, ExecutionError> {\n        crate::try_read_sync(self.into_data_async()).expect(\n            \"Failed to read tensor data synchronously.\n        This can happen on platforms that don't support blocking futures like WASM.\n        If possible, try using into_data_async instead.\",\n        )\n    }\n\n    /// Converts the data of the current tensor.\n    ///\n    /// # Note\n    ///\n    /// For better performance, prefer using a [Transaction](crate::Transaction) when reading multiple\n    /// tensors at once. This may improve laziness, especially if executed on a different\n    /// thread in native environments.\n    pub fn to_data(&self) -> TensorData {\n        self.clone().into_data()\n    }\n\n    /// Returns the data of the current tensor.\n    pub async fn into_data_async(self) -> Result<TensorData, ExecutionError> {\n        K::into_data_async(self.primitive).await\n    }\n\n    /// Returns the data of the current tensor.\n    pub async fn to_data_async(&self) -> Result<TensorData, ExecutionError> {\n        self.clone().into_data_async().await\n    }\n\n    /// Create a tensor from the given data on the given device.\n    pub fn from_data<T>(data: T, device: &B::Device) -> Self\n    where\n        T: Into<TensorData>,\n    {\n        let data = data.into();\n        check!(TensorCheck::creation_ops::<D>(\n            \"From Data\",\n            data.shape.as_slice()\n        ));\n        Self::new(K::from_data(data, device))\n    }\n\n    /// Create a tensor from the given data on the given device enforcing the given data type.\n    pub fn from_data_dtype<T>(data: T, device: &B::Device, dtype: DType) -> Self\n    where\n        T: Into<TensorData>,\n    {\n        let data = data.into();\n        check!(TensorCheck::creation_ops::<D>(\n            \"From Data\",\n            data.shape.as_slice()\n        ));\n        Self::new(K::from_data_dtype(data, device, dtype))\n    }\n\n    /// Repeat the tensor along the given dimension.\n    ///\n    /// The output tensor has the same shape, except along the given dimension.\n    ///\n    /// # Arguments\n    /// - `dim`: The dimension to repeat.\n    /// - `times`: The number of times to repeat the tensor along the given dimension in the new tensor.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the given dimension repeated `times` times.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [3, 2]\n    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///\n    ///     // Repeat the tensor along the dimension 0 twice.\n    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]\n    ///     // The resulting tensor will have dimensions [6, 2].\n    ///     let repeated = tensor.repeat_dim(0, 2);\n    ///     println!(\"{repeated}\");\n    /// }\n    /// ```\n    pub fn repeat_dim(self, dim: usize, times: usize) -> Self {\n        if times > 0 {\n            Self::new(K::repeat_dim(self.primitive, dim, times))\n        } else {\n            let shape = self.shape().repeat(dim, times).unwrap();\n            Self::empty(shape, &self.device())\n        }\n    }\n\n    /// Repeat the tensor along the given dimensions.\n    /// # Arguments\n    /// - `sizes`: Borrowed slice of the number of times to repeat each dimension.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the given dimensions repeated `times` times.\n    ///\n    /// # Panics\n    ///\n    /// If `sizes` contains more elements than the number of dimensions.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    ///\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [3, 2]\n    ///     let tensor = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///\n    ///     // Repeat the tensor along the dimension 0 twice and the dimension 0 once.\n    ///     // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]\n    ///     // The resulting tensor will have dimensions [6, 2].\n    ///     let repeated = tensor.repeat(&[2, 1]);\n    /// }\n    /// ```\n    pub fn repeat(self, sizes: &[usize]) -> Self {\n        if sizes.contains(&0) {\n            let mut shape = self.shape();\n            for (dim, &times) in sizes.iter().enumerate() {\n                shape = shape.repeat(dim, times).unwrap();\n            }\n\n            return Self::empty(shape, &self.device());\n        }\n\n        let mut tensor = self;\n        for (dim, &times) in sizes.iter().enumerate() {\n            if times > 1 {\n                tensor = tensor.repeat_dim(dim, times);\n            }\n        }\n        tensor\n    }\n\n    /// Applies element-wise equal comparison.\n    ///\n    /// # Returns\n    /// A boolean tensor that is `true` where input is equal to `other` and `false` elsewhere.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///     // Compare the elements of the two 2D tensors with dimensions [3, 2].\n    ///     // [[false, true], [true, true], [true, true]]\n    ///     let equal = t1.equal(t2);\n    ///     println!(\"{equal}\");\n    /// }\n    /// ```\n    pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"Equal\", &self, &other));\n        Tensor::new(K::equal(self.primitive, other.primitive))\n    }\n\n    /// Applies element-wise non-equality comparison.\n    ///\n    /// # Returns\n    /// A boolean tensor that is `true` where input is not equal to `other` and `false` elsewhere.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let t1 = Tensor::<B, 2>::from_data([[2.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///     let t2 = Tensor::<B, 2>::from_data([[3.0, 4.9], [2.0, 1.9], [4.0, 5.9]], &device);\n    ///     // Compare the elements of the two 2D tensors for inequality.\n    ///     // [[true, false], [false, false], [false, false]]\n    ///     let not_equal = t1.not_equal(t2);\n    ///     println!(\"{not_equal}\");\n    /// }\n    /// ```\n    pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"NotEqual\", &self, &other));\n        Tensor::new(K::not_equal(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise equal comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.equal_elem(3.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[false, false, true], [false, false, false]]\n    /// }\n    /// ```\n    pub fn equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::equal_elem(self.primitive, other))\n    }\n\n    /// Applies element wise non-equality comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.not_equal_elem(3.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[true, true, false], [true, true, true]]\n    /// }\n    /// ```\n    pub fn not_equal_elem<E: Element>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::not_equal_elem(self.primitive, other))\n    }\n\n    /// Concatenates all tensors into a new one along the given dimension.\n    ///\n    /// # Panics\n    ///\n    /// - If `dim` is higher than the rank.\n    /// - If `tensors` is an empty vector.\n    /// - If all tensors don't have the same shape (the dimension `dim` is ignored).\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0, 1.0], [2.0, 1.9, 3.0, 1.0]], &device);\n    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);\n    ///\n    ///     // Concatenate the two tensors with shapes [2, 4] and [2, 3] along the dimension 1.\n    ///     // [[3.0, 4.9, 2.0, 1.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.0, 1.4, 5.8, 6.0]]\n    ///     // The resulting tensor will have shape [2, 7].\n    ///     let concat = Tensor::cat(vec![t1, t2], 1);\n    ///     println!(\"{concat}\");\n    /// }\n    /// ```\n    pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {\n        check!(TensorCheck::cat(&tensors, dim));\n\n        // Filter out tensors with size 0 along the concatenation dimension.\n        // Empty tensors don't contribute to the output and would cause issues\n        // in backend implementations (e.g., division by zero in slice_assign).\n        // Safety: TensorCheck::cat ensures tensors is non-empty\n        let first_tensor = tensors.first().unwrap();\n        let device = first_tensor.device();\n        let mut shape = first_tensor.shape();\n\n        let non_empty_primitives: Vec<_> = tensors\n            .into_iter()\n            .filter(|t| t.shape()[dim] > 0)\n            .map(|t| t.primitive)\n            .collect();\n\n        // If all tensors were empty, return an empty tensor with size 0 on concat dim\n        if non_empty_primitives.is_empty() {\n            shape[dim] = 0;\n            return Self::empty(shape, &device);\n        }\n\n        Self::new(K::cat(non_empty_primitives, dim))\n    }\n\n    /// Concatenates all tensors into a new one along a new dimension.\n    ///\n    /// # Panics\n    ///\n    /// - If all tensors don't have the same shape.\n    /// - If given dimension is not with range of 0..D2\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let t1 = Tensor::<B, 2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);\n    ///     let t2 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);\n    ///     let t3 = Tensor::<B, 2>::from_data([[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]], &device);\n    ///\n    ///     // Concatenate the three tensors with shape [2, 3] along a new dimension, 0.\n    ///     // [[[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]],\n    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]],\n    ///     //  [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]\n    ///     // The resulting tensor will have shape [3, 2, 3].\n    ///     let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);\n    ///     println!(\"{stacked}\");\n    /// }\n    /// ```\n    pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {\n        check!(TensorCheck::stack::<B, D, K, D2>(&tensors, dim));\n        let tensors = tensors.into_iter().map(|t| t.unsqueeze_dim(dim)).collect();\n        Tensor::<B, D2, K>::cat(tensors, dim)\n    }\n\n    /// Iterate over slices of tensors alongside a given dimension.\n    ///\n    /// # Panics\n    ///\n    /// If given dimension is greater than or equal to tensor rank.\n    ///\n    /// # Returns\n    ///\n    /// A tensor iterator.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    /// fn example<B: Backend>() {\n    ///   let device = Default::default();\n    ///   let tensor = Tensor::<B,2>::from_data([[3.0, 4.9, 2.0], [2.0, 1.9, 3.0]], &device);\n    ///   // Given a 2D tensor with dimensions [2, 3], iterate over slices of tensors along the dimension 0.\n    ///   let iter = tensor.iter_dim(0);\n    ///   for (i,tensor) in iter.enumerate() {\n    ///     println!(\"Tensor {}: {}\", i, tensor);\n    ///     // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }\n    ///     // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }\n    ///  }\n    /// }\n    /// ```\n    pub fn iter_dim(self, dim: usize) -> DimIter<B, D, K> {\n        check!(TensorCheck::dim_ops::<D>(\"iter_dim\", dim));\n        DimIter::new(self, dim)\n    }\n\n    /// Returns a new tensor with the given dimension narrowed to the given range.\n    ///\n    /// # Panics\n    ///\n    /// - If the dimension is greater than the number of dimensions of the tensor.\n    /// - If the given range exceeds the number of elements on the given dimension.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the given dimension narrowed to the given range.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [4, 3]\n    ///     let tensor = Tensor::<B, 2>::from_data(\n    ///         [\n    ///             [3.0, 4.9, 2.0],\n    ///             [2.0, 1.9, 3.0],\n    ///             [6.0, 1.5, 7.0],\n    ///             [3.0, 4.9, 9.0],\n    ///         ],\n    ///         &device,\n    ///     );\n    ///     // Narrow the tensor along the dimension 0, keeping 3 elements starting from index 1.\n    ///     // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]\n    ///     // The resulting tensor will have dimensions [3, 3].\n    ///     let narrowed = tensor.narrow(0, 1, 3);\n    ///     println!(\"{narrowed}\");\n    /// }\n    /// ```\n    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {\n        check!(TensorCheck::dim_ops::<D>(\"narrow\", dim));\n        check!(TensorCheck::narrow(&self, dim, start, length));\n        let dims = self.dims();\n\n        let ranges: [Range<usize>; D] = dims\n            .iter()\n            .enumerate()\n            .map(|(i, d)| {\n                if i == dim {\n                    start..(start + length)\n                } else {\n                    0..*d\n                }\n            })\n            .collect::<Vec<_>>()\n            .try_into()\n            .unwrap();\n\n        Self::slice(self, ranges)\n    }\n\n    /// Attempts to split the tensor into a specified number of chunks along a given dimension.\n    /// May return less chunks than requested if the tensor size is not divisible by the number of chunks.\n    ///\n    /// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.\n    /// Otherwise all chunks will be of equal size except for the last one.\n    ///\n    /// # Panics\n    ///\n    /// If the dimension is greater than the number of dimensions of the tensor.\n    ///\n    /// # Returns\n    /// A vector of tensors.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [4, 3]\n    ///     let tensor = Tensor::<B, 2>::from_data(\n    ///         [\n    ///             [3.0, 4.9, 2.0],\n    ///             [2.0, 1.9, 3.0],\n    ///             [6.0, 1.5, 7.0],\n    ///             [3.0, 4.9, 9.0],\n    ///         ],\n    ///         &device,\n    ///     );\n    ///     // Split the tensor along the dimension 1 into 2 chunks.\n    ///     // The first chuck will have shape [4, 2]:\n    ///     // [[3.0, 4.9], [2.0, 1.9], [6.0, 1.5], [3.0, 4.9]]\n    ///     // The second chunk will have shape [4, 1]:\n    ///     // [[2.0], [3.0], [7.0], [9.0]]\n    ///     let chunks = tensor.chunk(2, 1);\n    ///     println!(\"{chunks:?}\");\n    /// }\n    /// ```\n    pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {\n        check!(TensorCheck::dim_ops::<D>(\"chunk\", dim));\n        let size = self.shape()[dim];\n        if size < chunks {\n            return (0..size)\n                .map(|i| Self::narrow(self.clone(), dim, i, 1))\n                .collect();\n        }\n\n        let mut tensors = Vec::with_capacity(chunks);\n        let mut sum_chunk_size = 0;\n        if size.is_multiple_of(chunks) {\n            let chunk_size = size / chunks;\n            for _ in 0..chunks {\n                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));\n                sum_chunk_size += chunk_size;\n            }\n        } else {\n            let chunk_size = (size / chunks) + 1; // assumes not divisible\n            for _ in 0..chunks - 1 {\n                tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, chunk_size));\n                sum_chunk_size += chunk_size;\n            }\n            let remainder = size % chunk_size;\n            tensors.push(Self::narrow(self.clone(), dim, sum_chunk_size, remainder));\n        }\n\n        tensors\n    }\n\n    /// Splits the tensor into chunks of a specified size along a given dimension.\n    /// Each chunk is a view of the original tensor.\n    ///\n    /// If the tensor size along the given dimension is not divisible by `split_size`,\n    /// then the last chunk will be smaller.\n    ///\n    /// # Panics\n    ///\n    /// If the specified dimension to split along is greater than the number of dimensions of the tensor.\n    ///\n    /// # Returns\n    ///\n    /// A vector of tensors.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 1D tensor with 5 elements\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);\n    ///     // Split the tensor into chunks of size 2 along dimension 0\n    ///     let chunks = tensor.split(2, 0);\n    ///     // The result is a vector of tensors:\n    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0]), Tensor([4.0])]\n    ///     println!(\"{:?}\", chunks);\n    /// }\n    /// ```\n    pub fn split(self, split_size: usize, dim: usize) -> Vec<Self> {\n        check!(TensorCheck::split::<D>(&self.shape(), split_size, dim));\n        let size = self.shape()[dim];\n        let mut tensors = Vec::new();\n\n        let mut start = 0;\n        while start < size {\n            let length = usize::min(split_size, size - start);\n            tensors.push(Self::narrow(self.clone(), dim, start, length));\n            start += length;\n        }\n\n        tensors\n    }\n\n    /// Splits the tensor into chunks with the specified sizes along a given dimension.\n    /// Each chunk is a view of the original tensor.\n    ///\n    /// The sizes of the chunks are specified in the `split_sizes` vector. The sum of the sizes\n    /// in `split_sizes` must equal the size of the tensor along the specified dimension.\n    ///\n    /// # Panics\n    ///\n    /// If the specified dimension to split along is greater than the number of dimensions of the tensor or\n    /// if the sum of `dim_sizes` does not equal the size of the tensor along `dim`.\n    ///\n    /// # Returns\n    ///\n    /// A vector of tensors.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 1D tensor with 5 elements\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, 1.0, 2.0, 3.0, 4.0], &device);\n    ///     // Split the tensor into chunks with sizes [2, 3] along dimension 0\n    ///     let chunks = tensor.split_with_sizes(vec![2, 3], 0);\n    ///     // The result is a vector of tensors:\n    ///     // [Tensor([0.0, 1.0]), Tensor([2.0, 3.0, 4.0])]\n    ///     println!(\"{:?}\", chunks);\n    /// }\n    /// ```\n    pub fn split_with_sizes(self, split_sizes: Vec<usize>, dim: usize) -> Vec<Self> {\n        check!(TensorCheck::split_with_sizes::<D>(\n            &self.shape(),\n            &split_sizes,\n            dim\n        ));\n        let mut tensors = Vec::new();\n\n        let mut start = 0;\n        for length in split_sizes {\n            if length == 0 {\n                continue;\n            }\n            tensors.push(Self::narrow(self.clone(), dim, start, length));\n            start += length;\n        }\n\n        tensors\n    }\n\n    /// Tests if any element in the `tensor` evaluates to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` containing a single element, True if any element in the input tensor\n    /// evaluates to True, False otherwise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = Default::default();\n    ///   let tensor = Tensor::<B,2, Bool>::from_data([[true,false,true],[false,true,false]], &device);\n    ///   let tensor_two = Tensor::<B,2, Bool>::from_data([[false,false,false],[false,false,false]], &device);\n    ///\n    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.\n    ///   let any_tensor = tensor.any();\n    ///   println!(\"{}\", any_tensor);\n    ///   // Tensor { data: [true], ... }\n    ///\n    ///   // Given a 2D tensor with dimensions [2, 3], test if any element in the tensor evaluates to True.\n    ///   let any_tensor_two = tensor_two.any();\n    ///   println!(\"{}\", any_tensor_two);\n    ///   // Tensor { data: [false], ... }\n    /// }\n    /// ```\n    pub fn any(self) -> Tensor<B, 1, Bool> {\n        Tensor::new(K::any(self.primitive))\n    }\n\n    /// Tests if any element in the `tensor` evaluates to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if any element along this dim in the input\n    /// evaluates to True, False otherwise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor =\n    ///         Tensor::<B, 2, Bool>::from_data([[true, false, false], [false, true, false]], &device);\n    ///     // Check if any element in the tensor evaluates to True along the dimension 1.\n    ///     // [[true], [true]],\n    ///     let any_dim = tensor.clone().any_dim(1);\n    ///     println!(\"{any_dim}\");\n    /// }\n    /// ```\n    pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {\n        Tensor::new(K::any_dim(self.primitive, dim))\n    }\n\n    /// Tests if all elements in the `tensor` evaluate to True.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, 1, Bool>` with a single element, True if all elements in the input tensor\n    /// evaluate to True, False otherwise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor =\n    ///         Tensor::<B, 2, Bool>::from_data([[true, false, true], [true, true, true]], &device);\n    ///     // Check if all elements in the tensor evaluate to True (which is not the case).\n    ///     // [false]\n    ///     let all = tensor.all();\n    ///     println!(\"{all}\");\n    /// }\n    /// ```\n    pub fn all(self) -> Tensor<B, 1, Bool> {\n        Tensor::new(K::all(self.primitive))\n    }\n\n    /// Tests if all elements in the `tensor` evaluate to True along a given dimension `dim`.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to test. All input tensor types (Float, Int, Bool) are supported.\n    /// * `dim` - The axis along which to test.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor `Tensor<B, D, Bool>` with the same shape as input `tensor`, except in the `dim` axis\n    /// where the size is 1. The elem in the `dim` axis is True if all elements along this dim in the input\n    /// evaluates to True, False otherwise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor =\n    ///         Tensor::<B, 2, Bool>::from_data([[true, true, false], [true, true, true]], &device);\n    ///     // Check if all elements in the tensor evaluate to True along the dimension 1.\n    ///     // [[true, true, false]]\n    ///     let all_dim = tensor.clone().all_dim(0);\n    ///     println!(\"{all_dim}\");\n    /// }\n    /// ```\n    pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {\n        Tensor::new(K::all_dim(self.primitive, dim))\n    }\n\n    /// Convert the tensor into a scalar.\n    ///\n    /// # Panics\n    ///\n    /// - If the tensor doesn't have one element.\n    /// - If the backend fails to read the tensor data synchronously.\n    ///\n    /// # Returns\n    ///\n    /// The scalar value of the tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);\n    ///     // Convert the tensor with a single element into a scalar.\n    ///     let scalar = tensor.into_scalar();\n    ///     println!(\"{scalar}\");\n    /// }\n    /// ```\n    pub fn into_scalar(self) -> K::Elem {\n        crate::try_read_sync(self.into_scalar_async())\n            .expect(\n            \"Failed to read tensor data synchronously. This can happen on platforms\n            that don't support blocking futures like WASM. Try into_scalar_async instead.\",\n            )\n            .expect(\"Error while reading data: use `try_into_scalar` instead to catch the error at runtime\")\n    }\n\n    /// Convert the tensor into a scalar and returns any error that might have occurred since the\n    /// last time the device was synchronized.\n    ///\n    /// # Panics\n    ///\n    /// - If the tensor doesn't have one element.\n    /// - If the backend fails to read the tensor data synchronously.\n    ///\n    /// # Returns\n    ///\n    /// The scalar value of the tensor.\n    pub fn try_into_scalar(self) -> Result<K::Elem, ExecutionError> {\n        crate::try_read_sync(self.into_scalar_async()).expect(\n            \"Failed to read tensor data synchronously. This can happen on platforms\n            that don't support blocking futures like WASM. Try into_scalar_async instead.\",\n        )\n    }\n\n    /// Convert the tensor into a scalar.\n    ///\n    /// # Panics\n    ///\n    /// If the tensor doesn't have one element.\n    pub async fn into_scalar_async(self) -> Result<K::Elem, ExecutionError> {\n        check!(TensorCheck::into_scalar::<D>(&self.shape()));\n\n        Ok(self.into_data_async().await?.iter().next().unwrap())\n    }\n\n    /// Broadcast the tensor to the given shape.\n    ///\n    /// Only singleton dimensions can be expanded to a larger size. Other dimensions must have the same size\n    /// (which can be inferred with `-1`).\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape to broadcast the tensor to.\n    ///   Can contain -1 for dimensions that should be inferred.\n    ///   The number of elements in the shape must be greater or equal as\n    ///   the number of dimensions of the tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the tensor cannot be broadcasted to the given shape.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the given shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     // Create a 2D tensor with dimensions [3, 1]\n    ///     let tensor = Tensor::<B, 2>::from_data([[1.], [2.], [3.]], &device);\n    ///     // Expand the tensor to a new shape [3, 4]\n    ///     // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]\n    ///     let expanded = tensor.expand([3, 4]);\n    ///     println!(\"{}\", expanded);\n    /// }\n    /// ```\n    pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {\n        let shape = shape.into_shape(&self.shape());\n        check!(TensorCheck::expand::<D, D2>(\n            \"expand\",\n            &self.shape(),\n            &shape,\n        ));\n\n        Tensor::<B, D2, K>::new(K::expand(self.primitive, shape))\n    }\n\n    /// Unfold windows along a dimension.\n    ///\n    /// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;\n    /// where windows are advanced by `step` at each index.\n    ///\n    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.\n    ///\n    /// The new view will have the unfolded dimension replaced by two dimensions;\n    /// one in the position of the original dimension, with size equal to the number of windows,\n    /// and one appended to the right-most position, with size equal to `size`.\n    ///\n    /// # Warning\n    ///\n    /// For the `ndarray` backend; this is not a view but a copy\n    /// with duplicated data.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - the dimension to unfold.\n    /// * `size` - the size of each unfolded window.\n    /// * `step` - the step between each window.\n    ///\n    /// # Returns\n    ///\n    /// A tensor view with the shape ``[pre=..., windows, post=..., size]``.\n    pub fn unfold<const D2: usize, I: AsIndex>(\n        self,\n        dim: I,\n        size: usize,\n        step: usize,\n    ) -> Tensor<B, D2, K> {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::unfold::<D, D2>(\n            \"unfold\",\n            &self.shape(),\n            dim,\n            size,\n            step,\n        ));\n        Tensor::<B, D2, K>::new(K::unfold(self.primitive, dim, size, step))\n    }\n}\n\n/// Iterator given by (Tensor::iter_dim).\npub struct DimIter<B, const D: usize, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n{\n    start: usize,\n    end: usize,\n    dim: usize,\n    ranges: [Range<usize>; D],\n    tensor: Tensor<B, D, K>,\n}\n\nimpl<B: Backend, const D: usize, K: BasicOps<B>> Iterator for DimIter<B, D, K> {\n    type Item = Tensor<B, D, K>;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        if self.start >= self.end {\n            return None;\n        }\n\n        let mut ranges = self.ranges.clone();\n        ranges[self.dim] = self.start..(self.start + 1);\n\n        let slice = self.tensor.clone().slice(ranges);\n        self.start += 1;\n\n        Some(slice)\n    }\n}\n\nimpl<B: Backend, const D: usize, K: BasicOps<B>> DoubleEndedIterator for DimIter<B, D, K> {\n    fn next_back(&mut self) -> Option<Self::Item> {\n        if self.start >= self.end {\n            return None;\n        }\n\n        let mut ranges = self.ranges.clone();\n        ranges[self.dim] = (self.end - 1)..self.end;\n\n        let slice = self.tensor.clone().slice(ranges);\n        self.end = self.end.saturating_sub(1);\n\n        Some(slice)\n    }\n}\n\nimpl<B: Backend, const D: usize, K: BasicOps<B>> DimIter<B, D, K> {\n    fn new(tensor: Tensor<B, D, K>, dim: usize) -> Self {\n        let dims = tensor.dims();\n        let ranges = dims\n            .iter()\n            .map(|&dim| 0..dim)\n            .collect::<Vec<Range<usize>>>();\n        let ranges: [Range<usize>; D] = ranges.try_into().unwrap();\n        Self {\n            end: dims[dim],\n            ranges,\n            start: 0,\n            dim,\n            tensor,\n        }\n    }\n}\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n    <K as BasicOps<B>>::Elem: Debug,\n{\n    #[inline]\n    fn push_newline_indent(acc: &mut String, indent: usize) {\n        acc.push('\\n');\n        for _ in 0..indent {\n            acc.push(' ');\n        }\n    }\n    fn fmt_inner_tensor(\n        &self,\n        acc: &mut String,\n        depth: usize,\n        multi_index: &mut [usize],\n        range: (usize, usize),\n        precision: Option<usize>,\n    ) {\n        let (start, end) = range;\n        for i in start..end {\n            if i > 0 {\n                acc.push_str(\", \");\n            }\n            multi_index[depth] = i;\n            let range: [Range<usize>; D] =\n                core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);\n\n            let data = burn_std::reader::try_read_sync(self.clone().slice(range).into_data_async());\n\n            if let Some(Ok(data)) = data {\n                let elem = data.iter::<<K as BasicOps<B>>::Elem>().next().unwrap();\n                match (precision, K::name()) {\n                    (Some(p), \"Float\") => acc.push_str(&format!(\"{elem:.p$}\")),\n                    (_, \"Bool\") => acc.push_str(&format!(\"{}\", elem.to_bool())),\n                    _ => acc.push_str(&format!(\"{elem:?}\")),\n                }\n            } else {\n                acc.push_str(\"<Tensor data not available>\");\n            }\n        }\n    }\n\n    fn fmt_outer_tensor(\n        &self,\n        acc: &mut String,\n        depth: usize,\n        multi_index: &mut [usize],\n        print_options: &PrintOptions,\n        summarize: bool,\n        range: (usize, usize),\n    ) {\n        let (start, end) = range;\n        for i in start..end {\n            if i > start {\n                acc.push(',');\n                Self::push_newline_indent(acc, depth + 1);\n            }\n            acc.push('[');\n            multi_index[depth] = i;\n            self.display_recursive(acc, depth + 1, multi_index, print_options, summarize);\n            acc.push(']');\n        }\n    }\n\n    /// Recursively formats the tensor data for display and appends it to the provided accumulator string.\n    ///\n    /// This function is designed to work with tensors of any dimensionality.\n    /// It traverses the tensor dimensions recursively, converting the elements\n    /// to strings and appending them to the accumulator string with the\n    /// appropriate formatting.\n    ///\n    /// # Arguments\n    ///\n    /// * `acc` - A mutable reference to a `String` used as an accumulator for the formatted output.\n    /// * `depth` - The current depth of the tensor dimensions being processed.\n    /// * `multi_index` - A mutable slice of `usize` representing the current indices in each dimension.\n    fn display_recursive(\n        &self,\n        acc: &mut String,\n        depth: usize,\n        multi_index: &mut [usize],\n        print_options: &PrintOptions,\n        summarize: bool,\n    ) {\n        let edge_items = print_options.edge_items;\n\n        if depth == 0 {\n            acc.push('[');\n        }\n\n        if depth == self.dims().len() - 1 {\n            // if we are at the innermost dimension, just push its elements into the accumulator\n            if summarize && self.dims()[depth] > 2 * edge_items {\n                // print the starting `edge_items` elements\n                self.fmt_inner_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    (0, edge_items),\n                    print_options.precision,\n                );\n                acc.push_str(\", ...\");\n                // print the last `edge_items` elements\n                self.fmt_inner_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    (self.dims()[depth] - edge_items, self.dims()[depth]),\n                    print_options.precision,\n                );\n            } else {\n                // print all the elements\n                self.fmt_inner_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    (0, self.dims()[depth]),\n                    print_options.precision,\n                );\n            }\n        } else {\n            // otherwise, iterate through the current dimension and recursively display the inner tensors\n            if summarize && self.dims()[depth] > 2 * edge_items {\n                self.fmt_outer_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    print_options,\n                    summarize,\n                    (0, edge_items),\n                );\n\n                acc.push(',');\n                Self::push_newline_indent(acc, depth + 1);\n                acc.push_str(\"...\");\n                Self::push_newline_indent(acc, depth + 1);\n\n                self.fmt_outer_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    print_options,\n                    summarize,\n                    (self.dims()[depth] - edge_items, self.dims()[depth]),\n                );\n            } else {\n                self.fmt_outer_tensor(\n                    acc,\n                    depth,\n                    multi_index,\n                    print_options,\n                    summarize,\n                    (0, self.dims()[depth]),\n                );\n            }\n        }\n\n        if depth == 0 {\n            acc.push(']');\n        }\n    }\n}\n\n#[derive(Clone, Debug)]\n/// Options for Tensor pretty printing\npub struct PrintOptions {\n    /// number of elements to start summarizing tensor\n    pub threshold: usize,\n\n    /// number of starting elements and ending elements to display\n    pub edge_items: usize,\n\n    /// Precision for floating point numbers\n    pub precision: Option<usize>,\n}\n\nstatic PRINT_OPTS: RwLock<PrintOptions> = RwLock::new(PrintOptions::const_default());\n\nimpl PrintOptions {\n    /// Print options with default values\n    pub const fn const_default() -> Self {\n        Self {\n            threshold: 1000,\n            edge_items: 3,\n            precision: None,\n        }\n    }\n}\n\nimpl Default for PrintOptions {\n    fn default() -> Self {\n        Self::const_default()\n    }\n}\n\n/// Set print options\npub fn set_print_options(options: PrintOptions) {\n    let mut print_opts = PRINT_OPTS.write().unwrap();\n    *print_opts = options;\n}\n\n/// Pretty print tensors\nimpl<B, const D: usize, K> core::fmt::Display for Tensor<B, D, K>\nwhere\n    B: Backend,\n    B::IntElem: core::fmt::Display,\n    K: BasicOps<B>,\n    <K as BasicOps<B>>::Elem: Debug,\n{\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        writeln!(f, \"Tensor {{\")?;\n\n        {\n            // Do not lock the mutex for the whole function\n            let mut po = { PRINT_OPTS.read().unwrap().clone() };\n\n            // Override the precision if it is set from the formatter\n            // This will be possible when the tensor is printed using the `{:.*}` syntax\n            if let Some(precision) = f.precision() {\n                po.precision = Some(precision);\n            }\n\n            let mut acc = String::new();\n            let mut multi_index = vec![0; D];\n            let summarize = self.shape().num_elements() > po.threshold;\n\n            self.display_recursive(&mut acc, 0, &mut multi_index, &po, summarize);\n\n            writeln!(f, \"  data:\")?;\n            write!(f, \"{acc}\")?;\n            writeln!(f, \",\")?;\n        }\n\n        writeln!(f, \"  shape:  {:?},\", self.dims())?;\n        writeln!(f, \"  device:  {:?},\", self.device())?;\n        writeln!(f, \"  backend:  {:?},\", B::name(&self.device()))?;\n        writeln!(f, \"  kind:  {:?},\", K::name())?;\n\n        let dtype = self.primitive.dtype();\n\n        writeln!(f, \"  dtype:  {:?},\", dtype.name())?;\n        write!(f, \"}}\")\n    }\n}\n\n/// Trait used for movedim arguments\npub trait MovedimArgs {\n    /// Converts into a set of dimensions `Vec<usize>` for the `tensor.movedim()` function\n    fn into_dim_vec<const D: usize>(self) -> Vec<usize>;\n}\n\nimpl MovedimArgs for Vec<i32> {\n    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {\n        let set = self\n            .iter()\n            .map(|&dim| {\n                if dim < 0 {\n                    (D as i32 + dim) as usize\n                } else {\n                    dim as usize\n                }\n            })\n            .collect::<Vec<usize>>();\n        check!(TensorCheck::movedim_args_vec::<D>(&set));\n\n        set\n    }\n}\n\nimpl MovedimArgs for Vec<usize> {\n    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {\n        check!(TensorCheck::movedim_args_vec::<D>(&self));\n        self\n    }\n}\n\nimpl MovedimArgs for usize {\n    #[allow(clippy::vec_init_then_push)]\n    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {\n        check!(TensorCheck::movedim_args_usize::<D>(self));\n\n        let mut set = Vec::with_capacity(1);\n        set.push(self);\n\n        set\n    }\n}\n\nimpl MovedimArgs for i32 {\n    #[allow(clippy::vec_init_then_push)]\n    fn into_dim_vec<const D: usize>(self) -> Vec<usize> {\n        check!(TensorCheck::movedim_args_i32::<D>(self));\n\n        let dim = if self < 0 {\n            (D as i32 + self) as usize\n        } else {\n            self as usize\n        };\n\n        let mut set = Vec::with_capacity(1);\n        set.push(dim);\n\n        set\n    }\n}\n\n/// Trait used for reshape arguments.\npub trait ReshapeArgs<const D2: usize>: Debug {\n    /// Converts to a shape.\n    fn into_shape<const D: usize>(self, source: Shape) -> Shape;\n}\n\nimpl<const D2: usize, I: AsIndex> ReshapeArgs<D2> for [I; D2] {\n    fn into_shape<const D: usize>(self, source: Shape) -> Shape {\n        unwrap_shape_reshape(source.reshape(self))\n    }\n}\n\nimpl<const D2: usize> ReshapeArgs<D2> for Shape {\n    fn into_shape<const D: usize>(self, source: Shape) -> Shape {\n        unwrap_shape_reshape(source.reshape(self))\n    }\n}\n\n/// Trait used for broadcast arguments.\npub trait BroadcastArgs<const D1: usize, const D2: usize> {\n    /// Converts to a shape.\n    fn into_shape(self, shape: &Shape) -> Shape;\n}\n\nimpl<const D1: usize, const D2: usize> BroadcastArgs<D1, D2> for Shape {\n    fn into_shape(self, _shape: &Shape) -> Shape {\n        self\n    }\n}\n\nimpl<const D1: usize, const D2: usize, E: AsIndex> BroadcastArgs<D1, D2> for [E; D2] {\n    // Passing -1 as the size for a dimension means not changing the size of that dimension.\n    fn into_shape(self, shape: &Shape) -> Shape {\n        if self.len() < shape.num_dims() {\n            panic!(\"Broadcast arguments must be greater than the number of dimensions\");\n        }\n\n        // Zip the two shapes in reverse order and replace -1 with the actual dimension value.\n        let new_shape: Vec<_> = self\n            .iter()\n            .rev()\n            .map(|x| {\n                let primitive = x.as_index();\n                if primitive < -1 || primitive == 0 {\n                    panic!(\"Broadcast arguments must be positive or -1\");\n                }\n                primitive\n            })\n            .zip(shape.iter().rev().chain(repeat(&0)).take(self.len())) // Pad the original shape with 0s\n            .map(|(x, &y)| if x == -1 { y } else { x as usize })\n            .collect::<Vec<_>>()\n            .into_iter()\n            .rev()\n            .collect();\n\n        if new_shape.contains(&0) {\n            panic!(\"Cannot substitute -1 for a non-existing dimension\");\n        }\n\n        let new_shape: [usize; D2] = new_shape.try_into().unwrap();\n\n        Shape::from(new_shape)\n    }\n}\n\nimpl<B, const D: usize, K> Serialize for Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n    K::Elem: Debug + Copy + Serialize,\n{\n    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {\n        let data = self.to_data();\n        data.serialize(serializer)\n    }\n}\n\nimpl<'de, B, const D: usize, K> Deserialize<'de> for Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n    K::Elem: Debug + Copy + Deserialize<'de>,\n{\n    fn deserialize<De: Deserializer<'de>>(deserializer: De) -> Result<Self, De::Error> {\n        let tensor = Tensor::from_data(\n            TensorData::deserialize(deserializer)?,\n            &<B::Device as Default>::default(),\n        );\n        Ok(tensor)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use burn_std::SliceOps;\n\n    use crate::{Shape, s};\n\n    #[test]\n    fn slice_range_single_dim_leading() {\n        let shape = Shape::new([8, 4]);\n\n        // Half-open range\n        let slices = shape.clone().into_slices([0..5]);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        let slices = shape.clone().into_slices([-3..-1]);\n        assert_eq!(slices[0].to_range(8), 5..7);\n\n        // Inclusive range\n        let slices = shape.clone().into_slices([0..=4]);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        let slices = shape.clone().into_slices([-2..=-1]);\n        assert_eq!(slices[0].to_range(8), 6..8);\n\n        // Unbounded start\n        let slices = shape.clone().into_slices([..3]);\n        assert_eq!(slices[0].to_range(8), 0..3);\n        let slices = shape.clone().into_slices([..-5]);\n        assert_eq!(slices[0].to_range(8), 0..3);\n\n        // Unbounded end\n        let slices = shape.clone().into_slices([5..]);\n        assert_eq!(slices[0].to_range(8), 5..8);\n        let slices = shape.clone().into_slices([-3..]);\n        assert_eq!(slices[0].to_range(8), 5..8);\n\n        // Full range\n        let slices = shape.into_slices([..]);\n        assert_eq!(slices[0].to_range(8), 0..8);\n    }\n\n    #[test]\n    fn test_negative_slice_indices() {\n        use crate::Slice;\n\n        // Test negative indices conversion\n        let slice: Slice = (-3..-1).into();\n        assert_eq!(slice.start, -3);\n        assert_eq!(slice.end, Some(-1));\n\n        // Test to_range conversion with size 8\n        let range = slice.to_range(8);\n        assert_eq!(range, 5..7);\n\n        // Test with shape slice\n        let shape = Shape::new([8, 4]);\n        let result = shape.clone().into_slices([-3..-1]);\n        assert_eq!(result[0].to_range(8), 5..7);\n\n        // Test more negative index cases\n        let slice2: Slice = (-5..).into();\n        assert_eq!(slice2.to_range(10), 5..10);\n\n        let slice3: Slice = (..-2).into();\n        assert_eq!(slice3.to_range(10), 0..8);\n\n        // Test with s! macro - single dimension returns Slice directly\n        let slice4 = s![-3..-1];\n        assert_eq!(slice4.start, -3);\n        assert_eq!(slice4.end, Some(-1));\n    }\n\n    #[test]\n    fn slice_range_multi_dim() {\n        let shape = Shape::new([8, 4]);\n\n        // Multiple ways to provide ranges\n        let slices = shape.clone().into_slices([0..5, 0..4]);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        assert_eq!(slices[1].to_range(4), 0..4);\n\n        let slices = shape.clone().into_slices([0.., 0..]);\n        assert_eq!(slices[0].to_range(8), 0..8);\n        assert_eq!(slices[1].to_range(4), 0..4);\n\n        let slices = shape.clone().into_slices([0..=7, 0..=3]);\n        assert_eq!(slices[0].to_range(8), 0..8);\n        assert_eq!(slices[1].to_range(4), 0..4);\n\n        let slices = shape.clone().into_slices([0..5, 0..3]);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        assert_eq!(slices[1].to_range(4), 0..3);\n\n        let slices = shape.into_slices([0.., 0..]);\n        assert_eq!(slices[0].to_range(8), 0..8);\n        assert_eq!(slices[1].to_range(4), 0..4);\n    }\n\n    #[test]\n    fn slice_range_multi_dim_index() {\n        let shape = Shape::new([8, 4]);\n\n        // Indices (single integer) should also convert to correct range\n        let slices = shape.clone().into_slices([0, 2]);\n        assert_eq!(slices[0].to_range(8), 0..1);\n        assert_eq!(slices[1].to_range(4), 2..3);\n\n        let slices = shape.into_slices([-1, -1]);\n        assert_eq!(slices[0].to_range(8), 7..8);\n        assert_eq!(slices[1].to_range(4), 3..4);\n    }\n\n    #[test]\n    fn slice_range_multi_dim_heterogeneous() {\n        // Slice macro `s![]` can be used to provide different range types\n        let shape = Shape::new([8, 4, 2]);\n        let slice = s![0..5, .., -1];\n        let slices = shape.into_slices(slice);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        assert_eq!(slices[1].to_range(4), 0..4);\n        assert_eq!(slices[2].to_range(2), 1..2);\n\n        let shape = Shape::new([8, 4, 2, 3]);\n        let slice = s![..=4, 0..=3, .., -2..];\n        let slices = shape.into_slices(slice);\n        assert_eq!(slices[0].to_range(8), 0..5);\n        assert_eq!(slices[1].to_range(4), 0..4);\n        assert_eq!(slices[2].to_range(2), 0..2);\n        assert_eq!(slices[3].to_range(3), 1..3);\n\n        let shape = Shape::new([3, 4]);\n        let slice = s![1..-1, ..];\n        let slices = shape.into_slices(slice);\n        assert_eq!(slices[0].to_range(3), 1..2);\n        assert_eq!(slices[1].to_range(4), 0..4);\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/bool.rs",
    "content": "use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};\nuse alloc::{vec, vec::Vec};\n\nuse crate::try_read_sync;\n\n/// The part of the tensor to keep when creating a triangular mask.\nenum TriPart {\n    /// Upper triangular part.\n    Upper,\n\n    /// Lower triangular part.\n    Lower,\n\n    /// Diagonal part.\n    Diagonal,\n}\n\nimpl<B, const D: usize> Tensor<B, D, Bool>\nwhere\n    B: Backend,\n{\n    /// Create a boolean tensor from data on the given device.\n    ///\n    /// # Arguments\n    ///\n    /// * `data` - The tensor data.\n    /// * `device` - The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);\n    ///     println!(\"{tensor}\");\n    /// }\n    /// ```\n    pub fn from_bool(data: TensorData, device: &B::Device) -> Self {\n        Self::new(B::bool_from_data(data.convert::<B::BoolElem>(), device))\n    }\n\n    /// Convert the bool tensor into an int tensor.\n    ///\n    /// # Returns\n    ///\n    /// An integer tensor where `true` is converted to `1` and `false` to `0`.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);\n    ///     let int_tensor = bool_tensor.int();\n    ///     println!(\"{int_tensor}\"); // [1, 0, 1]\n    /// }\n    /// ```\n    pub fn int(self) -> Tensor<B, D, Int> {\n        Tensor::new(B::bool_into_int(self.primitive))\n    }\n\n    /// Convert the bool tensor into a float tensor.\n    ///\n    /// # Returns\n    ///\n    /// A float tensor where `true` is converted to `1.0` and `false` to `0.0`.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let bool_tensor = Tensor::<B, 1, Bool>::from_bool([true, false, true].into(), &device);\n    ///     let float_tensor = bool_tensor.float();\n    ///     println!(\"{float_tensor}\"); // [1.0, 0.0, 1.0]\n    /// }\n    /// ```\n    pub fn float(self) -> Tensor<B, D> {\n        Tensor::new(TensorPrimitive::Float(B::bool_into_float(self.primitive)))\n    }\n\n    /// Inverses boolean values.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 2, Bool>::from_bool([[true, false], [false, true]].into(), &device);\n    ///     let inverted = tensor.bool_not();\n    ///     println!(\"{inverted}\"); // [[false, true], [true, false]]\n    /// }\n    /// ```\n    pub fn bool_not(self) -> Self {\n        Tensor::new(B::bool_not(self.primitive))\n    }\n\n    /// Performs logical and (`&&`) on two boolean tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `rhs` - The right-hand side tensor for the AND operation.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where each element is the result of `self[i] && rhs[i]`.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);\n    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);\n    ///     let result = a.bool_and(b);\n    ///     println!(\"{result}\"); // [[true, false], [false, false]]\n    /// }\n    /// ```\n    pub fn bool_and(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {\n        Tensor::new(B::bool_and(self.primitive, rhs.primitive))\n    }\n\n    /// Performs logical or (`||`) on two boolean tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `rhs` - The right-hand side tensor for the OR operation.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where each element is the result of `self[i] || rhs[i]`.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);\n    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);\n    ///     let result = a.bool_or(b);\n    ///     println!(\"{result}\"); // [[true, true], [true, false]]\n    /// }\n    /// ```\n    pub fn bool_or(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {\n        Tensor::new(B::bool_or(self.primitive, rhs.primitive))\n    }\n\n    /// Performs logical xor (`^`) on two boolean tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `rhs` - The right-hand side tensor for the XOR operation.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where each element is the result of `self[i] ^ rhs[i]`.\n    /// Returns `true` when exactly one of the operands is `true`.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let a = Tensor::<B, 2, Bool>::from_bool([[true, true], [false, false]].into(), &device);\n    ///     let b = Tensor::<B, 2, Bool>::from_bool([[true, false], [true, false]].into(), &device);\n    ///     let result = a.bool_xor(b);\n    ///     println!(\"{result}\"); // [[false, true], [true, false]]\n    /// }\n    /// ```\n    pub fn bool_xor(self, rhs: Tensor<B, D, Bool>) -> Tensor<B, D, Bool> {\n        Tensor::new(B::bool_xor(self.primitive, rhs.primitive))\n    }\n\n    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).\n    ///\n    /// # Returns\n    ///\n    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of\n    /// the non-zero elements in that dimension.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(\n    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),\n    ///         &device,\n    ///     );\n    ///     let indices = tensor.nonzero();\n    ///     println!(\"{}\", indices[0]); // [0, 0, 1, 2]\n    ///     println!(\"{}\", indices[1]); // [0, 2, 1, 1]\n    /// }\n    /// ```\n    pub fn nonzero(self) -> Vec<Tensor<B, 1, Int>> {\n        try_read_sync(self.nonzero_async())\n            .expect(\"Failed to read tensor data synchronously. Try using nonzero_async instead.\")\n    }\n\n    /// Compute the indices of `true` elements in the tensor (i.e., non-zero for boolean tensors).\n    ///\n    /// # Returns\n    ///\n    /// A vector of tensors, one for each dimension of the given tensor, containing the indices of\n    /// the non-zero elements in that dimension.\n    pub async fn nonzero_async(self) -> Vec<Tensor<B, 1, Int>> {\n        let indices = self.argwhere_async().await;\n\n        if indices.shape().num_elements() == 0 {\n            // Return empty vec when all elements are zero\n            return vec![];\n        }\n\n        let dims = indices.shape();\n        indices\n            .chunk(dims[1], 1)\n            .into_iter()\n            .map(|t| t.reshape(Shape::new([dims[0]])))\n            .collect()\n    }\n\n    /// Compute the indices of the elements that are true, grouped by element.\n    ///\n    /// # Returns\n    ///\n    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the\n    /// result contains the indices of a non-zero element.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 2, Bool>::from_bool(\n    ///         [[true, false, true], [false, true, false], [false, true, false]].into(),\n    ///         &device,\n    ///     );\n    ///     let indices = tensor.argwhere();\n    ///     println!(\"{indices}\"); // [[0, 0], [0, 2], [1, 1], [2, 1]]\n    /// }\n    /// ```\n    pub fn argwhere(self) -> Tensor<B, 2, Int> {\n        try_read_sync(self.argwhere_async())\n            .expect(\"Failed to read tensor data synchronously. Try using argwhere_async instead.\")\n    }\n\n    /// Compute the indices of the elements that are true, grouped by element.\n    ///\n    /// # Returns\n    ///\n    /// A tensor containing the indices of all non-zero elements of the given tensor. Each row in the\n    /// result contains the indices of a non-zero element.\n    pub async fn argwhere_async(self) -> Tensor<B, 2, Int> {\n        Tensor::new(B::bool_argwhere(self.primitive).await)\n    }\n\n    /// Creates a mask for the upper, lower triangle, or diagonal of a matrix, which can be used to\n    /// fill the specified area with a value.\n    fn tri_mask<S: Into<Shape>>(\n        shape: S,\n        tri_part: TriPart,\n        offset: i64,\n        device: &B::Device,\n    ) -> Self {\n        let shape: Shape = shape.into();\n        let height = shape[D - 2];\n        let width = shape[D - 1];\n\n        // Generate row and column index tensors.\n        let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, device);\n        let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, device);\n\n        // Prepare shapes for broadcasting.\n        let mut row_shape = [1; D];\n        row_shape[D - 2] = height;\n        let mut col_shape = [1; D];\n        col_shape[D - 1] = width;\n\n        // Reshape for broadcasting.\n        let row_broadcast: Tensor<B, D, Int> = row_indices.reshape(Shape::new(row_shape));\n        let col_broadcast = col_indices.reshape(Shape::new(col_shape));\n\n        // Broadcasting trick to create a matrix that facilitates comparison for mask generation.\n        let matrix = row_broadcast.clone() - (col_broadcast.clone() - offset);\n\n        // Select the appropriate comparison function based on `tri_part`.\n        let compare = match tri_part {\n            TriPart::Upper => Tensor::greater_elem,\n            TriPart::Lower => Tensor::lower_elem,\n            TriPart::Diagonal => Tensor::not_equal_elem,\n        };\n\n        // Generate and return the mask by applying the comparison to the matrix.\n        compare(matrix, 0).unsqueeze()\n    }\n\n    /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified\n    /// area with a value.\n    ///\n    /// This function generates a boolean tensor representing the mask of the upper triangle of a matrix.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape`: The shape of the matrix.\n    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift\n    ///   towards the upper triangle.\n    /// * `device`: The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the\n    /// upper triangle taking into account the specified `offset`. All other elements are `true`.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let mask = Tensor::<B, 2, Bool>::triu_mask([3, 3], 0, &Default::default());\n    ///   println!(\"{mask}\");\n    ///   // [[false, false, false],\n    ///   //  [true, false, false],\n    ///   //  [true, true, false]]\n    /// }\n    /// ```\n    pub fn triu_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {\n        Self::tri_mask(shape, TriPart::Upper, offset, device)\n    }\n\n    /// Creates a mask for the lower triangle of a matrix, which can be used to fill the specified\n    /// area with a value.\n    ///\n    /// This function generates a boolean tensor representing the mask of the lower triangle of a matrix.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape`: The shape of the matrix.\n    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and negative values shift\n    ///   towards the lower triangle.\n    /// * `device`: The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the\n    /// lower triangle taking into account the specified `offset`. All other elements are `true`.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let mask = Tensor::<B, 2, Bool>::tril_mask([3, 3], 0, &Default::default());\n    ///   println!(\"{mask}\");\n    ///   // [[false, true, true],\n    ///   //  [false, false, true],\n    ///   //  [false, false, false]]\n    /// }\n    /// ```\n    pub fn tril_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {\n        Self::tri_mask(shape, TriPart::Lower, offset, device)\n    }\n\n    /// Creates a mask for the diagonal of a matrix, which can be used to fill the specified\n    /// area with a value.\n    ///\n    /// This function generates a boolean tensor representing the mask of the diagonal of a matrix.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape`: The shape of the matrix.\n    /// * `offset`: The offset from the diagonal, where 0 means the diagonal, and positive values shift\n    ///   towards the upper triangle.\n    /// * `device`: The device on which the tensor will be allocated.\n    ///\n    /// # Returns\n    ///\n    /// Returns a boolean tensor where `false` indicates the elements of the matrix that are part of the\n    /// diagonal. All other elements are `true`.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let mask = Tensor::<B, 2, Bool>::diag_mask([3, 3], 0, &Default::default());\n    ///   println!(\"{mask}\");\n    ///   // [[false, true, true],\n    ///   //  [true, false, true],\n    ///   //  [true, true, false]]\n    /// }\n    /// ```\n    pub fn diag_mask<S: Into<Shape>>(shape: S, offset: i64, device: &B::Device) -> Self {\n        Self::tri_mask(shape, TriPart::Diagonal, offset, device)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/cartesian_grid.rs",
    "content": "use crate::{Int, Shape, Tensor, backend::Backend};\nuse alloc::vec::Vec;\n\n/// Generates a cartesian grid for the given tensor shape on the specified device.\n/// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.\n///\n/// # Arguments\n///\n/// * `shape` - The shape specifying the dimensions of the tensor.\n/// * `device` - The device to create the tensor on.\n///\n/// # Panics\n///\n/// Panics if `D2` is not equal to `D+1`.\n///\n/// # Examples\n///\n/// ```rust\n///    use burn_tensor::Int;\n///    use burn_tensor::{backend::Backend, Shape, Tensor};\n///    fn example<B: Backend>() {\n///        let device = Default::default();\n///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);\n///        println!(\"{}\", result);\n///    }\n/// ```\npub fn cartesian_grid<B: Backend, S: Into<Shape>, const D: usize, const D2: usize>(\n    shape: S,\n    device: &B::Device,\n) -> Tensor<B, D2, Int> {\n    if D2 != D + 1 {\n        panic!(\"D2 must equal D + 1 for Tensor::cartesian_grid\")\n    }\n\n    let dims = shape.into();\n    let mut indices: Vec<Tensor<B, D, Int>> = Vec::new();\n\n    for dim in 0..D {\n        let dim_range: Tensor<B, 1, Int> = Tensor::arange(0..dims[dim] as i64, device);\n\n        let mut shape = [1; D];\n        shape[dim] = dims[dim];\n        let mut dim_range = dim_range.reshape(shape);\n\n        for (i, &item) in dims.iter().enumerate() {\n            if i == dim {\n                continue;\n            }\n            dim_range = dim_range.repeat_dim(i, item);\n        }\n\n        indices.push(dim_range);\n    }\n\n    Tensor::stack::<D2>(indices, D)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/check.rs",
    "content": "use crate::ops::FloatElem;\nuse crate::{BasicOps, Shape, Slice, Tensor, backend::Backend, cast::ToElement};\nuse alloc::format;\nuse alloc::string::{String, ToString};\nuse alloc::vec;\nuse alloc::vec::Vec;\nuse burn_backend::tensor::Ordered;\n\n/// The struct should always be used with the [check](crate::check) macro.\n///\n/// This is a simple pub(crate) data structure that efficiently checks tensor operations and\n/// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter\n/// when a failed check is discovered since the program will panic.\n///\n/// # Notes\n///\n/// Failing tensor checks will always result in a panic.\n/// As mentioned in [The Rust Programming Language book](https://doc.rust-lang.org/book/ch09-03-to-panic-or-not-to-panic.html),\n/// when there is no way to recover, panic should be used instead of a result.\n///\n/// Most users will unwrap the results anyway, which will worsen the clarity of the code. Almost\n/// all checks highlight programming errors, which means invalid programs that should be fixed.\n/// Checks are not the ideal way to help users write correct programs, but they are still better\n/// than backend errors. Other forms of compile-time validation could be developed, such as named\n/// tensors, but we have to carefully evaluate the ease of use of the Tensor API. Adding overly\n/// complex type validation checks might drastically worsen the API and result in harder-to-maintain\n/// programs.\n///\n/// # Design\n///\n/// Maybe the Backend API should return a result for each operation, which would allow handling\n/// all checks, even the ones that can't be efficiently checked before performing an operation,\n/// such as the `index_select` operation. The downside of that approach is that all backend\n/// implementation might re-implement the same checks, which may result in unnecessary code\n/// duplication. Maybe a combination of both strategies could help to cover all use cases.\npub(crate) enum TensorCheck {\n    Ok,\n    Failed(FailedTensorCheck),\n}\n\nimpl TensorCheck {\n    /// Checks device and shape compatibility for element wise binary operations.\n    pub(crate) fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(\n        ops: &str,\n        lhs: &Tensor<B, D, K>,\n        rhs: &Tensor<B, D, K>,\n    ) -> Self {\n        Self::Ok\n            .binary_ops_device(ops, &lhs.device(), &rhs.device())\n            .binary_ops_ew_shape::<D>(ops, &lhs.shape(), &rhs.shape())\n    }\n\n    pub(crate) fn into_scalar<const D: usize>(shape: &Shape) -> Self {\n        let mut check = Self::Ok;\n\n        if shape.num_elements() != 1 {\n            check = check.register(\n                \"Into Scalar\",\n                TensorError::new(\"Only tensors with 1 element can be converted into scalar.\")\n                    .details(format!(\n                        \"Current tensor has {} elements\",\n                        shape.num_elements()\n                    )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn dim_ops<const D: usize>(ops: &str, dim: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim >= D {\n            check = check.register(\n                ops,\n                TensorError::new(\"Given dimension is higher than the tensor rank.\")\n                    .details(format!(\"Tensor rank: '{D}', given dimension: '{dim}'.\")),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn creation_ops<const D: usize>(ops: &str, dims: &[usize]) -> Self {\n        let mut check = Self::Ok;\n\n        if D == 0 {\n            check = check.register(\n                ops,\n                TensorError::new(\"Tried to create a 0-dim tensor, which is invalid.\")\n                    .details(format!(\"Tensor rank: '{D}', given dimensions: '{dims:?}'.\")),\n            );\n        }\n\n        if dims.len() != D {\n            check = check.register(\n                ops,\n                TensorError::new(\"Given dimensions differ from the tensor rank.\")\n                    .details(format!(\"Tensor rank: '{D}', given dimensions: '{dims:?}'.\")),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn narrow<B: Backend, const D: usize, K: BasicOps<B>>(\n        tensor: &Tensor<B, D, K>,\n        dim: usize,\n        start: usize,\n        length: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n\n        if length == 0 {\n            check = check.register(\n                \"Narrow\",\n                TensorError::new(format!(\n                    \"Can't narrow at dimension {dim}, length must be greater than 0\",\n                )),\n            );\n        }\n\n        if start >= tensor.shape()[dim] {\n            check = check.register(\n                \"Narrow\",\n                TensorError::new(format!(\n                    \"Can't narrow at dimension {dim}, start exceeds the size of the tensor along \\\n                     this dimension (Size={})\",\n                    tensor.shape()[dim]\n                )),\n            );\n        }\n\n        if start + length > tensor.shape()[dim] {\n            check = check.register(\n                \"Narrow\",\n                TensorError::new(format!(\n                    \"Can't narrow at dimension {dim}, start + length exceeds the size of the tensor \\\n                     along this dimension (Size={})\",\n                    tensor.shape()[dim]\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn movedim_args_usize<const D: usize>(dim: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim >= D {\n            check = check.register(\n                \"Movedim\",\n                TensorError::new(\n                    \"The given dimension exceeds the number of dimensions of the current tensor.\",\n                )\n                .details(format!(\n                    \"Current tensor has {D} dimensions, but the given dimension is {dim}.\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn movedim_args_i32<const D: usize>(dim: i32) -> Self {\n        let mut check = Self::Ok;\n\n        if dim < -(D as i32) || dim >= D as i32 {\n            check = check.register(\n                \"Movedim\",\n                TensorError::new(\n                    \"The given dimension is out of bounds for the current tensor dimensions.\",\n                )\n                .details(format!(\n                    \"Current tensor has {D} dimensions, but the given dimension is {dim}.\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn movedim_args_vec<const D: usize>(dims: &Vec<usize>) -> Self {\n        let mut check = Self::Ok;\n\n        // Check out of bounds\n        if dims.iter().any(|&x| x >= D) {\n            check = check.register(\n                \"Movedim\",\n                TensorError::new(\"The given dimensions are out of bounds.\").details(format!(\n                    \"Current tensor has {D} dimensions, but the given dimensions are {dims:?}.\",\n                )),\n            );\n        }\n\n        // Check there are no duplicates\n        for (i, &dim_i) in dims.iter().enumerate() {\n            for &dim_j in dims.iter().skip(i + 1) {\n                if dim_i == dim_j {\n                    check = check.register(\n                        \"Movedim\",\n                        TensorError::new(\"The given dimensions contain duplicates.\").details(\n                            format!(\n                            \"The dimension {dim_i} is duplicated in the given dimensions {dims:?}.\",\n                        ),\n                        ),\n                    );\n                }\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn movedim_args_length(\n        source_dims: &Vec<usize>,\n        destination_dims: &Vec<usize>,\n    ) -> Self {\n        let mut check = Self::Ok;\n\n        if source_dims.len() != destination_dims.len() {\n            check = check.register(\n                \"Movedim\",\n                TensorError::new(\n                    \"The number of dimensions in source and destination must be equal.\",\n                )\n                .details(format!(\n                    \"Source dimensions: {source_dims:?}, Destination dimensions: {destination_dims:?}.\",\n                )),\n            )\n        }\n\n        check\n    }\n\n    pub(crate) fn flatten<const D1: usize, const D2: usize>(\n        start_dim: usize,\n        end_dim: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n\n        if start_dim > end_dim {\n            check = check.register(\n                \"Flatten\",\n                TensorError::new(format!(\n                    \"The start dim ({start_dim}) must be smaller than or equal to the end dim ({end_dim})\"\n                )),\n            );\n        }\n\n        if D2 > D1 {\n            check = check.register(\n                \"Flatten\",\n                TensorError::new(format!(\n                    \"Result dim ({D2}) must be smaller than or equal to ({D1})\"\n                )),\n            );\n        }\n\n        if D1 < end_dim + 1 {\n            check = check.register(\n                \"Flatten\",\n                TensorError::new(format!(\n                    \"The end dim ({end_dim}) must be smaller than the tensor dim ({D1})\"\n                )),\n            );\n        }\n\n        if (D2 as i32) < (D1 as i32 - (end_dim as i32 - start_dim as i32)) {\n            check = check.register(\n                \"Flatten\",\n                TensorError::new(format!(\n                    \"The destination dimension ({D2}) must be large enough to accommodate the \\\n                     flattening operation.\"\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn tri<const D: usize>() -> Self {\n        let mut check = Self::Ok;\n\n        if D < 2 {\n            check = check.register(\n                \"Tri\",\n                TensorError::new(format!(\n                    \"The input tensor must have at least 2 dimensions, got {D}\"\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn squeeze<const D2: usize>(dim: usize, tensor_dims: &[usize]) -> Self {\n        let mut check = Self::Ok;\n        // This should actually be to check that the dimension to squeeze\n        // has a size of 1\n        if tensor_dims[dim] != 1 {\n            check = check.register(\n                \"Squeeze\",\n                TensorError::new(format!(\n                    \"Can't squeeze dimension {dim} because its size is not 1\",\n                )),\n            );\n        }\n\n        if dim >= tensor_dims.len() {\n            check = check.register(\n                \"Squeeze\",\n                TensorError::new(format!(\n                    \"Dimension index {dim} is out of bounds for tensor dimensions {tensor_dims:?}.\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn squeeze_dims_input<const D2: usize>(\n        dim_indices: &[usize],\n        current_dims: &[usize],\n    ) -> Self {\n        let mut check = Self::Ok;\n        if dim_indices.len() >= current_dims.len() {\n            check = check.register(\n                \"Squeeze\",\n                TensorError::new(\"Attempted to squeeze too many dimensions!\"),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn squeeze_dims_len<const D2: usize>(new_dims_len: usize) -> Self {\n        let mut check = Self::Ok;\n        if new_dims_len == 0 {\n            // 0-dim tensor not supported\n            check = check.register(\n                \"Squeeze\",\n                TensorError::new(\n                    \"Resulting dimensions cannot be zero. To remove specific singleton dimensions while preserving at least one, use `squeeze_dims` instead.\".to_string()\n                ),\n            );\n        }\n\n        if new_dims_len != D2 {\n            check = check.register(\n                \"Squeeze\",\n                TensorError::new(format!(\n                    \"Resulting dimensions {new_dims_len} do not match the required D2 size {D2}.\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {\n        let mut check = Self::Ok;\n        if D2 < D1 {\n            check = check.register(\n                \"Unsqueeze\",\n                TensorError::new(format!(\n                    \"Can't unsqueeze smaller tensor, got dim {D2}, expected > {D1}\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn unsqueeze_dim<const D1: usize, const D2: usize>(dim: usize) -> Self {\n        let mut check = Self::Ok;\n        if D2 <= D1 {\n            check = check.register(\n                \"Unsqueeze\",\n                TensorError::new(format!(\n                    \"The unsqueezed rank must be greater than the input rank (D={D1}; D2={D2})\",\n                )),\n            );\n        }\n\n        if dim > D1 {\n            check = check.register(\n                \"Unsqueeze\",\n                TensorError::new(format!(\n                    \"Can't unsqueeze at dimension {dim}, exceeds tensor dimensions (D={D1})\",\n                )),\n            );\n        }\n\n        if dim >= D2 {\n            check = check.register(\n                \"Unsqueeze\",\n                TensorError::new(format!(\n                    \"Can't unsqueeze at dimension {dim}, exceeds output tensor dimensions (D2={D2})\",\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn unsqueeze_dims<const D: usize>(dim: isize) -> Self {\n        let mut check = Self::Ok;\n        let output_rank = D as isize;\n        //contains is right exclusive, so this is to spec\n        if !(-output_rank..output_rank).contains(&dim) {\n            check = check.register(\n                \"Unsqueeze\",\n                TensorError::new(format!(\n                    \"unsqueeze arg {dim} is out of range for the output tensor of rank {output_rank}\",\n                )),\n            );\n        }\n        check\n    }\n\n    pub(crate) fn one_hot_tensor<B: Backend, const D: usize, K: Ordered<B>>(\n        index_tensor: Tensor<B, D, K>,\n        num_classes: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n        if index_tensor\n            .clone()\n            .greater_equal_elem(num_classes as i32)\n            .any()\n            .into_scalar()\n            .to_bool()\n        {\n            check = check.register(\n                \"One Hot\",\n                TensorError::new(format!(\n                    \"Can't create a one hot tensor from ({index_tensor:?}) containing indexes greater or equal to the number of classes ({num_classes})\",\n                )),\n            );\n        } else if num_classes <= 1 {\n            check = check.register(\n                \"One Hot\",\n                TensorError::new(\"Can't create a one hot tensor with less then 2 classes\"),\n            )\n        }\n        check\n    }\n\n    pub(crate) fn one_hot_tensor_rank<const D: usize, const D2: usize>() -> Self {\n        let mut check = Self::Ok;\n        if D + 1 != D2 {\n            check = check.register(\n                \"One Hot\",\n                TensorError::new(\n                    \"The one-hot tensor rank must correspond to the rank of the tensor + 1\",\n                )\n                .details(format!(\"Expected D2={}, got {D2}\", D + 1)),\n            );\n        }\n        check\n    }\n\n    pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim1 > D || dim2 > D {\n            check = check.register(\n                \"Swap Dims\",\n                TensorError::new(\"The swap dimensions must be smaller than the tensor dimension\")\n                    .details(format!(\n                        \"Swap dims ({dim1}, {dim2}) on tensor with ({D}) dimensions.\"\n                    )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn permute<const D: usize>(axes: [usize; D]) -> Self {\n        let check = Self::Ok;\n\n        // Check if the axes are within the tensor dimensions\n        if let Some(axis) = axes.iter().find(|&x| *x >= D) {\n            return check.register(\n                \"permute\",\n                TensorError::new(\"The axes must be smaller than the tensor dimension.\")\n                    .details(format!(\"The '{axis}' axis is greater than {D} dimensions.\")),\n            );\n        }\n\n        // Check if the axes are unique\n        let mut seen = [false; D];\n        axes.iter().for_each(|&x| seen[x] = true);\n        if seen.iter().any(|&x| !x) {\n            return check.register(\n                \"permute\",\n                TensorError::new(\"The axes must be unique.\")\n                    .details(format!(\"The axes '{axes:?}' are not unique.\")),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn flip(rank: usize, axes: &[usize]) -> Self {\n        let check = Self::Ok;\n\n        // Check if the axes are within the tensor dimensions\n        if let Some(axis) = axes.iter().find(|&x| *x >= rank) {\n            return check.register(\n                \"flip\",\n                TensorError::new(\"The axes must be smaller than the tensor dimension.\").details(\n                    format!(\"The '{axis}' axis is greater than {rank} dimensions.\"),\n                ),\n            );\n        }\n\n        // Check if the axes are unique\n        let mut dedup = axes.to_vec();\n        dedup.sort_unstable();\n        dedup.dedup();\n        if dedup.len() != axes.len() {\n            return check.register(\n                \"flip\",\n                TensorError::new(\"The axes must be unique.\")\n                    .details(format!(\"The axes '{axes:?}' are not unique.\")),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn matmul<B: Backend, const D: usize, K>(\n        lhs: &Tensor<B, D, K>,\n        rhs: &Tensor<B, D, K>,\n    ) -> Self\n    where\n        K: BasicOps<B>,\n    {\n        let mut check = Self::Ok;\n\n        check = check.binary_ops_device(\"Matmul\", &lhs.device(), &rhs.device());\n\n        if D < 2 {\n            return check;\n        }\n\n        let shape_lhs = lhs.shape();\n        let shape_rhs = rhs.shape();\n\n        let dim_lhs = shape_lhs[D - 1];\n        let dim_rhs = shape_rhs[D - 2];\n\n        if dim_lhs != dim_rhs {\n            check = check.register(\n                \"Matmul\",\n                TensorError::new(format!(\n                    \"The inner dimension of matmul should be the same, but got {dim_lhs} and \\\n                     {dim_rhs}.\"\n                ))\n                .details(format!(\n                    \"Lhs shape {:?}, rhs shape {:?}.\",\n                    shape_lhs, shape_rhs\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn cross<B: Backend, const D: usize, K>(\n        lhs: &Tensor<B, D, K>,\n        rhs: &Tensor<B, D, K>,\n        dim: usize,\n    ) -> Self\n    where\n        K: BasicOps<B>,\n    {\n        let mut check = Self::Ok;\n\n        check = check.binary_ops_device(\"Cross\", &lhs.device(), &rhs.device());\n\n        let shape_lhs = lhs.shape();\n        let shape_rhs = rhs.shape();\n\n        if dim >= D {\n            check = check.register(\n                \"Cross\",\n                TensorError::new(format!(\n                    \"Dimension {dim} is out of bounds for tensors with {D} dimensions.\"\n                )),\n            );\n            return check;\n        }\n\n        let dim_size_lhs = shape_lhs[dim];\n        let dim_size_rhs = shape_rhs[dim];\n\n        if dim_size_lhs != 3 || dim_size_rhs != 3 {\n            check = check.register(\n                \"Cross\",\n                TensorError::new(format!(\n                    \"Cross product requires dimension {dim} to have size 3, but got {dim_size_lhs} and {dim_size_rhs}.\"\n                )),\n            );\n        }\n\n        // Check broadcastability of other dimensions\n        for i in 0..D {\n            if i != dim {\n                let l = shape_lhs[i];\n                let r = shape_rhs[i];\n                if l != r && l != 1 && r != 1 {\n                    check = check.register(\n                        \"Cross\",\n                        TensorError::new(format!(\n                            \"Tensors are not broadcastable along dimension {i}: {l} and {r}.\"\n                        )),\n                    );\n                }\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn stack<B: Backend, const D1: usize, K: BasicOps<B>, const D2: usize>(\n        tensors: &[Tensor<B, D1, K>],\n        dim: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n\n        if dim > D1 {\n            check = check.register(\n                \"Stack\",\n                TensorError::new(\n                    \"Can't stack tensors on a dim that exceeds the tensors dimension (inclusive)\",\n                )\n                .details(format!(\n                    \"Trying to concatenate tensors with {D1} dimensions on axis {dim}.\"\n                )),\n            );\n        }\n\n        if D1 == D2 {\n            check = check.register(\n                \"Stack\",\n                TensorError::new(format!(\n                    \"Can't stack tensors on existing dimension {dim}, the input and output ranks are the same (D={D1}; D2={D2}).\\\n                    If you want to concatenate the tensors along the specified dimension ({dim}), use `Tensor::cat` instead.\",\n                )),\n            );\n        }\n\n        if tensors.is_empty() {\n            return check.register(\n                \"Stack\",\n                TensorError::new(\"Can't stack an empty list of tensors.\"),\n            );\n        }\n\n        let shape_reference = tensors.first().unwrap().shape();\n\n        for tensor in tensors {\n            let shape = tensor.shape();\n\n            if shape_reference != shape {\n                return check.register(\n                    \"Stack\",\n                    TensorError::new(\"Can't stack tensors with different shapes\").details(format!(\n                        \"Provided dimension ({dim}), tensors shapes: {:?}\",\n                        tensors.iter().map(Tensor::shape).collect::<Vec<_>>()\n                    )),\n                );\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn cat<B: Backend, const D: usize, K: BasicOps<B>>(\n        tensors: &[Tensor<B, D, K>],\n        dim: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n\n        if dim >= D {\n            check = check.register(\n                \"Cat\",\n                TensorError::new(\n                    \"Can't concatenate tensors on a dim that exceeds the tensors dimension\",\n                )\n                .details(format!(\n                    \"Trying to concatenate tensors with {D} dimensions on axis {dim}.\"\n                )),\n            );\n        }\n\n        if tensors.is_empty() {\n            return check.register(\n                \"Cat\",\n                TensorError::new(\"Can't concatenate an empty list of tensors.\"),\n            );\n        }\n\n        let mut shape_reference = tensors.first().unwrap().shape();\n        shape_reference[dim] = 1; // We want to check every dims except the one where the\n        // concatenation happens.\n\n        for tensor in tensors {\n            let mut shape = tensor.shape();\n            shape[dim] = 1; // Ignore the concatenate dim.\n\n            if shape_reference != shape {\n                return check.register(\n                    \"Cat\",\n                    TensorError::new(\n                        \"Can't concatenate tensors with different shapes, except for the provided \\\n                         dimension\",\n                    )\n                    .details(format!(\n                        \"Provided dimension ({dim}), tensors shapes: {:?}\",\n                        tensors.iter().map(Tensor::shape).collect::<Vec<_>>()\n                    )),\n                );\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn slice<const R: usize>(shape: &Shape, slices: &[Slice]) -> Self {\n        let mut check = Self::Ok;\n        let n_dims_tensor = R;\n        let n_dims_slices = slices.len();\n\n        if n_dims_tensor < n_dims_slices {\n            check = check.register(\n                \"Slice\",\n                TensorError::new(\n                    \"The provided slices array has a higher number of dimensions than the current \\\n                     tensor.\",\n                )\n                .details(format!(\n                    \"The slices array must be smaller or equal to the tensor number of \\\n                     dimensions. Tensor number of dimensions: {n_dims_tensor}, slices array \\\n                     length {n_dims_slices}.\"\n                )),\n            );\n        }\n\n        for (i, slice) in slices.iter().enumerate().take(R) {\n            let d_tensor = shape[i];\n\n            // Check the raw end value before conversion\n            if let Some(end) = slice.end\n                && end > 0\n                && end as usize > d_tensor\n            {\n                check = check.register(\n                        \"Slice\",\n                        TensorError::new(\n                            \"The provided slice has a range that exceeds the current tensor \\\n                             size.\",\n                        )\n                        .details(format!(\n                            \"The slice end index {} exceeds the size of the tensor ({}) at dimension {}. \\\n                             Tensor shape {:?}.\",\n                            end, d_tensor, i, shape,\n                        )),\n                    );\n            }\n\n            // Empty slices (start >= end) are allowed and produce a tensor with size 0\n            // in that dimension. This matches PyTorch behavior and is required for ONNX\n            // compatibility where dynamic slice ranges may become empty at runtime.\n\n            if slice.step() == 0 {\n                check = check.register(\n                    \"Slice\",\n                    TensorError::new(\"The provided slice has a step of 0.\").details(format!(\n                        \"The slice at dimension '{i}' has a step of 0. Step must be non-zero.\",\n                    )),\n                );\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn slice_assign<const R: usize>(\n        shape: &Shape,\n        shape_value: &Shape,\n        slices: &[crate::Slice],\n    ) -> Self {\n        let mut check = Self::Ok;\n        let n_dims_slices = slices.len();\n\n        if R < n_dims_slices {\n            check = check.register(\n                \"Slice Assign\",\n                TensorError::new(\n                    \"The provided slices array has a higher number of dimensions than the current \\\n                     tensor.\",\n                )\n                .details(format!(\n                    \"The slices array must be smaller or equal to the tensor number of \\\n                     dimensions. Tensor number of dimensions: {R}, slices array length {n_dims_slices}.\"\n                )),\n            );\n        }\n\n        for (i, slice) in slices.iter().enumerate().take(usize::min(R, n_dims_slices)) {\n            let d_tensor = shape[i];\n            let d_tensor_value = shape_value[i];\n            let range = slice.to_range(d_tensor);\n\n            if range.end > d_tensor {\n                check = check.register(\n                    \"Range Assign\",\n                    TensorError::new(\n                        \"The provided slice has a range that exceeds the current tensor \\\n                         size.\",\n                    )\n                    .details(format!(\n                        \"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \\\n                         Current tensor shape {:?}, value tensor shape {:?}.\",\n                        range.start, range.end, d_tensor, i, shape, shape_value,\n                    )),\n                );\n            }\n\n            // Calculate the number of elements selected with the given step\n            let num_elements = slice.output_size(d_tensor);\n\n            if num_elements != d_tensor_value {\n                check = check.register(\n                    \"Slice Assign\",\n                    TensorError::new(\n                        \"The value tensor must match the amount of elements selected with the \\\n                         slices array\",\n                    )\n                    .details(format!(\n                        \"The slice with range ({}..{}) and step {} selects {} elements but the value \\\n                         tensor has {} elements at dimension {}. Current tensor shape {:?}, value tensor \\\n                         shape {:?}.\",\n                        range.start,\n                        range.end,\n                        slice.step,\n                        num_elements,\n                        d_tensor_value,\n                        i,\n                        shape,\n                        shape_value,\n                    )),\n                );\n            }\n\n            // Note: Empty slices (start >= end with positive step) are handled at the API level\n            // by returning the original tensor unchanged, so we don't check for them here.\n        }\n\n        check\n    }\n\n    pub(crate) fn check_dim<const D: usize>(dim: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim >= D {\n            check = check.register(\n                \"Check Dim\",\n                TensorError::new(\"The provided dimension exceeds the tensor dimensions.\").details(\n                    format!(\"Tensor has {D} dimensions, but the provided dimension is {dim}.\"),\n                ),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn gather<const D: usize>(dim: usize, shape: &Shape, shape_indices: &Shape) -> Self {\n        Self::check_gather_scatter_indices::<D>(Self::Ok, \"Gather\", dim, shape, shape_indices)\n    }\n\n    pub(crate) fn scatter<const D: usize>(\n        dim: usize,\n        shape: &Shape,\n        shape_indices: &Shape,\n        shape_value: &Shape,\n    ) -> Self {\n        let ops = \"Scatter\";\n        let mut check =\n            Self::check_gather_scatter_indices::<D>(Self::Ok, ops, dim, shape, shape_indices);\n\n        if shape_indices != shape_value {\n            check = check.register(\n                ops,\n                TensorError::new(\n                    \"Indices tensor shape should be the same as the value tensor shape.\"\n                        .to_string(),\n                )\n                .details(format!(\n                    \"The shape differs: {:?} != {:?}\",\n                    shape_indices, shape_value\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn select<const D: usize>(dim: usize) -> Self {\n        Self::check_select_basic::<D>(Self::Ok, \"select\", dim)\n    }\n\n    pub(crate) fn take<const D: usize, const DI: usize, const DO: usize>(dim: usize) -> Self {\n        let mut check = Self::check_select_basic::<D>(Self::Ok, \"Take\", dim);\n\n        // Calculate expected output dimensions\n        // DO = D - 1 + DI (remove 1 dim, add DI dims)\n        let expected_do = D + DI - 1;\n        if DO != expected_do {\n            check = check.register(\n                \"Take\",\n                TensorError::new(\"Output dimension mismatch\").details(format!(\n                    \"Expected output dimension {} (D={} + DI={} - 1) but got DO={}\",\n                    expected_do, D, DI, DO\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn diag<const D: usize, const DO: usize>() -> Self {\n        let mut check = Self::Ok;\n\n        if D < 2 {\n            check = check.register(\n                \"Diag\",\n                TensorError::new(\n                    \"Diagonal operations require \n                tensors with at least 2 dimensions.\",\n                )\n                .details(format!(\n                    \"Got tensor with {D} dimensions,\n                expected at least 2\"\n                )),\n            );\n        }\n\n        if DO != D - 1 {\n            check = check.register(\n                \"Diag\",\n                TensorError::new(\"Output rank must be input rank minus 1 for diagonal\")\n                    .details(format!(\"Expected output rank {}, got {DO}\", D - 1)),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn select_assign<const D: usize>(\n        dim: usize,\n        shape_indices: &Shape,\n        shape_value: &Shape,\n    ) -> Self {\n        let mut check = Self::check_select_basic::<D>(Self::Ok, \"Select Assign\", dim);\n\n        if shape_value[dim] != shape_indices[0] {\n            check = check.register(\n                \"Select Assign\",\n                TensorError::new(\n                    format!(\n                        \"Number of indices ({}) should be equal to value tensor dimensions {:?} on axis (dim={dim})\",\n                        shape_indices[0],\n                        shape_value\n                    ),\n                )\n            );\n        }\n\n        check\n    }\n\n    fn check_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {\n        if dim > D {\n            check = check.register(\n                ops,\n                TensorError::new(format!(\n                    \"Can't index a tensor with ({D}) dimensions on axis ({dim})\"\n                )),\n            );\n        }\n\n        check\n    }\n    fn check_gather_scatter_indices<const D: usize>(\n        mut check: Self,\n        ops: &str,\n        dim: usize,\n        shape: &Shape,\n        shape_indices: &Shape,\n    ) -> Self {\n        if dim > D {\n            check = check.register(\n                ops,\n                TensorError::new(format!(\n                    \"Can't index a tensor with ({D}) dimensions on axis ({dim})\"\n                )),\n            );\n        }\n\n        for i in 0..D {\n            if i == dim {\n                continue;\n            }\n\n            let tensor_dim_i = shape[i];\n            let indices_dim_i = shape_indices[i];\n\n            if tensor_dim_i != indices_dim_i {\n                check = check.register(\n                    ops,\n                    TensorError::new(\n                        \"The tensor shape should be the same as the index tensor shape.\"\n                            .to_string(),\n                    )\n                    .details(format!(\n                        \"The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}\"\n                    )),\n                );\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn check_prelu_shape<const D: usize>(\n        shape_tensor: &Shape,\n        shape_weight: &Shape,\n    ) -> Self {\n        let mut check = Self::Ok;\n        if shape_weight[0] == 1 {\n            check\n        } else if D >= 2 {\n            let channels = shape_tensor[1];\n            let num_weights = shape_weight[0];\n            if channels != num_weights {\n                check = check.register(\n                    \"PReLu\",\n                    TensorError::new(\n                        \"Number of channels in input tensor and  number of weights must be equal\",\n                    )\n                    .details(format!(\n                        \"Got no. of channels: {channels}, no. of weights: {num_weights}\",\n                    )),\n                );\n                return check;\n            }\n            check\n        } else {\n            check = check.register(\n                \"PReLu\",\n                TensorError::new(\n                    \"Number of channels in input tensor and  number of weights must be equal\",\n                )\n                .details(format!(\n                    \"Got no. of channels: 1, no. of weights: {}\",\n                    shape_weight[0]\n                )),\n            );\n            check\n        }\n    }\n\n    /// Checks aggregate dimension such as mean and sum.\n    pub(crate) fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim > D {\n            check = check.register(\n                ops,\n                TensorError::new(format!(\n                    \"Can't aggregate a tensor with ({D}) dimensions on axis ({dim})\"\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn sort_dim<const D: usize>(ops: &str, dim: usize) -> Self {\n        let mut check = Self::Ok;\n\n        if dim > D {\n            check = check.register(\n                ops,\n                TensorError::new(format!(\n                    \"Can't sort a tensor with ({D}) dimensions on axis ({dim})\"\n                )),\n            );\n        }\n\n        check\n    }\n\n    pub(crate) fn split<const D: usize>(\n        tensor_dims: &[usize],\n        split_size: usize,\n        dim: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n        let op = \"split\";\n        let tensor_rank = tensor_dims.len();\n\n        if dim >= tensor_rank {\n            check = check.register(\n                op,\n                TensorError::new(\"Given dimension is greater than or equal to the tensor rank.\")\n                    .details(format!(\"Tensor rank: '{D}', given dimension: '{dim}'\")),\n            );\n        } else {\n            let tensor_size = tensor_dims[dim];\n            if split_size == 0 && tensor_size != 0 {\n                check = check.register(\n                    op,\n                    TensorError::new(\"split_size must be greater than 0 unless the tensor size along the dimension is 0.\")\n                        .details(format!(\"split_size: '{split_size}', tensor size along dim '{dim}': '{tensor_size}'.\")),\n                );\n            }\n        }\n\n        check\n    }\n\n    pub(crate) fn split_with_sizes<const D: usize>(\n        tensor_dims: &[usize],\n        split_sizes: &[usize],\n        dim: usize,\n    ) -> Self {\n        let mut check = Self::Ok;\n        let op = \"split_with_sizes\";\n        let tensor_rank = tensor_dims.len();\n\n        if dim >= tensor_rank {\n            check = check.register(\n                op,\n                TensorError::new(\"Given dimension is greater than or equal to the tensor rank.\")\n                    .details(format!(\"Tensor rank: '{D}', given dimension: '{dim}'.\")),\n            );\n        } else {\n            // Validate split_sizes add up to size of dimension to split along\n            let tensor_size = tensor_dims[dim];\n            let total_split_size: usize = split_sizes.iter().sum();\n            if total_split_size != tensor_size {\n                check = check.register(\n                    op,\n                    TensorError::new(\"The sum of split_sizes must equal the tensor size along the specified dimension.\")\n                        .details(format!(\"Sum of split_sizes: '{total_split_size}', tensor size along dim '{dim}': '{tensor_size}'.\")),\n                );\n            }\n        }\n\n        check\n    }\n\n    /// The goal is to minimize the cost of checks when there are no error, but it's way less\n    /// important when an error occurred, crafting a comprehensive error message is more important\n    /// than optimizing string manipulation.\n    fn register(self, ops: &str, error: TensorError) -> Self {\n        let errors = match self {\n            Self::Ok => vec![error],\n            Self::Failed(mut failed) => {\n                failed.errors.push(error);\n                failed.errors\n            }\n        };\n\n        Self::Failed(FailedTensorCheck {\n            ops: ops.to_string(),\n            errors,\n        })\n    }\n\n    /// Checks if shapes are compatible for element wise operations supporting broadcasting.\n    pub(crate) fn binary_ops_ew_shape<const D: usize>(\n        self,\n        ops: &str,\n        lhs: &Shape,\n        rhs: &Shape,\n    ) -> Self {\n        let mut check = self;\n\n        for i in 0..D {\n            let d_lhs = lhs[i];\n            let d_rhs = rhs[i];\n\n            if d_lhs != d_rhs {\n                let is_broadcast = d_lhs == 1 || d_rhs == 1;\n\n                if is_broadcast {\n                    continue;\n                }\n\n                check = check.register(\n                    ops,\n                    TensorError::new(\"The provided tensors have incompatible shapes.\").details(\n                        format!(\n                            \"Incompatible size at dimension '{}' => '{} != {}', which can't be \\\n                             broadcasted. Lhs tensor shape {:?}, Rhs tensor shape {:?}.\",\n                            i, d_lhs, d_rhs, lhs, rhs,\n                        ),\n                    ),\n                );\n            }\n        }\n\n        check\n    }\n\n    /// Checks if tensor devices are equal.\n    fn binary_ops_device<Device: PartialEq + core::fmt::Debug>(\n        self,\n        ops: &str,\n        lhs: &Device,\n        rhs: &Device,\n    ) -> Self {\n        match lhs != rhs {\n            true => self.register(\n                ops,\n                TensorError::new(\"The provided tensors are not on the same device.\").details(\n                    format!(\"Lhs tensor device {lhs:?}, Rhs tensor device {rhs:?}.\",),\n                ),\n            ),\n            false => self,\n        }\n    }\n\n    /// Checks if expand operation is possible for the given shapes.\n    pub fn expand<const D1: usize, const D2: usize>(ops: &str, shape: &Shape, to: &Shape) -> Self {\n        let mut check = TensorCheck::Ok;\n        let max_dims = core::cmp::max(D1, D2);\n\n        // Calculate the starting indices for each shape array, ensuring alignment from the right.\n        let start_index_shape = max_dims.saturating_sub(D1);\n        let start_index_to = max_dims.saturating_sub(D2);\n\n        for i in 0..max_dims {\n            // Use 1 as the default dimension size for dimensions beyond the tensor's rank.\n            let d_shape = if i >= start_index_shape {\n                shape[i - start_index_shape]\n            } else {\n                1\n            };\n            let d_to = if i >= start_index_to {\n                to[i - start_index_to]\n            } else {\n                1\n            };\n\n            if d_shape != d_to && d_shape != 1 && d_to != 1 {\n                // Register an incompatibility error.\n                check = check.register(\n                    ops,\n                    TensorError::new(\n                        \"The provided tensor can't be broadcasted to the target shape.\",\n                    )\n                    .details(format!(\n                        \"Incompatible size at dimension '{}' => '{} != {}', which can't be \\\n                         broadcasted. Tensor shape {:?}, Target shape {:?}.\",\n                        max_dims - i - 1,\n                        d_shape,\n                        d_to,\n                        shape,\n                        to,\n                    )),\n                );\n                break; // Incompatibility found, no need to check further.\n            }\n        }\n\n        check\n    }\n\n    /// Checks if unfold operation is possible for the given shapes.\n    pub fn unfold<const D1: usize, const D2: usize>(\n        ops: &str,\n        _shape: &Shape,\n        _dim: usize,\n        _size: usize,\n        _step: usize,\n    ) -> Self {\n        let mut check = TensorCheck::Ok;\n\n        if D2 != D1 + 1 {\n            check = check.register(\n                ops,\n                TensorError::new(\"The unfold rank is incompatible with the input tensor rank.\")\n                    .details(format!(\n                        \"The output rank '{D2}' != the input rank + 1 '{D1}'.\",\n                    )),\n            );\n        }\n\n        check\n    }\n\n    /// Checks if input is compatible with convolution weights.\n    pub fn conv<const D1: usize, const D2: usize>(\n        ops: &str,\n        x: [usize; D1],\n        weight: [usize; D2],\n        groups: usize,\n    ) -> Self {\n        let mut check = TensorCheck::Ok;\n        let channels = x[1];\n        let expected = weight[1] * groups;\n        if channels != expected {\n            check = check.register(\n                ops,\n                TensorError::new(\"Number of channels in input tensor and input channels of convolution must be equal.\")\n                .details(format!(\"got: {channels}, expected: {expected}\")),\n            );\n        }\n        check\n    }\n\n    /// Checks if input is compatible with transposed convolution weights.\n    pub fn conv_transpose<const D1: usize, const D2: usize>(\n        ops: &str,\n        x: [usize; D1],\n        weight: [usize; D2],\n    ) -> Self {\n        let mut check = TensorCheck::Ok;\n        let channels = x[1];\n        let expected = weight[0];\n        if channels != expected {\n            check = check.register(\n                ops,\n                TensorError::new(\"Number of channels in input tensor and input channels of convolution must be equal.\")\n                .details(format!(\"got: {channels}, expected: {expected}\")),\n            );\n        }\n        check\n    }\n\n    /// Check if input is compatible with LU decomposition.\n    pub fn is_square<const D: usize>(ops: &str, shape: &Shape) -> Self {\n        let mut check = TensorCheck::Ok;\n        if shape[D - 1] != shape[D - 2] {\n            check = check.register(\n                ops,\n                TensorError::new(\"The input tensor must be square.\").details(format!(\n                    \"Got tensor with shape {:?}, expected last two dimensions to be equal\",\n                    shape\n                )),\n            );\n        }\n        check\n    }\n\n    /// Check pivot is valid for LU decomposition.\n    pub fn lu_decomposition_pivot<B: Backend>(pivot: FloatElem<B>) -> Self {\n        let mut check = TensorCheck::Ok;\n        if pivot.to_f64().abs() <= 1e-6 {\n            check = check.register(\n                \"lu_decomposition\",\n                TensorError::new(\"LU decomposition requires a valid pivot.\")\n                    .details(format!(\"Got pivot value too close to zero: {}\", pivot)),\n            );\n        }\n        check\n    }\n}\n\npub(crate) struct FailedTensorCheck {\n    ops: String,\n    errors: Vec<TensorError>,\n}\n\nimpl FailedTensorCheck {\n    /// Format all the checks into a single message ready to be printed by a [panic](core::panic).\n    pub(crate) fn format(self) -> String {\n        self.errors.into_iter().enumerate().fold(\n            format!(\n                \"=== Tensor Operation Error ===\\n  Operation: '{}'\\n  Reason:\",\n                self.ops\n            ),\n            |accum, (number, error)| accum + error.format(number + 1).as_str(),\n        ) + \"\\n\"\n    }\n}\n\nstruct TensorError {\n    description: String,\n    details: Option<String>,\n}\n\nimpl TensorError {\n    pub(crate) fn new<S: Into<String>>(description: S) -> Self {\n        TensorError {\n            description: description.into(),\n            details: None,\n        }\n    }\n\n    pub(crate) fn details<S: Into<String>>(mut self, details: S) -> Self {\n        self.details = Some(details.into());\n        self\n    }\n\n    fn format(self, number: usize) -> String {\n        let mut message = format!(\"\\n    {number}. \");\n        message += self.description.as_str();\n        message += \" \";\n\n        if let Some(details) = self.details {\n            message += details.as_str();\n            message += \" \";\n        }\n\n        message\n    }\n}\n\n/// Module where we defined macros that can be used only in the project.\npub(crate) mod macros {\n    /// We use a macro for all checks, since the panic message file and line number will match the\n    /// function that does the check instead of a generic error.rs crate private unrelated file\n    /// and line number.\n    macro_rules! check {\n        ($check:expr) => {\n            if let TensorCheck::Failed(check) = $check {\n                core::panic!(\"{}\", check.format());\n            }\n        };\n    }\n    pub(crate) use check;\n}\n\npub(crate) fn unwrap_shape_reshape(result: Result<Shape, burn_std::MetadataError>) -> Shape {\n    match result {\n        Ok(shape) => shape,\n        // `shape.reshape(new_shape)` should only return `MetadataError::Invalid`.\n        Err(burn_std::MetadataError::Invalid { reason }) => {\n            macros::check!({\n                TensorCheck::Ok.register(\"Reshape\", crate::check::TensorError::new(reason))\n            });\n            unreachable!()\n        }\n        Err(e) => panic!(\"{e:?}\"),\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use macros::check;\n\n    #[test]\n    #[should_panic]\n    fn index_range_exceed_dimension() {\n        let slices = vec![Slice::from(0..2), Slice::from(0..4), Slice::from(1..8)];\n        check!(TensorCheck::slice::<3>(&Shape::new([3, 5, 7]), &slices));\n    }\n\n    #[test]\n    #[should_panic]\n    fn index_range_exceed_number_of_dimensions() {\n        let slices = vec![Slice::from(0..1), Slice::from(0..1), Slice::from(0..1)];\n        check!(TensorCheck::slice::<2>(&Shape::new([3, 5]), &slices));\n    }\n\n    #[test]\n    #[should_panic]\n    fn binary_ops_shapes_no_broadcast() {\n        check!(TensorCheck::binary_ops_ew_shape::<2>(\n            TensorCheck::Ok,\n            \"TestOps\",\n            &Shape::new([3, 5]),\n            &Shape::new([3, 6])\n        ));\n    }\n\n    #[test]\n    fn binary_ops_shapes_with_broadcast() {\n        check!(TensorCheck::binary_ops_ew_shape::<2>(\n            TensorCheck::Ok,\n            \"Test\",\n            &Shape::new([3, 5]),\n            &Shape::new([1, 5])\n        ));\n    }\n\n    #[test]\n    #[should_panic]\n    fn binary_ops_devices() {\n        check!(TensorCheck::binary_ops_device(\n            TensorCheck::Ok,\n            \"Test\",\n            &5, // We can pass anything that implements PartialEq as device\n            &8\n        ));\n    }\n\n    #[test]\n    #[should_panic]\n    fn movedim_args_out_of_bounds() {\n        check!(TensorCheck::movedim_args_usize::<3>(5));\n    }\n\n    #[test]\n    fn movedim_args_i32() {\n        check!(TensorCheck::movedim_args_i32::<3>(-3));\n    }\n\n    #[test]\n    #[should_panic]\n    fn movedim_args_too_negative() {\n        check!(TensorCheck::movedim_args_i32::<3>(-4));\n    }\n\n    #[test]\n    #[should_panic]\n    fn movedim_args_vec_out_of_bounds() {\n        check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 3]));\n    }\n\n    #[test]\n    #[should_panic]\n    fn movedim_args_vec_duplicates() {\n        check!(TensorCheck::movedim_args_vec::<3>(&vec![0, 1, 1]));\n    }\n\n    #[test]\n    #[should_panic]\n    fn movedim_args_length() {\n        check!(TensorCheck::movedim_args_length(\n            &vec![0, 1],\n            &vec![0, 1, 2]\n        ));\n    }\n\n    #[test]\n    #[should_panic]\n    fn unsqueeze_dim_same_rank() {\n        check!(TensorCheck::unsqueeze_dim::<3, 3>(2));\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/float.rs",
    "content": "use crate::AsIndex;\nuse crate::FloatDType;\nuse crate::Tensor;\nuse crate::cast::ToElement;\nuse crate::check;\nuse crate::check::TensorCheck;\nuse crate::ops::GridSampleOptions;\nuse crate::quantization::{QuantScheme, QuantizationParameters};\nuse crate::tensor::backend::Backend;\nuse crate::tensor::stats;\nuse crate::tensor::{Distribution, TensorData};\nuse crate::{Bool, Int, TensorPrimitive};\nuse burn_backend::ElementConversion;\nuse burn_backend::Scalar;\nuse burn_backend::tensor::quantization::QuantizationParametersPrimitive;\nuse core::f32;\n\n/// Default RTOL value for `is_close` and `all_close`.\npub const DEFAULT_RTOL: f64 = 1e-5;\n\n/// Default ATOL value for `is_close` and `all_close`.\npub const DEFAULT_ATOL: f64 = 1e-8;\n\nimpl<const D: usize, B> Tensor<B, D>\nwhere\n    B: Backend,\n{\n    /// Applies element wise exponential operation.\n    ///\n    #[cfg_attr(doc, doc = \"$y_i = e^{x_i}$\")]\n    #[cfg_attr(not(doc), doc = \"`y = e^x`\")]\n    pub fn exp(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_exp(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise natural log operation *ln*.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\log_e\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = log(x_i)`\")]\n    pub fn log(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_log(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies the natural logarithm of one plus the input tensor, element-wise.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\log_e\\(x_i + 1\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = log(x_i + 1)`\")]\n    pub fn log1p(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_log1p(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies the [error function](https://en.wikipedia.org/wiki/Error_function) element wise.\n    ///\n    #[cfg_attr(\n        doc,\n        doc = r#\"\n$y_i = \\text{erf}\\(x_i\\)$\n\nThe error function is defined as:\n\n$$\\text{erf}\\(x\\) = \\frac{2}{\\sqrt{\\pi}} \\int_0^x e^{-t^2} dt$$\n\"#\n    )]\n    #[cfg_attr(not(doc), doc = \"`y_i = erf(x_i)`\")]\n    pub fn erf(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_erf(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies [reciprocal operation](https://en.wikipedia.org/wiki/Multiplicative_inverse)\n    /// (or multiplicative inverse) element wise.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\frac{1}{x_i}$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = 1/x_i`\")]\n    pub fn recip(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_recip(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise square operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = x_i * x_i$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = x_i * x_i`\")]\n    pub fn square(self) -> Self {\n        self.powi_scalar(2)\n    }\n\n    /// Applies element wise root square operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\sqrt{x_i}$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = sqrt(x_i)`\")]\n    pub fn sqrt(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_sqrt(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise cosine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\cos\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = cos(x_i)`\")]\n    pub fn cos(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_cos(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise sine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\sin\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = sin(x_i)`\")]\n    pub fn sin(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_sin(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise tangent operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\tan\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = tan(x_i)`\")]\n    pub fn tan(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_tan(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise hyperbolic cosine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\cosh\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = cosh(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);\n    ///     println!(\"{}\", tensor.cosh()); // [1.0, 1.5430, 3.7621]\n    /// }\n    /// ```\n    pub fn cosh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_cosh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise hyperbolic sine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\sinh\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = sinh(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);\n    ///     println!(\"{}\", tensor.sinh()); // [0.0, -1.1752, 3.6269]\n    /// }\n    /// ```\n    pub fn sinh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_sinh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise hyperbolic tangent operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\tanh\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = tanh(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);\n    ///     println!(\"{}\", tensor.tanh()); // [0.0, -0.7616, 0.9640]\n    /// }\n    /// ```\n    pub fn tanh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_tanh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse sine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\asin\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = asin(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);\n    ///     println!(\"{}\", tensor.asin()); // [ 0.0000, -1.5708,  1.5708]\n    /// }\n    /// ```\n    pub fn asin(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_asin(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse hyperbolic sine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\asinh\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = asinh(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);\n    ///     println!(\"{}\", tensor.asinh()); // [ 0.0000, -0.8814,  0.8814]\n    /// }\n    /// ```\n    pub fn asinh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_asinh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse cosine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\acos\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = acos(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 1.0], &device);\n    ///     println!(\"{}\", tensor.acos()); // [1.5708, 3.1416, 0.0]\n    /// }\n    /// ```\n    pub fn acos(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_acos(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse hyperbolic cosine operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\acosh\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = acosh(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([1.0, 2.0, 3.0], &device);\n    ///     println!(\"{}\", tensor.sinh()); // [0.0000, 1.3170, 1.7627]\n    /// }\n    /// ```\n    pub fn acosh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_acosh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse tangent operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\atan\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = atan(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -1.0, 2.0], &device);\n    ///     println!(\"{}\", tensor.sinh()); // [ 0.0, -0.7854,  1.1071]\n    /// }\n    /// ```\n    pub fn atan(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_atan(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse hyperbolic tangent operation.\n    ///\n    #[cfg_attr(doc, doc = r#\"$y_i = \\atan\\(x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`y_i = atan(x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let tensor = Tensor::<B, 1>::from_data([0.0, -0.5, 0.5], &device);\n    ///     println!(\"{}\", tensor.sinh()); // [ 0.0, -0.5493,  0.5493]\n    /// }\n    /// ```\n    pub fn atanh(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_atanh(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise inverse tangent operation using the signs of arguments to determine the correct quadrant.\n    ///\n    #[cfg_attr(doc, doc = r#\"$z_i = \\atan2\\(y_i, x_i\\)$\"#)]\n    #[cfg_attr(not(doc), doc = \"`z_i = atan2(y_i, x_i)`\")]\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///\n    ///     let lhs = Tensor::<B, 1>::from_data([-2.0, 2.0, -2.0], &device);\n    ///     let rhs = Tensor::<B, 1>::from_data([1.0, -1.0, -1.0], &device);\n    ///     println!(\"{}\", lhs.atan2(rhs)); // [-1.1071,  2.0344, -2.0344]\n    /// }\n    /// ```\n    pub fn atan2(self, other: Self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_atan2(\n            self.primitive.tensor(),\n            other.primitive.tensor(),\n        )))\n    }\n\n    /// Converts each of the elements of the input tensor from angles in degrees to radians.\n    ///\n    /// # Example\n    /// ```ignore\n    /// let tensor_in_radians = tensor.deg2rad();\n    /// ```\n    pub fn deg2rad(self) -> Self {\n        self.mul_scalar(f32::consts::PI / 180.0)\n    }\n\n    /// Converts each of the elements of the input tensor from angles in radians to degrees.\n    ///\n    /// # Example\n    /// ```ignore\n    /// let tensor_in_degrees = tensor.rad2deg();\n    /// ```\n    pub fn rad2deg(self) -> Self {\n        self.mul_scalar(180.0 / f32::consts::PI)\n    }\n\n    /// Applies element wise round operation.\n    ///\n    /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)\n    /// strategy, with halfway cases rounded to the nearest even integer value.\n    pub fn round(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_round(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise floor operation.\n    pub fn floor(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_floor(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Applies element wise ceil operation.\n    pub fn ceil(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_ceil(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Create a tensor from floats (f32) on a given device.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let _ = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);\n    ///     let _ = Tensor::<B, 2>::from_floats([[1.0, 2.0], [3.0, 4.0]], &device);\n    /// }\n    /// ```\n    pub fn from_floats<A: Into<TensorData>>(floats: A, device: &B::Device) -> Self {\n        Self::from_data(floats.into().convert::<f32>(), device)\n    }\n\n    /// Returns a new tensor with the same shape and device as the current tensor and the data\n    /// cast to Integer.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let float_tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device);\n    ///     let int_tensor = float_tensor.int();\n    /// }\n    /// ```\n    pub fn int(self) -> Tensor<B, D, Int> {\n        Tensor::new(B::float_into_int(self.primitive.tensor()))\n    }\n\n    /// Returns a new tensor with the same shape, dtype, and device as the current tensor filled random\n    /// values sampled from the given distribution.\n    pub fn random_like(&self, distribution: Distribution) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_random(\n            self.shape(),\n            distribution,\n            &self.device(),\n        )))\n        .cast(self.dtype())\n    }\n\n    /// Calculate the variance along the given dimension.\n    pub fn var(self, dim: usize) -> Self {\n        stats::var(self, dim)\n    }\n\n    /// Calculate the variance along the given dimension without applying the Bessel’s correction.\n    pub fn var_bias(self, dim: usize) -> Self {\n        stats::var_bias(self, dim)\n    }\n\n    /// Calculate the variance along the given dimension and also returns the mean.\n    pub fn var_mean(self, dim: usize) -> (Self, Self) {\n        let mean = self.clone().mean_dim(dim);\n        let var = stats::var_with_mean(self, mean.clone(), dim);\n        (var, mean)\n    }\n\n    /// Calculate the variance along the given dimension without applying the Bessel’s correction and also returns the mean.\n    pub fn var_mean_bias(self, dim: usize) -> (Self, Self) {\n        let mean = self.clone().mean_dim(dim);\n        let var = stats::var_with_mean_bias(self, mean.clone(), dim);\n        (var, mean)\n    }\n\n    /// Returns the median value along the specified dimension.\n    ///\n    /// The median is not unique for input tensors with an even number of elements\n    /// in the reduced dimension. In this case, the lower of the two medians is returned,\n    /// following PyTorch's behavior.\n    ///\n    /// # Note\n    ///\n    /// The current implementation performs a full sort along the specified dimension,\n    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back\n    /// to CPU for the sort operation, which may result in slower performance compared\n    /// to native GPU operations.\n    ///\n    /// # Arguments\n    ///\n    /// - `dim` - The dimension along which to compute the median.\n    ///\n    /// # Returns\n    ///\n    /// - A tensor containing the median values along the specified dimension.\n    ///\n    /// # Example 1\n    ///\n    /// ```ignore\n    /// // Assuming backend B\n    /// let device = B::Device::default();\n    /// let tensor = Tensor::<B, 2>::from_data(\n    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],\n    ///     &device,\n    /// );\n    ///\n    /// // Median along dimension 0:\n    /// // sorted columns are [1.0, 8.0], [4.0, 5.0], [3.0, 6.0], [2.0, 7.0]\n    /// let median = tensor.median(0);\n    /// // Result: [[1.0, 4.0, 3.0, 2.0]]\n    ///\n    /// // Median along dimension 1:\n    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]\n    /// let median = tensor.median(1);\n    /// // Result: [[2.0], [6.0]]\n    /// ```\n    ///\n    /// # Example 2\n    ///\n    /// The median across all elements can be calculated as follows:\n    ///\n    /// ```ignore\n    /// // D is the number of dimensions of the tensor\n    /// let flattened_tensor: Tensor<B, 1> = tensor.flatten(0, D - 1);\n    ///\n    /// // Calculate median for dim 0 since the tensor has become 1 dimensional\n    /// let median = flattened_tensor.median(0);\n    /// // Result: [4.0]\n    /// ```\n    pub fn median(self, dim: usize) -> Self {\n        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl\n        // instead of leveraging a full sort to get the median.\n        stats::median(self, dim)\n    }\n\n    /// Returns the median value along the specified dimension and its index.\n    ///\n    /// The median is not unique for input tensors with an even number of elements\n    /// in the reduced dimension. In this case, the lower of the two medians is returned,\n    /// following PyTorch's behavior.\n    ///\n    /// # Note\n    ///\n    /// The current implementation performs a full sort along the specified dimension,\n    /// which has O(nlog(n)) complexity. Additionally, most backends currently fall back\n    /// to CPU for the sort operation, which may result in slower performance compared\n    /// to native GPU operations.\n    ///\n    /// # Arguments\n    ///\n    /// - `dim` - The dimension along which to compute the median.\n    ///\n    /// # Returns\n    ///\n    /// A tuple containing:\n    /// - A tensor with the median values.\n    /// - A tensor with the indices of the median values in the original tensor.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// // Assuming backend B\n    /// let device = B::Device::default();\n    /// let tensor = Tensor::<B, 2>::from_data(\n    ///     [[1.0, 5.0, 3.0, 2.0], [8.0, 4.0, 6.0, 7.0]],\n    ///     &device,\n    /// );\n    ///\n    /// // Median along dimension 1:\n    /// // sorted rows are [1.0, 2.0, 3.0, 5.0] and [4.0, 6.0, 7.0, 8.0]\n    /// let (values, indices) = tensor.median_with_indices(1);\n    /// // values: [[2.0], [6.0]], indices: [[3], [2]] (position in the original tensor)\n    /// ```\n    pub fn median_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {\n        // TODO: Allow backend specialization. Optimally, implement a median kernel for cubecl\n        // instead of leveraging a full sort to get the median.\n        stats::median_with_indices(self, dim)\n    }\n\n    /// Converts a tensor to the specified floating point data type.\n    ///\n    /// This is always a no-op when casting to the current dtype.\n    ///\n    /// # Warning\n    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors\n    /// have the same floating point precision data type for operations multiple input tensors (e.g., binary ops).\n    pub fn cast<F: Into<FloatDType>>(self, dtype: F) -> Tensor<B, D> {\n        let dtype = dtype.into();\n        let self_type: FloatDType = self.dtype().into();\n        if dtype == self_type {\n            // no-op.\n            return self;\n        }\n\n        Tensor::new(TensorPrimitive::Float(B::float_cast(\n            self.primitive.tensor(),\n            dtype,\n        )))\n    }\n\n    /// Detach the current tensor from the autodiff graph.\n    ///\n    /// This function does nothing when autodiff is not enabled.\n    /// This can be used in batchers or elsewhere to ensure that previous operations are not\n    /// considered in the autodiff graph.\n    pub fn detach(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_detach(\n            self.primitive.tensor(),\n        )))\n    }\n\n    /// Mark the tensor to keep gradients during the backward pass.\n    ///\n    /// This function does nothing when autodiff is not enabled.\n    pub fn require_grad(self) -> Self {\n        self.set_require_grad(true)\n    }\n\n    /// Returns true if the tensor requires gradients during the backward pass.\n    pub fn is_require_grad(&self) -> bool {\n        match &self.primitive {\n            TensorPrimitive::Float(tensor) => B::float_is_require_grad(tensor),\n            TensorPrimitive::QFloat(tensor) => B::q_is_require_grad(tensor),\n        }\n    }\n\n    /// Mark the tensor as tracked or untracked depending on the require_grad argument.\n    /// When tracked, the gradients will be available after the backward pass.\n    ///\n    /// This function does nothing when autodiff is not enabled.\n    pub fn set_require_grad(self, require_grad: bool) -> Self {\n        let primitive = match self.primitive {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_set_require_grad(tensor, require_grad))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_set_require_grad(tensor, require_grad))\n            }\n        };\n        Self::new(primitive)\n    }\n\n    /// Applies the relu function to the tensor.\n    pub(crate) fn relu(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::relu(self.primitive.tensor())))\n    }\n\n    /// Calculate covaraince matrix between different entries alongside a given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `size` - The size of the square matrix.\n    /// * `correction_factor` - Is usually 1 for samples and 0 for population.\n    pub fn cov(self, dim: usize, correction_factor: usize) -> Tensor<B, D> {\n        let n = self.dims()[dim];\n        let centered = (self.clone() - self.mean_dim(dim)).swap_dims(dim, 0);\n        centered\n            .clone()\n            .transpose()\n            .matmul(centered)\n            .div_scalar(n as f32 - correction_factor as f32)\n    }\n\n    /// Convert the tensor to a lower precision data type based on the quantization scheme.\n    ///\n    /// # Arguments\n    ///\n    /// * `scheme` - The quantization scheme.\n    /// * `qparams` - The pre-computed quantization parameters.\n    ///\n    /// # Returns\n    ///\n    /// The quantized tensor.\n    pub fn quantize(\n        self,\n        scheme: &QuantScheme,\n        qparams: QuantizationParameters<B>,\n    ) -> Tensor<B, D> {\n        Tensor::new(TensorPrimitive::QFloat(B::quantize(\n            self.primitive.tensor(),\n            scheme,\n            QuantizationParametersPrimitive {\n                scales: qparams.scales.primitive.tensor(),\n            },\n        )))\n    }\n\n    /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme.\n    ///\n    /// # Arguments\n    ///\n    /// * `scheme` - The quantization scheme.\n    ///\n    /// # Returns\n    ///\n    /// The quantized tensor.\n    ///\n    /// # Notes\n    /// This uses [min-max calibration](crate::quantization::Calibration::MinMax).\n    pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor<B, D> {\n        Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic(\n            self.primitive.tensor(),\n            scheme,\n        )))\n    }\n\n    /// Convert the tensor back to a higher precision data type.\n    ///\n    /// If the tensor is not quantized, its value is simply returned.\n    ///\n    /// # Returns\n    ///\n    /// The dequantized tensor.\n    pub fn dequantize(self) -> Tensor<B, D> {\n        Tensor::new(TensorPrimitive::Float(self.primitive.tensor()))\n    }\n\n    /// Checks element wise if the tensor is close to another tensor.\n    ///\n    /// The tolerance is defined by the following equation:\n    ///\n    /// ```text\n    /// abs(a - b) <= (atol + rtol * abs(b))\n    ///\n    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,\n    /// and `atol` is the absolute tolerance.\n    /// ```\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to compare with.\n    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.\n    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensors.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor1.is_close(tensor2, None, None);\n    ///    println!(\"{tensor}\");\n    ///    // [[true, true, true], [true, true, true]]\n    /// }\n    /// ```\n    pub fn is_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> Tensor<B, D, Bool> {\n        let rtol = rtol.unwrap_or(DEFAULT_RTOL);\n        let atol = atol.unwrap_or(DEFAULT_ATOL);\n\n        // check finite difference is close\n        let is_close_finite_val = self\n            .clone()\n            .sub(other.clone())\n            .abs()\n            .lower_equal(other.clone().abs().mul_scalar(rtol).add_scalar(atol))\n            .bool_and(self.clone().is_finite())\n            .bool_and(other.clone().is_finite());\n\n        // check if both are infinite and have same sign\n        let inf_same_sign = self\n            .clone()\n            .is_finite()\n            .bool_not()\n            .bool_and(other.clone().is_finite().bool_not())\n            .bool_and(self.equal(other));\n\n        is_close_finite_val.bool_or(inf_same_sign)\n    }\n\n    /// Checks if all elements are close to another tensor.\n    ///\n    /// The tolerance is defined by the following equation:\n    ///\n    /// ```text\n    ///\n    /// abs(a - b) <= (atol + rtol * abs(b))\n    ///\n    /// where `a` is the first tensor, `b` is the second tensor, `rtol` is the relative tolerance,\n    /// and `atol` is the absolute tolerance.\n    ///\n    /// ```\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to compare with.\n    /// * `rtol` - Optional relative tolerance. Default is 1e-5; see `DEFAULT_RTOL`.\n    /// * `atol` - Optional absolute tolerance. Default is 1e-8; see `DEFAULT_ATOL`.\n    ///\n    /// # Returns\n    ///\n    /// A boolean scalar.\n    ///\n    /// # Remarks\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let result = tensor1.all_close(tensor2, None, None);\n    ///    println!(\"{}\", result);\n    ///    // true\n    /// }\n    /// ```\n    pub fn all_close(self, other: Self, rtol: Option<f64>, atol: Option<f64>) -> bool {\n        self.is_close(other, rtol, atol)\n            .all()\n            .into_scalar()\n            .to_bool()\n    }\n\n    /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::NAN, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.is_nan();\n    ///    println!(\"{tensor}\");\n    ///    // [[false, true, false], [false, false, false]]\n    /// }\n    /// ```\n    pub fn is_nan(self) -> Tensor<B, D, Bool> {\n        Tensor::new(B::float_is_nan(self.primitive.tensor()))\n    }\n\n    /// Checks if the tensor contains any NaN values.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with a single element indicating whether the tensor contains any NaN values.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [f64::NAN, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.contains_nan();\n    ///   println!(\"{tensor}\");\n    ///   // [true]\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.contains_nan();\n    ///   println!(\"{tensor}\");\n    ///   // [false]\n    /// }\n    /// ```\n    pub fn contains_nan(self) -> Tensor<B, 1, Bool> {\n        // Summing the tensor will result in NaN if the tensor contains any NaN values\n        // This is faster than checking each element individually\n        // because it rolls up the NaN values into a single value\n        let sum = self.sum();\n\n        sum.is_nan()\n    }\n\n    /// Returns a new tensor with boolean elements indicating whether each element of the input is infinite (either +INF or -INF).\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where `true` indicates that the value is infinite\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.is_finite();\n    ///    println!(\"{tensor}\");\n    ///    // [[false, true, false], [false, false, false]]\n    /// }\n    /// ```\n    pub fn is_inf(self) -> Tensor<B, D, Bool> {\n        Tensor::new(B::float_is_inf(self.primitive.tensor()))\n    }\n\n    /// Returns a new tensor with boolean elements indicating whether each element of the input is finite\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor where `true` indicates that the value is finite and `false` indicates\n    /// either INF, -INF or NAN\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Bool, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, f64::INFINITY, 3.0], [f64::NAN, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.is_finite();\n    ///    println!(\"{tensor}\");\n    ///    // [[true, false, true], [false, true, true]]\n    /// }\n    /// ```\n    pub fn is_finite(self) -> Tensor<B, D, Bool> {\n        self.clone()\n            .is_nan()\n            .bool_not()\n            .bool_and(self.is_inf().bool_not())\n    }\n\n    /// Samples tensor as a two-dimensional spatial grid of (possibly multi-channel) values,\n    /// using the given locations in [-1, 1].\n    ///\n    /// # Arguments\n    ///\n    /// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].\n    ///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right\n    /// * `options` - Grid sampling options (mode, padding_mode, align_corners)\n    ///\n    /// # Returns\n    ///\n    /// A tensor with shape (N, C, H_out, W_out)\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// use burn_tensor::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};\n    ///\n    /// // Default options (bilinear, zeros padding, align_corners=false)\n    /// let output = tensor.grid_sample_2d(grid, GridSampleOptions::default());\n    ///\n    /// // Custom options\n    /// let options = GridSampleOptions::new(InterpolateMode::Bilinear)\n    ///     .with_padding_mode(GridSamplePaddingMode::Border)\n    ///     .with_align_corners(true);\n    /// let output = tensor.grid_sample_2d(grid, options);\n    /// ```\n    pub fn grid_sample_2d(\n        self,\n        grid: Tensor<B, D>,\n        options: impl Into<GridSampleOptions>,\n    ) -> Tensor<B, D> {\n        Tensor::new(TensorPrimitive::Float(B::float_grid_sample_2d(\n            self.primitive.tensor(),\n            grid.primitive.tensor(),\n            options.into(),\n        )))\n    }\n\n    /// Computes the cross product of `self` and another tensor along a given dimension.\n    ///\n    /// Both `self` and `other` **must have size 3** along the specified `dim`,\n    /// because the cross product is only defined in three-dimensional space.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The other tensor to take the cross product with.\n    /// * `dim`   - The dimension along which to compute the cross product.\n    ///\n    /// # Returns\n    ///\n    /// A tensor containing the cross product of `self` and `other` along `dim`.\n    pub fn cross<Dim: AsIndex>(self, other: Tensor<B, D>, dim: Dim) -> Tensor<B, D> {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::cross(&self, &other, dim));\n        Tensor::new(TensorPrimitive::Float(B::float_cross(\n            self.primitive.tensor(),\n            other.primitive.tensor(),\n            dim,\n        )))\n    }\n\n    /// Applies element wise power operation with a float Tensor\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to apply the power operation with.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.powf(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, 8.0, 81.0], [5.0, 81.0, 216.0]]\n    /// }\n    /// ```\n    pub fn powf(self, other: Self) -> Self {\n        let primitive = match (self.primitive, other.primitive) {\n            (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {\n                TensorPrimitive::Float(B::float_powf(lhs, rhs))\n            }\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs),\n            (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => {\n                TensorPrimitive::Float(B::float_powf(B::dequantize(lhs), rhs))\n            }\n            (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {\n                TensorPrimitive::Float(B::float_powf(lhs, B::dequantize(rhs)))\n            }\n        };\n\n        Tensor::new(primitive)\n    }\n\n    /// Applies element wise power operation with a float scalar\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to apply the power operation with.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.powf_scalar(2.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, 4.0, 9.0], [25.0, 81.0, 36.0]]\n    /// }\n    /// ```\n    pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let rhs = Scalar::new(other, &self.dtype());\n\n        let primitive = match self.primitive {\n            TensorPrimitive::Float(lhs) => TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs)),\n            TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs),\n        };\n\n        Tensor::new(primitive)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/fmod.rs",
    "content": "use crate::{Float, Tensor, backend::Backend};\n\nimpl<B, const D: usize> Tensor<B, D, Float>\nwhere\n    B: Backend,\n{\n    /// Computes the floating-point remainder of dividing `self` by `other`.\n    ///\n    /// The result has the same sign as `self` and magnitude less than `other`.\n    /// This is equivalent to the IEEE 754 remainder operation.\n    ///\n    /// # Special Cases (IEEE 754 compliant)\n    ///\n    /// - If `self` is ±∞ and `other` is not NaN, NaN is returned\n    /// - If `other` is ±0 and `self` is not NaN, NaN is returned\n    /// - If `other` is ±∞ and `self` is finite, `self` is returned\n    /// - If either argument is NaN, NaN is returned\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The divisor tensor. Must have the same shape as `self`.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the floating-point remainder.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let dividend = Tensor::<B, 1>::from_data([5.3, -5.3, 5.3, -5.3], &device);\n    ///     let divisor = Tensor::<B, 1>::from_data([2.0, 2.0, -2.0, -2.0], &device);\n    ///     let result = dividend.fmod(divisor);\n    ///\n    ///     // Result: [1.3, -1.3, 1.3, -1.3]\n    /// }\n    /// ```\n    pub fn fmod(self, other: Self) -> Self {\n        // Normal case: fmod(x, y) = x - y * trunc(x / y)\n        let quotient = self.clone().div(other.clone());\n        let truncated = quotient.trunc();\n        let product = other.clone() * truncated.clone();\n\n        // When divisor is infinity and dividend is finite:\n        // - quotient is 0, truncated is 0\n        // - but 0 * infinity = NaN, which is wrong\n        // We need to handle this case by replacing NaN with 0 when appropriate\n\n        // Check if the product is NaN due to 0 * inf\n        let is_zero_times_inf = truncated.equal_elem(0.0).bool_and(other.is_inf());\n        let zero_tensor = self.clone().mul_scalar(0.0);\n        let corrected_product = product.mask_where(is_zero_times_inf, zero_tensor);\n\n        self - corrected_product\n    }\n\n    /// Computes the floating-point remainder of dividing `self` by a scalar.\n    ///\n    /// The result has the same sign as `self` and magnitude less than the scalar.\n    ///\n    /// # Special Cases (IEEE 754 compliant)\n    ///\n    /// - If `self` is ±∞ and scalar is not NaN, NaN is returned\n    /// - If scalar is ±0 and `self` is not NaN, NaN is returned\n    /// - If scalar is ±∞ and `self` is finite, `self` is returned\n    /// - If either argument is NaN, NaN is returned\n    ///\n    /// # Arguments\n    ///\n    /// * `scalar` - The scalar divisor.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element is the floating-point remainder.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 1>::from_data([5.3, -5.3, 7.5, -7.5], &device);\n    ///     let result = tensor.fmod_scalar(2.0);\n    ///\n    ///     // Result: [1.3, -1.3, 1.5, -1.5]\n    /// }\n    /// ```\n    pub fn fmod_scalar(self, scalar: f32) -> Self {\n        // Normal case: fmod(x, y) = x - y * trunc(x / y)\n        let quotient = self.clone().div_scalar(scalar);\n        let truncated = quotient.trunc();\n        let product = truncated.mul_scalar(scalar);\n\n        // Handle the special case where scalar is infinity\n        // When scalar is ±∞ and self is finite, quotient is 0, truncated is 0\n        // but 0 * infinity = NaN, which is wrong - it should be 0\n        if scalar.is_infinite() {\n            // For finite values, fmod(x, ±∞) = x\n            // For infinite values, fmod(±∞, ±∞) = NaN (which is handled by arithmetic)\n            return self;\n        }\n\n        self - product\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/int.rs",
    "content": "use burn_backend::Scalar;\n\nuse crate::{\n    Float, Int, IntDType, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend,\n    cartesian_grid,\n};\n\nuse core::ops::Range;\n\nimpl<B> Tensor<B, 1, Int>\nwhere\n    B: Backend,\n{\n    /// Returns a new integer tensor on the specified device.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range of values to generate.\n    /// * `device` - The device to create the tensor on.\n    pub fn arange(range: Range<i64>, device: &B::Device) -> Self {\n        Tensor::new(B::int_arange(range, device))\n    }\n\n    /// Returns a new integer tensor on the specified device.\n    ///\n    /// # Arguments\n    ///\n    /// * `range` - The range of values to generate.\n    /// * `step` - The step between each value.\n    pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {\n        Tensor::new(B::int_arange_step(range, step, device))\n    }\n}\n\nimpl<const D: usize, B> Tensor<B, D, Int>\nwhere\n    B: Backend,\n{\n    /// Create a tensor from integers (i32), placing it on a given device.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Int};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2], &device);\n    ///     let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]], &device);\n    /// }\n    /// ```\n    pub fn from_ints<A: Into<TensorData>>(ints: A, device: &B::Device) -> Self {\n        Self::from_data(ints.into().convert::<i32>(), device)\n    }\n\n    /// Returns a new tensor with the same shape and device as the current tensor and the data\n    /// cast to Float.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let int_tensor = Tensor::<B, 1, Int>::arange(0..5, &device);\n    ///     let float_tensor = int_tensor.float();\n    /// }\n    /// ```\n    pub fn float(self) -> Tensor<B, D, Float> {\n        Tensor::new(TensorPrimitive::Float(B::int_into_float(self.primitive)))\n    }\n\n    /// Generates a cartesian grid for the given tensor shape on the specified device.\n    /// The generated tensor is of dimension `D2 = D + 1`, where each element at dimension D contains the cartesian grid coordinates for that element.\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape specifying the dimensions of the tensor.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Panics\n    ///\n    /// Panics if `D2` is not equal to `D+1`.\n    ///\n    /// # Examples\n    ///\n    /// ```rust\n    ///    use burn_tensor::Int;\n    ///    use burn_tensor::{backend::Backend, Shape, Tensor};\n    ///    fn example<B: Backend>() {\n    ///        let device = Default::default();\n    ///        let result: Tensor<B, 3, _> = Tensor::<B, 2, Int>::cartesian_grid([2, 3], &device);\n    ///        println!(\"{}\", result);\n    ///    }\n    /// ```\n    pub fn cartesian_grid<S: Into<Shape>, const D2: usize>(\n        shape: S,\n        device: &B::Device,\n    ) -> Tensor<B, D2, Int> {\n        cartesian_grid::<B, S, D, D2>(shape, device)\n    }\n\n    /// Applies the bitwise logical and operation with each bit representing the integer.\n    pub fn bitwise_and(self, other: Self) -> Self {\n        Self::new(B::bitwise_and(self.primitive, other.primitive))\n    }\n\n    /// Applies the bitwise logical or operation with another tensor.\n    pub fn bitwise_or(self, other: Self) -> Self {\n        Self::new(B::bitwise_or(self.primitive, other.primitive))\n    }\n\n    /// Applies the bitwise logical xor operation with another tensor.\n    pub fn bitwise_xor(self, other: Self) -> Self {\n        Self::new(B::bitwise_xor(self.primitive, other.primitive))\n    }\n\n    /// Applies the bitwise logical not operation.\n    pub fn bitwise_not(self) -> Self {\n        Self::new(B::bitwise_not(self.primitive))\n    }\n\n    /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor.\n    pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(B::bitwise_and_scalar(self.primitive, other))\n    }\n\n    /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor.\n    pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(B::bitwise_or_scalar(self.primitive, other))\n    }\n\n    /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor.\n    pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(B::bitwise_xor_scalar(self.primitive, other))\n    }\n\n    /// Applies the bitwise left shift operation with the integers in the tensor.\n    pub fn bitwise_left_shift(self, other: Self) -> Self {\n        Self::new(B::bitwise_left_shift(self.primitive, other.primitive))\n    }\n\n    /// Applies the bitwise right shift operation with the integers in the tensor.\n    pub fn bitwise_right_shift(self, other: Self) -> Self {\n        Self::new(B::bitwise_right_shift(self.primitive, other.primitive))\n    }\n\n    /// Applies the bitwise left shift operation with the scalar.\n    pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(B::bitwise_left_shift_scalar(self.primitive, other))\n    }\n\n    /// Applies the bitwise right shift operation with the scalar.\n    pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(B::bitwise_right_shift_scalar(self.primitive, other))\n    }\n\n    /// Converts a tensor to the specified integer data type.\n    ///\n    /// This is always a no-op when casting to the current dtype.\n    ///\n    /// # Warning\n    /// Most backends don't have automatic type promotion at this time, so make sure that all tensors\n    /// have the same integer data type for operations multiple input tensors (e.g., binary ops).\n    pub fn cast<F: Into<IntDType>>(self, dtype: F) -> Tensor<B, D, Int> {\n        let dtype = dtype.into();\n        let self_dtype: IntDType = self.dtype().into();\n        if dtype == self_dtype {\n            // no-op.\n            return self;\n        }\n        Tensor::new(B::int_cast(self.primitive, dtype))\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/mod.rs",
    "content": "pub(crate) mod check;\n\nmod autodiff;\nmod base;\nmod bool;\nmod cartesian_grid;\nmod float;\nmod fmod;\nmod int;\nmod numeric;\nmod options;\nmod orderable;\nmod pad;\npub use pad::IntoPadding;\nmod take;\nmod transaction;\nmod trunc;\n\npub use autodiff::*;\npub use base::*;\npub use cartesian_grid::cartesian_grid;\npub use float::{DEFAULT_ATOL, DEFAULT_RTOL};\npub use numeric::*;\npub use options::*;\npub use transaction::*;\n\npub use burn_backend::tensor::IndexingUpdateOp;\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/numeric.rs",
    "content": "use burn_backend::Scalar;\npub use burn_backend::tensor::Numeric;\n\nuse crate::alloc::borrow::ToOwned;\nuse alloc::vec::Vec;\n\nuse crate::IndexingUpdateOp;\nuse crate::{\n    AsIndex, Bool, Distribution, Element, ElementConversion, Int, Shape, Tensor, backend::Backend,\n    check, check::TensorCheck,\n};\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    /// Applies element wise addition operation.\n    ///\n    /// `y = x2 + x1`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to add.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1 + tensor2;\n    ///    println!(\"{tensor}\");\n    ///    // [[3.0, 1.0, 7.0], [6.0, 11.0, 9.0]]\n    /// }\n    /// ```\n    #[allow(clippy::should_implement_trait)]\n    pub fn add(self, other: Self) -> Self {\n        check!(TensorCheck::binary_ops_ew(\"Add\", &self, &other));\n        Self::new(K::add(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise addition operation with a scalar.\n    ///\n    /// `y = x + s`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to add, element wise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let scalar = 2.0;\n    ///   let tensor = tensor + scalar;\n    ///   println!(\"{tensor}\");\n    ///   // [[3.0, 0.0, 5.0], [7.0, 11.0, 8.0]]\n    /// }\n    /// ```\n    pub fn add_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::add_scalar(self.primitive, other))\n    }\n\n    /// Applies element wise subtraction operation.\n    ///\n    /// `y = x2 - x1`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to subtract.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///   let tensor = tensor1 - tensor2;\n    ///   println!(\"{tensor}\");\n    ///   // [[-1.0, -5.0, -1.0], [4.0, 7.0, 3.0]]\n    /// }\n    /// ```\n    #[allow(clippy::should_implement_trait)]\n    pub fn sub(self, other: Self) -> Self {\n        check!(TensorCheck::binary_ops_ew(\"Sub\", &self, &other));\n        Self::new(K::sub(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise subtraction operation with a scalar.\n    ///\n    /// `y = x - s`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to subtract, element wise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let scalar = 2.0;\n    ///    let tensor = tensor - scalar;\n    ///    println!(\"{tensor}\");\n    ///    // [[-1.0, -4.0, 1.0], [3.0, 7.0, 4.0]]\n    /// }\n    /// ```\n    pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::sub_scalar(self.primitive, other))\n    }\n\n    /// Applies element wise division operation.\n    ///\n    /// `y = x2 / x1`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to divide.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1 / tensor2;\n    ///    println!(\"{tensor}\");\n    ///    // [[0.5, -0.6666667, 0.75], [5.0, 4.5, 2.0]]\n    /// }\n    /// ```\n    #[allow(clippy::should_implement_trait)]\n    pub fn div(self, other: Self) -> Self {\n        check!(TensorCheck::binary_ops_ew(\"Div\", &self, &other));\n        Self::new(K::div(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise division operation with a scalar.\n    ///\n    /// `y = x / s`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to divide, element wise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let scalar = 2.0;\n    ///    let tensor = tensor / scalar;\n    ///    println!(\"{tensor}\");\n    ///    // [[0.5, -1.0, 1.5], [2.5, 4.5, 3.0]]\n    /// }\n    /// ```\n    pub fn div_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::div_scalar(self.primitive, other))\n    }\n\n    /// Applies element wise the remainder operation with a scalar.\n    ///\n    /// `y = x2 % x1`\n    pub fn remainder(self, other: Self) -> Self {\n        Self::new(K::remainder(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise the remainder operation with a scalar.\n    ///\n    /// `y = x % s`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to divide, element wise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let scalar = 2.0;\n    ///    let tensor = tensor1 % scalar;\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]\n    /// }\n    /// ```\n    pub fn remainder_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::remainder_scalar(self.primitive, other))\n    }\n\n    /// Applies element wise multiplication operation.\n    ///\n    /// `y = x2 * x1`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to multiply.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1 * tensor2;\n    ///    println!(\"{tensor}\");\n    ///    // [[2.0, -6.0, 12.0], [5.0, 18.0, 18.0]]\n    /// }\n    /// ```\n    #[allow(clippy::should_implement_trait)]\n    pub fn mul(self, other: Self) -> Self {\n        check!(TensorCheck::binary_ops_ew(\"Mul\", &self, &other));\n        Self::new(K::mul(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise multiplication operation with a scalar.\n    ///\n    /// `y = x * s`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to multiply, element wise.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let scalar = 2.0;\n    ///    let tensor = tensor * scalar;\n    ///    println!(\"{tensor}\");\n    ///    // [[2.0, -4.0, 6.0], [10.0, 18.0, 12.0]]\n    /// }\n    /// ```\n    pub fn mul_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::mul_scalar(self.primitive, other))\n    }\n\n    /// Switch sign of each element in the tensor.\n    ///\n    /// `y = -x`\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = -tensor;\n    ///    println!(\"{tensor}\");\n    ///    // [[-1.0, 2.0, -3.0], [-5.0, -9.0, -6.0]]\n    /// }\n    /// ```\n    #[allow(clippy::should_implement_trait)]\n    pub fn neg(self) -> Self {\n        Self::new(K::neg(self.primitive))\n    }\n\n    /// Returns the signs of the elements of the input tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.sign();\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, -1.0, 1.0], [1.0, 1.0, 1.0]]\n    /// }\n    /// ```\n    pub fn sign(self) -> Self {\n        Self::new(K::sign(self.primitive))\n    }\n\n    /// Aggregate all elements in the tensor with the mean operation.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.mean();\n    ///    println!(\"{tensor}\");\n    ///    // [3.6666667]\n    /// }\n    /// ```\n    pub fn mean(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::mean(self.primitive))\n    }\n\n    /// Aggregate all elements in the tensor with the sum operation.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.sum();\n    ///   println!(\"{tensor}\");\n    ///   // [22.0]\n    /// }\n    /// ```\n    pub fn sum(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::sum(self.primitive))\n    }\n\n    /// Aggregate all elements along the given *dimension* or *axis*\n    /// in the tensor with the mean operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.clone().mean_dim(0);\n    ///   println!(\"{tensor}\");\n    ///   // [[3.0, 3.5, 4.5]]\n    ///   let tensor = tensor.clone().mean_dim(1);\n    ///   println!(\"{tensor}\");\n    ///   // [[0.6666667], [6.6666665]]\n    /// }\n    /// ```\n    pub fn mean_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Mean\", dim));\n        Self::new(K::mean_dim(self.primitive, dim))\n    }\n\n    /// Aggregate all elements along the given *axes*\n    /// in the tensor with the mean operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - the dimensions to aggregate; supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[2.0, 4.0], [6.0, -4.0]], &device);\n    ///    let tensor = tensor.clone().mean_dims(&[0, 1]);\n    ///    println!(\"{tensor}\");\n    ///    // [[2.0]]\n    /// }\n    /// ```\n    pub fn mean_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter().fold(self, |tensor, &dim| tensor.mean_dim(dim))\n    }\n\n    /// Aggregate all elements along the given *dimension* or *axis*\n    /// in the tensor with the sum operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.clone().sum_dim(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[6.0, 7.0, 9.0]]\n    ///    let tensor = tensor.clone().sum_dim(1);\n    ///    println!(\"{tensor}\");\n    ///    // [[2.0], [20.0]]\n    /// }\n    /// ```\n    pub fn sum_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Sum\", dim));\n        Self::new(K::sum_dim(self.primitive, dim))\n    }\n\n    /// Aggregate all elements along the given *axes*\n    /// in the tensor with the sum operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - the dimensions to aggregate; supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);\n    ///    println!(\"{tensor}\");\n    ///    // [[27]]\n    /// }\n    /// ```\n    pub fn sum_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter().fold(self, |tensor, &dim| tensor.sum_dim(dim))\n    }\n\n    /// Aggregate and squeeze along the given dimensions.\n    ///\n    /// This is equivalent to ``tensor.sum_dims(dims).squeeze_dims(dims)``\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - the dimensions to aggregate; supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 3>::from_data([\n    ///         [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]],\n    ///         [[9.0, 2.0, 5.0], [5.0, 7.0, 7.0]],\n    ///     ], &device);\n    ///     let tensor = tensor.clone().sum_dims_squeeze::<1, _>(&[0, 1]);\n    ///     println!(\"{tensor}\");\n    ///     // [20.0, 16.0, 21.0]\n    /// }\n    /// ```\n    pub fn sum_dims_squeeze<const D2: usize, I: AsIndex>(self, dims: &[I]) -> Tensor<B, D2, K> {\n        // TODO: remove idims when squeeze_dims uses AsIndex.\n        let idims = dims\n            .iter()\n            .map(|&dim| (dim.expect_dim_index(D)) as isize)\n            .collect::<Vec<_>>();\n        self.sum_dims(dims).squeeze_dims::<D2>(&idims)\n    }\n\n    /// Aggregate all elements in the tensor with the product operation.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.prod();\n    ///    println!(\"{tensor}\");\n    ///    // [-1620.0]\n    /// }\n    /// ```\n    pub fn prod(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::prod(self.primitive))\n    }\n\n    /// Aggregate all elements along the given *dimension* or *axis*\n    /// in the tensor with the product operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements,\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimension will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.clone().prod_dim(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[5.0, -18.0, 18.0]]\n    ///    let tensor = tensor.clone().prod_dim(1);\n    ///    println!(\"{tensor}\");\n    ///    // [[-6.0], [270.0]]\n    /// }\n    /// ```\n    pub fn prod_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Prod\", dim));\n        Self::new(K::prod_dim(self.primitive, dim))\n    }\n\n    /// Aggregate all elements along the given *axes*\n    /// in the tensor with the prod operation.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - the dimensions to aggregate, supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.clone().sum_dims(&[0, 1]);\n    ///    println!(\"{tensor}\");\n    ///    // [[-1620.0]]\n    /// }\n    /// ```\n    pub fn prod_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter().fold(self, |tensor, &dim| tensor.prod_dim(dim))\n    }\n\n    /// Computes the cumulative sum of elements along the given *dimension* or *axis*.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to compute the cumulative sum.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    ///    let result = tensor.clone().cumsum(0);\n    ///    println!(\"{result}\");\n    ///    // [[1.0, 2.0, 3.0], [5.0, 7.0, 9.0]]\n    ///    let result = tensor.cumsum(1);\n    ///    println!(\"{result}\");\n    ///    // [[1.0, 3.0, 6.0], [4.0, 9.0, 15.0]]\n    /// }\n    /// ```\n    pub fn cumsum(self, dim: usize) -> Self {\n        check!(TensorCheck::aggregate_dim::<D>(\"CumSum\", dim));\n        Self::new(K::cumsum(self.primitive, dim))\n    }\n\n    /// Computes the cumulative product of elements along the given *dimension* or *axis*.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to compute the cumulative product.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    ///    let result = tensor.clone().cumprod(0);\n    ///    println!(\"{result}\");\n    ///    // [[1.0, 2.0, 3.0], [4.0, 10.0, 18.0]]\n    ///    let result = tensor.cumprod(1);\n    ///    println!(\"{result}\");\n    ///    // [[1.0, 2.0, 6.0], [4.0, 20.0, 120.0]]\n    /// }\n    /// ```\n    pub fn cumprod(self, dim: usize) -> Self {\n        check!(TensorCheck::aggregate_dim::<D>(\"CumProd\", dim));\n        Self::new(K::cumprod(self.primitive, dim))\n    }\n\n    /// Apply element wise absolute value operation.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = Default::default();\n    ///   let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [4, -5, 6], [7, -8, 9]], &device);\n    ///   let tensor = tensor.abs();\n    ///   println!(\"{tensor}\");\n    ///   // [[1, 2, 3], [4, 5, 6], [7, 8, 9]]\n    /// }\n    /// ```\n    ///\n    /// # Notes\n    ///\n    /// For signed integer dtypes, this operation uses two's-complement wraparound semantics, similar to\n    /// `x.wrapping_abs()`. For example, `abs(i64::MIN) == i64::MIN`.\n    pub fn abs(self) -> Self {\n        Self::new(K::abs(self.primitive))\n    }\n\n    /// Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input,\n    /// the other elements of the result tensor out are set to 0.\n    ///\n    /// See also [`triu_mask`](Tensor::triu_mask).\n    ///\n    /// # Arguments\n    ///\n    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift\n    ///   towards the upper triangle.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    let tensor = Tensor::<B, 2, Int>::from_ints(\n    ///        [\n    ///          [1, 2, 3],\n    ///          [4, 5, 6],\n    ///          [7, 8, 9]\n    ///        ],\n    ///        &device\n    ///    );\n    ///    let tensor = tensor.triu(1);\n    ///    println!(\"{tensor}\");\n    ///    // [\n    ///    //   [0, 2, 3],\n    ///    //   [0, 0, 6],\n    ///    //   [0, 0, 0]\n    ///    // ]\n    /// }\n    /// ```\n    pub fn triu(self, diagonal: i64) -> Self {\n        check!(TensorCheck::tri::<{ D }>());\n\n        // last two dimensions\n        let shape = &self.shape()[D - 2..].to_owned();\n\n        let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();\n        self.mask_fill(mask, 0)\n    }\n\n    /// Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input,\n    /// the other elements of the result tensor out are set to 0.\n    ///\n    /// See also [`tril_mask`](Tensor::tril_mask).\n    ///\n    /// # Arguments\n    ///\n    /// * `diagonal` - The offset from the diagonal, where 0 means the diagonal, and positive values shift\n    ///   towards the upper triangle.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    let tensor = Tensor::<B, 2, Int>::from_ints(\n    ///        [\n    ///          [1, 2, 3],\n    ///          [4, 5, 6],\n    ///          [7, 8, 9]\n    ///        ],\n    ///        &device\n    ///    );\n    ///\n    ///    let tensor = tensor.tril(-1);\n    ///    println!(\"{tensor}\");\n    ///    // [\n    ///    //   [0, 0, 0],\n    ///    //   [4, 0, 0],\n    ///    //   [7, 8, 0]\n    ///    // ]\n    /// }\n    /// ```\n    pub fn tril(self, diagonal: i64) -> Self {\n        check!(TensorCheck::tri::<{ D }>());\n\n        // last two dimensions\n        let shape = &self.shape()[D - 2..].to_owned();\n        let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();\n\n        self.mask_fill(mask, 0)\n    }\n\n    /// Applies element wise power operation with a integer Tensor\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to apply the power operation with.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, Int};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);\n    ///    let tensor2 = Tensor::<B, 2, Int>::from_ints([[2, 3, 4], [1, 2, 3]], &device);\n    ///    let tensor = tensor1.powi(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[1, -8, 81], [5, 81, 216]]\n    /// }\n    /// ```\n    pub fn powi(self, other: Self) -> Self {\n        Self::new(K::powi(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise power operation with a integer scalar\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The scalar to apply the power operation with.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, Int};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2, Int>::from_ints([[1, -2, 3], [5, 9, 6]], &device);\n    ///    let tensor = tensor.powi_scalar(2);\n    ///    println!(\"{tensor}\");\n    ///\n    ///    // [[1, 4, 9], [25, 81, 36]]\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.5, -2., 3.], [5., 9., 6.]], &device);\n    ///    let tensor = tensor.powi_scalar(2);\n    ///    println!(\"{tensor}\");\n    ///    // [[2.25, 4., 9.], [25., 81., 36.]]\n    /// }\n    /// ```\n    pub fn powi_scalar<E: ElementConversion>(self, other: E) -> Self {\n        let other = Scalar::new(other, &self.dtype());\n        Self::new(K::powi_scalar(self.primitive, other))\n    }\n\n    /// Converts the tensor to a boolean tensor by checking if the elements are non-zero.\n    ///\n    /// # Returns\n    ///\n    /// A boolean tensor with the same shape as the input tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [0.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.bool();\n    ///   println!(\"{tensor}\");\n    ///   // [\n    ///   //   [true, true, true],\n    ///   //   [false, true, true]\n    ///   // ]\n    /// }\n    pub fn bool(self) -> Tensor<B, D, Bool> {\n        self.not_equal_elem(0)\n    }\n\n    /// Create a random tensor of the given shape on the given device where each element is\n    /// sampled from the given distribution.\n    ///\n    /// See also [`random_like`](Tensor::random_like).\n    ///\n    /// # Arguments\n    ///\n    /// * `shape` - The shape of the tensor.\n    /// * `distribution` - The distribution to sample from.\n    /// * `device` - The device to create the tensor on.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the given shape and elements sampled from the given distribution.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape, Distribution};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let distribution = Distribution::Uniform(0.0, 1.0); // Any random value between 0.0 and 1.0\n    ///   let tensor = Tensor::<B, 2>::random(Shape::new([2, 3]), distribution, &device);\n    ///   println!(\"{tensor}\");\n    ///   // [\n    ///   //   [0.08347523, 0.70498955, 0.60332155],\n    ///   //   [0.08173251, 0.18028641, 0.97942924]\n    ///   // ]\n    /// }\n    /// ```\n    pub fn random<S: Into<Shape>>(\n        shape: S,\n        distribution: Distribution,\n        device: &B::Device,\n    ) -> Self {\n        Self::new(K::random(shape.into(), distribution, device))\n    }\n\n    /// Applies the matrix multiplication operation.\n    ///\n    /// ```math\n    /// C = AB\n    /// ```\n    ///\n    /// Shapes of the form `[..., B, 1, K] @ [..., 1, K, N]` are reinterpreted as\n    /// `[..., 1, B, K] @ [..., 1, K, N]`, turning a batched vec-mat into a general\n    /// matmul, which is often faster.\n    pub fn matmul(self, other: Self) -> Self {\n        check!(TensorCheck::matmul(&self, &other));\n\n        if D >= 3 {\n            let batch_index = D - 3;\n            let vector_index = D - 2;\n            let lhs_dims = &self.shape()[batch_index..D];\n            let rhs_dims = &other.shape()[batch_index..D];\n\n            if let ([_, 1, k1], [1, k2, _]) = (lhs_dims, rhs_dims)\n                && k1 == k2\n            {\n                return Tensor::new(K::matmul(\n                    self.swap_dims(batch_index, vector_index).primitive,\n                    other.primitive,\n                ))\n                .swap_dims(batch_index, vector_index);\n            }\n        }\n\n        Tensor::new(K::matmul(self.primitive, other.primitive))\n    }\n}\n\nimpl<B, K> Tensor<B, 1, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    /// Calculates the dot product with another tensor.\n    ///\n    /// `y = x2.dot(x1)`\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The tensor to compute dot product with.\n    ///\n    /// # Notes\n    ///\n    /// Both tensors must have the same number of elements.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 1>::from_data([1.0, 2.0], &device);\n    ///    let tensor2 = Tensor::<B, 1>::from_data([-2.0, 3.0], &device);\n    ///    let tensor = tensor1.dot(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [4]\n    /// }\n    /// ```\n    pub fn dot(self, other: Self) -> Self {\n        self.mul(other).sum()\n    }\n}\n\nimpl<B, K> Tensor<B, 2, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    /// Creates a new 2D tensor with ones on the diagonal and zeros elsewhere.\n    ///\n    /// # Arguments\n    ///\n    /// * `size` - The size of the square matrix.\n    pub fn eye(size: usize, device: &B::Device) -> Self {\n        let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze::<2>();\n        let ones = Self::ones([1, size], device);\n        let zeros = Self::zeros([size, size], device);\n\n        zeros.scatter(0, indices, ones, IndexingUpdateOp::Add)\n    }\n}\n\n// Tensor + tensor\nimpl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Add<Self> for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn add(self, rhs: Self) -> Self::Output {\n        Self::add(self, rhs)\n    }\n}\n\n// Tensor + scalar\nimpl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<E>\n    for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn add(self, other: E) -> Self::Output {\n        Tensor::add_scalar(self, other)\n    }\n}\n\n// Scalar + tensor\nmacro_rules! impl_tensor_scalar_add {\n    ($($t:ty),*) => {\n        $(\n            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Add<Tensor<B, D, K>> for $t\n            where\n                K::Elem: Element,\n            {\n                type Output = Tensor<B, D, K>;\n\n                fn add(self, tensor: Tensor<B, D, K>) -> Self::Output {\n                    Tensor::add_scalar(tensor, self)\n                }\n            }\n        )*\n    }\n}\nimpl_tensor_scalar_add!(f32, f64, i32, i64, u32, u64);\n\n// Tensor - tensor\nimpl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Sub<Self> for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn sub(self, rhs: Self) -> Self::Output {\n        Tensor::sub(self, rhs)\n    }\n}\n\n// Tensor - scalar\nimpl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<E>\n    for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn sub(self, other: E) -> Self::Output {\n        Tensor::sub_scalar(self, other)\n    }\n}\n\n// Scalar - tensor\nmacro_rules! impl_tensor_scalar_sub {\n    ($($t:ty),*) => {\n        $(\n            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Sub<Tensor<B, D, K>> for $t\n            where\n                K::Elem: Element,\n            {\n                type Output = Tensor<B, D, K>;\n\n                fn sub(self, tensor: Tensor<B, D, K>) -> Self::Output {\n                    Tensor::add_scalar(Tensor::neg(tensor), self)\n                }\n            }\n        )*\n    }\n}\nimpl_tensor_scalar_sub!(f32, f64, i32, i64, u32, u64);\n\n// Tensor / tensor\nimpl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Div<Self> for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn div(self, rhs: Self) -> Self::Output {\n        Tensor::div(self, rhs)\n    }\n}\n\n// Tensor / scalar\nimpl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Div<E>\n    for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn div(self, other: E) -> Self::Output {\n        Tensor::div_scalar(self, other)\n    }\n}\n\n// Scalar / tensor (float only)\nmacro_rules! impl_tensor_scalar_div {\n    ($($t:ty),*) => {\n        $(\n            impl<const D: usize, B: Backend> core::ops::Div<Tensor<B, D>> for $t\n            {\n                type Output = Tensor<B, D>;\n\n                fn div(self, tensor: Tensor<B, D>) -> Self::Output {\n                    tensor.recip().mul_scalar(self)\n                }\n            }\n        )*\n    }\n}\n\nimpl_tensor_scalar_div!(f32, f64);\n\n// Tensor % tensor.\nimpl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<Self> for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn rem(self, rhs: Self) -> Self::Output {\n        Tensor::remainder(self, rhs)\n    }\n}\n\n// Tensor % scalar.\nimpl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Rem<E>\n    for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn rem(self, other: E) -> Self::Output {\n        Tensor::remainder_scalar(self, other)\n    }\n}\n\n// Tensor * tensor.\nimpl<B: Backend, const D: usize, K: Numeric<B>> core::ops::Mul<Self> for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn mul(self, rhs: Self) -> Self::Output {\n        Tensor::mul(self, rhs)\n    }\n}\n\n// Tensor * scalar.\nimpl<E: ElementConversion, const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<E>\n    for Tensor<B, D, K>\nwhere\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn mul(self, other: E) -> Self::Output {\n        Tensor::mul_scalar(self, other)\n    }\n}\n\nmacro_rules! impl_tensor_scalar_mul {\n    ($($t:ty),*) => {\n        $(\n            impl<const D: usize, B: Backend, K: Numeric<B>> core::ops::Mul<Tensor<B, D, K>> for $t\n            where\n                K::Elem: Element,\n            {\n                type Output = Tensor<B, D, K>;\n\n                fn mul(self, other: Tensor<B, D, K>) -> Self::Output {\n                    Tensor::mul_scalar(other, self)\n                }\n            }\n        )*\n    }\n}\n\nimpl_tensor_scalar_mul!(f32, f64, i32, i64, u32, u64);\n\nimpl<B, const D: usize, K> core::ops::Neg for Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    type Output = Self;\n\n    fn neg(self) -> Self::Output {\n        Tensor::neg(self)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/options.rs",
    "content": "use burn_backend::{\n    Backend, Element,\n    tensor::{BasicOps, Device},\n};\nuse burn_std::DType;\n\nuse crate::get_device_policy;\n\n/// Options for tensor creation.\n///\n/// This struct allows specifying the `device` and overriding the data type when creating a tensor.\n/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.\n#[derive(Debug, Clone)]\npub struct TensorCreationOptions<B: Backend> {\n    /// Device where the tensor will be created.\n    pub device: Device<B>,\n    /// Optional data type.\n    /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).\n    pub dtype: Option<DType>,\n}\n\nimpl<B: Backend> Default for TensorCreationOptions<B> {\n    /// Returns new options with the backend's default device.\n    fn default() -> Self {\n        Self::new(Default::default())\n    }\n}\n\nimpl<B: Backend> TensorCreationOptions<B> {\n    /// Create new options with a specific device.\n    ///\n    /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.\n    pub fn new(device: Device<B>) -> Self {\n        Self {\n            device,\n            dtype: None,\n        }\n    }\n\n    /// Set the tensor creation data type.\n    pub fn with_dtype(mut self, dtype: DType) -> Self {\n        self.dtype = Some(dtype);\n\n        self\n    }\n\n    /// Set the tensor creation device.\n    pub fn with_device(mut self, device: Device<B>) -> Self {\n        self.device = device;\n\n        self\n    }\n\n    /// Create options with backend's default device and float dtype.\n    pub fn float() -> Self {\n        Self::default().with_dtype(<B::FloatElem as Element>::dtype())\n    }\n\n    /// Create options with backend's default device and int dtype.\n    pub fn int() -> Self {\n        Self::default().with_dtype(<B::IntElem as Element>::dtype())\n    }\n\n    /// Create options with backend's default device and bool dtype.\n    pub fn bool() -> Self {\n        Self::default().with_dtype(<B::BoolElem as Element>::dtype())\n    }\n\n    /// Returns the tensor data type, or a provided default if not set.\n    ///\n    /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.\n    pub fn dtype_or(&self, dtype: DType) -> DType {\n        self.dtype.unwrap_or(dtype)\n    }\n\n    /// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes).\n    pub(crate) fn resolve_policy<K: BasicOps<B>>(&self) -> DType {\n        let dtype = K::Elem::dtype();\n        let kind_name = K::name();\n        // TODO: tensor kind enum?\n        self.dtype.unwrap_or_else(|| {\n            let policy = get_device_policy(&self.device);\n            if dtype.is_float()\n                && kind_name == \"Float\"\n                && let Some(float_dtype) = policy.float_dtype()\n            {\n                float_dtype.into()\n            } else if (dtype.is_int() || dtype.is_uint())\n                && kind_name == \"Int\"\n                && let Some(int_dtype) = policy.int_dtype()\n            {\n                int_dtype.into()\n            } else {\n                // If policy was not explicitly set, use the fallback dtype (default backend elem type)\n                dtype\n            }\n        })\n    }\n}\n\nimpl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {\n    /// Convenience conversion from a reference to a device.\n    ///\n    /// Example:\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::TensorCreationOptions;\n    ///\n    /// fn example<B: Backend>(device: B::Device) {\n    ///     let options: TensorCreationOptions<B> = (&device).into();\n    /// }\n    /// ```\n    fn from(device: &Device<B>) -> Self {\n        TensorCreationOptions::new(device.clone())\n    }\n}\n\nimpl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {\n    /// Convenience conversion for a specified `(&device, dtype)` tuple.\n    fn from(args: (&Device<B>, DType)) -> Self {\n        TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/orderable.rs",
    "content": "use burn_backend::{\n    Backend, ElementConversion, Scalar,\n    tensor::{Bool, IndexingUpdateOp, Int, Ordered},\n};\nuse burn_std::AsIndex;\n\nuse crate::check;\nuse crate::{Tensor, check::TensorCheck};\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Ordered<B>,\n{\n    /// Sort the elements by value in ascending order along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the elements sorted in ascending order along the given dimension.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///   let tensor = tensor.sort(0);\n    ///   println!(\"{tensor}\");\n    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]\n    ///   let tensor = tensor.sort(1);\n    ///   println!(\"{tensor}\");\n    ///   // [[-2.0, 3.0, 12.0], [3.0, 5.0, 6.0]]\n    /// }\n    /// ```\n    pub fn sort(self, dim: usize) -> Self {\n        check!(TensorCheck::sort_dim::<D>(\"Sort\", dim));\n        Tensor::new(K::sort(self.primitive, dim, /*descending*/ false))\n    }\n\n    /// Sort the elements by value in descending order along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the elements sorted in descending order along the given dimension.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///    let tensor = tensor.sort_descending(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]\n    ///    let tensor = tensor.sort_descending(1);\n    ///    println!(\"{tensor}\");\n    ///    // [[12.0, 3.0, -2.0], [6.0, 5.0, 3.0]]\n    /// }\n    /// ```\n    pub fn sort_descending(self, dim: usize) -> Self {\n        check!(TensorCheck::sort_dim::<D>(\"Sort\", dim));\n        Tensor::new(K::sort(self.primitive, dim, /*descending*/ true))\n    }\n\n    /// Sort the elements by value in ascending order along a given dimension.\n    /// Also returns the indices.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Returns\n    ///\n    /// A tuple containing the sorted tensor and the indices tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///   let (tensor, indices) = tensor.sort_with_indices(0);\n    ///   println!(\"{tensor}\");\n    ///   // [[5.0, -2.0, 3.0], [12.0, 3.0, 6.0]]\n    ///   println!(\"{}\", indices);\n    ///   // [[1, 0, 0], [0, 1, 1]]\n    /// }\n    /// ```\n    pub fn sort_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {\n        check!(TensorCheck::sort_dim::<D>(\"Sort_with_indices\", dim));\n        let (values, indices) =\n            K::sort_with_indices(self.primitive, dim, /*descending*/ false);\n        (Tensor::new(values), Tensor::new(indices))\n    }\n\n    /// Sort the elements by value in descending order along a given dimension.\n    /// Also returns the indices.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///    let (tensor, indices) = tensor.sort_descending_with_indices(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]\n    ///    println!(\"{}\", indices);\n    ///    // [[0, 1, 1], [1, 0, 0]]\n    /// }\n    /// ```\n    pub fn sort_descending_with_indices(self, dim: usize) -> (Self, Tensor<B, D, Int>) {\n        check!(TensorCheck::sort_dim::<D>(\"Sort_with_indices\", dim));\n        let (values, indices) = K::sort_with_indices(self.primitive, dim, /*descending*/ true);\n        (Tensor::new(values), Tensor::new(indices))\n    }\n\n    /// Returns the indices that sort the elements by value in ascending order along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///    let tensor = tensor.argsort(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[1, 0, 0], [0, 1, 1]]\n    /// }\n    /// ```\n    pub fn argsort(self, dim: usize) -> Tensor<B, D, Int> {\n        check!(TensorCheck::sort_dim::<D>(\"Argsort\", dim));\n        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ false))\n    }\n\n    /// Returns the indices that sort the elements by value in descending order along a given dimension.\n    ///\n    /// This sort is unstable (i.e., may reorder equal elements).\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///    let tensor = tensor.argsort_descending(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[0, 1, 1], [1, 0, 0]]\n    ///    let tensor = tensor.argsort_descending(1);\n    ///    println!(\"{tensor}\");\n    ///    // [[0, 2, 1], [2, 0, 1]]\n    /// }\n    /// ```\n    pub fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {\n        check!(TensorCheck::sort_dim::<D>(\"Argsort\", dim));\n        Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))\n    }\n\n    /// Returns the `k` largest elements of the given input tensor along a given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `k` - The number of elements to return.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the `k` largest elements along the given dimension.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///   let tensor = tensor.topk(2, 0);\n    ///   println!(\"{tensor}\");\n    ///   // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]\n    ///   let tensor = tensor.topk(1, 1);\n    ///   println!(\"{tensor}\");\n    ///   // [[12.0], [6.0]]\n    /// }\n    /// ```\n    pub fn topk(self, k: usize, dim: usize) -> Self {\n        let k_indices = Tensor::arange(0..k as i64, &self.device());\n        self.sort_descending(dim).select(dim, k_indices)\n    }\n\n    /// Returns the `k` largest elements of the given input tensor along a given dimension.\n    /// Also returns the indices.\n    ///\n    /// # Arguments\n    ///\n    /// * `k` - The number of elements to return.\n    /// * `dim` - The dimension to sort along.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///    let (tensor, indices) = tensor.topk_with_indices(2, 0);\n    ///    println!(\"{tensor}\");\n    ///    // [[12.0, 3.0, 6.0], [5.0, -2.0, 3.0]]\n    ///    println!(\"{}\", indices);\n    ///    // [[0, 1, 1], [1, 0, 0]]\n    ///    let (tensor, indices) = tensor.topk_with_indices(1, 1);\n    ///    println!(\"{tensor}\");\n    ///    // [[12.0], [6.0]]\n    ///    println!(\"{indices}\");\n    ///    // [[0], [2]]\n    /// }\n    /// ```\n    pub fn topk_with_indices(self, k: usize, dim: usize) -> (Self, Tensor<B, D, Int>) {\n        let k_indices = Tensor::arange(0..k as i64, &self.device());\n        let (values, indices) = self.sort_descending_with_indices(dim);\n        (\n            values.select(dim, k_indices.clone()),\n            indices.select(dim, k_indices),\n        )\n    }\n\n    /// Create a one hot tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>(){\n    ///     let device = Default::default();\n    ///     let indices: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device);\n    ///     let one_hot: Tensor<B, 2> = indices.one_hot(4);\n    ///     println!(\"{}\", one_hot.to_data());\n    ///     // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]\n    /// }\n    /// ```\n    pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, K> {\n        check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));\n        self.one_hot_fill(num_classes, 1.0, 0.0, -1)\n    }\n\n    /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors.\n    ///\n    /// # Arguments\n    ///\n    /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension.\n    /// * `on_value`: The value to assign for active positions (corresponding to indices).\n    /// * `off_value`: The value to assign for inactive positions.\n    /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`.\n    ///\n    /// # Example\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Float};\n    /// fn example<B: Backend<FloatElem: From<f32>>>() {\n    ///     let device = B::Device::default();\n    ///     let indices: Tensor<B, 2, Float> = Tensor::from_floats([[0., 2.], [1., -1.]], &device);\n    ///     // One-hot encoding\n    ///     let tensor:Tensor<B, 3, Float> = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1);\n    ///     println!(\"{tensor}\");\n    ///     // [[[5.0, 0.0, 0.0],\n    ///     // [0.0, 0.0, 5.0]],\n    ///     // [[0.0, 5.0, 0.0],\n    ///     // [0.0, 0.0, 5.0]]]\n    /// }\n    /// ```\n    pub fn one_hot_fill<const D2: usize>(\n        self,\n        num_classes: usize,\n        on_value: f32,\n        off_value: f32,\n        axis: i64,\n    ) -> Tensor<B, D2, K> {\n        check!(TensorCheck::one_hot_tensor_rank::<D, D2>());\n        // Initialize shape from the current tensor dimensions and prepare for modification\n        let mut shape = self.shape();\n        let device = self.device();\n        let rank = self.dims().len();\n\n        // Adjust negative axis to a positive index\n        let axis = if axis < 0 {\n            axis + rank as i64 + 1\n        } else {\n            axis\n        };\n\n        // Ensure axis is within valid range\n        if axis < 0 || axis > rank as i64 {\n            panic!(\"Axis out of range. Accepted range is [-r-1, r] where r = rank(indices).\");\n        }\n        // Convert the input tensor to integer indices\n        let indices: Tensor<B, D, Int> =\n            Tensor::from_data(self.to_data().convert::<i64>(), &device);\n        // Insert the new dimension for the one-hot representation\n        shape.insert(axis as usize, num_classes);\n        // Adjust indices to valid range and handle invalid indices\n        let adjusted_indices = indices\n            .clone()\n            .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices\n            .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices\n        // Unsqueeze the indices tensor along the specified axis\n        let indices_unsqueezed: Tensor<B, D2, Int> = adjusted_indices.unsqueeze_dim(axis as usize);\n\n        // Initialize the output tensor with the off_value\n        let output = Tensor::full(shape.clone(), off_value, &device);\n\n        // Prepare scatter tensor for on_value and off_value adjustments\n        let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device)\n            - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device());\n\n        // Scatter on_value at the appropriate indices to create the one-hot representation\n        output.scatter(\n            axis as usize,\n            indices_unsqueezed,\n            scatter_on_values,\n            IndexingUpdateOp::Add,\n        )\n    }\n\n    /// Applies element wise greater comparison and returns a boolean tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///   let tensor = tensor1.greater(tensor2);\n    ///   println!(\"{tensor}\");\n    ///   // [[false, false, false], [true, true, true]]\n    /// }\n    /// ```\n    pub fn greater(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"Greater\", &self, &other));\n        Tensor::new(K::greater(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise greater-equal comparison and returns a boolean tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.greater_equal(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[true, false, false], [true, true, true]]\n    /// }\n    /// ```\n    pub fn greater_equal(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"Greater_equal\", &self, &other));\n        Tensor::new(K::greater_equal(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise lower comparison and returns a boolean tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.lower(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[false, true, true], [false, false, false]]\n    /// }\n    /// ```\n    pub fn lower(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"Lower\", &self, &other));\n        Tensor::new(K::lower(self.primitive, other.primitive))\n    }\n\n    /// Applies element wise lower-equal comparison and returns a boolean tensor.\n    ///\n    /// # Panics\n    ///\n    /// If the two tensors don't have the same shape.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[1.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.lower_equal(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[true, true, true], [false, false, false]]\n    /// }\n    /// ```\n    pub fn lower_equal(self, other: Self) -> Tensor<B, D, Bool> {\n        check!(TensorCheck::binary_ops_ew(\"Lower_equal\", &self, &other));\n        Tensor::new(K::lower_equal(self.primitive, other.primitive))\n    }\n\n    /// Applies greater than `other` comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.greater_elem(3.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[false, false, true], [true, true, true]]\n    /// }\n    /// ```\n    pub fn greater_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::greater_elem(self.primitive, other))\n    }\n\n    /// Applies greater-equal than `other` comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.greater_equal_elem(3.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[false, false, true], [true, true, true]]\n    /// }\n    /// ```\n    pub fn greater_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::greater_equal_elem(self.primitive, other))\n    }\n\n    /// Applies lower than `other` comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///     let tensor = tensor.lower_elem(3.0);\n    ///     println!(\"{tensor}\");\n    ///     // [[true, true, false], [false, false, false]]\n    /// }\n    /// ```\n    pub fn lower_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::lower_elem(self.primitive, other))\n    }\n\n    /// Applies lower-equal than `other` comparison and returns a boolean tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - The element to compare.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.lower_equal_elem(3.0);\n    ///    println!(\"{tensor}\");\n    ///    // [[true, true, true], [false, false, false]]\n    /// }\n    /// ```\n    pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {\n        let other = Scalar::new(other, &self.dtype());\n        Tensor::new(K::lower_equal_elem(self.primitive, other))\n    }\n\n    /// Applies the argmax function along the given dimension and returns an integer tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);\n    ///     let tensor = tensor.argmax(1);\n    ///     println!(\"{:?}\", tensor.shape());\n    ///     // Shape { dims: [2, 1, 3] }\n    /// }\n    /// ```\n    pub fn argmax(self, dim: usize) -> Tensor<B, D, Int> {\n        Tensor::new(K::argmax(self.primitive, dim))\n    }\n\n    /// Find the maximum value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.max();\n    ///   println!(\"{tensor}\");\n    ///   // [9.0]\n    /// }\n    /// ```\n    pub fn max(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::max(self.primitive))\n    }\n\n    /// Find the maximum value along the given dimension.\n    ///\n    /// Also returns the indices.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let (tensor, index) = tensor.max_dim_with_indices(0);\n    ///    // [[5.0, 9.0, 6.0]]\n    ///    println!(\"{tensor}\");\n    ///    // [[1, 1, 1]]\n    ///    println!(\"{index}\");\n    /// }\n    /// ```\n    pub fn max_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Max\", dim));\n\n        let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);\n\n        let tensor = Tensor::new(tensor);\n        let index = Tensor::new(index);\n\n        (tensor, index)\n    }\n\n    /// Find the maximum absolute value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -7.0, 3.0], [5.0, -1.0, 6.0]], &device);\n    ///   let tensor = tensor.max_abs();\n    ///   println!(\"{tensor}\");\n    ///   // [7.0]\n    /// }\n    /// ```\n    pub fn max_abs(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::max_abs(self.primitive))\n    }\n\n    /// Finds the maximum pair wise values with another tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - Other tensor to find maximum elements with\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensors containing the maximum value found\n    /// in the input tensors.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.max_pair(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[2.0, 3.0, 4.0], [5.0, 9.0, 6.0]]\n    /// }\n    /// ```\n    pub fn max_pair(self, other: Self) -> Self {\n        let mask = self.clone().lower(other.clone());\n        self.mask_where(mask, other)\n    }\n\n    /// Find the maximum absolute value along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements,\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimension will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.max_dim(0);\n    ///   println!(\"{tensor}\");\n    ///   // [[5.0, 9.0, 6.0]]\n    /// }\n    /// ```\n    pub fn max_abs_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"MaxAbs\", dim));\n\n        Tensor::new(K::max_abs_dim(self.primitive, dim))\n    }\n\n    /// Find the maximum absolute value along the given dimensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - The dimensions or axes along which to aggregate the elements,\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.max_abs_dims(&[0, 1]);\n    ///   println!(\"{tensor}\");\n    ///   // [[9.0]]\n    /// }\n    /// ```\n    pub fn max_abs_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter()\n            .fold(self, |tensor, &dim| tensor.max_abs_dim(dim))\n    }\n\n    /// Applies the argmin function along the given dimension and returns an integer tensor.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = Default::default();\n    ///     let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]), &device);\n    ///     let tensor = tensor.argmin(1);\n    ///     println!(\"{:?}\", tensor.shape());\n    ///     // Shape { dims: [2, 1, 3] }\n    /// }\n    /// ```\n    pub fn argmin(self, dim: usize) -> Tensor<B, D, Int> {\n        Tensor::new(K::argmin(self.primitive, dim))\n    }\n\n    /// Find the minimum value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.min();\n    ///    println!(\"{tensor}\");\n    ///    // [-2.0]\n    /// }\n    /// ```\n    pub fn min(self) -> Tensor<B, 1, K> {\n        Tensor::new(K::min(self.primitive))\n    }\n\n    /// Find the minimum value along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimension will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor = tensor.min_dim(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, -2.0, 3.0]]\n    /// }\n    /// ```\n    pub fn min_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Min\", dim));\n        Tensor::new(K::min_dim(self.primitive, dim))\n    }\n\n    /// Find the minimum value along the given dimensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - The dimensions or axes along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.min_dims(&[0, 1]);\n    ///   println!(\"{tensor}\");\n    ///   // [[-2.0]]\n    /// }\n    /// ```\n    pub fn min_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter().fold(self, |tensor, &dim| tensor.min_dim(dim))\n    }\n\n    /// Find the minimum value along the given dimension.\n    ///\n    /// Also returns the indices.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[7.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let (tensor, index) = tensor.min_dim_with_indices(0);\n    ///    println!(\"{tensor}\");\n    ///    // [[5.0, -2.0, 3.0]]\n    ///    println!(\"{}\", index);\n    ///    // [[1, 0, 0]]\n    /// }\n    /// ```\n    pub fn min_dim_with_indices<I: AsIndex>(self, dim: I) -> (Self, Tensor<B, D, Int>) {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Min\", dim));\n\n        let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);\n\n        let tensor = Tensor::new(tensor);\n        let index = Tensor::new(index);\n\n        (tensor, index)\n    }\n\n    /// Finds the minimum pair wise values with another tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `other` - Other tensor to find minimum elements with\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape as the input tensors containing the minimum value found\n    /// between each element of the two source tensors.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor1 = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///    let tensor2 = Tensor::<B, 2>::from_data([[2.0, 3.0, 4.0], [1.0, 2.0, 3.0]], &device);\n    ///    let tensor = tensor1.min_pair(tensor2);\n    ///    println!(\"{tensor}\");\n    ///    // [[1.0, -2.0, 3.0], [1.0, 2.0, 3.0]]\n    /// }\n    pub fn min_pair(self, other: Self) -> Self {\n        let mask = other.clone().lower(self.clone());\n        self.mask_where(mask, other)\n    }\n\n    /// Clamp element wise between the given min and max values.\n    ///\n    /// # Arguments\n    ///\n    /// * `min` - The minimum value.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped between the given min and max values.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = Default::default();\n    ///   let tensor = Tensor::<B, 2, Int>::from_ints(\n    ///    [\n    ///     [1, 2, 3],\n    ///     [4, 5, 6],\n    ///     [7, 8, 9]\n    ///    ],\n    ///    &device);\n    ///    let tensor = tensor.clamp(2, 6);\n    ///    println!(\"{tensor}\");\n    ///    // [[2, 2, 3], [4, 5, 6], [6, 6, 6]]\n    /// }\n    /// ```\n    pub fn clamp<E: ElementConversion>(self, min: E, max: E) -> Self {\n        let dtype = self.dtype();\n        Self::new(K::clamp(\n            self.primitive,\n            Scalar::new(min, &dtype),\n            Scalar::new(max, &dtype),\n        ))\n    }\n\n    /// Clamp element wise under a minimum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `min` - The minimum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped under the given min value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    let tensor = Tensor::<B, 2, Int>::from_ints(\n    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],\n    ///    &device);\n    ///    let tensor = tensor.clamp_min(4);\n    ///    println!(\"{tensor}\");\n    ///    // [[4, 4, 4], [4, 5, 6], [7, 8, 9]]\n    /// }\n    /// ```\n    pub fn clamp_min<E: ElementConversion>(self, min: E) -> Self {\n        let min = Scalar::new(min, &self.dtype());\n        Self::new(K::clamp_min(self.primitive, min))\n    }\n\n    /// Clamp element wise over a maximum value.\n    ///\n    /// # Arguments\n    ///\n    /// * `tensor` - The tensor to clamp.\n    /// * `max` - The maximum value.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the values clamped over the given max value.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Int, Tensor};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = Default::default();\n    ///    let tensor = Tensor::<B, 2, Int>::from_ints(\n    ///    [[1, 2, 3], [4, 5, 6], [7, 8, 9]],\n    ///    &device);\n    ///    let tensor = tensor.clamp_max(5);\n    ///    println!(\"{tensor}\");\n    ///    // [[1, 2, 3], [4, 5, 5], [5, 5, 5]]\n    /// }\n    /// ```\n    pub fn clamp_max<E: ElementConversion>(self, max: E) -> Self {\n        let max = Scalar::new(max, &self.dtype());\n        Self::new(K::clamp_max(self.primitive, max))\n    }\n\n    /// Computes the cumulative minimum of elements along the given *dimension* or *axis*.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to compute the cumulative minimum.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device);\n    ///    let result = tensor.clone().cummin(0);\n    ///    println!(\"{result}\");\n    ///    // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]]\n    ///    let result = tensor.cummin(1);\n    ///    println!(\"{result}\");\n    ///    // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]]\n    /// }\n    /// ```\n    pub fn cummin(self, dim: usize) -> Self {\n        check!(TensorCheck::aggregate_dim::<D>(\"CumMin\", dim));\n        Self::new(K::cummin(self.primitive, dim))\n    }\n\n    /// Computes the cumulative maximum of elements along the given *dimension* or *axis*.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to compute the cumulative maximum.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]], &device);\n    ///    let result = tensor.clone().cummax(0);\n    ///    println!(\"{result}\");\n    ///    // [[3.0, 1.0, 2.0], [4.0, 5.0, 2.0]]\n    ///    let result = tensor.cummax(1);\n    ///    println!(\"{result}\");\n    ///    // [[3.0, 3.0, 3.0], [4.0, 5.0, 5.0]]\n    /// }\n    /// ```\n    pub fn cummax(self, dim: usize) -> Self {\n        check!(TensorCheck::aggregate_dim::<D>(\"CumMax\", dim));\n        Self::new(K::cummax(self.primitive, dim))\n    }\n    /// Find the maximum value along the given dimension.\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension or axis along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimension will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.max_dim(0);\n    ///   println!(\"{tensor}\");\n    ///   // [[5.0, 9.0, 6.0]]\n    /// }\n    /// ```\n    pub fn max_dim<I: AsIndex>(self, dim: I) -> Self {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::aggregate_dim::<D>(\"Max\", dim));\n        Tensor::new(K::max_dim(self.primitive, dim))\n    }\n\n    /// Find the maximum value along the given dimensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `dims` - The dimensions or axis along which to aggregate the elements;\n    ///   supports negative indexing.\n    ///\n    /// # Returns\n    ///\n    /// The returned tensor will have the same rank,\n    /// but the aggregated dimensions will have size 1.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]], &device);\n    ///   let tensor = tensor.max_dims(&[0, 1]);\n    ///   println!(\"{tensor}\");\n    ///   // [[9.0]]\n    /// }\n    /// ```\n    pub fn max_dims<I: AsIndex>(self, dims: &[I]) -> Self {\n        dims.iter().fold(self, |tensor, &dim| tensor.max_dim(dim))\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/pad.rs",
    "content": "use alloc::vec::Vec;\nuse core::ops::Range;\n\nuse crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};\n\nuse super::Numeric;\n\n/// Trait for types that can be used as padding specifications.\n///\n/// Padding is specified as `(before, after)` pairs per dimension, returned as a\n/// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided,\n/// they apply to the **last** N dimensions (earlier dimensions are left unpadded).\npub trait IntoPadding<const D: usize> {\n    /// Converts into a fixed-size array of `(before, after)` padding pairs.\n    fn into_padding(self) -> [(usize, usize); D];\n}\n\nimpl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {\n    fn into_padding(self) -> [(usize, usize); D] {\n        assert!(\n            N <= D,\n            \"Padding has {} pairs but tensor only has {} dimensions\",\n            N,\n            D\n        );\n        let mut result = [(0usize, 0usize); D];\n        let offset = D - N;\n        for (i, pair) in self.into_iter().enumerate() {\n            result[offset + i] = pair;\n        }\n        result\n    }\n}\n\n/// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions.\n///\n/// Equivalent to `[(top, bottom), (left, right)]`.\nimpl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {\n    fn into_padding(self) -> [(usize, usize); D] {\n        let (left, right, top, bottom) = self;\n        let mut result = [(0usize, 0usize); D];\n        result[D - 2] = (top, bottom);\n        result[D - 1] = (left, right);\n        result\n    }\n}\n\nimpl<const D: usize> IntoPadding<D> for &[(usize, usize)] {\n    fn into_padding(self) -> [(usize, usize); D] {\n        assert!(\n            self.len() <= D,\n            \"Padding has {} pairs but tensor only has {} dimensions\",\n            self.len(),\n            D\n        );\n        let mut result = [(0usize, 0usize); D];\n        let offset = D - self.len();\n        for (i, &pair) in self.iter().enumerate() {\n            result[offset + i] = pair;\n        }\n        result\n    }\n}\n\nimpl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {\n    fn into_padding(self) -> [(usize, usize); D] {\n        assert!(\n            self.len() <= D,\n            \"Padding has {} pairs but tensor only has {} dimensions\",\n            self.len(),\n            D\n        );\n        let mut result = [(0usize, 0usize); D];\n        let offset = D - self.len();\n        for (i, pair) in self.into_iter().enumerate() {\n            result[offset + i] = pair;\n        }\n        result\n    }\n}\n\n/// Helper to build a range array for slice_assign, selecting a portion of one dimension.\nfn build_slice_ranges<const D: usize>(\n    dims: [usize; D],\n    target_dim: usize,\n    start: usize,\n    len: usize,\n) -> [Range<usize>; D] {\n    dims.iter()\n        .enumerate()\n        .map(|(i, &size)| {\n            if i == target_dim {\n                start..start + len\n            } else {\n                0..size\n            }\n        })\n        .collect::<Vec<Range<usize>>>()\n        .try_into()\n        .unwrap()\n}\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    /// Pads the tensor using the specified padding mode.\n    ///\n    /// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions\n    /// are provided, they apply to the **last** N dimensions (unspecified leading dimensions\n    /// are left unpadded).\n    ///\n    /// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted,\n    /// which pads the last two dimensions.\n    ///\n    /// # Arguments\n    ///\n    /// * `padding` - Padding specification. Accepts:\n    ///   - `[(before, after); N]` fixed-size array of pairs (N <= D)\n    ///   - `&[(before, after)]` slice of pairs per dimension\n    ///   - `Vec<(before, after)>` vector of pairs\n    ///   - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility\n    /// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.\n    ///\n    /// # Returns\n    ///\n    /// A new tensor with the specified padding applied.\n    ///\n    /// # Panics\n    ///\n    /// - Panics if more padding pairs are provided than tensor dimensions.\n    /// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.\n    /// - `Edge` mode panics if padding is applied to a zero-sized dimension.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Shape};\n    /// use burn_tensor::ops::PadMode;\n    ///\n    /// fn example<B: Backend<FloatElem: From<f32>>>() {\n    ///    let device = B::Device::default();\n    ///    let tensor = Tensor::<B, 2>::from_data([[12.0, -2.0, 3.0], [5.0, 3.0, 6.0]], &device);\n    ///\n    ///    // Constant padding with value 0.0 (backward-compatible tuple)\n    ///    let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));\n    ///\n    ///    // Pad arbitrary dimensions with slice of (before, after) pairs\n    ///    let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0));\n    ///\n    ///    // Pad only the last dimension\n    ///    let padded = tensor.pad([(1, 1)], PadMode::Reflect);\n    /// }\n    /// ```\n    pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {\n        let pairs = padding.into_padding();\n        match mode.into() {\n            PadMode::Constant(value) => pad_constant(self, &pairs, value),\n            PadMode::Reflect => pad_reflect(self, &pairs),\n            PadMode::Edge => pad_edge(self, &pairs),\n        }\n    }\n}\n\n/// Pad with a constant value.\nfn pad_constant<B, const D: usize, K, E>(\n    tensor: Tensor<B, D, K>,\n    padding: &[(usize, usize); D],\n    value: E,\n) -> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n    E: ElementConversion,\n{\n    let mut padded_dims: [usize; D] = tensor.dims();\n\n    for (i, &(before, after)) in padding.iter().enumerate() {\n        padded_dims[i] += before + after;\n    }\n\n    let ranges: [Range<usize>; D] = padded_dims\n        .iter()\n        .enumerate()\n        .map(|(i, &dim)| {\n            let (before, after) = padding[i];\n            before..dim - after\n        })\n        .collect::<Vec<Range<usize>>>()\n        .try_into()\n        .unwrap();\n\n    let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());\n\n    padded_tensor.slice_assign(ranges, tensor)\n}\n\n/// Pad using reflection at the boundaries (excluding edge values).\n///\n/// For ONNX \"reflect\" mode: mirrors from index 1, not index 0.\n/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`\nfn pad_reflect<B, const D: usize, K>(\n    tensor: Tensor<B, D, K>,\n    padding: &[(usize, usize); D],\n) -> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    let dims = tensor.dims();\n\n    for (i, &(before, after)) in padding.iter().enumerate() {\n        if before > 0 || after > 0 {\n            assert!(\n                before < dims[i] && after < dims[i],\n                \"Reflect padding ({}, {}) must be less than dimension {} size ({})\",\n                before,\n                after,\n                i,\n                dims[i]\n            );\n        }\n    }\n\n    let mut result = tensor;\n\n    for (i, &(before, after)) in padding.iter().enumerate() {\n        if before > 0 || after > 0 {\n            result = pad_reflect_dim(result, i, before, after);\n        }\n    }\n\n    result\n}\n\n/// Helper to pad a single dimension using reflection.\nfn pad_reflect_dim<B, const D: usize, K>(\n    tensor: Tensor<B, D, K>,\n    dim: usize,\n    pad_before: usize,\n    pad_after: usize,\n) -> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    let dims = tensor.dims();\n    let dim_size = dims[dim];\n\n    // Calculate output dimensions\n    let mut output_dims = dims;\n    output_dims[dim] += pad_before + pad_after;\n\n    // Create output tensor and place original in the center\n    let output = Tensor::zeros(output_dims, &tensor.device());\n    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);\n    let mut output = output.slice_assign(original_range, tensor.clone());\n\n    // Assign reflected \"before\" padding (e.g., top or left)\n    // Reflect excludes the edge, so we take indices [1..pad_before+1] and flip\n    if pad_before > 0 {\n        let before_slice = tensor.clone().narrow(dim, 1, pad_before);\n        let before_flipped = before_slice.flip([dim as isize]);\n        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);\n        output = output.slice_assign(before_range, before_flipped);\n    }\n\n    // Assign reflected \"after\" padding (e.g., bottom or right)\n    // Take indices [dim_size - pad_after - 1..dim_size - 1] and flip\n    if pad_after > 0 {\n        let start = dim_size - pad_after - 1;\n        let after_slice = tensor.narrow(dim, start, pad_after);\n        let after_flipped = after_slice.flip([dim as isize]);\n        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);\n        output = output.slice_assign(after_range, after_flipped);\n    }\n\n    output\n}\n\n/// Pad by replicating edge values.\n///\n/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`\nfn pad_edge<B, const D: usize, K>(\n    tensor: Tensor<B, D, K>,\n    padding: &[(usize, usize); D],\n) -> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    let dims = tensor.dims();\n\n    for (i, &(before, after)) in padding.iter().enumerate() {\n        if before > 0 || after > 0 {\n            assert!(\n                dims[i] > 0,\n                \"Cannot apply edge padding to zero-sized dimension {}\",\n                i\n            );\n        }\n    }\n\n    let mut result = tensor;\n\n    for (i, &(before, after)) in padding.iter().enumerate() {\n        if before > 0 || after > 0 {\n            result = pad_edge_dim(result, i, before, after);\n        }\n    }\n\n    result\n}\n\n/// Helper to pad a single dimension by replicating edge values.\nfn pad_edge_dim<B, const D: usize, K>(\n    tensor: Tensor<B, D, K>,\n    dim: usize,\n    pad_before: usize,\n    pad_after: usize,\n) -> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: Numeric<B>,\n    K::Elem: Element,\n{\n    let dims = tensor.dims();\n    let dim_size = dims[dim];\n\n    // Calculate output dimensions\n    let mut output_dims = dims;\n    output_dims[dim] += pad_before + pad_after;\n\n    // Create output tensor and place original in the center\n    let output = Tensor::zeros(output_dims, &tensor.device());\n    let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);\n    let mut output = output.slice_assign(original_range, tensor.clone());\n\n    // Assign \"before\" padding by repeating the first element\n    if pad_before > 0 {\n        let first_slice = tensor.clone().narrow(dim, 0, 1);\n        let before_pad = first_slice.repeat_dim(dim, pad_before);\n        let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);\n        output = output.slice_assign(before_range, before_pad);\n    }\n\n    // Assign \"after\" padding by repeating the last element\n    if pad_after > 0 {\n        let last_slice = tensor.narrow(dim, dim_size - 1, 1);\n        let after_pad = last_slice.repeat_dim(dim, pad_after);\n        let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);\n        output = output.slice_assign(after_range, after_pad);\n    }\n\n    output\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/take.rs",
    "content": "use crate::{AsIndex, BasicOps, Int, Tensor, backend::Backend, check, check::TensorCheck};\nuse alloc::vec::Vec;\n\nimpl<B, const D: usize, K> Tensor<B, D, K>\nwhere\n    B: Backend,\n    K: BasicOps<B>,\n{\n    /// Takes elements from the tensor along the given dimension using indices of any dimensionality.\n    ///\n    /// This behaves like numpy's take function. When indices is multi-dimensional,\n    /// the output shape will be: input.shape\\[:dim\\] + indices.shape + input.shape\\[dim+1:\\]\n    ///\n    /// # Arguments\n    ///\n    /// * `dim` - The dimension along which to select elements. Supports negative indexing.\n    /// * `indices` - The indices of elements to select. Can be any dimensionality.\n    ///   Must be valid indices in the range [0, dim_size).\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::{Tensor, Int};\n    ///\n    /// fn example<B: Backend>() {\n    ///   let device = B::Device::default();\n    ///\n    ///   // Example with 1D indices\n    ///   let tensor = Tensor::<B, 2>::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], &device);\n    ///   let indices = Tensor::<B, 1, Int>::from_data([2, 0, 1], &device);\n    ///   let result: Tensor<B, 2> = tensor.clone().take::<1, 2>(-1, indices);  // -1 refers to last dimension\n    ///   println!(\"{result}\");\n    ///   // [[3.0, 1.0, 2.0], [6.0, 4.0, 5.0]]\n    ///\n    ///   // Example with 2D indices - output will have +1 dimension (2D -> 3D)\n    ///   let indices_2d = Tensor::<B, 2, Int>::from_data([[0, 2], [1, 0]], &device);\n    ///   let result: Tensor<B, 3> = tensor.take::<2, 3>(1, indices_2d);\n    ///   println!(\"{result}\");\n    ///   // [[[1.0, 3.0], [2.0, 1.0]], [[4.0, 6.0], [5.0, 4.0]]]\n    /// }\n    /// ```\n    pub fn take<const DI: usize, const DO: usize>(\n        self,\n        dim: impl AsIndex,\n        indices: Tensor<B, DI, Int>,\n    ) -> Tensor<B, DO, K> {\n        let dim = dim.expect_dim_index(D);\n        check!(TensorCheck::take::<D, DI, DO>(dim));\n\n        // Store the indices shape for reshaping later\n        let indices_shape = indices.shape();\n        let indices_dims = indices_shape.clone();\n\n        // Flatten indices to 1D for processing\n        let indices_flat = indices.reshape([indices_shape.num_elements()]);\n\n        // Perform the selection with the flattened indices\n        let selected = self.select(dim, indices_flat);\n\n        // Build the output shape\n        // Output shape = input.shape[:dim] + indices.shape + input.shape[dim+1:]\n        let selected_shape = selected.shape();\n        let mut new_shape = Vec::with_capacity(DO);\n\n        // Add dimensions before the selected dimension\n        for i in 0..dim {\n            new_shape.push(selected_shape[i]);\n        }\n\n        // Add all indices dimensions\n        for &idx_dim in indices_dims.iter() {\n            new_shape.push(idx_dim);\n        }\n\n        // Add dimensions after the selected dimension\n        for i in (dim + 1)..D {\n            new_shape.push(selected_shape[i]);\n        }\n\n        // Verify we have the correct number of dimensions\n        assert_eq!(\n            new_shape.len(),\n            DO,\n            \"Internal error: shape calculation resulted in {} dims but expected {}\",\n            new_shape.len(),\n            DO\n        );\n\n        // Convert to fixed-size array for reshape\n        let mut shape_array = [0; DO];\n        for (i, &s) in new_shape.iter().enumerate() {\n            shape_array[i] = s;\n        }\n\n        selected.reshape(shape_array)\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/transaction.rs",
    "content": "use super::{BasicOps, Tensor};\nuse crate::{\n    TensorData,\n    backend::{Backend, ExecutionError},\n    ops::TransactionPrimitive,\n};\nuse alloc::vec::Vec;\n\n#[derive(Default)]\n/// A transaction can [read](Self::register) multiple tensors at once with a single operation improving\n/// compute utilization with optimized laziness.\n///\n/// # Example\n///\n/// ```rust,ignore\n///  let [output_data, loss_data, targets_data] = Transaction::default()\n///    .register(output)\n///    .register(loss)\n///    .register(targets)\n///    .execute()\n///    .try_into()\n///    .expect(\"Correct amount of tensor data\");\n/// ```\npub struct Transaction<B: Backend> {\n    op: TransactionPrimitive<B>,\n}\n\nimpl<B: Backend> Transaction<B> {\n    /// Add a [tensor](Tensor) to the transaction to be read.\n    pub fn register<const D: usize, K: BasicOps<B>>(mut self, tensor: Tensor<B, D, K>) -> Self {\n        K::register_transaction(&mut self.op, tensor.into_primitive());\n        self\n    }\n\n    /// Executes the transaction synchronously and returns the [data](TensorData) in the same order\n    /// in which they were [registered](Self::register).\n    pub fn execute(self) -> Vec<TensorData> {\n        burn_std::future::block_on(self.execute_async())\n            .expect(\"Error while reading data: use `try_execute` to handle error at runtime\")\n    }\n\n    /// Executes the transaction synchronously and returns the [data](TensorData) in the same\n    /// order in which they were [registered](Self::register).\n    ///\n    /// # Returns\n    ///\n    /// Any error that might have occurred since the last time the device was synchronized.\n    pub fn try_execute(self) -> Result<Vec<TensorData>, ExecutionError> {\n        burn_std::future::block_on(self.execute_async())\n    }\n\n    /// Executes the transaction asynchronously and returns the [data](TensorData) in the same order\n    /// in which they were [registered](Self::register).\n    pub async fn execute_async(self) -> Result<Vec<TensorData>, ExecutionError> {\n        self.op.execute_async().await\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/api/trunc.rs",
    "content": "use crate::{Float, Tensor, TensorPrimitive, backend::Backend};\n\nimpl<B, const D: usize> Tensor<B, D, Float>\nwhere\n    B: Backend,\n{\n    /// Truncates the tensor element-wise, rounding toward zero.\n    ///\n    /// This function returns a new tensor with the same shape as the input tensor,\n    /// where each element is truncated toward zero. For positive values, this is\n    /// equivalent to floor, and for negative values, it's equivalent to ceil.\n    ///\n    /// # Special Cases (IEEE 754 compliant)\n    ///\n    /// - `trunc(±0)` returns ±0 (preserves sign of zero)\n    /// - `trunc(±∞)` returns ±∞\n    /// - `trunc(NaN)` returns NaN\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same shape where each element has been truncated toward zero.\n    ///\n    /// # Example\n    ///\n    /// ```rust\n    /// use burn_tensor::backend::Backend;\n    /// use burn_tensor::Tensor;\n    ///\n    /// fn example<B: Backend>() {\n    ///     let device = B::Device::default();\n    ///     let tensor = Tensor::<B, 1>::from_data([2.3, -1.7, 0.5, -0.5, 3.9], &device);\n    ///     let truncated = tensor.trunc();\n    ///\n    ///     // Result: [2.0, -1.0, 0.0, -0.0, 3.0]\n    /// }\n    /// ```\n    pub fn trunc(self) -> Self {\n        Self::new(TensorPrimitive::Float(B::float_trunc(\n            self.primitive.tensor(),\n        )))\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/grid/affine_grid.rs",
    "content": "use crate::ElementConversion;\nuse crate::backend::Backend;\nuse crate::s;\nuse crate::tensor::{Int, Tensor};\nuse alloc::vec;\n\n/// Generate a tensor with homogeonous coordinates of each element's\n/// transformed location\n///\n///\n/// See:\n///  - [torch.nn.functional.affine_grid](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html)\n///\n/// * `transform` - Transformation with shape (batch_size, 2, 3)\n/// * `dims` - dimensions as (batch_size, channels, height, width)\n///\n/// # Returns\n///\n/// Tensor with shape (batch_size, height, width, 2), where dim 2 is (x, y)\n/// All coordinates are broadcast on the batch dim\npub fn affine_grid_2d<B: Backend>(transform: Tensor<B, 3>, dims: [usize; 4]) -> Tensor<B, 4> {\n    let [batch_size, _c, height, width] = dims;\n\n    let device = &transform.device();\n\n    let x = Tensor::<B, 1, Int>::arange(0..width as i64, device)\n        .reshape([1, width])\n        .expand([height, width]);\n    let y = Tensor::<B, 1, Int>::arange(0..height as i64, device)\n        .reshape([height, 1])\n        .expand([height, width]);\n\n    // from ints (0..(width-1)) and (0..(height-1)), to (-1.0..1.0)\n    let x = x\n        .float()\n        .div_scalar(((width - 1) as f32 / 2.0).elem::<f32>())\n        .sub_scalar((1_f32).elem::<f32>());\n    let y = y\n        .float()\n        .div_scalar(((height - 1) as f32 / 2.0).elem::<f32>())\n        .sub_scalar((1_f32).elem::<f32>());\n\n    // Broadcast to batch dimension\n    let x = x.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]\n    let y = y.unsqueeze_dim::<3>(0).expand([batch_size, height, width]); // [B, H, W]\n\n    // Apply affine transform\n    let a_11 = transform.clone().slice(s![.., 0, 0]);\n    let a_12 = transform.clone().slice(s![.., 0, 1]);\n    let trans_x = transform.clone().slice(s![.., 0, 2]);\n\n    let a_21 = transform.clone().slice(s![.., 1, 0]);\n    let a_22 = transform.clone().slice(s![.., 1, 1]);\n    let trans_y = transform.slice(s![.., 1, 2]);\n\n    let grid_x = a_11.mul(x.clone()).add(a_12.mul(y.clone())).add(trans_x);\n    let grid_y = a_21.mul(x).add(a_22.mul(y)).add(trans_y);\n\n    Tensor::stack(vec![grid_x, grid_y], 3)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/grid/meshgrid.rs",
    "content": "use crate::backend::Backend;\nuse crate::tensor::grid::{GridIndexing, GridOptions, GridSparsity, IndexPos};\nuse crate::tensor::{BasicOps, Tensor};\nuse alloc::vec::Vec;\n\n/// Return a collection of coordinate matrices for coordinate vectors.\n///\n/// Takes N 1D tensors and returns N tensors where each tensor represents the coordinates\n/// in one dimension across an N-dimensional grid.\n///\n/// Based upon `options.sparse`, the generated coordinate tensors can either be `Sparse` or `Dense`:\n/// * In `Sparse` mode, output tensors will have shape 1 everywhere except their cardinal dimension.\n/// * In `Dense` mode, output tensors will be expanded to the full grid shape.\n///\n/// Based upon `options.indexing`, the generated coordinate tensors will use either:\n/// * `Matrix` indexing, where dimensions are in the same order as their cardinality.\n/// * `Cartesian` indexing; where the first two dimensions are swapped.\n///\n/// See:\n///  - [numpy.meshgrid](https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html)\n///  - [torch.meshgrid](https://pytorch.org/docs/stable/generated/torch.meshgrid.html)\n///\n/// # Arguments\n///\n/// * `tensors` - A slice of 1D tensors\n/// * `options` - the options.\n///\n/// # Returns\n///\n/// A vector of N N-dimensional tensors representing the grid coordinates.\npub fn meshgrid<B: Backend, const N: usize, K, O>(\n    tensors: &[Tensor<B, 1, K>; N],\n    options: O,\n) -> [Tensor<B, N, K>; N]\nwhere\n    K: BasicOps<B>,\n    O: Into<GridOptions>,\n{\n    let options = options.into();\n    let swap_dims = options.indexing == GridIndexing::Cartesian && N > 1;\n    let dense = options.sparsity == GridSparsity::Dense;\n\n    let grid_shape: [usize; N] = tensors\n        .iter()\n        .map(|t| t.dims()[0])\n        .collect::<Vec<_>>()\n        .try_into()\n        .unwrap();\n\n    tensors\n        .iter()\n        .enumerate()\n        .map(|(i, tensor)| {\n            let mut coord_tensor_shape = [1; N];\n            coord_tensor_shape[i] = grid_shape[i];\n\n            // Reshape the tensor to have singleton dimensions in all but the i-th dimension\n            let mut tensor = tensor.clone().reshape(coord_tensor_shape);\n\n            if dense {\n                tensor = tensor.expand(grid_shape);\n            }\n            if swap_dims {\n                tensor = tensor.swap_dims(0, 1);\n            }\n\n            tensor\n        })\n        .collect::<Vec<_>>()\n        .try_into()\n        .unwrap()\n}\n\n/// Return a coordinate matrix for a given set of 1D coordinate tensors.\n///\n/// Equivalent to stacking a dense matrix `meshgrid`,\n/// where the stack is along the first or last dimension.\n///\n/// # Arguments\n///\n/// * `tensors`: A slice of 1D tensors.\n/// * `index_pos`: The position of the index in the output tensor.\n///\n/// # Returns\n///\n/// A tensor of either ``(N, ..., |T[i]|, ...)`` or ``(..., |T[i]|, ..., N)``,\n/// of coordinates, indexed on the first or last dimension.\npub fn meshgrid_stack<B: Backend, const D: usize, const D2: usize, K>(\n    tensors: &[Tensor<B, 1, K>; D],\n    index_pos: IndexPos,\n) -> Tensor<B, D2, K>\nwhere\n    K: BasicOps<B>,\n{\n    assert_eq!(D2, D + 1, \"D2 ({D2}) != D ({D}) + 1\");\n\n    let xs: Vec<Tensor<B, D, K>> = meshgrid(tensors, GridOptions::default())\n        .into_iter()\n        .collect();\n\n    let dim = match index_pos {\n        IndexPos::First => 0,\n        IndexPos::Last => D,\n    };\n\n    Tensor::stack(xs, dim)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/grid/mod.rs",
    "content": "mod affine_grid;\nmod meshgrid;\n\npub use meshgrid::*;\n\npub use affine_grid::*;\n\n/// Enum to specify index cardinal layout.\n#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]\npub enum GridIndexing {\n    /// Dimensions are in the same order as the cardinality of the inputs.\n    /// Equivalent to \"ij\" indexing in NumPy and PyTorch.\n    #[default]\n    Matrix,\n\n    /// The same as Matrix, but the first two dimensions are swapped.\n    /// Equivalent to \"xy\" indexing in NumPy and PyTorch.\n    Cartesian,\n}\n\n/// Enum to specify grid sparsity mode.\n#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]\npub enum GridSparsity {\n    /// The grid is fully expanded to the full cartesian product shape.\n    #[default]\n    Dense,\n\n    /// The grid is sparse, expanded only at the cardinal dimensions.\n    Sparse,\n}\n\n/// Grid policy options.\n#[derive(new, Default, Debug, Copy, Clone)]\npub struct GridOptions {\n    /// Indexing mode.\n    pub indexing: GridIndexing,\n\n    /// Sparsity mode.\n    pub sparsity: GridSparsity,\n}\n\nimpl From<GridIndexing> for GridOptions {\n    fn from(value: GridIndexing) -> Self {\n        Self {\n            indexing: value,\n            ..Default::default()\n        }\n    }\n}\nimpl From<GridSparsity> for GridOptions {\n    fn from(value: GridSparsity) -> Self {\n        Self {\n            sparsity: value,\n            ..Default::default()\n        }\n    }\n}\n\n/// Enum to specify the index dimension position.\n#[derive(Default, Debug, Copy, Clone)]\npub enum IndexPos {\n    /// The index is in the first dimension.\n    #[default]\n    First,\n\n    /// The index is in the last dimension.\n    Last,\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/cosine_similarity.rs",
    "content": "use crate::ElementConversion;\nuse crate::backend::Backend;\nuse crate::tensor::Tensor;\n\nuse super::l2_norm;\n\n/// Default epsilon value to avoid division by zero\npub const DEFAULT_EPSILON: f64 = 1e-8;\n/// Computes the cosine similarity between two tensors along a specified dimension.\n///\n/// Calculates the cosine of the angle between inputs as their dot product divided\n/// by the product of their L2 norms.\n///\n/// # Arguments\n///\n/// * `x1` - First input tensor\n/// * `x2` - Second input tensor\n/// * `dim` - Dimension along which to compute the similarity\n///   (negative indices allowed: -1 for last dimension)\n/// * `eps` - Small value to avoid division by zero (default: 1e-8)\n///\n/// # Returns\n///\n/// Tensor containing the cosine similarity between x1 and x2\npub fn cosine_similarity<B: Backend, const D: usize>(\n    x1: Tensor<B, D>,\n    x2: Tensor<B, D>,\n    dim: i32,\n    eps: Option<B::FloatElem>,\n) -> Tensor<B, D> {\n    let eps = eps.unwrap_or_else(|| B::FloatElem::from_elem(DEFAULT_EPSILON));\n\n    // Convert negative dimension to positive\n    let dim_idx = if dim < 0 { D as i32 + dim } else { dim } as usize;\n\n    // Compute dot product: sum(x1 * x2) along the specified dimension\n    let dot_product = (x1.clone() * x2.clone()).sum_dim(dim_idx);\n\n    // Compute L2 norms: ||x1|| and ||x2||\n    let norm_x1 = l2_norm(x1, dim_idx);\n    let norm_x2 = l2_norm(x2, dim_idx);\n\n    // Calculate the denominator (product of the norms) with epsilon to avoid division by zero\n    let denominator = norm_x1.clamp_min(eps) * norm_x2.clamp_min(eps);\n\n    // Return the cosine similarity (dot product divided by the product of norms)\n    dot_product / denominator\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/diag.rs",
    "content": "use crate::backend::Backend;\nuse crate::check;\nuse crate::check::TensorCheck;\nuse crate::tensor::{Int, Tensor};\nuse crate::{BasicOps, TensorKind};\n\n/// Returns the diag of a matrix.\n///\n/// For batched inputs, returns of each matrix in the batch independently.\n///\n/// The diag operation extracts the diagonal elements of the last two dimensions,\n/// treating them as the matrix dimensions, while preserving all leading batch dimensions.\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor with at least 2 dimensions.\n///\n/// # Returns\n/// A tensor of rank `D - 1`, where the last dimension contains the diagonal elements of the input.\npub fn diag<B: Backend, const D: usize, const DO: usize, K>(\n    tensor: Tensor<B, D, K>,\n) -> Tensor<B, DO, K>\nwhere\n    K: TensorKind<B> + BasicOps<B>,\n{\n    check!(TensorCheck::diag::<D, DO>());\n\n    let shape = tensor.shape();\n    let rows = shape[D - 2];\n    let cols = shape[D - 1];\n    let diag_len = rows.min(cols);\n    let device = tensor.device();\n\n    // create the indices for the diag\n    let mut flat_shape = shape.clone();\n    flat_shape[D - 2] = rows * cols;\n    flat_shape[D - 1] = 1;\n    let flat: Tensor<B, D, K> = tensor.reshape(flat_shape);\n\n    let range = Tensor::<B, 1, Int>::arange(0..diag_len as i64, &device);\n    let step_tensor = Tensor::<B, 1, Int>::from_data([cols as i64 + 1], &device);\n    let indices = range * step_tensor;\n    flat.take::<1, D>(D - 2, indices).squeeze_dim(D - 1)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/lu_decomposition.rs",
    "content": "use crate::{\n    Int, backend::Backend, cast::ToElement, check, check::TensorCheck, linalg::swap_slices, s,\n    tensor::Tensor,\n};\n/// Performs PLU decomposition of a square matrix.\n///\n/// The function decomposes a given square matrix `A` into three matrices: a permutation vector `p`,\n/// a lower triangular matrix `L`, and an upper triangular matrix `U`, such that `PA = LU`.\n/// The permutation vector `p` represents the row swaps made during the decomposition process.\n/// The lower triangular matrix `L` has ones on its diagonal and contains the multipliers used\n/// during the elimination process below the diagonal. The upper triangular matrix `U` contains\n/// the resulting upper triangular form of the matrix after the elimination process.\n///\n/// # Arguments\n/// * `tensor` - A square matrix to decompose, represented as a 2D tensor.\n///\n/// # Returns\n/// A tuple containing:\n/// - A 2D tensor representing the combined `L` and `U` matrices.\n/// - A 1D tensor representing the permutation vector `p`.\n///\n/// # Panics and numerical issues\n/// - The function will panic if the input matrix is singular or near-singular.\n/// - The function will panic if the input matrix is not square.\n/// # Performance note (synchronization / device transfers)\n/// This function may involve multiple synchronizations and device transfers, especially\n/// when determining pivot elements and performing row swaps. This can impact performance,\npub fn lu_decomposition<B: Backend>(tensor: Tensor<B, 2>) -> (Tensor<B, 2>, Tensor<B, 1, Int>) {\n    check!(TensorCheck::is_square::<2>(\n        \"lu_decomposition\",\n        &tensor.shape()\n    ));\n    let dims = tensor.shape().dims::<2>();\n    let n = dims[0];\n\n    let mut permutations = Tensor::arange(0..n as i64, &tensor.device());\n    let mut tensor = tensor;\n\n    for k in 0..n {\n        // Find the pivot row\n        let p = tensor\n            .clone()\n            .slice(s![k.., k])\n            .abs()\n            .argmax(0)\n            .into_scalar()\n            .to_usize()\n            + k;\n        let max = tensor.clone().slice(s![p, k]).abs();\n\n        // Avoid division by zero\n        let pivot = max.into_scalar();\n        check!(TensorCheck::lu_decomposition_pivot::<B>(pivot));\n\n        if p != k {\n            tensor = swap_slices(tensor, s![k, ..], s![p, ..]);\n            permutations = swap_slices(permutations, s![k], s![p]);\n        }\n\n        // Normalize k-th column under the diagonal\n        if k < n - 1 {\n            let a_kk = tensor.clone().slice(s![k, k]);\n            let column = tensor.clone().slice(s![(k + 1).., k]) / a_kk;\n            tensor = tensor.slice_assign(s![(k + 1).., k], column);\n        }\n\n        // Update the trailing submatrix\n        for i in (k + 1)..n {\n            // a[i, k+1..] -=  a[i, k] * a[k, k+1..]\n            let a_ik = tensor.clone().slice(s![i, k]);\n            let row_k = tensor.clone().slice(s![k, (k + 1)..]);\n            let update = a_ik * row_k;\n            let row_i = tensor.clone().slice(s![i, (k + 1)..]);\n            tensor = tensor.slice_assign(s![i, (k + 1)..], row_i - update);\n        }\n    }\n\n    (tensor, permutations)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/matvec.rs",
    "content": "use crate::Numeric;\nuse crate::backend::Backend;\nuse crate::tensor::{BasicOps, Shape, Tensor};\n\n/// Performs matrix-vector multiplication with optional batch dimensions.\n///\n/// The `matrix` tensor is expected to have rank `DM` with the last two dimensions representing\n/// the matrix rows and columns. The `vector` tensor should have rank `DV = DM - 1`, sharing\n/// broadcast-compatible batch dimensions and matching the last dimension of the matrix.\n///\n/// # Panics\n///\n/// * If the matrix rank is lower than 2.\n/// * If the vector rank isn't one less than the matrix rank.\n/// * If batch dimensions differ between the operands.\n/// * If the inner dimensions are incompatible for multiplication.\npub fn matvec<B: Backend, const DM: usize, const DV: usize, K>(\n    matrix: Tensor<B, DM, K>,\n    vector: Tensor<B, DV, K>,\n) -> Tensor<B, DV, K>\nwhere\n    K: BasicOps<B> + Numeric<B>,\n{\n    assert!(\n        DM >= 2,\n        \"matvec expects the matrix to be at least rank 2 (got {DM})\"\n    );\n    assert!(\n        DM == DV + 1,\n        \"matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})\",\n    );\n\n    let matrix_dims = matrix.shape().dims::<DM>();\n    let vector_dims = vector.shape().dims::<DV>();\n\n    // Validate batch dimensions (all leading dimensions prior to the matrix axes).\n    let batch_rank = DM.saturating_sub(2);\n    if batch_rank > 0 {\n        let matrix_batch = Shape::from(&matrix_dims[..batch_rank]);\n        let vector_batch = Shape::from(&vector_dims[..batch_rank]);\n\n        assert!(\n            matrix_batch.broadcast(&vector_batch).is_ok(),\n            \"Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}\",\n            &matrix_dims[..batch_rank],\n            &vector_dims[..batch_rank]\n        );\n    }\n\n    let matrix_inner = matrix_dims[DM - 1];\n    let vector_inner = vector_dims[DV - 1];\n    assert!(\n        matrix_inner == vector_inner,\n        \"Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries\",\n    );\n\n    let vector_expanded = vector.unsqueeze_dim::<DM>(DV);\n    matrix.matmul(vector_expanded).squeeze_dim::<DV>(DM - 1)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/mod.rs",
    "content": "mod cosine_similarity;\nmod diag;\nmod lu_decomposition;\nmod matvec;\nmod outer;\nmod trace;\nmod vector_norm;\n\npub use cosine_similarity::*;\npub use diag::*;\npub use lu_decomposition::*;\npub use matvec::*;\npub use outer::*;\npub use trace::*;\npub use vector_norm::*;\n\nuse crate::{BasicOps, SliceArg, Tensor, TensorKind, backend::Backend};\n\n/// Swaps two slices of a tensor.\n/// # Arguments\n/// * `tensor` - The input tensor.\n/// * `slices1` - The first slice to swap.\n/// * `slices2` - The second slice to swap.\n/// # Returns\n/// A new tensor with the specified slices swapped.\n/// # Notes\n/// This method will be useful for matrix factorization algorithms.\nfn swap_slices<B: Backend, const D: usize, K, S>(\n    tensor: Tensor<B, D, K>,\n    slices1: S,\n    slices2: S,\n) -> Tensor<B, D, K>\nwhere\n    S: SliceArg + Clone,\n    K: TensorKind<B> + BasicOps<B>,\n{\n    let temporary = tensor.clone().slice(slices1.clone());\n    let tensor = tensor\n        .clone()\n        .slice_assign(slices1, tensor.slice(slices2.clone()));\n    tensor.slice_assign(slices2, temporary)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/outer.rs",
    "content": "use crate::backend::Backend;\nuse crate::tensor::{BasicOps, Tensor};\nuse crate::{AsIndex, Numeric};\n\n/// Computes the outer product for the last columns of 2 tensors.\n///\n/// See also: [`outer_dim`].\n///\n/// # Arguments\n/// - `lhs`: the \"row\" tensor, with shape ``[..., i]``.\n/// - `rhs`: the \"col\" tensor, with shape ``[..., j]``.\n///\n/// # Returns\n///\n/// A tensor of rank `R = D + 1`, where:\n///\n/// ``\n/// result[..., i, j] = lhs[..., i] * rhs[..., j]\n/// ``\npub fn outer<B: Backend, const D: usize, const R: usize, K>(\n    lhs: Tensor<B, D, K>,\n    rhs: Tensor<B, D, K>,\n) -> Tensor<B, R, K>\nwhere\n    K: BasicOps<B> + Numeric<B>,\n{\n    outer_dim(lhs, rhs, -1)\n}\n\n/// Computes the outer product along a specific dimension, broadcasting over others.\n///\n/// For the given `dim`, computes the outer product of elements along that dimension,\n/// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``.\n///\n/// # Arguments\n///\n/// - `lhs`: left operand, the \"row\" tensor, with size `M` at dimension `dim`.\n/// - `rhs`: right operand, the \"col\" tensor, with size `N` at dimension `dim`.\n/// - `dim`: dimension to compute the outer product along (supports negative indexing).\n///\n/// # Returns\n///\n/// A tensor of rank `R = D + 1`, where:\n///\n/// ``\n/// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...]\n/// ``\n//\n// Notes:\n// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant\n//   than broadcasted elemwise multiply; benchmarking needed to confirm.\npub fn outer_dim<B: Backend, const D: usize, const R: usize, Dim: AsIndex, K>(\n    lhs: Tensor<B, D, K>,\n    rhs: Tensor<B, D, K>,\n    dim: Dim,\n) -> Tensor<B, R, K>\nwhere\n    K: BasicOps<B> + Numeric<B>,\n{\n    assert_eq!(\n        R,\n        D + 1,\n        \"`outer` with D={D} expects R={} (got R={R})\",\n        D + 1\n    );\n    let dim = dim.expect_dim_index(D);\n\n    // (..., i, 1, ...)\n    let x = lhs.unsqueeze_dim::<R>(dim + 1);\n\n    // (..., 1, j, ...)\n    let y = rhs.unsqueeze_dim::<R>(dim);\n\n    // (..., i, j, ...)\n    x * y\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/trace.rs",
    "content": "use super::diag;\nuse crate::backend::Backend;\nuse crate::tensor::Tensor;\n\n/// Computes the trace of a matrix.\n///\n/// For batched inputs, computes the trace of each matrix in the batch independently.\n///\n/// The trace operation sums the diagonal elements of the last two dimensions,\n/// treating them as the matrix dimensions, while preserving all leading batch dimensions.\n///\n/// # Arguments\n///\n/// * `tensor` - The input tensor with at least 2 dimensions.\n///\n/// # Returns\n///\n/// A tensor of rank `D - 1`, where the last dimension contains the sum along the diagonals\n/// of the input.\npub fn trace<B: Backend, const D: usize, const DO: usize>(tensor: Tensor<B, D>) -> Tensor<B, DO> {\n    let diag_tensor = diag::<_, D, DO, _>(tensor);\n\n    diag_tensor.sum_dim(DO - 1)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/linalg/vector_norm.rs",
    "content": "use burn_backend::tensor::Ordered;\n\nuse crate::backend::Backend;\nuse crate::tensor::{BasicOps, Tensor};\nuse crate::{ElementConversion, Numeric};\n#[allow(unused_imports)]\nuse num_traits::float::Float;\n/// Specifies the type of norm to compute.\n#[derive(Debug, Clone, Copy, PartialEq)]\npub enum Norm {\n    /// L0 norm (count of non-zero elements)\n    L0,\n\n    /// L1 norm (sum of absolute values)\n    L1,\n\n    /// L2 norm (Euclidean norm)\n    L2,\n\n    /// L:INFINITY norm (maximum absolute value)\n    LInf,\n\n    /// L:NEG_INFINITY norm (minimum absolute value)\n    LNegInf,\n\n    /// Lp norm (generalized norm)\n    Lp(f64),\n}\n\nimpl Norm {\n    /// Get the exponent of the norm.\n    pub fn to_exponent(self) -> f64 {\n        use Norm::*;\n        match self {\n            L0 => 0.0,\n            L1 => 1.0,\n            L2 => 2.0,\n            LInf => f64::INFINITY,\n            LNegInf => f64::NEG_INFINITY,\n            Lp(p) => p,\n        }\n    }\n}\n\nimpl From<u32> for Norm {\n    fn from(value: u32) -> Self {\n        use Norm::*;\n        match value {\n            0 => L0,\n            1 => L1,\n            2 => L2,\n            u32::MAX => LInf,\n            _ => Lp(value as f64),\n        }\n    }\n}\n\nimpl From<i32> for Norm {\n    fn from(value: i32) -> Self {\n        use Norm::*;\n        match value {\n            0 => L0,\n            1 => L1,\n            2 => L2,\n            i32::MAX => LInf,\n            i32::MIN => LNegInf,\n            _ => Lp(value as f64),\n        }\n    }\n}\n\nimpl From<f32> for Norm {\n    fn from(value: f32) -> Self {\n        use Norm::*;\n        match value {\n            0.0 => L0,\n            1.0 => L1,\n            2.0 => L2,\n            f32::INFINITY => LInf,\n            f32::NEG_INFINITY => LNegInf,\n            _ => Lp(value as f64),\n        }\n    }\n}\n\nimpl From<f64> for Norm {\n    fn from(value: f64) -> Self {\n        use Norm::*;\n        match value {\n            0.0 => L0,\n            1.0 => L1,\n            2.0 => L2,\n            f64::INFINITY => LInf,\n            f64::NEG_INFINITY => LNegInf,\n            _ => Lp(value),\n        }\n    }\n}\n\n/// Computes the vector norm of a tensor along a specified dimension.\n///\n/// Generic dispatch wrapper over specialized / optimized norms.\n///\n/// See:\n/// - [torch.linalg.vector_norm](https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)\n/// - [numpy.linalg.vector_norm](https://numpy.org/doc/stable/reference/generated/numpy.linalg.vector_norm.html)\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `norm` - The selected norm.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The vector norm of the input tensor.\npub fn vector_norm<B: Backend, const D: usize>(\n    x: Tensor<B, D>,\n    norm: impl Into<Norm>,\n    dim: usize,\n) -> Tensor<B, D> {\n    lp_norm(x, norm.into().to_exponent(), dim)\n}\n\n/// Computes the general ``L(p)`` norm of a tensor along a specified dimension.\n///\n/// Uses the specialized implementations for:\n/// * 0.0\n/// * 1.0\n/// * 2.0\n/// * 2 * N for integral N,\n/// * f64::INFINITY,\n/// * f64::NEG_INFINITY,\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `p` - The exponent of the Lp norm.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The ``L(p)`` norm of the input tensor.\npub fn lp_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {\n    match p {\n        0.0 => l0_norm(x, dim),\n        1.0 => l1_norm(x, dim),\n        2.0 => l2_norm(x, dim),\n        p if is_even_integer(p) => lp_signed_norm(x, p as u32, dim),\n        f64::INFINITY => max_abs_norm(x, dim),\n        f64::NEG_INFINITY => min_abs_norm(x, dim),\n        _ => lp_norm_base(x, p, dim),\n    }\n}\n\n/// Normalize a tensor versus its `vector_norm`.\n///\n/// Equivalent to ``x.clone() / vector_norm(x, norm, dim).clamp_min(eps)``.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `norm` - The selected norm.\n/// * `dim` - The dimension to compute the norm over.\n/// * `eps` - The epsilon for the norm.\n///\n/// # Returns\n///\n/// The normalized tensor.\npub fn vector_normalize<B: Backend, const D: usize, E: ElementConversion>(\n    x: Tensor<B, D>,\n    norm: impl Into<Norm>,\n    dim: usize,\n    eps: E,\n) -> Tensor<B, D> {\n    let norm = vector_norm(x.clone(), norm, dim).clamp_min(eps);\n    x / norm\n}\n\n/// Computes the L0 norm of a tensor along a specified dimension.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The L0 norm of the input tensor.\npub fn l0_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>\nwhere\n    K: BasicOps<B> + Numeric<B>,\n{\n    x.zeros_like()\n        .mask_fill(x.not_equal_elem(0), 1)\n        .sum_dim(dim)\n}\n\n/// Computes the L1 norm of a tensor along a specified dimension.\n///\n/// This is a convenience function that wraps `vector_norm` with `p = 1.0`.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The L1 norm of the input tensor.\npub fn l1_norm<B: Backend, const D: usize, K>(x: Tensor<B, D, K>, dim: usize) -> Tensor<B, D, K>\nwhere\n    K: BasicOps<B> + Numeric<B>,\n{\n    x.abs().sum_dim(dim)\n}\n\n/// Computes the L2 norm of a tensor along a specified dimension.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The L2 norm of the input tensor.\npub fn l2_norm<B: Backend, const D: usize>(x: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    x.square().sum_dim(dim).sqrt()\n}\n\nfn is_even_integer(x: f64) -> bool {\n    x.fract() == 0.0 && (x as i64) % 2 == 0\n}\n\n/// Computes ``L(2*n)`` for even integer ``n``.\n///\n/// This lets us skip the abs.\nfn lp_signed_norm<B: Backend, const D: usize>(x: Tensor<B, D>, p: u32, dim: usize) -> Tensor<B, D> {\n    x.powi_scalar(p).sum_dim(dim).powf_scalar(1. / (p as f64))\n}\n\n/// Computes the general ``L(p)`` using the generalized method.\n///\n/// This uses no specialized implementations and cannot handle:\n/// * 0.0\n/// * f64::INFINITY,\n/// * f64::NEG_INFINITY,\nfn lp_norm_base<B: Backend, const D: usize>(x: Tensor<B, D>, p: f64, dim: usize) -> Tensor<B, D> {\n    x.abs().powf_scalar(p).sum_dim(dim).powf_scalar(1. / p)\n}\n\n/// Computes the L:INFINITY norm of a tensor along a specified dimension.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The L:INFINITY norm of the input tensor.\npub fn max_abs_norm<B: Backend, const D: usize, K>(\n    x: Tensor<B, D, K>,\n    dim: usize,\n) -> Tensor<B, D, K>\nwhere\n    K: Ordered<B>,\n{\n    x.max_abs_dim(dim)\n}\n\n/// Computes the L:NEG_INFINITY norm of a tensor along a specified dimension.\n///\n/// # Arguments\n///\n/// * `x` - The input tensor.\n/// * `dim` - The dimension to compute the norm over.\n///\n/// # Returns\n///\n/// The L:NEG_INFINITY norm of the input tensor.\npub fn min_abs_norm<B: Backend, const D: usize, K>(\n    x: Tensor<B, D, K>,\n    dim: usize,\n) -> Tensor<B, D, K>\nwhere\n    K: Ordered<B>,\n{\n    x.abs().min_dim(dim)\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/loss/mod.rs",
    "content": "use crate::backend::Backend;\nuse crate::{Tensor, activation};\n\n/// Computes the log softmax cross entropy between logits and target probabilities.\n///\n/// # Arguments\n///\n/// * `logits` - The logits.\n/// * `target_probs` - The target probabilities.\n///\n/// # Returns\n///\n/// The log softmax cross entropy.\npub fn cross_entropy_with_logits<B: Backend, const D: usize>(\n    logits: Tensor<B, D>,\n    target_probs: Tensor<B, D>,\n) -> Tensor<B, 1> {\n    let tensor = activation::log_softmax(logits, D - 1);\n    let tensor = tensor.mul(target_probs);\n    let tensor = tensor.sum_dim(D - 1);\n\n    tensor.mean().neg()\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/mod.rs",
    "content": "pub(crate) mod stats;\n\nmod api;\n\npub use api::*;\n\n// Re-exported types\npub use burn_backend::{\n    BoolDType, BoolStore, DType, DataError, FloatDType, IntDType, TensorData, TensorMetadata,\n    TensorPrimitive, Tolerance,\n    distribution::*,\n    element::*,\n    indexing::*,\n    ops::TransactionPrimitive,\n    shape::*,\n    slice::*,\n    tensor::{Bool, Float, Int, TensorKind},\n};\n\n/// The activation module.\npub mod activation;\n\n/// The backend module.\npub mod backend {\n    pub use burn_backend::backend::*;\n}\n\n/// The container module.\npub mod container {\n    pub use burn_backend::tensor::TensorContainer;\n}\n\n/// The grid module.\npub mod grid;\n\n/// The linalg module.\npub mod linalg;\n\n/// The loss module.\npub mod loss;\n\n/// The neural network module.\npub mod module;\n\n/// Operations on tensors module.\npub mod ops {\n    pub use burn_backend::backend::ops::*;\n    pub use burn_backend::tensor::{\n        BoolElem, BoolTensor, Device, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor,\n    };\n}\n\n/// Tensor quantization module.\npub mod quantization;\n\n#[cfg(feature = \"std\")]\npub use report::*;\n\n#[cfg(feature = \"std\")]\nmod report;\n\npub use ops::Device; // Re-export device so that it's available from `burn_tensor::Device`.\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/module.rs",
    "content": "use crate::{\n    Bool, Int, Tensor, TensorPrimitive,\n    backend::Backend,\n    check,\n    check::TensorCheck,\n    ops::{\n        AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,\n        PaddedConvOptions, UnfoldOptions,\n    },\n};\n\nuse super::ops::DeformConvOptions;\n\n/// Applies the [embedding module](crate::ops::ModuleOps::embedding).\npub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::embedding(\n        weights.primitive.tensor(),\n        indices.primitive,\n    )))\n}\n\n/// Applies a [1D convolution](crate::ops::ModuleOps::conv1d).\n///\n/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for\n/// asymmetric padding. When asymmetric padding is specified, an explicit pad\n/// operation is applied before the convolution backend op.\npub fn conv1d<B>(\n    x: Tensor<B, 3>,\n    weight: Tensor<B, 3>,\n    bias: Option<Tensor<B, 1>>,\n    options: impl Into<PaddedConvOptions<1>>,\n) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    let padded_options = options.into();\n    check!(TensorCheck::conv(\n        \"conv1d\",\n        x.dims(),\n        weight.dims(),\n        padded_options.options.groups,\n    ));\n\n    if let Some(padding_end) = padded_options.padding_end {\n        let left = padded_options.options.padding[0];\n        let right = padding_end[0];\n        // For 1D (NCL format), pad the length dimension\n        let padded = x.pad((left, right, 0, 0), PadMode::Constant(0.0));\n        let zero_options = ConvOptions::new(\n            padded_options.options.stride,\n            [0],\n            padded_options.options.dilation,\n            padded_options.options.groups,\n        );\n        Tensor::new(TensorPrimitive::Float(B::conv1d(\n            padded.primitive.tensor(),\n            weight.primitive.tensor(),\n            bias.map(|b| b.primitive.tensor()),\n            zero_options,\n        )))\n    } else {\n        Tensor::new(TensorPrimitive::Float(B::conv1d(\n            x.primitive.tensor(),\n            weight.primitive.tensor(),\n            bias.map(|b| b.primitive.tensor()),\n            padded_options.options,\n        )))\n    }\n}\n\n/// Applies a [2D convolution](crate::ops::ModuleOps::conv2d).\n///\n/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for\n/// asymmetric padding. When asymmetric padding is specified, an explicit pad\n/// operation is applied before the convolution backend op.\npub fn conv2d<B>(\n    x: Tensor<B, 4>,\n    weight: Tensor<B, 4>,\n    bias: Option<Tensor<B, 1>>,\n    options: impl Into<PaddedConvOptions<2>>,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    let padded_options = options.into();\n    check!(TensorCheck::conv(\n        \"conv2d\",\n        x.dims(),\n        weight.dims(),\n        padded_options.options.groups,\n    ));\n\n    if let Some(padding_end) = padded_options.padding_end {\n        let top = padded_options.options.padding[0];\n        let left = padded_options.options.padding[1];\n        let bottom = padding_end[0];\n        let right = padding_end[1];\n        // For 2D (NCHW format), pad height and width\n        let padded = x.pad((left, right, top, bottom), PadMode::Constant(0.0));\n        let zero_options = ConvOptions::new(\n            padded_options.options.stride,\n            [0, 0],\n            padded_options.options.dilation,\n            padded_options.options.groups,\n        );\n        Tensor::new(TensorPrimitive::Float(B::conv2d(\n            padded.primitive.tensor(),\n            weight.primitive.tensor(),\n            bias.map(|b| b.primitive.tensor()),\n            zero_options,\n        )))\n    } else {\n        Tensor::new(TensorPrimitive::Float(B::conv2d(\n            x.primitive.tensor(),\n            weight.primitive.tensor(),\n            bias.map(|b| b.primitive.tensor()),\n            padded_options.options,\n        )))\n    }\n}\n\n/// Applies a [3D convolution](crate::ops::ModuleOps::conv3d).\n///\n/// Accepts [`ConvOptions`] for symmetric padding, or [`PaddedConvOptions`] for\n/// asymmetric padding. Asymmetric 3D padding is not yet supported.\npub fn conv3d<B>(\n    x: Tensor<B, 5>,\n    weight: Tensor<B, 5>,\n    bias: Option<Tensor<B, 1>>,\n    options: impl Into<PaddedConvOptions<3>>,\n) -> Tensor<B, 5>\nwhere\n    B: Backend,\n{\n    let padded_options = options.into();\n    check!(TensorCheck::conv(\n        \"conv3d\",\n        x.dims(),\n        weight.dims(),\n        padded_options.options.groups,\n    ));\n\n    if padded_options.is_asymmetric() {\n        panic!(\"Asymmetric padding is not yet supported for conv3d\");\n    }\n\n    Tensor::new(TensorPrimitive::Float(B::conv3d(\n        x.primitive.tensor(),\n        weight.primitive.tensor(),\n        bias.map(|b| b.primitive.tensor()),\n        padded_options.options,\n    )))\n}\n\n/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).\npub fn deform_conv2d<B>(\n    x: Tensor<B, 4>,\n    offset: Tensor<B, 4>,\n    weight: Tensor<B, 4>,\n    mask: Option<Tensor<B, 4>>,\n    bias: Option<Tensor<B, 1>>,\n    options: DeformConvOptions<2>,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    check!(TensorCheck::conv(\n        \"deform_conv2d\",\n        x.dims(),\n        weight.dims(),\n        options.weight_groups,\n    ));\n    Tensor::new(TensorPrimitive::Float(B::deform_conv2d(\n        x.primitive.tensor(),\n        offset.primitive.tensor(),\n        weight.primitive.tensor(),\n        mask.map(|m| m.primitive.tensor()),\n        bias.map(|b| b.primitive.tensor()),\n        options,\n    )))\n}\n\n/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).\npub fn conv_transpose1d<B>(\n    x: Tensor<B, 3>,\n    weight: Tensor<B, 3>,\n    bias: Option<Tensor<B, 1>>,\n    options: ConvTransposeOptions<1>,\n) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    check!(TensorCheck::conv_transpose(\n        \"conv_transpose1d\",\n        x.dims(),\n        weight.dims(),\n    ));\n    Tensor::new(TensorPrimitive::Float(B::conv_transpose1d(\n        x.primitive.tensor(),\n        weight.primitive.tensor(),\n        bias.map(|b| b.primitive.tensor()),\n        options,\n    )))\n}\n\n/// Applies a [2D transposed convolution](crate::ops::ModuleOps::conv_transpose2d).\npub fn conv_transpose2d<B>(\n    x: Tensor<B, 4>,\n    weight: Tensor<B, 4>,\n    bias: Option<Tensor<B, 1>>,\n    options: ConvTransposeOptions<2>,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    check!(TensorCheck::conv_transpose(\n        \"conv_transpose2d\",\n        x.dims(),\n        weight.dims(),\n    ));\n    Tensor::new(TensorPrimitive::Float(B::conv_transpose2d(\n        x.primitive.tensor(),\n        weight.primitive.tensor(),\n        bias.map(|b| b.primitive.tensor()),\n        options,\n    )))\n}\n\n/// Applies a 3D transposed convolution](crate::ops::ModuleOps::conv_transpose3d).\npub fn conv_transpose3d<B>(\n    x: Tensor<B, 5>,\n    weight: Tensor<B, 5>,\n    bias: Option<Tensor<B, 1>>,\n    options: ConvTransposeOptions<3>,\n) -> Tensor<B, 5>\nwhere\n    B: Backend,\n{\n    check!(TensorCheck::conv_transpose(\n        \"conv_transpose3d\",\n        x.dims(),\n        weight.dims(),\n    ));\n    Tensor::new(TensorPrimitive::Float(B::conv_transpose3d(\n        x.primitive.tensor(),\n        weight.primitive.tensor(),\n        bias.map(|b| b.primitive.tensor()),\n        options,\n    )))\n}\n\n/// Applies a [4D to 3D unfold](crate::ops::ModuleOps::unfold4d).\npub fn unfold4d<B>(x: Tensor<B, 4>, kernel_size: [usize; 2], options: UnfoldOptions) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::unfold4d(\n        x.primitive.tensor(),\n        kernel_size,\n        options,\n    )))\n}\n\n/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).\npub fn max_pool1d<B>(\n    x: Tensor<B, 3>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    ceil_mode: bool,\n) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::max_pool1d(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        ceil_mode,\n    )))\n}\n\n/// Applies a [2D max pooling](crate::ops::ModuleOps::max_pool2d).\npub fn max_pool2d<B>(\n    x: Tensor<B, 4>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::max_pool2d(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        ceil_mode,\n    )))\n}\n\n/// Applies a [2D avg pooling](crate::ops::ModuleOps::avg_pool2d).\npub fn avg_pool2d<B>(\n    x: Tensor<B, 4>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::avg_pool2d(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        count_include_pad,\n        ceil_mode,\n    )))\n}\n\n/// Applies a [1D avg pooling](crate::ops::ModuleOps::avg_pool1d).\npub fn avg_pool1d<B>(\n    x: Tensor<B, 3>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    count_include_pad: bool,\n    ceil_mode: bool,\n) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::avg_pool1d(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        count_include_pad,\n        ceil_mode,\n    )))\n}\n\n/// Applies a [1D max pooling](crate::ops::ModuleOps::max_pool1d).\npub fn max_pool1d_with_indices<B>(\n    x: Tensor<B, 3>,\n    kernel_size: usize,\n    stride: usize,\n    padding: usize,\n    dilation: usize,\n    ceil_mode: bool,\n) -> (Tensor<B, 3>, Tensor<B, 3, Int>)\nwhere\n    B: Backend,\n{\n    let output = B::max_pool1d_with_indices(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        ceil_mode,\n    );\n\n    (\n        Tensor::new(TensorPrimitive::Float(output.output)),\n        Tensor::new(output.indices),\n    )\n}\n\n/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).\npub fn max_pool2d_with_indices<B>(\n    x: Tensor<B, 4>,\n    kernel_size: [usize; 2],\n    stride: [usize; 2],\n    padding: [usize; 2],\n    dilation: [usize; 2],\n    ceil_mode: bool,\n) -> (Tensor<B, 4>, Tensor<B, 4, Int>)\nwhere\n    B: Backend,\n{\n    let output = B::max_pool2d_with_indices(\n        x.primitive.tensor(),\n        kernel_size,\n        stride,\n        padding,\n        dilation,\n        ceil_mode,\n    );\n\n    (\n        Tensor::new(TensorPrimitive::Float(output.output)),\n        Tensor::new(output.indices),\n    )\n}\n\n/// Applies a [2D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool2d).\npub fn adaptive_avg_pool2d<B>(x: Tensor<B, 4>, output_size: [usize; 2]) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool2d(\n        x.primitive.tensor(),\n        output_size,\n    )))\n}\n\n/// Applies a [1D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool1d).\npub fn adaptive_avg_pool1d<B>(x: Tensor<B, 3>, output_size: usize) -> Tensor<B, 3>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool1d(\n        x.primitive.tensor(),\n        output_size,\n    )))\n}\n\n/// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate).\npub fn interpolate<B>(\n    x: Tensor<B, 4>,\n    output_size: [usize; 2],\n    options: InterpolateOptions,\n) -> Tensor<B, 4>\nwhere\n    B: Backend,\n{\n    Tensor::new(TensorPrimitive::Float(B::interpolate(\n        x.primitive.tensor(),\n        output_size,\n        options,\n    )))\n}\n\n/// Applies a linear transformation to the input tensor using the given weight and bias.\n///\n/// ```math\n/// y = x @ weight + [bias]\n/// ```\n///\n/// # Arguments:\n///\n/// - `input` is the input tensor, ``[..., d_input]``.\n/// - `weight` is the weight tensor, ``[d_input, d_output]``.\n/// - `bias` is the bias tensor (optional), ``[d_output]``.\n///\n/// # Returns:\n///\n/// The transformed tensor, ``[..., d_output]``.\n///\n/// # Compatibility\n///\n/// This function differs from PyTorch's ``torch.nn.functional.linear`` in that it does not\n/// transpose the weight matrix. In PyTorch, the weight matrix is transposed before\n/// multiplication:\n///\n/// ```math\n/// y = x @ weight^T + [bias]\n/// ```\npub fn linear<B: Backend, const D: usize>(\n    input: Tensor<B, D>,\n    weight: Tensor<B, 2>,\n    bias: Option<Tensor<B, 1>>,\n) -> Tensor<B, D> {\n    if D == 1 {\n        // Insert and remove an extra batch dimension for the batch matmul to work.\n        let input = input.unsqueeze::<2>();\n        let output = linear(input, weight, bias);\n        return output.squeeze_dim(0);\n    }\n\n    // Perform broadcasting\n    //\n    // Important to be done before doing operations to easily fuse.\n    let weight = weight.unsqueeze::<D>();\n    let bias = bias.map(|bias| bias.unsqueeze::<D>());\n\n    let output = input.matmul(weight);\n    match bias {\n        Some(bias) => output.add(bias),\n        None => output,\n    }\n}\n\n/// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,\n/// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`).\n/// Optionally applies masking, additive bias, causal masking, and softcap.\n///\n/// # Arguments\n/// - `query`: Query tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`\n/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`\n/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`\n/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,\n///   where `true` indicates positions to mask (i.e. set to -inf before softmax).\n/// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`\n///   added to the attention scores before softmax (e.g. ALiBi, relative position biases).\n/// - `options`: Additional attention options (custom scale, softcap, causal masking).\n///\n/// # Returns\n/// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`\n/// representing the attended context per head.\n///\n/// # Note\n/// This implementation does not support dropout and is intended for inference or\n/// use cases where dropout is not needed.\npub fn attention<B: Backend>(\n    query: Tensor<B, 4>,\n    key: Tensor<B, 4>,\n    value: Tensor<B, 4>,\n    mask: Option<Tensor<B, 4, Bool>>,\n    attn_bias: Option<Tensor<B, 4>>,\n    options: AttentionModuleOptions,\n) -> Tensor<B, 4> {\n    Tensor::new(TensorPrimitive::Float(B::attention(\n        query.primitive.tensor(),\n        key.primitive.tensor(),\n        value.primitive.tensor(),\n        mask.map(|mask| mask.primitive),\n        attn_bias.map(|bias| bias.primitive.tensor()),\n        options,\n    )))\n}\n\n/// Exports attention fallback to test backend's attention against.\npub fn attention_fallback<B: Backend>(\n    query: Tensor<B, 4>,\n    key: Tensor<B, 4>,\n    value: Tensor<B, 4>,\n    mask: Option<Tensor<B, 4, Bool>>,\n    attn_bias: Option<Tensor<B, 4>>,\n    options: AttentionModuleOptions,\n) -> Tensor<B, 4> {\n    Tensor::new(TensorPrimitive::Float(\n        crate::ops::attention::attention_fallback::<B>(\n            query.primitive.tensor(),\n            key.primitive.tensor(),\n            value.primitive.tensor(),\n            mask.map(|mask| mask.primitive),\n            attn_bias.map(|bias| bias.primitive.tensor()),\n            options,\n        ),\n    ))\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/quantization.rs",
    "content": "use crate::{Tensor, TensorPrimitive, backend::Backend};\nuse burn_backend::tensor::quantization;\n\n// We re-export those types.\npub use burn_backend::{QTensorPrimitive, quantization::*};\n\n/// The tensor quantization parameters.\npub type QuantizationParameters<B> = QParams<Tensor<B, 1>>;\n\n/// The observed input calibration range.\n#[derive(Clone, Debug)]\npub struct CalibrationRange<B: Backend> {\n    /// Minimum observed value(s).\n    pub min: Tensor<B, 1>,\n    /// Maximum observed value(s).\n    pub max: Tensor<B, 1>,\n}\n\n/// Compute the quantization range mapping.\npub fn compute_range<B: Backend, const D: usize>(\n    scheme: &QuantScheme,\n    tensor: &Tensor<B, D>,\n    calibration: &Calibration,\n) -> CalibrationRange<B> {\n    let (min, max) = match &tensor.primitive {\n        TensorPrimitive::Float(tensor) => {\n            quantization::compute_range::<B>(scheme, tensor.clone(), calibration)\n        }\n        TensorPrimitive::QFloat(_) => unreachable!(),\n    };\n\n    CalibrationRange {\n        min: Tensor::from_primitive(TensorPrimitive::Float(min)),\n        max: Tensor::from_primitive(TensorPrimitive::Float(max)),\n    }\n}\n\n/// Compute the quantization parameters.\npub fn compute_q_params<B: Backend>(\n    scheme: &QuantScheme,\n    range: CalibrationRange<B>,\n) -> QuantizationParameters<B> {\n    match (range.min.primitive, range.max.primitive) {\n        (TensorPrimitive::Float(min), TensorPrimitive::Float(max)) => {\n            let qparams = quantization::compute_q_params::<B>(scheme, min, max);\n            QuantizationParameters {\n                scales: Tensor::from_primitive(TensorPrimitive::Float(qparams.scales)),\n            }\n        }\n        _ => unreachable!(),\n    }\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/report.rs",
    "content": "use super::{Tensor, backend::Backend};\n\nuse colored::*;\n\n/// Checks the closeness of two tensors and prints the results.\n///\n/// Compares tensors by checking the absolute difference between each element.\n/// Prints the percentage of elements within specified tolerances.\n///\n/// # Arguments\n///\n/// * `output` - The output tensor.\n/// * `expected` - The expected tensor.\n///\n/// # Example\n///\n/// ```no_run\n/// use burn_tensor::backend::Backend;\n/// use burn_tensor::{check_closeness, Tensor};\n///\n/// fn example<B: Backend>() {\n///     let device = Default::default();\n///     let tensor1 = Tensor::<B, 1>::from_floats(\n///         [1.0, 2.0, 3.0, 4.0, 5.0, 6.001, 7.002, 8.003, 9.004, 10.1],\n///         &device,\n///     );\n///     let tensor2 = Tensor::<B, 1>::from_floats(\n///         [1.0, 2.0, 3.0, 4.000, 5.0, 6.0, 7.001, 8.002, 9.003, 10.004],\n///         &device,\n///     );\n///    check_closeness(&tensor1, &tensor2);\n///}\n/// ```\n///\n/// # Output\n///\n/// ```text\n/// Tensor Closeness Check Results:\n/// ===============================\n/// Epsilon: 1e-1\n///   Close elements: 10/10 (100.00%)\n///   [PASS] All elements are within tolerance\n///\n/// Epsilon: 1e-2\n///   Close elements: 10/10 (100.00%)\n///   [PASS] All elements are within tolerance\n///\n/// Epsilon: 1e-3\n///   Close elements: 9/10 (90.00%)\n///   [WARN] Most elements are within tolerance\n///\n/// Epsilon: 1e-4\n///   Close elements: 6/10 (60.00%)\n///   [FAIL] Significant differences detected\n///\n/// Epsilon: 1e-5\n///   Close elements: 5/10 (50.00%)\n///   [FAIL] Significant differences detected\n///\n/// Epsilon: 1e-6\n///   Close elements: 5/10 (50.00%)\n///   [FAIL] Significant differences detected\n///\n/// Epsilon: 1e-7\n///   Close elements: 5/10 (50.00%)\n///   [FAIL] Significant differences detected\n///\n/// Epsilon: 1e-8\n///   Close elements: 5/10 (50.00%)\n///   [FAIL] Significant differences detected\n///\n/// Closeness check complete.\n/// ```\npub fn check_closeness<B: Backend, const D: usize>(output: &Tensor<B, D>, expected: &Tensor<B, D>) {\n    println!(\"{}\", \"Tensor Closeness Check Results:\".bold());\n    println!(\"===============================\");\n\n    for epsilon in [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8].iter() {\n        println!(\"{} {:e}\", \"Epsilon:\".bold(), epsilon);\n\n        let close = output\n            .clone()\n            .is_close(expected.clone(), Some(*epsilon), Some(*epsilon));\n        let data = close.clone().into_data();\n        let num_elements = data.num_elements();\n\n        // Count the number of elements that are close (true)\n        let count = data.iter::<bool>().filter(|x| *x).count();\n\n        let percentage = (count as f64 / num_elements as f64) * 100.0;\n\n        println!(\"  Close elements: {count}/{num_elements} ({percentage:.2}%)\");\n\n        if percentage == 100.0 {\n            println!(\"  {} All elements are within tolerance\", \"[PASS]\".green());\n        } else if percentage >= 90.0 {\n            println!(\"  {} Most elements are within tolerance\", \"[WARN]\".yellow());\n        } else {\n            println!(\"  {} Significant differences detected\", \"[FAIL]\".red());\n        }\n\n        println!();\n    }\n\n    println!(\"{}\", \"Closeness check complete.\".bold());\n}\n"
  },
  {
    "path": "crates/burn-tensor/src/tensor/stats/mod.rs",
    "content": "use crate::{Tensor, backend::Backend};\nuse burn_backend::tensor::Int;\n\npub fn var<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    let mean = tensor.clone().mean_dim(dim);\n    var_with_mean(tensor, mean, dim)\n}\n\npub fn var_with_mean<B: Backend, const D: usize>(\n    tensor: Tensor<B, D>,\n    mean: Tensor<B, D>,\n    dim: usize,\n) -> Tensor<B, D> {\n    let n = tensor.shape()[dim] - 1;\n    var_with_mean_n(tensor, mean, dim, n)\n}\n\npub fn var_bias<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    let mean = tensor.clone().mean_dim(dim);\n    var_with_mean_bias(tensor, mean, dim)\n}\n\npub fn var_with_mean_bias<B: Backend, const D: usize>(\n    tensor: Tensor<B, D>,\n    mean: Tensor<B, D>,\n    dim: usize,\n) -> Tensor<B, D> {\n    let n = tensor.shape()[dim];\n    var_with_mean_n(tensor, mean, dim, n)\n}\n\npub fn var_with_mean_n<B: Backend, const D: usize>(\n    tensor: Tensor<B, D>,\n    mean: Tensor<B, D>,\n    dim: usize,\n    n: usize,\n) -> Tensor<B, D> {\n    tensor.sub(mean).square().sum_dim(dim).div_scalar(n as f32)\n}\n\npub fn median<B: Backend, const D: usize>(tensor: Tensor<B, D>, dim: usize) -> Tensor<B, D> {\n    let total_elem_numbers = tensor.dims()[dim];\n    let sorted_tensor = tensor.sort(dim);\n\n    // Following the PyTorch behavior:\n    // - Odd count: the median\n    // - Even count: the lower of the two median elements\n    //\n    // Example:\n    // - 5 elements: (5 - 1) / 2 = 4 / 2 = 2\n    // - 4 elements: (4 - 1) / 2 = 3 / 2 = 1\n    let median_index = (total_elem_numbers - 1) / 2;\n    sorted_tensor.narrow(dim, median_index, 1)\n}\n\npub fn median_with_indices<B: Backend, const D: usize>(\n    tensor: Tensor<B, D>,\n    dim: usize,\n) -> (Tensor<B, D>, Tensor<B, D, Int>) {\n    let total_elem_numbers = tensor.dims()[dim];\n    let (sorted_tensor, indices) = tensor.sort_with_indices(dim);\n\n    // Following the PyTorch behavior:\n    // - Odd count: the median\n    // - Even count: the lower of the two median elements\n    //\n    // Example:\n    // - 5 elements: (5 - 1) / 2 = 4 / 2 = 2\n    // - 4 elements: (4 - 1) / 2 = 3 / 2 = 1\n    let median_index = (total_elem_numbers - 1) / 2;\n    let median_values = sorted_tensor.narrow(dim, median_index, 1);\n    let median_indices = indices.narrow(dim, median_index, 1);\n    (median_values, median_indices)\n}\n"
  },
  {
    "path": "crates/burn-tensor-testgen/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ndescription = \"Test generation crate for burn-tensor\"\nedition.workspace = true\nlicense.workspace = true\nname = \"burn-tensor-testgen\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-tensor-testgen\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[lib]\nproc-macro = true\n\n[dependencies]\nproc-macro2 = { workspace = true }\nquote = { workspace = true }\nsyn = { workspace = true }\n"
  },
  {
    "path": "crates/burn-tensor-testgen/README.md",
    "content": "# Burn Tensor Test Generation\n\n> [Burn](https://github.com/tracel-ai/burn) tensor test generation\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tensor-testgen.svg)](https://crates.io/crates/burn-tensor-testgen)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-tensor-testgen/blob/master/README.md)\n"
  },
  {
    "path": "crates/burn-tensor-testgen/src/lib.rs",
    "content": "use proc_macro::TokenStream;\nuse quote::{format_ident, quote};\n\nuse syn::parse::{Parse, ParseStream};\nuse syn::punctuated::Punctuated;\nuse syn::token::Comma;\nuse syn::{Attribute, Expr, ItemFn, Lit, Meta, MetaNameValue, parse_macro_input};\n\n// Define a structure to parse the attribute arguments\nstruct AttributeArgs {\n    args: Punctuated<Meta, Comma>,\n}\n\nimpl Parse for AttributeArgs {\n    fn parse(input: ParseStream) -> syn::Result<Self> {\n        Ok(AttributeArgs {\n            args: Punctuated::parse_terminated(input)?,\n        })\n    }\n}\n\n#[allow(clippy::test_attr_in_doctest)]\n/// **This is only meaningful when the `reason` is specific and clear.**\n///\n/// A proc macro attribute that adds panic handling to test functions.\n///\n/// # Usage\n/// ```rust, ignore\n/// #[might_panic(reason = \"expected panic message prefix\")]\n/// #[test]\n/// fn test_that_might_panic() {\n///     // test code that might panic (with acceptable reason)\n/// }\n/// ```\n///\n/// # Behavior\n/// - If the test does not panic, it passes.\n/// - If the test panics with a message starting with the expected prefix, the failure is ignored.\n/// - If the test panics with a different message, the test fails.\n///\n/// # Note\n/// This proc macro uses [`std::panic::catch_unwind`]. As such, it does not work in a no-std environment.\n/// Make sure it is feature gated when an `\"std\"` feature is available.\n#[proc_macro_attribute]\npub fn might_panic(args: TokenStream, input: TokenStream) -> TokenStream {\n    // Parse the attribute arguments\n    let args = parse_macro_input!(args as AttributeArgs);\n    let input_fn = parse_macro_input!(input as ItemFn);\n\n    // Extract the expected panic reason\n    let mut expected_reason = None;\n    for arg in args.args.iter() {\n        if let Meta::NameValue(MetaNameValue { path, value, .. }) = arg\n            && path.is_ident(\"reason\")\n            && let Expr::Lit(lit) = value\n            && let Lit::Str(ref lit_str) = lit.lit\n        {\n            expected_reason = Some(lit_str.value());\n        }\n    }\n\n    let expected_reason = match expected_reason {\n        Some(reason) => reason,\n        None => {\n            return syn::Error::new(\n                proc_macro2::Span::call_site(),\n                \"The #[might_panic] attribute requires a 'reason' parameter\",\n            )\n            .to_compile_error()\n            .into();\n        }\n    };\n\n    let fn_name = &input_fn.sig.ident;\n    let fn_vis = &input_fn.vis;\n    let fn_generics = &input_fn.sig.generics;\n    let fn_block = &input_fn.block;\n    let fn_attrs = input_fn\n        .attrs\n        .iter()\n        .filter(|attr| !attr.path().is_ident(\"test\"))\n        .collect::<Vec<&Attribute>>();\n\n    // Create a wrapped test function\n    let wrapper_name = format_ident!(\"{}_might_panic\", fn_name);\n\n    quote! {\n        #(#fn_attrs)*\n        #fn_vis fn #fn_name #fn_generics() { #fn_block }\n\n        #[test]\n        #fn_vis fn #wrapper_name #fn_generics() {\n            use std::panic::{self, AssertUnwindSafe};\n            use std::sync::{Arc, Mutex, OnceLock};\n\n            let get_msg = |p: &(dyn std::any::Any + Send)| -> String {\n                p.downcast_ref::<String>().cloned()\n                    .or_else(|| p.downcast_ref::<&str>().map(|s| s.to_string()))\n                    .unwrap_or_else(|| \"Unknown panic\".to_string())\n            };\n\n            // An append-only list of all panic messages across the entire process.\n            // This is required because cubecl's `CallError` hides the original panic message\n            // occurring in the device threads.\n            //\n            // A global log also prevents parallel tests from overwriting each other's panic hooks.\n            static PANIC_LOG: OnceLock<Mutex<Vec<String>>> = OnceLock::new();\n            let log = PANIC_LOG.get_or_init(|| Mutex::new(Vec::new()));\n\n            static HOOK: OnceLock<()> = OnceLock::new();\n            HOOK.get_or_init(|| {\n                let prev = panic::take_hook();\n                panic::set_hook(Box::new(move |info| {\n                    if let Ok(mut v) = log.lock() {\n                        v.push(get_msg(info.payload()));\n                    }\n                    prev(info);\n                }));\n            });\n\n            // We only care about panics that occur during this test's execution window, so\n            // we start at the number of panics logged before this test starts.\n            let start_idx = log.lock().unwrap().len();\n            let result = panic::catch_unwind(AssertUnwindSafe(|| #fn_name()));\n\n            if let Err(e) = result {\n                let main_msg = get_msg(&*e);\n                let panic_logs = log.lock().unwrap();\n                let window = &panic_logs[start_idx..];\n\n                let matched = window.iter().chain(std::iter::once(&main_msg))\n                    .any(|m| m.contains(#expected_reason));\n\n                if !matched {\n                    let all = window.iter().chain(std::iter::once(&main_msg))\n                        .map(|m| format!(\"- {m}\")).collect::<Vec<_>>().join(\"\\n\");\n                    panic!(\"\\nTest '{}' failed.\\nExpected: '{}'\\nFound:\\n{}\\n\",\n                           stringify!(#fn_name), #expected_reason, all);\n                } else {\n                    let all = window.iter().chain(std::iter::once(&main_msg))\n                        .map(|m| format!(\"- {m}\")).collect::<Vec<_>>().join(\"\\n\");\n                    println!(\"\\nTest '{}' failed.\\nExpected: '{}'\\nFound:\\n{}\\n\",\n                           stringify!(#fn_name), #expected_reason, all);\n                }\n            }\n        }\n    }\n    .into()\n}\n"
  },
  {
    "path": "crates/burn-train/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"Training crate for the Burn framework\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"tensor\", \"pytorch\", \"ndarray\"]\nlicense.workspace = true\nname = \"burn-train\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-train\"\ndocumentation = \"https://docs.rs/burn-train\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"sys-metrics\", \"tui\", \"rl\"]\ndoc = [\"default\"]\nvision = [\"burn-nn\", \"burn-store/pytorch\", \"burn-std/network\", \"dirs\"]\ntracing = [\n    \"burn-core/tracing\",\n    \"burn-optim/tracing\",\n    \"burn-collective?/tracing\",\n]\n\n\nsys-metrics = [\"nvml-wrapper\", \"sysinfo\", \"systemstat\"]\ntui = [\"ratatui\"]\nrl = [\"burn-rl\"]\n# Distributed Data Parallel\nddp = [\"burn-collective\", \"burn-optim/collective\"]\n\n[dependencies]\nburn-core = { path = \"../burn-core\", version = \"=0.21.0-pre.2\", features = [\n    \"dataset\",\n    \"std\",\n], default-features = false }\nburn-optim = { path = \"../burn-optim\", version = \"=0.21.0-pre.2\", features = [\n    \"std\",\n], default-features = false }\nburn-rl = { path = \"../burn-rl\", version = \"=0.21.0-pre.2\", optional = true, default-features = false }\nburn-collective = { path = \"../burn-collective\", version = \"=0.21.0-pre.2\", optional = true }\nburn-nn = { path = \"../burn-nn\", version = \"=0.21.0-pre.2\", optional = true, default-features = false, features = [\"std\"] }\nburn-store = { path = \"../burn-store\", version = \"=0.21.0-pre.2\", optional = true, default-features = false, features = [\"std\"] }\nburn-std = { path = \"../burn-std\", version = \"=0.21.0-pre.2\", optional = true, default-features = false, features = [\"std\"] }\ndirs = { workspace = true, optional = true }\n\nlog = { workspace = true }\ntracing-subscriber = { workspace = true }\ntracing-appender = { workspace = true }\ntracing-core = { workspace = true }\n\n# System Metrics\nnvml-wrapper = { workspace = true, optional = true }\nsysinfo = { workspace = true, optional = true }\nsystemstat = { workspace = true, optional = true }\n\n# Text UI\nratatui = { workspace = true, optional = true, features = [\n    \"all-widgets\",\n    \"crossterm\",\n] }\n\n# Utilities\nderive-new = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\nasync-channel = { workspace = true }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nrstest.workspace = true\nthiserror.workspace = true\nrand.workspace = true\n\n[dev-dependencies]\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-autodiff = { path = \"../burn-autodiff\", version = \"=0.21.0-pre.2\" }\n\n[package.metadata.docs.rs]\nfeatures = [\"doc\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-train/README.md",
    "content": "# Burn Train\n\nThis crate should be used with [burn](https://github.com/tracel-ai/burn).\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-train.svg)](https://crates.io/crates/burn-train)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-train/blob/master/README.md)\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/async_checkpoint.rs",
    "content": "use super::{Checkpointer, CheckpointerError};\nuse crate::Interrupter;\nuse burn_core::{record::Record, tensor::backend::Backend};\nuse std::sync::mpsc;\n\nenum Message<R, B: Backend> {\n    Restore(\n        usize,\n        B::Device,\n        mpsc::SyncSender<Result<R, CheckpointerError>>,\n        Option<Interrupter>,\n    ),\n    Save(usize, R, Option<Interrupter>),\n    Delete(usize, Option<Interrupter>),\n    End,\n}\n\n#[derive(new)]\nstruct CheckpointerThread<C, R, B: Backend> {\n    checkpointer: C,\n    receiver: mpsc::Receiver<Message<R, B>>,\n}\n\nimpl<C, R, B> CheckpointerThread<C, R, B>\nwhere\n    C: Checkpointer<R, B>,\n    R: Record<B>,\n    B: Backend,\n{\n    fn run(self) {\n        for item in self.receiver.iter() {\n            match item {\n                Message::Restore(epoch, device, callback, interrupter) => {\n                    let record = self.checkpointer.restore(epoch, &device);\n                    callback.send(record).unwrap_or_else(|err| {\n                        interrupter.map_or_else(\n                            || {\n                                panic!(\n                                    \"Error when sending response through callback channel: {err}\"\n                                )\n                            },\n                            |int| int.stop(Some(&err.to_string())),\n                        )\n                    });\n                }\n                Message::Save(epoch, state, interrupter) => {\n                    self.checkpointer.save(epoch, state).unwrap_or_else(|err| {\n                        interrupter.map_or_else(\n                            || panic!(\"Error when saving the state: {err}\"),\n                            |int| int.stop(Some(&err.to_string())),\n                        )\n                    });\n                }\n                Message::Delete(epoch, interrupter) => {\n                    self.checkpointer.delete(epoch).unwrap_or_else(|err| {\n                        interrupter.map_or_else(\n                            || panic!(\"Error when deleting the state: {err}\"),\n                            |int| int.stop(Some(&err.to_string())),\n                        )\n                    });\n                }\n\n                Message::End => {\n                    return;\n                }\n            };\n        }\n    }\n}\n\n/// Async checkpointer.\npub struct AsyncCheckpointer<Record, B: Backend> {\n    sender: mpsc::SyncSender<Message<Record, B>>,\n    handler: Option<std::thread::JoinHandle<()>>,\n    interrupter: Option<Interrupter>,\n}\n\nimpl<R, B> AsyncCheckpointer<R, B>\nwhere\n    R: Record<B> + 'static,\n    B: Backend,\n{\n    /// Create a new async checkpointer.\n    ///\n    /// # Arguments\n    ///\n    /// * `checkpointer` - The checkpointer.\n    ///\n    /// # Returns\n    ///\n    /// The async checkpointer.\n    pub fn new<C>(checkpointer: C) -> Self\n    where\n        C: Checkpointer<R, B> + Send + 'static,\n    {\n        // Only on checkpoint can be done in advance.\n        let (sender, receiver) = mpsc::sync_channel(0);\n        let thread = CheckpointerThread::new(checkpointer, receiver);\n        let handler = Some(std::thread::spawn(move || thread.run()));\n\n        Self {\n            sender,\n            handler,\n            interrupter: None,\n        }\n    }\n\n    /// Assign a handle used to interrupt training in case of checkpointing error.\n    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {\n        self.interrupter = Some(interrupter);\n        self\n    }\n}\n\nimpl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>\nwhere\n    R: Record<B> + 'static,\n    B: Backend,\n{\n    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {\n        self.sender\n            .send(Message::Save(epoch, record, self.interrupter.clone()))\n            .expect(\"Can send message to checkpointer thread.\");\n\n        Ok(())\n    }\n\n    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {\n        let (sender, receiver) = mpsc::sync_channel(1);\n        self.sender\n            .send(Message::Restore(\n                epoch,\n                device.clone(),\n                sender,\n                self.interrupter.clone(),\n            ))\n            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;\n\n        if let Ok(record) = receiver.recv() {\n            return record;\n        };\n\n        Err(CheckpointerError::Unknown(\"Channel error.\".to_string()))\n    }\n\n    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {\n        self.sender\n            .send(Message::Delete(epoch, self.interrupter.clone()))\n            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;\n\n        Ok(())\n    }\n}\n\nimpl<E, B> Drop for AsyncCheckpointer<E, B>\nwhere\n    B: Backend,\n{\n    fn drop(&mut self) {\n        self.sender\n            .send(Message::End)\n            .expect(\"Can send the end message to the checkpointer thread.\");\n        let handler = self.handler.take();\n\n        if let Some(handler) = handler {\n            handler\n                .join()\n                .expect(\"The checkpointer thread should stop.\");\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/base.rs",
    "content": "use burn_core::{\n    record::{Record, RecorderError},\n    tensor::backend::Backend,\n};\nuse thiserror::Error;\n\n/// The error type for checkpointer.\n#[derive(Error, Debug)]\npub enum CheckpointerError {\n    /// IO error.\n    #[error(\"I/O Error: `{0}`\")]\n    IOError(std::io::Error),\n\n    /// Recorder error.\n    #[error(\"Recorder error: `{0}`\")]\n    RecorderError(RecorderError),\n\n    /// Other errors.\n    #[error(\"Unknown error: `{0}`\")]\n    Unknown(String),\n}\n\n/// The trait for checkpointer.\npub trait Checkpointer<R, B>: Send + Sync\nwhere\n    R: Record<B>,\n    B: Backend,\n{\n    /// Save the record.\n    ///\n    /// # Arguments\n    ///\n    /// * `epoch` - The epoch.\n    /// * `record` - The record.\n    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;\n\n    /// Delete the record at the given epoch if present.\n    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;\n\n    /// Restore the record.\n    ///\n    /// # Arguments\n    ///\n    /// * `epoch` - The epoch.\n    /// * `device` - The device used to restore the record.\n    ///\n    /// # Returns\n    ///\n    /// The record.\n    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/file.rs",
    "content": "use std::path::{Path, PathBuf};\n\nuse super::{Checkpointer, CheckpointerError};\nuse burn_core::{\n    record::{FileRecorder, Record},\n    tensor::backend::Backend,\n};\n\n/// The file checkpointer.\npub struct FileCheckpointer<FR> {\n    directory: PathBuf,\n    name: String,\n    recorder: FR,\n}\n\nimpl<FR> FileCheckpointer<FR> {\n    /// Creates a new file checkpointer.\n    ///\n    /// # Arguments\n    ///\n    /// * `recorder` - The file recorder.\n    /// * `directory` - The directory to save the checkpoints.\n    /// * `name` - The name of the checkpoint.\n    pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {\n        let directory = directory.as_ref();\n        std::fs::create_dir_all(directory).ok();\n\n        Self {\n            directory: directory.to_path_buf(),\n            name: name.to_string(),\n            recorder,\n        }\n    }\n\n    fn path_for_epoch(&self, epoch: usize) -> PathBuf {\n        self.directory.join(format!(\"{}-{}\", self.name, epoch))\n    }\n}\n\nimpl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>\nwhere\n    R: Record<B>,\n    FR: FileRecorder<B>,\n    B: Backend,\n{\n    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {\n        let file_path = self.path_for_epoch(epoch);\n        log::trace!(\"Saving checkpoint {} to {}\", epoch, file_path.display());\n\n        self.recorder\n            .record(record, file_path)\n            .map_err(CheckpointerError::RecorderError)?;\n\n        Ok(())\n    }\n\n    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {\n        let file_path = self.path_for_epoch(epoch);\n        log::info!(\n            \"Restoring checkpoint {} from {}\",\n            epoch,\n            file_path.display()\n        );\n        let record = self\n            .recorder\n            .load(file_path, device)\n            .map_err(CheckpointerError::RecorderError)?;\n\n        Ok(record)\n    }\n\n    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {\n        let file_to_remove = format!(\n            \"{}.{}\",\n            self.path_for_epoch(epoch).display(),\n            FR::file_extension(),\n        );\n\n        if std::path::Path::new(&file_to_remove).exists() {\n            log::trace!(\"Removing checkpoint {file_to_remove}\");\n            std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;\n        }\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/mod.rs",
    "content": "mod async_checkpoint;\nmod base;\nmod file;\nmod strategy;\n\npub use async_checkpoint::*;\npub use base::*;\npub use file::*;\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/strategy/base.rs",
    "content": "use std::ops::DerefMut;\n\nuse crate::metric::store::EventStoreClient;\n\n/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).\n#[derive(Clone, PartialEq, Debug)]\npub enum CheckpointingAction {\n    /// Delete the given epoch.\n    Delete(usize),\n    /// Save the current record.\n    Save,\n}\n\n/// Define when checkpoint should be saved and deleted.\npub trait CheckpointingStrategy: Send {\n    /// Based on the epoch, determine if the checkpoint should be saved.\n    fn checkpointing(\n        &mut self,\n        epoch: usize,\n        collector: &EventStoreClient,\n    ) -> Vec<CheckpointingAction>;\n}\n\n// We make dyn box implement the checkpointing strategy so that it can be used with generic, but\n// still be dynamic.\nimpl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {\n    fn checkpointing(\n        &mut self,\n        epoch: usize,\n        collector: &EventStoreClient,\n    ) -> Vec<CheckpointingAction> {\n        self.deref_mut().checkpointing(epoch, collector)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/strategy/composed.rs",
    "content": "use crate::metric::store::EventStoreClient;\n\nuse super::{CheckpointingAction, CheckpointingStrategy};\nuse std::collections::HashSet;\n\n/// Compose multiple checkpointing strategy and only delete checkpoints when both strategy flag an\n/// epoch to be deleted.\npub struct ComposedCheckpointingStrategy {\n    strategies: Vec<Box<dyn CheckpointingStrategy>>,\n    deleted: Vec<HashSet<usize>>,\n}\n\n/// Help building a [checkpointing strategy](CheckpointingStrategy) by combining multiple ones.\n#[derive(Default)]\npub struct ComposedCheckpointingStrategyBuilder {\n    strategies: Vec<Box<dyn CheckpointingStrategy>>,\n}\n\nimpl ComposedCheckpointingStrategyBuilder {\n    /// Add a new [checkpointing strategy](CheckpointingStrategy).\n    #[allow(clippy::should_implement_trait)]\n    pub fn add<S>(mut self, strategy: S) -> Self\n    where\n        S: CheckpointingStrategy + 'static,\n    {\n        self.strategies.push(Box::new(strategy));\n        self\n    }\n\n    /// Create a new [composed checkpointing strategy](ComposedCheckpointingStrategy).\n    pub fn build(self) -> ComposedCheckpointingStrategy {\n        ComposedCheckpointingStrategy::new(self.strategies)\n    }\n}\n\nimpl ComposedCheckpointingStrategy {\n    fn new(strategies: Vec<Box<dyn CheckpointingStrategy>>) -> Self {\n        Self {\n            deleted: strategies.iter().map(|_| HashSet::new()).collect(),\n            strategies,\n        }\n    }\n    /// Create a new builder which help compose multiple\n    /// [checkpointing strategies](CheckpointingStrategy).\n    pub fn builder() -> ComposedCheckpointingStrategyBuilder {\n        ComposedCheckpointingStrategyBuilder::default()\n    }\n}\n\nimpl CheckpointingStrategy for ComposedCheckpointingStrategy {\n    fn checkpointing(\n        &mut self,\n        epoch: usize,\n        collector: &EventStoreClient,\n    ) -> Vec<CheckpointingAction> {\n        let mut saved = false;\n        let mut actions = Vec::new();\n        let mut epochs_to_check = Vec::new();\n\n        for (i, strategy) in self.strategies.iter_mut().enumerate() {\n            let actions = strategy.checkpointing(epoch, collector);\n            // We assume that the strategy would not want the current epoch to be saved.\n            // So we flag it as deleted.\n            if actions.is_empty() {\n                self.deleted\n                    .get_mut(i)\n                    .expect(\"As many 'deleted' as 'strategies'.\")\n                    .insert(epoch);\n            }\n\n            for action in actions {\n                match action {\n                    CheckpointingAction::Delete(epoch) => {\n                        self.deleted\n                            .get_mut(i)\n                            .expect(\"As many 'deleted' as 'strategies'.\")\n                            .insert(epoch);\n                        epochs_to_check.push(epoch);\n                    }\n                    CheckpointingAction::Save => saved = true,\n                }\n            }\n        }\n\n        if saved {\n            actions.push(CheckpointingAction::Save);\n        }\n\n        for epoch in epochs_to_check.into_iter() {\n            let mut num_true = 0;\n            for i in 0..self.strategies.len() {\n                if self\n                    .deleted\n                    .get(i)\n                    .expect(\"Ad many 'deleted' as 'strategies'.\")\n                    .contains(&epoch)\n                {\n                    num_true += 1;\n                }\n            }\n\n            if num_true == self.strategies.len() {\n                actions.push(CheckpointingAction::Delete(epoch));\n\n                for i in 0..self.strategies.len() {\n                    self.deleted\n                        .get_mut(i)\n                        .expect(\"As many 'deleted' as 'strategies'.\")\n                        .remove(&epoch);\n                }\n            }\n        }\n\n        actions\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{checkpoint::KeepLastNCheckpoints, metric::store::LogEventStore};\n\n    #[test]\n    fn should_delete_when_both_deletes() {\n        let store = EventStoreClient::new(LogEventStore::default());\n        let mut strategy = ComposedCheckpointingStrategy::builder()\n            .add(KeepLastNCheckpoints::new(1))\n            .add(KeepLastNCheckpoints::new(2))\n            .build();\n\n        assert_eq!(\n            vec![CheckpointingAction::Save],\n            strategy.checkpointing(1, &store)\n        );\n\n        assert_eq!(\n            vec![CheckpointingAction::Save],\n            strategy.checkpointing(2, &store)\n        );\n\n        assert_eq!(\n            vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],\n            strategy.checkpointing(3, &store)\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/strategy/lastn.rs",
    "content": "use super::CheckpointingStrategy;\nuse crate::{checkpoint::CheckpointingAction, metric::store::EventStoreClient};\n\n/// Keep the last N checkpoints.\n///\n/// Very useful when training, minimizing disk space while ensuring that the training can be\n/// resumed even if something goes wrong.\n#[derive(new)]\npub struct KeepLastNCheckpoints {\n    num_keep: usize,\n}\n\nimpl CheckpointingStrategy for KeepLastNCheckpoints {\n    fn checkpointing(\n        &mut self,\n        epoch: usize,\n        _store: &EventStoreClient,\n    ) -> Vec<CheckpointingAction> {\n        let mut actions = vec![CheckpointingAction::Save];\n\n        if let Some(epoch) = usize::checked_sub(epoch, self.num_keep)\n            && epoch > 0\n        {\n            actions.push(CheckpointingAction::Delete(epoch));\n        }\n\n        actions\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::metric::store::LogEventStore;\n\n    #[test]\n    fn should_always_delete_lastn_epoch_if_higher_than_one() {\n        let mut strategy = KeepLastNCheckpoints::new(2);\n        let store = EventStoreClient::new(LogEventStore::default());\n\n        assert_eq!(\n            vec![CheckpointingAction::Save],\n            strategy.checkpointing(1, &store)\n        );\n\n        assert_eq!(\n            vec![CheckpointingAction::Save],\n            strategy.checkpointing(2, &store)\n        );\n\n        assert_eq!(\n            vec![CheckpointingAction::Save, CheckpointingAction::Delete(1)],\n            strategy.checkpointing(3, &store)\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/strategy/metric.rs",
    "content": "use super::CheckpointingStrategy;\nuse crate::{\n    checkpoint::CheckpointingAction,\n    metric::{\n        Metric, MetricName,\n        store::{Aggregate, Direction, EventStoreClient, Split},\n    },\n};\n\n/// Keep the best checkpoint based on a metric.\npub struct MetricCheckpointingStrategy {\n    current: Option<usize>,\n    aggregate: Aggregate,\n    direction: Direction,\n    split: Split,\n    name: MetricName,\n}\n\nimpl MetricCheckpointingStrategy {\n    /// Create a new metric checkpointing strategy.\n    pub fn new<M>(metric: &M, aggregate: Aggregate, direction: Direction, split: Split) -> Self\n    where\n        M: Metric,\n    {\n        Self {\n            current: None,\n            name: metric.name(),\n            aggregate,\n            direction,\n            split,\n        }\n    }\n}\n\nimpl CheckpointingStrategy for MetricCheckpointingStrategy {\n    fn checkpointing(\n        &mut self,\n        epoch: usize,\n        store: &EventStoreClient,\n    ) -> Vec<CheckpointingAction> {\n        let best_epoch =\n            match store.find_epoch(&self.name, self.aggregate, self.direction, &self.split) {\n                Some(epoch_best) => epoch_best,\n                None => epoch,\n            };\n\n        let mut actions = Vec::new();\n\n        if let Some(current) = self.current\n            && current != best_epoch\n        {\n            actions.push(CheckpointingAction::Delete(current));\n        }\n\n        if best_epoch == epoch {\n            actions.push(CheckpointingAction::Save);\n        }\n\n        self.current = Some(best_epoch);\n\n        actions\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::{\n        EventProcessorTraining, TestBackend,\n        logger::InMemoryMetricLogger,\n        metric::{\n            LossMetric,\n            processor::{\n                MetricsTraining, MinimalEventProcessor,\n                test_utils::{end_epoch, process_train},\n            },\n            store::LogEventStore,\n        },\n    };\n\n    use super::*;\n    use std::sync::Arc;\n\n    #[test]\n    fn always_keep_the_best_epoch() {\n        let loss = LossMetric::<TestBackend>::new();\n        let mut store = LogEventStore::default();\n        let mut strategy = MetricCheckpointingStrategy::new(\n            &loss,\n            Aggregate::Mean,\n            Direction::Lowest,\n            Split::Train,\n        );\n        let mut metrics = MetricsTraining::<f64, f64>::default();\n        // Register an in memory logger.\n        store.register_logger(InMemoryMetricLogger::default());\n        // Register the loss metric.\n        metrics.register_train_metric_numeric(loss);\n        let store = Arc::new(EventStoreClient::new(store));\n        let mut processor = MinimalEventProcessor::new(metrics, store.clone());\n        processor.process_train(crate::LearnerEvent::Start);\n\n        // Two points for the first epoch. Mean 0.75\n        let mut epoch = 1;\n        process_train(&mut processor, 1.0, epoch);\n        process_train(&mut processor, 0.5, epoch);\n        end_epoch(&mut processor, epoch);\n\n        // Should save the current record.\n        assert_eq!(\n            vec![CheckpointingAction::Save],\n            strategy.checkpointing(epoch, &store)\n        );\n\n        // Two points for the second epoch. Mean 0.4\n        epoch += 1;\n        process_train(&mut processor, 0.5, epoch);\n        process_train(&mut processor, 0.3, epoch);\n        end_epoch(&mut processor, epoch);\n\n        // Should save the current record and delete the previous one.\n        assert_eq!(\n            vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],\n            strategy.checkpointing(epoch, &store)\n        );\n\n        // Two points for the last epoch. Mean 2.0\n        epoch += 1;\n        process_train(&mut processor, 1.0, epoch);\n        process_train(&mut processor, 3.0, epoch);\n        end_epoch(&mut processor, epoch);\n\n        // Should not delete the previous record, since it's the best one, and should not save a\n        // new one.\n        assert!(strategy.checkpointing(epoch, &store).is_empty());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/checkpoint/strategy/mod.rs",
    "content": "mod base;\nmod composed;\nmod lastn;\nmod metric;\n\npub use base::*;\npub use composed::*;\npub use lastn::*;\npub use metric::*;\n"
  },
  {
    "path": "crates/burn-train/src/components.rs",
    "content": "use crate::{InferenceStep, TrainStep};\nuse burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};\nuse burn_optim::{Optimizer, lr_scheduler::LrScheduler};\nuse std::marker::PhantomData;\n\n/// Components used for a model to learn, grouped in one trait.\npub trait LearningComponentsTypes {\n    /// The backend used for training.\n    type Backend: AutodiffBackend;\n    /// The learning rate scheduler used for training.\n    type LrScheduler: LrScheduler + 'static;\n    /// The model to train.\n    type TrainingModel: TrainStep\n        + AutodiffModule<Self::Backend, InnerModule = Self::InferenceModel>\n        + core::fmt::Display\n        + 'static;\n    /// The non-autodiff type of the model.\n    type InferenceModel: InferenceStep;\n    /// The optimizer used for training.\n    type Optimizer: Optimizer<Self::TrainingModel, Self::Backend> + 'static;\n}\n\n/// Concrete type that implements the [LearningComponentsTypes](LearningComponentsTypes) trait.\npub struct LearningComponentsMarker<B, LR, M, O> {\n    _backend: PhantomData<B>,\n    _lr_scheduler: PhantomData<LR>,\n    _model: PhantomData<M>,\n    _optimizer: PhantomData<O>,\n}\n\nimpl<B, LR, M, O> LearningComponentsTypes for LearningComponentsMarker<B, LR, M, O>\nwhere\n    B: AutodiffBackend,\n    LR: LrScheduler + 'static,\n    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,\n    M::InnerModule: InferenceStep,\n    O: Optimizer<M, B> + 'static,\n{\n    type Backend = B;\n    type LrScheduler = LR;\n    type TrainingModel = M;\n    type InferenceModel = M::InnerModule;\n    type Optimizer = O;\n}\n\n/// The training backend.\npub type TrainingBackend<LC> = <LC as LearningComponentsTypes>::Backend;\n/// The inference backend.\npub(crate) type InferenceBackend<LC> =\n    <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend;\n/// The model used for training.\npub type TrainingModel<LC> = <LC as LearningComponentsTypes>::TrainingModel;\n/// The non-autodiff model.\npub(crate) type InferenceModel<LC> = <LC as LearningComponentsTypes>::InferenceModel;\n/// Type for training input.\npub(crate) type TrainingModelInput<LC> =\n    <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Input;\n/// Type for inference input.\npub(crate) type InferenceModelInput<LC> =\n    <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Input;\n/// Type for training output.\npub(crate) type TrainingModelOutput<LC> =\n    <<LC as LearningComponentsTypes>::TrainingModel as TrainStep>::Output;\n/// Type for inference output.\npub(crate) type InferenceModelOutput<LC> =\n    <<LC as LearningComponentsTypes>::InferenceModel as InferenceStep>::Output;\n"
  },
  {
    "path": "crates/burn-train/src/evaluator/base.rs",
    "content": "use crate::{\n    AsyncProcessorEvaluation, EvaluationItem, FullEventProcessorEvaluation, InferenceStep,\n    Interrupter, LearnerSummaryConfig,\n    evaluator::components::EvaluatorComponentTypes,\n    metric::processor::{EvaluatorEvent, EventProcessorEvaluation},\n    renderer::{EvaluationName, MetricsRenderer},\n};\nuse burn_core::{data::dataloader::DataLoader, module::Module};\nuse std::sync::Arc;\n\npub(crate) type TestBackend<EC> = <EC as EvaluatorComponentTypes>::Backend;\npub(crate) type TestInput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Input;\npub(crate) type TestOutput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Output;\n\npub(crate) type TestLoader<EC> = Arc<dyn DataLoader<TestBackend<EC>, TestInput<EC>>>;\n\n/// Evaluates a model on a specific dataset.\npub struct Evaluator<EC: EvaluatorComponentTypes> {\n    pub(crate) model: EC::Model,\n    pub(crate) interrupter: Interrupter,\n    pub(crate) event_processor:\n        AsyncProcessorEvaluation<FullEventProcessorEvaluation<TestOutput<EC>>>,\n    /// Config for creating a summary of the evaluation\n    pub summary: Option<LearnerSummaryConfig>,\n}\n\nimpl<EC: EvaluatorComponentTypes> Evaluator<EC> {\n    /// Run the evaluation on the given dataset.\n    ///\n    /// The data will be stored and displayed under the provided name.\n    pub fn eval<S: core::fmt::Display>(\n        mut self,\n        name: S,\n        dataloader: TestLoader<EC>,\n    ) -> Box<dyn MetricsRenderer> {\n        // Move dataloader to the model device\n        let dataloader = dataloader.to_device(self.model.devices().first().unwrap());\n\n        let name = EvaluationName::new(name);\n        let mut iterator = dataloader.iter();\n        let mut iteration = 0;\n\n        self.event_processor.process_test(EvaluatorEvent::Start);\n        while let Some(item) = iterator.next() {\n            let progress = iterator.progress();\n            iteration += 1;\n\n            let item = self.model.step(item);\n            let item = EvaluationItem::new(item, progress, Some(iteration));\n\n            self.event_processor\n                .process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));\n\n            if self.interrupter.should_stop() {\n                log::info!(\"Testing interrupted.\");\n                break;\n            }\n        }\n\n        let summary = self.summary.and_then(|summary| {\n            summary\n                .init()\n                .map(|summary| summary.with_model(self.model.to_string()))\n                .ok()\n        });\n\n        self.event_processor\n            .process_test(EvaluatorEvent::End(summary));\n\n        self.event_processor.renderer()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/evaluator/builder.rs",
    "content": "use crate::{\n    ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, InferenceStep,\n    Interrupter, LearnerSummaryConfig, TestOutput,\n    evaluator::components::{EvaluatorComponentTypes, EvaluatorComponentTypesMarker},\n    logger::FileMetricLogger,\n    metric::{\n        Adaptor, ItemLazy, Metric, Numeric,\n        processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation},\n        store::{EventStoreClient, LogEventStore},\n    },\n    renderer::{MetricsRenderer, default_renderer},\n};\nuse burn_core::{module::Module, prelude::Backend};\nuse std::{\n    collections::BTreeSet,\n    path::{Path, PathBuf},\n    sync::Arc,\n};\n\n/// Struct to configure and create an [evaluator](Evaluator).\n///\n/// The generics components of the builder should probably not be set manually, as they are\n/// optimized for Rust type inference.\npub struct EvaluatorBuilder<EC: EvaluatorComponentTypes> {\n    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    event_store: LogEventStore,\n    summary_metrics: BTreeSet<String>,\n    renderer: Option<Box<dyn MetricsRenderer + 'static>>,\n    interrupter: Interrupter,\n    metrics: MetricsEvaluation<TestOutput<EC>>,\n    directory: PathBuf,\n    summary: bool,\n}\n\nimpl<B, M> EvaluatorBuilder<EvaluatorComponentTypesMarker<B, M>>\nwhere\n    B: Backend,\n    M: Module<B> + InferenceStep + core::fmt::Display + 'static,\n{\n    /// Creates a new evaluator builder.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory to save the checkpoints.\n    pub fn new(directory: impl AsRef<Path>) -> Self {\n        let directory = directory.as_ref().to_path_buf();\n        let log_file = directory.join(\"evaluation.log\");\n\n        Self {\n            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))),\n            event_store: LogEventStore::default(),\n            summary_metrics: Default::default(),\n            renderer: None,\n            interrupter: Interrupter::new(),\n            summary: false,\n            metrics: MetricsEvaluation::default(),\n            directory,\n        }\n    }\n}\n\nimpl<EC: EvaluatorComponentTypes> EvaluatorBuilder<EC> {\n    /// Registers [numeric](crate::metric::Numeric) test [metrics](Metric).\n    pub fn metrics<Me: EvalMetricRegistration<EC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Registers text [metrics](Metric).\n    pub fn metrics_text<Me: EvalTextMetricRegistration<EC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// By default, Rust logs are captured and written into\n    /// `evaluation.log`. If disabled, standard Rust log handling\n    /// will apply.\n    pub fn with_application_logger(\n        mut self,\n        logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    ) -> Self {\n        self.tracing_logger = logger;\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) test [metric](Metric).\n    pub fn metric_numeric<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + Numeric + 'static,\n        <TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_test_metric_numeric(metric);\n        self\n    }\n\n    /// Register a text test [metric](Metric).\n    pub fn metric<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + 'static,\n        <TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_test_metric(metric);\n        self\n    }\n\n    /// Replace the default CLI renderer with a custom one.\n    ///\n    /// # Arguments\n    ///\n    /// * `renderer` - The custom renderer.\n    pub fn renderer(mut self, renderer: Box<dyn MetricsRenderer + 'static>) -> Self {\n        self.renderer = Some(renderer);\n        self\n    }\n\n    /// Enable the evaluation summary report.\n    ///\n    /// The summary will be displayed at the end of `.eval()`.\n    pub fn summary(mut self) -> Self {\n        self.summary = true;\n        self\n    }\n\n    /// Builds the evaluator.\n    #[allow(clippy::type_complexity)]\n    pub fn build(mut self, model: EC::Model) -> Evaluator<EC> {\n        let renderer = self\n            .renderer\n            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), None));\n\n        self.event_store\n            .register_logger(FileMetricLogger::new_eval(self.directory.clone()));\n        let event_store = Arc::new(EventStoreClient::new(self.event_store));\n\n        let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new(\n            self.metrics,\n            renderer,\n            event_store,\n        ));\n\n        let summary = if self.summary {\n            Some(LearnerSummaryConfig {\n                directory: self.directory,\n                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),\n            })\n        } else {\n            None\n        };\n\n        Evaluator {\n            model,\n            interrupter: self.interrupter,\n            event_processor,\n            summary,\n        }\n    }\n}\n\n/// Trait to fake variadic generics.\npub trait EvalMetricRegistration<EC: EvaluatorComponentTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;\n}\n\n/// Trait to fake variadic generics.\npub trait EvalTextMetricRegistration<EC: EvaluatorComponentTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: EvaluatorBuilder<EC>) -> EvaluatorBuilder<EC>;\n}\n\nmacro_rules! gen_tuple {\n    ($($M:ident),*) => {\n        impl<$($M,)* EC: EvaluatorComponentTypes> EvalTextMetricRegistration<EC> for ($($M,)*)\n        where\n            $(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: EvaluatorBuilder<EC>,\n            ) -> EvaluatorBuilder<EC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric($M);)*\n                builder\n            }\n        }\n\n        impl<$($M,)* EC: EvaluatorComponentTypes> EvalMetricRegistration<EC> for ($($M,)*)\n        where\n            $(<TestOutput<EC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + $crate::metric::Numeric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: EvaluatorBuilder<EC>,\n            ) -> EvaluatorBuilder<EC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_numeric($M);)*\n                builder\n            }\n        }\n    };\n}\n\ngen_tuple!(M1);\ngen_tuple!(M1, M2);\ngen_tuple!(M1, M2, M3);\ngen_tuple!(M1, M2, M3, M4);\ngen_tuple!(M1, M2, M3, M4, M5);\ngen_tuple!(M1, M2, M3, M4, M5, M6);\n"
  },
  {
    "path": "crates/burn-train/src/evaluator/components.rs",
    "content": "use crate::InferenceStep;\nuse burn_core::{module::Module, prelude::Backend};\nuse std::marker::PhantomData;\n\n/// All components necessary to evaluate a model grouped in one trait.\npub trait EvaluatorComponentTypes {\n    /// The backend in used for the evaluation.\n    type Backend: Backend;\n    /// The model to evaluate.\n    type Model: Module<Self::Backend> + InferenceStep + core::fmt::Display + 'static;\n}\n\n/// A marker type used to provide [evaluation components](EvaluatorComponentTypes).\npub struct EvaluatorComponentTypesMarker<B, M> {\n    _p: PhantomData<(B, M)>,\n}\n\nimpl<B, M> EvaluatorComponentTypes for EvaluatorComponentTypesMarker<B, M>\nwhere\n    B: Backend,\n    M: Module<B> + InferenceStep + core::fmt::Display + 'static,\n{\n    type Backend = B;\n    type Model = M;\n}\n"
  },
  {
    "path": "crates/burn-train/src/evaluator/mod.rs",
    "content": "mod base;\nmod builder;\n\npub(crate) mod components;\n\npub use base::*;\npub use builder::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/application_logger.rs",
    "content": "use std::path::{Path, PathBuf};\nuse tracing_core::{Level, LevelFilter};\nuse tracing_subscriber::filter::filter_fn;\nuse tracing_subscriber::prelude::*;\nuse tracing_subscriber::{Layer, registry};\n\n/// This trait is used to install an application logger.\npub trait ApplicationLoggerInstaller {\n    /// Install the application logger.\n    fn install(&self) -> Result<(), String>;\n}\n\n/// This struct is used to install a local file application logger to output logs to a given file path.\npub struct FileApplicationLoggerInstaller {\n    path: PathBuf,\n}\n\nimpl FileApplicationLoggerInstaller {\n    /// Create a new file application logger.\n    pub fn new(path: impl AsRef<Path>) -> Self {\n        Self {\n            path: path.as_ref().to_path_buf(),\n        }\n    }\n}\n\nimpl ApplicationLoggerInstaller for FileApplicationLoggerInstaller {\n    fn install(&self) -> Result<(), String> {\n        let path = Path::new(&self.path);\n        let writer = tracing_appender::rolling::never(\n            path.parent().unwrap_or_else(|| Path::new(\".\")),\n            path.file_name().unwrap_or_else(|| {\n                panic!(\"The path '{}' to point to a file.\", self.path.display())\n            }),\n        );\n        let layer = tracing_subscriber::fmt::layer()\n            .with_ansi(false)\n            .with_writer(writer)\n            .with_filter(LevelFilter::INFO)\n            .with_filter(filter_fn(|m| {\n                if let Some(path) = m.module_path() {\n                    // The wgpu crate is logging too much, so we skip `info` level.\n                    if path.starts_with(\"wgpu\") && *m.level() >= Level::INFO {\n                        return false;\n                    }\n                }\n                true\n            }));\n\n        if registry().with(layer).try_init().is_err() {\n            return Err(\"Failed to install the file logger.\".to_string());\n        }\n\n        let hook = std::panic::take_hook();\n        let file_path = self.path.to_owned();\n\n        std::panic::set_hook(Box::new(move |info| {\n            log::error!(\"PANIC => {info}\");\n            eprintln!(\n                \"=== PANIC ===\\nA fatal error happened, you can check the experiment logs here => \\\n                    '{}'\\n=============\",\n                file_path.display()\n            );\n            hook(info);\n        }));\n\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/base.rs",
    "content": "use crate::LearningComponentsMarker;\nuse crate::checkpoint::{\n    AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,\n};\nuse crate::components::{LearningComponentsTypes, TrainingBackend};\nuse crate::metric::store::EventStoreClient;\nuse crate::{\n    CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,\n    TrainingModelOutput,\n};\nuse burn_core::module::{AutodiffModule, Module};\nuse burn_core::prelude::Backend;\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_optim::lr_scheduler::LrScheduler;\nuse burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};\nuse std::sync::atomic::{AtomicBool, Ordering};\nuse std::sync::{Arc, Mutex};\n\n/// The record of the learner's model.\npub type LearnerModelRecord<LC> =\n    <<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;\n/// The record of the optimizer.\npub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<\n    <LC as LearningComponentsTypes>::TrainingModel,\n    TrainingBackend<LC>,\n>>::Record;\n/// The record of the LR scheduler.\npub type LearnerSchedulerRecord<LC> =\n    <<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;\n\n/// Learner struct encapsulating all components necessary to train a Neural Network model.\npub struct Learner<LC: LearningComponentsTypes> {\n    pub(crate) model: LC::TrainingModel,\n    optim: LC::Optimizer,\n    lr_scheduler: LC::LrScheduler,\n    lr: f64,\n}\n\nimpl<LC: LearningComponentsTypes> Clone for Learner<LC> {\n    fn clone(&self) -> Self {\n        Self {\n            model: self.model.clone(),\n            optim: self.optim.clone(),\n            lr_scheduler: self.lr_scheduler.clone(),\n            lr: self.lr,\n        }\n    }\n}\n\nimpl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>\nwhere\n    B: AutodiffBackend,\n    LR: LrScheduler + 'static,\n    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,\n    M::InnerModule: InferenceStep,\n    O: Optimizer<M, B> + 'static,\n{\n    /// Create a learner.\n    pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {\n        Self {\n            model,\n            optim,\n            lr_scheduler,\n            lr: 0.0,\n        }\n    }\n}\n\nimpl<LC: LearningComponentsTypes> Learner<LC> {\n    /// Fork the learner's model to the given device.\n    pub fn fork(&mut self, device: &<TrainingBackend<LC> as Backend>::Device) {\n        self.model = self.model().fork(device);\n    }\n\n    /// Returns the current model.\n    pub fn model(&self) -> LC::TrainingModel {\n        self.model.clone()\n    }\n\n    /// Returns the current learning rate.\n    pub fn lr_current(&self) -> f64 {\n        self.lr\n    }\n\n    /// Executes a step of the learning rate scheduler.\n    pub fn lr_step(&mut self) {\n        self.lr = self.lr_scheduler.step();\n    }\n\n    /// Runs a step of the model for training, which executes the forward and backward passes.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The input for the model.\n    ///\n    /// # Returns\n    ///\n    /// The output containing the model output and the gradients.\n    pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {\n        self.model.step(item)\n    }\n\n    /// Optimize the current module with the provided gradients and learning rate.\n    ///\n    /// # Arguments\n    ///\n    /// * `optim`: Optimizer used for learning.\n    /// * `lr`: The learning rate used for this step.\n    /// * `grads`: The gradients of each parameter in the current model.\n    pub fn optimizer_step(&mut self, grads: GradientsParams) {\n        self.model = self.model().optimize(&mut self.optim, self.lr, grads);\n    }\n\n    /// Optimize the current module with the provided gradients and learning rate.\n    ///\n    /// # Arguments\n    ///\n    /// * `optim`: Optimizer used for learning.\n    /// * `lr`: The learning rate used for this step.\n    /// * `grads`: Multiple gradients associated to each parameter in the current model.\n    pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {\n        self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);\n    }\n\n    /// Load the module state from a [record](LearnerModelRecord<LC>).\n    pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {\n        self.model = self.model.clone().load_record(record);\n    }\n\n    /// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord<LC>).\n    pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {\n        self.optim = self.optim.clone().load_record(record);\n    }\n\n    /// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord<LC>).\n    pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {\n        self.lr_scheduler = self.lr_scheduler.clone().load_record(record);\n    }\n}\n\n#[derive(new)]\n/// Used to create, delete, or load checkpoints of the training process.\npub struct LearningCheckpointer<LC: LearningComponentsTypes> {\n    model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,\n    optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,\n    lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,\n    strategy: Box<dyn CheckpointingStrategy>,\n}\n\nimpl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {\n    /// Create checkpoint for the training process.\n    pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {\n        let actions = self.strategy.checkpointing(epoch, store);\n\n        for action in actions {\n            match action {\n                CheckpointingAction::Delete(epoch) => {\n                    self.model\n                        .delete(epoch)\n                        .expect(\"Can delete model checkpoint.\");\n                    self.optim\n                        .delete(epoch)\n                        .expect(\"Can delete optimizer checkpoint.\");\n                    self.lr_scheduler\n                        .delete(epoch)\n                        .expect(\"Can delete learning rate scheduler checkpoint.\");\n                }\n                CheckpointingAction::Save => {\n                    self.model\n                        .save(epoch, learner.model.clone().into_record())\n                        .expect(\"Can save model checkpoint.\");\n                    self.optim\n                        .save(epoch, learner.optim.to_record())\n                        .expect(\"Can save optimizer checkpoint.\");\n                    self.lr_scheduler\n                        .save(epoch, learner.lr_scheduler.to_record())\n                        .expect(\"Can save learning rate scheduler checkpoint.\");\n                }\n            }\n        }\n    }\n\n    /// Load a training checkpoint.\n    pub fn load_checkpoint(\n        &self,\n        mut learner: Learner<LC>,\n        device: &Device<LC::Backend>,\n        epoch: usize,\n    ) -> Learner<LC> {\n        let record = self\n            .model\n            .restore(epoch, device)\n            .expect(\"Can load model checkpoint.\");\n        learner.load_model(record);\n\n        let record = self\n            .optim\n            .restore(epoch, device)\n            .expect(\"Can load optimizer checkpoint.\");\n        learner.load_optim(record);\n\n        let record = self\n            .lr_scheduler\n            .restore(epoch, device)\n            .expect(\"Can load learning rate scheduler checkpoint.\");\n        learner.load_scheduler(record);\n\n        learner\n    }\n}\n\n/// Cloneable reference to an early stopping strategy\npub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;\n\n#[derive(Clone, Default)]\n/// A handle that allows aborting the training/evaluation process early.\npub struct Interrupter {\n    state: Arc<AtomicBool>,\n    message: Arc<Mutex<Option<String>>>,\n}\n\nimpl Interrupter {\n    /// Create a new instance.\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    /// Notify the learner that it should stop.\n    /// # Arguments\n    /// * `reason` - A string describing the reason the training was stopped.\n    pub fn stop(&self, reason: Option<&str>) {\n        self.state.store(true, Ordering::Relaxed);\n        reason.inspect(|r| {\n            let mut message = self.message.lock().unwrap();\n            *message = Some(String::from(*r));\n        });\n    }\n\n    /// Reset the interrupter.\n    pub fn reset(&self) {\n        self.state.store(false, Ordering::Relaxed);\n    }\n\n    /// True if .stop() has been called.\n    pub fn should_stop(&self) -> bool {\n        self.state.load(Ordering::Relaxed)\n    }\n\n    /// Get the message associated with the interrupt.\n    pub fn get_message(&self) -> Option<String> {\n        let message = self.message.lock().unwrap();\n        message.clone()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/classification.rs",
    "content": "use crate::metric::{\n    AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,\n    PerplexityInput, TopKAccuracyInput, processor::ItemLazy,\n};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{Int, Tensor, Transaction};\nuse burn_ndarray::NdArray;\n\n/// Simple classification output adapted for multiple metrics.\n///\n/// Supported metrics:\n/// - Accuracy\n/// - AUROC\n/// - TopKAccuracy\n/// - Perplexity\n/// - Precision (via ConfusionStatsInput)\n/// - Recall (via ConfusionStatsInput)\n/// - FBetaScore (via ConfusionStatsInput)\n/// - Loss.\n#[derive(new)]\npub struct ClassificationOutput<B: Backend> {\n    /// The loss.\n    pub loss: Tensor<B, 1>,\n\n    /// The class logits or probabilities. Shape: \\[batch_size, num_classes\\].\n    pub output: Tensor<B, 2>,\n\n    /// The ground truth class index for each sample. Shape: \\[batch_size\\].\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> ItemLazy for ClassificationOutput<B> {\n    type ItemSync = ClassificationOutput<NdArray>;\n\n    fn sync(self) -> Self::ItemSync {\n        let [output, loss, targets] = Transaction::default()\n            .register(self.output)\n            .register(self.loss)\n            .register(self.targets)\n            .execute()\n            .try_into()\n            .expect(\"Correct amount of tensor data\");\n\n        let device = &Default::default();\n\n        ClassificationOutput {\n            output: Tensor::from_data(output, device),\n            loss: Tensor::from_data(loss, device),\n            targets: Tensor::from_data(targets, device),\n        }\n    }\n}\n\nimpl<B: Backend> Adaptor<AccuracyInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> AccuracyInput<B> {\n        AccuracyInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<AurocInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> AurocInput<B> {\n        AurocInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> LossInput<B> {\n        LossInput::new(self.loss.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<TopKAccuracyInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> TopKAccuracyInput<B> {\n        TopKAccuracyInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<PerplexityInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> PerplexityInput<B> {\n        PerplexityInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {\n    fn adapt(&self) -> ConfusionStatsInput<B> {\n        let [_, num_classes] = self.output.dims();\n        if num_classes > 1 {\n            ConfusionStatsInput::new(\n                self.output.clone(),\n                self.targets.clone().one_hot(num_classes).bool(),\n            )\n        } else {\n            ConfusionStatsInput::new(\n                self.output.clone(),\n                self.targets.clone().unsqueeze_dim(1).bool(),\n            )\n        }\n    }\n}\n\n/// Multi-label classification output adapted for multiple metrics.\n///\n/// Supported metrics:\n/// - HammingScore\n/// - Precision (via ConfusionStatsInput)\n/// - Recall (via ConfusionStatsInput)\n/// - FBetaScore (via ConfusionStatsInput)\n/// - Loss\n#[derive(new)]\npub struct MultiLabelClassificationOutput<B: Backend> {\n    /// The loss.\n    pub loss: Tensor<B, 1>,\n\n    /// The label logits or probabilities. Shape: \\[batch_size, num_classes\\].\n    pub output: Tensor<B, 2>,\n\n    /// The ground truth labels. Shape: \\[batch_size, num_classes\\].\n    pub targets: Tensor<B, 2, Int>,\n}\n\nimpl<B: Backend> ItemLazy for MultiLabelClassificationOutput<B> {\n    type ItemSync = MultiLabelClassificationOutput<NdArray>;\n\n    fn sync(self) -> Self::ItemSync {\n        let [output, loss, targets] = Transaction::default()\n            .register(self.output)\n            .register(self.loss)\n            .register(self.targets)\n            .execute()\n            .try_into()\n            .expect(\"Correct amount of tensor data\");\n\n        let device = &Default::default();\n\n        MultiLabelClassificationOutput {\n            output: Tensor::from_data(output, device),\n            loss: Tensor::from_data(loss, device),\n            targets: Tensor::from_data(targets, device),\n        }\n    }\n}\n\nimpl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {\n    fn adapt(&self) -> HammingScoreInput<B> {\n        HammingScoreInput::new(self.output.clone(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {\n    fn adapt(&self) -> LossInput<B> {\n        LossInput::new(self.loss.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {\n    fn adapt(&self) -> ConfusionStatsInput<B> {\n        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/early_stopping.rs",
    "content": "use crate::metric::{\n    Metric, MetricName,\n    store::{Aggregate, Direction, EventStoreClient, Split},\n};\n\n/// The condition that [early stopping strategies](EarlyStoppingStrategy) should follow.\n#[derive(Clone)]\npub enum StoppingCondition {\n    /// When no improvement has happened since the given number of epochs.\n    NoImprovementSince {\n        /// The number of epochs allowed to worsen before it gets better.\n        n_epochs: usize,\n    },\n}\n\n/// A strategy that checks if the training should be stopped.\npub trait EarlyStoppingStrategy: Send {\n    /// Update its current state and returns if the training should be stopped.\n    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool;\n}\n\n/// A helper trait to provide type-erased cloning.\npub trait CloneEarlyStoppingStrategy: EarlyStoppingStrategy + Send {\n    /// Clone into a boxed trait object.\n    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy>;\n}\n\n/// Blanket-implement `CloneEarlyStoppingStrategy` for any `T` that\n/// already implements your strategy + `Clone` + `Send` + `'static`.\nimpl<T> CloneEarlyStoppingStrategy for T\nwhere\n    T: EarlyStoppingStrategy + Clone + Send + 'static,\n{\n    fn clone_box(&self) -> Box<dyn CloneEarlyStoppingStrategy> {\n        Box::new(self.clone())\n    }\n}\n\n/// Now you can `impl Clone` for the boxed trait object.\nimpl Clone for Box<dyn CloneEarlyStoppingStrategy> {\n    fn clone(&self) -> Box<dyn CloneEarlyStoppingStrategy> {\n        self.clone_box()\n    }\n}\n\n/// An [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected\n/// during training or validation.\n#[derive(Clone)]\npub struct MetricEarlyStoppingStrategy {\n    condition: StoppingCondition,\n    metric_name: MetricName,\n    aggregate: Aggregate,\n    direction: Direction,\n    split: Split,\n    best_epoch: usize,\n    best_value: f64,\n    warmup_epochs: Option<usize>,\n}\n\nimpl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {\n    fn should_stop(&mut self, epoch: usize, store: &EventStoreClient) -> bool {\n        let current_value =\n            match store.find_metric(&self.metric_name, epoch, self.aggregate, &self.split) {\n                Some(value) => value,\n                None => {\n                    log::warn!(\"Can't find metric for early stopping.\");\n                    return false;\n                }\n            };\n\n        let is_best = match self.direction {\n            Direction::Lowest => current_value < self.best_value,\n            Direction::Highest => current_value > self.best_value,\n        };\n\n        if is_best {\n            log::info!(\n                \"New best epoch found {} {}: {}\",\n                epoch,\n                self.metric_name,\n                current_value\n            );\n            self.best_value = current_value;\n            self.best_epoch = epoch;\n            return false;\n        }\n\n        if let Some(warmup_epochs) = self.warmup_epochs\n            && epoch <= warmup_epochs\n        {\n            return false;\n        }\n\n        match self.condition {\n            StoppingCondition::NoImprovementSince { n_epochs } => {\n                let should_stop = epoch - self.best_epoch >= n_epochs;\n\n                if should_stop {\n                    log::info!(\n                        \"Stopping training loop, no improvement since epoch {}, {}: {},  current \\\n                         epoch {}, {}: {}\",\n                        self.best_epoch,\n                        self.metric_name,\n                        self.best_value,\n                        epoch,\n                        self.metric_name,\n                        current_value\n                    );\n                }\n\n                should_stop\n            }\n        }\n    }\n}\n\nimpl MetricEarlyStoppingStrategy {\n    /// Create a new [early stopping strategy](EarlyStoppingStrategy) based on a metrics collected\n    /// during training or validation.\n    ///\n    /// # Notes\n    ///\n    /// The metric should be registered for early stopping to work, otherwise no data is collected.\n    pub fn new<Me: Metric>(\n        metric: &Me,\n        aggregate: Aggregate,\n        direction: Direction,\n        split: Split,\n        condition: StoppingCondition,\n    ) -> Self {\n        let init_value = match direction {\n            Direction::Lowest => f64::MAX,\n            Direction::Highest => f64::MIN,\n        };\n\n        Self {\n            metric_name: metric.name(),\n            condition,\n            aggregate,\n            direction,\n            split,\n            best_epoch: 1,\n            best_value: init_value,\n            warmup_epochs: None,\n        }\n    }\n\n    /// Get the warmup period.\n    ///\n    /// Early stopping will not trigger during the warmup epochs.\n    pub fn warmup_epochs(&self) -> Option<usize> {\n        self.warmup_epochs\n    }\n\n    /// Set the warmup epochs.\n    ///\n    /// Early stopping will not trigger during the warmup epochs.\n    ///\n    /// # Arguments\n    /// - `warmup`: the number of warmup epochs, or None.\n    pub fn with_warmup_epochs(self, warmup: Option<usize>) -> Self {\n        Self {\n            warmup_epochs: warmup,\n            ..self\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::sync::Arc;\n\n    use crate::{\n        EventProcessorTraining, TestBackend,\n        logger::InMemoryMetricLogger,\n        metric::{\n            LossMetric,\n            processor::{\n                MetricsTraining, MinimalEventProcessor,\n                test_utils::{end_epoch, process_train},\n            },\n            store::LogEventStore,\n        },\n    };\n\n    use super::*;\n\n    #[test]\n    fn never_early_stop_while_it_is_improving() {\n        test_early_stopping(\n            None,\n            1,\n            &[\n                (&[0.5, 0.3], false, \"Should not stop first epoch\"),\n                (&[0.4, 0.3], false, \"Should not stop when improving\"),\n                (&[0.3, 0.3], false, \"Should not stop when improving\"),\n                (&[0.2, 0.3], false, \"Should not stop when improving\"),\n            ],\n        );\n    }\n\n    #[test]\n    fn early_stop_when_no_improvement_since_two_epochs() {\n        test_early_stopping(\n            None,\n            2,\n            &[\n                (&[1.0, 0.5], false, \"Should not stop first epoch\"),\n                (&[0.5, 0.3], false, \"Should not stop when improving\"),\n                (\n                    &[1.0, 3.0],\n                    false,\n                    \"Should not stop first time it gets worse\",\n                ),\n                (\n                    &[1.0, 2.0],\n                    true,\n                    \"Should stop since two following epochs didn't improve\",\n                ),\n            ],\n        );\n    }\n\n    #[test]\n    fn early_stopping_with_warmup() {\n        test_early_stopping(\n            Some(3),\n            2,\n            &[\n                (&[1.0, 0.5], false, \"Should not stop during warmup\"),\n                (&[1.0, 0.5], false, \"Should not stop during warmup\"),\n                (&[1.0, 0.5], false, \"Should not stop during warmup\"),\n                (\n                    &[1.0, 0.5],\n                    true,\n                    \"Should stop when not improving after warmup\",\n                ),\n            ],\n        )\n    }\n\n    #[test]\n    fn early_stop_when_stays_equal() {\n        test_early_stopping(\n            None,\n            2,\n            &[\n                (&[0.5, 0.3], false, \"Should not stop first epoch\"),\n                (\n                    &[0.5, 0.3],\n                    false,\n                    \"Should not stop first time it stars the same\",\n                ),\n                (\n                    &[0.5, 0.3],\n                    true,\n                    \"Should stop since two following epochs didn't improve\",\n                ),\n            ],\n        );\n    }\n\n    fn test_early_stopping(warmup: Option<usize>, n_epochs: usize, data: &[(&[f64], bool, &str)]) {\n        let loss = LossMetric::<TestBackend>::new();\n        let mut early_stopping = MetricEarlyStoppingStrategy::new(\n            &loss,\n            Aggregate::Mean,\n            Direction::Lowest,\n            Split::Train,\n            StoppingCondition::NoImprovementSince { n_epochs },\n        )\n        .with_warmup_epochs(warmup);\n        let mut store = LogEventStore::default();\n        let mut metrics = MetricsTraining::<f64, f64>::default();\n\n        store.register_logger(InMemoryMetricLogger::default());\n        metrics.register_train_metric_numeric(loss);\n\n        let store = Arc::new(EventStoreClient::new(store));\n        let mut processor = MinimalEventProcessor::new(metrics, store.clone());\n\n        let mut epoch = 1;\n        processor.process_train(crate::LearnerEvent::Start);\n        for (points, should_start, comment) in data {\n            for point in points.iter() {\n                process_train(&mut processor, *point, epoch);\n            }\n            end_epoch(&mut processor, epoch);\n\n            assert_eq!(\n                *should_start,\n                early_stopping.should_stop(epoch, &store),\n                \"{comment}\"\n            );\n            epoch += 1;\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/mod.rs",
    "content": "#[cfg(feature = \"rl\")]\nmod rl;\n#[cfg(feature = \"rl\")]\npub use rl::*;\n\nmod application_logger;\nmod base;\nmod classification;\nmod early_stopping;\nmod regression;\nmod sequence;\nmod summary;\nmod supervised;\nmod train_val;\n\npub use application_logger::*;\npub use base::*;\npub use classification::*;\npub use early_stopping::*;\npub use regression::*;\npub use sequence::*;\npub use summary::*;\npub use supervised::*;\npub use train_val::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/regression.rs",
    "content": "use crate::metric::processor::ItemLazy;\nuse crate::metric::{Adaptor, LossInput};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{Tensor, Transaction};\nuse burn_ndarray::NdArray;\n\n/// Regression output adapted for the loss metric.\n#[derive(new)]\npub struct RegressionOutput<B: Backend> {\n    /// The loss.\n    pub loss: Tensor<B, 1>,\n\n    /// The predicted values. Shape: \\[batch_size, num_targets\\].\n    pub output: Tensor<B, 2>,\n\n    /// The ground truth values. Shape: \\[batch_size, num_targets\\].\n    pub targets: Tensor<B, 2>,\n}\n\nimpl<B: Backend> Adaptor<LossInput<B>> for RegressionOutput<B> {\n    fn adapt(&self) -> LossInput<B> {\n        LossInput::new(self.loss.clone())\n    }\n}\n\nimpl<B: Backend> ItemLazy for RegressionOutput<B> {\n    type ItemSync = RegressionOutput<NdArray>;\n\n    fn sync(self) -> Self::ItemSync {\n        let [output, loss, targets] = Transaction::default()\n            .register(self.output)\n            .register(self.loss)\n            .register(self.targets)\n            .execute()\n            .try_into()\n            .expect(\"Correct amount of tensor data\");\n\n        let device = &Default::default();\n\n        RegressionOutput {\n            output: Tensor::from_data(output, device),\n            loss: Tensor::from_data(loss, device),\n            targets: Tensor::from_data(targets, device),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/checkpointer.rs",
    "content": "use burn_core::tensor::Device;\nuse burn_rl::{Policy, PolicyLearner, PolicyState};\n\nuse crate::RLAgentRecord;\nuse crate::{\n    RLComponentsTypes, RLPolicyRecord,\n    checkpoint::Checkpointer,\n    checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},\n    metric::store::EventStoreClient,\n};\n\n#[derive(new)]\n/// Used to create, delete, or load checkpoints of the training process.\npub struct RLCheckpointer<RLC: RLComponentsTypes> {\n    policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,\n    learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,\n    strategy: Box<dyn CheckpointingStrategy>,\n}\n\nimpl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {\n    /// Create checkpoint for the training process.\n    pub fn checkpoint(\n        &mut self,\n        policy: &RLC::PolicyState,\n        learning_agent: &RLC::LearningAgent,\n        epoch: usize,\n        store: &EventStoreClient,\n    ) {\n        let actions = self.strategy.checkpointing(epoch, store);\n\n        for action in actions {\n            match action {\n                CheckpointingAction::Delete(epoch) => {\n                    self.policy\n                        .delete(epoch)\n                        .expect(\"Can delete policy checkpoint.\");\n                    self.learning_agent\n                        .delete(epoch)\n                        .expect(\"Can delete learning agent checkpoint.\")\n                }\n                CheckpointingAction::Save => {\n                    self.policy\n                        .save(epoch, policy.clone().into_record())\n                        .expect(\"Can save policy checkpoint.\");\n                    self.learning_agent\n                        .save(epoch, learning_agent.record())\n                        .expect(\"Can save learning agent checkpoint.\");\n                }\n            }\n        }\n    }\n\n    /// Load a training checkpoint.\n    pub fn load_checkpoint(\n        &self,\n        learning_agent: RLC::LearningAgent,\n        device: &Device<RLC::Backend>,\n        epoch: usize,\n    ) -> RLC::LearningAgent {\n        let record = self\n            .policy\n            .restore(epoch, device)\n            .expect(\"Can load model checkpoint.\");\n        let policy = learning_agent.policy().load_record(record);\n\n        let record = self\n            .learning_agent\n            .restore(epoch, device)\n            .expect(\"Can load learning agent checkpoint.\");\n        let mut learning_agent = learning_agent.load_record(record);\n        learning_agent.update_policy(policy);\n\n        learning_agent\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/components.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, PolicyState};\n\nuse crate::{AgentEvaluationEvent, AsyncProcessorTraining, ItemLazy, RLEvent};\n\n/// All components used by the reinforcement learning paradigm, grouped in one trait.\npub trait RLComponentsTypes {\n    /// The backend used for training.\n    type Backend: AutodiffBackend;\n    /// The learning environment.\n    type Env: Environment<State = Self::State, Action = Self::Action> + 'static;\n    /// Specifies how to initialize the environment.\n    type EnvInit: EnvironmentInit<Self::Env> + Send + 'static;\n    /// The type of the environment state.\n    type State: Into<<Self::Policy as Policy<Self::Backend>>::Observation> + Clone + Send + 'static;\n    /// The type of the environment action.\n    type Action: From<<Self::Policy as Policy<Self::Backend>>::Action>\n        + Into<<Self::Policy as Policy<Self::Backend>>::Action>\n        + Clone\n        + Send\n        + 'static;\n\n    /// The policy used to take actions in the environment.\n    type Policy: Policy<\n            Self::Backend,\n            Observation = Self::PolicyObs,\n            ActionDistribution = Self::PolicyAD,\n            Action = Self::PolicyAction,\n            ActionContext = Self::ActionContext,\n            PolicyState = Self::PolicyState,\n        > + Send\n        + 'static;\n    /// The policy's observation type.\n    type PolicyObs: Clone + Send + Batchable + 'static;\n    /// The policy's action distribution type.\n    type PolicyAD: Clone + Send + Batchable;\n    /// The policy's action type.\n    type PolicyAction: Clone + Send + Batchable;\n    /// Additional data as context for an agent's action.\n    type ActionContext: ItemLazy + Clone + Send + 'static;\n    /// The state of the parameterized policy.\n    type PolicyState: Clone + Send + PolicyState<Self::Backend> + 'static;\n\n    /// The learning agent.\n    type LearningAgent: PolicyLearner<\n            Self::Backend,\n            TrainContext = Self::TrainingOutput,\n            InnerPolicy = Self::Policy,\n        > + Send\n        + 'static;\n    /// The output data of a training step.\n    type TrainingOutput: ItemLazy + Clone + Send;\n}\n\n/// Concrete type that implements the [RLComponentsTypes](RLComponentsTypes) trait.\npub struct RLComponentsMarker<B, E, EI, A> {\n    _backend: PhantomData<B>,\n    _env: PhantomData<E>,\n    _env_init: PhantomData<EI>,\n    _agent: PhantomData<A>,\n}\n\nimpl<B, E, EI, A> RLComponentsTypes for RLComponentsMarker<B, E, EI, A>\nwhere\n    B: AutodiffBackend,\n    E: Environment + 'static,\n    EI: EnvironmentInit<E> + Send + 'static,\n    A: PolicyLearner<B> + Send + 'static,\n    A::TrainContext: ItemLazy + Clone + Send,\n    A::InnerPolicy: Policy<B> + Send,\n    <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,\n    <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,\n    E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,\n    E::Action: From<<A::InnerPolicy as Policy<B>>::Action>\n        + Into<<A::InnerPolicy as Policy<B>>::Action>\n        + Clone\n        + Send\n        + 'static,\n{\n    type Backend = B;\n    type Env = E;\n    type EnvInit = EI;\n    type LearningAgent = A;\n    type Policy = A::InnerPolicy;\n    type PolicyObs = <A::InnerPolicy as Policy<B>>::Observation;\n    type PolicyAD = <A::InnerPolicy as Policy<B>>::ActionDistribution;\n    type PolicyAction = <A::InnerPolicy as Policy<B>>::Action;\n    type ActionContext = <A::InnerPolicy as Policy<B>>::ActionContext;\n    type PolicyState = <A::InnerPolicy as Policy<B>>::PolicyState;\n    type TrainingOutput = A::TrainContext;\n    type State = E::State;\n    type Action = E::Action;\n}\n\npub(crate) type RlPolicy<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<\n    <RLC as RLComponentsTypes>::Backend,\n>>::InnerPolicy;\n/// The event processor type for reinforcement learning.\npub type RLEventProcessorType<RLC> = AsyncProcessorTraining<\n    RLEvent<<RLC as RLComponentsTypes>::TrainingOutput, <RLC as RLComponentsTypes>::ActionContext>,\n    AgentEvaluationEvent<<RLC as RLComponentsTypes>::ActionContext>,\n>;\n/// The record of the policy.\npub type RLPolicyRecord<RLC> = <<<RLC as RLComponentsTypes>::Policy as Policy<\n    <RLC as RLComponentsTypes>::Backend,\n>>::PolicyState as PolicyState<<RLC as RLComponentsTypes>::Backend>>::Record;\n/// The record of the learning agent.\npub type RLAgentRecord<RLC> = <<RLC as RLComponentsTypes>::LearningAgent as PolicyLearner<\n    <RLC as RLComponentsTypes>::Backend,\n>>::Record;\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/env_runner/async_runner.rs",
    "content": "use rand::prelude::SliceRandom;\nuse std::{\n    sync::mpsc::{Receiver, Sender},\n    thread::spawn,\n};\n\nuse burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};\nuse burn_rl::EnvironmentInit;\nuse burn_rl::Policy;\nuse burn_rl::Transition;\nuse burn_rl::{AsyncPolicy, Environment};\n\nuse crate::{\n    AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,\n    Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,\n    RlPolicy, TimeStep, Trajectory,\n};\n\nenum RequestMessage {\n    Step(),\n    Episode(),\n}\n\n/// Configuration for an async agent/environment loop.\npub struct AsyncAgentEnvLoopConfig {\n    /// If the loop is used for evaluation (as opposed to training).\n    pub eval: bool,\n    /// If the agent should take action deterministically.\n    pub deterministic: bool,\n    /// An arbitrary ID for the loop.\n    pub id: usize,\n}\n\n/// An asynchronous agent/environement interface.\npub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {\n    eval: bool,\n    agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,\n    transition_receiver: Receiver<RLTimeStep<BT, RLC>>,\n    trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,\n    request_sender: Sender<RequestMessage>,\n}\n\nimpl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {\n    /// Create a new asynchronous runner.\n    ///\n    /// # Arguments\n    ///\n    /// * `env_init` - A function returning an environment instance.\n    /// * `agent` - An [AsyncPolicy](AsyncPolicy) taking actions in the loop.\n    /// * `config` - An [AsyncAgentEnvLoopConfig](AsyncAgentEnvLoopConfig).\n    /// * `transition_sender` - Optional sender for transitions if you want to drive the requests from outside of the loop instance.\n    /// * `trajectory_sender` - Optional sender for trajectories if you want to drive the requests from outside of the loop instance.\n    ///\n    /// # Returns\n    ///\n    /// An async Agent/Environement loop.\n    pub fn new(\n        env_init: RLC::EnvInit,\n        agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,\n        config: AsyncAgentEnvLoopConfig,\n        transition_device: &Device<BT>,\n        transition_sender: Option<Sender<RLTimeStep<BT, RLC>>>,\n        trajectory_sender: Option<Sender<RLTrajectory<BT, RLC>>>,\n    ) -> Self {\n        let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel();\n        let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();\n        let (request_sender, request_receiver) = std::sync::mpsc::channel();\n        let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender);\n        let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender);\n\n        let device = transition_device.clone();\n        let mut loop_agent = agent.clone();\n        let eval = config.eval;\n\n        let mut current_steps = vec![];\n        let mut current_reward = 0.0;\n        let mut step_num = 0;\n        spawn(move || {\n            let mut env = env_init.init();\n            env.reset();\n\n            let mut request_episode = false;\n            loop {\n                let state = env.state();\n                let (action, context) =\n                    loop_agent.action(state.clone().into(), config.deterministic);\n\n                let env_action = RLC::Action::from(action);\n                let step_result = env.step(env_action.clone());\n\n                current_reward += step_result.reward;\n                step_num += 1;\n\n                let transition = Transition::new(\n                    state.clone(),\n                    step_result.next_state,\n                    env_action,\n                    Tensor::from_data([step_result.reward], &device),\n                    Tensor::from_data(\n                        [(step_result.done || step_result.truncated) as i32 as f64],\n                        &device,\n                    ),\n                );\n\n                if !request_episode {\n                    loop_agent.decrement_agents(1);\n                    let request = match request_receiver.recv() {\n                        Ok(req) => req,\n                        Err(err) => {\n                            log::error!(\"Error in env runner : {}\", err);\n                            break;\n                        }\n                    };\n                    loop_agent.increment_agents(1);\n\n                    match request {\n                        RequestMessage::Step() => (),\n                        RequestMessage::Episode() => request_episode = true,\n                    }\n                }\n\n                let time_step = TimeStep {\n                    env_id: config.id,\n                    transition,\n                    done: step_result.done,\n                    ep_len: step_num,\n                    cum_reward: current_reward,\n                    action_context: context[0].clone(),\n                };\n                current_steps.push(time_step.clone());\n\n                if !request_episode && let Err(err) = loop_transition_sender.send(time_step) {\n                    log::error!(\"Error in env runner : {}\", err);\n                    break;\n                }\n\n                if step_result.done || step_result.truncated {\n                    if request_episode {\n                        request_episode = false;\n                        loop_trajectory_sender\n                            .send(Trajectory {\n                                timesteps: current_steps.clone(),\n                            })\n                            .expect(\"Can send trajectory to main thread.\");\n                    }\n                    current_steps.clear();\n\n                    env.reset();\n                    current_reward = 0.;\n                    step_num = 0;\n                }\n            }\n        });\n\n        Self {\n            eval,\n            agent,\n            transition_receiver,\n            trajectory_receiver,\n            request_sender,\n        }\n    }\n}\n\nimpl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>\nwhere\n    BT: Backend,\n    RLC: RLComponentsTypes,\n{\n    fn run_steps(\n        &mut self,\n        num_steps: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTimeStep<BT, RLC>> {\n        let mut items = vec![];\n        for _ in 0..num_steps {\n            self.request_sender\n                .send(RequestMessage::Step())\n                .expect(\"Can request transitions.\");\n            let transition = self\n                .transition_receiver\n                .recv()\n                .expect(\"Can receive transitions.\");\n            items.push(transition.clone());\n\n            if !self.eval {\n                progress.items_processed += 1;\n                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(\n                    transition.action_context,\n                    progress.clone(),\n                    None,\n                )));\n\n                if transition.done {\n                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(\n                        EpisodeSummary {\n                            episode_length: transition.ep_len,\n                            cum_reward: transition.cum_reward,\n                        },\n                        progress.clone(),\n                        None,\n                    )));\n                }\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        items\n    }\n\n    fn run_episodes(\n        &mut self,\n        num_episodes: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        _progress: &mut Progress,\n    ) -> Vec<RLTrajectory<BT, RLC>> {\n        let mut items = vec![];\n        self.agent.increment_agents(1);\n        for episode_num in 0..num_episodes {\n            self.request_sender\n                .send(RequestMessage::Episode())\n                .expect(\"Can request episodes.\");\n            let trajectory = self\n                .trajectory_receiver\n                .recv()\n                .expect(\"Main thread can receive trajectory.\");\n\n            for (i, step) in trajectory.timesteps.iter().enumerate() {\n                // TODO : clean this.\n                if self.eval {\n                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(\n                        step.action_context.clone(),\n                        Progress::new(i, i),\n                        None,\n                    )));\n\n                    if step.done {\n                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(\n                            EvaluationItem::new(\n                                EpisodeSummary {\n                                    episode_length: step.ep_len,\n                                    cum_reward: step.cum_reward,\n                                },\n                                Progress::new(episode_num + 1, num_episodes),\n                                None,\n                            ),\n                        ));\n                    }\n                } else {\n                    processor.process_train(RLEvent::TimeStep(EvaluationItem::new(\n                        step.action_context.clone(),\n                        Progress::new(i, i),\n                        None,\n                    )));\n\n                    if step.done {\n                        processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(\n                            EpisodeSummary {\n                                episode_length: step.ep_len,\n                                cum_reward: step.cum_reward,\n                            },\n                            Progress::new(episode_num + 1, num_episodes),\n                            None,\n                        )));\n                    }\n                }\n            }\n\n            items.push(trajectory);\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        self.agent.decrement_agents(1);\n        items\n    }\n\n    fn update_policy(&mut self, update: RLC::PolicyState) {\n        self.agent.update(update);\n    }\n\n    fn policy(&self) -> RLC::PolicyState {\n        self.agent.state()\n    }\n}\n\n/// An asynchronous runner for multiple agent/environement interfaces.\npub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {\n    num_envs: usize,\n    eval: bool,\n    agent: AsyncPolicy<RLC::Backend, RLC::Policy>,\n    transition_receiver: Receiver<RLTimeStep<BT, RLC>>,\n    trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,\n    request_senders: Vec<Sender<RequestMessage>>,\n}\n\nimpl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {\n    /// Create a new asynchronous runner for multiple agent/environement interfaces.\n    pub fn new(\n        num_envs: usize,\n        env_init: RLC::EnvInit,\n        agent: AsyncPolicy<RLC::Backend, RLC::Policy>,\n        eval: bool,\n        deterministic: bool,\n        device: &Device<BT>,\n    ) -> Self {\n        let (transition_sender, transition_receiver) = std::sync::mpsc::channel();\n        let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();\n        let mut request_senders = vec![];\n\n        // Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.\n        agent.increment_agents(num_envs);\n\n        for i in 0..num_envs {\n            let config = AsyncAgentEnvLoopConfig {\n                eval,\n                deterministic,\n                id: i,\n            };\n            let runner = AgentEnvAsyncLoop::<BT, RLC>::new(\n                env_init.clone(),\n                agent.clone(),\n                config,\n                &device.clone(),\n                Some(transition_sender.clone()),\n                Some(trajectory_sender.clone()),\n            );\n            request_senders.push(runner.request_sender.clone());\n        }\n\n        // Double batching : The environments are always one step ahead.\n        request_senders.iter().for_each(|s| {\n            s.send(RequestMessage::Step())\n                .expect(\"Main thread can send step requests.\")\n        });\n\n        Self {\n            num_envs,\n            eval,\n            agent: agent.clone(),\n            transition_receiver,\n            trajectory_receiver,\n            request_senders,\n        }\n    }\n}\n\nimpl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>\nwhere\n    BT: Backend,\n    RLC: RLComponentsTypes,\n{\n    fn run_steps(\n        &mut self,\n        num_steps: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTimeStep<BT, RLC>> {\n        let mut items = vec![];\n        for _ in 0..num_steps {\n            let transition = self\n                .transition_receiver\n                .recv()\n                .expect(\"Can receive transitions.\");\n            items.push(transition.clone());\n\n            self.request_senders[transition.env_id]\n                .send(RequestMessage::Step())\n                .expect(\"Main thread can request steps.\");\n\n            if !self.eval {\n                progress.items_processed += 1;\n                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(\n                    transition.action_context,\n                    progress.clone(),\n                    None,\n                )));\n\n                if transition.done {\n                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(\n                        EpisodeSummary {\n                            episode_length: transition.ep_len,\n                            cum_reward: transition.cum_reward,\n                        },\n                        progress.clone(),\n                        None,\n                    )));\n                }\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        items\n    }\n\n    fn update_policy(&mut self, update: RLC::PolicyState) {\n        self.agent.update(update);\n    }\n\n    fn run_episodes(\n        &mut self,\n        num_episodes: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        _progress: &mut Progress,\n    ) -> Vec<RLTrajectory<BT, RLC>> {\n        // Send `num_episodes` initial requests.\n        let mut idx = vec![];\n        if num_episodes < self.num_envs {\n            let mut rng = rand::rng();\n            let mut vec: Vec<usize> = (0..self.num_envs).collect();\n            vec.shuffle(&mut rng);\n            idx = vec.into_iter().take(num_episodes).collect();\n        } else {\n            idx = (0..self.num_envs).collect();\n        }\n        let num_requests = self.num_envs.min(num_episodes);\n        idx.into_iter().for_each(|i| {\n            self.request_senders[i]\n                .send(RequestMessage::Episode())\n                .expect(\"Main thread can request steps.\");\n        });\n\n        let mut items = vec![];\n        for episode_num in 0..num_episodes {\n            let trajectory = self\n                .trajectory_receiver\n                .recv()\n                .expect(\"Can receive trajectory.\");\n            items.push(trajectory.clone());\n            if items.len() + num_requests <= num_episodes {\n                self.request_senders[trajectory.timesteps[0].env_id]\n                    .send(RequestMessage::Episode())\n                    .expect(\"Main thread can request steps.\");\n            }\n            for (i, step) in trajectory.timesteps.iter().enumerate() {\n                if self.eval {\n                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(\n                        step.action_context.clone(),\n                        Progress::new(i, i),\n                        None,\n                    )));\n\n                    if step.done {\n                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(\n                            EvaluationItem::new(\n                                EpisodeSummary {\n                                    episode_length: step.ep_len,\n                                    cum_reward: step.cum_reward,\n                                },\n                                Progress::new(episode_num + 1, num_episodes),\n                                None,\n                            ),\n                        ));\n                    }\n                } else {\n                    processor.process_train(RLEvent::TimeStep(EvaluationItem::new(\n                        step.action_context.clone(),\n                        Progress::new(i, i),\n                        None,\n                    )));\n\n                    if step.done {\n                        processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(\n                            EpisodeSummary {\n                                episode_length: step.ep_len,\n                                cum_reward: step.cum_reward,\n                            },\n                            Progress::new(episode_num + 1, num_episodes),\n                            None,\n                        )));\n                    }\n                }\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n\n        items\n    }\n\n    fn policy(&self) -> RLC::PolicyState {\n        self.agent.state()\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::needless_range_loop)]\nmod tests {\n    use burn_core::data::dataloader::Progress;\n    use burn_rl::AsyncPolicy;\n\n    use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig;\n    use crate::learner::rl::env_runner::base::AgentEnvLoop;\n    use crate::learner::tests::{MockPolicyState, MockProcessor};\n    use crate::{\n        AgentEnvAsyncLoop, TestBackend,\n        learner::tests::{MockEnvInit, MockPolicy, MockRLComponents},\n    };\n    use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop};\n\n    fn setup_async_loop(\n        state: usize,\n        eval: bool,\n        deterministic: bool,\n    ) -> AgentEnvAsyncLoop<TestBackend, MockRLComponents> {\n        let env_init = MockEnvInit;\n        let agent = MockPolicy(state);\n        let config = AsyncAgentEnvLoopConfig {\n            eval,\n            deterministic,\n            id: 0,\n        };\n        AgentEnvAsyncLoop::<TestBackend, MockRLComponents>::new(\n            env_init,\n            AsyncPolicy::new(1, agent),\n            config,\n            &Default::default(),\n            None,\n            None,\n        )\n    }\n\n    fn setup_multi_loop(\n        num_envs: usize,\n        autobatch_size: usize,\n        state: usize,\n        eval: bool,\n        deterministic: bool,\n    ) -> MultiAgentEnvLoop<TestBackend, MockRLComponents> {\n        let env_init = MockEnvInit;\n        let agent = MockPolicy(state);\n        MultiAgentEnvLoop::<TestBackend, MockRLComponents>::new(\n            num_envs,\n            env_init,\n            AsyncPolicy::new(autobatch_size, agent),\n            eval,\n            deterministic,\n            &Default::default(),\n        )\n    }\n\n    #[test]\n    fn test_policy_async_loop() {\n        let runner = setup_async_loop(1000, false, false);\n        let policy_state = runner.policy();\n        assert_eq!(policy_state.0, 1000);\n    }\n\n    #[test]\n    fn test_update_policy_async_loop() {\n        let mut runner = setup_async_loop(0, false, false);\n\n        runner.update_policy(MockPolicyState(1));\n        assert_eq!(runner.policy().0, 1);\n    }\n\n    #[test]\n    fn run_steps_returns_requested_number_async_loop() {\n        let mut runner = setup_async_loop(0, false, false);\n        let mut processor = AsyncProcessorTraining::new(MockProcessor);\n        let interrupter = Interrupter::new();\n        let mut progress = Progress {\n            items_processed: 0,\n            items_total: 1,\n        };\n\n        let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);\n        assert_eq!(steps.len(), 1);\n        let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);\n        assert_eq!(steps.len(), 8);\n    }\n\n    #[test]\n    fn run_episodes_returns_requested_number_async_loop() {\n        let mut runner = setup_async_loop(0, false, false);\n        let mut processor = AsyncProcessorTraining::new(MockProcessor);\n        let interrupter = Interrupter::new();\n        let mut progress = Progress {\n            items_processed: 0,\n            items_total: 1,\n        };\n\n        let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);\n        assert_eq!(trajectories.len(), 1);\n        assert_ne!(trajectories[0].timesteps.len(), 0);\n        let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);\n        assert_eq!(trajectories.len(), 8);\n        for i in 0..8 {\n            assert_ne!(trajectories[i].timesteps.len(), 0);\n        }\n    }\n\n    #[test]\n    fn test_policy_multi_loop() {\n        let runner = setup_multi_loop(4, 4, 1000, false, false);\n        let policy_state = runner.policy();\n        assert_eq!(policy_state.0, 1000);\n    }\n\n    #[test]\n    fn test_update_policy_multi_loop() {\n        let mut runner = setup_multi_loop(4, 4, 0, false, false);\n\n        runner.update_policy(MockPolicyState(1));\n        assert_eq!(runner.policy().0, 1);\n    }\n\n    #[test]\n    fn run_steps_returns_requested_number_multi_loop() {\n        fn run_test(num_envs: usize, autobatch_size: usize) {\n            let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);\n            let mut processor = AsyncProcessorTraining::new(MockProcessor);\n            let interrupter = Interrupter::new();\n            let mut progress = Progress {\n                items_processed: 0,\n                items_total: 1,\n            };\n\n            // Kickstart tests by running some steps to make sure it's not a double batching edge case success.\n            let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);\n            assert_eq!(steps.len(), 8);\n\n            for i in 0..16 {\n                let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress);\n                assert_eq!(steps.len(), i);\n            }\n        }\n\n        // num_envs == autobatch_size\n        run_test(1, 1);\n        run_test(4, 4);\n        // num_envs < autobatch_size\n        run_test(1, 2);\n        run_test(1, 3);\n        run_test(2, 3);\n        run_test(2, 4);\n        run_test(5, 19);\n        // num_envs > autobatch_size\n        run_test(2, 1);\n        run_test(8, 1);\n        run_test(3, 2);\n        run_test(8, 2);\n        run_test(8, 3);\n        run_test(8, 7);\n    }\n\n    #[test]\n    fn run_episodes_returns_requested_number_multi_loop() {\n        fn run_test(num_envs: usize, autobatch_size: usize) {\n            let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);\n            let mut processor = AsyncProcessorTraining::new(MockProcessor);\n            let interrupter = Interrupter::new();\n            let mut progress = Progress {\n                items_processed: 0,\n                items_total: 1,\n            };\n\n            // Kickstart tests by running some episodes to make sure it's not a double batching edge case success.\n            let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);\n            assert_eq!(trajectories.len(), 8);\n            for j in 0..8 {\n                assert_ne!(trajectories[j].timesteps.len(), 0);\n            }\n\n            for i in 0..16 {\n                let trajectories =\n                    runner.run_episodes(i, &mut processor, &interrupter, &mut progress);\n                assert_eq!(trajectories.len(), i);\n                for j in 0..i {\n                    assert_ne!(trajectories[j].timesteps.len(), 0);\n                }\n            }\n        }\n\n        // num_envs == autobatch_size\n        run_test(1, 1);\n        run_test(4, 4);\n        // num_envs < autobatch_size\n        run_test(1, 2);\n        run_test(1, 3);\n        run_test(2, 3);\n        run_test(2, 4);\n        run_test(5, 19);\n        // num_envs > autobatch_size\n        run_test(2, 1);\n        run_test(8, 1);\n        run_test(3, 2);\n        run_test(8, 2);\n        run_test(8, 3);\n        run_test(8, 7);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/env_runner/base.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn_core::data::dataloader::Progress;\nuse burn_core::{Tensor, prelude::Backend};\nuse burn_rl::Policy;\nuse burn_rl::Transition;\nuse burn_rl::{Environment, EnvironmentInit};\n\nuse crate::RLEvent;\nuse crate::{\n    AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,\n    RLEventProcessorType,\n};\nuse crate::{Interrupter, RLComponentsTypes};\n\n/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).\n#[derive(Clone, new)]\npub struct Trajectory<B: Backend, S, A, C> {\n    /// A list of ordered [TimeStep](TimeStep)s.\n    pub timesteps: Vec<TimeStep<B, S, A, C>>,\n}\n\n/// A timestep debscribing an iteration of the state/decision process.\n#[derive(Clone)]\npub struct TimeStep<B: Backend, S, A, C> {\n    /// The environment id.\n    pub env_id: usize,\n    /// The [burn_rl::Transition](burn_rl::Transition).\n    pub transition: Transition<B, S, A>,\n    /// True if the environment reaches a terminal state.\n    pub done: bool,\n    /// The running length of the current episode.\n    pub ep_len: usize,\n    /// The running cumulative reward.\n    pub cum_reward: f64,\n    /// The action's context for this timestep.\n    pub action_context: C,\n}\n\npub(crate) type RLTimeStep<B, RLC> = TimeStep<\n    B,\n    <RLC as RLComponentsTypes>::State,\n    <RLC as RLComponentsTypes>::Action,\n    <RLC as RLComponentsTypes>::ActionContext,\n>;\n\npub(crate) type RLTrajectory<B, RLC> = Trajectory<\n    B,\n    <RLC as RLComponentsTypes>::State,\n    <RLC as RLComponentsTypes>::Action,\n    <RLC as RLComponentsTypes>::ActionContext,\n>;\n\n/// Trait for a structure that implements an agent/environement interface.\npub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {\n    /// Run a certain number of timesteps.\n    ///\n    /// # Arguments\n    ///\n    /// * `num_steps` - The number of time_steps to run.\n    /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).\n    /// * `interrupter` - An [crate::Interrupter](crate::Interrupter).\n    /// * `num_steps` - The number of time_steps to run.\n    /// * `progress` - A mutable reference to the learning progress.\n    ///\n    /// # Returns\n    ///\n    /// A list of ordered timesteps.\n    fn run_steps(\n        &mut self,\n        num_steps: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTimeStep<BT, RLC>>;\n    /// Run a certain number of episodes.\n    ///\n    /// # Arguments\n    ///\n    /// * `num_episodes` - The number of episodes to run.\n    /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).\n    /// * `interrupter` - An [crate::Interrupter](crate::Interrupter).\n    /// * `progress` - A mutable reference to the learning progress.\n    ///\n    /// # Returns\n    ///\n    /// A list of ordered timesteps.\n    fn run_episodes(\n        &mut self,\n        num_episodes: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTrajectory<BT, RLC>>;\n    /// Update the runner's agent.\n    fn update_policy(&mut self, update: RLC::PolicyState);\n    /// Get the state of the runner's agent.\n    fn policy(&self) -> RLC::PolicyState;\n}\n\n/// A simple, synchronized agent/environement interface.\npub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {\n    env: RLC::Env,\n    eval: bool,\n    agent: RLC::Policy,\n    deterministic: bool,\n    current_reward: f64,\n    run_num: usize,\n    step_num: usize,\n    _backend: PhantomData<B>,\n}\n\nimpl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {\n    /// Create a new base runner.\n    pub fn new(\n        env_init: RLC::EnvInit,\n        agent: RLC::Policy,\n        eval: bool,\n        deterministic: bool,\n    ) -> Self {\n        let mut env = env_init.init();\n        env.reset();\n\n        Self {\n            env,\n            eval,\n            agent: agent.clone(),\n            deterministic,\n            current_reward: 0.0,\n            run_num: 0,\n            step_num: 0,\n            _backend: PhantomData,\n        }\n    }\n}\n\nimpl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>\nwhere\n    BT: Backend,\n    RLC: RLComponentsTypes,\n{\n    fn run_steps(\n        &mut self,\n        num_steps: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTimeStep<BT, RLC>> {\n        let mut items = vec![];\n        let device = Default::default();\n        for _ in 0..num_steps {\n            let state = self.env.state();\n            let (action, context) = self.agent.action(state.clone().into(), self.deterministic);\n\n            let step_result = self.env.step(RLC::Action::from(action.clone()));\n\n            self.current_reward += step_result.reward;\n            self.step_num += 1;\n\n            let transition = Transition::new(\n                state.clone(),\n                step_result.next_state,\n                RLC::Action::from(action),\n                Tensor::from_data([step_result.reward], &device),\n                Tensor::from_data(\n                    [(step_result.done || step_result.truncated) as i32 as f64],\n                    &device,\n                ),\n            );\n            items.push(TimeStep {\n                env_id: 0,\n                transition,\n                done: step_result.done,\n                ep_len: self.step_num,\n                cum_reward: self.current_reward,\n                action_context: context[0].clone(),\n            });\n\n            if !self.eval {\n                progress.items_processed += 1;\n                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(\n                    context[0].clone(),\n                    progress.clone(),\n                    None,\n                )));\n\n                if step_result.done {\n                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(\n                        EpisodeSummary {\n                            episode_length: self.step_num,\n                            cum_reward: self.current_reward,\n                        },\n                        progress.clone(),\n                        None,\n                    )));\n                }\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n\n            if step_result.done || step_result.truncated {\n                self.env.reset();\n                self.current_reward = 0.;\n                self.step_num = 0;\n                self.run_num += 1;\n            }\n        }\n        items\n    }\n\n    fn update_policy(&mut self, update: RLC::PolicyState) {\n        self.agent.update(update);\n    }\n\n    fn run_episodes(\n        &mut self,\n        num_episodes: usize,\n        processor: &mut RLEventProcessorType<RLC>,\n        interrupter: &Interrupter,\n        progress: &mut Progress,\n    ) -> Vec<RLTrajectory<BT, RLC>> {\n        self.env.reset();\n\n        let mut items = vec![];\n        for ep in 0..num_episodes {\n            let mut steps = vec![];\n            loop {\n                let step = self.run_steps(1, processor, interrupter, progress)[0].clone();\n                steps.push(step.clone());\n\n                if self.eval {\n                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(\n                        step.action_context.clone(),\n                        Progress::new(steps.len() + 1, steps.len() + 1),\n                        None,\n                    )));\n\n                    if step.done {\n                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(\n                            EvaluationItem::new(\n                                EpisodeSummary {\n                                    episode_length: step.ep_len,\n                                    cum_reward: step.cum_reward,\n                                },\n                                Progress::new(ep + 1, num_episodes),\n                                None,\n                            ),\n                        ));\n                    }\n                }\n\n                if interrupter.should_stop() || step.done {\n                    break;\n                }\n            }\n            items.push(Trajectory::new(steps));\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        items\n    }\n\n    fn policy(&self) -> RLC::PolicyState {\n        self.agent.state()\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::needless_range_loop)]\nmod tests {\n    use crate::{AsyncProcessorTraining, TestBackend};\n\n    use crate::learner::tests::{\n        MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents,\n    };\n\n    use super::*;\n\n    fn setup(\n        state: usize,\n        eval: bool,\n        deterministic: bool,\n    ) -> AgentEnvBaseLoop<TestBackend, MockRLComponents> {\n        let env_init = MockEnvInit;\n        let agent = MockPolicy(state);\n        AgentEnvBaseLoop::<TestBackend, MockRLComponents>::new(env_init, agent, eval, deterministic)\n    }\n\n    #[test]\n    fn test_policy_returns_agent_state() {\n        let runner = setup(1000, false, false);\n        let policy_state = runner.policy();\n        assert_eq!(policy_state.0, 1000);\n    }\n\n    #[test]\n    fn test_update_policy() {\n        let mut runner = setup(0, false, false);\n\n        runner.update_policy(MockPolicyState(1));\n        assert_eq!(runner.policy().0, 1);\n    }\n\n    #[test]\n    fn run_steps_returns_requested_number() {\n        let mut runner = setup(0, false, false);\n        let mut processor = AsyncProcessorTraining::new(MockProcessor);\n        let interrupter = Interrupter::new();\n        let mut progress = Progress {\n            items_processed: 0,\n            items_total: 1,\n        };\n\n        let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);\n        assert_eq!(steps.len(), 1);\n        let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);\n        assert_eq!(steps.len(), 8);\n    }\n\n    #[test]\n    fn run_episodes_returns_requested_number() {\n        let mut runner = setup(0, false, false);\n        let mut processor = AsyncProcessorTraining::new(MockProcessor);\n        let interrupter = Interrupter::new();\n        let mut progress = Progress {\n            items_processed: 0,\n            items_total: 1,\n        };\n\n        let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);\n        assert_eq!(trajectories.len(), 1);\n        assert_ne!(trajectories[0].timesteps.len(), 0);\n        let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);\n        assert_eq!(trajectories.len(), 8);\n        for i in 0..8 {\n            assert_ne!(trajectories[i].timesteps.len(), 0);\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/env_runner/mod.rs",
    "content": "mod async_runner;\nmod base;\n\npub use async_runner::*;\npub use base::*;\n\n#[cfg(test)]\npub(crate) mod tests {\n    use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyState};\n\n    use crate::tests::TestAutodiffBackend;\n    use crate::{\n        AgentEvaluationEvent, EventProcessorTraining, ItemLazy, RLComponentsTypes, RLEvent,\n    };\n    use burn_rl::{LearnerTransitionBatch, PolicyLearner, RLTrainOutput, StepResult};\n\n    /// Mock policy for testing\n    ///\n    /// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)\n    /// containing a list of 0s of the same length as the observation.\n    ///\n    /// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockPolicyAction](MockPolicyAction) with a list of actions of the same length as the observation.\n    /// The actions are all 1 if the call is requested as deterministic, or else 0.\n\n    #[derive(Clone)]\n    pub(crate) struct MockPolicy(pub usize);\n\n    impl Policy<TestAutodiffBackend> for MockPolicy {\n        type Observation = MockObservation;\n        type ActionDistribution = MockActionDistribution;\n        type Action = MockPolicyAction;\n        type ActionContext = MockActionContext;\n        type PolicyState = MockPolicyState;\n\n        fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {\n            let mut dists = vec![];\n            for _ in obs.0 {\n                dists.push(MockActionDistribution(vec![0.]));\n            }\n            MockActionDistribution::batch(dists)\n        }\n\n        fn action(\n            &mut self,\n            obs: Self::Observation,\n            deterministic: bool,\n        ) -> (Self::Action, Vec<Self::ActionContext>) {\n            let mut actions = vec![];\n            let mut contexts = vec![];\n\n            for _ in obs.0 {\n                if deterministic {\n                    actions.push(MockPolicyAction(vec![1]));\n                } else {\n                    actions.push(MockPolicyAction(vec![0]));\n                }\n                contexts.push(MockActionContext);\n            }\n\n            (MockPolicyAction::batch(actions), contexts)\n        }\n\n        fn update(&mut self, update: Self::PolicyState) {\n            self.0 = update.0;\n        }\n\n        fn state(&self) -> Self::PolicyState {\n            MockPolicyState(self.0)\n        }\n\n        fn load_record(\n            self,\n            _record: <Self::PolicyState as PolicyState<TestAutodiffBackend>>::Record,\n        ) -> Self {\n            self\n        }\n    }\n\n    /// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockObservation(pub Vec<f32>);\n\n    /// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockPolicyAction(pub Vec<i32>);\n\n    /// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.\n    #[derive(Clone)]\n    pub(crate) struct MockActionDistribution(Vec<f32>);\n\n    #[derive(Clone)]\n    pub(crate) struct MockActionContext;\n\n    /// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.\n    #[derive(Clone)]\n    pub(crate) struct MockPolicyState(pub usize);\n\n    impl PolicyState<TestAutodiffBackend> for MockPolicyState {\n        type Record = ();\n\n        fn into_record(self) -> Self::Record {}\n\n        fn load_record(&self, _record: Self::Record) -> Self {\n            self.clone()\n        }\n    }\n\n    impl Batchable for MockObservation {\n        fn batch(items: Vec<Self>) -> Self {\n            MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            vec![MockObservation(self.0)]\n        }\n    }\n\n    impl Batchable for MockPolicyAction {\n        fn batch(items: Vec<Self>) -> Self {\n            MockPolicyAction(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            let mut actions = vec![];\n            for a in self.0 {\n                actions.push(MockPolicyAction(vec![a]));\n            }\n            actions\n        }\n    }\n\n    impl Batchable for MockActionDistribution {\n        fn batch(items: Vec<Self>) -> Self {\n            MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())\n        }\n\n        fn unbatch(self) -> Vec<Self> {\n            let mut dists = vec![];\n            for _ in self.0 {\n                dists.push(MockActionDistribution(vec![0.]));\n            }\n            dists\n        }\n    }\n\n    /// Mock environment for testing\n    #[derive(Clone)]\n    pub(crate) struct MockEnv {\n        counter: usize,\n    }\n\n    #[derive(Clone, Debug)]\n    pub(crate) struct MockState;\n\n    #[derive(Clone, Debug)]\n    pub(crate) struct MockAction(pub i32);\n\n    impl From<MockState> for MockObservation {\n        fn from(_value: MockState) -> Self {\n            MockObservation(vec![0.])\n        }\n    }\n\n    impl From<MockPolicyAction> for MockAction {\n        fn from(value: MockPolicyAction) -> Self {\n            MockAction(value.0[0])\n        }\n    }\n\n    impl From<MockAction> for MockPolicyAction {\n        fn from(value: MockAction) -> Self {\n            MockPolicyAction(vec![value.0])\n        }\n    }\n\n    impl ItemLazy for MockActionContext {\n        type ItemSync = MockActionContext;\n\n        fn sync(self) -> Self::ItemSync {\n            self\n        }\n    }\n\n    impl MockEnv {\n        fn new() -> Self {\n            Self { counter: 0 }\n        }\n    }\n\n    impl Environment for MockEnv {\n        type State = MockState;\n        type Action = MockAction;\n        const MAX_STEPS: usize = 5;\n\n        fn reset(&mut self) {\n            self.counter = 0;\n        }\n\n        fn step(&mut self, _action: Self::Action) -> StepResult<Self::State> {\n            self.counter += 1;\n            let done = self.counter >= Self::MAX_STEPS;\n\n            burn_rl::StepResult {\n                next_state: MockState,\n                reward: 1.0,\n                done,\n                truncated: false,\n            }\n        }\n\n        fn state(&self) -> Self::State {\n            MockState\n        }\n    }\n\n    /// Mock environment init for testing\n    #[derive(Clone)]\n    pub(crate) struct MockEnvInit;\n\n    impl EnvironmentInit<MockEnv> for MockEnvInit {\n        fn init(&self) -> MockEnv {\n            MockEnv::new()\n        }\n    }\n\n    // Mock RLComponentsTypes for testing\n    pub(crate) struct MockRLComponents;\n\n    impl RLComponentsTypes for MockRLComponents {\n        type Backend = TestAutodiffBackend;\n        type Env = MockEnv;\n        type EnvInit = MockEnvInit;\n        type State = MockState;\n        type Action = MockAction;\n        type Policy = MockPolicy;\n        type PolicyObs = MockObservation;\n        type PolicyAD = MockActionDistribution;\n        type PolicyAction = MockPolicyAction;\n        type ActionContext = MockActionContext;\n        type PolicyState = MockPolicyState;\n        type LearningAgent = MockLearningAgent;\n        type TrainingOutput = ();\n    }\n\n    // Mock learning agent for testing\n    #[derive(Clone)]\n    pub(crate) struct MockLearningAgent;\n\n    impl PolicyLearner<TestAutodiffBackend> for MockLearningAgent {\n        type InnerPolicy = MockPolicy;\n        type TrainContext = ();\n        type Record = ();\n\n        fn train(\n            &mut self,\n            _input: LearnerTransitionBatch<TestAutodiffBackend, Self::InnerPolicy>,\n        ) -> RLTrainOutput<\n            Self::TrainContext,\n            <Self::InnerPolicy as Policy<TestAutodiffBackend>>::PolicyState,\n        > {\n            unimplemented!()\n        }\n\n        fn policy(&self) -> Self::InnerPolicy {\n            unimplemented!()\n        }\n\n        fn update_policy(&mut self, _update: Self::InnerPolicy) {\n            unimplemented!()\n        }\n\n        fn record(&self) -> Self::Record {\n            unimplemented!()\n        }\n\n        fn load_record(self, _record: Self::Record) -> Self {\n            unimplemented!()\n        }\n    }\n\n    // Mock event processor for testing\n    pub(crate) struct MockProcessor;\n\n    impl\n        EventProcessorTraining<\n            RLEvent<(), MockActionContext>,\n            AgentEvaluationEvent<MockActionContext>,\n        > for MockProcessor\n    {\n        fn process_train(&mut self, _event: RLEvent<(), MockActionContext>) {\n            // Mock process train\n        }\n\n        fn process_valid(&mut self, _event: AgentEvaluationEvent<MockActionContext>) {\n            // Mock process valid\n        }\n\n        fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {\n            unimplemented!()\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/mod.rs",
    "content": "mod checkpointer;\nmod components;\nmod env_runner;\nmod off_policy;\nmod output;\nmod paradigm;\nmod strategy;\n\npub use checkpointer::*;\npub use components::*;\npub use env_runner::*;\npub use off_policy::*;\npub use output::*;\npub use paradigm::*;\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/off_policy.rs",
    "content": "use std::marker::PhantomData;\n\nuse crate::{\n    AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,\n    EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,\n    RLEventProcessorType, RLStrategy,\n};\nuse burn_core::{self as burn};\nuse burn_core::{config::Config, data::dataloader::Progress};\nuse burn_ndarray::NdArray;\nuse burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};\n\n/// Parameters of an on policy training with multi environments and double-batching.\n#[derive(Config, Debug)]\npub struct OffPolicyConfig {\n    /// The number of environments to run simultaneously for experience collection.\n    #[config(default = 1)]\n    pub num_envs: usize,\n    /// Number of environment state to accumulate before running one step of inference with the policy.\n    /// Must be equal or less than the number of simultaneous environments.\n    #[config(default = 1)]\n    pub autobatch_size: usize,\n    /// Max number of transitions stored in the replay buffer.\n    #[config(default = 1024)]\n    pub replay_buffer_size: usize,\n    /// The number of steps to collect between each step of training.\n    #[config(default = 1)]\n    pub train_interval: usize,\n    /// Number of optimization steps done each `train_interval`.\n    #[config(default = 1)]\n    pub train_steps: usize,\n    /// The number of steps to collect between each evaluation.\n    #[config(default = 10_000)]\n    pub eval_interval: usize,\n    /// The number of episodes to run for each evaluation.\n    #[config(default = 1)]\n    pub eval_episodes: usize,\n    /// The number of transition to train on.\n    #[config(default = 32)]\n    pub train_batch_size: usize,\n    /// Number of steps to collect before starting to train.\n    #[config(default = 0)]\n    pub warmup_steps: usize,\n}\n\n/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.\npub struct OffPolicyStrategy<RLC: RLComponentsTypes> {\n    config: OffPolicyConfig,\n    _components: PhantomData<RLC>,\n}\nimpl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {\n    /// Create a new off-policy base strategy.\n    pub fn new(config: OffPolicyConfig) -> Self {\n        Self {\n            config,\n            _components: PhantomData,\n        }\n    }\n}\n\nimpl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>\nwhere\n    RLC: RLComponentsTypes,\n    RLC::PolicyObs: SliceAccess<RLC::Backend>,\n    RLC::PolicyAction: SliceAccess<RLC::Backend>,\n{\n    fn train_loop(\n        &self,\n        training_components: RLComponents<RLC>,\n        learner_agent: &mut RLC::LearningAgent,\n        starting_epoch: usize,\n        env_init: RLC::EnvInit,\n    ) -> (RLC::Policy, RLEventProcessorType<RLC>) {\n        let mut event_processor = training_components.event_processor;\n        let mut checkpointer = training_components.checkpointer;\n        let num_steps_total = training_components.num_steps;\n\n        let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(\n            self.config.num_envs,\n            env_init.clone(),\n            AsyncPolicy::new(\n                self.config.num_envs.min(self.config.autobatch_size),\n                learner_agent.policy(),\n            ),\n            false,\n            false,\n            &Default::default(),\n        );\n        let runner_config = AsyncAgentEnvLoopConfig {\n            eval: true,\n            deterministic: true,\n            id: 0,\n        };\n        let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(\n            env_init,\n            AsyncPolicy::new(1, learner_agent.policy()),\n            runner_config,\n            &Default::default(),\n            None,\n            None,\n        );\n\n        let device: <RLC::Backend as burn_core::prelude::Backend>::Device = Default::default();\n        let mut transition_buffer = TransitionBuffer::<\n            RLC::Backend,\n            RLC::PolicyObs,\n            RLC::PolicyAction,\n        >::new(self.config.replay_buffer_size, &device);\n\n        let mut valid_next = self.config.eval_interval + starting_epoch - 1;\n        let mut progress = Progress {\n            items_processed: starting_epoch,\n            items_total: num_steps_total,\n        };\n\n        let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =\n            None;\n        while progress.items_processed < num_steps_total {\n            if training_components.interrupter.should_stop() {\n                let reason = training_components\n                    .interrupter\n                    .get_message()\n                    .unwrap_or(String::from(\"Reason unknown\"));\n                log::info!(\"Training interrupted: {reason}\");\n                break;\n            }\n\n            let previous_steps = progress.items_processed;\n            let items = env_runner.run_steps(\n                self.config.train_interval,\n                &mut event_processor,\n                &training_components.interrupter,\n                &mut progress,\n            );\n\n            for item in &items {\n                let t = &item.transition;\n                let state: RLC::PolicyObs = t.state.clone().into();\n                let next_state: RLC::PolicyObs = t.next_state.clone().into();\n                let action: RLC::PolicyAction = t.action.clone().into();\n                let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];\n                let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;\n                transition_buffer.push(state, next_state, action, reward, done);\n            }\n\n            if transition_buffer.len() >= self.config.train_batch_size\n                && progress.items_processed >= self.config.warmup_steps\n            {\n                if let Some(ref u) = intermediary_update {\n                    env_runner.update_policy(u.clone());\n                }\n                for _ in 0..self.config.train_steps {\n                    let batch = transition_buffer.sample(self.config.train_batch_size);\n                    let train_item = learner_agent.train(batch);\n                    intermediary_update = Some(train_item.policy);\n\n                    event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(\n                        train_item.item,\n                        progress.clone(),\n                        None,\n                    )));\n                }\n            }\n\n            if valid_next > previous_steps && valid_next <= progress.items_processed {\n                env_runner_valid.update_policy(learner_agent.policy().state());\n                env_runner_valid.run_episodes(\n                    self.config.eval_episodes,\n                    &mut event_processor,\n                    &training_components.interrupter,\n                    &mut progress,\n                );\n\n                if let Some(checkpointer) = &mut checkpointer {\n                    checkpointer.checkpoint(\n                        &env_runner.policy(),\n                        learner_agent,\n                        valid_next,\n                        &training_components.event_store,\n                    );\n                }\n\n                valid_next += self.config.eval_interval;\n            }\n        }\n\n        (learner_agent.policy(), event_processor)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/output.rs",
    "content": "use crate::{\n    ItemLazy,\n    metric::{Adaptor, CumulativeRewardInput, EpisodeLengthInput},\n};\n\n/// Summary of an episode.\npub struct EpisodeSummary {\n    /// The total length of the episode.\n    pub episode_length: usize,\n    /// The final cumulative reward.\n    pub cum_reward: f64,\n}\n\nimpl ItemLazy for EpisodeSummary {\n    type ItemSync = EpisodeSummary;\n\n    fn sync(self) -> Self::ItemSync {\n        self\n    }\n}\n\nimpl Adaptor<EpisodeLengthInput> for EpisodeSummary {\n    fn adapt(&self) -> EpisodeLengthInput {\n        EpisodeLengthInput::new(self.episode_length as f64)\n    }\n}\n\nimpl Adaptor<CumulativeRewardInput> for EpisodeSummary {\n    fn adapt(&self) -> CumulativeRewardInput {\n        CumulativeRewardInput::new(self.cum_reward)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/paradigm.rs",
    "content": "use crate::checkpoint::{\n    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,\n    KeepLastNCheckpoints, MetricCheckpointingStrategy,\n};\nuse crate::learner::base::Interrupter;\nuse crate::logger::{FileMetricLogger, MetricLogger};\nuse crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};\nuse crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};\nuse crate::renderer::{MetricsRenderer, default_renderer};\nuse crate::{\n    ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,\n    LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,\n    RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,\n    RLPolicyRecord, RLStrategy,\n};\nuse crate::{EpisodeSummary, RLStrategies};\nuse burn_core::record::FileRecorder;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess};\nuse std::collections::BTreeSet;\nuse std::path::{Path, PathBuf};\nuse std::sync::Arc;\n\n/// Structure to configure and launch reinforcement learning trainings.\npub struct RLTraining<RLC: RLComponentsTypes> {\n    // Not that complex. Extracting into yet another type would only make it more confusing.\n    #[allow(clippy::type_complexity)]\n    checkpointers: Option<(\n        AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,\n        AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,\n    )>,\n    num_steps: usize,\n    checkpoint: Option<usize>,\n    directory: PathBuf,\n    grad_accumulation: Option<usize>,\n    renderer: Option<Box<dyn MetricsRenderer + 'static>>,\n    metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,\n    event_store: LogEventStore,\n    interrupter: Interrupter,\n    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    checkpointer_strategy: Box<dyn CheckpointingStrategy>,\n    learning_strategy: RLStrategies<RLC>,\n    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order\n    summary_metrics: BTreeSet<String>,\n    summary: bool,\n    env_initializer: RLC::EnvInit,\n}\n\nimpl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>\nwhere\n    B: AutodiffBackend,\n    E: Environment + 'static,\n    EI: EnvironmentInit<E> + Send + 'static,\n    A: PolicyLearner<B> + Send + 'static,\n    A::TrainContext: ItemLazy + Clone + Send,\n    A::InnerPolicy: Policy<B> + Send,\n    <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,\n    <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,\n    <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,\n    E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,\n    E::Action: From<<A::InnerPolicy as Policy<B>>::Action>\n        + Into<<A::InnerPolicy as Policy<B>>::Action>\n        + Clone\n        + Send\n        + 'static,\n{\n    /// Creates a new runner for reinforcement learning.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory to save the checkpoints.\n    /// * `env_init` - Specifies how to initialize the environment.\n    pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {\n        let directory = directory.as_ref().to_path_buf();\n        let experiment_log_file = directory.join(\"experiment.log\");\n        Self {\n            num_steps: 1,\n            checkpoint: None,\n            checkpointers: None,\n            directory,\n            grad_accumulation: None,\n            metrics: RLMetrics::default(),\n            event_store: LogEventStore::default(),\n            renderer: None,\n            interrupter: Interrupter::new(),\n            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(\n                experiment_log_file,\n            ))),\n            checkpointer_strategy: Box::new(\n                ComposedCheckpointingStrategy::builder()\n                    .add(KeepLastNCheckpoints::new(2))\n                    .add(MetricCheckpointingStrategy::new(\n                        &EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.\n                        Aggregate::Mean,\n                        Direction::Lowest,\n                        Split::Valid,\n                    ))\n                    .build(),\n            ),\n            learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),\n            summary_metrics: BTreeSet::new(),\n            summary: false,\n            env_initializer,\n        }\n    }\n}\n\nimpl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {\n    /// Replace the default learning strategy (Off Policy learning) with the provided one.\n    ///\n    /// # Arguments\n    ///\n    /// * `training_strategy` - The training strategy.\n    pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {\n        self.learning_strategy = learning_strategy;\n        self\n    }\n\n    /// Replace the default metric loggers with the provided ones.\n    ///\n    /// # Arguments\n    ///\n    /// * `logger` - The training logger.\n    pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self\n    where\n        ML: MetricLogger + 'static,\n    {\n        self.event_store.register_logger(logger);\n        self\n    }\n\n    /// Update the checkpointing_strategy.\n    pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(\n        mut self,\n        strategy: CS,\n    ) -> Self {\n        self.checkpointer_strategy = Box::new(strategy);\n        self\n    }\n\n    /// Replace the default CLI renderer with a custom one.\n    ///\n    /// # Arguments\n    ///\n    /// * `renderer` - The custom renderer.\n    pub fn renderer<MR>(mut self, renderer: MR) -> Self\n    where\n        MR: MetricsRenderer + 'static,\n    {\n        self.renderer = Some(Box::new(renderer));\n        self\n    }\n\n    /// Register numerical metrics for a training step of the agent.\n    pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register textual metrics for a training step of the agent.\n    pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register numerical metrics for each action of the agent.\n    pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register textual metrics for each action of the agent.\n    pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register numerical metrics for a completed episode.\n    pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register textual metrics for a completed episode.\n    pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register a textual metric for a training step.\n    pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self\n    where\n        <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.metrics.register_text_metric_train(metric);\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.\n    pub fn metric_train<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + Numeric + 'static,\n        <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_metric_train(metric);\n        self\n    }\n\n    /// Register a textual metric for each action taken by the agent.\n    pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self\n    where\n        <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.metrics.register_text_metric_agent(metric.clone());\n        self.metrics.register_text_metric_agent_valid(metric);\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.\n    pub fn metric_agent<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + Numeric + 'static,\n        <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_agent_metric(metric.clone());\n        self.metrics.register_agent_metric_valid(metric);\n        self\n    }\n\n    /// Register a textual metric for a completed episode.\n    pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self\n    where\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        self.metrics.register_text_metric_episode(metric.clone());\n        self.metrics.register_text_metric_episode_valid(metric);\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.\n    pub fn metric_episode<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + Numeric + 'static,\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_episode_metric(metric.clone());\n        self.metrics.register_episode_metric_valid(metric);\n        self\n    }\n\n    /// The number of environment steps to train for.\n    pub fn num_steps(mut self, num_steps: usize) -> Self {\n        self.num_steps = num_steps;\n        self\n    }\n\n    /// The step from which the training must resume.\n    pub fn checkpoint(mut self, checkpoint: usize) -> Self {\n        self.checkpoint = Some(checkpoint);\n        self\n    }\n\n    /// Provides a handle that can be used to interrupt training.\n    pub fn interrupter(&self) -> Interrupter {\n        self.interrupter.clone()\n    }\n\n    /// Override the handle for stopping training with an externally provided handle\n    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {\n        self.interrupter = interrupter;\n        self\n    }\n\n    /// By default, Rust logs are captured and written into\n    /// `experiment.log`. If disabled, standard Rust log handling\n    /// will apply.\n    pub fn with_application_logger(\n        mut self,\n        logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    ) -> Self {\n        self.tracing_logger = logger;\n        self\n    }\n\n    /// Register a checkpointer that will save the environment runner's [policy](Policy)\n    /// and the [PolicyLearner](PolicyLearner) state to different files.\n    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self\n    where\n        FR: FileRecorder<RLC::Backend> + 'static,\n        FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,\n    {\n        let checkpoint_dir = self.directory.join(\"checkpoint\");\n        let checkpointer_policy =\n            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, \"policy\");\n        let checkpointer_learning =\n            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, \"learning-agent\");\n\n        self.checkpointers = Some((\n            AsyncCheckpointer::new(checkpointer_policy),\n            AsyncCheckpointer::new(checkpointer_learning),\n        ));\n\n        self\n    }\n\n    /// Enable the training summary report.\n    ///\n    /// The summary will be displayed after `.launch()`, when the renderer is dropped.\n    pub fn summary(mut self) -> Self {\n        self.summary = true;\n        self\n    }\n\n    /// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.\n    pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy>\n    where\n        RLC::PolicyObs: SliceAccess<RLC::Backend>,\n        RLC::PolicyAction: SliceAccess<RLC::Backend>,\n    {\n        if self.tracing_logger.is_some()\n            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()\n        {\n            log::warn!(\"Failed to install the experiment logger: {e}\");\n        }\n        let renderer = self\n            .renderer\n            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));\n\n        if !self.event_store.has_loggers() {\n            self.event_store\n                .register_logger(FileMetricLogger::new(self.directory.clone()));\n        }\n\n        let event_store = Arc::new(EventStoreClient::new(self.event_store));\n        let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(\n            self.metrics,\n            renderer,\n            event_store.clone(),\n        ));\n\n        let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {\n            RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)\n        });\n\n        let summary = if self.summary {\n            Some(LearnerSummaryConfig {\n                directory: self.directory,\n                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),\n            })\n        } else {\n            None\n        };\n\n        let components = RLComponents::<RLC> {\n            checkpoint: self.checkpoint,\n            checkpointer,\n            interrupter: self.interrupter,\n            event_processor,\n            event_store,\n            num_steps: self.num_steps,\n            grad_accumulation: self.grad_accumulation,\n            summary,\n        };\n\n        match self.learning_strategy {\n            RLStrategies::OffPolicyStrategy(config) => {\n                let strategy = OffPolicyStrategy::new(config);\n                strategy.train(learner_agent, components, self.env_initializer)\n            }\n            RLStrategies::Custom(strategy) => {\n                strategy.train(learner_agent, components, self.env_initializer)\n            }\n        }\n    }\n}\n\n/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).\npub struct RLResult<P> {\n    /// The learned policy.\n    pub policy: P,\n    /// The renderer that can be used for follow up training and evaluation.\n    pub renderer: Box<dyn MetricsRenderer>,\n}\n\n/// Trait to fake variadic generics for train step metrics.\npub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\n/// Trait to fake variadic generics for train step text metrics.\npub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\n/// Trait to fake variadic generics for env step metrics.\npub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\n/// Trait to fake variadic generics for env step text metrics.\npub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\n/// Trait to fake variadic generics for episode metrics.\npub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\n/// Trait to fake variadic generics for episode text metrics.\npub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;\n}\n\nmacro_rules! gen_tuple {\n    ($($M:ident),*) => {\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.text_metric_train($M.clone());)*\n                builder\n            }\n        }\n\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + Numeric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_train($M.clone());)*\n                builder\n            }\n        }\n\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.text_metric_agent($M.clone());)*\n                builder\n            }\n        }\n\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + Numeric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_agent($M.clone());)*\n                builder\n            }\n        }\n\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*\n            $($M: Metric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.text_metric_episode($M.clone());)*\n                builder\n            }\n        }\n\n        impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)\n        where\n            $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*\n            $($M: Metric + Numeric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: RLTraining<RLC>,\n            ) -> RLTraining<RLC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_episode($M.clone());)*\n                builder\n            }\n        }\n    };\n}\n\ngen_tuple!(M1);\ngen_tuple!(M1, M2);\ngen_tuple!(M1, M2, M3);\ngen_tuple!(M1, M2, M3, M4);\ngen_tuple!(M1, M2, M3, M4, M5);\ngen_tuple!(M1, M2, M3, M4, M5, M6);\n"
  },
  {
    "path": "crates/burn-train/src/learner/rl/strategy.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,\n    RLEventProcessorType, RLResult,\n    metric::{processor::EventProcessorTraining, store::EventStoreClient},\n};\n\n/// Struct to minimise parameters passed to [RLStrategy::train].\npub struct RLComponents<RLC: RLComponentsTypes> {\n    /// The total number of environment steps.\n    pub num_steps: usize,\n    /// The step number from which to continue the training.\n    pub checkpoint: Option<usize>,\n    /// A checkpointer used to load and save learning checkpoints.\n    pub checkpointer: Option<RLCheckpointer<RLC>>,\n    /// Enables gradients accumulation.\n    pub grad_accumulation: Option<usize>,\n    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.\n    pub interrupter: Interrupter,\n    /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.\n    pub event_processor: RLEventProcessorType<RLC>,\n    /// A reference to an [EventStoreClient](EventStoreClient).\n    pub event_store: Arc<EventStoreClient>,\n    /// Config for creating a summary of the learning\n    pub summary: Option<LearnerSummaryConfig>,\n}\n\n/// The strategy for reinforcement learning.\n#[derive(Clone)]\npub enum RLStrategies<RLC: RLComponentsTypes> {\n    /// Training on one device\n    OffPolicyStrategy(OffPolicyConfig),\n    /// Training using a custom learning strategy\n    Custom(CustomRLStrategy<RLC>),\n}\n\n/// A reference to an implementation of [RLStrategy].\npub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;\n\n/// Provides the `fit` function for any learning strategy\npub trait RLStrategy<RLC: RLComponentsTypes> {\n    /// Train the learner agent with this strategy.\n    fn train(\n        &self,\n        mut learner_agent: RLC::LearningAgent,\n        mut training_components: RLComponents<RLC>,\n        env_init: RLC::EnvInit,\n    ) -> RLResult<RLC::Policy> {\n        let starting_epoch = match training_components.checkpoint {\n            Some(checkpoint) => {\n                if let Some(checkpointer) = &mut training_components.checkpointer {\n                    learner_agent = checkpointer.load_checkpoint(\n                        learner_agent,\n                        &Default::default(),\n                        checkpoint,\n                    );\n                }\n                checkpoint + 1\n            }\n            None => 1,\n        };\n\n        let summary_config = training_components.summary.clone();\n\n        // Event processor start training\n        training_components\n            .event_processor\n            .process_train(RLEvent::Start);\n\n        // Training loop\n        let (policy, mut event_processor) = self.train_loop(\n            training_components,\n            &mut learner_agent,\n            starting_epoch,\n            env_init,\n        );\n\n        let summary = summary_config.and_then(|summary| summary.init().ok());\n\n        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.\n        // TODO: summary makes sense for RL?\n        event_processor.process_train(RLEvent::End(summary));\n\n        // let model = model.valid();\n        let renderer = event_processor.renderer();\n\n        RLResult { policy, renderer }\n    }\n\n    /// Training loop for this strategy\n    fn train_loop(\n        &self,\n        training_components: RLComponents<RLC>,\n        learner_agent: &mut RLC::LearningAgent,\n        starting_epoch: usize,\n        env_init: RLC::EnvInit,\n    ) -> (RLC::Policy, RLEventProcessorType<RLC>);\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/sequence.rs",
    "content": "use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput};\nuse crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{Int, Tensor, Transaction};\nuse burn_ndarray::NdArray;\n\n/// Sequence prediction output adapted for multiple metrics.\n///\n/// Supported metrics:\n/// - Accuracy\n/// - TopKAccuracy\n/// - Perplexity\n/// - Loss\n/// - CER\n/// - WER\n#[derive(new)]\npub struct SequenceOutput<B: Backend> {\n    /// The loss.\n    pub loss: Tensor<B, 1>,\n\n    /// Raw logits. Shape: `[batch_size, seq_len, vocab_size]`\n    pub logits: Tensor<B, 3>,\n\n    /// Optional predicted token indices. Shape: `[batch_size, seq_length]`.\n    /// If not provided, predictions default to argmax of `logits` along the last dimension.\n    pub predictions: Option<Tensor<B, 2, Int>>,\n\n    /// The target token indices. Shape: `[batch_size, seq_length]`\n    pub targets: Tensor<B, 2, Int>,\n}\n\nimpl<B: Backend> SequenceOutput<B> {\n    fn predicted_tokens(&self) -> Tensor<B, 2, Int> {\n        match &self.predictions {\n            Some(preds) => preds.clone(),\n            None => self.logits.clone().argmax(2).squeeze_dim::<2>(2),\n        }\n    }\n\n    fn flat_logits(&self) -> Tensor<B, 2> {\n        let [batch_size, seq_len, vocab_size] = self.logits.dims();\n        self.logits\n            .clone()\n            .reshape([batch_size * seq_len, vocab_size])\n    }\n\n    fn flat_targets(&self) -> Tensor<B, 1, Int> {\n        let [batch_size, seq_len] = self.targets.dims();\n        self.targets.clone().reshape([batch_size * seq_len])\n    }\n}\n\nimpl<B: Backend> ItemLazy for SequenceOutput<B> {\n    type ItemSync = SequenceOutput<NdArray>;\n\n    fn sync(self) -> Self::ItemSync {\n        let device = &Default::default();\n\n        match self.predictions {\n            Some(preds) => {\n                let [logits, loss, targets, predictions] = Transaction::default()\n                    .register(self.logits)\n                    .register(self.loss)\n                    .register(self.targets)\n                    .register(preds)\n                    .execute()\n                    .try_into()\n                    .expect(\"Correct amount of tensor data\");\n\n                SequenceOutput {\n                    logits: Tensor::from_data(logits, device),\n                    loss: Tensor::from_data(loss, device),\n                    targets: Tensor::from_data(targets, device),\n                    predictions: Some(Tensor::from_data(predictions, device)),\n                }\n            }\n            None => {\n                let [logits, loss, targets] = Transaction::default()\n                    .register(self.logits)\n                    .register(self.loss)\n                    .register(self.targets)\n                    .execute()\n                    .try_into()\n                    .expect(\"Correct amount of tensor data\");\n\n                SequenceOutput {\n                    logits: Tensor::from_data(logits, device),\n                    loss: Tensor::from_data(loss, device),\n                    targets: Tensor::from_data(targets, device),\n                    predictions: None,\n                }\n            }\n        }\n    }\n}\n\nimpl<B: Backend> Adaptor<LossInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> LossInput<B> {\n        LossInput::new(self.loss.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<CerInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> CerInput<B> {\n        CerInput::new(self.predicted_tokens(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<WerInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> WerInput<B> {\n        WerInput::new(self.predicted_tokens(), self.targets.clone())\n    }\n}\n\nimpl<B: Backend> Adaptor<AccuracyInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> AccuracyInput<B> {\n        AccuracyInput::new(self.flat_logits(), self.flat_targets())\n    }\n}\n\nimpl<B: Backend> Adaptor<TopKAccuracyInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> TopKAccuracyInput<B> {\n        TopKAccuracyInput::new(self.flat_logits(), self.flat_targets())\n    }\n}\n\nimpl<B: Backend> Adaptor<PerplexityInput<B>> for SequenceOutput<B> {\n    fn adapt(&self) -> PerplexityInput<B> {\n        PerplexityInput::new(self.flat_logits(), self.flat_targets())\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/summary.rs",
    "content": "use core::cmp::Ordering;\nuse std::{\n    collections::{HashMap, hash_map::Entry},\n    fmt::Display,\n    path::{Path, PathBuf},\n};\n\nuse crate::{\n    logger::FileMetricLogger,\n    metric::store::{Aggregate, EventStore, LogEventStore, Split},\n};\n\n/// Contains the metric value at a given time.\n#[derive(Debug)]\npub struct MetricEntry {\n    /// The step at which the metric was recorded (i.e., epoch).\n    pub step: usize,\n    /// The metric value.\n    pub value: f64,\n}\n\n/// Contains the summary of recorded values for a given metric.\n#[derive(Debug)]\npub struct MetricSummary {\n    /// The metric name.\n    pub name: String,\n    /// The metric entries.\n    pub entries: Vec<MetricEntry>,\n}\n\nimpl MetricSummary {\n    fn collect<E: EventStore>(\n        event_store: &mut E,\n        metric: &str,\n        split: &Split,\n        num_epochs: usize,\n    ) -> Option<Self> {\n        let entries = (1..=num_epochs)\n            .filter_map(|epoch| {\n                event_store\n                    .find_metric(metric, epoch, Aggregate::Mean, split)\n                    .map(|value| MetricEntry { step: epoch, value })\n            })\n            .collect::<Vec<_>>();\n\n        if entries.is_empty() {\n            None\n        } else {\n            Some(Self {\n                name: metric.to_string(),\n                entries,\n            })\n        }\n    }\n}\n\n/// Contains the summary of recorded metrics for the training and validation steps.\npub struct SummaryMetrics {\n    /// Training metrics summary.\n    pub train: Vec<MetricSummary>,\n    /// Validation metrics summary.\n    pub valid: Vec<MetricSummary>,\n    /// Test metrics summary per test split tag.\n    ///\n    /// Each key corresponds to a `Split::Test(Some(tag))`.\n    /// The empty string represents `Split::Test(None)`.\n    pub test: HashMap<String, Vec<MetricSummary>>,\n}\n\n/// Detailed training summary.\npub struct LearnerSummary {\n    /// The number of epochs completed.\n    pub epochs: usize,\n    /// The summary of recorded metrics during training.\n    pub metrics: SummaryMetrics,\n    /// The model name (only recorded within the learner).\n    pub(crate) model: Option<String>,\n}\n\nimpl LearnerSummary {\n    /// Creates a new learner summary for the specified metrics.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory containing the training artifacts (checkpoints and logs).\n    /// * `metrics` - The list of metrics to collect for the summary.\n    pub fn new<S: AsRef<str>>(directory: impl AsRef<Path>, metrics: &[S]) -> Result<Self, String> {\n        let directory = directory.as_ref();\n        if !directory.exists() {\n            return Err(format!(\n                \"Artifact directory does not exist at: {}\",\n                directory.display()\n            ));\n        }\n\n        let mut event_store = LogEventStore::default();\n        let train_split = Split::Train;\n        let valid_split = Split::Valid;\n\n        let logger = FileMetricLogger::new(directory);\n        let test_split_root = logger.split_dir(&Split::Test(None));\n        if !logger.split_exists(&train_split)\n            && !logger.split_exists(&valid_split)\n            && test_split_root.is_none()\n        {\n            return Err(format!(\n                \"No training, validation or test artifacts found at: {}\",\n                directory.display()\n            ));\n        }\n\n        // Number of recorded epochs\n        let epochs = logger.epochs();\n\n        event_store.register_logger(logger);\n\n        let train_summary = metrics\n            .iter()\n            .filter_map(|metric| {\n                MetricSummary::collect(&mut event_store, metric.as_ref(), &train_split, epochs)\n            })\n            .collect::<Vec<_>>();\n\n        let valid_summary = metrics\n            .iter()\n            .filter_map(|metric| {\n                MetricSummary::collect(&mut event_store, metric.as_ref(), &valid_split, epochs)\n            })\n            .collect::<Vec<_>>();\n\n        let test_summary = match test_split_root {\n            Some(root) => collect_test_split_metrics(root, metrics, &mut event_store, epochs),\n            None => Default::default(),\n        };\n\n        Ok(Self {\n            epochs,\n            metrics: SummaryMetrics {\n                train: train_summary,\n                valid: valid_summary,\n                test: test_summary,\n            },\n            model: None,\n        })\n    }\n\n    pub(crate) fn with_model(mut self, name: String) -> Self {\n        self.model = Some(name);\n        self\n    }\n\n    /// Merges another summary into this one, combining all metric entries.\n    pub(crate) fn merge(mut self, other: LearnerSummary) -> Self {\n        fn merge_metrics(\n            base: Vec<MetricSummary>,\n            incoming: Vec<MetricSummary>,\n        ) -> Vec<MetricSummary> {\n            let mut map: HashMap<String, MetricSummary> =\n                base.into_iter().map(|m| (m.name.clone(), m)).collect();\n\n            for metric in incoming {\n                match map.entry(metric.name.clone()) {\n                    Entry::Occupied(mut entry) => {\n                        entry.get_mut().entries.extend(metric.entries);\n                    }\n                    Entry::Vacant(entry) => {\n                        entry.insert(metric);\n                    }\n                }\n            }\n            map.into_values().collect()\n        }\n\n        self.metrics.train = merge_metrics(self.metrics.train, other.metrics.train);\n        self.metrics.valid = merge_metrics(self.metrics.valid, other.metrics.valid);\n\n        for (tag, metrics) in other.metrics.test {\n            match self.metrics.test.entry(tag) {\n                Entry::Occupied(mut entry) => {\n                    let current = std::mem::take(entry.get_mut());\n                    let merged = merge_metrics(current, metrics);\n                    *entry.get_mut() = merged;\n                }\n                Entry::Vacant(entry) => {\n                    entry.insert(metrics);\n                }\n            }\n        }\n\n        if self.model != other.model {\n            self.model = None;\n        }\n\n        self\n    }\n}\n\nfn collect_test_split_metrics<P: AsRef<Path>, S: AsRef<str>>(\n    root: P,\n    metrics: &[S],\n    event_store: &mut LogEventStore,\n    epochs: usize,\n) -> HashMap<String, Vec<MetricSummary>> {\n    // Collect immediate child directories\n    let dirs = match std::fs::read_dir(root) {\n        Ok(entries) => entries\n            .filter_map(|entry| {\n                let entry = entry.ok()?;\n                let file_type = entry.file_type().ok()?;\n                if file_type.is_dir() {\n                    Some(entry.file_name().to_string_lossy().to_string())\n                } else {\n                    None\n                }\n            })\n            .collect::<Vec<_>>(),\n        Err(_) => Vec::new(),\n    };\n\n    let mut map = HashMap::new();\n\n    if dirs.is_empty() {\n        return map;\n    }\n\n    // Detect if all directories are epoch directories\n    let all_epochs = dirs.iter().all(FileMetricLogger::is_epoch_dir);\n\n    if all_epochs {\n        // Single untagged test split\n        let split = Split::Test(None);\n\n        let summaries = metrics\n            .iter()\n            .filter_map(|metric| {\n                MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)\n            })\n            .collect::<Vec<_>>();\n\n        // Untagged marked with empty string\n        map.insert(\"\".to_string(), summaries);\n    } else {\n        // Tagged splits\n        for tag in dirs {\n            let split = Split::Test(Some(tag.clone().into()));\n\n            let summaries = metrics\n                .iter()\n                .filter_map(|metric| {\n                    MetricSummary::collect(event_store, metric.as_ref(), &split, epochs)\n                })\n                .collect::<Vec<_>>();\n\n            map.insert(tag, summaries);\n        }\n    }\n\n    map\n}\n\nimpl Display for LearnerSummary {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        // Compute the max length for each column\n        let mut max_split_len = 5; // \"Train\"\n        let mut max_metric_len = \"Metric\".len();\n        for metric in self.metrics.train.iter() {\n            max_metric_len = max_metric_len.max(metric.name.len());\n        }\n        for metric in self.metrics.valid.iter() {\n            max_metric_len = max_metric_len.max(metric.name.len());\n        }\n        for (tag, metrics) in self.metrics.test.iter() {\n            let split_name = if tag.is_empty() {\n                \"Test\".to_string()\n            } else {\n                format!(\"Test ({tag})\")\n            };\n\n            max_split_len = max_split_len.max(split_name.len());\n\n            for metric in metrics {\n                max_metric_len = max_metric_len.max(metric.name.len());\n            }\n        }\n\n        // Summary header\n        writeln!(\n            f,\n            \"{:=>width_symbol$} Learner Summary {:=>width_symbol$}\",\n            \"\",\n            \"\",\n            width_symbol = 24,\n        )?;\n\n        if let Some(model) = &self.model {\n            writeln!(f, \"Model:\\n{model}\")?;\n        }\n        writeln!(f, \"Total Epochs: {epochs}\\n\\n\", epochs = self.epochs)?;\n\n        // Metrics table header\n        writeln!(\n            f,\n            \"| {:<width_split$} | {:<width_metric$} | Min.     | Epoch    | Max.     | Epoch    |\\n|{:->width_split$}--|{:->width_metric$}--|----------|----------|----------|----------|\",\n            \"Split\",\n            \"Metric\",\n            \"\",\n            \"\",\n            width_split = max_split_len,\n            width_metric = max_metric_len,\n        )?;\n\n        // Table entries\n        fn cmp_f64(a: &f64, b: &f64) -> Ordering {\n            match (a.is_nan(), b.is_nan()) {\n                (true, true) => Ordering::Equal,\n                (true, false) => Ordering::Greater,\n                (false, true) => Ordering::Less,\n                _ => a.partial_cmp(b).unwrap(),\n            }\n        }\n\n        fn fmt_val(val: f64) -> String {\n            if val < 1e-2 {\n                // Use scientific notation for small values which would otherwise be truncated\n                format!(\"{val:<9.3e}\")\n            } else {\n                format!(\"{val:<9.3}\")\n            }\n        }\n\n        let mut write_metrics_summary =\n            |metrics: &[MetricSummary], split: String| -> std::fmt::Result {\n                for metric in metrics.iter() {\n                    if metric.entries.is_empty() {\n                        continue; // skip metrics with no recorded values\n                    }\n\n                    // Compute the min & max for each metric\n                    let metric_min = metric\n                        .entries\n                        .iter()\n                        .min_by(|a, b| cmp_f64(&a.value, &b.value))\n                        .unwrap();\n                    let metric_max = metric\n                        .entries\n                        .iter()\n                        .max_by(|a, b| cmp_f64(&a.value, &b.value))\n                        .unwrap();\n\n                    writeln!(\n                        f,\n                        \"| {:<width_split$} | {:<width_metric$} | {}| {:<9?}| {}| {:<9?}|\",\n                        split,\n                        metric.name,\n                        fmt_val(metric_min.value),\n                        metric_min.step,\n                        fmt_val(metric_max.value),\n                        metric_max.step,\n                        width_split = max_split_len,\n                        width_metric = max_metric_len,\n                    )?;\n                }\n\n                Ok(())\n            };\n\n        write_metrics_summary(&self.metrics.train, format!(\"{:?}\", Split::Train))?;\n        write_metrics_summary(&self.metrics.valid, format!(\"{:?}\", Split::Valid))?;\n\n        for (tag, metrics) in &self.metrics.test {\n            let split_name = if tag.is_empty() {\n                \"Test\".to_string()\n            } else {\n                format!(\"Test ({tag})\")\n            };\n\n            write_metrics_summary(metrics, split_name)?;\n        }\n\n        Ok(())\n    }\n}\n\n// TODO: rename to `ExperimentSummary`? Used in learner + evaluator.\n\n#[derive(Clone)]\n/// Learning summary config.\npub struct LearnerSummaryConfig {\n    pub(crate) directory: PathBuf,\n    pub(crate) metrics: Vec<String>,\n}\n\nimpl LearnerSummaryConfig {\n    /// Create the learning summary.\n    pub fn init(&self) -> Result<LearnerSummary, String> {\n        LearnerSummary::new(&self.directory, &self.metrics[..])\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    #[should_panic = \"Summary artifacts should exist\"]\n    fn test_artifact_dir_should_exist() {\n        let dir = \"/tmp/learner-summary-not-found\";\n        let _summary = LearnerSummary::new(dir, &[\"Loss\"]).expect(\"Summary artifacts should exist\");\n    }\n\n    #[test]\n    #[should_panic = \"Summary artifacts should exist\"]\n    fn test_train_valid_artifacts_should_exist() {\n        let dir = \"/tmp/test-learner-summary-empty\";\n        std::fs::create_dir_all(dir).ok();\n        let _summary = LearnerSummary::new(dir, &[\"Loss\"]).expect(\"Summary artifacts should exist\");\n    }\n\n    #[test]\n    fn test_summary_should_be_empty() {\n        let dir = Path::new(\"/tmp/test-learner-summary-empty-metrics\");\n        std::fs::create_dir_all(dir).unwrap();\n        std::fs::create_dir_all(dir.join(\"train/epoch-1\")).unwrap();\n        std::fs::create_dir_all(dir.join(\"valid/epoch-1\")).unwrap();\n        let summary = LearnerSummary::new(dir.to_str().unwrap(), &[\"Loss\"])\n            .expect(\"Summary artifacts should exist\");\n\n        assert_eq!(summary.epochs, 1);\n\n        assert_eq!(summary.metrics.train.len(), 0);\n        assert_eq!(summary.metrics.valid.len(), 0);\n\n        std::fs::remove_dir_all(dir).unwrap();\n    }\n\n    #[test]\n    fn test_summary_should_be_collected() {\n        let dir = Path::new(\"/tmp/test-learner-summary\");\n        let train_dir = dir.join(\"train/epoch-1\");\n        let valid_dir = dir.join(\"valid/epoch-1\");\n        std::fs::create_dir_all(dir).unwrap();\n        std::fs::create_dir_all(&train_dir).unwrap();\n        std::fs::create_dir_all(&valid_dir).unwrap();\n\n        std::fs::write(train_dir.join(\"Loss.log\"), \"1.0\\n2.0\").expect(\"Unable to write file\");\n        std::fs::write(valid_dir.join(\"Loss.log\"), \"1.0\").expect(\"Unable to write file\");\n\n        let summary = LearnerSummary::new(dir.to_str().unwrap(), &[\"Loss\"])\n            .expect(\"Summary artifacts should exist\");\n\n        assert_eq!(summary.epochs, 1);\n\n        // Only Loss metric\n        assert_eq!(summary.metrics.train.len(), 1);\n        assert_eq!(summary.metrics.valid.len(), 1);\n\n        // Aggregated train metric entries for 1 epoch\n        let train_metric = &summary.metrics.train[0];\n        assert_eq!(train_metric.name, \"Loss\");\n        assert_eq!(train_metric.entries.len(), 1);\n        let entry = &train_metric.entries[0];\n        assert_eq!(entry.step, 1); // epoch = 1\n        assert_eq!(entry.value, 1.5); // (1 + 2) / 2\n\n        // Aggregated valid metric entries for 1 epoch\n        let valid_metric = &summary.metrics.valid[0];\n        assert_eq!(valid_metric.name, \"Loss\");\n        assert_eq!(valid_metric.entries.len(), 1);\n        let entry = &valid_metric.entries[0];\n        assert_eq!(entry.step, 1); // epoch = 1\n        assert_eq!(entry.value, 1.0);\n\n        std::fs::remove_dir_all(dir).unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/mod.rs",
    "content": "mod paradigm;\nmod step;\nmod strategies;\n\npub use paradigm::*;\npub use step::*;\npub use strategies::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/paradigm.rs",
    "content": "use crate::checkpoint::{\n    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,\n    KeepLastNCheckpoints, MetricCheckpointingStrategy,\n};\nuse crate::components::{InferenceModelOutput, TrainingModelOutput};\nuse crate::learner::EarlyStoppingStrategy;\nuse crate::learner::base::Interrupter;\nuse crate::logger::{FileMetricLogger, MetricLogger};\nuse crate::metric::processor::{\n    AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,\n};\nuse crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};\nuse crate::metric::{Adaptor, LossMetric, Metric, Numeric};\nuse crate::multi::MultiDeviceLearningStrategy;\nuse crate::renderer::{MetricsRenderer, default_renderer};\nuse crate::single::SingleDeviceTrainingStrategy;\nuse crate::{\n    ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,\n    InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerEvent,\n    LearnerModelRecord, LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig,\n    LearningCheckpointer, LearningComponentsMarker, LearningComponentsTypes, LearningResult,\n    TrainStep, TrainingBackend, TrainingComponents, TrainingModelInput, TrainingStrategy,\n};\nuse crate::{Learner, SupervisedLearningStrategy};\nuse burn_core::data::dataloader::DataLoader;\nuse burn_core::module::{AutodiffModule, Module};\nuse burn_core::record::FileRecorder;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_optim::Optimizer;\nuse burn_optim::lr_scheduler::LrScheduler;\nuse std::collections::BTreeSet;\nuse std::path::{Path, PathBuf};\nuse std::sync::Arc;\n\n/// A reference to the training split [DataLoader](DataLoader).\npub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;\n/// A reference to the validation split [DataLoader](DataLoader).\npub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;\n/// The event processor type for supervised learning.\npub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<\n    LearnerEvent<TrainingModelOutput<LC>>,\n    LearnerEvent<InferenceModelOutput<LC>>,\n>;\n\n/// Structure to configure and launch supervised learning trainings.\npub struct SupervisedTraining<LC>\nwhere\n    LC: LearningComponentsTypes,\n{\n    // Not that complex. Extracting into another type would only make it more confusing.\n    #[allow(clippy::type_complexity)]\n    checkpointers: Option<(\n        AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,\n        AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,\n        AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,\n    )>,\n    num_epochs: usize,\n    checkpoint: Option<usize>,\n    directory: PathBuf,\n    grad_accumulation: Option<usize>,\n    renderer: Option<Box<dyn MetricsRenderer + 'static>>,\n    metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,\n    event_store: LogEventStore,\n    interrupter: Interrupter,\n    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    checkpointer_strategy: Box<dyn CheckpointingStrategy>,\n    early_stopping: Option<EarlyStoppingStrategyRef>,\n    training_strategy: Option<TrainingStrategy<LC>>,\n    dataloader_train: TrainLoader<LC>,\n    dataloader_valid: ValidLoader<LC>,\n    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order\n    summary_metrics: BTreeSet<String>,\n    summary: bool,\n}\n\nimpl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>\nwhere\n    B: AutodiffBackend,\n    LR: LrScheduler + 'static,\n    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,\n    M::InnerModule: InferenceStep,\n    O: Optimizer<M, B> + 'static,\n{\n    /// Creates a new runner for a supervised training.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory to save the checkpoints.\n    /// * `dataloader_train` - The dataloader for the training split.\n    /// * `dataloader_valid` - The dataloader for the validation split.\n    pub fn new(\n        directory: impl AsRef<Path>,\n        dataloader_train: Arc<dyn DataLoader<B, M::Input>>,\n        dataloader_valid: Arc<\n            dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,\n        >,\n    ) -> Self {\n        let directory = directory.as_ref().to_path_buf();\n        let experiment_log_file = directory.join(\"experiment.log\");\n        Self {\n            num_epochs: 1,\n            checkpoint: None,\n            checkpointers: None,\n            directory,\n            grad_accumulation: None,\n            metrics: MetricsTraining::default(),\n            event_store: LogEventStore::default(),\n            renderer: None,\n            interrupter: Interrupter::new(),\n            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(\n                experiment_log_file,\n            ))),\n            checkpointer_strategy: Box::new(\n                ComposedCheckpointingStrategy::builder()\n                    .add(KeepLastNCheckpoints::new(2))\n                    .add(MetricCheckpointingStrategy::new(\n                        &LossMetric::<B>::new(), // default to valid loss\n                        Aggregate::Mean,\n                        Direction::Lowest,\n                        Split::Valid,\n                    ))\n                    .build(),\n            ),\n            early_stopping: None,\n            training_strategy: None,\n            summary_metrics: BTreeSet::new(),\n            summary: false,\n            dataloader_train,\n            dataloader_valid,\n        }\n    }\n}\n\nimpl<LC: LearningComponentsTypes> SupervisedTraining<LC> {\n    /// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided one.\n    ///\n    /// # Arguments\n    ///\n    /// * `training_strategy` - The training strategy.\n    pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {\n        self.training_strategy = Some(training_strategy);\n        self\n    }\n\n    /// Replace the default metric loggers with the provided ones.\n    ///\n    /// # Arguments\n    ///\n    /// * `logger` - The training logger.\n    pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self\n    where\n        ML: MetricLogger + 'static,\n    {\n        self.event_store.register_logger(logger);\n        self\n    }\n\n    /// Update the checkpointing_strategy.\n    pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(\n        mut self,\n        strategy: CS,\n    ) -> Self {\n        self.checkpointer_strategy = Box::new(strategy);\n        self\n    }\n\n    /// Replace the default CLI renderer with a custom one.\n    ///\n    /// # Arguments\n    ///\n    /// * `renderer` - The custom renderer.\n    pub fn renderer<MR>(mut self, renderer: MR) -> Self\n    where\n        MR: MetricsRenderer + 'static,\n    {\n        self.renderer = Some(Box::new(renderer));\n        self\n    }\n\n    /// Register all metrics as numeric for the training and validation set.\n    pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register all metrics as text for the training and validation set.\n    pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {\n        metrics.register(self)\n    }\n\n    /// Register a training metric.\n    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self\n    where\n        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.metrics.register_train_metric(metric);\n        self\n    }\n\n    /// Register a validation metric.\n    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self\n    where\n        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.metrics.register_valid_metric(metric);\n        self\n    }\n\n    /// Enable gradients accumulation.\n    ///\n    /// # Notes\n    ///\n    /// When you enable gradients accumulation, the gradients object used by the optimizer will be\n    /// the sum of all gradients generated by each backward pass. It might be a good idea to\n    /// reduce the learning to compensate.\n    ///\n    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`\n    /// amount.\n    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {\n        self.grad_accumulation = Some(accumulation);\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).\n    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self\n    where\n        Me: Metric + Numeric + 'static,\n        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_train_metric_numeric(metric);\n        self\n    }\n\n    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).\n    pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self\n    where\n        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,\n    {\n        self.summary_metrics.insert(metric.name().to_string());\n        self.metrics.register_valid_metric_numeric(metric);\n        self\n    }\n\n    /// The number of epochs the training should last.\n    pub fn num_epochs(mut self, num_epochs: usize) -> Self {\n        self.num_epochs = num_epochs;\n        self\n    }\n\n    /// The epoch from which the training must resume.\n    pub fn checkpoint(mut self, checkpoint: usize) -> Self {\n        self.checkpoint = Some(checkpoint);\n        self\n    }\n\n    /// Provides a handle that can be used to interrupt training.\n    pub fn interrupter(&self) -> Interrupter {\n        self.interrupter.clone()\n    }\n\n    /// Override the handle for stopping training with an externally provided handle\n    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {\n        self.interrupter = interrupter;\n        self\n    }\n\n    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the\n    /// conditions are meet.\n    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self\n    where\n        Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,\n    {\n        self.early_stopping = Some(Box::new(strategy));\n        self\n    }\n\n    /// By default, Rust logs are captured and written into\n    /// `experiment.log`. If disabled, standard Rust log handling\n    /// will apply.\n    pub fn with_application_logger(\n        mut self,\n        logger: Option<Box<dyn ApplicationLoggerInstaller>>,\n    ) -> Self {\n        self.tracing_logger = logger;\n        self\n    }\n\n    /// Register a checkpointer that will save the [optimizer](Optimizer), the\n    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.\n    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self\n    where\n        FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,\n        FR: FileRecorder<\n                <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,\n            > + 'static,\n    {\n        let checkpoint_dir = self.directory.join(\"checkpoint\");\n        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, \"model\");\n        let checkpointer_optimizer =\n            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, \"optim\");\n        let checkpointer_scheduler: FileCheckpointer<FR> =\n            FileCheckpointer::new(recorder, &checkpoint_dir, \"scheduler\");\n\n        self.checkpointers = Some((\n            AsyncCheckpointer::new(checkpointer_model),\n            AsyncCheckpointer::new(checkpointer_optimizer),\n            AsyncCheckpointer::new(checkpointer_scheduler),\n        ));\n\n        self\n    }\n\n    /// Enable the training summary report.\n    ///\n    /// The summary will be displayed after `.fit()`, when the renderer is dropped.\n    pub fn summary(mut self) -> Self {\n        self.summary = true;\n        self\n    }\n}\n\nimpl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC> {\n    /// Launch this training with the given [Learner](Learner).\n    pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {\n        if self.tracing_logger.is_some()\n            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()\n        {\n            log::warn!(\"Failed to install the experiment logger: {e}\");\n        }\n        let renderer = self\n            .renderer\n            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));\n\n        if !self.event_store.has_loggers() {\n            self.event_store\n                .register_logger(FileMetricLogger::new(self.directory.clone()));\n        }\n\n        let event_store = Arc::new(EventStoreClient::new(self.event_store));\n        let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(\n            self.metrics,\n            renderer,\n            event_store.clone(),\n        ));\n\n        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {\n            LearningCheckpointer::new(\n                model.with_interrupter(self.interrupter.clone()),\n                optim.with_interrupter(self.interrupter.clone()),\n                scheduler.with_interrupter(self.interrupter.clone()),\n                self.checkpointer_strategy,\n            )\n        });\n\n        let summary = if self.summary {\n            Some(LearnerSummaryConfig {\n                directory: self.directory,\n                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),\n            })\n        } else {\n            None\n        };\n\n        let components = TrainingComponents {\n            checkpoint: self.checkpoint,\n            checkpointer,\n            interrupter: self.interrupter,\n            early_stopping: self.early_stopping,\n            event_processor,\n            event_store,\n            num_epochs: self.num_epochs,\n            grad_accumulation: self.grad_accumulation,\n            summary,\n        };\n\n        // Default to single device based on model\n        let training_strategy = self\n            .training_strategy\n            .unwrap_or(TrainingStrategy::SingleDevice(\n                learner.model.devices()[0].clone(),\n            ));\n\n        match training_strategy {\n            TrainingStrategy::SingleDevice(device) => {\n                let single_device: SingleDeviceTrainingStrategy<LC> =\n                    SingleDeviceTrainingStrategy::new(device);\n                single_device.train(\n                    learner,\n                    self.dataloader_train,\n                    self.dataloader_valid,\n                    components,\n                )\n            }\n            TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(\n                learner,\n                self.dataloader_train,\n                self.dataloader_valid,\n                components,\n            ),\n            TrainingStrategy::MultiDevice(devices, multi_device_optim) => {\n                let strategy: Box<dyn SupervisedLearningStrategy<LC>> = match devices.len() == 1 {\n                    true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())),\n                    false => Box::new(MultiDeviceLearningStrategy::new(\n                        devices,\n                        multi_device_optim,\n                    )),\n                };\n                strategy.train(\n                    learner,\n                    self.dataloader_train,\n                    self.dataloader_valid,\n                    components,\n                )\n            }\n            #[cfg(feature = \"ddp\")]\n            TrainingStrategy::DistributedDataParallel { devices, config } => {\n                use crate::ddp::DdpTrainingStrategy;\n\n                let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone());\n                ddp.train(\n                    learner,\n                    self.dataloader_train,\n                    self.dataloader_valid,\n                    components,\n                )\n            }\n        }\n    }\n}\n\n/// Trait to fake variadic generics.\npub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;\n}\n\n/// Trait to fake variadic generics.\npub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {\n    /// Register the metrics.\n    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;\n}\n\nmacro_rules! gen_tuple {\n    ($($M:ident),*) => {\n        impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)\n        where\n            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: SupervisedTraining<LC>,\n            ) -> SupervisedTraining<LC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_train($M.clone());)*\n                $(let builder = builder.metric_valid($M);)*\n                builder\n            }\n        }\n\n        impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)\n        where\n            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*\n            $($M: Metric + Numeric + 'static,)*\n        {\n            #[allow(non_snake_case)]\n            fn register(\n                self,\n                builder: SupervisedTraining<LC>,\n            ) -> SupervisedTraining<LC> {\n                let ($($M,)*) = self;\n                $(let builder = builder.metric_train_numeric($M.clone());)*\n                $(let builder = builder.metric_valid_numeric($M);)*\n                builder\n            }\n        }\n    };\n}\n\ngen_tuple!(M1);\ngen_tuple!(M1, M2);\ngen_tuple!(M1, M2, M3);\ngen_tuple!(M1, M2, M3, M4);\ngen_tuple!(M1, M2, M3, M4, M5);\ngen_tuple!(M1, M2, M3, M4, M5, M6);\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/step/mod.rs",
    "content": "/// The trainer module.\npub mod train;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/step/train.rs",
    "content": "use crate::{LearningComponentsTypes, TrainingModel};\nuse crate::{TrainOutput, TrainStep, TrainingBackend, TrainingModelInput, TrainingModelOutput};\nuse burn_core::data::dataloader::DataLoaderIterator;\nuse burn_core::data::dataloader::Progress;\nuse burn_core::module::Module;\nuse burn_core::prelude::DeviceOps;\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::DeviceId;\nuse std::sync::mpsc::{Receiver, Sender};\nuse std::thread::spawn;\n\n/// Multi devices train step.\npub struct MultiDevicesTrainStep<LC: LearningComponentsTypes> {\n    workers: Vec<Worker<LC>>,\n    receiver: Receiver<MultiTrainOutput<TrainingModelOutput<LC>>>,\n}\n\nstruct Message<M, TI> {\n    item: TI,\n    model: M,\n}\n\nstruct Worker<LC: LearningComponentsTypes> {\n    // Not that complex. Extracting into another type would only make it more confusing.\n    #[allow(clippy::type_complexity)]\n    sender_input: Sender<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,\n    device: Device<TrainingBackend<LC>>,\n}\n\nimpl<LC: LearningComponentsTypes> Worker<LC> {\n    fn register(&self, item: TrainingModelInput<LC>, model: &TrainingModel<LC>) {\n        let message = Message {\n            item,\n            model: model.clone(),\n        };\n        self.sender_input.send(message).unwrap();\n    }\n\n    // Not that complex. Extracting into another type would only make it more confusing.\n    #[allow(clippy::type_complexity)]\n    fn start(\n        &self,\n        sender_output: Sender<MultiTrainOutput<TrainingModelOutput<LC>>>,\n        receiver_input: Receiver<Message<TrainingModel<LC>, TrainingModelInput<LC>>>,\n    ) {\n        let device = self.device.clone();\n\n        spawn(move || {\n            loop {\n                match receiver_input.recv() {\n                    Ok(item) => {\n                        let model = item.model.fork(&device);\n                        let output = model.step(item.item);\n                        let item = MultiTrainOutput {\n                            output,\n                            device: device.to_id(),\n                        };\n\n                        sender_output.send(item).unwrap();\n                    }\n                    Err(_err) => {\n                        log::info!(\"Closing thread on device {device:?}\");\n                        break;\n                    }\n                }\n            }\n        });\n    }\n}\n\n/// Multiple output items.\npub struct MultiTrainOutput<TO> {\n    /// The training output.\n    pub output: TrainOutput<TO>,\n    /// The device on which the computing happened.\n    pub device: DeviceId,\n}\n\nimpl<LC: LearningComponentsTypes> MultiDevicesTrainStep<LC> {\n    /// Create a new multi devices train step.\n    ///\n    /// # Arguments\n    ///\n    /// * `devices` - Devices.\n    ///\n    /// # Returns\n    ///\n    /// MultiDevicesTrainStep instance.\n    pub fn new(devices: &[Device<TrainingBackend<LC>>]) -> Self {\n        let (sender_output, receiver_output) = std::sync::mpsc::channel();\n        let workers = devices\n            .iter()\n            .map(|device| {\n                let (sender_input, receiver_input) = std::sync::mpsc::channel();\n                let worker = Worker {\n                    sender_input,\n                    device: device.clone(),\n                };\n\n                worker.start(sender_output.clone(), receiver_input);\n                worker\n            })\n            .collect();\n\n        Self {\n            workers,\n            receiver: receiver_output,\n        }\n    }\n\n    /// Collect outputs from workers for one step.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - Model.\n    /// * `dataloaders` - The data loader for each worker.\n    ///\n    /// # Returns\n    ///\n    /// Outputs.\n    pub fn step<'a>(\n        &self,\n        dataloaders: &mut [Box<dyn DataLoaderIterator<TrainingModelInput<LC>> + 'a>],\n        model: &TrainingModel<LC>,\n    ) -> (Vec<MultiTrainOutput<TrainingModelOutput<LC>>>, Progress) {\n        let mut num_send = 0;\n\n        let mut items_total = 0;\n        let mut items_processed = 0;\n\n        for (i, worker) in self.workers.iter().enumerate() {\n            let dataloader = &mut dataloaders[i];\n            if let Some(item) = dataloader.next() {\n                worker.register(item, model);\n                num_send += 1;\n                let progress = dataloader.progress();\n                items_total += progress.items_total;\n                items_processed += progress.items_processed;\n            }\n        }\n\n        let mut outputs = Vec::with_capacity(num_send);\n\n        for _ in 0..num_send {\n            let output = self.receiver.recv().unwrap();\n            outputs.push(output);\n        }\n\n        (outputs, Progress::new(items_processed, items_total))\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/base.rs",
    "content": "use std::sync::Arc;\n\n#[cfg(feature = \"ddp\")]\nuse burn_collective::CollectiveConfig;\nuse burn_core::{module::AutodiffModule, prelude::Backend};\n\nuse crate::{\n    EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,\n    LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,\n    TrainingModel, ValidLoader,\n    components::LearningComponentsTypes,\n    metric::{\n        processor::{EventProcessorTraining, LearnerEvent},\n        store::EventStoreClient,\n    },\n};\n\ntype LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;\n\n/// A reference to an implementation of SupervisedLearningStrategy.\npub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;\n\n#[derive(Clone, Copy, Debug)]\n/// Determine how the optimization is performed when training with multiple devices.\npub enum MultiDeviceOptim {\n    /// The optimization is done on an elected device.\n    OptimMainDevice,\n    /// The optimization is sharded across all devices.\n    OptimSharded,\n}\n\n/// How should the learner run the learning for the model\n#[derive(Clone)]\npub enum TrainingStrategy<LC: LearningComponentsTypes> {\n    /// Training on one device\n    SingleDevice(LearnerDevice<LC>),\n    /// Performs data-parallel distributed training where the optimization is\n    /// done on an elected master device.\n    MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),\n    /// Training using a custom learning strategy\n    Custom(CustomLearningStrategy<LC>),\n    /// Training with input distributed across devices, each device has its own copy of the model.\n    /// Collective ops are used to sync the gradients after each pass.\n    #[cfg(feature = \"ddp\")]\n    DistributedDataParallel {\n        /// Devices on this node for the DDP\n        devices: Vec<LearnerDevice<LC>>,\n\n        /// The configuration for collective operations\n        /// num_devices is ignored\n        config: CollectiveConfig,\n    },\n}\n\n/// Constructor for a distributed data parallel (DDP) learning strategy\n#[cfg(feature = \"ddp\")]\npub fn ddp<LC: LearningComponentsTypes>(\n    devices: Vec<LearnerDevice<LC>>,\n    config: CollectiveConfig,\n) -> TrainingStrategy<LC> {\n    TrainingStrategy::DistributedDataParallel { devices, config }\n}\n\nimpl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {\n    fn default() -> Self {\n        Self::SingleDevice(Default::default())\n    }\n}\n\n/// Struct to minimise parameters passed to [SupervisedLearningStrategy::train].\n/// These components are used during training.\npub struct TrainingComponents<LC: LearningComponentsTypes> {\n    /// The total number of epochs\n    pub num_epochs: usize,\n    /// The epoch number from which to continue the training.\n    pub checkpoint: Option<usize>,\n    /// A checkpointer used to load and save learner checkpoints.\n    pub checkpointer: Option<LearningCheckpointer<LC>>,\n    /// Enables gradients accumulation.\n    pub grad_accumulation: Option<usize>,\n    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.\n    pub interrupter: Interrupter,\n    /// Cloneable reference to an early stopping strategy.\n    pub early_stopping: Option<EarlyStoppingStrategyRef>,\n    /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation.\n    pub event_processor: SupervisedTrainingEventProcessor<LC>,\n    /// A reference to an [EventStoreClient](EventStoreClient).\n    pub event_store: Arc<EventStoreClient>,\n    /// Config for creating a summary of the learning\n    pub summary: Option<LearnerSummaryConfig>,\n}\n\n/// Provides the `fit` function for any learning strategy\npub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {\n    /// Train the learner's model with this strategy.\n    fn train(\n        &self,\n        mut learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        mut training_components: TrainingComponents<LC>,\n    ) -> LearningResult<InferenceModel<LC>> {\n        let starting_epoch = match training_components.checkpoint {\n            Some(checkpoint) => {\n                if let Some(checkpointer) = &mut training_components.checkpointer {\n                    learner =\n                        checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);\n                }\n                checkpoint + 1\n            }\n            None => 1,\n        };\n\n        let summary_config = training_components.summary.clone();\n\n        // Event processor start training\n        training_components\n            .event_processor\n            .process_train(LearnerEvent::Start);\n        // Training loop\n        let (model, mut event_processor) = self.fit(\n            training_components,\n            learner,\n            dataloader_train,\n            dataloader_valid,\n            starting_epoch,\n        );\n\n        let summary = summary_config.and_then(|summary| {\n            summary\n                .init()\n                .map(|summary| summary.with_model(model.to_string()))\n                .ok()\n        });\n\n        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.\n        event_processor.process_train(LearnerEvent::End(summary));\n\n        let model = model.valid();\n        let renderer = event_processor.renderer();\n\n        LearningResult::<InferenceModel<LC>> { model, renderer }\n    }\n\n    /// Training loop for this strategy\n    fn fit(\n        &self,\n        training_components: TrainingComponents<LC>,\n        learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        starting_epoch: usize,\n    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/ddp/README.md",
    "content": "## DDP\nDistributed Data Parallel\n\nThe DDP is a learning strategy that trains a replica of the model on each device.\n\nThe DDP launches threads for each local device. Each thread on each node will run the model.\nAfter the forward and backward passes, the gradients are synced between all peers on all nodes \nwith an `all-reduce` operation.\n\nWhile the DDP launches threads for each local device, it is the user's responsibility to launch the \nDDP on each node, and assure the collective configuration matches.\n\n## Main device vs secondary devices \n\nThe main device is responsible for validation, as well as event processing, which is used in the UI.\n\nThe first device is chosen as the main device.\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/ddp/epoch.rs",
    "content": "use burn_collective::{PeerId, ReduceOperation};\nuse burn_core::data::dataloader::Progress;\nuse burn_core::module::AutodiffModule;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_optim::GradientsAccumulator;\nuse burn_optim::GradientsParams;\nuse std::marker::PhantomData;\nuse std::sync::mpsc::{Receiver, SyncSender};\nuse std::sync::{Arc, Mutex};\n\nuse crate::SupervisedTrainingEventProcessor;\nuse crate::learner::base::Interrupter;\nuse crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};\nuse crate::{\n    InferenceStep, Learner, LearningComponentsTypes, TrainLoader, TrainingBackend, ValidLoader,\n};\n\n/// A validation epoch.\n#[derive(new)]\npub struct DdpValidEpoch<LC: LearningComponentsTypes> {\n    dataloader: ValidLoader<LC>,\n}\n\n/// A training epoch.\n#[derive(new)]\npub struct DdpTrainEpoch<LC: LearningComponentsTypes> {\n    dataloader: TrainLoader<LC>,\n    grad_accumulation: Option<usize>,\n}\n\nimpl<LC: LearningComponentsTypes> DdpValidEpoch<LC> {\n    /// Runs the validation epoch.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - The model to validate.\n    /// * `processor` - The event processor to use.\n    pub fn run(\n        &self,\n        model: &<LC as LearningComponentsTypes>::TrainingModel,\n        global_progress: &Progress,\n        processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\"Executing validation step for epoch {}\", epoch);\n        let model = model.valid();\n\n        let mut iterator = self.dataloader.iter();\n        let mut iteration = 0;\n\n        while let Some(item) = iterator.next() {\n            let progress = iterator.progress();\n            iteration += 1;\n\n            let item = model.step(item);\n            let item = TrainingItem::new(\n                item,\n                progress,\n                global_progress.clone(),\n                Some(iteration),\n                None,\n            );\n\n            processor.process_valid(LearnerEvent::ProcessedItem(item));\n\n            if interrupter.should_stop() {\n                log::info!(\"Training interrupted.\");\n                break;\n            }\n        }\n        processor.process_valid(LearnerEvent::EndEpoch(epoch));\n    }\n}\n\nimpl<LC: LearningComponentsTypes> DdpTrainEpoch<LC> {\n    /// Runs the training epoch.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - The model to train.\n    /// * `optim` - The optimizer to use.\n    /// * `scheduler` - The learning rate scheduler to use.\n    /// * `processor` - The event processor to use.\n    ///\n    /// # Returns\n    ///\n    /// The trained model and the optimizer.\n    #[allow(clippy::too_many_arguments)]\n    pub fn run(\n        &self,\n        learner: &mut Learner<LC>,\n        global_progress: &Progress,\n        processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,\n        interrupter: &Interrupter,\n        peer_id: PeerId,\n        peer_count: usize,\n        is_main: bool,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\"Executing training step for epoch {}\", epoch,);\n\n        let mut iterator = self.dataloader.iter();\n        let mut iteration = 0;\n        let mut accumulator = GradientsAccumulator::new();\n        let mut accumulation_current = 0;\n\n        let grads_syncer = GradsSyncer::<\n            TrainingBackend<LC>,\n            <LC as LearningComponentsTypes>::TrainingModel,\n        >::new(false, peer_id);\n\n        while let Some(item) = iterator.next() {\n            for _ in 0..peer_count {\n                iteration += 1;\n                learner.lr_step();\n            }\n            log::info!(\"Iteration {iteration}\");\n\n            let mut progress = iterator.progress();\n            progress.items_processed *= peer_count;\n            progress.items_total *= peer_count;\n\n            let item = learner.train_step(item);\n\n            match self.grad_accumulation {\n                Some(accumulation) => {\n                    accumulator.accumulate(&learner.model(), item.grads);\n                    accumulation_current += 1;\n\n                    if accumulation <= accumulation_current {\n                        let grads = accumulator.grads();\n\n                        // With double buffering, these are the previous iteration's gradients\n                        let grads = grads_syncer.sync(grads);\n                        if let Some(grads) = grads {\n                            learner.optimizer_step(grads);\n                        }\n\n                        accumulation_current = 0;\n                    }\n                }\n                None => {\n                    // With double buffering, these are the previous iteration's gradients\n                    let grads = grads_syncer.sync(item.grads);\n\n                    if let Some(grads) = grads {\n                        learner.optimizer_step(grads);\n                    }\n                }\n            }\n\n            let item = TrainingItem::new(\n                item.item,\n                progress,\n                global_progress.clone(),\n                Some(iteration),\n                Some(learner.lr_current()),\n            );\n\n            {\n                let mut processor = processor.lock().unwrap();\n                processor.process_train(LearnerEvent::ProcessedItem(item));\n            }\n\n            if interrupter.should_stop() {\n                log::info!(\"Training interrupted.\");\n                break;\n            }\n        }\n\n        if is_main {\n            let mut processor = processor.lock().unwrap();\n            processor.process_train(LearnerEvent::EndEpoch(epoch));\n        }\n    }\n}\n\n/// Worker that is responsible for syncing gradients for the DDP worker. With double buffering,\n/// this allows for more optimization.\nstruct GradsSyncer<B: AutodiffBackend, M: AutodiffModule<B> + 'static> {\n    msg_send: SyncSender<GradientsParams>,\n    // Optional because with double buffering, the first iteration yields no gradients.\n    result_recv: Receiver<Option<GradientsParams>>,\n\n    _p: PhantomData<(B, M)>,\n}\n\nimpl<B: AutodiffBackend, M: AutodiffModule<B> + 'static> GradsSyncer<B, M> {\n    fn new(double_buffering: bool, peer_id: PeerId) -> Self {\n        let (msg_send, msg_recv) = std::sync::mpsc::sync_channel::<GradientsParams>(1);\n        let (result_send, result_recv) =\n            std::sync::mpsc::sync_channel::<Option<GradientsParams>>(1);\n        std::thread::spawn(move || {\n            Self::run_worker(double_buffering, peer_id, result_send, msg_recv)\n        });\n        Self {\n            msg_send,\n            result_recv,\n            _p: PhantomData,\n        }\n    }\n\n    fn sync(&self, grads: GradientsParams) -> Option<GradientsParams> {\n        self.msg_send.send(grads).unwrap();\n        self.result_recv.recv().unwrap()\n    }\n\n    fn run_worker(\n        double_buffering: bool,\n        peer_id: PeerId,\n        send: SyncSender<Option<GradientsParams>>,\n        recv: Receiver<GradientsParams>,\n    ) {\n        let mut grads_buffer = None;\n\n        while let Ok(new_grads) = recv.recv() {\n            // Sync grads with collective\n            let new_grads = new_grads\n                .all_reduce::<B::InnerBackend>(peer_id, ReduceOperation::Mean)\n                .expect(\"DDP worker could not sync gradients!\");\n\n            if double_buffering {\n                let old_grads = grads_buffer.take();\n                grads_buffer = Some(new_grads);\n\n                send.send(old_grads).unwrap();\n            } else {\n                send.send(Some(new_grads)).unwrap();\n            }\n        }\n        // GradsSyncer dropped, channel closed, this thread can end\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/ddp/mod.rs",
    "content": "mod epoch;\nmod strategy;\nmod worker;\n\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/ddp/strategy.rs",
    "content": "use core::panic;\nuse std::sync::{Arc, Mutex};\n\nuse burn_collective::CollectiveConfig;\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::DeviceOps;\n\nuse crate::ddp::worker::DdpWorker;\nuse crate::metric::store::EventStoreClient;\nuse crate::{\n    EarlyStoppingStrategyRef, Interrupter, Learner, LearningComponentsTypes,\n    SupervisedLearningStrategy, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend,\n    TrainingComponents, TrainingModel, ValidLoader,\n};\nuse burn_core::data::dataloader::split::split_dataloader;\n\n#[derive(Clone)]\npub(crate) struct WorkerComponents {\n    /// The total number of epochs\n    pub num_epochs: usize,\n    /// Enables gradients accumulation.\n    pub grad_accumulation: Option<usize>,\n    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.\n    pub interrupter: Interrupter,\n    /// Cloneable reference to an early stopping strategy.\n    pub early_stopping: Option<EarlyStoppingStrategyRef>,\n    /// A reference to an [EventStoreClient](EventStoreClient).\n    pub event_store: Arc<EventStoreClient>,\n}\n\npub struct DdpTrainingStrategy<LC: LearningComponentsTypes> {\n    devices: Vec<Device<TrainingBackend<LC>>>,\n    config: CollectiveConfig,\n}\nimpl<LC: LearningComponentsTypes> DdpTrainingStrategy<LC> {\n    pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, config: CollectiveConfig) -> Self {\n        let config = config.with_num_devices(devices.len());\n        Self { devices, config }\n    }\n}\n\nimpl<LC: LearningComponentsTypes + Send + 'static> SupervisedLearningStrategy<LC>\n    for DdpTrainingStrategy<LC>\n{\n    fn fit(\n        &self,\n        training_components: TrainingComponents<LC>,\n        learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        starting_epoch: usize,\n    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {\n        // The reference model is always on the first device provided.\n        let main_device = self.devices.first().unwrap();\n        // One worker per device, so we use a fixed device strategy\n        // for each (worker) data loader. This matches the expected device on the worker, so we\n        // don't have to move the data between devices.\n        let mut dataloaders_train = split_dataloader(dataloader_train, &self.devices);\n        let dataloader_valid = dataloader_valid.to_device(main_device.inner());\n\n        let main_device = self.devices[0].clone();\n        let peer_count = self.devices.len();\n        let event_processor = Arc::new(Mutex::new(training_components.event_processor));\n\n        let interrupter = training_components.interrupter;\n        let worker_components = WorkerComponents {\n            num_epochs: training_components.num_epochs,\n            grad_accumulation: training_components.grad_accumulation,\n            interrupter: interrupter.clone(),\n            early_stopping: training_components.early_stopping,\n            event_store: training_components.event_store,\n        };\n\n        // Start worker for main device\n        // First training dataloader corresponds to main device\n        let main_handle = DdpWorker::<LC>::start(\n            0.into(),\n            main_device,\n            learner.clone(),\n            event_processor.clone(),\n            worker_components.clone(),\n            training_components.checkpointer,\n            dataloaders_train.remove(0),\n            Some(dataloader_valid),\n            self.config.clone(),\n            starting_epoch,\n            peer_count,\n            true,\n        );\n\n        // Spawn other workers for the other devices, starting with peer id 1\n        let mut peer_id = 1;\n        let mut secondary_workers = vec![];\n        for device in &self.devices[1..] {\n            let handle = DdpWorker::<LC>::start(\n                peer_id.into(),\n                device.clone(),\n                learner.clone(),\n                event_processor.clone(),\n                worker_components.clone(),\n                None,\n                dataloaders_train.remove(0),\n                None,\n                self.config.clone(),\n                starting_epoch,\n                peer_count,\n                false,\n            );\n\n            peer_id += 1;\n\n            secondary_workers.push(handle);\n        }\n\n        // Wait for all devices to finish\n        for worker in secondary_workers {\n            worker\n                .join()\n                .expect(\"Distributed data parallel worker failed\");\n        }\n        // Main worker had the event processor\n        let model = main_handle\n            .join()\n            .expect(\"Distributed data parallel main worker failed\");\n\n        if interrupter.should_stop() {\n            let reason = interrupter\n                .get_message()\n                .unwrap_or(String::from(\"Reason unknown\"));\n            log::info!(\"Training interrupted: {reason}\");\n        }\n        let Ok(event_processor) = Arc::try_unwrap(event_processor) else {\n            panic!(\"Event processor still held!\");\n        };\n        let Ok(event_processor) = event_processor.into_inner() else {\n            panic!(\"Event processor lock poisoned\");\n        };\n        (model, event_processor)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/ddp/worker.rs",
    "content": "use crate::ddp::epoch::{DdpTrainEpoch, DdpValidEpoch};\nuse crate::ddp::strategy::WorkerComponents;\nuse crate::single::TrainingLoop;\nuse crate::{\n    Learner, LearningCheckpointer, LearningComponentsTypes, SupervisedTrainingEventProcessor,\n    TrainLoader, TrainingBackend, ValidLoader,\n};\nuse burn_collective::{self, CollectiveConfig, PeerId};\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse std::sync::{Arc, Mutex};\nuse std::thread::JoinHandle;\n\n/// A worker runs the model, syncing gradients using collective operations.\n/// Event processing and validation is optional too.\npub(crate) struct DdpWorker<LC>\nwhere\n    LC: LearningComponentsTypes + Send + 'static,\n{\n    peer_id: PeerId,\n    device: Device<TrainingBackend<LC>>,\n    learner: Learner<LC>,\n    event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,\n    components: WorkerComponents,\n    checkpointer: Option<LearningCheckpointer<LC>>,\n    dataloader_train: TrainLoader<LC>,\n    dataloader_valid: Option<ValidLoader<LC>>,\n    collective_config: CollectiveConfig,\n    starting_epoch: usize,\n    peer_count: usize,\n    is_main: bool,\n}\n\nimpl<LC> DdpWorker<LC>\nwhere\n    LC: LearningComponentsTypes + Send + 'static,\n{\n    /// Starts a worker that runs the model in a data distributed parallel\n    #[allow(clippy::too_many_arguments)]\n    pub fn start(\n        peer_id: PeerId,\n        device: Device<TrainingBackend<LC>>,\n        learner: Learner<LC>,\n        event_processor: Arc<Mutex<SupervisedTrainingEventProcessor<LC>>>,\n        components: WorkerComponents,\n        checkpointer: Option<LearningCheckpointer<LC>>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: Option<ValidLoader<LC>>,\n        collective_config: CollectiveConfig,\n        starting_epoch: usize,\n        peer_count: usize,\n        is_main: bool,\n    ) -> JoinHandle<<LC as LearningComponentsTypes>::TrainingModel> {\n        let worker = Self {\n            peer_id,\n            device,\n            learner,\n            event_processor,\n            components,\n            checkpointer,\n            dataloader_train,\n            dataloader_valid,\n            collective_config,\n            starting_epoch,\n            peer_count,\n            is_main,\n        };\n\n        std::thread::spawn(|| worker.fit())\n    }\n\n    /// Fits the model,\n    pub fn fit(mut self) -> <LC as LearningComponentsTypes>::TrainingModel {\n        burn_collective::register::<<TrainingBackend<LC> as AutodiffBackend>::InnerBackend>(\n            self.peer_id,\n            self.device.clone(),\n            self.collective_config.clone(),\n        )\n        .expect(\"Couldn't register for collective operations!\");\n\n        let num_epochs = self.components.num_epochs;\n        let interrupter = self.components.interrupter;\n\n        // Changed the train epoch to keep the dataloaders\n        let epoch_train = DdpTrainEpoch::<LC>::new(\n            self.dataloader_train.clone(),\n            self.components.grad_accumulation,\n        );\n        let epoch_valid = self\n            .dataloader_valid\n            .map(|dataloader| DdpValidEpoch::<LC>::new(dataloader));\n        self.learner.fork(&self.device);\n\n        for training_progress in TrainingLoop::new(self.starting_epoch, num_epochs) {\n            let epoch = training_progress.items_processed;\n\n            epoch_train.run(\n                &mut self.learner,\n                &training_progress,\n                self.event_processor.clone(),\n                &interrupter,\n                self.peer_id,\n                self.peer_count,\n                self.is_main,\n            );\n\n            if interrupter.should_stop() {\n                break;\n            }\n\n            // Validation\n            if let Some(runner) = &epoch_valid {\n                let mut event_processor = self.event_processor.lock().unwrap();\n                runner.run(\n                    &self.learner.model(),\n                    &training_progress,\n                    &mut event_processor,\n                    &interrupter,\n                );\n            }\n\n            if let Some(checkpointer) = &mut self.checkpointer {\n                checkpointer.checkpoint(&self.learner, epoch, &self.components.event_store);\n            }\n\n            if let Some(early_stopping) = &mut self.components.early_stopping\n                && early_stopping.should_stop(epoch, &self.components.event_store)\n            {\n                break;\n            }\n        }\n\n        self.learner.model()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/mod.rs",
    "content": "mod base;\n\n#[cfg(feature = \"ddp\")]\npub(crate) mod ddp;\npub(crate) mod multi;\npub(crate) mod single;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/multi/epoch.rs",
    "content": "use crate::learner::base::Interrupter;\nuse crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};\nuse crate::train::MultiDevicesTrainStep;\nuse crate::{\n    Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedTrainingEventProcessor,\n    TrainLoader, TrainingBackend,\n};\nuse burn_core::data::dataloader::Progress;\nuse burn_core::prelude::DeviceOps;\nuse burn_core::tensor::Device;\nuse burn_core::tensor::backend::DeviceId;\nuse burn_optim::GradientsAccumulator;\nuse burn_optim::MultiGradientsParams;\nuse std::collections::HashMap;\n\n/// A training epoch.\n#[derive(new)]\npub struct MultiDeviceTrainEpoch<LC: LearningComponentsTypes> {\n    dataloaders: Vec<TrainLoader<LC>>,\n    grad_accumulation: Option<usize>,\n}\n\nimpl<LC: LearningComponentsTypes> MultiDeviceTrainEpoch<LC> {\n    /// Runs the training epoch on multiple devices.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - The model to train.\n    /// * `optim` - The optimizer to use.\n    /// * `lr_scheduler` - The learning rate scheduler to use.\n    /// * `processor` - The event processor to use.\n    /// * `devices` - The devices to use.\n    ///\n    /// # Returns\n    ///\n    /// The trained model and the optimizer.\n    #[allow(clippy::too_many_arguments)]\n    pub fn run(\n        &self,\n        learner: &mut Learner<LC>,\n        global_progress: &Progress,\n        event_processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n        devices: Vec<Device<TrainingBackend<LC>>>,\n        strategy: MultiDeviceOptim,\n    ) {\n        match strategy {\n            MultiDeviceOptim::OptimMainDevice => self.run_optim_main(\n                learner,\n                global_progress,\n                event_processor,\n                interrupter,\n                devices,\n            ),\n            MultiDeviceOptim::OptimSharded => self.run_optim_distr(\n                learner,\n                global_progress,\n                event_processor,\n                interrupter,\n                devices,\n            ),\n        }\n    }\n\n    fn run_optim_main(\n        &self,\n        learner: &mut Learner<LC>,\n        global_progress: &Progress,\n        event_processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n        devices: Vec<Device<TrainingBackend<LC>>>,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\n            \"Executing training step for epoch {} on devices {:?}\",\n            epoch,\n            devices\n        );\n\n        let mut iterators = self\n            .dataloaders\n            .iter()\n            .map(|d| d.iter())\n            .collect::<Vec<_>>();\n        let mut iteration = 0;\n        let mut accumulator = GradientsAccumulator::new();\n        let mut accumulation_current = 0;\n\n        let accumulation = self.grad_accumulation.unwrap_or(1);\n        let step = MultiDevicesTrainStep::<LC>::new(&devices);\n\n        // The main device is always the first in the list.\n        let device_main = devices.first().expect(\"A minimum of one device.\").clone();\n\n        loop {\n            let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());\n            if items.is_empty() {\n                break;\n            }\n\n            learner.lr_step();\n\n            let mut progress_items = Vec::with_capacity(items.len());\n            for item in items.into_iter() {\n                let grads = item.output.grads.to_device(&device_main, &learner.model());\n                accumulator.accumulate(&learner.model(), grads);\n                progress_items.push(item.output.item);\n            }\n\n            accumulation_current += 1;\n\n            if accumulation <= accumulation_current {\n                let grads = accumulator.grads();\n                learner.optimizer_step(grads);\n                accumulation_current = 0;\n            }\n\n            for item in progress_items {\n                iteration += 1;\n                let item = TrainingItem::new(\n                    item,\n                    progress.clone(),\n                    global_progress.clone(),\n                    Some(iteration),\n                    Some(learner.lr_current()),\n                );\n\n                event_processor.process_train(LearnerEvent::ProcessedItem(item));\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n\n        event_processor.process_train(LearnerEvent::EndEpoch(epoch));\n    }\n\n    fn run_optim_distr(\n        &self,\n        learner: &mut Learner<LC>,\n        global_progress: &Progress,\n        event_processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n        devices: Vec<Device<TrainingBackend<LC>>>,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\n            \"Executing training step for epoch {} on devices {:?}\",\n            epoch,\n            devices\n        );\n\n        let mut iterators = self\n            .dataloaders\n            .iter()\n            .map(|d| d.iter())\n            .collect::<Vec<_>>();\n        let mut iteration = 0;\n        let mut accumulators = HashMap::<\n            DeviceId,\n            GradientsAccumulator<<LC as LearningComponentsTypes>::TrainingModel>,\n        >::new();\n        for device in devices.iter() {\n            accumulators.insert(device.to_id(), GradientsAccumulator::new());\n        }\n        let mut accumulation_current = 0;\n\n        let accumulation = self.grad_accumulation.unwrap_or(1);\n        let step = MultiDevicesTrainStep::<LC>::new(&devices);\n\n        loop {\n            let (items, progress) = step.step(iterators.as_mut_slice(), &learner.model());\n            if items.is_empty() {\n                break;\n            }\n\n            learner.lr_step();\n\n            let mut progress_items = Vec::with_capacity(items.len());\n            for item in items.into_iter() {\n                let accumulator = accumulators.get_mut(&item.device).unwrap();\n                accumulator.accumulate(&learner.model(), item.output.grads);\n                progress_items.push(item.output.item);\n            }\n\n            accumulation_current += 1;\n\n            if accumulation <= accumulation_current {\n                let mut grads = MultiGradientsParams::default();\n                for (device_id, accumulator) in accumulators.iter_mut() {\n                    let grad = accumulator.grads();\n                    grads.grads.push((grad, *device_id));\n                }\n                learner.optimizer_step_multi(grads);\n                accumulation_current = 0;\n            }\n\n            for item in progress_items {\n                iteration += 1;\n                let item = TrainingItem::new(\n                    item,\n                    progress.clone(),\n                    global_progress.clone(),\n                    Some(iteration),\n                    Some(learner.lr_current()),\n                );\n\n                event_processor.process_train(LearnerEvent::ProcessedItem(item));\n            }\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n\n        event_processor.process_train(LearnerEvent::EndEpoch(epoch));\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/multi/mod.rs",
    "content": "pub(crate) mod epoch;\nmod strategy;\n\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/multi/strategy.rs",
    "content": "use crate::{\n    Learner, LearningComponentsTypes, MultiDeviceOptim, SupervisedLearningStrategy,\n    SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend, TrainingComponents,\n    TrainingModel, ValidLoader,\n    multi::epoch::MultiDeviceTrainEpoch,\n    single::{TrainingLoop, epoch::SingleDeviceValidEpoch},\n};\nuse burn_core::{\n    data::dataloader::split::split_dataloader,\n    tensor::{Device, backend::DeviceOps},\n};\n\npub struct MultiDeviceLearningStrategy<LC: LearningComponentsTypes> {\n    devices: Vec<Device<TrainingBackend<LC>>>,\n    optim: MultiDeviceOptim,\n}\nimpl<LC: LearningComponentsTypes> MultiDeviceLearningStrategy<LC> {\n    pub fn new(devices: Vec<Device<TrainingBackend<LC>>>, optim: MultiDeviceOptim) -> Self {\n        Self { devices, optim }\n    }\n}\n\nimpl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>\n    for MultiDeviceLearningStrategy<LC>\n{\n    fn fit(\n        &self,\n        training_components: TrainingComponents<LC>,\n        mut learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        starting_epoch: usize,\n    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {\n        let main_device = self.devices.first().unwrap();\n\n        // `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy\n        // for each (worker) data loader. This matches the expected device on the worker, so we\n        // don't have to move the data between devices.\n        let dataloader_train = split_dataloader(dataloader_train, &self.devices);\n        let dataloader_valid = dataloader_valid.to_device(main_device.inner());\n\n        learner.fork(main_device);\n        let mut event_processor = training_components.event_processor;\n        let mut checkpointer = training_components.checkpointer;\n        let mut early_stopping = training_components.early_stopping;\n\n        let epoch_train = MultiDeviceTrainEpoch::<LC>::new(\n            dataloader_train.clone(),\n            training_components.grad_accumulation,\n        );\n        let epoch_valid: SingleDeviceValidEpoch<LC> =\n            SingleDeviceValidEpoch::new(dataloader_valid.clone());\n\n        for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {\n            let epoch = training_progress.items_processed;\n            epoch_train.run(\n                &mut learner,\n                &training_progress,\n                &mut event_processor,\n                &training_components.interrupter,\n                self.devices.to_vec(),\n                self.optim,\n            );\n\n            if training_components.interrupter.should_stop() {\n                let reason = training_components\n                    .interrupter\n                    .get_message()\n                    .unwrap_or(String::from(\"Reason unknown\"));\n                log::info!(\"Training interrupted: {reason}\");\n                break;\n            }\n\n            // After OptimSharded training, model parameters are scattered across\n            // devices. Fork back to main_device before single-device validation.\n            if matches!(self.optim, MultiDeviceOptim::OptimSharded) {\n                learner.fork(main_device);\n            }\n\n            epoch_valid.run(\n                &learner,\n                &training_progress,\n                &mut event_processor,\n                &training_components.interrupter,\n            );\n\n            if let Some(checkpointer) = &mut checkpointer {\n                checkpointer.checkpoint(&learner, epoch, &training_components.event_store);\n            }\n\n            if let Some(early_stopping) = &mut early_stopping\n                && early_stopping.should_stop(epoch, &training_components.event_store)\n            {\n                break;\n            }\n        }\n\n        (learner.model(), event_processor)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/single/epoch.rs",
    "content": "use crate::learner::base::Interrupter;\nuse crate::metric::processor::{EventProcessorTraining, LearnerEvent, TrainingItem};\nuse crate::{\n    InferenceStep, Learner, LearningComponentsTypes, SupervisedTrainingEventProcessor, TrainLoader,\n    ValidLoader,\n};\nuse burn_core::data::dataloader::Progress;\nuse burn_core::module::AutodiffModule;\nuse burn_optim::GradientsAccumulator;\n\n/// A validation epoch.\n#[derive(new)]\npub struct SingleDeviceValidEpoch<LC: LearningComponentsTypes> {\n    dataloader: ValidLoader<LC>,\n}\n\n/// A training epoch.\n#[derive(new)]\npub struct SingleDeviceTrainEpoch<LC: LearningComponentsTypes> {\n    dataloader: TrainLoader<LC>,\n    grad_accumulation: Option<usize>,\n}\n\nimpl<LC: LearningComponentsTypes> SingleDeviceValidEpoch<LC> {\n    /// Runs the validation epoch.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - The model to validate.\n    /// * `processor` - The event processor to use.\n    pub fn run(\n        &self,\n        learner: &Learner<LC>,\n        global_progress: &Progress,\n        processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\"Executing validation step for epoch {}\", epoch);\n        let model = learner.model().valid();\n\n        let mut iterator = self.dataloader.iter();\n        let mut iteration = 0;\n\n        while let Some(item) = iterator.next() {\n            let progress = iterator.progress();\n            iteration += 1;\n\n            let item = model.step(item);\n            let item = TrainingItem::new(\n                item,\n                progress,\n                global_progress.clone(),\n                Some(iteration),\n                None,\n            );\n\n            processor.process_valid(LearnerEvent::ProcessedItem(item));\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        processor.process_valid(LearnerEvent::EndEpoch(epoch));\n    }\n}\n\nimpl<LC: LearningComponentsTypes> SingleDeviceTrainEpoch<LC> {\n    /// Runs the training epoch.\n    ///\n    /// # Arguments\n    ///\n    /// * `model` - The model to train.\n    /// * `optim` - The optimizer to use.\n    /// * `scheduler` - The learning rate scheduler to use.\n    /// * `processor` - The event processor to use.\n    ///\n    /// # Returns\n    ///\n    /// The trained model and the optimizer.\n    pub fn run(\n        &self,\n        learner: &mut Learner<LC>,\n        global_progress: &Progress,\n        processor: &mut SupervisedTrainingEventProcessor<LC>,\n        interrupter: &Interrupter,\n    ) {\n        let epoch = global_progress.items_processed;\n        log::info!(\"Executing training step for epoch {}\", epoch,);\n\n        // Single device / dataloader\n        let mut iterator = self.dataloader.iter();\n        let mut iteration = 0;\n        let mut accumulator = GradientsAccumulator::new();\n        let mut accumulation_current = 0;\n\n        while let Some(item) = iterator.next() {\n            iteration += 1;\n            learner.lr_step();\n            log::info!(\"Iteration {iteration}\");\n\n            let progress = iterator.progress();\n            let item = learner.train_step(item);\n\n            match self.grad_accumulation {\n                Some(accumulation) => {\n                    accumulator.accumulate(&learner.model(), item.grads);\n                    accumulation_current += 1;\n\n                    if accumulation <= accumulation_current {\n                        let grads = accumulator.grads();\n\n                        learner.optimizer_step(grads);\n                        accumulation_current = 0;\n                    }\n                }\n                None => learner.optimizer_step(item.grads),\n            }\n\n            let item = TrainingItem::new(\n                item.item,\n                progress,\n                global_progress.clone(),\n                Some(iteration),\n                Some(learner.lr_current()),\n            );\n\n            processor.process_train(LearnerEvent::ProcessedItem(item));\n\n            if interrupter.should_stop() {\n                break;\n            }\n        }\n        processor.process_train(LearnerEvent::EndEpoch(epoch));\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/single/mod.rs",
    "content": "pub(crate) mod epoch;\nmod strategy;\n\npub use strategy::*;\n"
  },
  {
    "path": "crates/burn-train/src/learner/supervised/strategies/single/strategy.rs",
    "content": "use crate::{\n    Learner, LearningComponentsTypes, SupervisedLearningStrategy, SupervisedTrainingEventProcessor,\n    TrainLoader, TrainingBackend, TrainingComponents, TrainingModel, ValidLoader,\n    single::epoch::{SingleDeviceTrainEpoch, SingleDeviceValidEpoch},\n};\nuse burn_core::{\n    data::dataloader::Progress,\n    tensor::{Device, backend::DeviceOps},\n};\n\n/// Simplest learning strategy possible, with only a single devices doing both the training and\n/// validation.\npub struct SingleDeviceTrainingStrategy<LC: LearningComponentsTypes> {\n    device: Device<TrainingBackend<LC>>,\n}\nimpl<LC: LearningComponentsTypes> SingleDeviceTrainingStrategy<LC> {\n    pub fn new(device: Device<TrainingBackend<LC>>) -> Self {\n        Self { device }\n    }\n}\n\n#[derive(new)]\npub(crate) struct TrainingLoop {\n    next_iteration: usize,\n    total_iteration: usize,\n}\n\n/// An iterator that returns the progress of the training.\nimpl Iterator for TrainingLoop {\n    type Item = Progress;\n\n    fn next(&mut self) -> Option<Self::Item> {\n        if self.next_iteration > self.total_iteration {\n            return None;\n        }\n\n        let progress = Progress {\n            items_processed: self.next_iteration,\n            items_total: self.total_iteration,\n        };\n\n        self.next_iteration += 1;\n        Some(progress)\n    }\n}\n\nimpl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC>\n    for SingleDeviceTrainingStrategy<LC>\n{\n    fn fit(\n        &self,\n        training_components: TrainingComponents<LC>,\n        mut learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        starting_epoch: usize,\n    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {\n        let dataloader_train = dataloader_train.to_device(&self.device);\n        let dataloader_valid = dataloader_valid.to_device(self.device.inner());\n        learner.fork(&self.device);\n        let mut event_processor = training_components.event_processor;\n        let mut checkpointer = training_components.checkpointer;\n        let mut early_stopping = training_components.early_stopping;\n\n        let epoch_train: SingleDeviceTrainEpoch<LC> =\n            SingleDeviceTrainEpoch::new(dataloader_train, training_components.grad_accumulation);\n        let epoch_valid: SingleDeviceValidEpoch<LC> =\n            SingleDeviceValidEpoch::new(dataloader_valid.clone());\n\n        for training_progress in TrainingLoop::new(starting_epoch, training_components.num_epochs) {\n            let epoch = training_progress.items_processed;\n            epoch_train.run(\n                &mut learner,\n                &training_progress,\n                &mut event_processor,\n                &training_components.interrupter,\n            );\n\n            if training_components.interrupter.should_stop() {\n                let reason = training_components\n                    .interrupter\n                    .get_message()\n                    .unwrap_or(String::from(\"Reason unknown\"));\n                log::info!(\"Training interrupted: {reason}\");\n                break;\n            }\n\n            epoch_valid.run(\n                &learner,\n                &training_progress,\n                &mut event_processor,\n                &training_components.interrupter,\n            );\n\n            if let Some(checkpointer) = &mut checkpointer {\n                checkpointer.checkpoint(&learner, epoch, &training_components.event_store);\n            }\n\n            if let Some(early_stopping) = &mut early_stopping\n                && early_stopping.should_stop(epoch, &training_components.event_store)\n            {\n                break;\n            }\n        }\n\n        (learner.model(), event_processor)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/learner/train_val.rs",
    "content": "use crate::{ItemLazy, renderer::MetricsRenderer};\nuse burn_core::module::AutodiffModule;\nuse burn_core::tensor::backend::AutodiffBackend;\nuse burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};\n\n/// A training output.\npub struct TrainOutput<TO> {\n    /// The gradients.\n    pub grads: GradientsParams,\n\n    /// The item.\n    pub item: TO,\n}\n\nimpl<TO> TrainOutput<TO> {\n    /// Creates a new training output.\n    ///\n    /// # Arguments\n    ///\n    /// * `module` - The module.\n    /// * `grads` - The gradients.\n    /// * `item` - The item.\n    ///\n    /// # Returns\n    ///\n    /// A new training output.\n    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(\n        module: &M,\n        grads: B::Gradients,\n        item: TO,\n    ) -> Self {\n        let grads = GradientsParams::from_grads(grads, module);\n        Self { grads, item }\n    }\n}\n\n/// Trait to be implemented for models to be able to be trained.\n///\n/// The [step](TrainStep::step) method needs to be manually implemented for all structs.\n///\n/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the\n/// optimizer is used to update the model. This can be useful if you want to call custom mutable\n/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.\n///\n/// # Notes\n///\n/// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must\n/// also implement the [AutodiffModule] trait, which is done automatically with the\n/// [Module](burn_core::module::Module) derive.\npub trait TrainStep {\n    /// Type of input for a step of the training stage.\n    type Input: Send + 'static;\n    /// Type of output for a step of the training stage.\n    type Output: ItemLazy + 'static;\n    /// Runs a step for training, which executes the forward and backward passes.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The input for the model.\n    ///\n    /// # Returns\n    ///\n    /// The output containing the model output and the gradients.\n    fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;\n    /// Optimize the current module with the provided gradients and learning rate.\n    ///\n    /// # Arguments\n    ///\n    /// * `optim`: Optimizer used for learning.\n    /// * `lr`: The learning rate used for this step.\n    /// * `grads`: The gradients of each parameter in the current model.\n    ///\n    /// # Returns\n    ///\n    /// The updated model.\n    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self\n    where\n        B: AutodiffBackend,\n        O: Optimizer<Self, B>,\n        Self: AutodiffModule<B>,\n    {\n        optim.step(lr, self, grads)\n    }\n    /// Optimize the current module with the provided gradients and learning rate.\n    ///\n    /// # Arguments\n    ///\n    /// * `optim`: Optimizer used for learning.\n    /// * `lr`: The learning rate used for this step.\n    /// * `grads`: Multiple gradients associated to each parameter in the current model.\n    ///\n    /// # Returns\n    ///\n    /// The updated model.\n    fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self\n    where\n        B: AutodiffBackend,\n        O: Optimizer<Self, B>,\n        Self: AutodiffModule<B>,\n    {\n        optim.step_multi(lr, self, grads)\n    }\n}\n\n/// Trait to be implemented for validating models.\npub trait InferenceStep {\n    /// Type of input for an inference step.\n    type Input: Send + 'static;\n    /// Type of output for an inference step.\n    type Output: ItemLazy + 'static;\n    /// Runs a validation step.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The item to validate on.\n    ///\n    /// # Returns\n    ///\n    /// The validation output.\n    fn step(&self, item: Self::Input) -> Self::Output;\n}\n\n/// The result of a training, containing the model along with the [renderer](MetricsRenderer).\npub struct LearningResult<M> {\n    /// The model with the learned weights.\n    pub model: M,\n    /// The renderer that can be used for follow up training and evaluation.\n    pub renderer: Box<dyn MetricsRenderer>,\n}\n"
  },
  {
    "path": "crates/burn-train/src/lib.rs",
    "content": "#![warn(missing_docs)]\n#![cfg_attr(docsrs, feature(doc_cfg))]\n\n//! A library for training neural networks using the burn crate.\n\n#[macro_use]\nextern crate derive_new;\n\n/// The checkpoint module.\npub mod checkpoint;\n\npub(crate) mod components;\n\n/// Renderer modules to display metrics and training information.\npub mod renderer;\n\n/// The logger module.\npub mod logger;\n\n/// The metric module.\npub mod metric;\n\npub use metric::processor::*;\n\nmod learner;\n\npub use learner::*;\n\nmod evaluator;\n\npub use evaluator::*;\n\npub use components::*;\n\n#[cfg(test)]\npub(crate) type TestBackend = burn_ndarray::NdArray<f32>;\n\n#[cfg(test)]\npub(crate) mod tests {\n    use crate::TestBackend;\n    use burn_core::{prelude::Tensor, tensor::Bool};\n    use std::default::Default;\n\n    pub type TestAutodiffBackend = burn_autodiff::Autodiff<TestBackend>;\n\n    /// Probability of tp before adding errors\n    pub const THRESHOLD: f64 = 0.5;\n\n    #[derive(Debug, Default)]\n    pub enum ClassificationType {\n        #[default]\n        Binary,\n        Multiclass,\n        Multilabel,\n    }\n\n    /// Sample x Class shaped matrix for use in\n    /// classification metrics testing\n    pub fn dummy_classification_input(\n        classification_type: &ClassificationType,\n    ) -> (Tensor<TestBackend, 2>, Tensor<TestBackend, 2, Bool>) {\n        match classification_type {\n            ClassificationType::Binary => {\n                (\n                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),\n                    // targets\n                    Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),\n                    // predictions @ threshold=0.5\n                    //                     [[0], [0], [1], [0], [1]]\n                )\n            }\n            ClassificationType::Multiclass => {\n                (\n                    Tensor::from_data(\n                        [\n                            [0.2, 0.8, 0.0],\n                            [0.3, 0.6, 0.1],\n                            [0.7, 0.25, 0.05],\n                            [0.1, 0.15, 0.8],\n                            [0.9, 0.03, 0.07],\n                        ],\n                        &Default::default(),\n                    ),\n                    Tensor::from_data(\n                        // targets\n                        [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],\n                        // predictions @ top_k=1\n                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]\n                        // predictions @ top_k=2\n                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]\n                        &Default::default(),\n                    ),\n                )\n            }\n            ClassificationType::Multilabel => {\n                (\n                    Tensor::from_data(\n                        [\n                            [0.1, 0.7, 0.6],\n                            [0.3, 0.9, 0.05],\n                            [0.8, 0.9, 0.4],\n                            [0.7, 0.5, 0.9],\n                            [1.0, 0.3, 0.2],\n                        ],\n                        &Default::default(),\n                    ),\n                    // targets\n                    Tensor::from_data(\n                        [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],\n                        // predictions @ threshold=0.5\n                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]\n                        &Default::default(),\n                    ),\n                )\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/async_logger.rs",
    "content": "use super::Logger;\nuse std::sync::mpsc;\n\nenum Message<T> {\n    Log(T),\n    End,\n    Sync(mpsc::Sender<()>),\n}\n/// Async logger.\npub struct AsyncLogger<T> {\n    sender: mpsc::Sender<Message<T>>,\n    handler: Option<std::thread::JoinHandle<()>>,\n}\n\n#[derive(new)]\nstruct LoggerThread<T, L: Logger<T>> {\n    logger: L,\n    receiver: mpsc::Receiver<Message<T>>,\n}\n\nimpl<T, L> LoggerThread<T, L>\nwhere\n    L: Logger<T>,\n{\n    fn run(mut self) {\n        for item in self.receiver.iter() {\n            match item {\n                Message::Log(item) => {\n                    self.logger.log(item);\n                }\n                Message::End => {\n                    return;\n                }\n                Message::Sync(callback) => {\n                    callback\n                        .send(())\n                        .expect(\"Can return result with the callback channel.\");\n                }\n            }\n        }\n    }\n}\n\nimpl<T: Send + Sync + 'static> AsyncLogger<T> {\n    /// Create a new async logger.\n    pub fn new<L>(logger: L) -> Self\n    where\n        L: Logger<T> + 'static,\n    {\n        let (sender, receiver) = mpsc::channel();\n        let thread = LoggerThread::new(logger, receiver);\n\n        let handler = Some(std::thread::spawn(move || thread.run()));\n\n        Self { sender, handler }\n    }\n\n    /// Sync the async logger.\n    pub(crate) fn sync(&self) {\n        let (sender, receiver) = mpsc::channel();\n\n        self.sender\n            .send(Message::Sync(sender))\n            .expect(\"Can send message to logger thread.\");\n\n        receiver\n            .recv()\n            .expect(\"Should sync, otherwise the thread is dead.\");\n    }\n}\n\nimpl<T: Send> Logger<T> for AsyncLogger<T> {\n    fn log(&mut self, item: T) {\n        self.sender\n            .send(Message::Log(item))\n            .expect(\"Can log using the logger thread.\");\n    }\n}\n\nimpl<T> Drop for AsyncLogger<T> {\n    fn drop(&mut self) {\n        self.sender\n            .send(Message::End)\n            .expect(\"Can send the end message to the logger thread.\");\n        let handler = self.handler.take();\n\n        if let Some(handler) = handler {\n            handler.join().expect(\"The logger thread should stop.\");\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/base.rs",
    "content": "/// The logger trait.\npub trait Logger<T>: Send {\n    /// Logs an item.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The item.\n    fn log(&mut self, item: T);\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/file.rs",
    "content": "use super::Logger;\nuse std::{fs::File, io::Write, path::Path};\n\n/// File logger.\npub struct FileLogger {\n    file: File,\n}\n\nimpl FileLogger {\n    /// Create a new file logger.\n    ///\n    /// # Arguments\n    ///\n    /// * `path` - The path.\n    ///\n    /// # Returns\n    ///\n    /// The file logger.\n    pub fn new(path: impl AsRef<Path>) -> Self {\n        let path = path.as_ref();\n        let mut options = std::fs::File::options();\n        let file = options\n            .write(true)\n            .truncate(true)\n            .create(true)\n            .open(path)\n            .unwrap_or_else(|err| {\n                panic!(\n                    \"Should be able to create the new file '{}': {}\",\n                    path.display(),\n                    err\n                )\n            });\n\n        Self { file }\n    }\n}\n\nimpl<T> Logger<T> for FileLogger\nwhere\n    T: std::fmt::Display,\n{\n    fn log(&mut self, item: T) {\n        writeln!(&mut self.file, \"{item}\").expect(\"Can log an item.\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/in_memory.rs",
    "content": "use super::Logger;\n\n/// In memory logger.\n#[derive(Default)]\npub struct InMemoryLogger {\n    pub(crate) values: Vec<String>,\n}\n\nimpl<T> Logger<T> for InMemoryLogger\nwhere\n    T: std::fmt::Display,\n{\n    fn log(&mut self, item: T) {\n        self.values.push(item.to_string());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/metric.rs",
    "content": "use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};\nuse crate::metric::{\n    MetricDefinition, MetricEntry, MetricId, NumericEntry,\n    store::{EpochSummary, MetricsUpdate, Split},\n};\nuse std::{\n    collections::HashMap,\n    fs,\n    path::{Path, PathBuf},\n};\n\nconst EPOCH_PREFIX: &str = \"epoch-\";\n\n/// Metric logger.\npub trait MetricLogger: Send {\n    /// Logs an item.\n    ///\n    /// # Arguments\n    ///\n    /// * `update` - Update information for all registered metrics.\n    /// * `epoch` - Current epoch.\n    /// * `split` - Current dataset split.\n    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);\n\n    /// Read the logs for an epoch.\n    fn read_numeric(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        split: &Split,\n    ) -> Result<Vec<NumericEntry>, String>;\n\n    /// Logs the metric definition information (name, description, unit, etc.)\n    fn log_metric_definition(&mut self, definition: MetricDefinition);\n\n    /// Logs summary at the end of the epoch.\n    fn log_epoch_summary(&mut self, summary: EpochSummary);\n}\n\n/// The file metric logger.\npub struct FileMetricLogger {\n    loggers: HashMap<String, AsyncLogger<String>>,\n    directory: PathBuf,\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n    is_eval: bool,\n    last_epoch: Option<usize>,\n}\n\nimpl FileMetricLogger {\n    /// Create a new file metric logger.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory.\n    ///\n    /// # Returns\n    ///\n    /// The file metric logger.\n    pub fn new(directory: impl AsRef<Path>) -> Self {\n        Self {\n            loggers: HashMap::new(),\n            directory: directory.as_ref().to_path_buf(),\n            metric_definitions: HashMap::default(),\n            is_eval: false,\n            last_epoch: None,\n        }\n    }\n\n    /// Create a new file metric logger.\n    ///\n    /// # Arguments\n    ///\n    /// * `directory` - The directory.\n    ///\n    /// # Returns\n    ///\n    /// The file metric logger.\n    pub fn new_eval(directory: impl AsRef<Path>) -> Self {\n        Self {\n            loggers: HashMap::new(),\n            directory: directory.as_ref().to_path_buf(),\n            metric_definitions: HashMap::default(),\n            is_eval: true,\n            last_epoch: None,\n        }\n    }\n\n    pub(crate) fn split_exists(&self, split: &Split) -> bool {\n        self.split_dir(split).is_some()\n    }\n\n    pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {\n        let split_path = match split {\n            Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),\n            other => self.directory.join(other.to_string()),\n        };\n        (split_path.exists() && split_path.is_dir()).then_some(split_path)\n    }\n\n    pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {\n        dirname.as_ref().starts_with(EPOCH_PREFIX)\n    }\n\n    /// Number of epochs recorded.\n    pub(crate) fn epochs(&self) -> usize {\n        if self.is_eval {\n            log::warn!(\"Number of epochs not available when testing.\");\n            return 0;\n        }\n\n        let mut max_epoch = 0;\n\n        // with split\n        for path in fs::read_dir(&self.directory).unwrap() {\n            let path = path.unwrap();\n\n            if fs::metadata(path.path()).unwrap().is_dir() {\n                for split_path in fs::read_dir(path.path()).unwrap() {\n                    let split_path = split_path.unwrap();\n\n                    if fs::metadata(split_path.path()).unwrap().is_dir() {\n                        let dir_name = split_path.file_name().into_string().unwrap();\n\n                        if !dir_name.starts_with(EPOCH_PREFIX) {\n                            continue;\n                        }\n\n                        let epoch = dir_name.replace(EPOCH_PREFIX, \"\").parse::<usize>().ok();\n\n                        if let Some(epoch) = epoch\n                            && epoch > max_epoch\n                        {\n                            max_epoch = epoch;\n                        }\n                    }\n                }\n            }\n        }\n\n        max_epoch\n    }\n\n    fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {\n        let name = format!(\"{EPOCH_PREFIX}{epoch}\");\n\n        match split {\n            Split::Train | Split::Valid | Split::Test(None) => {\n                self.directory.join(split.to_string()).join(name)\n            }\n            Split::Test(Some(tag)) => {\n                let tag = format_tag(tag);\n                self.directory.join(split.to_string()).join(tag).join(name)\n            }\n        }\n    }\n\n    fn eval_directory(&self, split: &Split) -> PathBuf {\n        match split {\n            Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),\n            Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),\n        }\n    }\n\n    fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {\n        let directory = match epoch {\n            Some(epoch) => self.train_directory(epoch, split),\n            None => self.eval_directory(split),\n        };\n        let name = name.replace(' ', \"_\");\n        let name = format!(\"{name}.log\");\n        directory.join(name)\n    }\n\n    fn create_directory(&self, epoch: Option<usize>, split: &Split) {\n        let directory = match epoch {\n            Some(epoch) => self.train_directory(epoch, split),\n            None => self.eval_directory(split),\n        };\n        std::fs::create_dir_all(directory).ok();\n    }\n}\n\nimpl FileMetricLogger {\n    fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {\n        let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;\n        let key = logger_key(name, split);\n        let value = &item.serialized_entry.serialized;\n\n        let logger = match self.loggers.get_mut(&key) {\n            Some(val) => val,\n            None => {\n                self.create_directory(epoch, split);\n\n                let file_path = self.file_path(name, epoch, split);\n                let logger = FileLogger::new(file_path);\n                let logger = AsyncLogger::new(logger);\n\n                self.loggers.insert(key.clone(), logger);\n                self.loggers\n                    .get_mut(&key)\n                    .expect(\"Can get the previously saved logger.\")\n            }\n        };\n\n        logger.log(value.clone());\n    }\n}\n\nfn format_tag(tag: &str) -> String {\n    tag.trim().replace(' ', \"-\").to_lowercase()\n}\n\nimpl MetricLogger for FileMetricLogger {\n    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {\n        if !self.is_eval && self.last_epoch != Some(epoch) {\n            self.loggers.clear();\n            self.last_epoch = Some(epoch);\n        }\n\n        let entries: Vec<_> = update\n            .entries\n            .iter()\n            .chain(\n                update\n                    .entries_numeric\n                    .iter()\n                    .map(|numeric_update| &numeric_update.entry),\n            )\n            .cloned()\n            .collect();\n\n        for item in entries.iter() {\n            self.log_item(item, Some(epoch), split);\n        }\n    }\n\n    fn read_numeric(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        split: &Split,\n    ) -> Result<Vec<NumericEntry>, String> {\n        if let Some(value) = self.loggers.get(name) {\n            value.sync()\n        }\n\n        let file_path = self.file_path(name, Some(epoch), split);\n\n        let mut errors = false;\n\n        let data = std::fs::read_to_string(file_path)\n            .unwrap_or_default()\n            .split('\\n')\n            .filter_map(|value| {\n                if value.is_empty() {\n                    None\n                } else {\n                    match NumericEntry::deserialize(value) {\n                        Ok(value) => Some(value),\n                        Err(err) => {\n                            log::error!(\"{err}\");\n                            errors = true;\n                            None\n                        }\n                    }\n                }\n            })\n            .collect();\n\n        if errors {\n            Err(\"Parsing numeric entry errors\".to_string())\n        } else {\n            Ok(data)\n        }\n    }\n\n    fn log_metric_definition(&mut self, definition: MetricDefinition) {\n        self.metric_definitions\n            .insert(definition.metric_id.clone(), definition);\n    }\n\n    fn log_epoch_summary(&mut self, _summary: EpochSummary) {\n        if !self.is_eval {\n            self.loggers.clear();\n        }\n    }\n}\n\nfn logger_key(name: &str, split: &Split) -> String {\n    format!(\"{name}_{split}\")\n}\n\n/// In memory metric logger, useful when testing and debugging.\n#[derive(Default)]\npub struct InMemoryMetricLogger {\n    values: HashMap<String, Vec<InMemoryLogger>>,\n    last_epoch: Option<usize>,\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n}\n\nimpl InMemoryMetricLogger {\n    /// Create a new in-memory metric logger.\n    pub fn new() -> Self {\n        Self::default()\n    }\n}\n\nimpl MetricLogger for InMemoryMetricLogger {\n    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {\n        if self.last_epoch != Some(epoch) {\n            self.values\n                .values_mut()\n                .for_each(|loggers| loggers.push(InMemoryLogger::default()));\n            self.last_epoch = Some(epoch);\n        }\n\n        let entries: Vec<_> = update\n            .entries\n            .iter()\n            .chain(\n                update\n                    .entries_numeric\n                    .iter()\n                    .map(|numeric_update| &numeric_update.entry),\n            )\n            .cloned()\n            .collect();\n\n        for item in entries.iter() {\n            let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;\n            let key = logger_key(name, split);\n\n            if !self.values.contains_key(&key) {\n                self.values\n                    .insert(key.to_string(), vec![InMemoryLogger::default()]);\n            }\n\n            let values = self.values.get_mut(&key).unwrap();\n\n            values\n                .last_mut()\n                .unwrap()\n                .log(item.serialized_entry.serialized.clone());\n        }\n    }\n\n    fn read_numeric(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        split: &Split,\n    ) -> Result<Vec<NumericEntry>, String> {\n        let key = logger_key(name, split);\n        let values = match self.values.get(&key) {\n            Some(values) => values,\n            None => return Ok(Vec::new()),\n        };\n\n        match values.get(epoch - 1) {\n            Some(logger) => Ok(logger\n                .values\n                .iter()\n                .filter_map(|value| NumericEntry::deserialize(value).ok())\n                .collect()),\n            None => Ok(Vec::new()),\n        }\n    }\n\n    fn log_metric_definition(&mut self, definition: MetricDefinition) {\n        self.metric_definitions\n            .insert(definition.metric_id.clone(), definition);\n    }\n\n    fn log_epoch_summary(&mut self, _summary: EpochSummary) {}\n}\n"
  },
  {
    "path": "crates/burn-train/src/logger/mod.rs",
    "content": "mod async_logger;\nmod base;\nmod file;\nmod in_memory;\nmod metric;\n\npub use async_logger::*;\npub use base::*;\npub use file::*;\npub use in_memory::*;\npub use metric::*;\n"
  },
  {
    "path": "crates/burn-train/src/metric/acc.rs",
    "content": "use core::marker::PhantomData;\n\nuse super::MetricMetadata;\nuse super::state::{FormatOptions, NumericMetricState};\nuse crate::metric::{Metric, MetricAttributes, MetricName, Numeric, SerializedEntry};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{ElementConversion, Int, Tensor};\n\n/// The accuracy metric.\n#[derive(Clone)]\npub struct AccuracyMetric<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    pad_token: Option<usize>,\n    _b: PhantomData<B>,\n}\n\n/// The [accuracy metric](AccuracyMetric) input type.\n#[derive(new)]\npub struct AccuracyInput<B: Backend> {\n    outputs: Tensor<B, 2>,\n    targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Default for AccuracyMetric<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> AccuracyMetric<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self {\n            name: MetricName::new(\"Accuracy\".to_string()),\n            state: Default::default(),\n            pad_token: Default::default(),\n            _b: PhantomData,\n        }\n    }\n\n    /// Sets the pad token.\n    pub fn with_pad_token(mut self, index: usize) -> Self {\n        self.pad_token = Some(index);\n        self\n    }\n}\n\nimpl<B: Backend> Metric for AccuracyMetric<B> {\n    type Input = AccuracyInput<B>;\n\n    fn update(&mut self, input: &AccuracyInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {\n        let targets = input.targets.clone();\n        let outputs = input.outputs.clone();\n\n        let [batch_size, _n_classes] = outputs.dims();\n\n        let outputs = outputs.argmax(1).reshape([batch_size]);\n\n        let accuracy = match self.pad_token {\n            Some(pad_token) => {\n                let mask = targets.clone().equal_elem(pad_token as i64);\n                let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0);\n                let num_pad = mask.float().sum();\n\n                let acc = matches.sum() / (num_pad.neg() + batch_size as f32);\n\n                acc.into_scalar().elem::<f64>()\n            }\n            None => {\n                outputs\n                    .equal(targets)\n                    .int()\n                    .sum()\n                    .into_scalar()\n                    .elem::<f64>()\n                    / batch_size as f64\n            }\n        };\n\n        self.state.update(\n            100.0 * accuracy,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        super::NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for AccuracyMetric<B> {\n    fn value(&self) -> super::NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> super::NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn test_accuracy_without_padding() {\n        let device = Default::default();\n        let mut metric = AccuracyMetric::<TestBackend>::new();\n        let input = AccuracyInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.2, 0.8], // 2\n                    [1.0, 2.0, 0.5], // 1\n                    [0.4, 0.1, 0.2], // 0\n                    [0.6, 0.7, 0.2], // 1\n                ],\n                &device,\n            ),\n            Tensor::from_data([2, 2, 1, 1], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    #[test]\n    fn test_accuracy_with_padding() {\n        let device = Default::default();\n        let mut metric = AccuracyMetric::<TestBackend>::new().with_pad_token(3);\n        let input = AccuracyInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.2, 0.8, 0.0], // 2\n                    [1.0, 2.0, 0.5, 0.0], // 1\n                    [0.4, 0.1, 0.2, 0.0], // 0\n                    [0.6, 0.7, 0.2, 0.0], // 1\n                    [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count\n                    [0.0, 0.1, 0.2, 0.0], // Error on padding should not count\n                    [0.6, 0.0, 0.2, 0.0], // Error on padding should not count\n                ],\n                &device,\n            ),\n            Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/auroc.rs",
    "content": "use core::f64;\nuse core::marker::PhantomData;\n\nuse super::MetricMetadata;\nuse super::state::{FormatOptions, NumericMetricState};\nuse crate::metric::{Metric, MetricName, Numeric, SerializedEntry};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{ElementConversion, Int, Tensor};\n\n/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.\n#[derive(Clone)]\npub struct AurocMetric<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    _b: PhantomData<B>,\n}\n\n/// The [AUROC metric](AurocMetric) input type.\n#[derive(new)]\npub struct AurocInput<B: Backend> {\n    outputs: Tensor<B, 2>,\n    targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Default for AurocMetric<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> AurocMetric<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self {\n            name: MetricName::new(\"AUROC\".to_string()),\n            state: Default::default(),\n            _b: PhantomData,\n        }\n    }\n\n    fn binary_auroc(&self, probabilities: &Tensor<B, 1>, targets: &Tensor<B, 1, Int>) -> f64 {\n        let n = targets.dims()[0];\n\n        let n_pos = targets.clone().sum().into_scalar().elem::<u64>() as usize;\n\n        // Early return if we don't have both positive and negative samples\n        if n_pos == 0 || n_pos == n {\n            if n_pos == 0 {\n                log::warn!(\"Metric cannot be computed because all target values are negative.\")\n            } else {\n                log::warn!(\"Metric cannot be computed because all target values are positive.\")\n            }\n            return 0.0;\n        }\n\n        let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);\n        let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);\n\n        let valid_pairs = pos_mask * neg_mask;\n\n        let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);\n        let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);\n\n        let correct_order = prob_i.clone().greater(prob_j.clone()).int();\n\n        let ties = prob_i.equal(prob_j).int();\n\n        // Calculate AUC components\n        let num_pairs = valid_pairs.clone().sum().into_scalar().elem::<f64>();\n        let correct_pairs = (correct_order * valid_pairs.clone())\n            .sum()\n            .into_scalar()\n            .elem::<f64>();\n        let tied_pairs = (ties * valid_pairs).sum().into_scalar().elem::<f64>();\n\n        (correct_pairs + 0.5 * tied_pairs) / num_pairs\n    }\n}\n\nimpl<B: Backend> Metric for AurocMetric<B> {\n    type Input = AurocInput<B>;\n\n    fn update(&mut self, input: &AurocInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [batch_size, num_classes] = input.outputs.dims();\n\n        assert_eq!(\n            num_classes, 2,\n            \"Currently only binary classification is supported\"\n        );\n\n        let probabilities = {\n            let exponents = input.outputs.clone().exp();\n            let sum = exponents.clone().sum_dim(1);\n            (exponents / sum)\n                .select(1, Tensor::arange(1..2, &input.outputs.device()))\n                .squeeze_dim(1)\n        };\n\n        let area_under_curve = self.binary_auroc(&probabilities, &input.targets);\n\n        self.state.update(\n            100.0 * area_under_curve,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n}\n\nimpl<B: Backend> Numeric for AurocMetric<B> {\n    fn value(&self) -> super::NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> super::NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn test_auroc() {\n        let device = Default::default();\n        let mut metric = AurocMetric::<TestBackend>::new();\n\n        let input = AurocInput::new(\n            Tensor::from_data(\n                [\n                    [0.1, 0.9], // High confidence positive\n                    [0.7, 0.3], // Low confidence negative\n                    [0.6, 0.4], // Low confidence negative\n                    [0.2, 0.8], // High confidence positive\n                ],\n                &device,\n            ),\n            Tensor::from_data([1, 0, 0, 1], &device), // True labels\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(metric.value().current(), 100.0);\n    }\n\n    #[test]\n    fn test_auroc_perfect_separation() {\n        let device = Default::default();\n        let mut metric = AurocMetric::<TestBackend>::new();\n\n        let input = AurocInput::new(\n            Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),\n            Tensor::from_data([1, 0, 0, 1], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(metric.value().current(), 100.0); // Perfect AUC\n    }\n\n    #[test]\n    fn test_auroc_random() {\n        let device = Default::default();\n        let mut metric = AurocMetric::<TestBackend>::new();\n\n        let input = AurocInput::new(\n            Tensor::from_data(\n                [\n                    [0.5, 0.5], // Random predictions\n                    [0.5, 0.5],\n                    [0.5, 0.5],\n                    [0.5, 0.5],\n                ],\n                &device,\n            ),\n            Tensor::from_data([1, 0, 0, 1], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(metric.value().current(), 50.0);\n    }\n\n    #[test]\n    fn test_auroc_all_one_class() {\n        let device = Default::default();\n        let mut metric = AurocMetric::<TestBackend>::new();\n\n        let input = AurocInput::new(\n            Tensor::from_data(\n                [\n                    [0.1, 0.9], // All positives predictions\n                    [0.2, 0.8],\n                    [0.3, 0.7],\n                    [0.4, 0.6],\n                ],\n                &device,\n            ),\n            Tensor::from_data([1, 1, 1, 1], &device), // All positive class\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(metric.value().current(), 0.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Currently only binary classification is supported\")]\n    fn test_auroc_multiclass_error() {\n        let device = Default::default();\n        let mut metric = AurocMetric::<TestBackend>::new();\n\n        let input = AurocInput::new(\n            Tensor::from_data(\n                [\n                    [0.1, 0.2, 0.7], // More than 2 classes not supported\n                    [0.3, 0.5, 0.2],\n                ],\n                &device,\n            ),\n            Tensor::from_data([2, 1], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/base.rs",
    "content": "use std::sync::Arc;\n\nuse burn_core::data::dataloader::Progress;\nuse burn_optim::LearningRate;\n\n/// Metric metadata that can be used when computing metrics.\npub struct MetricMetadata {\n    /// The current progress.\n    pub progress: Progress,\n\n    /// The global progress of the training (e.g. epochs).\n    pub global_progress: Progress,\n\n    /// The current iteration.\n    pub iteration: Option<usize>,\n\n    /// The current learning rate.\n    pub lr: Option<LearningRate>,\n}\n\nimpl MetricMetadata {\n    /// Fake metric metadata\n    #[cfg(test)]\n    pub fn fake() -> Self {\n        Self {\n            progress: Progress {\n                items_processed: 1,\n                items_total: 1,\n            },\n            global_progress: Progress {\n                items_processed: 0,\n                items_total: 1,\n            },\n            iteration: Some(0),\n            lr: None,\n        }\n    }\n}\n\n/// Metric id that can be used to compare metrics and retrieve entries of the same metric.\n/// For now we take the name as id to make sure that the same metric has the same id across different runs.\n#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]\npub struct MetricId {\n    /// The metric id.\n    id: Arc<String>,\n}\n\n/// Metric attributes define the properties intrinsic to different types of metric.\n#[derive(Clone, Debug)]\npub enum MetricAttributes {\n    /// Numeric attributes.\n    Numeric(NumericAttributes),\n    /// No attributes.\n    None,\n}\n\n/// Definition of a metric.\n#[derive(Clone, Debug)]\npub struct MetricDefinition {\n    /// The metric's id.\n    pub metric_id: MetricId,\n    /// The name of the metric.\n    pub name: String,\n    /// The description of the metric.\n    pub description: Option<String>,\n    /// The attributes of the metric.\n    pub attributes: MetricAttributes,\n}\n\nimpl MetricDefinition {\n    /// Create a new metric definition given the metric and a unique id.\n    pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {\n        Self {\n            metric_id,\n            name: metric.name().to_string(),\n            description: metric.description(),\n            attributes: metric.attributes(),\n        }\n    }\n}\n\n/// Metric trait.\n///\n/// # Notes\n///\n/// Implementations should define their own input type only used by the metric.\n/// This is important since some conflict may happen when the model output is adapted for each\n/// metric's input type.\npub trait Metric: Send + Sync + Clone {\n    /// The input type of the metric.\n    type Input;\n\n    /// The parameterized name of the metric.\n    ///\n    /// This should be unique, so avoid using short generic names, prefer using the long name.\n    ///\n    /// For a metric that can exist at different parameters (e.g., top-k accuracy for different\n    /// values of k), the name should be unique for each instance.\n    fn name(&self) -> MetricName;\n\n    /// A short description of the metric.\n    fn description(&self) -> Option<String> {\n        None\n    }\n\n    /// Attributes of the metric.\n    ///\n    /// By default, metrics have no attributes.\n    fn attributes(&self) -> MetricAttributes {\n        MetricAttributes::None\n    }\n\n    /// Update the metric state and returns the current metric entry.\n    fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;\n\n    /// Clear the metric state.\n    fn clear(&mut self);\n}\n\n/// Type used to store metric names efficiently.\npub type MetricName = Arc<String>;\n\n/// Adaptor are used to transform types so that they can be used by metrics.\n///\n/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are\n/// registered with the specific learning paradigm (i.e. [SupervisedTraining](crate::SupervisedTraining)).\npub trait Adaptor<T> {\n    /// Adapt the type to be passed to a [metric](Metric).\n    fn adapt(&self) -> T;\n}\n\nimpl<T> Adaptor<()> for T {\n    fn adapt(&self) {}\n}\n\n/// Attributes that describe intrinsic properties of a numeric metric.\n#[derive(Clone, Debug)]\npub struct NumericAttributes {\n    /// Optional unit (e.g. \"%\", \"ms\", \"pixels\")\n    pub unit: Option<String>,\n    /// Whether larger values are better (true) or smaller are better (false).\n    pub higher_is_better: bool,\n}\n\nimpl From<NumericAttributes> for MetricAttributes {\n    fn from(attr: NumericAttributes) -> Self {\n        MetricAttributes::Numeric(attr)\n    }\n}\n\nimpl Default for NumericAttributes {\n    fn default() -> Self {\n        Self {\n            unit: None,\n            higher_is_better: true,\n        }\n    }\n}\n\n/// Declare a metric to be numeric.\n///\n/// This is useful to plot the values of a metric during training.\npub trait Numeric {\n    /// Returns the numeric value of the metric.\n    fn value(&self) -> NumericEntry;\n    /// Returns the current aggregated value of the metric over the global step (epoch).\n    fn running_value(&self) -> NumericEntry;\n}\n\n/// Serialized form of a metric entry.\n#[derive(Debug, Clone, new)]\npub struct SerializedEntry {\n    /// The string to be displayed.\n    pub formatted: String,\n    /// The string to be saved.\n    pub serialized: String,\n}\n\n/// Data type that contains the current state of a metric at a given time.\n#[derive(Debug, Clone)]\npub struct MetricEntry {\n    /// Id of the entry's metric.\n    pub metric_id: MetricId,\n    /// The serialized form of the entry.\n    pub serialized_entry: SerializedEntry,\n}\n\nimpl MetricEntry {\n    /// Create a new metric.\n    pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {\n        Self {\n            metric_id,\n            serialized_entry,\n        }\n    }\n}\n\n/// Numeric metric entry.\n#[derive(Debug, Clone)]\npub enum NumericEntry {\n    /// Single numeric value.\n    Value(f64),\n    /// Aggregated numeric (value, number of elements).\n    Aggregated {\n        /// The aggregated value of all entries.\n        aggregated_value: f64,\n        /// The number of entries present in the aggregated value.\n        count: usize,\n    },\n}\n\nimpl NumericEntry {\n    /// Gets the current aggregated value of the metric.\n    pub fn current(&self) -> f64 {\n        match self {\n            NumericEntry::Value(val) => *val,\n            NumericEntry::Aggregated {\n                aggregated_value, ..\n            } => *aggregated_value,\n        }\n    }\n\n    /// Returns a String representing the NumericEntry\n    pub fn serialize(&self) -> String {\n        match self {\n            Self::Value(v) => v.to_string(),\n            Self::Aggregated {\n                aggregated_value,\n                count,\n            } => format!(\"{aggregated_value},{count}\"),\n        }\n    }\n\n    /// De-serializes a string representing a NumericEntry and returns a Result containing the corresponding NumericEntry.\n    pub fn deserialize(entry: &str) -> Result<Self, String> {\n        // Check for comma separated values\n        let values = entry.split(',').collect::<Vec<_>>();\n        let num_values = values.len();\n\n        if num_values == 1 {\n            // Numeric value\n            match values[0].parse::<f64>() {\n                Ok(value) => Ok(NumericEntry::Value(value)),\n                Err(err) => Err(err.to_string()),\n            }\n        } else if num_values == 2 {\n            // Aggregated numeric (value, number of elements)\n            let (value, numel) = (values[0], values[1]);\n            match value.parse::<f64>() {\n                Ok(value) => match numel.parse::<usize>() {\n                    Ok(numel) => Ok(NumericEntry::Aggregated {\n                        aggregated_value: value,\n                        count: numel,\n                    }),\n                    Err(err) => Err(err.to_string()),\n                },\n                Err(err) => Err(err.to_string()),\n            }\n        } else {\n            Err(\"Invalid number of values for numeric entry\".to_string())\n        }\n    }\n\n    /// Compare this numeric metric's value with another one using the specified direction.\n    pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {\n        (self.current() > other.current()) == higher_is_better\n    }\n}\n\n/// Format a float with the given precision. Will use scientific notation if necessary.\npub fn format_float(float: f64, precision: usize) -> String {\n    let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);\n\n    match scientific_notation_threshold >= float {\n        true => format!(\"{float:.precision$e}\"),\n        false => format!(\"{float:.precision$}\"),\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/cer.rs",
    "content": "use super::state::{FormatOptions, NumericMetricState};\nuse super::{MetricMetadata, SerializedEntry};\nuse crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{Int, Tensor};\nuse core::marker::PhantomData;\nuse std::sync::Arc;\n\n/// Computes the edit distance (Levenshtein distance) between two sequences of integers.\n///\n/// The edit distance is defined as the minimum number of single-element edits (insertions,\n/// deletions, or substitutions) required to change one sequence into the other. This\n/// implementation is optimized for space, using only two rows of the dynamic programming table.\n///\npub(crate) fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {\n    let mut prev = (0..=prediction.len()).collect::<Vec<_>>();\n    let mut curr = vec![0; prediction.len() + 1];\n\n    for (i, &r) in reference.iter().enumerate() {\n        curr[0] = i + 1;\n        for (j, &p) in prediction.iter().enumerate() {\n            curr[j + 1] = if r == p {\n                prev[j] // no operation needed\n            } else {\n                1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion\n            };\n        }\n        core::mem::swap(&mut prev, &mut curr);\n    }\n    prev[prediction.len()]\n}\n\n/// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted\n/// and reference character sequences, divided by the total number of characters in the reference.\n/// This metric is commonly used in tasks such as speech recognition, OCR, or text generation\n/// to quantify how closely the predicted output matches the ground truth at a character level.\n///\n#[derive(Clone)]\npub struct CharErrorRate<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    pad_token: Option<usize>,\n    _b: PhantomData<B>,\n}\n\n/// The [character error rate metric](CharErrorRate) input type.\n#[derive(new)]\npub struct CerInput<B: Backend> {\n    /// The predicted token sequences (as a 2-D tensor of token indices).\n    pub outputs: Tensor<B, 2, Int>,\n    /// The target token sequences (as a 2-D tensor of token indices).\n    pub targets: Tensor<B, 2, Int>,\n}\n\nimpl<B: Backend> Default for CharErrorRate<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> CharErrorRate<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"CER\".to_string()),\n            state: NumericMetricState::default(),\n            pad_token: None,\n            _b: PhantomData,\n        }\n    }\n\n    /// Sets the pad token.\n    pub fn with_pad_token(mut self, index: usize) -> Self {\n        self.pad_token = Some(index);\n        self\n    }\n}\n\n/// The [character error rate metric](CharErrorRate) implementation.\nimpl<B: Backend> Metric for CharErrorRate<B> {\n    type Input = CerInput<B>;\n\n    fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {\n        let outputs = &input.outputs;\n        let targets = &input.targets;\n        let [batch_size, seq_len] = targets.dims();\n\n        let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {\n            // Create boolean masks for non-padding tokens.\n            let output_mask = outputs.clone().not_equal_elem(pad as i64);\n            let target_mask = targets.clone().not_equal_elem(pad as i64);\n\n            let output_lengths_tensor = output_mask.int().sum_dim(1);\n            let target_lengths_tensor = target_mask.int().sum_dim(1);\n\n            (\n                output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),\n                target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),\n            )\n        } else {\n            // If there's no padding, all sequences have the full length.\n            (\n                vec![seq_len as i64; batch_size],\n                vec![seq_len as i64; batch_size],\n            )\n        };\n\n        let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();\n        let targets_data = targets.to_data().to_vec::<i64>().unwrap();\n\n        let total_edit_distance: usize = (0..batch_size)\n            .map(|i| {\n                let start = i * seq_len;\n\n                // Get pre-calculated lengths for the current sequence.\n                let output_len = output_lengths[i] as usize;\n                let target_len = target_lengths[i] as usize;\n\n                let output_seq_slice = &outputs_data[start..(start + output_len)];\n                let target_seq_slice = &targets_data[start..(start + target_len)];\n                let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();\n                let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();\n\n                edit_distance(&target_seq, &output_seq)\n            })\n            .sum();\n\n        let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();\n\n        let value = if total_target_length > 0.0 {\n            100.0 * total_edit_distance as f64 / total_target_length\n        } else {\n            0.0\n        };\n\n        self.state.update(\n            value,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        super::NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for CharErrorRate<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    /// Perfect match ⇒ CER = 0 %.\n    #[test]\n    fn test_cer_without_padding() {\n        let device = Default::default();\n        let mut metric = CharErrorRate::<TestBackend>::new();\n\n        // Batch size = 2, sequence length = 2\n        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);\n        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);\n\n        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());\n\n        assert_eq!(0.0, metric.value().current());\n    }\n\n    /// Two edits in four target tokens ⇒ 50 %.\n    #[test]\n    fn test_cer_without_padding_two_errors() {\n        let device = Default::default();\n        let mut metric = CharErrorRate::<TestBackend>::new();\n\n        // One substitution in each sequence.\n        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);\n        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);\n\n        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());\n\n        // 2 edits / 4 tokens = 50 %\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    /// Same scenario as above, but with right-padding (token 9) ignored.\n    #[test]\n    fn test_cer_with_padding() {\n        let device = Default::default();\n        let pad = 9_i64;\n        let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);\n\n        // Each row has three columns, last one is the pad token.\n        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);\n        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);\n\n        metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    /// `clear()` must reset the running statistics to zero.\n    #[test]\n    fn test_clear_resets_state() {\n        let device = Default::default();\n        let mut metric = CharErrorRate::<TestBackend>::new();\n\n        let preds = Tensor::from_data([[1, 2]], &device);\n        let tgts = Tensor::from_data([[1, 3]], &device); // one error\n\n        metric.update(\n            &CerInput::new(preds.clone(), tgts.clone()),\n            &MetricMetadata::fake(),\n        );\n        assert!(metric.value().current() > 0.0);\n\n        metric.clear();\n        assert!(metric.value().current().is_nan());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/classification.rs",
    "content": "use std::num::NonZeroUsize;\n\n/// Necessary data for classification metrics.\n#[derive(Default, Debug, Clone)]\npub struct ClassificationMetricConfig {\n    pub decision_rule: DecisionRule,\n    pub class_reduction: ClassReduction,\n}\n\n/// The prediction decision rule for classification metrics.\n#[derive(Debug, Clone)]\npub enum DecisionRule {\n    /// Consider a class predicted if its probability exceeds the threshold.\n    Threshold(f64),\n    /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.\n    TopK(NonZeroUsize),\n}\n\nimpl Default for DecisionRule {\n    fn default() -> Self {\n        Self::Threshold(0.5)\n    }\n}\n\n/// The reduction strategy for classification metrics.\n#[derive(Copy, Clone, Default, Debug)]\npub enum ClassReduction {\n    /// Computes the statistics over all classes before averaging\n    Micro,\n    /// Computes the statistics independently for each class before averaging\n    #[default]\n    Macro,\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/confusion_stats.rs",
    "content": "use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};\nuse burn_core::{\n    prelude::{Backend, Bool, Int, Tensor},\n    tensor::IndexingUpdateOp,\n};\nuse std::fmt::{self, Debug};\n\n/// Input for confusion statistics error types.\n#[derive(new, Debug, Clone)]\npub struct ConfusionStatsInput<B: Backend> {\n    /// Sample x Class Non thresholded normalized predictions.\n    pub predictions: Tensor<B, 2>,\n    /// Sample x Class one-hot encoded target.\n    pub targets: Tensor<B, 2, Bool>,\n}\n\nimpl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {\n    fn from(input: ConfusionStatsInput<B>) -> Self {\n        (input.predictions, input.targets)\n    }\n}\n\nimpl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {\n    fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {\n        Self::new(value.0, value.1)\n    }\n}\n\n#[derive(Clone)]\npub struct ConfusionStats<B: Backend> {\n    confusion_classes: Tensor<B, 2, Int>,\n    class_reduction: ClassReduction,\n}\n\nimpl<B: Backend> Debug for ConfusionStats<B> {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        let to_vec = |tensor_data: Tensor<B, 1>| {\n            tensor_data\n                .to_data()\n                .to_vec::<f32>()\n                .expect(\"A vector representation of the input Tensor is expected\")\n        };\n        let ratio_of_support_vec =\n            |metric: Tensor<B, 1>| to_vec(self.clone().ratio_of_support(metric));\n        f.debug_struct(\"ConfusionStats\")\n            .field(\"tp\", &ratio_of_support_vec(self.clone().true_positive()))\n            .field(\"fp\", &ratio_of_support_vec(self.clone().false_positive()))\n            .field(\"tn\", &ratio_of_support_vec(self.clone().true_negative()))\n            .field(\"fn\", &ratio_of_support_vec(self.clone().false_negative()))\n            .field(\"support\", &to_vec(self.clone().support()))\n            .finish()\n    }\n}\n\nimpl<B: Backend> ConfusionStats<B> {\n    /// Expects `predictions` to be normalized.\n    pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {\n        let prediction_mask = match config.decision_rule {\n            DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),\n            DecisionRule::TopK(top_k) => {\n                let mask = input.predictions.zeros_like();\n                let indexes =\n                    input\n                        .predictions\n                        .clone()\n                        .argsort_descending(1)\n                        .narrow(1, 0, top_k.get());\n                let values = indexes.ones_like().float();\n                mask.scatter(1, indexes, values, IndexingUpdateOp::Add)\n                    .bool()\n            }\n        };\n        Self {\n            confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,\n            class_reduction: config.class_reduction,\n        }\n    }\n\n    /// sum over samples\n    fn aggregate(\n        sample_class_mask: Tensor<B, 2, Bool>,\n        class_reduction: ClassReduction,\n    ) -> Tensor<B, 1> {\n        use ClassReduction::{Macro, Micro};\n        match class_reduction {\n            Micro => sample_class_mask.float().sum(),\n            Macro => sample_class_mask.float().sum_dim(0).squeeze_dim(0),\n        }\n    }\n\n    pub fn true_positive(self) -> Tensor<B, 1> {\n        Self::aggregate(self.confusion_classes.equal_elem(3), self.class_reduction)\n    }\n\n    pub fn true_negative(self) -> Tensor<B, 1> {\n        Self::aggregate(self.confusion_classes.equal_elem(0), self.class_reduction)\n    }\n\n    pub fn false_positive(self) -> Tensor<B, 1> {\n        Self::aggregate(self.confusion_classes.equal_elem(1), self.class_reduction)\n    }\n\n    pub fn false_negative(self) -> Tensor<B, 1> {\n        Self::aggregate(self.confusion_classes.equal_elem(2), self.class_reduction)\n    }\n\n    pub fn positive(self) -> Tensor<B, 1> {\n        self.clone().true_positive() + self.false_negative()\n    }\n\n    pub fn negative(self) -> Tensor<B, 1> {\n        self.clone().true_negative() + self.false_positive()\n    }\n\n    pub fn predicted_positive(self) -> Tensor<B, 1> {\n        self.clone().true_positive() + self.false_positive()\n    }\n\n    pub fn support(self) -> Tensor<B, 1> {\n        self.clone().positive() + self.negative()\n    }\n\n    pub fn ratio_of_support(self, metric: Tensor<B, 1>) -> Tensor<B, 1> {\n        metric / self.clone().support()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::{ConfusionStats, ConfusionStatsInput};\n    use crate::{\n        TestBackend,\n        metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},\n        tests::{ClassificationType, THRESHOLD, dummy_classification_input},\n    };\n    use burn_core::prelude::TensorData;\n    use rstest::{fixture, rstest};\n    use std::num::NonZeroUsize;\n\n    fn top_k_config(\n        top_k: NonZeroUsize,\n        class_reduction: ClassReduction,\n    ) -> ClassificationMetricConfig {\n        ClassificationMetricConfig {\n            decision_rule: DecisionRule::TopK(top_k),\n            class_reduction,\n        }\n    }\n    #[fixture]\n    #[once]\n    fn top_k_config_k1_micro() -> ClassificationMetricConfig {\n        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)\n    }\n\n    #[fixture]\n    #[once]\n    fn top_k_config_k1_macro() -> ClassificationMetricConfig {\n        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)\n    }\n    #[fixture]\n    #[once]\n    fn top_k_config_k2_micro() -> ClassificationMetricConfig {\n        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)\n    }\n    #[fixture]\n    #[once]\n    fn top_k_config_k2_macro() -> ClassificationMetricConfig {\n        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)\n    }\n\n    fn threshold_config(\n        threshold: f64,\n        class_reduction: ClassReduction,\n    ) -> ClassificationMetricConfig {\n        ClassificationMetricConfig {\n            decision_rule: DecisionRule::Threshold(threshold),\n            class_reduction,\n        }\n    }\n    #[fixture]\n    #[once]\n    fn threshold_config_micro() -> ClassificationMetricConfig {\n        threshold_config(THRESHOLD, ClassReduction::Micro)\n    }\n    #[fixture]\n    #[once]\n    fn threshold_config_macro() -> ClassificationMetricConfig {\n        threshold_config(THRESHOLD, ClassReduction::Macro)\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]\n    fn test_true_positive(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .true_positive()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]\n    fn test_true_negative(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .true_negative()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]\n    fn test_false_positive(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .false_positive()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]\n    fn test_false_negatives(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .false_negative()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]\n    fn test_positive(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .positive()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]\n    fn test_negative(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .negative()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n\n    #[rstest]\n    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]\n    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]\n    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]\n    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]\n    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]\n    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]\n    fn test_predicted_positive(\n        #[case] classification_type: ClassificationType,\n        #[case] config: ClassificationMetricConfig,\n        #[case] expected: Vec<i64>,\n    ) {\n        let input: ConfusionStatsInput<TestBackend> =\n            dummy_classification_input(&classification_type).into();\n        ConfusionStats::new(&input, &config)\n            .predicted_positive()\n            .int()\n            .into_data()\n            .assert_eq(&TensorData::from(expected.as_slice()), true);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/cpu_temp.rs",
    "content": "use std::sync::Arc;\n\n/// CPU Temperature metric\nuse super::MetricMetadata;\nuse crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};\nuse systemstat::{Platform, System};\n\n/// CPU Temperature in celsius degrees\n#[derive(Clone)]\npub struct CpuTemperature {\n    name: MetricName,\n    temp_celsius: f32,\n    sys: Arc<System>,\n}\n\nimpl CpuTemperature {\n    /// Creates a new CPU temp metric\n    pub fn new() -> Self {\n        let name = Arc::new(\"CPU Temperature\".to_string());\n\n        Self {\n            name,\n            temp_celsius: 0.,\n            sys: Arc::new(System::new()),\n        }\n    }\n}\n\nimpl Default for CpuTemperature {\n    fn default() -> Self {\n        CpuTemperature::new()\n    }\n}\n\nimpl Metric for CpuTemperature {\n    type Input = ();\n\n    fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        match self.sys.cpu_temp() {\n            Ok(temp) => self.temp_celsius = temp,\n            Err(_) => self.temp_celsius = f32::NAN,\n        }\n\n        let formatted = match self.temp_celsius.is_nan() {\n            true => format!(\"{}: NaN °C\", self.name()),\n            false => format!(\"{}: {:.2} °C\", self.name(), self.temp_celsius),\n        };\n        let raw = format!(\"{:.2}\", self.temp_celsius);\n\n        SerializedEntry::new(formatted, raw)\n    }\n\n    fn clear(&mut self) {}\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        super::NumericAttributes {\n            unit: Some(\"°C\".to_string()),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for CpuTemperature {\n    fn value(&self) -> NumericEntry {\n        NumericEntry::Value(self.temp_celsius as f64)\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        NumericEntry::Value(self.temp_celsius as f64)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/cpu_use.rs",
    "content": "use super::MetricMetadata;\nuse crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericEntry, SerializedEntry};\nuse std::{\n    sync::Arc,\n    time::{Duration, Instant},\n};\nuse sysinfo::{CpuRefreshKind, RefreshKind, System};\n\n/// General CPU Usage metric\npub struct CpuUse {\n    name: MetricName,\n    last_refresh: Instant,\n    refresh_frequency: Duration,\n    sys: System,\n    current: f64,\n}\n\nimpl Clone for CpuUse {\n    fn clone(&self) -> Self {\n        Self {\n            name: self.name.clone(),\n            last_refresh: self.last_refresh,\n            refresh_frequency: self.refresh_frequency,\n            sys: System::new(),\n            current: self.current,\n        }\n    }\n}\n\nimpl CpuUse {\n    /// Creates a new CPU metric\n    pub fn new() -> Self {\n        let mut sys = System::new();\n        let current = Self::refresh(&mut sys);\n        let name = \"CPU Usage\".to_string();\n\n        Self {\n            name: Arc::new(name),\n            last_refresh: Instant::now(),\n            refresh_frequency: Duration::from_millis(200),\n            sys,\n            current,\n        }\n    }\n\n    fn refresh(sys: &mut System) -> f64 {\n        sys.refresh_specifics(\n            RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()),\n        );\n\n        let cpus = sys.cpus();\n        let num_cpus = cpus.len();\n        let use_percentage = cpus.iter().fold(0.0, |acc, cpu| acc + cpu.cpu_usage()) as f64;\n\n        use_percentage / num_cpus as f64\n    }\n}\n\nimpl Default for CpuUse {\n    fn default() -> Self {\n        CpuUse::new()\n    }\n}\n\nimpl Metric for CpuUse {\n    type Input = ();\n\n    fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        if self.last_refresh.elapsed() >= self.refresh_frequency {\n            self.current = Self::refresh(&mut self.sys);\n            self.last_refresh = Instant::now();\n        }\n\n        let formatted = format!(\"{}: {:.2} %\", self.name(), self.current);\n        let raw = format!(\"{:.2}\", self.current);\n\n        SerializedEntry::new(formatted, raw)\n    }\n\n    fn clear(&mut self) {}\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        super::NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for CpuUse {\n    fn value(&self) -> NumericEntry {\n        NumericEntry::Value(self.current)\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        NumericEntry::Value(self.current)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/cuda.rs",
    "content": "use std::sync::Arc;\n\nuse super::MetricMetadata;\nuse crate::metric::{Metric, MetricName, SerializedEntry};\nuse nvml_wrapper::Nvml;\n\n/// Track basic cuda infos.\n#[derive(Clone)]\npub struct CudaMetric {\n    name: MetricName,\n    nvml: Arc<Option<Nvml>>,\n}\n\nimpl CudaMetric {\n    /// Creates a new metric for CUDA.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Cuda\".to_string()),\n            nvml: Arc::new(Nvml::init().map(Some).unwrap_or_else(|err| {\n                log::warn!(\"Unable to initialize CUDA Metric: {err}\");\n                None\n            })),\n        }\n    }\n}\n\nimpl Default for CudaMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Metric for CudaMetric {\n    type Input = ();\n\n    fn update(&mut self, _item: &(), _metadata: &MetricMetadata) -> SerializedEntry {\n        let not_available =\n            || SerializedEntry::new(\"Unavailable\".to_string(), \"Unavailable\".to_string());\n\n        let available = |nvml: &Nvml| {\n            let mut formatted = String::new();\n            let mut raw_running = String::new();\n\n            let device_count = match nvml.device_count() {\n                Ok(val) => val,\n                Err(err) => {\n                    log::warn!(\"Unable to get the number of cuda devices: {err}\");\n                    return not_available();\n                }\n            };\n\n            for index in 0..device_count {\n                let device = match nvml.device_by_index(index) {\n                    Ok(val) => val,\n                    Err(err) => {\n                        log::warn!(\"Unable to get device {index}: {err}\");\n                        return not_available();\n                    }\n                };\n                let memory_info = match device.memory_info() {\n                    Ok(info) => info,\n                    Err(err) => {\n                        log::warn!(\"Unable to get memory info from device {index}: {err}\");\n                        return not_available();\n                    }\n                };\n\n                let used_gb = memory_info.used as f64 * 1e-9;\n                let total_gb = memory_info.total as f64 * 1e-9;\n\n                let memory_info_formatted = format!(\"{used_gb:.2}/{total_gb:.2} Gb\");\n                let memory_info_raw = format!(\"{used_gb}/{total_gb}\");\n\n                formatted = format!(\"{formatted} GPU #{index} - Memory {memory_info_formatted}\");\n                raw_running = format!(\"{memory_info_raw} \");\n\n                let utilization_rates = match device.utilization_rates() {\n                    Ok(rate) => rate,\n                    Err(err) => {\n                        log::warn!(\"Unable to get utilization rates from device {index}: {err}\");\n                        return not_available();\n                    }\n                };\n                let utilization_rate_formatted = format!(\"{}%\", utilization_rates.gpu);\n                formatted = format!(\"{formatted} - Usage {utilization_rate_formatted}\");\n\n                // Power is the currency for perf/W. NVML reports milliwatts.\n                if let Ok(power_mw) = device.power_usage() {\n                    let power_w = power_mw as f64 / 1000.0;\n                    formatted = format!(\"{formatted} - Power {power_w:.1} W\");\n                }\n            }\n\n            SerializedEntry::new(formatted, raw_running)\n        };\n\n        match self.nvml.as_ref() {\n            Some(nvml) => available(nvml),\n            None => not_available(),\n        }\n    }\n\n    fn clear(&mut self) {}\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/fbetascore.rs",
    "content": "use crate::metric::{MetricName, Numeric};\n\nuse super::{\n    Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,\n    classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},\n    confusion_stats::{ConfusionStats, ConfusionStatsInput},\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::cast::ToElement,\n};\nuse core::marker::PhantomData;\nuse std::{num::NonZeroUsize, sync::Arc};\n\n/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric.\n///\n/// The `beta` parameter represents the ratio of recall importance to precision importance.\n/// `beta > 1` gives more weight to recall, while `beta < 1` favors precision.\n#[derive(Clone)]\npub struct FBetaScoreMetric<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    _b: PhantomData<B>,\n    config: ClassificationMetricConfig,\n    beta: f64,\n}\n\nimpl<B: Backend> Default for FBetaScoreMetric<B> {\n    fn default() -> Self {\n        Self::new(Default::default(), Default::default())\n    }\n}\n\nimpl<B: Backend> FBetaScoreMetric<B> {\n    #[allow(dead_code)]\n    fn new(config: ClassificationMetricConfig, beta: f64) -> Self {\n        let name = Arc::new(format!(\n            \"FBetaScore ({}) @ {:?} [{:?}]\",\n            beta, config.decision_rule, config.class_reduction\n        ));\n        Self {\n            name,\n            config,\n            beta,\n            state: Default::default(),\n            _b: PhantomData,\n        }\n    }\n\n    /// F-beta score metric for binary classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `beta` - Positive real factor to weight recall's importance.\n    /// * `threshold` - The threshold to transform a probability into a binary prediction.\n    #[allow(dead_code)]\n    pub fn binary(beta: f64, threshold: f64) -> Self {\n        Self::new(\n            ClassificationMetricConfig {\n                decision_rule: DecisionRule::Threshold(threshold),\n                // binary classification results are the same independently of class_reduction\n                ..Default::default()\n            },\n            beta,\n        )\n    }\n\n    /// F-beta score metric for multiclass classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `beta` - Positive real factor to weight recall's importance.\n    /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self {\n        Self::new(\n            ClassificationMetricConfig {\n                decision_rule: DecisionRule::TopK(\n                    NonZeroUsize::new(top_k).expect(\"top_k must be non-zero\"),\n                ),\n                class_reduction,\n            },\n            beta,\n        )\n    }\n\n    /// F-beta score metric for multi-label classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `beta` - Positive real factor to weight recall's importance.\n    /// * `threshold` - The threshold to transform a probability into a binary prediction.\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self {\n        Self::new(\n            ClassificationMetricConfig {\n                decision_rule: DecisionRule::Threshold(threshold),\n                class_reduction,\n            },\n            beta,\n        )\n    }\n\n    fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {\n        use ClassReduction::{Macro, Micro};\n        let avg_tensor = match self.config.class_reduction {\n            Micro => aggregated_metric,\n            Macro => {\n                if aggregated_metric\n                    .clone()\n                    .contains_nan()\n                    .any()\n                    .into_scalar()\n                    .to_bool()\n                {\n                    let nan_mask = aggregated_metric.clone().is_nan();\n                    aggregated_metric = aggregated_metric\n                        .clone()\n                        .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))\n                }\n                aggregated_metric.mean()\n            }\n        };\n        avg_tensor.into_scalar().to_f64()\n    }\n}\n\nimpl<B: Backend> Metric for FBetaScoreMetric<B> {\n    type Input = ConfusionStatsInput<B>;\n\n    fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [sample_size, _] = input.predictions.dims();\n\n        let cf_stats = ConfusionStats::new(input, &self.config);\n        let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2));\n        let metric = self.class_average(\n            scaled_true_positive.clone()\n                / (scaled_true_positive\n                    + cf_stats.clone().false_negative() * self.beta.powi(2)\n                    + cf_stats.false_positive()),\n        );\n\n        self.state.update(\n            100.0 * metric,\n            sample_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for FBetaScoreMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::{\n        ClassReduction::{self, *},\n        FBetaScoreMetric, Metric, MetricMetadata,\n    };\n    use crate::metric::Numeric;\n    use crate::{\n        TestBackend,\n        tests::{ClassificationType, THRESHOLD, dummy_classification_input},\n    };\n    use burn_core::tensor::TensorData;\n    use burn_core::tensor::Tolerance;\n    use rstest::rstest;\n\n    #[rstest]\n    #[case::binary_b1(1.0, THRESHOLD, 0.5)]\n    #[case::binary_b2(2.0, THRESHOLD, 0.5)]\n    fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) {\n        let input = dummy_classification_input(&ClassificationType::Binary).into();\n        let mut metric = FBetaScoreMetric::binary(beta, threshold);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)]\n    #[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))]\n    #[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)]\n    #[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)]\n    #[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)]\n    #[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))]\n    #[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)]\n    #[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)]\n    fn test_multiclass_fscore(\n        #[case] beta: f64,\n        #[case] class_reduction: ClassReduction,\n        #[case] top_k: usize,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multiclass).into();\n        let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))]\n    #[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)]\n    #[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))]\n    #[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)]\n    fn test_multilabel_fscore(\n        #[case] beta: f64,\n        #[case] class_reduction: ClassReduction,\n        #[case] threshold: f64,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multilabel).into();\n        let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f32>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[test]\n    fn test_parameterized_unique_name() {\n        let metric_a = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);\n        let metric_b = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 2, ClassReduction::Macro);\n        let metric_c = FBetaScoreMetric::<TestBackend>::multiclass(0.5, 1, ClassReduction::Macro);\n\n        assert_ne!(metric_a.name(), metric_b.name());\n        assert_eq!(metric_a.name(), metric_c.name());\n\n        let metric_a = FBetaScoreMetric::<TestBackend>::binary(0.5, 0.5);\n        let metric_b = FBetaScoreMetric::<TestBackend>::binary(0.75, 0.5);\n        assert_ne!(metric_a.name(), metric_b.name());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/hamming.rs",
    "content": "use core::marker::PhantomData;\nuse std::sync::Arc;\n\nuse super::state::{FormatOptions, NumericMetricState};\nuse super::{MetricMetadata, SerializedEntry};\nuse crate::metric::{\n    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,\n};\nuse burn_core::tensor::{ElementConversion, Int, Tensor, activation::sigmoid, backend::Backend};\n\n/// The hamming score, sometimes referred to as multi-label or label-based accuracy.\n#[derive(Clone)]\npub struct HammingScore<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    threshold: f32,\n    sigmoid: bool,\n    _b: PhantomData<B>,\n}\n\n/// The [hamming score](HammingScore) input type.\n#[derive(new)]\npub struct HammingScoreInput<B: Backend> {\n    outputs: Tensor<B, 2>,\n    targets: Tensor<B, 2, Int>,\n}\n\nimpl<B: Backend> HammingScore<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self::default()\n    }\n\n    fn update_name(&mut self) {\n        self.name = Arc::new(format!(\"Hamming Score @ Threshold({})\", self.threshold));\n    }\n\n    /// Sets the threshold.\n    pub fn with_threshold(mut self, threshold: f32) -> Self {\n        self.threshold = threshold;\n        self.update_name();\n        self\n    }\n\n    /// Sets the sigmoid activation function usage.\n    pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {\n        self.sigmoid = sigmoid;\n        self.update_name();\n        self\n    }\n}\n\nimpl<B: Backend> Default for HammingScore<B> {\n    /// Creates a new metric instance with default values.\n    fn default() -> Self {\n        let threshold = 0.5;\n        let name = Arc::new(format!(\"Hamming Score @ Threshold({})\", threshold));\n\n        Self {\n            name,\n            state: NumericMetricState::default(),\n            threshold,\n            sigmoid: false,\n            _b: PhantomData,\n        }\n    }\n}\n\nimpl<B: Backend> Metric for HammingScore<B> {\n    type Input = HammingScoreInput<B>;\n\n    fn update(\n        &mut self,\n        input: &HammingScoreInput<B>,\n        _metadata: &MetricMetadata,\n    ) -> SerializedEntry {\n        let [batch_size, _n_classes] = input.outputs.dims();\n\n        let targets = input.targets.clone();\n\n        let mut outputs = input.outputs.clone();\n\n        if self.sigmoid {\n            outputs = sigmoid(outputs);\n        }\n\n        let score = outputs\n            .greater_elem(self.threshold)\n            .equal(targets.bool())\n            .float()\n            .mean()\n            .into_scalar()\n            .elem::<f64>();\n\n        self.state.update(\n            100.0 * score,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for HammingScore<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn test_hamming_score() {\n        let device = Default::default();\n        let mut metric = HammingScore::<TestBackend>::new();\n\n        let x = Tensor::from_data(\n            [\n                [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]\n                [0.43, 0.31, 0.21, 0.63, 0.53], //               [0, 0, 0, 1, 1]\n                [0.44, 0.25, 0.71, 0.39, 0.73], //               [0, 0, 1, 0, 1]\n                [0.49, 0.37, 0.68, 0.39, 0.31], //               [0, 0, 1, 0, 0]\n            ],\n            &device,\n        );\n        let y = Tensor::from_data(\n            [\n                [0, 1, 0, 1, 1],\n                [0, 0, 0, 1, 1],\n                [0, 0, 1, 0, 1],\n                [0, 0, 1, 0, 0],\n            ],\n            &device,\n        );\n\n        let _entry = metric.update(\n            &HammingScoreInput::new(x.clone(), y.clone()),\n            &MetricMetadata::fake(),\n        );\n        assert_eq!(100.0, metric.value().current());\n\n        // Invert all targets: y = (1 - y)\n        let y = y.neg().add_scalar(1);\n        let _entry = metric.update(\n            &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)\n            &MetricMetadata::fake(),\n        );\n        assert_eq!(0.0, metric.value().current());\n\n        // Invert 5 target values -> 1 - (5/20) = 0.75\n        let y = Tensor::from_data(\n            [\n                [0, 1, 1, 0, 1],\n                [0, 0, 0, 0, 1],\n                [0, 0, 0, 0, 1],\n                [0, 1, 1, 0, 0],\n            ],\n            &device,\n        );\n        let _entry = metric.update(\n            &HammingScoreInput::new(x, y), // invert targets (1 - y)\n            &MetricMetadata::fake(),\n        );\n        assert_eq!(75.0, metric.value().current());\n    }\n\n    #[test]\n    fn test_parameterized_unique_name() {\n        let metric_a = HammingScore::<TestBackend>::new().with_threshold(0.5);\n        let metric_b = HammingScore::<TestBackend>::new().with_threshold(0.75);\n        let metric_c = HammingScore::<TestBackend>::new().with_threshold(0.5);\n\n        assert_ne!(metric_a.name(), metric_b.name());\n        assert_eq!(metric_a.name(), metric_c.name());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/iteration.rs",
    "content": "use std::sync::Arc;\n\nuse super::MetricMetadata;\nuse super::SerializedEntry;\nuse super::state::FormatOptions;\nuse super::state::NumericMetricState;\nuse crate::metric::MetricName;\nuse crate::metric::Numeric;\nuse crate::metric::{Metric, MetricAttributes, NumericAttributes, NumericEntry};\n\n/// The loss metric.\n#[derive(Clone)]\npub struct IterationSpeedMetric {\n    name: MetricName,\n    state: NumericMetricState,\n    instant: Option<std::time::Instant>,\n}\n\nimpl Default for IterationSpeedMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl IterationSpeedMetric {\n    /// Create the metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Iteration Speed\".to_string()),\n            state: Default::default(),\n            instant: Default::default(),\n        }\n    }\n}\n\nimpl Metric for IterationSpeedMetric {\n    type Input = ();\n\n    fn update(&mut self, _: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry {\n        let raw = match self.instant {\n            Some(val) => {\n                // If iteration is not logged, compute the speed over the number of items processed.\n                // 1 iteration should equal 1 item when iteration is not logged.\n                metadata\n                    .iteration\n                    .unwrap_or(metadata.progress.items_processed) as f64\n                    / val.elapsed().as_secs_f64()\n            }\n            None => {\n                self.instant = Some(std::time::Instant::now());\n                0.0\n            }\n        };\n\n        self.state.update(\n            raw,\n            1,\n            FormatOptions::new(self.name())\n                .unit(\"iter/sec\")\n                .precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.instant = None;\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"iter/sec\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for IterationSpeedMetric {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/learning_rate.rs",
    "content": "use std::sync::Arc;\n\nuse super::{\n    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse crate::metric::{Metric, MetricName, Numeric, SerializedEntry};\n\n/// Track the learning rate across iterations.\n#[derive(Clone)]\npub struct LearningRateMetric {\n    name: MetricName,\n    state: NumericMetricState,\n}\n\nimpl LearningRateMetric {\n    /// Creates a new learning rate metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Learning Rate\".to_string()),\n            state: NumericMetricState::new(),\n        }\n    }\n}\n\nimpl Default for LearningRateMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Metric for LearningRateMetric {\n    type Input = ();\n\n    fn update(&mut self, _item: &(), metadata: &MetricMetadata) -> SerializedEntry {\n        let lr = metadata.lr.unwrap_or(0.0);\n\n        self.state\n            .update(lr, 1, FormatOptions::new(self.name()).precision(2))\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for LearningRateMetric {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/loss.rs",
    "content": "use std::sync::Arc;\n\nuse super::MetricMetadata;\nuse super::SerializedEntry;\nuse super::state::FormatOptions;\nuse super::state::NumericMetricState;\nuse crate::metric::MetricName;\nuse crate::metric::{Metric, MetricAttributes, Numeric, NumericAttributes, NumericEntry};\nuse burn_core::tensor::Tensor;\nuse burn_core::tensor::backend::Backend;\n\n/// The loss metric.\n#[derive(Clone)]\npub struct LossMetric<B: Backend> {\n    name: Arc<String>,\n    state: NumericMetricState,\n    _b: B,\n}\n\n/// The [loss metric](LossMetric) input type.\n#[derive(new)]\npub struct LossInput<B: Backend> {\n    tensor: Tensor<B, 1>,\n}\n\nimpl<B: Backend> Default for LossMetric<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> LossMetric<B> {\n    /// Create the metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Loss\".to_string()),\n            state: NumericMetricState::default(),\n            _b: Default::default(),\n        }\n    }\n}\n\nimpl<B: Backend> Metric for LossMetric<B> {\n    type Input = LossInput<B>;\n\n    fn update(&mut self, loss: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [batch_size] = loss.tensor.dims();\n        let loss = loss\n            .tensor\n            .clone()\n            .mean()\n            .into_data()\n            .iter::<f64>()\n            .next()\n            .unwrap();\n\n        self.state.update(\n            loss,\n            batch_size,\n            FormatOptions::new(self.name()).precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for LossMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/memory_use.rs",
    "content": "/// RAM use metric\nuse super::{MetricAttributes, MetricMetadata, NumericAttributes};\nuse crate::metric::{Metric, Numeric, NumericEntry, SerializedEntry};\nuse std::{\n    sync::Arc,\n    time::{Duration, Instant},\n};\nuse sysinfo::System;\n\n/// Memory information\npub struct CpuMemory {\n    name: Arc<String>,\n    last_refresh: Instant,\n    refresh_frequency: Duration,\n    sys: System,\n    ram_bytes_total: u64,\n    ram_bytes_used: u64,\n}\n\nimpl Clone for CpuMemory {\n    fn clone(&self) -> Self {\n        Self {\n            name: self.name.clone(),\n            last_refresh: self.last_refresh,\n            refresh_frequency: self.refresh_frequency,\n            sys: System::new(),\n            ram_bytes_total: self.ram_bytes_total,\n            ram_bytes_used: self.ram_bytes_used,\n        }\n    }\n}\n\nimpl CpuMemory {\n    /// Creates a new memory metric\n    pub fn new() -> Self {\n        let mut metric = Self {\n            name: Arc::new(\"CPU Memory\".into()),\n            last_refresh: Instant::now(),\n            refresh_frequency: Duration::from_millis(200),\n            sys: System::new(),\n            ram_bytes_total: 0,\n            ram_bytes_used: 0,\n        };\n        metric.refresh();\n        metric\n    }\n\n    fn refresh(&mut self) {\n        self.sys.refresh_memory();\n        self.last_refresh = Instant::now();\n\n        // bytes of RAM available\n        self.ram_bytes_total = self.sys.total_memory();\n\n        // bytes of RAM in use\n        self.ram_bytes_used = self.sys.used_memory();\n    }\n}\n\nimpl Default for CpuMemory {\n    fn default() -> Self {\n        CpuMemory::new()\n    }\n}\n\nimpl Metric for CpuMemory {\n    type Input = ();\n\n    fn update(&mut self, _item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        if self.last_refresh.elapsed() >= self.refresh_frequency {\n            self.refresh();\n        }\n\n        let raw = bytes2gb(self.ram_bytes_used);\n        let formatted = format!(\n            \"RAM Used: {:.2} / {:.2} Gb\",\n            raw,\n            bytes2gb(self.ram_bytes_total),\n        );\n\n        SerializedEntry::new(formatted, raw.to_string())\n    }\n\n    fn clear(&mut self) {}\n\n    fn name(&self) -> Arc<String> {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"Gb\".to_string()),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for CpuMemory {\n    fn value(&self) -> NumericEntry {\n        NumericEntry::Value(bytes2gb(self.ram_bytes_used))\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        NumericEntry::Value(bytes2gb(self.ram_bytes_used))\n    }\n}\n\nfn bytes2gb(bytes: u64) -> f64 {\n    bytes as f64 / 1e9\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/mod.rs",
    "content": "/// State module.\npub mod state;\n/// Module responsible to save and exposes data collected during training.\npub mod store;\n/// Metrics module for vision tasks.\n#[cfg(feature = \"vision\")]\npub mod vision;\n\n//Metrics for reinforcement learning.\n#[cfg(feature = \"rl\")]\nmod rl;\n#[cfg(feature = \"rl\")]\npub use rl::*;\n\n// System metrics\n#[cfg(feature = \"sys-metrics\")]\nmod cpu_temp;\n#[cfg(feature = \"sys-metrics\")]\nmod cpu_use;\n#[cfg(feature = \"sys-metrics\")]\nmod cuda;\n#[cfg(feature = \"sys-metrics\")]\nmod memory_use;\n#[cfg(feature = \"sys-metrics\")]\npub use cpu_temp::*;\n#[cfg(feature = \"sys-metrics\")]\npub use cpu_use::*;\n#[cfg(feature = \"sys-metrics\")]\npub use cuda::*;\n#[cfg(feature = \"sys-metrics\")]\npub use memory_use::*;\n\n// Training metrics\nmod acc;\nmod auroc;\nmod base;\nmod cer;\nmod confusion_stats;\nmod fbetascore;\nmod hamming;\nmod iteration;\nmod learning_rate;\nmod loss;\nmod perplexity;\nmod precision;\nmod recall;\nmod top_k_acc;\nmod wer;\n\npub use acc::*;\npub use auroc::*;\npub use base::*;\npub use cer::*;\npub use confusion_stats::ConfusionStatsInput;\npub use fbetascore::*;\npub use hamming::*;\npub use iteration::*;\npub use learning_rate::*;\npub use loss::*;\npub use perplexity::*;\npub use precision::*;\npub use recall::*;\npub use top_k_acc::*;\npub use wer::*;\n\npub(crate) mod classification;\npub(crate) mod processor;\n\npub use crate::metric::classification::ClassReduction;\n// Expose `ItemLazy` so it can be implemented for custom types\npub use processor::ItemLazy;\n"
  },
  {
    "path": "crates/burn-train/src/metric/perplexity.rs",
    "content": "use core::marker::PhantomData;\n\nuse super::state::FormatOptions;\nuse super::{MetricMetadata, NumericEntry, SerializedEntry, format_float};\nuse crate::metric::{Metric, MetricAttributes, MetricName, Numeric, NumericAttributes};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{ElementConversion, Int, Tensor};\n\n/// Custom state for perplexity metric that correctly accumulates negative log-likelihood.\n///\n/// Unlike other metrics that can be averaged, perplexity requires special handling:\n/// - Accumulate total negative log-likelihood across all tokens\n/// - Accumulate total number of effective tokens\n/// - Compute perplexity as exp(total_nll / total_tokens) only at the end\n#[derive(Clone)]\nstruct PerplexityState {\n    /// Sum of negative log-likelihood across all tokens\n    sum_nll: f64,\n    /// Total number of effective tokens (excluding padding)\n    total_tokens: usize,\n    /// Current batch perplexity (for display purposes)\n    current: f64,\n}\n\nimpl PerplexityState {\n    fn new() -> Self {\n        Self {\n            sum_nll: 0.0,\n            total_tokens: 0,\n            current: f64::NAN,\n        }\n    }\n\n    fn reset(&mut self) {\n        self.sum_nll = 0.0;\n        self.total_tokens = 0;\n        self.current = f64::NAN;\n    }\n\n    /// Update state with negative log-likelihood and token count from current batch\n    fn update(\n        &mut self,\n        sum_log_prob: f64,\n        effective_tokens: usize,\n        format: FormatOptions,\n    ) -> SerializedEntry {\n        // sum_log_prob is already the sum of log probabilities (negative values)\n        // We need to negate it to get negative log-likelihood\n        let batch_nll = -sum_log_prob;\n\n        // Accumulate across batches\n        self.sum_nll += batch_nll;\n        self.total_tokens += effective_tokens;\n\n        // Compute current batch perplexity for display\n        let batch_perplexity = if effective_tokens > 0 {\n            (batch_nll / effective_tokens as f64).exp()\n        } else {\n            f64::INFINITY\n        };\n        self.current = batch_perplexity;\n\n        // Compute running epoch perplexity\n        let epoch_perplexity = if self.total_tokens > 0 {\n            (self.sum_nll / self.total_tokens as f64).exp()\n        } else {\n            f64::INFINITY\n        };\n\n        // Format for display\n        let (formatted_current, formatted_running) = match format.precision_value() {\n            Some(precision) => (\n                format_float(batch_perplexity, precision),\n                format_float(epoch_perplexity, precision),\n            ),\n            None => (format!(\"{batch_perplexity}\"), format!(\"{epoch_perplexity}\")),\n        };\n\n        let formatted = match format.unit_value() {\n            Some(unit) => {\n                format!(\"epoch {formatted_running} {unit} - batch {formatted_current} {unit}\")\n            }\n            None => format!(\"epoch {formatted_running} - batch {formatted_current}\"),\n        };\n\n        // Serialize the state for aggregation\n        let serialized = NumericEntry::Aggregated {\n            aggregated_value: epoch_perplexity,\n            count: self.total_tokens,\n        }\n        .serialize();\n\n        SerializedEntry::new(formatted, serialized)\n    }\n\n    fn value(&self) -> NumericEntry {\n        let perplexity = if self.total_tokens > 0 {\n            (self.sum_nll / self.total_tokens as f64).exp()\n        } else {\n            f64::INFINITY\n        };\n\n        NumericEntry::Aggregated {\n            aggregated_value: perplexity,\n            count: self.total_tokens,\n        }\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.value()\n    }\n}\n\n/// The perplexity metric.\n///\n/// Perplexity is a measure of how well a probability distribution or probability model\n/// predicts a sample. It's commonly used to evaluate language models. A lower perplexity\n/// indicates that the model is more confident in its predictions.\n///\n/// Mathematically, perplexity is defined as the exponentiation of the cross-entropy loss:\n/// PPL = exp(H(p, q)) = exp(-1/N * Σ log(p(x_i)))\n///\n/// where:\n/// - H(p, q) is the cross-entropy between the true distribution p and predicted distribution q\n/// - N is the number of tokens\n/// - p(x_i) is the predicted probability of the i-th token\n///\n/// # Aggregation\n/// Unlike other metrics, perplexity cannot be simply averaged across batches.\n/// This implementation correctly accumulates the total negative log-likelihood and\n/// total token count across batches, then computes perplexity as exp(total_nll / total_tokens).\n#[derive(Clone)]\npub struct PerplexityMetric<B: Backend> {\n    name: MetricName,\n    state: PerplexityState,\n    pad_token: Option<usize>,\n    _b: PhantomData<B>,\n}\n\n/// The [perplexity metric](PerplexityMetric) input type.\n#[derive(new)]\npub struct PerplexityInput<B: Backend> {\n    /// Logits tensor of shape [batch_size * sequence_length, vocab_size]\n    outputs: Tensor<B, 2>,\n    /// Target tokens tensor of shape [batch_size * sequence_length]\n    targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Default for PerplexityMetric<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> PerplexityMetric<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self {\n            name: MetricName::new(\"Perplexity\".to_string()),\n            state: PerplexityState::new(),\n            pad_token: Default::default(),\n            _b: PhantomData,\n        }\n    }\n\n    /// Sets the pad token to exclude from perplexity calculation.\n    ///\n    /// When a pad token is set, predictions for padding tokens are masked out\n    /// and do not contribute to the perplexity calculation. This is important\n    /// for variable-length sequences where padding is used.\n    pub fn with_pad_token(mut self, index: usize) -> Self {\n        self.pad_token = Some(index);\n        self\n    }\n}\n\nimpl<B: Backend> Metric for PerplexityMetric<B> {\n    type Input = PerplexityInput<B>;\n\n    fn update(\n        &mut self,\n        input: &PerplexityInput<B>,\n        _metadata: &MetricMetadata,\n    ) -> SerializedEntry {\n        let targets = input.targets.clone();\n        let outputs = input.outputs.clone();\n\n        let [total_tokens, _vocab_size] = outputs.dims();\n\n        // Convert logits to log probabilities using log_softmax for numerical stability\n        let log_probs = burn_core::tensor::activation::log_softmax(outputs, 1);\n\n        // Gather the log probabilities for the target tokens\n        let target_log_probs = log_probs\n            .gather(1, targets.clone().unsqueeze_dim(1))\n            .squeeze_dim(1);\n\n        let (sum_log_prob, effective_tokens) = match self.pad_token {\n            Some(pad_token) => {\n                // Create a mask for non-padding tokens\n                let mask = targets.clone().not_equal_elem(pad_token as i64);\n\n                // Apply mask to log probabilities (set padding log probs to 0)\n                let masked_log_probs = target_log_probs.mask_fill(mask.clone().bool_not(), 0.0);\n\n                // Sum the log probabilities and count effective tokens\n                let sum_log_prob = masked_log_probs.sum().into_scalar().elem::<f64>();\n                let effective_tokens = mask.int().sum().into_scalar().elem::<i64>() as usize;\n\n                (sum_log_prob, effective_tokens)\n            }\n            None => {\n                // No padding, use all tokens\n                let sum_log_prob = target_log_probs.sum().into_scalar().elem::<f64>();\n                (sum_log_prob, total_tokens)\n            }\n        };\n\n        // Pass the sum_log_prob and effective_tokens to the state\n        // The state will handle the correct accumulation and perplexity calculation\n        self.state.update(\n            sum_log_prob,\n            effective_tokens,\n            FormatOptions::new(self.name()).precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for PerplexityMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn test_perplexity_perfect_prediction() {\n        let device = Default::default();\n        let mut metric = PerplexityMetric::<TestBackend>::new();\n\n        // Perfect prediction: target is always the highest probability class\n        let input = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [10.0, 0.0, 0.0], // Very confident prediction for class 0\n                    [0.0, 10.0, 0.0], // Very confident prediction for class 1\n                    [0.0, 0.0, 10.0], // Very confident prediction for class 2\n                ],\n                &device,\n            ),\n            Tensor::from_data([0, 1, 2], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        let perplexity = metric.value().current();\n\n        // Perfect predictions should result in very low perplexity (close to 1.0)\n        assert!(\n            perplexity < 1.1,\n            \"Perfect predictions should have low perplexity, got {}\",\n            perplexity\n        );\n    }\n\n    #[test]\n    fn test_perplexity_uniform_prediction() {\n        let device = Default::default();\n        let mut metric = PerplexityMetric::<TestBackend>::new();\n\n        // Uniform prediction: all classes have equal probability\n        let input = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)\n                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)\n                    [0.0, 0.0, 0.0], // Uniform distribution (after softmax)\n                ],\n                &device,\n            ),\n            Tensor::from_data([0, 1, 2], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        let perplexity = metric.value().current();\n\n        // Uniform distribution over 3 classes should have perplexity ≈ 3.0\n        assert!(\n            (perplexity - 3.0).abs() < 0.1,\n            \"Uniform distribution perplexity should be ~3.0, got {}\",\n            perplexity\n        );\n    }\n\n    #[test]\n    fn test_perplexity_with_padding() {\n        let device = Default::default();\n        let mut metric = PerplexityMetric::<TestBackend>::new().with_pad_token(3);\n\n        let input = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [10.0, 0.0, 0.0, 0.0], // Good prediction for class 0\n                    [0.0, 10.0, 0.0, 0.0], // Good prediction for class 1\n                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored\n                    [0.0, 0.0, 0.0, 1.0],  // This is padding - should be ignored\n                ],\n                &device,\n            ),\n            Tensor::from_data([0, 1, 3, 3], &device), // 3 is pad token\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        let perplexity = metric.value().current();\n\n        // Should only consider the first two predictions, both of which are confident\n        assert!(\n            perplexity < 1.1,\n            \"Good predictions with padding should have low perplexity, got {}\",\n            perplexity\n        );\n    }\n\n    #[test]\n    fn test_perplexity_wrong_prediction() {\n        let device = Default::default();\n        let mut metric = PerplexityMetric::<TestBackend>::new();\n\n        // Wrong predictions: target class has very low probability\n        let input = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 10.0, 0.0], // Predicts class 1, but target is 0\n                    [10.0, 0.0, 0.0], // Predicts class 0, but target is 1\n                    [0.0, 0.0, 10.0], // Predicts class 2, but target is 0\n                ],\n                &device,\n            ),\n            Tensor::from_data([0, 1, 0], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        let perplexity = metric.value().current();\n\n        // Wrong predictions should result in high perplexity\n        assert!(\n            perplexity > 10.0,\n            \"Wrong predictions should have high perplexity, got {}\",\n            perplexity\n        );\n    }\n\n    #[test]\n    fn test_perplexity_multi_batch_aggregation() {\n        let device = Default::default();\n        let mut metric = PerplexityMetric::<TestBackend>::new();\n\n        // First batch: 2 tokens with uniform distribution (log_prob ≈ -1.0986 each)\n        let input1 = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)\n                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)\n                ],\n                &device,\n            ),\n            Tensor::from_data([0, 1], &device),\n        );\n\n        // Second batch: 1 token with uniform distribution\n        let input2 = PerplexityInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.0, 0.0], // Uniform distribution (log_prob ≈ -1.0986)\n                ],\n                &device,\n            ),\n            Tensor::from_data([2], &device),\n        );\n\n        // Update with both batches\n        let _entry1 = metric.update(&input1, &MetricMetadata::fake());\n        let _entry2 = metric.update(&input2, &MetricMetadata::fake());\n\n        let aggregated_perplexity = metric.value().current();\n\n        // For uniform distribution over 3 classes: log_prob ≈ -log(3) ≈ -1.0986\n        // Total negative log-likelihood: 3 * 1.0986 ≈ 3.2958\n        // Total tokens: 3\n        // Expected perplexity: exp(3.2958 / 3) = exp(1.0986) ≈ 3.0\n        assert!(\n            (aggregated_perplexity - 3.0).abs() < 0.1,\n            \"Multi-batch aggregated perplexity should be ~3.0, got {}\",\n            aggregated_perplexity\n        );\n\n        // Compare with single batch containing all data\n        let mut single_batch_metric = PerplexityMetric::<TestBackend>::new();\n        let single_input = PerplexityInput::new(\n            Tensor::from_data([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], &device),\n            Tensor::from_data([0, 1, 2], &device),\n        );\n\n        let _single_entry = single_batch_metric.update(&single_input, &MetricMetadata::fake());\n        let single_batch_perplexity = single_batch_metric.value().current();\n\n        // Multi-batch and single-batch should give the same result\n        assert!(\n            (aggregated_perplexity - single_batch_perplexity).abs() < 0.01,\n            \"Multi-batch ({}) and single-batch ({}) perplexity should match\",\n            aggregated_perplexity,\n            single_batch_perplexity\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/precision.rs",
    "content": "use crate::metric::{MetricName, Numeric};\n\nuse super::{\n    Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,\n    classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},\n    confusion_stats::{ConfusionStats, ConfusionStatsInput},\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::cast::ToElement,\n};\nuse core::marker::PhantomData;\nuse std::{num::NonZeroUsize, sync::Arc};\n\n/// The Precision Metric\n#[derive(Clone)]\npub struct PrecisionMetric<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    _b: PhantomData<B>,\n    config: ClassificationMetricConfig,\n}\n\nimpl<B: Backend> Default for PrecisionMetric<B> {\n    fn default() -> Self {\n        Self::new(Default::default())\n    }\n}\n\nimpl<B: Backend> PrecisionMetric<B> {\n    fn new(config: ClassificationMetricConfig) -> Self {\n        let state = Default::default();\n        let name = Arc::new(format!(\n            \"Precision @ {:?} [{:?}]\",\n            config.decision_rule, config.class_reduction\n        ));\n\n        Self {\n            state,\n            config,\n            name,\n            _b: Default::default(),\n        }\n    }\n    /// Precision metric for binary classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `threshold` - The threshold to transform a probability into a binary prediction.\n    #[allow(dead_code)]\n    pub fn binary(threshold: f64) -> Self {\n        Self::new(ClassificationMetricConfig {\n            decision_rule: DecisionRule::Threshold(threshold),\n            // binary classification results are the same independently of class_reduction\n            ..Default::default()\n        })\n    }\n\n    /// Precision metric for multiclass classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {\n        Self::new(ClassificationMetricConfig {\n            decision_rule: DecisionRule::TopK(\n                NonZeroUsize::new(top_k).expect(\"top_k must be non-zero\"),\n            ),\n            class_reduction,\n        })\n    }\n\n    /// Precision metric for multi-label classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `threshold` - The threshold to transform a probability into a binary value.\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {\n        Self {\n            config: ClassificationMetricConfig {\n                decision_rule: DecisionRule::Threshold(threshold),\n                class_reduction,\n            },\n            ..Default::default()\n        }\n    }\n\n    fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {\n        use ClassReduction::{Macro, Micro};\n        let avg_tensor = match self.config.class_reduction {\n            Micro => aggregated_metric,\n            Macro => {\n                if aggregated_metric\n                    .clone()\n                    .contains_nan()\n                    .any()\n                    .into_scalar()\n                    .to_bool()\n                {\n                    let nan_mask = aggregated_metric.clone().is_nan();\n                    aggregated_metric = aggregated_metric\n                        .clone()\n                        .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))\n                }\n                aggregated_metric.mean()\n            }\n        };\n        avg_tensor.into_scalar().to_f64()\n    }\n}\n\nimpl<B: Backend> Metric for PrecisionMetric<B> {\n    type Input = ConfusionStatsInput<B>;\n\n    fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [sample_size, _] = input.predictions.dims();\n\n        let cf_stats = ConfusionStats::new(input, &self.config);\n        let metric =\n            self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive());\n\n        self.state.update(\n            100.0 * metric,\n            sample_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for PrecisionMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::{\n        ClassReduction::{self, *},\n        Metric, MetricMetadata, PrecisionMetric,\n    };\n    use crate::metric::Numeric;\n    use crate::{\n        TestBackend,\n        tests::{ClassificationType, THRESHOLD, dummy_classification_input},\n    };\n    use burn_core::tensor::TensorData;\n    use burn_core::tensor::Tolerance;\n    use rstest::rstest;\n\n    #[rstest]\n    #[case::binary(THRESHOLD, 0.5)]\n    fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) {\n        let input = dummy_classification_input(&ClassificationType::Binary).into();\n        let mut metric = PrecisionMetric::binary(threshold);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]\n    #[case::multiclass_micro_k2(Micro, 2, 4.0/10.0)]\n    #[case::multiclass_macro_k1(Macro, 1, (0.5 + 0.5 + 1.0)/3.0)]\n    #[case::multiclass_macro_k2(Macro, 2, (0.5 + 1.0/4.0 + 0.5)/3.0)]\n    fn test_multiclass_precision(\n        #[case] class_reduction: ClassReduction,\n        #[case] top_k: usize,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multiclass).into();\n        let mut metric = PrecisionMetric::multiclass(top_k, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)]\n    #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)]\n    fn test_multilabel_precision(\n        #[case] class_reduction: ClassReduction,\n        #[case] threshold: f64,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multilabel).into();\n        let mut metric = PrecisionMetric::multilabel(threshold, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[test]\n    fn test_parameterized_unique_name() {\n        let metric_a = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);\n        let metric_b = PrecisionMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);\n        let metric_c = PrecisionMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);\n\n        assert_ne!(metric_a.name(), metric_b.name());\n        assert_eq!(metric_a.name(), metric_c.name());\n\n        let metric_a = PrecisionMetric::<TestBackend>::binary(0.5);\n        let metric_b = PrecisionMetric::<TestBackend>::binary(0.75);\n        assert_ne!(metric_a.name(), metric_b.name());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/async_wrapper.rs",
    "content": "use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};\n\nuse super::EventProcessorTraining;\nuse async_channel::{Receiver, Sender};\n\n/// Event processor for the training process.\npub struct AsyncProcessorTraining<ET, EV> {\n    sender: Sender<Message<ET, EV>>,\n}\n\n/// Event processor for the model evaluation.\npub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {\n    sender: Sender<EvalMessage<P>>,\n}\n\nstruct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {\n    processor: P,\n    rec: Receiver<Message<ET, EV>>,\n}\n\nstruct WorkerEvaluation<P: EventProcessorEvaluation> {\n    processor: P,\n    rec: Receiver<EvalMessage<P>>,\n}\n\nimpl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>\n    WorkerTraining<ET, EV, P>\n{\n    pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {\n        let mut worker = Self { processor, rec };\n        std::thread::Builder::new()\n            .name(\"train-worker\".into())\n            .spawn(move || {\n                while let Ok(msg) = worker.rec.recv_blocking() {\n                    match msg {\n                        Message::Train(event) => worker.processor.process_train(event),\n                        Message::Valid(event) => worker.processor.process_valid(event),\n                        Message::Renderer(callback) => {\n                            callback.send_blocking(worker.processor.renderer()).unwrap();\n                            return;\n                        }\n                    }\n                }\n            })\n            .unwrap();\n    }\n}\nimpl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {\n    pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {\n        let mut worker = Self { processor, rec };\n\n        std::thread::Builder::new()\n            .name(\"evel-worker\".into())\n            .spawn(move || {\n                while let Ok(event) = worker.rec.recv_blocking() {\n                    match event {\n                        EvalMessage::Test(event) => worker.processor.process_test(event),\n                        EvalMessage::Renderer(sender) => {\n                            sender.send_blocking(worker.processor.renderer()).unwrap();\n                            return;\n                        }\n                    }\n                }\n            })\n            .unwrap();\n    }\n}\n\nimpl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {\n    /// Create an event processor for training.\n    pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {\n        let (sender, rec) = async_channel::bounded(1);\n\n        WorkerTraining::start(processor, rec);\n\n        Self { sender }\n    }\n}\n\nimpl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {\n    /// Create an event processor for model evaluation.\n    pub fn new(processor: P) -> Self {\n        let (sender, rec) = async_channel::bounded(1);\n\n        WorkerEvaluation::start(processor, rec);\n\n        Self { sender }\n    }\n}\n\nenum Message<EventTrain, EventValid> {\n    Train(EventTrain),\n    Valid(EventValid),\n    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),\n}\n\nenum EvalMessage<P: EventProcessorEvaluation> {\n    Test(EvaluatorEvent<P::ItemTest>),\n    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),\n}\n\nimpl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {\n    fn process_train(&mut self, event: ET) {\n        self.sender.send_blocking(Message::Train(event)).unwrap();\n    }\n\n    fn process_valid(&mut self, event: EV) {\n        self.sender.send_blocking(Message::Valid(event)).unwrap();\n    }\n\n    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {\n        let (sender, rec) = async_channel::bounded(1);\n        self.sender\n            .send_blocking(Message::Renderer(sender))\n            .unwrap();\n\n        match rec.recv_blocking() {\n            Ok(value) => value,\n            Err(err) => panic!(\"{err:?}\"),\n        }\n    }\n}\n\nimpl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {\n    type ItemTest = P::ItemTest;\n\n    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {\n        self.sender.send_blocking(EvalMessage::Test(event)).unwrap();\n    }\n\n    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {\n        let (sender, rec) = async_channel::bounded(1);\n        self.sender\n            .send_blocking(EvalMessage::Renderer(sender))\n            .unwrap();\n\n        match rec.recv_blocking() {\n            Ok(value) => value,\n            Err(err) => panic!(\"{err:?}\"),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/base.rs",
    "content": "use burn_core::data::dataloader::Progress;\nuse burn_optim::LearningRate;\n\nuse crate::{\n    LearnerSummary,\n    renderer::{EvaluationName, MetricsRenderer},\n};\n\n/// Event happening during the training/validation process.\npub enum LearnerEvent<T> {\n    /// Signal the start of the process (e.g., training start)\n    Start,\n    /// Signal that an item have been processed.\n    ProcessedItem(TrainingItem<T>),\n    /// Signal the end of an epoch.\n    EndEpoch(usize),\n    /// Signal the end of the process (e.g., training end).\n    End(Option<LearnerSummary>),\n}\n\n/// Event happening during the evaluation process.\npub enum EvaluatorEvent<T> {\n    /// Signal the start of the process (e.g., evaluation start)\n    Start,\n    /// Signal that an item have been processed.\n    ProcessedItem(EvaluationName, EvaluationItem<T>),\n    /// Signal the end of the process (e.g., evaluation end).\n    End(Option<LearnerSummary>),\n}\n\n/// Items that are lazy are not ready to be processed by metrics.\n///\n/// We want to sync them on a different thread to avoid blocking training.\npub trait ItemLazy: Send {\n    /// Item that is properly synced and ready to be processed by metrics.\n    type ItemSync: Send;\n\n    /// Sync the item.\n    fn sync(self) -> Self::ItemSync;\n}\n\n/// Process events happening during training and validation.\npub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {\n    /// Collect a training event.\n    fn process_train(&mut self, event: TrainEvent);\n    /// Collect a validation event.\n    fn process_valid(&mut self, event: ValidEvent);\n    /// Returns the renderer used for training.\n    fn renderer(self) -> Box<dyn MetricsRenderer>;\n}\n\n/// Process events happening during evaluation.\npub trait EventProcessorEvaluation: Send {\n    /// The test item.\n    type ItemTest: ItemLazy;\n\n    /// Collect a test event.\n    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);\n\n    /// Returns the renderer used for evaluation.\n    fn renderer(self) -> Box<dyn MetricsRenderer>;\n}\n\n/// A learner item.\n#[derive(new)]\npub struct TrainingItem<T> {\n    /// The item.\n    pub item: T,\n\n    /// The progress.\n    pub progress: Progress,\n\n    /// The global progress of the training (e.g. epochs).\n    pub global_progress: Progress,\n\n    /// The iteration, if it it different from the items processed.\n    pub iteration: Option<usize>,\n\n    /// The learning rate.\n    pub lr: Option<LearningRate>,\n}\n\nimpl<T: ItemLazy> ItemLazy for TrainingItem<T> {\n    type ItemSync = TrainingItem<T::ItemSync>;\n\n    fn sync(self) -> Self::ItemSync {\n        TrainingItem {\n            item: self.item.sync(),\n            progress: self.progress,\n            global_progress: self.global_progress,\n            iteration: self.iteration,\n            lr: self.lr,\n        }\n    }\n}\n\n/// An evaluation item.\n#[derive(new)]\npub struct EvaluationItem<T> {\n    /// The item.\n    pub item: T,\n\n    /// The progress.\n    pub progress: Progress,\n\n    /// The iteration, if it it different from the items processed.\n    pub iteration: Option<usize>,\n}\n\nimpl<T: ItemLazy> ItemLazy for EvaluationItem<T> {\n    type ItemSync = EvaluationItem<T::ItemSync>;\n\n    fn sync(self) -> Self::ItemSync {\n        EvaluationItem {\n            item: self.item.sync(),\n            progress: self.progress,\n            iteration: self.iteration,\n        }\n    }\n}\n\nimpl ItemLazy for () {\n    type ItemSync = ();\n\n    fn sync(self) -> Self::ItemSync {}\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/full.rs",
    "content": "use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};\nuse crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};\nuse crate::metric::store::{EpochSummary, EventStoreClient, Split};\nuse crate::renderer::{\n    EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,\n};\nuse std::sync::Arc;\n\n/// An [event processor](EventProcessorTraining) that handles:\n///   - Computing and storing metrics in an [event store](crate::metric::store::EventStore).\n///   - Render metrics using a [metrics renderer](MetricsRenderer).\npub struct FullEventProcessorTraining<T: ItemLazy, V: ItemLazy> {\n    metrics: MetricsTraining<T, V>,\n    renderer: Box<dyn MetricsRenderer>,\n    store: Arc<EventStoreClient>,\n}\n\n/// An [event processor](EventProcessorEvaluation) that handles:\n///   - Computing and storing metrics in an [event store](crate::metric::store::EventStore).\n///   - Render metrics using a [metrics renderer](MetricsRenderer).\npub struct FullEventProcessorEvaluation<T: ItemLazy> {\n    metrics: MetricsEvaluation<T>,\n    renderer: Box<dyn MetricsRenderer>,\n    store: Arc<EventStoreClient>,\n}\n\nimpl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {\n    pub(crate) fn new(\n        metrics: MetricsTraining<T, V>,\n        renderer: Box<dyn MetricsRenderer>,\n        store: Arc<EventStoreClient>,\n    ) -> Self {\n        Self {\n            metrics,\n            renderer,\n            store,\n        }\n    }\n\n    fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {\n        let mut indicators = vec![];\n        indicators.push(ProgressType::Detailed {\n            tag: String::from(\"Epoch\"),\n            progress: progress.global_progress.clone(),\n        });\n\n        if let Some(iteration) = progress.iteration {\n            indicators.push(ProgressType::Value {\n                tag: String::from(\"Iteration\"),\n                value: iteration,\n            });\n        };\n\n        if let Some(p) = &progress.progress {\n            indicators.push(ProgressType::Detailed {\n                tag: String::from(\"Items\"),\n                progress: p.clone(),\n            });\n        };\n\n        indicators\n    }\n}\n\nimpl<T: ItemLazy> FullEventProcessorEvaluation<T> {\n    pub(crate) fn new(\n        metrics: MetricsEvaluation<T>,\n        renderer: Box<dyn MetricsRenderer>,\n        store: Arc<EventStoreClient>,\n    ) -> Self {\n        Self {\n            metrics,\n            renderer,\n            store,\n        }\n    }\n\n    fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {\n        let mut indicators = vec![];\n        if let Some(iteration) = progress.iteration {\n            indicators.push(ProgressType::Value {\n                tag: String::from(\"Iteration\"),\n                value: iteration,\n            });\n        };\n\n        indicators.push(ProgressType::Detailed {\n            tag: String::from(\"Items\"),\n            progress: progress.progress.clone(),\n        });\n\n        indicators\n    }\n}\n\nimpl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {\n    type ItemTest = T;\n\n    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {\n        match event {\n            EvaluatorEvent::Start => {\n                let definitions = self.metrics.metric_definitions();\n                self.store\n                    .add_event_train(crate::metric::store::Event::MetricsInit(\n                        definitions.clone(),\n                    ));\n                definitions\n                    .iter()\n                    .for_each(|definition| self.renderer.register_metric(definition.clone()));\n            }\n            EvaluatorEvent::ProcessedItem(name, item) => {\n                let item = item.sync();\n                let progress = (&item).into();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_test(&item, &metadata);\n\n                self.store.add_event_test(\n                    crate::metric::store::Event::MetricsUpdate(update.clone()),\n                    name.name.clone(),\n                );\n\n                update.entries.into_iter().for_each(|entry| {\n                    self.renderer\n                        .update_test(name.clone(), MetricState::Generic(entry))\n                });\n\n                update\n                    .entries_numeric\n                    .into_iter()\n                    .for_each(|numeric_update| {\n                        self.renderer.update_test(\n                            name.clone(),\n                            MetricState::Numeric(\n                                numeric_update.entry,\n                                numeric_update.numeric_entry,\n                            ),\n                        )\n                    });\n\n                let indicators = self.progress_indicators(&progress);\n                self.renderer.render_test(progress, indicators);\n            }\n            EvaluatorEvent::End(summary) => {\n                self.renderer.on_test_end(summary).ok();\n            }\n        }\n    }\n\n    fn renderer(self) -> Box<dyn MetricsRenderer> {\n        self.renderer\n    }\n}\n\nimpl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>\n    for FullEventProcessorTraining<T, V>\n{\n    fn process_train(&mut self, event: LearnerEvent<T>) {\n        match event {\n            LearnerEvent::Start => {\n                let definitions = self.metrics.metric_definitions();\n                self.store\n                    .add_event_train(crate::metric::store::Event::MetricsInit(\n                        definitions.clone(),\n                    ));\n                definitions\n                    .iter()\n                    .for_each(|definition| self.renderer.register_metric(definition.clone()));\n            }\n            LearnerEvent::ProcessedItem(item) => {\n                let item = item.sync();\n                let progress = (&item).into();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_train(&item, &metadata);\n\n                self.store\n                    .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));\n\n                update\n                    .entries\n                    .into_iter()\n                    .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));\n\n                update\n                    .entries_numeric\n                    .into_iter()\n                    .for_each(|numeric_update| {\n                        self.renderer.update_train(MetricState::Numeric(\n                            numeric_update.entry,\n                            numeric_update.numeric_entry,\n                        ))\n                    });\n\n                let indicators = self.progress_indicators(&progress);\n                self.renderer.render_train(progress, indicators);\n            }\n            LearnerEvent::EndEpoch(epoch) => {\n                self.store\n                    .add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(\n                        epoch,\n                        Split::Train,\n                    )));\n                self.metrics.end_epoch_train();\n            }\n            LearnerEvent::End(summary) => {\n                self.renderer.on_train_end(summary).ok();\n            }\n        }\n    }\n\n    fn process_valid(&mut self, event: LearnerEvent<V>) {\n        match event {\n            LearnerEvent::Start => {} // no-op for now\n            LearnerEvent::ProcessedItem(item) => {\n                let item = item.sync();\n                let progress = (&item).into();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_valid(&item, &metadata);\n\n                self.store\n                    .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));\n\n                update\n                    .entries\n                    .into_iter()\n                    .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));\n\n                update\n                    .entries_numeric\n                    .into_iter()\n                    .for_each(|numeric_update| {\n                        self.renderer.update_valid(MetricState::Numeric(\n                            numeric_update.entry,\n                            numeric_update.numeric_entry,\n                        ))\n                    });\n\n                let indicators = self.progress_indicators(&progress);\n                self.renderer.render_valid(progress, indicators);\n            }\n            LearnerEvent::EndEpoch(epoch) => {\n                self.store\n                    .add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(\n                        epoch,\n                        Split::Valid,\n                    )));\n                self.metrics.end_epoch_valid();\n            }\n            LearnerEvent::End(_) => {} // no-op for now\n        }\n    }\n    fn renderer(self) -> Box<dyn MetricsRenderer> {\n        self.renderer\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/metrics.rs",
    "content": "use std::collections::HashMap;\n\nuse super::{ItemLazy, TrainingItem};\nuse crate::{\n    EvaluationItem,\n    metric::{\n        Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,\n        store::{MetricsUpdate, NumericMetricUpdate},\n    },\n    renderer::{EvaluationProgress, TrainingProgress},\n};\n\npub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {\n    train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,\n    valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,\n    train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,\n    valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n}\n\npub(crate) struct MetricsEvaluation<T: ItemLazy> {\n    test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,\n    test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n}\n\nimpl<T: ItemLazy> Default for MetricsEvaluation<T> {\n    fn default() -> Self {\n        Self {\n            test: Default::default(),\n            test_numeric: Default::default(),\n            metric_definitions: HashMap::default(),\n        }\n    }\n}\n\nimpl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {\n    fn default() -> Self {\n        Self {\n            train: Vec::default(),\n            valid: Vec::default(),\n            train_numeric: Vec::default(),\n            valid_numeric: Vec::default(),\n            metric_definitions: HashMap::default(),\n        }\n    }\n}\n\nimpl<T: ItemLazy> MetricsEvaluation<T> {\n    /// Register a testing metric.\n    pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        T::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.test.push(Box::new(metric))\n    }\n\n    /// Register a numeric testing metric.\n    pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(\n        &mut self,\n        metric: Me,\n    ) where\n        T::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.test_numeric.push(Box::new(metric))\n    }\n\n    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {\n        self.metric_definitions.insert(\n            metric.id.clone(),\n            MetricDefinition::new(metric.id.clone(), &metric.metric),\n        );\n    }\n\n    /// Get metric definitions.\n    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {\n        self.metric_definitions.values().cloned().collect()\n    }\n\n    /// Update the testing information from the testing item.\n    pub(crate) fn update_test(\n        &mut self,\n        item: &EvaluationItem<T::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.test.len());\n        let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());\n\n        for metric in self.test.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.test_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n}\n\nimpl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {\n    /// Register a training metric.\n    pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        T::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.train.push(Box::new(metric))\n    }\n\n    /// Register a validation metric.\n    pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        V::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.valid.push(Box::new(metric))\n    }\n\n    /// Register a numeric training metric.\n    pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(\n        &mut self,\n        metric: Me,\n    ) where\n        T::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.train_numeric.push(Box::new(metric))\n    }\n\n    /// Register a numeric validation metric.\n    pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)\n    where\n        V::ItemSync: Adaptor<Me::Input> + 'static,\n        Me: Metric + Numeric + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.valid_numeric.push(Box::new(metric))\n    }\n\n    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {\n        self.metric_definitions.insert(\n            metric.id.clone(),\n            MetricDefinition::new(metric.id.clone(), &metric.metric),\n        );\n    }\n\n    /// Get metric definitions for all splits\n    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {\n        self.metric_definitions.values().cloned().collect()\n    }\n\n    /// Update the training information from the training item.\n    pub(crate) fn update_train(\n        &mut self,\n        item: &TrainingItem<T::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.train.len());\n        let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());\n\n        for metric in self.train.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.train_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Update the training information from the validation item.\n    pub(crate) fn update_valid(\n        &mut self,\n        item: &TrainingItem<V::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.valid.len());\n        let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());\n\n        for metric in self.valid.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.valid_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Signal the end of a training epoch.\n    pub(crate) fn end_epoch_train(&mut self) {\n        for metric in self.train.iter_mut() {\n            metric.clear();\n        }\n        for metric in self.train_numeric.iter_mut() {\n            metric.clear();\n        }\n    }\n\n    /// Signal the end of a validation epoch.\n    pub(crate) fn end_epoch_valid(&mut self) {\n        for metric in self.valid.iter_mut() {\n            metric.clear();\n        }\n        for metric in self.valid_numeric.iter_mut() {\n            metric.clear();\n        }\n    }\n}\n\nimpl<T> From<&TrainingItem<T>> for TrainingProgress {\n    fn from(item: &TrainingItem<T>) -> Self {\n        Self {\n            progress: Some(item.progress.clone()),\n            global_progress: item.global_progress.clone(),\n            iteration: item.iteration,\n        }\n    }\n}\n\nimpl<T> From<&EvaluationItem<T>> for TrainingProgress {\n    fn from(item: &EvaluationItem<T>) -> Self {\n        Self {\n            progress: None,\n            global_progress: item.progress.clone(),\n            iteration: item.iteration,\n        }\n    }\n}\n\nimpl<T> From<&EvaluationItem<T>> for EvaluationProgress {\n    fn from(item: &EvaluationItem<T>) -> Self {\n        Self {\n            progress: item.progress.clone(),\n            iteration: item.iteration,\n        }\n    }\n}\n\nimpl<T> From<&TrainingItem<T>> for MetricMetadata {\n    fn from(item: &TrainingItem<T>) -> Self {\n        Self {\n            progress: item.progress.clone(),\n            global_progress: item.global_progress.clone(),\n            iteration: item.iteration,\n            lr: item.lr,\n        }\n    }\n}\n\nimpl<T> From<&EvaluationItem<T>> for MetricMetadata {\n    fn from(item: &EvaluationItem<T>) -> Self {\n        Self {\n            progress: item.progress.clone(),\n            global_progress: item.progress.clone(),\n            iteration: item.iteration,\n            lr: None,\n        }\n    }\n}\n\npub(crate) trait NumericMetricUpdater<T>: Send + Sync {\n    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;\n    fn clear(&mut self);\n}\n\npub(crate) trait MetricUpdater<T>: Send + Sync {\n    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;\n    fn clear(&mut self);\n}\n\npub(crate) struct MetricWrapper<M> {\n    pub id: MetricId,\n    pub metric: M,\n}\n\nimpl<M: Metric> MetricWrapper<M> {\n    pub fn new(metric: M) -> Self {\n        Self {\n            id: MetricId::new(metric.name()),\n            metric,\n        }\n    }\n}\n\nimpl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>\nwhere\n    T: 'static,\n    M: Metric + Numeric + 'static,\n    T: Adaptor<M::Input>,\n{\n    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {\n        let serialized_entry = self.metric.update(&item.adapt(), metadata);\n        let update = MetricEntry::new(self.id.clone(), serialized_entry);\n        let numeric = self.metric.value();\n        let running = self.metric.running_value();\n\n        NumericMetricUpdate {\n            entry: update,\n            numeric_entry: numeric,\n            running_entry: running,\n        }\n    }\n\n    fn clear(&mut self) {\n        self.metric.clear()\n    }\n}\n\nimpl<T, M> MetricUpdater<T> for MetricWrapper<M>\nwhere\n    T: 'static,\n    M: Metric + 'static,\n    T: Adaptor<M::Input>,\n{\n    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {\n        let serialized_entry = self.metric.update(&item.adapt(), metadata);\n        MetricEntry::new(self.id.clone(), serialized_entry)\n    }\n\n    fn clear(&mut self) {\n        self.metric.clear()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/minimal.rs",
    "content": "use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};\nuse crate::{\n    metric::store::{EpochSummary, EventStoreClient, Split},\n    renderer::cli::CliMetricsRenderer,\n};\nuse std::sync::Arc;\n\n/// An [event processor](EventProcessor) that handles:\n///   - Computing and storing metrics in an [event store](crate::metric::store::EventStore).\n#[allow(dead_code)]\n#[derive(new)]\npub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {\n    metrics: MetricsTraining<T, V>,\n    store: Arc<EventStoreClient>,\n}\n\nimpl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>\n    for MinimalEventProcessor<T, V>\n{\n    fn process_train(&mut self, event: LearnerEvent<T>) {\n        match event {\n            LearnerEvent::Start => {\n                let definitions = self.metrics.metric_definitions();\n                self.store\n                    .add_event_train(crate::metric::store::Event::MetricsInit(definitions));\n            }\n\n            LearnerEvent::ProcessedItem(item) => {\n                let item = item.sync();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_train(&item, &metadata);\n\n                self.store\n                    .add_event_train(crate::metric::store::Event::MetricsUpdate(update));\n            }\n            LearnerEvent::EndEpoch(epoch) => {\n                self.metrics.end_epoch_train();\n                self.store\n                    .add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(\n                        epoch,\n                        Split::Train,\n                    )));\n            }\n            LearnerEvent::End(_summary) => {} // no-op for now\n        }\n    }\n\n    fn process_valid(&mut self, event: LearnerEvent<V>) {\n        match event {\n            LearnerEvent::Start => {} // no-op for now\n            LearnerEvent::ProcessedItem(item) => {\n                let item = item.sync();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_valid(&item, &metadata);\n\n                self.store\n                    .add_event_valid(crate::metric::store::Event::MetricsUpdate(update));\n            }\n            LearnerEvent::EndEpoch(epoch) => {\n                self.metrics.end_epoch_valid();\n                self.store\n                    .add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(\n                        epoch,\n                        Split::Valid,\n                    )));\n            }\n            LearnerEvent::End(_) => {} // no-op for now\n        }\n    }\n    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {\n        // TODO: Check for another default.\n        Box::new(CliMetricsRenderer::new())\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/mod.rs",
    "content": "mod async_wrapper;\nmod base;\nmod full;\nmod metrics;\nmod minimal;\n#[cfg(feature = \"rl\")]\nmod rl_metrics;\n#[cfg(feature = \"rl\")]\nmod rl_processor;\n\npub use base::*;\npub(crate) use full::*;\npub(crate) use metrics::*;\n#[cfg(feature = \"rl\")]\npub(crate) use rl_metrics::*;\n#[cfg(feature = \"rl\")]\npub(crate) use rl_processor::*;\n\n#[cfg(test)]\npub(crate) use minimal::*;\n\npub use async_wrapper::{AsyncProcessorEvaluation, AsyncProcessorTraining};\n\n#[cfg(test)]\npub(crate) mod test_utils {\n    use crate::metric::{\n        Adaptor, LossInput,\n        processor::{EventProcessorTraining, LearnerEvent, MinimalEventProcessor, TrainingItem},\n    };\n    use burn_core::tensor::{ElementConversion, Tensor, backend::Backend};\n\n    use super::ItemLazy;\n\n    impl ItemLazy for f64 {\n        type ItemSync = f64;\n\n        fn sync(self) -> Self::ItemSync {\n            self\n        }\n    }\n\n    impl<B: Backend> Adaptor<LossInput<B>> for f64 {\n        fn adapt(&self) -> LossInput<B> {\n            let device = B::Device::default();\n            LossInput::new(Tensor::from_data([self.elem::<B::FloatElem>()], &device))\n        }\n    }\n\n    pub(crate) fn process_train(\n        processor: &mut MinimalEventProcessor<f64, f64>,\n        value: f64,\n        epoch: usize,\n    ) {\n        let dummy_progress = burn_core::data::dataloader::Progress {\n            items_processed: 1,\n            items_total: 10,\n        };\n        let dummy_global_progress = burn_core::data::dataloader::Progress {\n            items_processed: epoch,\n            items_total: 3,\n        };\n        let dummy_iteration = Some(1);\n\n        processor.process_train(LearnerEvent::ProcessedItem(TrainingItem::new(\n            value,\n            dummy_progress,\n            dummy_global_progress,\n            dummy_iteration,\n            None,\n        )));\n    }\n\n    pub(crate) fn end_epoch(processor: &mut MinimalEventProcessor<f64, f64>, epoch: usize) {\n        processor.process_train(LearnerEvent::EndEpoch(epoch));\n        processor.process_valid(LearnerEvent::EndEpoch(epoch));\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/rl_metrics.rs",
    "content": "use std::collections::HashMap;\n\nuse crate::{\n    EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,\n    metric::{\n        Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,\n    },\n};\n\npub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {\n    train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,\n    env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,\n    env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,\n    episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,\n    episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,\n\n    train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,\n    env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,\n    env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,\n    episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,\n    episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,\n\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n}\n\nimpl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {\n    fn default() -> Self {\n        Self {\n            train_step: Vec::default(),\n            env_step: Vec::default(),\n            env_step_valid: Vec::default(),\n            episode_end: Vec::default(),\n            episode_end_valid: Vec::default(),\n            train_step_numeric: Vec::default(),\n            env_step_numeric: Vec::default(),\n            env_step_valid_numeric: Vec::default(),\n            episode_end_numeric: Vec::default(),\n            episode_end_valid_numeric: Vec::default(),\n            metric_definitions: HashMap::default(),\n        }\n    }\n}\n\nimpl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {\n    /// Register a training metric.\n    pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        ES::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.env_step.push(Box::new(metric))\n    }\n\n    /// Register a training metric.\n    pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)\n    where\n        ES::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.env_step_numeric.push(Box::new(metric))\n    }\n\n    /// Register a training metric.\n    pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        TS::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.train_step.push(Box::new(metric))\n    }\n\n    /// Register a training metric.\n    pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)\n    where\n        TS::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.train_step_numeric.push(Box::new(metric))\n    }\n\n    /// Register a validation env-step metric.\n    pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        ES::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.env_step_valid.push(Box::new(metric))\n    }\n\n    /// Register a validation env-step numeric metric.\n    pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)\n    where\n        ES::ItemSync: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.env_step_valid_numeric.push(Box::new(metric))\n    }\n\n    /// Register an episode-end metric.\n    pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.episode_end.push(Box::new(metric))\n    }\n\n    /// Register an episode-end numeric metric.\n    pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)\n    where\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.episode_end_numeric.push(Box::new(metric))\n    }\n\n    /// Register an episode-end metric for validation.\n    pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)\n    where\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.episode_end_valid.push(Box::new(metric))\n    }\n\n    /// Register an episode-end numeric metric for validation.\n    pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(\n        &mut self,\n        metric: Me,\n    ) where\n        EpisodeSummary: Adaptor<Me::Input> + 'static,\n    {\n        let metric = MetricWrapper::new(metric);\n        self.register_definition(&metric);\n        self.episode_end_valid_numeric.push(Box::new(metric))\n    }\n\n    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {\n        self.metric_definitions.insert(\n            metric.id.clone(),\n            MetricDefinition::new(metric.id.clone(), &metric.metric),\n        );\n    }\n\n    /// Get metric definitions for all splits\n    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {\n        self.metric_definitions.values().cloned().collect()\n    }\n\n    /// Update the training information from the training item.\n    pub(crate) fn update_train_step(\n        &mut self,\n        item: &EvaluationItem<TS::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.train_step.len());\n        let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());\n\n        for metric in self.train_step.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.train_step_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Update the env-step metrics from an environment step item.\n    pub(crate) fn update_env_step(\n        &mut self,\n        item: &EvaluationItem<ES::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.env_step.len());\n        let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());\n\n        for metric in self.env_step.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.env_step_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Update the env-step metrics for validation from an environment step item.\n    pub(crate) fn update_env_step_valid(\n        &mut self,\n        item: &EvaluationItem<ES::ItemSync>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.env_step_valid.len());\n        let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());\n\n        for metric in self.env_step_valid.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.env_step_valid_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Update the episode-end metrics from an episode summary.\n    pub(crate) fn update_episode_end(\n        &mut self,\n        item: &EvaluationItem<EpisodeSummary>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.episode_end.len());\n        let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());\n\n        for metric in self.episode_end.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.episode_end_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n\n    /// Update the episode-end metrics for validation from an episode summary.\n    pub(crate) fn update_episode_end_valid(\n        &mut self,\n        item: &EvaluationItem<EpisodeSummary>,\n        metadata: &MetricMetadata,\n    ) -> MetricsUpdate {\n        let mut entries = Vec::with_capacity(self.episode_end_valid.len());\n        let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());\n\n        for metric in self.episode_end_valid.iter_mut() {\n            let state = metric.update(&item.item, metadata);\n            entries.push(state);\n        }\n\n        for metric in self.episode_end_valid_numeric.iter_mut() {\n            let numeric_update = metric.update(&item.item, metadata);\n            entries_numeric.push(numeric_update);\n        }\n\n        MetricsUpdate::new(entries, entries_numeric)\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/processor/rl_processor.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    EpisodeSummary, EvaluationItem, EventProcessorTraining, ItemLazy, LearnerSummary, RLMetrics,\n    metric::store::{Event, EventStoreClient, MetricsUpdate},\n    renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},\n};\n\n/// Event happening during reinforcement learning.\npub enum RLEvent<TS, ES> {\n    /// Signal the start of the process (e.g., learning starts).\n    Start,\n    /// Signal an agent's training step.\n    TrainStep(EvaluationItem<TS>),\n    /// Signal a timestep of the agent-environment interface.\n    TimeStep(EvaluationItem<ES>),\n    /// Signal an episode end.\n    EpisodeEnd(EvaluationItem<EpisodeSummary>),\n    /// Signal the end of the process (e.g., learning ends).\n    End(Option<LearnerSummary>),\n}\n\n/// Event happening during evaluation of a reinforcement learning's agent.\npub enum AgentEvaluationEvent<T> {\n    /// Signal the start of the process (e.g., training start)\n    Start,\n    /// Signal a timestep of the agent-environment interface.\n    TimeStep(EvaluationItem<T>),\n    /// Signal an episode end.\n    EpisodeEnd(EvaluationItem<EpisodeSummary>),\n    /// Signal the end of the process (e.g., training end).\n    End,\n}\n\n/// An [event processor](EventProcessorTraining) that handles:\n///   - Computing and storing metrics in an [event store](crate::metric::store::EventStore).\n///   - Render metrics using a [metrics renderer](MetricsRenderer).\n#[derive(new)]\npub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {\n    metrics: RLMetrics<TS, ES>,\n    renderer: Box<dyn MetricsRenderer>,\n    store: Arc<EventStoreClient>,\n}\n\nimpl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {\n    fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {\n        let indicators = vec![ProgressType::Detailed {\n            tag: String::from(\"Step\"),\n            progress: progress.global_progress.clone(),\n        }];\n\n        indicators\n    }\n\n    fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {\n        let indicators = vec![ProgressType::Detailed {\n            tag: String::from(\"Step\"),\n            progress: progress.global_progress.clone(),\n        }];\n\n        indicators\n    }\n}\n\nimpl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {\n    fn process_update_train(&mut self, update: MetricsUpdate) {\n        self.store\n            .add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));\n\n        update\n            .entries\n            .into_iter()\n            .for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));\n\n        update\n            .entries_numeric\n            .into_iter()\n            .for_each(|numeric_update| {\n                self.renderer.update_train(MetricState::Numeric(\n                    numeric_update.entry,\n                    numeric_update.numeric_entry,\n                ))\n            });\n    }\n\n    fn process_update_valid(&mut self, update: MetricsUpdate) {\n        self.store\n            .add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));\n\n        update\n            .entries\n            .into_iter()\n            .for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));\n\n        update\n            .entries_numeric\n            .into_iter()\n            .for_each(|numeric_update| {\n                self.renderer.update_valid(MetricState::Numeric(\n                    numeric_update.entry,\n                    numeric_update.numeric_entry,\n                ))\n            });\n    }\n}\n\nimpl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>\n    for RLEventProcessor<TS, ES>\n{\n    fn process_train(&mut self, event: RLEvent<TS, ES>) {\n        match event {\n            RLEvent::Start => {\n                let definitions = self.metrics.metric_definitions();\n                self.store\n                    .add_event_train(Event::MetricsInit(definitions.clone()));\n                definitions\n                    .iter()\n                    .for_each(|definition| self.renderer.register_metric(definition.clone()));\n            }\n            RLEvent::TrainStep(item) => {\n                let item = item.sync();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_train_step(&item, &metadata);\n                self.process_update_train(update);\n            }\n            RLEvent::TimeStep(item) => {\n                let item = item.sync();\n                let progress = (&item).into();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_env_step(&item, &metadata);\n                self.process_update_train(update);\n                let status = self.progress_indicators(&progress);\n                self.renderer.render_train(progress, status);\n            }\n            RLEvent::EpisodeEnd(item) => {\n                let item = item.sync();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_episode_end(&item, &metadata);\n                self.process_update_train(update);\n            }\n            RLEvent::End(learner_summary) => {\n                self.renderer.on_train_end(learner_summary).ok();\n            }\n        }\n    }\n\n    fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {\n        match event {\n            AgentEvaluationEvent::Start => {} // no-op for now\n            AgentEvaluationEvent::TimeStep(item) => {\n                let item = item.sync();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_env_step_valid(&item, &metadata);\n                self.process_update_valid(update);\n            }\n            AgentEvaluationEvent::EpisodeEnd(item) => {\n                let item = item.sync();\n                let progress = (&item).into();\n                let metadata = (&item).into();\n\n                let update = self.metrics.update_episode_end_valid(&item, &metadata);\n                self.process_update_valid(update);\n                let status = self.progress_indicators_eval(&progress);\n                self.renderer.render_valid(progress, status);\n            }\n            AgentEvaluationEvent::End => {} // no-op for now\n        }\n    }\n\n    fn renderer(self) -> Box<dyn MetricsRenderer> {\n        self.renderer\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/recall.rs",
    "content": "use crate::metric::{MetricName, Numeric};\n\nuse super::{\n    Metric, MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry, SerializedEntry,\n    classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},\n    confusion_stats::{ConfusionStats, ConfusionStatsInput},\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::cast::ToElement,\n};\nuse core::marker::PhantomData;\nuse std::{num::NonZeroUsize, sync::Arc};\n\n///The Recall Metric\n#[derive(Clone)]\npub struct RecallMetric<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    _b: PhantomData<B>,\n    config: ClassificationMetricConfig,\n}\n\nimpl<B: Backend> Default for RecallMetric<B> {\n    fn default() -> Self {\n        Self::new(Default::default())\n    }\n}\n\nimpl<B: Backend> RecallMetric<B> {\n    fn new(config: ClassificationMetricConfig) -> Self {\n        let state = Default::default();\n        let name = Arc::new(format!(\n            \"Recall @ {:?} [{:?}]\",\n            config.decision_rule, config.class_reduction\n        ));\n\n        Self {\n            state,\n            config,\n            name,\n            _b: Default::default(),\n        }\n    }\n    /// Recall metric for binary classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `threshold` - The threshold to transform a probability into a binary prediction.\n    #[allow(dead_code)]\n    pub fn binary(threshold: f64) -> Self {\n        Self::new(ClassificationMetricConfig {\n            decision_rule: DecisionRule::Threshold(threshold),\n            // binary classification results are the same independently of class_reduction\n            ..Default::default()\n        })\n    }\n\n    /// Recall metric for multiclass classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {\n        Self::new(ClassificationMetricConfig {\n            decision_rule: DecisionRule::TopK(\n                NonZeroUsize::new(top_k).expect(\"top_k must be non-zero\"),\n            ),\n            class_reduction,\n        })\n    }\n\n    /// Recall metric for multi-label classification.\n    ///\n    /// # Arguments\n    ///\n    /// * `threshold` - The threshold to transform a probability into a binary prediction.\n    /// * `class_reduction` - [Class reduction](ClassReduction) type.\n    #[allow(dead_code)]\n    pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {\n        Self::new(ClassificationMetricConfig {\n            decision_rule: DecisionRule::Threshold(threshold),\n            class_reduction,\n        })\n    }\n\n    fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {\n        use ClassReduction::{Macro, Micro};\n        let avg_tensor = match self.config.class_reduction {\n            Micro => aggregated_metric,\n            Macro => {\n                if aggregated_metric\n                    .clone()\n                    .contains_nan()\n                    .any()\n                    .into_scalar()\n                    .to_bool()\n                {\n                    let nan_mask = aggregated_metric.clone().is_nan();\n                    aggregated_metric = aggregated_metric\n                        .clone()\n                        .select(0, nan_mask.bool_not().argwhere().squeeze_dim(1))\n                }\n                aggregated_metric.mean()\n            }\n        };\n        avg_tensor.into_scalar().to_f64()\n    }\n}\n\nimpl<B: Backend> Metric for RecallMetric<B> {\n    type Input = ConfusionStatsInput<B>;\n\n    fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let [sample_size, _] = input.predictions.dims();\n\n        let cf_stats = ConfusionStats::new(input, &self.config);\n        let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive());\n\n        self.state.update(\n            100.0 * metric,\n            sample_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for RecallMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::{\n        ClassReduction::{self, *},\n        Metric, MetricMetadata, RecallMetric,\n    };\n    use crate::metric::Numeric;\n    use crate::{\n        TestBackend,\n        tests::{ClassificationType, THRESHOLD, dummy_classification_input},\n    };\n    use burn_core::tensor::{TensorData, Tolerance};\n    use rstest::rstest;\n\n    #[rstest]\n    #[case::binary(THRESHOLD, 0.5)]\n    fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) {\n        let input = dummy_classification_input(&ClassificationType::Binary).into();\n        let mut metric = RecallMetric::binary(threshold);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]\n    #[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)]\n    #[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)]\n    #[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)]\n    fn test_multiclass_recall(\n        #[case] class_reduction: ClassReduction,\n        #[case] top_k: usize,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multiclass).into();\n        let mut metric = RecallMetric::multiclass(top_k, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[rstest]\n    #[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)]\n    #[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)]\n    fn test_multilabel_recall(\n        #[case] class_reduction: ClassReduction,\n        #[case] threshold: f64,\n        #[case] expected: f64,\n    ) {\n        let input = dummy_classification_input(&ClassificationType::Multilabel).into();\n        let mut metric = RecallMetric::multilabel(threshold, class_reduction);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        TensorData::from([metric.value().current()])\n            .assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default())\n    }\n\n    #[test]\n    fn test_parameterized_unique_name() {\n        let metric_a = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);\n        let metric_b = RecallMetric::<TestBackend>::multiclass(2, ClassReduction::Macro);\n        let metric_c = RecallMetric::<TestBackend>::multiclass(1, ClassReduction::Macro);\n\n        assert_ne!(metric_a.name(), metric_b.name());\n        assert_eq!(metric_a.name(), metric_c.name());\n\n        let metric_a = RecallMetric::<TestBackend>::binary(0.5);\n        let metric_b = RecallMetric::<TestBackend>::binary(0.75);\n        assert_ne!(metric_a.name(), metric_b.name());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/rl/cum_reward.rs",
    "content": "use std::sync::Arc;\n\nuse super::super::{\n    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse crate::metric::{Metric, MetricName, Numeric, SerializedEntry};\n\n/// Metric for the cumulative reward of the last completed episode.\n#[derive(Clone)]\npub struct CumulativeRewardMetric {\n    name: MetricName,\n    state: NumericMetricState,\n}\n\nimpl CumulativeRewardMetric {\n    /// Creates a new episode length metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Cum. Reward\".to_string()),\n            state: NumericMetricState::new(),\n        }\n    }\n}\n\nimpl Default for CumulativeRewardMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\n/// The [CumulativeRewardMetric](CumulativeRewardMetric) input type.\n#[derive(new)]\npub struct CumulativeRewardInput {\n    cum_reward: f64,\n}\n\nimpl Metric for CumulativeRewardMetric {\n    type Input = CumulativeRewardInput;\n\n    fn update(\n        &mut self,\n        item: &CumulativeRewardInput,\n        _metadata: &MetricMetadata,\n    ) -> SerializedEntry {\n        self.state.update(\n            item.cum_reward,\n            1,\n            FormatOptions::new(self.name()).precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for CumulativeRewardMetric {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/rl/ep_len.rs",
    "content": "use std::sync::Arc;\n\nuse super::super::{\n    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse crate::metric::{Metric, MetricName, Numeric, SerializedEntry};\n\n/// Metric for the length of the last completed episode.\n#[derive(Clone)]\npub struct EpisodeLengthMetric {\n    name: MetricName,\n    state: NumericMetricState,\n}\n\nimpl EpisodeLengthMetric {\n    /// Creates a new episode length metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Episode length\".to_string()),\n            state: NumericMetricState::new(),\n        }\n    }\n}\n\nimpl Default for EpisodeLengthMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\n/// The [EpisodeLengthMetric](EpisodeLengthMetric) input type.\n#[derive(new)]\npub struct EpisodeLengthInput {\n    ep_len: f64,\n}\n\nimpl Metric for EpisodeLengthMetric {\n    type Input = EpisodeLengthInput;\n\n    fn update(&mut self, item: &EpisodeLengthInput, _metadata: &MetricMetadata) -> SerializedEntry {\n        self.state\n            .update(item.ep_len, 1, FormatOptions::new(self.name()).precision(0))\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(String::from(\"steps\")),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for EpisodeLengthMetric {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/rl/exploration_rate.rs",
    "content": "use std::sync::Arc;\n\nuse super::super::{\n    MetricAttributes, MetricMetadata, NumericAttributes, NumericEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse crate::metric::{Metric, MetricName, Numeric, SerializedEntry};\n\n/// Metric for the length of the last completed episode.\n#[derive(Clone)]\npub struct ExplorationRateMetric {\n    name: MetricName,\n    state: NumericMetricState,\n}\n\nimpl ExplorationRateMetric {\n    /// Creates a new episode length metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"Exploration rate\".to_string()),\n            state: NumericMetricState::new(),\n        }\n    }\n}\n\nimpl Default for ExplorationRateMetric {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\n/// The [ExplorationRateMetric](ExplorationRateMetric) input type.\n#[derive(new)]\npub struct ExplorationRateInput {\n    exploration_rate: f64,\n}\n\nimpl Metric for ExplorationRateMetric {\n    type Input = ExplorationRateInput;\n\n    fn update(\n        &mut self,\n        item: &ExplorationRateInput,\n        _metadata: &MetricMetadata,\n    ) -> SerializedEntry {\n        self.state.update(\n            item.exploration_rate,\n            1,\n            FormatOptions::new(self.name()).precision(3),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(String::from(\"%\")),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl Numeric for ExplorationRateMetric {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/rl/mod.rs",
    "content": "mod cum_reward;\nmod ep_len;\nmod exploration_rate;\n\npub use cum_reward::*;\npub use ep_len::*;\npub use exploration_rate::*;\n"
  },
  {
    "path": "crates/burn-train/src/metric/state.rs",
    "content": "use std::sync::Arc;\n\nuse crate::metric::{MetricName, NumericEntry, SerializedEntry, format_float};\n\n/// Useful utility to implement numeric metrics.\n///\n/// # Notes\n///\n/// The numeric metric store values inside floats.\n/// Even if some metric are integers, their mean are floats.\n#[derive(Clone)]\npub struct NumericMetricState {\n    sum: f64,\n    count: usize,\n    current: f64,\n    current_count: usize,\n}\n\n/// Formatting options for the [numeric metric state](NumericMetricState).\npub struct FormatOptions {\n    name: Arc<String>,\n    unit: Option<String>,\n    precision: Option<usize>,\n}\n\nimpl FormatOptions {\n    /// Create the [formatting options](FormatOptions) with a name.\n    pub fn new(name: MetricName) -> Self {\n        Self {\n            name: name.clone(),\n            unit: None,\n            precision: None,\n        }\n    }\n\n    /// Specify the metric unit.\n    pub fn unit(mut self, unit: &str) -> Self {\n        self.unit = Some(unit.to_string());\n        self\n    }\n\n    /// Specify the floating point precision.\n    pub fn precision(mut self, precision: usize) -> Self {\n        self.precision = Some(precision);\n        self\n    }\n\n    /// Get the metric name.\n    pub fn name(&self) -> &Arc<String> {\n        &self.name\n    }\n\n    /// Get the metric unit.\n    pub fn unit_value(&self) -> &Option<String> {\n        &self.unit\n    }\n\n    /// Get the precision.\n    pub fn precision_value(&self) -> Option<usize> {\n        self.precision\n    }\n}\n\nimpl NumericMetricState {\n    /// Create a new [numeric metric state](NumericMetricState).\n    pub fn new() -> Self {\n        Self {\n            sum: 0.0,\n            count: 0,\n            current: f64::NAN,\n            current_count: 0,\n        }\n    }\n\n    /// Reset the state.\n    pub fn reset(&mut self) {\n        self.sum = 0.0;\n        self.count = 0;\n        self.current = f64::NAN;\n        self.current_count = 0;\n    }\n\n    /// Update the state.\n    pub fn update(\n        &mut self,\n        value: f64,\n        batch_size: usize,\n        format: FormatOptions,\n    ) -> SerializedEntry {\n        self.sum += value * batch_size as f64;\n        self.count += batch_size;\n        self.current = value;\n        self.current_count = batch_size;\n\n        let value_current = value;\n        let value_running = self.sum / self.count as f64;\n        // Numeric metric state is an aggregated value\n        let serialized = NumericEntry::Aggregated {\n            aggregated_value: value_current,\n            count: batch_size,\n        }\n        .serialize();\n\n        let (formatted_current, formatted_running) = match format.precision {\n            Some(precision) => (\n                format_float(value_current, precision),\n                format_float(value_running, precision),\n            ),\n            None => (format!(\"{value_current}\"), format!(\"{value_running}\")),\n        };\n\n        // TODO: naming inconsistent with RL.\n        let formatted = match format.unit {\n            Some(unit) => {\n                format!(\"epoch {formatted_running} {unit} - batch {formatted_current} {unit}\")\n            }\n            None => format!(\"epoch {formatted_running} - batch {formatted_current}\"),\n        };\n\n        SerializedEntry::new(formatted, serialized)\n    }\n\n    /// Get the numeric value.\n    pub fn current_value(&self) -> NumericEntry {\n        NumericEntry::Aggregated {\n            aggregated_value: self.current,\n            count: self.current_count,\n        }\n    }\n\n    /// Get the running aggregated value.\n    pub fn running_value(&self) -> NumericEntry {\n        NumericEntry::Aggregated {\n            aggregated_value: self.sum / self.count as f64,\n            count: self.count,\n        }\n    }\n}\n\nimpl Default for NumericMetricState {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/store/aggregate.rs",
    "content": "use crate::{\n    logger::MetricLogger,\n    metric::{NumericEntry, store::Split},\n};\nuse std::collections::HashMap;\n\nuse super::{Aggregate, Direction};\n\n/// Type that can be used to fetch and use numeric metric aggregates.\n#[derive(Default, Debug)]\npub(crate) struct NumericMetricsAggregate {\n    value_for_each_epoch: HashMap<Key, f64>,\n}\n\n#[derive(new, Hash, PartialEq, Eq, Debug)]\nstruct Key {\n    name: String,\n    epoch: usize,\n    split: Split,\n    aggregate: Aggregate,\n}\n\nimpl NumericMetricsAggregate {\n    pub(crate) fn aggregate(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        split: &Split,\n        aggregate: Aggregate,\n        loggers: &mut [Box<dyn MetricLogger>],\n    ) -> Option<f64> {\n        let key = Key::new(name.to_string(), epoch, split.clone(), aggregate);\n\n        if let Some(value) = self.value_for_each_epoch.get(&key) {\n            return Some(*value);\n        }\n\n        let points = || {\n            let mut errors = Vec::new();\n            for logger in loggers {\n                match logger.read_numeric(name, epoch, split) {\n                    Ok(points) => return Ok(points),\n                    Err(err) => errors.push(err),\n                };\n            }\n\n            Err(errors.join(\" \"))\n        };\n\n        let points = points().expect(\"Can read values\");\n\n        if points.is_empty() {\n            return None;\n        }\n\n        // Accurately compute the aggregated value based on the *actual* number of points\n        // since not all mini-batches are guaranteed to have the specified batch size\n        let (sum, num_points) = points\n            .into_iter()\n            .map(|entry| match entry {\n                NumericEntry::Value(v) => (v, 1),\n                // Right now the mean is the only aggregate available, so we can assume that the sum\n                // of an entry corresponds to (value * number of elements)\n                NumericEntry::Aggregated {\n                    aggregated_value,\n                    count,\n                } => (aggregated_value * count as f64, count),\n            })\n            .reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))\n            .unwrap();\n        let value = match aggregate {\n            Aggregate::Mean => sum / num_points as f64,\n        };\n\n        self.value_for_each_epoch.insert(key, value);\n        Some(value)\n    }\n\n    pub(crate) fn find_epoch(\n        &mut self,\n        name: &str,\n        split: &Split,\n        aggregate: Aggregate,\n        direction: Direction,\n        loggers: &mut [Box<dyn MetricLogger>],\n    ) -> Option<usize> {\n        let mut data = Vec::new();\n        let mut current_epoch = 1;\n\n        while let Some(value) = self.aggregate(name, current_epoch, split, aggregate, loggers) {\n            data.push(value);\n            current_epoch += 1;\n        }\n\n        if data.is_empty() {\n            return None;\n        }\n\n        let mut current_value = match &direction {\n            Direction::Lowest => f64::MAX,\n            Direction::Highest => f64::MIN,\n        };\n\n        for (i, value) in data.into_iter().enumerate() {\n            match &direction {\n                Direction::Lowest => {\n                    if value < current_value {\n                        current_value = value;\n                        current_epoch = i + 1;\n                    }\n                }\n                Direction::Highest => {\n                    if value > current_value {\n                        current_value = value;\n                        current_epoch = i + 1;\n                    }\n                }\n            }\n        }\n\n        Some(current_epoch)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::sync::Arc;\n\n    use crate::{\n        logger::{FileMetricLogger, InMemoryMetricLogger},\n        metric::{MetricDefinition, MetricEntry, MetricId, SerializedEntry, store::MetricsUpdate},\n    };\n\n    use super::*;\n\n    struct TestLogger {\n        logger: FileMetricLogger,\n        epoch: usize,\n    }\n    const NAME: &str = \"test-logger\";\n\n    impl TestLogger {\n        fn new() -> Self {\n            Self {\n                logger: FileMetricLogger::new(\"/tmp\"),\n                epoch: 1,\n            }\n        }\n        fn log(&mut self, num: f64) {\n            let entry = MetricEntry::new(\n                MetricId::new(Arc::new(NAME.into())),\n                SerializedEntry::new(num.to_string(), num.to_string()),\n            );\n            let entries = Vec::from([entry]);\n            let metrics_update = MetricsUpdate::new(entries, vec![]);\n            self.logger.log(metrics_update, self.epoch, &Split::Train);\n        }\n        fn log_definition(&mut self) {\n            let definition = MetricDefinition {\n                metric_id: MetricId::new(Arc::new(NAME.into())),\n                name: NAME.into(),\n                attributes: crate::metric::MetricAttributes::None,\n                description: None,\n            };\n            self.logger.log_metric_definition(definition);\n        }\n        fn new_epoch(&mut self) {\n            self.epoch += 1;\n        }\n    }\n\n    #[test]\n    fn should_find_epoch() {\n        let mut logger = TestLogger::new();\n        let mut aggregate = NumericMetricsAggregate::default();\n        logger.log_definition();\n\n        logger.log(500.); // Epoch 1\n        logger.log(1000.); // Epoch 1\n        logger.new_epoch();\n        logger.log(200.); // Epoch 2\n        logger.log(1000.); // Epoch 2\n        logger.new_epoch();\n        logger.log(10000.); // Epoch 3\n\n        let value = aggregate\n            .find_epoch(\n                NAME,\n                &Split::Train,\n                Aggregate::Mean,\n                Direction::Lowest,\n                &mut [Box::new(logger.logger)],\n            )\n            .unwrap();\n\n        assert_eq!(value, 2);\n    }\n\n    #[test]\n    fn should_aggregate_numeric_entry() {\n        let mut logger = InMemoryMetricLogger::default();\n        let mut aggregate = NumericMetricsAggregate::default();\n        let metric_name = Arc::new(\"Loss\".to_string());\n        let metric_id = MetricId::new(metric_name.clone());\n        let definition = MetricDefinition {\n            metric_id: metric_id.clone(),\n            name: metric_name.to_string(),\n            attributes: crate::metric::MetricAttributes::None,\n            description: None,\n        };\n        logger.log_metric_definition(definition);\n\n        // Epoch 1\n        let loss_1 = 0.5;\n        let loss_2 = 1.25; // (1.5 + 1.0) / 2 = 2.5 / 2\n        let entry = MetricEntry::new(\n            metric_id.clone(),\n            SerializedEntry::new(loss_1.to_string(), NumericEntry::Value(loss_1).serialize()),\n        );\n        let entries = Vec::from([entry]);\n        let metrics_update = MetricsUpdate::new(entries, vec![]);\n        logger.log(metrics_update, 1, &Split::Train);\n        let entry = MetricEntry::new(\n            metric_id.clone(),\n            SerializedEntry::new(\n                loss_2.to_string(),\n                NumericEntry::Aggregated {\n                    aggregated_value: loss_2,\n                    count: 2,\n                }\n                .serialize(),\n            ),\n        );\n        let entries = Vec::from([entry]);\n        let metrics_update = MetricsUpdate::new(entries, vec![]);\n        logger.log(metrics_update, 1, &Split::Train);\n\n        let value = aggregate\n            .aggregate(\n                &metric_name,\n                1,\n                &Split::Train,\n                Aggregate::Mean,\n                &mut [Box::new(logger)],\n            )\n            .unwrap();\n\n        // Average should be (0.5 + 1.25 * 2) / 3 = 1.0, not (0.5 + 1.25) / 2 = 0.875\n        assert_eq!(value, 1.0);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/store/base.rs",
    "content": "use std::sync::Arc;\n\nuse crate::metric::{MetricDefinition, MetricEntry, NumericEntry};\n\n/// Event happening during the training/validation process.\npub enum Event {\n    /// Signal the iniialization of the metrics\n    MetricsInit(Vec<MetricDefinition>),\n    /// Signal that metrics have been updated.\n    MetricsUpdate(MetricsUpdate),\n    /// Signal the end of an epoch.\n    EndEpoch(EpochSummary),\n}\n\n/// Contains all metric information.\n#[derive(new, Clone, Debug)]\npub struct NumericMetricUpdate {\n    /// Generic metric information.\n    pub entry: MetricEntry,\n    /// The numeric information.\n    pub numeric_entry: NumericEntry,\n    /// Numeric value averaged over the global step (epoch).\n    pub running_entry: NumericEntry,\n}\n\n/// Contains all metric information.\n#[derive(new, Clone, Debug)]\npub struct MetricsUpdate {\n    /// Metrics information related to non-numeric metrics.\n    pub entries: Vec<MetricEntry>,\n    /// Metrics information related to numeric metrics.\n    pub entries_numeric: Vec<NumericMetricUpdate>,\n}\n\n/// Summary information about a given epoch\n#[derive(new, Clone, Debug)]\npub struct EpochSummary {\n    /// Epoch number.\n    pub epoch_number: usize,\n    /// Dataset split (train, valid, test).\n    pub split: Split,\n}\n\n/// Defines how training and validation events are collected and searched.\n///\n/// This trait also exposes methods that uses the collected data to compute useful information.\npub trait EventStore: Send {\n    /// Collect a training/validation event.\n    fn add_event(&mut self, event: Event, split: Split);\n\n    /// Find the epoch following the given criteria from the collected data.\n    fn find_epoch(\n        &mut self,\n        name: &str,\n        aggregate: Aggregate,\n        direction: Direction,\n        split: &Split,\n    ) -> Option<usize>;\n\n    /// Find the metric value for the current epoch following the given criteria.\n    fn find_metric(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        aggregate: Aggregate,\n        split: &Split,\n    ) -> Option<f64>;\n}\n\n#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]\n/// How to aggregate the metric.\npub enum Aggregate {\n    /// Compute the average.\n    Mean,\n}\n\n#[derive(Clone, Debug, Hash, PartialEq, Eq)]\n/// The split to use.\npub enum Split {\n    /// The training split.\n    Train,\n    /// The validation split.\n    Valid,\n    /// The testing split, which might be tagged.\n    Test(Option<Arc<String>>),\n}\n\nimpl std::fmt::Display for Split {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Split::Train => write!(f, \"train\"),\n            Split::Valid => write!(f, \"valid\"),\n            Split::Test(_) => write!(f, \"test\"),\n        }\n    }\n}\n\n#[derive(Copy, Clone)]\n/// The direction of the query.\npub enum Direction {\n    /// Lower is better.\n    Lowest,\n    /// Higher is better.\n    Highest,\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/store/client.rs",
    "content": "use super::EventStore;\nuse super::{Aggregate, Direction, Event, Split};\nuse std::sync::Arc;\nuse std::{sync::mpsc, thread::JoinHandle};\n\n/// Type that allows to communicate with an [event store](EventStore).\npub struct EventStoreClient {\n    sender: mpsc::Sender<Message>,\n    handler: Option<JoinHandle<()>>,\n}\n\nimpl EventStoreClient {\n    /// Create a new [event store](EventStore) client.\n    pub(crate) fn new<C>(store: C) -> Self\n    where\n        C: EventStore + 'static,\n    {\n        let (sender, receiver) = mpsc::channel();\n        let thread = WorkerThread::new(store, receiver);\n\n        let handler = std::thread::spawn(move || thread.run());\n        let handler = Some(handler);\n\n        Self { sender, handler }\n    }\n}\n\nimpl EventStoreClient {\n    /// Add a training event to the [event store](EventStore).\n    pub(crate) fn add_event_train(&self, event: Event) {\n        self.sender\n            .send(Message::OnEventTrain(event))\n            .expect(\"Can send event to event store thread.\");\n    }\n\n    /// Add a validation event to the [event store](EventStore).\n    pub(crate) fn add_event_valid(&self, event: Event) {\n        self.sender\n            .send(Message::OnEventValid(event))\n            .expect(\"Can send event to event store thread.\");\n    }\n\n    /// Add a testing event to the [event store](EventStore).\n    pub(crate) fn add_event_test(&self, event: Event, tag: Arc<String>) {\n        self.sender\n            .send(Message::OnEventTest(event, tag))\n            .expect(\"Can send event to event store thread.\");\n    }\n\n    /// Find the epoch following the given criteria from the collected data.\n    pub fn find_epoch(\n        &self,\n        name: &str,\n        aggregate: Aggregate,\n        direction: Direction,\n        split: &Split,\n    ) -> Option<usize> {\n        let (sender, receiver) = mpsc::sync_channel(1);\n        self.sender\n            .send(Message::FindEpoch(\n                name.to_string(),\n                aggregate,\n                direction,\n                split.clone(),\n                sender,\n            ))\n            .expect(\"Can send event to event store thread.\");\n\n        match receiver.recv() {\n            Ok(value) => value,\n            Err(err) => panic!(\"Event store thread crashed: {err:?}\"),\n        }\n    }\n\n    /// Find the metric value for the current epoch following the given criteria.\n    pub fn find_metric(\n        &self,\n        name: &str,\n        epoch: usize,\n        aggregate: Aggregate,\n        split: &Split,\n    ) -> Option<f64> {\n        let (sender, receiver) = mpsc::sync_channel(1);\n        self.sender\n            .send(Message::FindMetric(\n                name.to_string(),\n                epoch,\n                aggregate,\n                split.clone(),\n                sender,\n            ))\n            .expect(\"Can send event to event store thread.\");\n\n        match receiver.recv() {\n            Ok(value) => value,\n            Err(err) => panic!(\"Event store thread crashed: {err:?}\"),\n        }\n    }\n}\n\n#[derive(new)]\nstruct WorkerThread<S> {\n    store: S,\n    receiver: mpsc::Receiver<Message>,\n}\n\nimpl<C> WorkerThread<C>\nwhere\n    C: EventStore,\n{\n    fn run(mut self) {\n        for item in self.receiver.iter() {\n            match item {\n                Message::End => {\n                    return;\n                }\n                Message::FindEpoch(name, aggregate, direction, split, callback) => {\n                    let response = self.store.find_epoch(&name, aggregate, direction, &split);\n                    callback\n                        .send(response)\n                        .expect(\"Can send response using callback channel.\");\n                }\n                Message::FindMetric(name, epoch, aggregate, split, callback) => {\n                    let response = self.store.find_metric(&name, epoch, aggregate, &split);\n                    callback\n                        .send(response)\n                        .expect(\"Can send response using callback channel.\");\n                }\n                Message::OnEventTrain(event) => self.store.add_event(event, Split::Train),\n                Message::OnEventValid(event) => self.store.add_event(event, Split::Valid),\n                Message::OnEventTest(event, tag) => {\n                    self.store.add_event(event, Split::Test(Some(tag)))\n                }\n            }\n        }\n    }\n}\n\nenum Message {\n    OnEventTest(Event, Arc<String>),\n    OnEventTrain(Event),\n    OnEventValid(Event),\n    End,\n    FindEpoch(\n        String,\n        Aggregate,\n        Direction,\n        Split,\n        mpsc::SyncSender<Option<usize>>,\n    ),\n    FindMetric(\n        String,\n        usize,\n        Aggregate,\n        Split,\n        mpsc::SyncSender<Option<f64>>,\n    ),\n}\n\nimpl Drop for EventStoreClient {\n    fn drop(&mut self) {\n        self.sender\n            .send(Message::End)\n            .expect(\"Can send the end message to the event store thread.\");\n        let handler = self.handler.take();\n\n        if let Some(handler) = handler {\n            handler.join().expect(\"The event store thread should stop.\");\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/store/log.rs",
    "content": "use std::collections::HashMap;\n\nuse super::{Aggregate, Direction, Event, EventStore, Split, aggregate::NumericMetricsAggregate};\nuse crate::logger::MetricLogger;\n\n#[derive(Default)]\npub(crate) struct LogEventStore {\n    loggers: Vec<Box<dyn MetricLogger>>,\n    aggregate: NumericMetricsAggregate,\n    epochs: HashMap<Split, usize>,\n}\n\nimpl EventStore for LogEventStore {\n    fn add_event(&mut self, event: Event, split: Split) {\n        let epoch = *self.epochs.entry(split.clone()).or_insert(1);\n\n        match event {\n            Event::MetricsInit(definitions) => {\n                definitions.iter().for_each(|def| {\n                    self.loggers\n                        .iter_mut()\n                        .for_each(|logger| logger.log_metric_definition(def.clone()));\n                });\n            }\n            Event::MetricsUpdate(update) => {\n                self.loggers\n                    .iter_mut()\n                    .for_each(|logger| logger.log(update.clone(), epoch, &split));\n            }\n            Event::EndEpoch(summary) => {\n                self.epochs.insert(split, summary.epoch_number + 1);\n                self.loggers\n                    .iter_mut()\n                    .for_each(|logger| logger.log_epoch_summary(summary.clone()));\n            }\n        }\n    }\n\n    fn find_epoch(\n        &mut self,\n        name: &str,\n        aggregate: Aggregate,\n        direction: Direction,\n        split: &Split,\n    ) -> Option<usize> {\n        self.aggregate\n            .find_epoch(name, split, aggregate, direction, &mut self.loggers)\n    }\n\n    fn find_metric(\n        &mut self,\n        name: &str,\n        epoch: usize,\n        aggregate: Aggregate,\n        split: &Split,\n    ) -> Option<f64> {\n        self.aggregate\n            .aggregate(name, epoch, split, aggregate, &mut self.loggers)\n    }\n}\n\nimpl LogEventStore {\n    /// Register a logger for metrics.\n    pub(crate) fn register_logger<ML: MetricLogger + 'static>(&mut self, logger: ML) {\n        self.loggers.push(Box::new(logger));\n    }\n\n    /// Returns whether any loggers are registered.\n    pub(crate) fn has_loggers(&self) -> bool {\n        !self.loggers.is_empty()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/store/mod.rs",
    "content": "pub(crate) mod aggregate;\n\nmod base;\nmod client;\nmod log;\n\npub(crate) use self::log::*;\npub use base::*;\npub use client::*;\n"
  },
  {
    "path": "crates/burn-train/src/metric/top_k_acc.rs",
    "content": "use core::marker::PhantomData;\nuse std::sync::Arc;\n\nuse super::state::{FormatOptions, NumericMetricState};\nuse super::{MetricMetadata, SerializedEntry};\nuse crate::metric::{\n    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,\n};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{ElementConversion, Int, Tensor};\n\n/// The Top-K accuracy metric.\n///\n/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).\n#[derive(Default, Clone)]\npub struct TopKAccuracyMetric<B: Backend> {\n    name: Arc<String>,\n    k: usize,\n    state: NumericMetricState,\n    /// If specified, targets equal to this value will be considered padding and will not count\n    /// towards the metric\n    pad_token: Option<usize>,\n    _b: PhantomData<B>,\n}\n\n/// The [top-k accuracy metric](TopKAccuracyMetric) input type.\n#[derive(new)]\npub struct TopKAccuracyInput<B: Backend> {\n    /// The outputs (batch_size, num_classes)\n    outputs: Tensor<B, 2>,\n    /// The labels (batch_size)\n    targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> TopKAccuracyMetric<B> {\n    /// Creates the metric.\n    pub fn new(k: usize) -> Self {\n        Self {\n            name: Arc::new(format!(\"Top-K Accuracy @ TopK({})\", k)),\n            k,\n            ..Default::default()\n        }\n    }\n\n    /// Sets the pad token.\n    pub fn with_pad_token(mut self, index: usize) -> Self {\n        self.pad_token = Some(index);\n        self\n    }\n}\n\nimpl<B: Backend> Metric for TopKAccuracyMetric<B> {\n    type Input = TopKAccuracyInput<B>;\n\n    fn update(\n        &mut self,\n        input: &TopKAccuracyInput<B>,\n        _metadata: &MetricMetadata,\n    ) -> SerializedEntry {\n        let [batch_size, _n_classes] = input.outputs.dims();\n\n        let targets = input.targets.clone().to_device(&B::Device::default());\n\n        let outputs = input\n            .outputs\n            .clone()\n            .argsort_descending(1)\n            .narrow(1, 0, self.k)\n            .to_device(&B::Device::default())\n            .reshape([batch_size, self.k]);\n\n        let (targets, num_pad) = match self.pad_token {\n            Some(pad_token) => {\n                // we ignore the samples where the target is equal to the pad token\n                let mask = targets.clone().equal_elem(pad_token as i64);\n                let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();\n                (targets.clone().mask_fill(mask, -1_i64), num_pad)\n            }\n            None => (targets.clone(), 0_f64),\n        };\n\n        let accuracy = targets\n            .reshape([batch_size, 1])\n            .repeat_dim(1, self.k)\n            .equal(outputs)\n            .int()\n            .sum()\n            .into_scalar()\n            .elem::<f64>()\n            / (batch_size as f64 - num_pad);\n\n        self.state.update(\n            100.0 * accuracy,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn clear(&mut self) {\n        self.state.reset()\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for TopKAccuracyMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    #[test]\n    fn test_accuracy_without_padding() {\n        let device = Default::default();\n        let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);\n        let input = TopKAccuracyInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.2, 0.8], // 2, 1\n                    [1.0, 2.0, 0.5], // 1, 0\n                    [0.4, 0.1, 0.2], // 0, 2\n                    [0.6, 0.7, 0.2], // 1, 0\n                ],\n                &device,\n            ),\n            Tensor::from_data([2, 2, 1, 1], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    #[test]\n    fn test_accuracy_with_padding() {\n        let device = Default::default();\n        let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);\n        let input = TopKAccuracyInput::new(\n            Tensor::from_data(\n                [\n                    [0.0, 0.2, 0.8, 0.0], // 2, 1\n                    [1.0, 2.0, 0.5, 0.0], // 1, 0\n                    [0.4, 0.1, 0.2, 0.0], // 0, 2\n                    [0.6, 0.7, 0.2, 0.0], // 1, 0\n                    [0.0, 0.1, 0.2, 5.0], // Predicted padding should not count\n                    [0.0, 0.1, 0.2, 0.0], // Error on padding should not count\n                    [0.6, 0.0, 0.2, 0.0], // Error on padding should not count\n                ],\n                &device,\n            ),\n            Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),\n        );\n\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    #[test]\n    fn test_parameterized_unique_name() {\n        let metric_a = TopKAccuracyMetric::<TestBackend>::new(2);\n        let metric_b = TopKAccuracyMetric::<TestBackend>::new(1);\n        let metric_c = TopKAccuracyMetric::<TestBackend>::new(2);\n\n        assert_ne!(metric_a.name(), metric_b.name());\n        assert_eq!(metric_a.name(), metric_c.name());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dice.rs",
    "content": "use crate::metric::{MetricAttributes, MetricName, SerializedEntry};\n\nuse super::super::{\n    Metric, MetricMetadata,\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::{ElementConversion, Int, s},\n};\nuse core::marker::PhantomData;\n\n/// Input type for the [DiceMetric].\n///\n/// # Type Parameters\n/// - `B`: Backend type.\n/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).\npub struct DiceInput<B: Backend, const D: usize = 4> {\n    /// Model outputs (predictions), as a tensor.\n    outputs: Tensor<B, D, Int>,\n    /// Ground truth targets, as a tensor.\n    targets: Tensor<B, D, Int>,\n}\n\nimpl<B: Backend, const D: usize> DiceInput<B, D> {\n    /// Creates a new DiceInput with the given outputs and targets.\n    ///\n    /// Inputs are expected to have the dimensions `[B, C, ...]`\n    /// where `B` is the batch size, `C` is the number of classes,\n    /// and `...` represents additional dimensions (e.g., height, width for images).\n    ///\n    /// If `C` is more than 1, the first class (index 0) is considered the background.\n    /// Additionally, one-hot encoding is the responsibility of the caller.\n    ///\n    /// # Arguments\n    /// - `outputs`: The model outputs as a tensor.\n    /// - `targets`: The ground truth targets as a tensor.\n    ///\n    /// # Returns\n    /// A new instance of `DiceInput`.\n    ///\n    ///  # Panics\n    /// - If `D` is less than 3.\n    /// - If `outputs` and `targets` do not have the same dimensions.\n    /// - If `outputs` or `targets` do not have exactly `D` dimensions.\n    /// - If `outputs` and `targets` do not have the same shape.\n    pub fn new(outputs: Tensor<B, D, Int>, targets: Tensor<B, D, Int>) -> Self {\n        assert!(D >= 3, \"DiceInput requires at least 3 dimensions.\");\n        assert!(\n            outputs.dims() == targets.dims(),\n            \"Outputs and targets must have the same dimensions. Got {:?} and {:?}\",\n            outputs.dims(),\n            targets.dims()\n        );\n        Self { outputs, targets }\n    }\n}\n\n/// Configuration for the [DiceMetric].\n#[derive(Debug, Clone, Copy)]\npub struct DiceMetricConfig {\n    /// Epsilon value to avoid division by zero.\n    pub epsilon: f64,\n    /// Whether to include the background class in the metric calculation.\n    /// The background is assumed to be the first class (index 0).\n    /// if `true`, will panic if there are fewer than 2 classes.\n    pub include_background: bool,\n}\n\nimpl Default for DiceMetricConfig {\n    fn default() -> Self {\n        Self {\n            epsilon: 1e-7,\n            include_background: false,\n        }\n    }\n}\n\n/// The Dice-Sorenson coefficient (DSC) for evaluating overlap between two binary masks.\n/// The DSC is defined as:\n/// `DSC = 2 * (|X ∩ Y|) / (|X| + |Y|)`\n/// where `X` is the model output and `Y` is the ground truth target.\n///\n///  # Type Parameters\n/// - `B`: Backend type.\n/// - `D`: Number of dimensions. Should be more than, or equal to 3 (default 4).\n#[derive(Default, Clone)]\npub struct DiceMetric<B: Backend, const D: usize = 4> {\n    name: MetricName,\n    /// Internal state for numeric metric aggregation.\n    state: NumericMetricState,\n    /// Marker for backend type.\n    _b: PhantomData<B>,\n    /// Configuration for the metric.\n    config: DiceMetricConfig,\n}\n\nimpl<B: Backend, const D: usize> DiceMetric<B, D> {\n    /// Creates a new Dice metric instance with default config.\n    pub fn new() -> Self {\n        Self::with_config(DiceMetricConfig::default())\n    }\n\n    /// Creates a new Dice metric with a custom config.\n    pub fn with_config(config: DiceMetricConfig) -> Self {\n        let name = MetricName::new(format!(\"{D}D Dice Metric\"));\n        assert!(D >= 3, \"DiceMetric requires at least 3 dimensions.\");\n        Self {\n            name,\n            config,\n            ..Default::default()\n        }\n    }\n}\n\nimpl<B: Backend, const D: usize> Metric for DiceMetric<B, D> {\n    type Input = DiceInput<B, D>;\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        // Dice coefficient: 2 * (|X ∩ Y|) / (|X| + |Y|)\n        if item.outputs.dims() != item.targets.dims() {\n            panic!(\n                \"Outputs and targets must have the same dimensions. Got {:?} and {:?}\",\n                item.outputs.dims(),\n                item.targets.dims()\n            );\n        }\n\n        let dims = item.outputs.dims();\n        let batch_size = dims[0];\n        let n_classes = dims[1];\n\n        let mut outputs = item.outputs.clone();\n        let mut targets = item.targets.clone();\n\n        if !self.config.include_background && n_classes > 1 {\n            // If not including background, we can ignore the first class\n            outputs = outputs.slice(s![.., 1..]);\n            targets = targets.slice(s![.., 1..]);\n        } else if self.config.include_background && n_classes < 2 {\n            // If including background, we need at least 2 classes\n            panic!(\"Dice metric requires at least 2 classes when including background.\");\n        }\n\n        let intersection = (outputs.clone() * targets.clone()).sum();\n        let outputs_sum = outputs.sum();\n        let targets_sum = targets.sum();\n\n        // Convert to f64\n        let intersection_val = intersection.into_scalar().elem::<f64>();\n        let outputs_sum_val = outputs_sum.into_scalar().elem::<f64>();\n        let targets_sum_val = targets_sum.into_scalar().elem::<f64>();\n\n        // Use epsilon from config\n        let epsilon = self.config.epsilon;\n        let dice =\n            (2.0 * intersection_val + epsilon) / (outputs_sum_val + targets_sum_val + epsilon);\n\n        self.state.update(\n            dice,\n            batch_size,\n            FormatOptions::new(self.name()).precision(4),\n        )\n    }\n\n    /// Clears the metric state.\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        crate::metric::NumericAttributes {\n            unit: None,\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend, const D: usize> crate::metric::Numeric for DiceMetric<B, D> {\n    fn value(&self) -> crate::metric::NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> crate::metric::NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, metric::Numeric};\n    use burn_core::tensor::{Shape, Tensor};\n\n    #[test]\n    fn test_dice_perfect_overlap() {\n        let device = Default::default();\n        let mut metric = DiceMetric::<TestBackend, 4>::new();\n        let input = DiceInput::new(\n            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),\n            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!((metric.value().current() - 1.0).abs() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_no_overlap() {\n        let device = Default::default();\n        let mut metric = DiceMetric::<TestBackend, 4>::new();\n        let input = DiceInput::new(\n            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),\n            Tensor::from_data([[[[0, 1], [0, 1]]]], &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!(metric.value().current() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_partial_overlap() {\n        let device = Default::default();\n        let mut metric = DiceMetric::<TestBackend, 4>::new();\n        let input = DiceInput::new(\n            Tensor::from_data([[[[1, 1], [0, 0]]]], &device),\n            Tensor::from_data([[[[1, 0], [1, 0]]]], &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        // intersection = 1, sum = 2+2=4, dice = 2*1/4 = 0.5\n        assert!((metric.value().current() - 0.5).abs() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_empty_masks() {\n        let device = Default::default();\n        let mut metric = DiceMetric::<TestBackend, 4>::new();\n        let input = DiceInput::new(\n            Tensor::from_data([[[[0, 0], [0, 0]]]], &device),\n            Tensor::from_data([[[[0, 0], [0, 0]]]], &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!((metric.value().current() - 1.0).abs() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_no_background() {\n        let device = Default::default();\n        let mut metric = DiceMetric::<TestBackend, 4>::new();\n        let input = DiceInput::new(\n            Tensor::ones(Shape::new([1, 1, 2, 2]), &device),\n            Tensor::ones(Shape::new([1, 1, 2, 2]), &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!((metric.value().current() - 1.0).abs() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_with_background() {\n        let device = Default::default();\n        let config = DiceMetricConfig {\n            epsilon: 1e-7,\n            include_background: true,\n        };\n        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);\n        let input = DiceInput::new(\n            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),\n            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!((metric.value().current() - 1.0).abs() < 1e-6);\n    }\n\n    #[test]\n    fn test_dice_ignored_background() {\n        let device = Default::default();\n        let config = DiceMetricConfig {\n            epsilon: 1e-7,\n            include_background: false,\n        };\n        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);\n        let input = DiceInput::new(\n            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),\n            Tensor::ones(Shape::new([1, 2, 2, 2]), &device),\n        );\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n        assert!((metric.value().current() - 1.0).abs() < 1e-6);\n    }\n\n    #[test]\n    #[should_panic(expected = \"DiceInput requires at least 3 dimensions.\")]\n    fn test_invalid_input_dimensions() {\n        let device = Default::default();\n        // D = 2, should panic\n        let _ = DiceInput::<TestBackend, 2>::new(\n            Tensor::from_data([[0.0, 0.0]], &device),\n            Tensor::from_data([[0.0, 0.0]], &device),\n        );\n    }\n\n    #[test]\n    #[should_panic(\n        expected = \"Outputs and targets must have the same dimensions. Got [1, 1, 2, 2] and [1, 1, 2, 3]\"\n    )]\n    fn test_mismatched_shape() {\n        let device = Default::default();\n        // shapes differ\n        let _ = DiceInput::<TestBackend, 4>::new(\n            Tensor::from_data([[[[0.0; 2]; 2]; 1]; 1], &device),\n            Tensor::from_data([[[[0.0; 3]; 2]; 1]; 1], &device),\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"Dice metric requires at least 2 classes when including background.\")]\n    fn test_include_background_panic() {\n        let device = Default::default();\n        let config = DiceMetricConfig {\n            epsilon: 1e-7,\n            include_background: true,\n        };\n        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);\n        let input = DiceInput::new(\n            Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),\n            Tensor::from_data([[[[1.0; 2]; 1]; 1]; 1], &device),\n        );\n        // n_classes = 2, should not panic\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let config = DiceMetricConfig {\n            epsilon: 1e-7,\n            include_background: true,\n        };\n        let mut metric = DiceMetric::<TestBackend, 4>::with_config(config);\n        let input = DiceInput::new(\n            Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),\n            Tensor::from_data([[[[1.0; 1]; 1]; 1]; 1], &device),\n        );\n        // n_classes = 1, should panic\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dists/l2pool.rs",
    "content": "//! L2 Pooling layer for DISTS.\n//!\n//! L2 Pooling applies a Hanning window filter and computes the L2 norm\n//! across the pooling window. This is used in DISTS instead of MaxPooling.\n\nuse burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_nn::PaddingConfig2d;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\n\n/// L2 Pooling layer configuration.\n#[derive(Debug, Clone)]\npub struct L2Pool2dConfig {\n    /// Kernel size for pooling\n    pub kernel_size: usize,\n    /// Stride for pooling\n    pub stride: usize,\n    /// Padding for pooling\n    pub padding: usize,\n}\n\nimpl Default for L2Pool2dConfig {\n    fn default() -> Self {\n        Self {\n            kernel_size: 5,\n            stride: 2,\n            padding: 2,\n        }\n    }\n}\n\nimpl L2Pool2dConfig {\n    /// Create a new L2Pool2d configuration.\n    #[allow(dead_code)]\n    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {\n        Self {\n            kernel_size,\n            stride,\n            padding,\n        }\n    }\n\n    /// Initialize the L2Pool2d layer.\n    pub fn init<B: Backend>(&self, channels: usize, device: &B::Device) -> L2Pool2d<B> {\n        L2Pool2d::new(\n            channels,\n            self.kernel_size,\n            self.stride,\n            self.padding,\n            device,\n        )\n    }\n}\n\n/// L2 Pooling layer.\n///\n/// Applies a 2D Hanning window filter followed by L2 normalization.\n/// This provides smoother downsampling compared to MaxPooling.\n#[derive(Module, Debug)]\npub struct L2Pool2d<B: Backend> {\n    /// Depthwise convolution with Hanning kernel\n    conv: Conv2d<B>,\n}\n\nimpl<B: Backend> L2Pool2d<B> {\n    /// Create a new L2Pool2d layer with Hanning window kernel.\n    pub fn new(\n        channels: usize,\n        kernel_size: usize,\n        stride: usize,\n        padding: usize,\n        device: &B::Device,\n    ) -> Self {\n        // Create Hanning kernel\n        let kernel = Self::create_hanning_kernel(channels, kernel_size, device);\n\n        // Create depthwise convolution (groups = channels)\n        let mut conv = Conv2dConfig::new([channels, channels], [kernel_size, kernel_size])\n            .with_stride([stride, stride])\n            .with_padding(PaddingConfig2d::Explicit(\n                padding, padding, padding, padding,\n            ))\n            .with_groups(channels)\n            .with_bias(false)\n            .init(device);\n\n        // Set the kernel weights to Hanning window\n        conv.weight = burn::module::Param::from_tensor(kernel);\n\n        Self { conv }\n    }\n\n    /// Create a Hanning kernel for depthwise convolution.\n    /// Output shape: [channels, 1, kernel_size, kernel_size]\n    fn create_hanning_kernel<B2: Backend>(\n        channels: usize,\n        kernel_size: usize,\n        device: &B2::Device,\n    ) -> Tensor<B2, 4> {\n        // Create 1D Hanning window\n        let mut hanning_1d = Vec::with_capacity(kernel_size);\n        for i in 0..kernel_size {\n            let n = i as f32;\n            let n_minus_1 = (kernel_size - 1) as f32;\n            let value = if n_minus_1 == 0.0 {\n                1.0\n            } else {\n                0.5 * (1.0 - (2.0 * std::f32::consts::PI * n / n_minus_1).cos())\n            };\n            hanning_1d.push(value);\n        }\n\n        // Create 2D Hanning window by outer product\n        let mut hanning_2d = Vec::with_capacity(kernel_size * kernel_size);\n        let mut sum = 0.0;\n        for i in 0..kernel_size {\n            for j in 0..kernel_size {\n                let value = hanning_1d[i] * hanning_1d[j];\n                hanning_2d.push(value);\n                sum += value;\n            }\n        }\n\n        // Normalize\n        for v in hanning_2d.iter_mut() {\n            *v /= sum;\n        }\n\n        // Create tensor of shape [1, 1, kernel_size, kernel_size]\n        let kernel_single = Tensor::<B2, 1>::from_floats(hanning_2d.as_slice(), device).reshape([\n            1,\n            1,\n            kernel_size,\n            kernel_size,\n        ]);\n\n        // Expand to [channels, 1, kernel_size, kernel_size]\n        kernel_single.repeat_dim(0, channels)\n    }\n\n    /// Apply L2 pooling to the input tensor.\n    ///\n    /// # Arguments\n    ///\n    /// * `x` - Input tensor of shape `[batch, channels, height, width]`\n    ///\n    /// # Returns\n    ///\n    /// Pooled tensor with reduced spatial dimensions.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        // Square the input\n        let x_sq = x.clone().mul(x);\n\n        // Apply depthwise convolution with Hanning kernel\n        let pooled = self.conv.forward(x_sq);\n\n        // Take square root for L2 norm\n        // Add small epsilon to avoid sqrt of negative numbers due to numerical errors\n        pooled.clamp_min(1e-10).sqrt()\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dists/metric.rs",
    "content": "//! DISTS (Deep Image Structure and Texture Similarity) metric.\n//!\n//! DISTS is a full-reference image quality assessment metric that combines\n//! structure and texture similarity using deep features from VGG16.\n//!\n//! Reference: \"Image Quality Assessment: Unifying Structure and Texture Similarity\"\n//! https://arxiv.org/abs/2004.07728\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay, Param};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_nn::loss::Reduction;\n\nuse super::vgg16_l2pool::Vgg16L2PoolExtractor;\n\n/// Channel counts for each stage: [input, stage1, stage2, stage3, stage4, stage5]\nconst CHANNELS: [usize; 6] = [3, 64, 128, 256, 512, 512];\n\n/// Small constant for numerical stability in structure similarity.\nconst C1: f32 = 1e-6;\n\n/// Small constant for numerical stability in texture similarity.\nconst C2: f32 = 1e-6;\n\n/// ImageNet normalization constants.\nconst IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];\nconst IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];\n\n/// Image normalizer with pre-initialized mean and std tensors.\n///\n/// This struct holds the mean and std tensors for normalization,\n/// avoiding the need to create them on each forward pass.\n#[derive(Module, Debug)]\npub struct Normalizer<B: Backend> {\n    /// Mean tensor of shape [1, 3, 1, 1] for broadcasting.\n    pub mean: Tensor<B, 4>,\n    /// Std tensor of shape [1, 3, 1, 1] for broadcasting.\n    pub std: Tensor<B, 4>,\n}\n\nimpl<B: Backend> Normalizer<B> {\n    /// Create a new ImageNet normalizer.\n    pub fn imagenet(device: &B::Device) -> Self {\n        // Shape: [1, 3, 1, 1] for broadcasting over [batch, channels, height, width]\n        let mean = Tensor::from_floats(\n            [[\n                [[IMAGENET_MEAN[0]]],\n                [[IMAGENET_MEAN[1]]],\n                [[IMAGENET_MEAN[2]]],\n            ]],\n            device,\n        );\n        let std = Tensor::from_floats(\n            [[\n                [[IMAGENET_STD[0]]],\n                [[IMAGENET_STD[1]]],\n                [[IMAGENET_STD[2]]],\n            ]],\n            device,\n        );\n        Self { mean, std }\n    }\n\n    /// Normalize a tensor: (x - mean) / std\n    pub fn normalize(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        x.sub(self.mean.clone()).div(self.std.clone())\n    }\n}\n\n/// Configuration for DISTS metric.\n#[derive(Config, Debug)]\npub struct DistsConfig {\n    /// Whether to apply ImageNet normalization to input images.\n    #[config(default = true)]\n    pub normalize: bool,\n}\n\nimpl DistsConfig {\n    /// Initialize a DISTS module with default weights.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Dists<B> {\n        let total_channels: usize = CHANNELS.iter().sum();\n\n        // Initialize alpha and beta with constant value 0.1 for all channels\n        let alpha_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();\n        let beta_data: Vec<f32> = (0..total_channels).map(|_| 0.1).collect();\n\n        let normalizer = if self.normalize {\n            Some(Normalizer::imagenet(device))\n        } else {\n            None\n        };\n\n        Dists {\n            extractor: Vgg16L2PoolExtractor::new(device),\n            alpha: Param::from_tensor(Tensor::from_floats(alpha_data.as_slice(), device)),\n            beta: Param::from_tensor(Tensor::from_floats(beta_data.as_slice(), device)),\n            normalizer,\n        }\n    }\n\n    /// Initialize a DISTS module with pretrained weights.\n    pub fn init_pretrained<B: Backend>(&self, device: &B::Device) -> Dists<B> {\n        let dists = self.init(device);\n        super::weights::load_pretrained_weights(dists)\n    }\n}\n\n/// DISTS (Deep Image Structure and Texture Similarity) metric.\n///\n/// Computes perceptual similarity between two images by combining\n/// structure similarity (based on spatial means) and texture similarity\n/// (based on variances and covariances) across VGG16 feature maps.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_train::metric::vision::{DistsConfig, Reduction};\n///\n/// let device = Default::default();\n/// let dists = DistsConfig::new().init_pretrained(&device);\n///\n/// let img1: Tensor<B, 4> = /* [batch, 3, H, W] */;\n/// let img2: Tensor<B, 4> = /* [batch, 3, H, W] */;\n///\n/// let distance = dists.forward(img1, img2, Reduction::Mean);\n/// ```\n#[derive(Module, Debug)]\n#[module(custom_display)]\npub struct Dists<B: Backend> {\n    /// VGG16 feature extractor with L2 pooling\n    pub(crate) extractor: Vgg16L2PoolExtractor<B>,\n    /// Learned weights for structure similarity (per channel)\n    pub(crate) alpha: Param<Tensor<B, 1>>,\n    /// Learned weights for texture similarity (per channel)\n    pub(crate) beta: Param<Tensor<B, 1>>,\n    /// Optional normalizer for input preprocessing\n    pub(crate) normalizer: Option<Normalizer<B>>,\n}\n\nimpl<B: Backend> ModuleDisplay for Dists<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        content\n            .add(\"backbone\", &\"VGG16-L2Pool\".to_string())\n            .add(\"normalize\", &self.normalizer.is_some().to_string())\n            .optional()\n    }\n}\n\nimpl<B: Backend> Dists<B> {\n    /// Compute DISTS distance with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - First image tensor of shape `[batch, 3, H, W]`\n    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`\n    /// * `reduction` - How to reduce the output (Mean, Sum, or Auto)\n    ///\n    /// # Returns\n    ///\n    /// Scalar tensor of shape `[1]`.\n    pub fn forward(\n        &self,\n        input: Tensor<B, 4>,\n        target: Tensor<B, 4>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let distance = self.forward_no_reduction(input, target);\n\n        match reduction {\n            Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(),\n            Reduction::Sum => distance.sum(),\n        }\n    }\n\n    /// Compute DISTS distance without reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - First image tensor of shape `[batch, 3, H, W]`\n    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`\n    ///\n    /// # Returns\n    ///\n    /// Per-sample distance tensor of shape `[batch]`.\n    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {\n        let [batch, _, _, _] = input.dims();\n\n        // Preprocess inputs\n        let (input, target) = self.preprocess(input, target);\n\n        // Extract features from both images\n        let feats_x = self.extractor.forward(input);\n        let feats_y = self.extractor.forward(target);\n\n        // Get alpha and beta weights\n        let alpha = self.alpha.val();\n        let beta = self.beta.val();\n\n        // Compute weighted sum of alpha and beta for normalization\n        let alpha_sum = alpha.clone().sum();\n        let beta_sum = beta.clone().sum();\n\n        let device = feats_x[0].device();\n\n        // Initialize accumulators\n        let mut structure_dist = Tensor::<B, 1>::zeros([batch], &device);\n        let mut texture_dist = Tensor::<B, 1>::zeros([batch], &device);\n\n        let mut channel_offset = 0;\n\n        // Compute similarity for each stage\n        for (feat_x, feat_y) in feats_x.iter().zip(feats_y.iter()) {\n            let [_b, c, _h, _w] = feat_x.dims();\n\n            // Get alpha and beta for this stage\n            let alpha_stage = alpha.clone().narrow(0, channel_offset, c);\n            let beta_stage = beta.clone().narrow(0, channel_offset, c);\n\n            // Compute structure and texture similarity for this stage\n            let (s_dist, t_dist) = self.compute_stage_similarity(\n                feat_x.clone(),\n                feat_y.clone(),\n                alpha_stage,\n                beta_stage,\n            );\n\n            structure_dist = structure_dist.add(s_dist);\n            texture_dist = texture_dist.add(t_dist);\n\n            channel_offset += c;\n        }\n\n        // Normalize by sum of weights\n        structure_dist = structure_dist.div(alpha_sum);\n        texture_dist = texture_dist.div(beta_sum);\n\n        // DISTS = 1 - (structure_similarity + texture_similarity)\n        // Since we computed distances (1 - similarity), we return the sum\n        structure_dist.add(texture_dist)\n    }\n\n    /// Compute structure and texture similarity for a single stage.\n    fn compute_stage_similarity(\n        &self,\n        feat_x: Tensor<B, 4>,\n        feat_y: Tensor<B, 4>,\n        alpha: Tensor<B, 1>,\n        beta: Tensor<B, 1>,\n    ) -> (Tensor<B, 1>, Tensor<B, 1>) {\n        let [batch, channels, height, width] = feat_x.dims();\n        let device = feat_x.device();\n\n        // Reshape to [batch, channels, H*W] for easier computation\n        let x = feat_x.reshape([batch, channels, height * width]);\n        let y = feat_y.reshape([batch, channels, height * width]);\n\n        // Compute means: [batch, channels] (squeeze after mean_dim to remove the reduced dimension)\n        let mean_x = x.clone().mean_dim(2).squeeze_dim::<2>(2);\n        let mean_y = y.clone().mean_dim(2).squeeze_dim::<2>(2);\n\n        // Compute structure similarity: (2*mean_x*mean_y + c1) / (mean_x^2 + mean_y^2 + c1)\n        let c1 = Tensor::<B, 2>::full([batch, channels], C1, &device);\n        let structure_sim = mean_x\n            .clone()\n            .mul(mean_y.clone())\n            .mul_scalar(2.0)\n            .add(c1.clone())\n            .div(\n                mean_x\n                    .clone()\n                    .mul(mean_x.clone())\n                    .add(mean_y.clone().mul(mean_y.clone()))\n                    .add(c1),\n            );\n\n        // Compute variances and covariance\n        // var_x = E[x^2] - E[x]^2, clamped at 0 for numerical stability\n        let var_x = x\n            .clone()\n            .mul(x.clone())\n            .mean_dim(2)\n            .squeeze_dim::<2>(2)\n            .sub(mean_x.clone().mul(mean_x.clone()))\n            .clamp_min(0.0);\n        let var_y = y\n            .clone()\n            .mul(y.clone())\n            .mean_dim(2)\n            .squeeze_dim::<2>(2)\n            .sub(mean_y.clone().mul(mean_y.clone()))\n            .clamp_min(0.0);\n\n        // cov_xy = E[xy] - E[x]E[y]\n        let cov_xy = x\n            .mul(y)\n            .mean_dim(2)\n            .squeeze_dim::<2>(2)\n            .sub(mean_x.clone().mul(mean_y.clone()));\n\n        // Compute texture similarity: (2*cov_xy + c2) / (var_x + var_y + c2)\n        let c2 = Tensor::<B, 2>::full([batch, channels], C2, &device);\n        let texture_sim = cov_xy\n            .mul_scalar(2.0)\n            .add(c2.clone())\n            .div(var_x.add(var_y).add(c2));\n\n        // Convert similarity to distance: 1 - similarity\n        let structure_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(structure_sim);\n        let texture_dist = Tensor::<B, 2>::ones([batch, channels], &device).sub(texture_sim);\n\n        // Apply weights: [batch, channels] * [channels] -> [batch, channels]\n        // Then sum over channels -> [batch]\n        let weighted_structure = structure_dist\n            .mul(alpha.unsqueeze_dim::<2>(0))\n            .sum_dim(1)\n            .squeeze_dim::<1>(1);\n        let weighted_texture = texture_dist\n            .mul(beta.unsqueeze_dim::<2>(0))\n            .sum_dim(1)\n            .squeeze_dim::<1>(1);\n\n        (weighted_structure, weighted_texture)\n    }\n\n    /// Preprocess input images using the configured normalizer.\n    fn preprocess(\n        &self,\n        input: Tensor<B, 4>,\n        target: Tensor<B, 4>,\n    ) -> (Tensor<B, 4>, Tensor<B, 4>) {\n        match &self.normalizer {\n            Some(normalizer) => {\n                let input = normalizer.normalize(input);\n                let target = normalizer.normalize(target);\n                (input, target)\n            }\n            None => (input, target),\n        }\n    }\n}\n\n// =============================================================================\n// Tests\n// =============================================================================\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem};\n    use burn_ndarray::NdArray;\n\n    type TestBackend = NdArray<f32>;\n    type FT = FloatElem<TestBackend>;\n    type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n\n    #[test]\n    fn test_dists_identical_images_zero_distance() {\n        let device = Default::default();\n        // Use random image instead of constant to avoid numerical edge cases\n        let image = TestTensor::<4>::random(\n            [1, 3, 64, 64],\n            burn_core::tensor::Distribution::Uniform(0.0, 1.0),\n            &device,\n        );\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n        let distance = dists.forward(image.clone(), image, Reduction::Mean);\n\n        let expected = TensorData::from([0.0]);\n        distance\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    #[test]\n    fn test_dists_different_images_nonzero_distance() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n        let distance = dists.forward(image1, image2, Reduction::Mean);\n\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() > 1e-6,\n            \"DISTS should be != 0 for different images\"\n        );\n    }\n\n    #[test]\n    fn test_dists_symmetry() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n        let distance_forward = dists.forward(image1.clone(), image2.clone(), Reduction::Mean);\n        let distance_reverse = dists.forward(image2, image1, Reduction::Mean);\n\n        distance_forward\n            .into_data()\n            .assert_approx_eq::<FT>(&distance_reverse.into_data(), Tolerance::default());\n    }\n\n    #[test]\n    fn test_dists_batch_processing() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device);\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n        let distance = dists.forward(image1, image2, Reduction::Mean);\n\n        assert_eq!(distance.dims(), [1]);\n    }\n\n    #[test]\n    fn test_dists_no_reduction() {\n        let device = Default::default();\n\n        let batch_size = 4;\n        let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device);\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n        let distance = dists.forward_no_reduction(image1, image2);\n\n        assert_eq!(distance.dims(), [batch_size]);\n    }\n\n    #[test]\n    fn display_dists() {\n        let device = Default::default();\n        let dists: Dists<TestBackend> = DistsConfig::new().init(&device);\n\n        let display_str = format!(\"{dists}\");\n        assert!(display_str.contains(\"Dists\"));\n        assert!(display_str.contains(\"VGG16-L2Pool\"));\n    }\n\n    // =========================================================================\n    // Pretrained Weights Tests (requires network)\n    // =========================================================================\n\n    /// Test DISTS pretrained weights download and loading.\n    #[test]\n    fn test_dists_pretrained() {\n        let device = Default::default();\n\n        let dists: Dists<TestBackend> = DistsConfig::new().init_pretrained(&device);\n\n        // Test with identical images - should be ~0\n        // Use random image to avoid numerical edge cases with constant images\n        let image = TestTensor::<4>::random(\n            [1, 3, 64, 64],\n            burn_core::tensor::Distribution::Uniform(0.0, 1.0),\n            &device,\n        );\n        let distance = dists.forward(image.clone(), image, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() < 1e-5,\n            \"Pretrained DISTS should be ~0 for identical images, got {}\",\n            distance_value\n        );\n\n        // Test with different images - should be positive\n        let image1 = TestTensor::<4>::random(\n            [1, 3, 64, 64],\n            burn_core::tensor::Distribution::Uniform(0.0, 0.3),\n            &device,\n        );\n        let image2 = TestTensor::<4>::random(\n            [1, 3, 64, 64],\n            burn_core::tensor::Distribution::Uniform(0.7, 1.0),\n            &device,\n        );\n        let distance = dists.forward(image1, image2, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value > 0.0,\n            \"Pretrained DISTS should be > 0 for different images\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dists/mod.rs",
    "content": "//! DISTS (Deep Image Structure and Texture Similarity) metric.\n//!\n//! This module implements DISTS, a full-reference image quality assessment metric\n//! that combines structure and texture similarity using deep features.\n//!\n//! Reference: \"Image Quality Assessment: Unifying Structure and Texture Similarity\"\n//! https://arxiv.org/abs/2004.07728\n\nmod l2pool;\nmod metric;\nmod vgg16_l2pool;\nmod weights;\n\npub use metric::{Dists, DistsConfig};\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dists/vgg16_l2pool.rs",
    "content": "//! VGG16 feature extractor with L2 Pooling for DISTS.\n//!\n//! This module implements the VGG16 backbone used in DISTS,\n//! with L2Pooling replacing MaxPooling for smoother feature extraction.\n\nuse burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::relu;\nuse burn::tensor::backend::Backend;\nuse burn_nn::PaddingConfig2d;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\n\nuse super::l2pool::{L2Pool2d, L2Pool2dConfig};\n\n/// VGG16 feature extractor with L2 Pooling for DISTS.\n///\n/// Extracts features from 5 stages of VGG16, using L2Pooling\n/// instead of MaxPooling for smoother downsampling.\n///\n/// Output channels per stage: [64, 128, 256, 512, 512]\n#[derive(Module, Debug)]\npub struct Vgg16L2PoolExtractor<B: Backend> {\n    // Stage 1: 2 conv layers, 64 channels\n    pub(crate) conv1_1: Conv2d<B>,\n    pub(crate) conv1_2: Conv2d<B>,\n    pub(crate) pool1: L2Pool2d<B>,\n\n    // Stage 2: 2 conv layers, 128 channels\n    pub(crate) conv2_1: Conv2d<B>,\n    pub(crate) conv2_2: Conv2d<B>,\n    pub(crate) pool2: L2Pool2d<B>,\n\n    // Stage 3: 3 conv layers, 256 channels\n    pub(crate) conv3_1: Conv2d<B>,\n    pub(crate) conv3_2: Conv2d<B>,\n    pub(crate) conv3_3: Conv2d<B>,\n    pub(crate) pool3: L2Pool2d<B>,\n\n    // Stage 4: 3 conv layers, 512 channels\n    pub(crate) conv4_1: Conv2d<B>,\n    pub(crate) conv4_2: Conv2d<B>,\n    pub(crate) conv4_3: Conv2d<B>,\n    pub(crate) pool4: L2Pool2d<B>,\n\n    // Stage 5: 3 conv layers, 512 channels\n    pub(crate) conv5_1: Conv2d<B>,\n    pub(crate) conv5_2: Conv2d<B>,\n    pub(crate) conv5_3: Conv2d<B>,\n}\n\nimpl<B: Backend> Vgg16L2PoolExtractor<B> {\n    /// Create a new VGG16 feature extractor with L2 Pooling.\n    pub fn new(device: &B::Device) -> Self {\n        let pool_config = L2Pool2dConfig::default();\n\n        Self {\n            // Stage 1\n            conv1_1: Conv2dConfig::new([3, 64], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv1_2: Conv2dConfig::new([64, 64], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            pool1: pool_config.init(64, device),\n\n            // Stage 2\n            conv2_1: Conv2dConfig::new([64, 128], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv2_2: Conv2dConfig::new([128, 128], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            pool2: pool_config.init(128, device),\n\n            // Stage 3\n            conv3_1: Conv2dConfig::new([128, 256], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv3_2: Conv2dConfig::new([256, 256], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv3_3: Conv2dConfig::new([256, 256], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            pool3: pool_config.init(256, device),\n\n            // Stage 4\n            conv4_1: Conv2dConfig::new([256, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv4_2: Conv2dConfig::new([512, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv4_3: Conv2dConfig::new([512, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            pool4: pool_config.init(512, device),\n\n            // Stage 5\n            conv5_1: Conv2dConfig::new([512, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv5_2: Conv2dConfig::new([512, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n            conv5_3: Conv2dConfig::new([512, 512], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .init(device),\n        }\n    }\n\n    /// Extract features from all 5 stages.\n    ///\n    /// # Arguments\n    ///\n    /// * `x` - Input tensor of shape `[batch, 3, height, width]`\n    ///\n    /// # Returns\n    ///\n    /// Vector of 6 feature tensors:\n    /// - Stage 0: Input image [batch, 3, H, W]\n    /// - Stage 1: After conv1 [batch, 64, H/2, W/2]\n    /// - Stage 2: After conv2 [batch, 128, H/4, W/4]\n    /// - Stage 3: After conv3 [batch, 256, H/8, W/8]\n    /// - Stage 4: After conv4 [batch, 512, H/16, W/16]\n    /// - Stage 5: After conv5 [batch, 512, H/32, W/32]\n    pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {\n        let mut features = Vec::with_capacity(6);\n\n        // Stage 0: Input image\n        features.push(x.clone());\n\n        // Stage 1\n        let x = relu(self.conv1_1.forward(x));\n        let x = relu(self.conv1_2.forward(x));\n        features.push(x.clone());\n        let x = self.pool1.forward(x);\n\n        // Stage 2\n        let x = relu(self.conv2_1.forward(x));\n        let x = relu(self.conv2_2.forward(x));\n        features.push(x.clone());\n        let x = self.pool2.forward(x);\n\n        // Stage 3\n        let x = relu(self.conv3_1.forward(x));\n        let x = relu(self.conv3_2.forward(x));\n        let x = relu(self.conv3_3.forward(x));\n        features.push(x.clone());\n        let x = self.pool3.forward(x);\n\n        // Stage 4\n        let x = relu(self.conv4_1.forward(x));\n        let x = relu(self.conv4_2.forward(x));\n        let x = relu(self.conv4_3.forward(x));\n        features.push(x.clone());\n        let x = self.pool4.forward(x);\n\n        // Stage 5\n        let x = relu(self.conv5_1.forward(x));\n        let x = relu(self.conv5_2.forward(x));\n        let x = relu(self.conv5_3.forward(x));\n        features.push(x);\n\n        features\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/dists/weights.rs",
    "content": "//! Pretrained weights loading for DISTS.\n\nuse burn_core as burn;\n\nuse burn::tensor::backend::Backend;\nuse burn_std::network::downloader::download_file_as_bytes;\nuse burn_store::{ModuleSnapshot, PytorchStore};\nuse std::fs::{File, create_dir_all};\nuse std::io::Write;\nuse std::path::PathBuf;\n\nuse super::metric::Dists;\n\n/// URL for pretrained DISTS alpha/beta weights from the official repository.\n/// Reference: https://github.com/dingkeyan93/DISTS\nconst DISTS_WEIGHTS_URL: &str =\n    \"https://github.com/dingkeyan93/DISTS/raw/master/DISTS_pytorch/weights.pt\";\n\n/// URL for ImageNet pretrained VGG16 backbone weights from PyTorch.\nconst VGG16_IMAGENET_URL: &str = \"https://download.pytorch.org/models/vgg16-397923af.pth\";\n\n/// Get the cache directory for DISTS weights.\nfn get_cache_dir() -> PathBuf {\n    let cache_dir = dirs::cache_dir()\n        .expect(\"Could not get cache directory\")\n        .join(\"burn-dataset\")\n        .join(\"dists\");\n\n    if !cache_dir.exists() {\n        create_dir_all(&cache_dir).expect(\"Failed to create cache directory\");\n    }\n\n    cache_dir\n}\n\n/// Download file if not cached.\nfn download_if_needed(url: &str, cache_path: &PathBuf, message: &str) {\n    if !cache_path.exists() {\n        let bytes = download_file_as_bytes(url, message);\n        let mut file = File::create(cache_path).expect(\"Failed to create cache file\");\n        file.write_all(&bytes).expect(\"Failed to write weights\");\n    }\n}\n\n/// Download and load pretrained weights into a DISTS module.\n///\n/// This loads both:\n/// 1. ImageNet pretrained VGG16 backbone weights\n/// 2. DISTS trained alpha/beta weights\n///\n/// Weights are cached in the user's cache directory to avoid re-downloading.\n///\n/// # Arguments\n///\n/// * `dists` - The DISTS module to load weights into.\n///\n/// # Returns\n///\n/// The DISTS module with loaded pretrained weights.\npub fn load_pretrained_weights<B: Backend>(mut dists: Dists<B>) -> Dists<B> {\n    let cache_dir = get_cache_dir();\n\n    // Step 1: Download and load VGG16 ImageNet backbone weights\n    let vgg_cache_path = cache_dir.join(\"vgg16_backbone.pth\");\n    download_if_needed(\n        VGG16_IMAGENET_URL,\n        &vgg_cache_path,\n        \"Downloading VGG16 ImageNet weights for DISTS...\",\n    );\n\n    // Step 2: Download DISTS alpha/beta weights\n    let dists_cache_path = cache_dir.join(\"dists_weights.pt\");\n    download_if_needed(\n        DISTS_WEIGHTS_URL,\n        &dists_cache_path,\n        \"Downloading DISTS alpha/beta weights...\",\n    );\n\n    // Load VGG16 backbone weights first\n    dists = load_vgg16_backbone_weights(dists, &vgg_cache_path);\n\n    // Then load DISTS alpha/beta weights\n    dists = load_dists_weights(dists, &dists_cache_path);\n\n    dists\n}\n\n/// Load VGG16 ImageNet pretrained backbone weights.\nfn load_vgg16_backbone_weights<B: Backend>(mut dists: Dists<B>, cache_path: &PathBuf) -> Dists<B> {\n    let mut store = PytorchStore::from_file(cache_path)\n        .allow_partial(true)\n        .skip_enum_variants(true)\n        // VGG16 features.X -> extractor.convY_Z\n        .with_key_remapping(r\"^features\\.0\\.\", \"extractor.conv1_1.\")\n        .with_key_remapping(r\"^features\\.2\\.\", \"extractor.conv1_2.\")\n        .with_key_remapping(r\"^features\\.5\\.\", \"extractor.conv2_1.\")\n        .with_key_remapping(r\"^features\\.7\\.\", \"extractor.conv2_2.\")\n        .with_key_remapping(r\"^features\\.10\\.\", \"extractor.conv3_1.\")\n        .with_key_remapping(r\"^features\\.12\\.\", \"extractor.conv3_2.\")\n        .with_key_remapping(r\"^features\\.14\\.\", \"extractor.conv3_3.\")\n        .with_key_remapping(r\"^features\\.17\\.\", \"extractor.conv4_1.\")\n        .with_key_remapping(r\"^features\\.19\\.\", \"extractor.conv4_2.\")\n        .with_key_remapping(r\"^features\\.21\\.\", \"extractor.conv4_3.\")\n        .with_key_remapping(r\"^features\\.24\\.\", \"extractor.conv5_1.\")\n        .with_key_remapping(r\"^features\\.26\\.\", \"extractor.conv5_2.\")\n        .with_key_remapping(r\"^features\\.28\\.\", \"extractor.conv5_3.\");\n\n    let result = dists.load_from(&mut store);\n    if let Err(e) = result {\n        log::warn!(\"Some VGG16 backbone weights could not be loaded: {:?}\", e);\n    }\n\n    dists\n}\n\n/// Load DISTS trained alpha/beta weights.\nfn load_dists_weights<B: Backend>(mut dists: Dists<B>, cache_path: &PathBuf) -> Dists<B> {\n    let mut store = PytorchStore::from_file(cache_path)\n        .allow_partial(true)\n        .skip_enum_variants(true);\n\n    let result = dists.load_from(&mut store);\n    if let Err(e) = result {\n        log::warn!(\"Some DISTS weights could not be loaded: {:?}\", e);\n    }\n\n    dists\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/alexnet.rs",
    "content": "//! AlexNet feature extractor for LPIPS.\n\nuse burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::relu;\nuse burn::tensor::backend::Backend;\nuse burn_nn::PaddingConfig2d;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\n\n/// AlexNet feature extractor for LPIPS.\n///\n/// Extracts features from 5 layers:\n/// - conv1: 64 channels (after ReLU)\n/// - conv2: 192 channels (after ReLU)\n/// - conv3: 384 channels (after ReLU)\n/// - conv4: 256 channels (after ReLU)\n/// - conv5: 256 channels (after ReLU)\n#[derive(Module, Debug)]\npub struct AlexFeatureExtractor<B: Backend> {\n    /// Conv1: 3 -> 64, kernel 11x11, stride 4, padding 2\n    conv1: Conv2d<B>,\n    /// Conv2: 64 -> 192, kernel 5x5, stride 1, padding 2\n    conv2: Conv2d<B>,\n    /// Conv3: 192 -> 384, kernel 3x3, stride 1, padding 1\n    conv3: Conv2d<B>,\n    /// Conv4: 384 -> 256, kernel 3x3, stride 1, padding 1\n    conv4: Conv2d<B>,\n    /// Conv5: 256 -> 256, kernel 3x3, stride 1, padding 1\n    conv5: Conv2d<B>,\n}\n\nimpl<B: Backend> AlexFeatureExtractor<B> {\n    /// Create a new AlexNet feature extractor.\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            // Conv1: 3 -> 64, 11x11, stride 4, padding 2\n            conv1: Conv2dConfig::new([3, 64], [11, 11])\n                .with_stride([4, 4])\n                .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))\n                .with_bias(true)\n                .init(device),\n            // Conv2: 64 -> 192, 5x5, stride 1, padding 2\n            conv2: Conv2dConfig::new([64, 192], [5, 5])\n                .with_padding(PaddingConfig2d::Explicit(2, 2, 2, 2))\n                .with_bias(true)\n                .init(device),\n            // Conv3: 192 -> 384, 3x3, stride 1, padding 1\n            conv3: Conv2dConfig::new([192, 384], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .with_bias(true)\n                .init(device),\n            // Conv4: 384 -> 256, 3x3, stride 1, padding 1\n            conv4: Conv2dConfig::new([384, 256], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .with_bias(true)\n                .init(device),\n            // Conv5: 256 -> 256, 3x3, stride 1, padding 1\n            conv5: Conv2dConfig::new([256, 256], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .with_bias(true)\n                .init(device),\n        }\n    }\n\n    /// Extract features from 5 AlexNet layers.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {\n        let mut features = Vec::with_capacity(5);\n\n        // Slice 1: Conv1 + ReLU\n        let x = relu(self.conv1.forward(x));\n        features.push(x.clone());\n\n        // Slice 2: MaxPool + Conv2 + ReLU\n        let x = max_pool2d_alex(x);\n        let x = relu(self.conv2.forward(x));\n        features.push(x.clone());\n\n        // Slice 3: MaxPool + Conv3 + ReLU\n        let x = max_pool2d_alex(x);\n        let x = relu(self.conv3.forward(x));\n        features.push(x.clone());\n\n        // Slice 4: Conv4 + ReLU (no pooling)\n        let x = relu(self.conv4.forward(x));\n        features.push(x.clone());\n\n        // Slice 5: Conv5 + ReLU (no pooling)\n        let x = relu(self.conv5.forward(x));\n        features.push(x);\n\n        features\n    }\n}\n\n/// 3x3 max pooling with stride 2 (for AlexNet).\nfn max_pool2d_alex<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {\n    burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], false)\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/metric.rs",
    "content": "//! LPIPS (Learned Perceptual Image Patch Similarity) metric module.\n//!\n//! LPIPS measures perceptual similarity between images using deep features.\n//! Supports VGG16, AlexNet, and SqueezeNet as backbone networks.\n//!\n//! Reference: \"The Unreasonable Effectiveness of Deep Features as a Perceptual Metric\"\n//! <https://arxiv.org/abs/1801.03924>\n\nuse burn_core as burn;\n\nuse burn::config::Config;\nuse burn::module::{Content, DisplaySettings, Module, ModuleDisplay};\nuse burn::tensor::Tensor;\nuse burn::tensor::backend::Backend;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\nuse burn_nn::loss::Reduction;\n\nuse super::alexnet::AlexFeatureExtractor;\nuse super::squeezenet::SqueezeFeatureExtractor;\nuse super::vgg::VggFeatureExtractor;\n\n/// Network type for LPIPS.\n#[derive(Config, Debug, Copy, PartialEq, Eq)]\npub enum LpipsNet {\n    /// VGG16 network (default)\n    Vgg,\n    /// AlexNet network\n    Alex,\n    /// SqueezeNet network\n    Squeeze,\n}\n\n/// Configuration for [Lpips](Lpips) metric module.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_train::metric::vision::{LpipsConfig, LpipsNet};\n///\n/// // VGG (default)\n/// let lpips_vgg = LpipsConfig::new().init(&device);\n///\n/// // AlexNet\n/// let lpips_alex = LpipsConfig::new()\n///     .with_net(LpipsNet::Alex)\n///     .init(&device);\n///\n/// // SqueezeNet\n/// let lpips_squeeze = LpipsConfig::new()\n///     .with_net(LpipsNet::Squeeze)\n///     .init(&device);\n/// ```\n#[derive(Config, Debug)]\npub struct LpipsConfig {\n    /// Network type for feature extraction.\n    #[config(default = \"LpipsNet::Vgg\")]\n    pub net: LpipsNet,\n\n    /// Whether to normalize input images to [-1, 1] range.\n    /// Set to true if input is in [0, 1] range.\n    #[config(default = true)]\n    pub normalize: bool,\n}\n\nimpl LpipsConfig {\n    /// Initialize a new [Lpips](Lpips) module with pretrained weights.\n    ///\n    /// Downloads and loads official LPIPS pretrained weights from the\n    /// PerceptualSimilarity repository.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to create the module on.\n    ///\n    /// # Returns\n    ///\n    /// A new LPIPS module with pretrained weights loaded.\n    ///\n    /// # Example\n    ///\n    /// ```ignore\n    /// use burn_train::metric::vision::{LpipsConfig, LpipsNet};\n    ///\n    /// let lpips = LpipsConfig::new()\n    ///     .with_net(LpipsNet::Vgg)\n    ///     .init_pretrained(&device);\n    /// ```\n    pub fn init_pretrained<B: Backend>(&self, device: &B::Device) -> Lpips<B> {\n        let lpips = self.init(device);\n        super::weights::load_pretrained_weights(lpips, self.net)\n    }\n\n    /// Initialize a new [Lpips](Lpips) module with random weights.\n    ///\n    /// # Arguments\n    ///\n    /// * `device` - Device to create the module on.\n    ///\n    /// # Returns\n    ///\n    /// A new LPIPS module with random weights. Use `init_pretrained` for accurate results.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Lpips<B> {\n        match self.net {\n            LpipsNet::Vgg => {\n                // Channel sizes for VGG16: [64, 128, 256, 512, 512]\n                Lpips::Vgg(LpipsVgg {\n                    extractor: VggFeatureExtractor::new(device),\n                    lin0: Conv2dConfig::new([64, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin1: Conv2dConfig::new([128, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin2: Conv2dConfig::new([256, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin3: Conv2dConfig::new([512, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin4: Conv2dConfig::new([512, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    normalize: self.normalize,\n                })\n            }\n            LpipsNet::Alex => {\n                // Channel sizes for AlexNet: [64, 192, 384, 256, 256]\n                Lpips::Alex(LpipsAlex {\n                    extractor: AlexFeatureExtractor::new(device),\n                    lin0: Conv2dConfig::new([64, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin1: Conv2dConfig::new([192, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin2: Conv2dConfig::new([384, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin3: Conv2dConfig::new([256, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin4: Conv2dConfig::new([256, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    normalize: self.normalize,\n                })\n            }\n            LpipsNet::Squeeze => {\n                // Channel sizes for SqueezeNet: [64, 128, 256, 384, 384, 512, 512]\n                Lpips::Squeeze(LpipsSqueeze {\n                    extractor: SqueezeFeatureExtractor::new(device),\n                    lin0: Conv2dConfig::new([64, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin1: Conv2dConfig::new([128, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin2: Conv2dConfig::new([256, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin3: Conv2dConfig::new([384, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin4: Conv2dConfig::new([384, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin5: Conv2dConfig::new([512, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    lin6: Conv2dConfig::new([512, 1], [1, 1])\n                        .with_bias(false)\n                        .init(device),\n                    normalize: self.normalize,\n                })\n            }\n        }\n    }\n}\n\n/// LPIPS (Learned Perceptual Image Patch Similarity) metric module.\n///\n/// Computes perceptual distance between two images using deep features.\n/// Supports VGG16, AlexNet, and SqueezeNet as backbone networks.\n///\n/// # Example\n///\n/// ```ignore\n/// use burn_train::metric::vision::{LpipsConfig, LpipsNet, Reduction};\n///\n/// let device = Default::default();\n/// let lpips = LpipsConfig::new().init(&device);\n///\n/// let img1: Tensor<B, 4> = /* [batch, 3, H, W] */;\n/// let img2: Tensor<B, 4> = /* [batch, 3, H, W] */;\n///\n/// // Compute LPIPS distance\n/// let distance = lpips.forward(img1, img2, Reduction::Mean);\n/// ```\n#[derive(Module, Debug)]\n#[allow(clippy::large_enum_variant)]\n#[module(custom_display)]\npub enum Lpips<B: Backend> {\n    /// VGG16 backbone (5 feature layers)\n    Vgg(LpipsVgg<B>),\n    /// AlexNet backbone (5 feature layers)\n    Alex(LpipsAlex<B>),\n    /// SqueezeNet backbone (7 feature layers)\n    Squeeze(LpipsSqueeze<B>),\n}\n\n/// LPIPS with VGG16 backbone.\n#[derive(Module, Debug)]\npub struct LpipsVgg<B: Backend> {\n    /// VGG feature extractor\n    pub(crate) extractor: VggFeatureExtractor<B>,\n    /// Linear layers for each feature level\n    pub(crate) lin0: Conv2d<B>,\n    pub(crate) lin1: Conv2d<B>,\n    pub(crate) lin2: Conv2d<B>,\n    pub(crate) lin3: Conv2d<B>,\n    pub(crate) lin4: Conv2d<B>,\n    /// Whether to normalize input\n    pub(crate) normalize: bool,\n}\n\n/// LPIPS with AlexNet backbone.\n#[derive(Module, Debug)]\npub struct LpipsAlex<B: Backend> {\n    /// AlexNet feature extractor\n    pub(crate) extractor: AlexFeatureExtractor<B>,\n    /// Linear layers for each feature level\n    pub(crate) lin0: Conv2d<B>,\n    pub(crate) lin1: Conv2d<B>,\n    pub(crate) lin2: Conv2d<B>,\n    pub(crate) lin3: Conv2d<B>,\n    pub(crate) lin4: Conv2d<B>,\n    /// Whether to normalize input\n    pub(crate) normalize: bool,\n}\n\n/// LPIPS with SqueezeNet backbone.\n#[derive(Module, Debug)]\npub struct LpipsSqueeze<B: Backend> {\n    /// SqueezeNet feature extractor\n    pub(crate) extractor: SqueezeFeatureExtractor<B>,\n    /// Linear layers for each feature level\n    pub(crate) lin0: Conv2d<B>,\n    pub(crate) lin1: Conv2d<B>,\n    pub(crate) lin2: Conv2d<B>,\n    pub(crate) lin3: Conv2d<B>,\n    pub(crate) lin4: Conv2d<B>,\n    pub(crate) lin5: Conv2d<B>,\n    pub(crate) lin6: Conv2d<B>,\n    /// Whether to normalize input\n    pub(crate) normalize: bool,\n}\n\nimpl<B: Backend> LpipsVgg<B> {\n    /// Compute LPIPS distance without reduction using VGG backbone.\n    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {\n        // Preprocess inputs\n        let (input, target) = preprocess_inputs(input, target, self.normalize);\n\n        // Extract features from both images\n        let feats0 = self.extractor.forward(input);\n        let feats1 = self.extractor.forward(target);\n\n        // Compute distance for each layer using stack + sum\n        let layer_distances: Vec<Tensor<B, 2>> = vec![\n            compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),\n        ];\n\n        Tensor::cat(layer_distances, 1)\n            .sum_dim(1)\n            .squeeze_dim::<1>(1)\n    }\n}\n\nimpl<B: Backend> LpipsAlex<B> {\n    /// Compute LPIPS distance without reduction using AlexNet backbone.\n    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {\n        // Preprocess inputs\n        let (input, target) = preprocess_inputs(input, target, self.normalize);\n\n        // Extract features from both images\n        let feats0 = self.extractor.forward(input);\n        let feats1 = self.extractor.forward(target);\n\n        // Compute distance for each layer using stack + sum\n        let layer_distances: Vec<Tensor<B, 2>> = vec![\n            compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),\n        ];\n\n        Tensor::cat(layer_distances, 1)\n            .sum_dim(1)\n            .squeeze_dim::<1>(1)\n    }\n}\n\nimpl<B: Backend> LpipsSqueeze<B> {\n    /// Compute LPIPS distance without reduction using SqueezeNet backbone.\n    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {\n        // Preprocess inputs\n        let (input, target) = preprocess_inputs(input, target, self.normalize);\n\n        // Extract features from both images\n        let feats0 = self.extractor.forward(input);\n        let feats1 = self.extractor.forward(target);\n\n        // Compute distance for each layer using stack + sum (7 layers for SqueezeNet)\n        let layer_distances: Vec<Tensor<B, 2>> = vec![\n            compute_layer_distance(&feats0[0], &feats1[0], &self.lin0).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[1], &feats1[1], &self.lin1).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[2], &feats1[2], &self.lin2).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[3], &feats1[3], &self.lin3).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[4], &feats1[4], &self.lin4).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[5], &feats1[5], &self.lin5).unsqueeze_dim(1),\n            compute_layer_distance(&feats0[6], &feats1[6], &self.lin6).unsqueeze_dim(1),\n        ];\n\n        Tensor::cat(layer_distances, 1)\n            .sum_dim(1)\n            .squeeze_dim::<1>(1)\n    }\n}\n\nimpl<B: Backend> ModuleDisplay for Lpips<B> {\n    fn custom_settings(&self) -> Option<DisplaySettings> {\n        DisplaySettings::new()\n            .with_new_line_after_attribute(false)\n            .optional()\n    }\n\n    fn custom_content(&self, content: Content) -> Option<Content> {\n        let (net_name, normalize) = match self {\n            Lpips::Vgg(inner) => (\"Vgg\", inner.normalize),\n            Lpips::Alex(inner) => (\"Alex\", inner.normalize),\n            Lpips::Squeeze(inner) => (\"Squeeze\", inner.normalize),\n        };\n        content\n            .add(\"net\", &net_name.to_string())\n            .add(\"normalize\", &normalize.to_string())\n            .optional()\n    }\n}\n\nimpl<B: Backend> Lpips<B> {\n    /// Compute LPIPS distance with reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - First image tensor of shape `[batch, 3, H, W]`\n    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`\n    /// * `reduction` - How to reduce the output (Mean, Sum, or Auto)\n    ///\n    /// # Returns\n    ///\n    /// Scalar tensor of shape `[1]`.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch, 3, H, W]`\n    /// - target: `[batch, 3, H, W]`\n    /// - output: `[1]`\n    pub fn forward(\n        &self,\n        input: Tensor<B, 4>,\n        target: Tensor<B, 4>,\n        reduction: Reduction,\n    ) -> Tensor<B, 1> {\n        let distance = self.forward_no_reduction(input, target);\n\n        match reduction {\n            Reduction::Mean | Reduction::Auto | Reduction::BatchMean => distance.mean(),\n            Reduction::Sum => distance.sum(),\n        }\n    }\n\n    /// Compute LPIPS distance without reduction.\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - First image tensor of shape `[batch, 3, H, W]`\n    /// * `target` - Second image tensor of shape `[batch, 3, H, W]`\n    ///\n    /// # Returns\n    ///\n    /// Per-sample distance tensor of shape `[batch]`.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch, 3, H, W]`\n    /// - target: `[batch, 3, H, W]`\n    /// - output: `[batch]`\n    pub fn forward_no_reduction(&self, input: Tensor<B, 4>, target: Tensor<B, 4>) -> Tensor<B, 1> {\n        match self {\n            Lpips::Vgg(inner) => inner.forward_no_reduction(input, target),\n            Lpips::Alex(inner) => inner.forward_no_reduction(input, target),\n            Lpips::Squeeze(inner) => inner.forward_no_reduction(input, target),\n        }\n    }\n}\n\n// =============================================================================\n// Helper Functions\n// =============================================================================\n\n/// Normalize tensor to unit norm along channel dimension.\nfn normalize_tensor<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {\n    let norm = x.clone().mul(x.clone()).sum_dim(1).sqrt().clamp_min(1e-10);\n    x.div(norm)\n}\n\n/// Apply ImageNet normalization used by PyTorch lpips.\n/// shift = [-.030, -.088, -.188], scale = [.458, .448, .450]\n/// output = (input - shift) / scale\nfn scaling_layer<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {\n    let device = x.device();\n    let [batch, _, h, w] = x.dims();\n\n    // Create shift and scale tensors [1, 3, 1, 1] and broadcast\n    let shift = Tensor::<B, 2>::from_floats([[-0.030], [-0.088], [-0.188]], &device)\n        .reshape([1, 3, 1, 1])\n        .expand([batch, 3, h, w]);\n    let scale = Tensor::<B, 2>::from_floats([[0.458], [0.448], [0.450]], &device)\n        .reshape([1, 3, 1, 1])\n        .expand([batch, 3, h, w]);\n\n    x.sub(shift).div(scale)\n}\n\n/// Compute normalized L2 distance for a single layer.\nfn compute_layer_distance<B: Backend>(\n    feat0: &Tensor<B, 4>,\n    feat1: &Tensor<B, 4>,\n    lin: &Conv2d<B>,\n) -> Tensor<B, 1> {\n    // Normalize features (unit norm along channel dimension)\n    let feat0_norm = normalize_tensor(feat0.clone());\n    let feat1_norm = normalize_tensor(feat1.clone());\n\n    // Compute squared difference\n    let diff = feat0_norm.sub(feat1_norm);\n    let diff_sq = diff.clone().mul(diff);\n\n    // Apply linear layer (learned weights)\n    // Shape: [batch, C, H, W] -> [batch, 1, H, W]\n    let weighted = lin.forward(diff_sq);\n\n    // Spatial average: compute mean over C, H, W dimensions\n    // Shape: [batch, 1, H, W] -> [batch]\n    let [batch, c, h, w] = weighted.dims();\n\n    // Reshape to [batch, c*h*w] then take mean over last dimension\n    weighted\n        .reshape([batch, c * h * w])\n        .mean_dim(1)\n        .squeeze_dim::<1>(1)\n}\n\n/// Preprocess input images for LPIPS computation.\nfn preprocess_inputs<B: Backend>(\n    input: Tensor<B, 4>,\n    target: Tensor<B, 4>,\n    normalize: bool,\n) -> (Tensor<B, 4>, Tensor<B, 4>) {\n    // Normalize to [-1, 1] if needed\n    let (input, target) = if normalize {\n        (\n            input.mul_scalar(2.0).sub_scalar(1.0),\n            target.mul_scalar(2.0).sub_scalar(1.0),\n        )\n    } else {\n        (input, target)\n    };\n\n    // Apply ImageNet normalization (same as PyTorch lpips scaling_layer)\n    (scaling_layer(input), scaling_layer(target))\n}\n\n// =============================================================================\n// Tests\n// =============================================================================\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_core::tensor::{TensorData, Tolerance, ops::FloatElem};\n    use burn_ndarray::NdArray;\n\n    type TestBackend = NdArray<f32>;\n    type FT = FloatElem<TestBackend>;\n    type TestTensor<const D: usize> = Tensor<TestBackend, D>;\n\n    // =========================================================================\n    // Basic Functionality Tests\n    // =========================================================================\n\n    /// Identical images should have LPIPS distance of 0.\n    #[test]\n    fn test_lpips_identical_images_zero_distance() {\n        let device = Default::default();\n        let image = TestTensor::<4>::ones([1, 3, 32, 32], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n\n        // Identical images → distance = 0\n        let expected = TensorData::from([0.0]);\n        distance\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    /// Different images should have LPIPS distance != 0.\n    /// Note: With random weights, distance can be negative, so we only check != 0.\n    /// Non-negativity is tested with pretrained weights.\n    #[test]\n    fn test_lpips_different_images_nonzero_distance() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() > 1e-6,\n            \"LPIPS should be != 0 for different images\"\n        );\n    }\n\n    /// Test symmetry: LPIPS(a, b) == LPIPS(b, a).\n    #[test]\n    fn test_lpips_symmetry() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 32, 32], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n        let distance_forward = lpips.forward(image1.clone(), image2.clone(), Reduction::Mean);\n        let distance_reverse = lpips.forward(image2, image1, Reduction::Mean);\n\n        distance_forward\n            .into_data()\n            .assert_approx_eq::<FT>(&distance_reverse.into_data(), Tolerance::default());\n    }\n\n    // =========================================================================\n    // Reduction Tests\n    // =========================================================================\n\n    #[test]\n    fn test_lpips_forward_mean_reduction() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([2, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([2, 3, 32, 32], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n\n        assert_eq!(distance.dims(), [1]);\n    }\n\n    #[test]\n    fn test_lpips_forward_no_reduction() {\n        let device = Default::default();\n\n        let batch_size = 4;\n        let image1 = TestTensor::<4>::zeros([batch_size, 3, 32, 32], &device);\n        let image2 = TestTensor::<4>::ones([batch_size, 3, 32, 32], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n        let distance = lpips.forward_no_reduction(image1, image2);\n\n        assert_eq!(distance.dims(), [batch_size]);\n    }\n\n    // =========================================================================\n    // AlexNet Tests\n    // =========================================================================\n\n    /// Test AlexNet LPIPS with identical images.\n    #[test]\n    fn test_lpips_alex_identical_images_zero_distance() {\n        let device = Default::default();\n        let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n\n        let expected = TensorData::from([0.0]);\n        distance\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    /// Test AlexNet LPIPS with different images produces non-zero distance.\n    #[test]\n    fn test_lpips_alex_different_images_nonzero_distance() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        // Note: With random weights, non-negativity is not guaranteed.\n        // We only check that different images produce a non-zero distance.\n        assert!(\n            distance_value.abs() > 1e-6,\n            \"LPIPS (Alex) should be != 0 for different images\"\n        );\n    }\n\n    // =========================================================================\n    // SqueezeNet Tests\n    // =========================================================================\n\n    /// Test SqueezeNet LPIPS with identical images.\n    #[test]\n    fn test_lpips_squeeze_identical_images_zero_distance() {\n        let device = Default::default();\n        let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n\n        let lpips: Lpips<TestBackend> =\n            LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n\n        let expected = TensorData::from([0.0]);\n        distance\n            .into_data()\n            .assert_approx_eq::<FT>(&expected, Tolerance::default());\n    }\n\n    /// Test SqueezeNet LPIPS with different images produces non-zero distance.\n    #[test]\n    fn test_lpips_squeeze_different_images_nonzero_distance() {\n        let device = Default::default();\n\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n\n        let lpips: Lpips<TestBackend> =\n            LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        // Note: With random weights, non-negativity is not guaranteed.\n        // We only check that different images produce a non-zero distance.\n        assert!(\n            distance_value.abs() > 1e-6,\n            \"LPIPS (Squeeze) should be != 0 for different images\"\n        );\n    }\n\n    // =========================================================================\n    // Display Tests\n    // =========================================================================\n\n    #[test]\n    fn display_vgg() {\n        let device = Default::default();\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().init(&device);\n\n        let display_str = format!(\"{lpips}\");\n        assert!(display_str.contains(\"Lpips\"));\n        assert!(display_str.contains(\"Vgg\"));\n    }\n\n    #[test]\n    fn display_alex() {\n        let device = Default::default();\n        let lpips: Lpips<TestBackend> = LpipsConfig::new().with_net(LpipsNet::Alex).init(&device);\n\n        let display_str = format!(\"{lpips}\");\n        assert!(display_str.contains(\"Lpips\"));\n        assert!(display_str.contains(\"Alex\"));\n    }\n\n    #[test]\n    fn display_squeeze() {\n        let device = Default::default();\n        let lpips: Lpips<TestBackend> =\n            LpipsConfig::new().with_net(LpipsNet::Squeeze).init(&device);\n\n        let display_str = format!(\"{lpips}\");\n        assert!(display_str.contains(\"Lpips\"));\n        assert!(display_str.contains(\"Squeeze\"));\n    }\n\n    // =========================================================================\n    // Pretrained Weights Tests (requires network)\n    // =========================================================================\n\n    /// Test VGG pretrained weights download and loading.\n    #[test]\n    fn test_lpips_pretrained_vgg() {\n        let device = Default::default();\n\n        // This will download ~60MB of weights\n        let lpips: Lpips<TestBackend> = LpipsConfig::new()\n            .with_net(LpipsNet::Vgg)\n            .init_pretrained(&device);\n\n        // Test with identical images - should be 0\n        let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() < 1e-5,\n            \"Pretrained LPIPS (VGG) should be ~0 for identical images, got {}\",\n            distance_value\n        );\n\n        // Test with different images - should be positive\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value > 0.0,\n            \"Pretrained LPIPS (VGG) should be > 0 for different images, got {}\",\n            distance_value\n        );\n    }\n\n    /// Test AlexNet pretrained weights download and loading.\n    #[test]\n    fn test_lpips_pretrained_alex() {\n        let device = Default::default();\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new()\n            .with_net(LpipsNet::Alex)\n            .init_pretrained(&device);\n\n        // Test with identical images\n        let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() < 1e-5,\n            \"Pretrained LPIPS (Alex) should be ~0 for identical images, got {}\",\n            distance_value\n        );\n\n        // Test with different images\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value > 0.0,\n            \"Pretrained LPIPS (Alex) should be > 0 for different images\"\n        );\n    }\n\n    /// Test SqueezeNet pretrained weights download and loading.\n    #[test]\n    fn test_lpips_pretrained_squeeze() {\n        let device = Default::default();\n\n        let lpips: Lpips<TestBackend> = LpipsConfig::new()\n            .with_net(LpipsNet::Squeeze)\n            .init_pretrained(&device);\n\n        // Test with identical images\n        let image = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image.clone(), image, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value.abs() < 1e-5,\n            \"Pretrained LPIPS (Squeeze) should be ~0 for identical images, got {}\",\n            distance_value\n        );\n\n        // Test with different images\n        let image1 = TestTensor::<4>::zeros([1, 3, 64, 64], &device);\n        let image2 = TestTensor::<4>::ones([1, 3, 64, 64], &device);\n        let distance = lpips.forward(image1, image2, Reduction::Mean);\n        let distance_value = distance.into_data().to_vec::<f32>().unwrap()[0];\n        assert!(\n            distance_value > 0.0,\n            \"Pretrained LPIPS (Squeeze) should be > 0 for different images, got {}\",\n            distance_value\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/mod.rs",
    "content": "//! LPIPS (Learned Perceptual Image Patch Similarity) metric module.\n//!\n//! LPIPS measures perceptual similarity between images using deep features.\n//! Supports VGG16, AlexNet, and SqueezeNet as backbone networks.\n//!\n//! Reference: \"The Unreasonable Effectiveness of Deep Features as a Perceptual Metric\"\n//! <https://arxiv.org/abs/1801.03924>\n\nmod alexnet;\nmod metric;\nmod squeezenet;\nmod vgg;\nmod weights;\n\npub use metric::{Lpips, LpipsAlex, LpipsConfig, LpipsNet, LpipsSqueeze, LpipsVgg};\npub use weights::{get_backbone_weights_url, get_lpips_weights_url, load_pretrained_weights};\n\n// Re-export feature extractors for advanced use cases\npub use alexnet::AlexFeatureExtractor;\npub use squeezenet::{FireModule, SqueezeFeatureExtractor};\npub use vgg::VggFeatureExtractor;\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/squeezenet.rs",
    "content": "//! SqueezeNet feature extractor for LPIPS.\n\nuse burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::relu;\nuse burn::tensor::backend::Backend;\nuse burn_nn::PaddingConfig2d;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\n\n/// Fire module for SqueezeNet.\n///\n/// A fire module consists of:\n/// - Squeeze layer: 1x1 conv to reduce channels\n/// - Expand layers: parallel 1x1 and 3x3 convs, concatenated\n#[derive(Module, Debug)]\npub struct FireModule<B: Backend> {\n    /// Squeeze layer: 1x1 conv\n    squeeze: Conv2d<B>,\n    /// Expand 1x1 conv\n    expand1x1: Conv2d<B>,\n    /// Expand 3x3 conv\n    expand3x3: Conv2d<B>,\n}\n\nimpl<B: Backend> FireModule<B> {\n    /// Create a new Fire module.\n    pub fn new(\n        in_channels: usize,\n        squeeze_channels: usize,\n        expand1x1_channels: usize,\n        expand3x3_channels: usize,\n        device: &B::Device,\n    ) -> Self {\n        Self {\n            squeeze: Conv2dConfig::new([in_channels, squeeze_channels], [1, 1])\n                .with_bias(true)\n                .init(device),\n            expand1x1: Conv2dConfig::new([squeeze_channels, expand1x1_channels], [1, 1])\n                .with_bias(true)\n                .init(device),\n            expand3x3: Conv2dConfig::new([squeeze_channels, expand3x3_channels], [3, 3])\n                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))\n                .with_bias(true)\n                .init(device),\n        }\n    }\n\n    /// Forward pass through fire module.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {\n        let squeezed = relu(self.squeeze.forward(x));\n        let e1 = relu(self.expand1x1.forward(squeezed.clone()));\n        let e3 = relu(self.expand3x3.forward(squeezed));\n        // Concatenate along channel dimension\n        Tensor::cat(vec![e1, e3], 1)\n    }\n}\n\n/// SqueezeNet 1.1 feature extractor for LPIPS.\n///\n/// Extracts features from 7 layers:\n/// - After conv1+relu: 64 channels\n/// - After fire1+fire2: 128 channels\n/// - After fire3+fire4: 256 channels\n/// - After fire5: 384 channels\n/// - After fire6: 384 channels\n/// - After fire7: 512 channels\n/// - After fire8: 512 channels\n#[derive(Module, Debug)]\npub struct SqueezeFeatureExtractor<B: Backend> {\n    /// Conv1: 3 -> 64, kernel 3x3, stride 2\n    conv1: Conv2d<B>,\n    /// Fire1: 64 -> 128 (squeeze=16, expand=64+64)\n    fire1: FireModule<B>,\n    /// Fire2: 128 -> 128 (squeeze=16, expand=64+64)\n    fire2: FireModule<B>,\n    /// Fire3: 128 -> 256 (squeeze=32, expand=128+128)\n    fire3: FireModule<B>,\n    /// Fire4: 256 -> 256 (squeeze=32, expand=128+128)\n    fire4: FireModule<B>,\n    /// Fire5: 256 -> 384 (squeeze=48, expand=192+192)\n    fire5: FireModule<B>,\n    /// Fire6: 384 -> 384 (squeeze=48, expand=192+192)\n    fire6: FireModule<B>,\n    /// Fire7: 384 -> 512 (squeeze=64, expand=256+256)\n    fire7: FireModule<B>,\n    /// Fire8: 512 -> 512 (squeeze=64, expand=256+256)\n    fire8: FireModule<B>,\n}\n\nimpl<B: Backend> SqueezeFeatureExtractor<B> {\n    /// Create a new SqueezeNet feature extractor.\n    pub fn new(device: &B::Device) -> Self {\n        Self {\n            // Conv1: 3 -> 64, 3x3, stride 2\n            conv1: Conv2dConfig::new([3, 64], [3, 3])\n                .with_stride([2, 2])\n                .with_bias(true)\n                .init(device),\n            // Fire modules (SqueezeNet 1.1 configuration)\n            fire1: FireModule::new(64, 16, 64, 64, device), // -> 128\n            fire2: FireModule::new(128, 16, 64, 64, device), // -> 128\n            fire3: FireModule::new(128, 32, 128, 128, device), // -> 256\n            fire4: FireModule::new(256, 32, 128, 128, device), // -> 256\n            fire5: FireModule::new(256, 48, 192, 192, device), // -> 384\n            fire6: FireModule::new(384, 48, 192, 192, device), // -> 384\n            fire7: FireModule::new(384, 64, 256, 256, device), // -> 512\n            fire8: FireModule::new(512, 64, 256, 256, device), // -> 512\n        }\n    }\n\n    /// Extract features from 7 SqueezeNet layers.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {\n        let mut features = Vec::with_capacity(7);\n\n        // Slice 1: Conv1 + ReLU (64 channels)\n        let x = relu(self.conv1.forward(x));\n        features.push(x.clone());\n\n        // Slice 2: MaxPool + Fire1 + Fire2 (128 channels)\n        let x = max_pool2d_squeeze(x);\n        let x = self.fire1.forward(x);\n        let x = self.fire2.forward(x);\n        features.push(x.clone());\n\n        // Slice 3: MaxPool + Fire3 + Fire4 (256 channels)\n        let x = max_pool2d_squeeze(x);\n        let x = self.fire3.forward(x);\n        let x = self.fire4.forward(x);\n        features.push(x.clone());\n\n        // Slice 4: MaxPool + Fire5 (384 channels)\n        let x = max_pool2d_squeeze(x);\n        let x = self.fire5.forward(x);\n        features.push(x.clone());\n\n        // Slice 5: Fire6 (384 channels)\n        let x = self.fire6.forward(x);\n        features.push(x.clone());\n\n        // Slice 6: Fire7 (512 channels)\n        let x = self.fire7.forward(x);\n        features.push(x.clone());\n\n        // Slice 7: Fire8 (512 channels)\n        let x = self.fire8.forward(x);\n        features.push(x);\n\n        features\n    }\n}\n\n/// 3x3 max pooling with stride 2, ceil mode (for SqueezeNet).\nfn max_pool2d_squeeze<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {\n    burn_core::tensor::module::max_pool2d(x, [3, 3], [2, 2], [0, 0], [1, 1], true)\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/vgg.rs",
    "content": "//! VGG16 feature extractor for LPIPS.\n\nuse burn_core as burn;\n\nuse burn::module::Module;\nuse burn::tensor::Tensor;\nuse burn::tensor::activation::relu;\nuse burn::tensor::backend::Backend;\nuse burn_nn::PaddingConfig2d;\nuse burn_nn::conv::{Conv2d, Conv2dConfig};\n\n/// VGG16 feature extractor for LPIPS.\n///\n/// Extracts features from 5 layers:\n/// - conv1_2: 64 channels\n/// - conv2_2: 128 channels\n/// - conv3_3: 256 channels\n/// - conv4_3: 512 channels\n/// - conv5_3: 512 channels\n#[derive(Module, Debug)]\npub struct VggFeatureExtractor<B: Backend> {\n    // Block 1\n    conv1_1: Conv2d<B>,\n    conv1_2: Conv2d<B>,\n    // Block 2\n    conv2_1: Conv2d<B>,\n    conv2_2: Conv2d<B>,\n    // Block 3\n    conv3_1: Conv2d<B>,\n    conv3_2: Conv2d<B>,\n    conv3_3: Conv2d<B>,\n    // Block 4\n    conv4_1: Conv2d<B>,\n    conv4_2: Conv2d<B>,\n    conv4_3: Conv2d<B>,\n    // Block 5\n    conv5_1: Conv2d<B>,\n    conv5_2: Conv2d<B>,\n    conv5_3: Conv2d<B>,\n}\n\nimpl<B: Backend> VggFeatureExtractor<B> {\n    /// Create a new VGG16 feature extractor.\n    pub fn new(device: &B::Device) -> Self {\n        let conv_config = |in_ch, out_ch| {\n            Conv2dConfig::new([in_ch, out_ch], [3, 3])\n                .with_padding(PaddingConfig2d::Same)\n                .with_bias(true)\n        };\n\n        Self {\n            // Block 1: 3 -> 64\n            conv1_1: conv_config(3, 64).init(device),\n            conv1_2: conv_config(64, 64).init(device),\n            // Block 2: 64 -> 128\n            conv2_1: conv_config(64, 128).init(device),\n            conv2_2: conv_config(128, 128).init(device),\n            // Block 3: 128 -> 256\n            conv3_1: conv_config(128, 256).init(device),\n            conv3_2: conv_config(256, 256).init(device),\n            conv3_3: conv_config(256, 256).init(device),\n            // Block 4: 256 -> 512\n            conv4_1: conv_config(256, 512).init(device),\n            conv4_2: conv_config(512, 512).init(device),\n            conv4_3: conv_config(512, 512).init(device),\n            // Block 5: 512 -> 512\n            conv5_1: conv_config(512, 512).init(device),\n            conv5_2: conv_config(512, 512).init(device),\n            conv5_3: conv_config(512, 512).init(device),\n        }\n    }\n\n    /// Extract features from 5 VGG layers.\n    pub fn forward(&self, x: Tensor<B, 4>) -> Vec<Tensor<B, 4>> {\n        let mut features = Vec::with_capacity(5);\n\n        // Block 1\n        let x = relu(self.conv1_1.forward(x));\n        let x = relu(self.conv1_2.forward(x));\n        features.push(x.clone());\n        let x = max_pool2d(x);\n\n        // Block 2\n        let x = relu(self.conv2_1.forward(x));\n        let x = relu(self.conv2_2.forward(x));\n        features.push(x.clone());\n        let x = max_pool2d(x);\n\n        // Block 3\n        let x = relu(self.conv3_1.forward(x));\n        let x = relu(self.conv3_2.forward(x));\n        let x = relu(self.conv3_3.forward(x));\n        features.push(x.clone());\n        let x = max_pool2d(x);\n\n        // Block 4\n        let x = relu(self.conv4_1.forward(x));\n        let x = relu(self.conv4_2.forward(x));\n        let x = relu(self.conv4_3.forward(x));\n        features.push(x.clone());\n        let x = max_pool2d(x);\n\n        // Block 5\n        let x = relu(self.conv5_1.forward(x));\n        let x = relu(self.conv5_2.forward(x));\n        let x = relu(self.conv5_3.forward(x));\n        features.push(x);\n\n        features\n    }\n}\n\n/// 2x2 max pooling with stride 2.\nfn max_pool2d<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {\n    burn_core::tensor::module::max_pool2d(x, [2, 2], [2, 2], [0, 0], [1, 1], false)\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/lpips/weights.rs",
    "content": "//! Pretrained weights loading for LPIPS.\n\nuse burn_core as burn;\n\nuse burn::tensor::backend::Backend;\nuse burn_std::network::downloader::download_file_as_bytes;\nuse burn_store::{ModuleSnapshot, PytorchStore};\nuse std::fs::{File, create_dir_all};\nuse std::io::Write;\nuse std::path::PathBuf;\n\nuse super::metric::{Lpips, LpipsNet};\n\n/// URLs for pretrained LPIPS linear layer weights from the official repository.\n/// Reference: https://github.com/richzhang/PerceptualSimilarity\nconst LPIPS_VGG_URL: &str =\n    \"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/vgg.pth\";\nconst LPIPS_ALEX_URL: &str =\n    \"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/alex.pth\";\nconst LPIPS_SQUEEZE_URL: &str =\n    \"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/squeeze.pth\";\n\n/// URLs for ImageNet pretrained backbone weights from PyTorch.\nconst VGG16_IMAGENET_URL: &str = \"https://download.pytorch.org/models/vgg16-397923af.pth\";\nconst ALEXNET_IMAGENET_URL: &str = \"https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth\";\nconst SQUEEZENET_IMAGENET_URL: &str =\n    \"https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth\";\n\n/// Get the download URL for LPIPS linear layer weights.\npub fn get_lpips_weights_url(net: LpipsNet) -> &'static str {\n    match net {\n        LpipsNet::Vgg => LPIPS_VGG_URL,\n        LpipsNet::Alex => LPIPS_ALEX_URL,\n        LpipsNet::Squeeze => LPIPS_SQUEEZE_URL,\n    }\n}\n\n/// Get the download URL for backbone ImageNet weights.\npub fn get_backbone_weights_url(net: LpipsNet) -> &'static str {\n    match net {\n        LpipsNet::Vgg => VGG16_IMAGENET_URL,\n        LpipsNet::Alex => ALEXNET_IMAGENET_URL,\n        LpipsNet::Squeeze => SQUEEZENET_IMAGENET_URL,\n    }\n}\n\n/// Get the cache directory for LPIPS weights.\nfn get_cache_dir() -> PathBuf {\n    let cache_dir = dirs::cache_dir()\n        .expect(\"Could not get cache directory\")\n        .join(\"burn-dataset\")\n        .join(\"lpips\");\n\n    if !cache_dir.exists() {\n        create_dir_all(&cache_dir).expect(\"Failed to create cache directory\");\n    }\n\n    cache_dir\n}\n\n/// Download file if not cached and return the cache path.\nfn download_if_needed(url: &str, cache_path: &PathBuf, message: &str) {\n    if !cache_path.exists() {\n        let bytes = download_file_as_bytes(url, message);\n        let mut file = File::create(cache_path).expect(\"Failed to create cache file\");\n        file.write_all(&bytes).expect(\"Failed to write weights\");\n    }\n}\n\n/// Download and load pretrained weights into an LPIPS module.\n///\n/// This loads both:\n/// 1. ImageNet pretrained backbone weights (VGG16/AlexNet/SqueezeNet)\n/// 2. LPIPS trained linear layer weights\n///\n/// Weights are cached in the user's cache directory to avoid re-downloading.\n///\n/// # Arguments\n///\n/// * `lpips` - The LPIPS module to load weights into.\n/// * `net` - The network type (determines which weights to download).\n///\n/// # Returns\n///\n/// The LPIPS module with loaded pretrained weights.\npub fn load_pretrained_weights<B: Backend>(mut lpips: Lpips<B>, net: LpipsNet) -> Lpips<B> {\n    let cache_dir = get_cache_dir();\n\n    // Step 1: Load backbone ImageNet weights\n    let backbone_url = get_backbone_weights_url(net);\n    let backbone_cache_path = cache_dir.join(format!(\"{:?}_backbone.pth\", net).to_lowercase());\n    let backbone_message = match net {\n        LpipsNet::Vgg => \"Downloading VGG16 ImageNet weights...\",\n        LpipsNet::Alex => \"Downloading AlexNet ImageNet weights...\",\n        LpipsNet::Squeeze => \"Downloading SqueezeNet ImageNet weights...\",\n    };\n    download_if_needed(backbone_url, &backbone_cache_path, backbone_message);\n\n    // Step 2: Load LPIPS linear layer weights\n    let lpips_url = get_lpips_weights_url(net);\n    let lpips_cache_path = cache_dir.join(format!(\"{:?}_lpips.pth\", net).to_lowercase());\n    let lpips_message = match net {\n        LpipsNet::Vgg => \"Downloading LPIPS VGG weights...\",\n        LpipsNet::Alex => \"Downloading LPIPS AlexNet weights...\",\n        LpipsNet::Squeeze => \"Downloading LPIPS SqueezeNet weights...\",\n    };\n    download_if_needed(lpips_url, &lpips_cache_path, lpips_message);\n\n    // Load backbone weights first\n    lpips = load_backbone_weights(lpips, &backbone_cache_path);\n\n    // Then load LPIPS linear layer weights\n    lpips = load_lpips_weights(lpips, &lpips_cache_path);\n\n    lpips\n}\n\n/// Load ImageNet pretrained backbone weights.\nfn load_backbone_weights<B: Backend>(lpips: Lpips<B>, cache_path: &PathBuf) -> Lpips<B> {\n    // Load directly into the inner struct to avoid enum variant issues\n    match lpips {\n        Lpips::Vgg(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                // VGG16 features.X -> extractor.convY_Z\n                .with_key_remapping(r\"^features\\.0\\.\", \"extractor.conv1_1.\")\n                .with_key_remapping(r\"^features\\.2\\.\", \"extractor.conv1_2.\")\n                .with_key_remapping(r\"^features\\.5\\.\", \"extractor.conv2_1.\")\n                .with_key_remapping(r\"^features\\.7\\.\", \"extractor.conv2_2.\")\n                .with_key_remapping(r\"^features\\.10\\.\", \"extractor.conv3_1.\")\n                .with_key_remapping(r\"^features\\.12\\.\", \"extractor.conv3_2.\")\n                .with_key_remapping(r\"^features\\.14\\.\", \"extractor.conv3_3.\")\n                .with_key_remapping(r\"^features\\.17\\.\", \"extractor.conv4_1.\")\n                .with_key_remapping(r\"^features\\.19\\.\", \"extractor.conv4_2.\")\n                .with_key_remapping(r\"^features\\.21\\.\", \"extractor.conv4_3.\")\n                .with_key_remapping(r\"^features\\.24\\.\", \"extractor.conv5_1.\")\n                .with_key_remapping(r\"^features\\.26\\.\", \"extractor.conv5_2.\")\n                .with_key_remapping(r\"^features\\.28\\.\", \"extractor.conv5_3.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\"Some VGG backbone weights could not be loaded: {:?}\", e);\n            }\n            Lpips::Vgg(inner)\n        }\n        Lpips::Alex(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                // AlexNet features.X -> extractor.convY\n                .with_key_remapping(r\"^features\\.0\\.\", \"extractor.conv1.\")\n                .with_key_remapping(r\"^features\\.3\\.\", \"extractor.conv2.\")\n                .with_key_remapping(r\"^features\\.6\\.\", \"extractor.conv3.\")\n                .with_key_remapping(r\"^features\\.8\\.\", \"extractor.conv4.\")\n                .with_key_remapping(r\"^features\\.10\\.\", \"extractor.conv5.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\"Some AlexNet backbone weights could not be loaded: {:?}\", e);\n            }\n            Lpips::Alex(inner)\n        }\n        Lpips::Squeeze(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                // SqueezeNet features.X -> extractor.*\n                .with_key_remapping(r\"^features\\.0\\.\", \"extractor.conv1.\")\n                .with_key_remapping(r\"^features\\.3\\.\", \"extractor.fire1.\")\n                .with_key_remapping(r\"^features\\.4\\.\", \"extractor.fire2.\")\n                .with_key_remapping(r\"^features\\.6\\.\", \"extractor.fire3.\")\n                .with_key_remapping(r\"^features\\.7\\.\", \"extractor.fire4.\")\n                .with_key_remapping(r\"^features\\.9\\.\", \"extractor.fire5.\")\n                .with_key_remapping(r\"^features\\.10\\.\", \"extractor.fire6.\")\n                .with_key_remapping(r\"^features\\.11\\.\", \"extractor.fire7.\")\n                .with_key_remapping(r\"^features\\.12\\.\", \"extractor.fire8.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\n                    \"Some SqueezeNet backbone weights could not be loaded: {:?}\",\n                    e\n                );\n            }\n            Lpips::Squeeze(inner)\n        }\n    }\n}\n\n/// Load LPIPS trained linear layer weights.\nfn load_lpips_weights<B: Backend>(lpips: Lpips<B>, cache_path: &PathBuf) -> Lpips<B> {\n    // Load directly into the inner struct to avoid enum variant issues\n    match lpips {\n        Lpips::Vgg(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                .with_key_remapping(r\"^lin0\\.model\\.1\\.\", \"lin0.\")\n                .with_key_remapping(r\"^lin1\\.model\\.1\\.\", \"lin1.\")\n                .with_key_remapping(r\"^lin2\\.model\\.1\\.\", \"lin2.\")\n                .with_key_remapping(r\"^lin3\\.model\\.1\\.\", \"lin3.\")\n                .with_key_remapping(r\"^lin4\\.model\\.1\\.\", \"lin4.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\"Some VGG LPIPS weights could not be loaded: {:?}\", e);\n            }\n            Lpips::Vgg(inner)\n        }\n        Lpips::Alex(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                .with_key_remapping(r\"^lin0\\.model\\.1\\.\", \"lin0.\")\n                .with_key_remapping(r\"^lin1\\.model\\.1\\.\", \"lin1.\")\n                .with_key_remapping(r\"^lin2\\.model\\.1\\.\", \"lin2.\")\n                .with_key_remapping(r\"^lin3\\.model\\.1\\.\", \"lin3.\")\n                .with_key_remapping(r\"^lin4\\.model\\.1\\.\", \"lin4.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\"Some AlexNet LPIPS weights could not be loaded: {:?}\", e);\n            }\n            Lpips::Alex(inner)\n        }\n        Lpips::Squeeze(mut inner) => {\n            let mut store = PytorchStore::from_file(cache_path)\n                .allow_partial(true)\n                .with_key_remapping(r\"^lin0\\.model\\.1\\.\", \"lin0.\")\n                .with_key_remapping(r\"^lin1\\.model\\.1\\.\", \"lin1.\")\n                .with_key_remapping(r\"^lin2\\.model\\.1\\.\", \"lin2.\")\n                .with_key_remapping(r\"^lin3\\.model\\.1\\.\", \"lin3.\")\n                .with_key_remapping(r\"^lin4\\.model\\.1\\.\", \"lin4.\")\n                .with_key_remapping(r\"^lin5\\.model\\.1\\.\", \"lin5.\")\n                .with_key_remapping(r\"^lin6\\.model\\.1\\.\", \"lin6.\");\n            if let Err(e) = inner.load_from(&mut store) {\n                log::warn!(\"Some SqueezeNet LPIPS weights could not be loaded: {:?}\", e);\n            }\n            Lpips::Squeeze(inner)\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/mod.rs",
    "content": "mod dice;\nmod dists;\nmod lpips;\nmod ms_ssim;\nmod psnr;\nmod ssim;\n\npub use dice::*;\npub use dists::*;\npub use lpips::*;\npub use ms_ssim::*;\npub use psnr::*;\npub use ssim::*;\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/ms_ssim.rs",
    "content": "use crate::metric::{\n    Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry,\n    SerializedEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Int, Tensor},\n    tensor::{\n        ElementConversion,\n        module::{avg_pool2d, conv2d},\n        ops::{ConvOptions, PadMode},\n    },\n};\nuse core::marker::PhantomData;\n\n/// Input type for the [MsSsimMetric].\n///\n/// Both tensors must have shape `[N, C, H, W]`:\n/// - `N`: Batch size\n/// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.)\n/// - `H`: Height\n/// - `W`: Width\n///\n/// # Important\n/// The image dimensions must be sufficiently large to accommodate the multi-scale\n/// computation. Specifically, for the default 5 scales used by Burn, the image dimensions\n/// should be at least `kernel_size * 2^(scales-1)` (e.g., 11 × 2^4 = 11 * 16 = 176 for default kernel size).\n/// If your images are smaller, reduce the kernel size or number of scales.\n///\n/// # Example\n/// ```rust,ignore\n/// // Create input for RGB images\n/// let outputs: Tensor<B, 4> = /* tensor */;\n/// let targets: Tensor<B, 4> = /* tensor */;\n/// let input = MsSsimInput::new(outputs, targets);\n/// ```\npub struct MsSsimInput<B: Backend> {\n    /// Model outputs with shape [N, C, H, W].\n    outputs: Tensor<B, 4>,\n    /// Ground truth targets with shape [N, C, H, W].\n    targets: Tensor<B, 4>,\n}\n\nimpl<B: Backend> MsSsimInput<B> {\n    /// Creates a new MsSsimInput with the given outputs and targets.\n    ///\n    /// # Arguments\n    /// - `outputs`: The model output images with shape [N, C, H, W].\n    /// - `targets`: The ground truth images with shape [N, C, H, W].\n    ///\n    /// # Returns\n    /// A new instance of `MsSsimInput`.\n    ///\n    /// # Panics\n    /// - If `outputs` and `targets` do not have the same shape.\n    pub fn new(outputs: Tensor<B, 4>, targets: Tensor<B, 4>) -> Self {\n        assert!(\n            outputs.dims() == targets.dims(),\n            \"Shape mismatch: outputs {:?} targets {:?}\",\n            outputs.dims(),\n            targets.dims()\n        );\n        Self { outputs, targets }\n    }\n}\n\n/// Configuration for the [MsSsimMetric].\n#[derive(Debug, Clone)]\npub struct MsSsimMetricConfig {\n    /// A parameter of SSIM used to stabilize the luminance comparison.\n    /// Default is 0.01.\n    pub k1: f32,\n    /// A parameter of SSIM used to stabilize the contrast comparison.\n    /// Default is 0.03.\n    pub k2: f32,\n    /// The range of the pixel values in images which can be computed as following:\n    /// `let pixel_range = max_pixel_val - min_pixel_val;`\n    /// where `max_pixel_val` is the maximum possible pixel value and `min_pixel_val`\n    /// is the minimum possible pixel value.\n    ///\n    /// - For normalized images in range [0, 1], it should be set to `1.0 - 0.0 = 1.0`\n    /// - For normalized images in range [-1, 1], it should be set to `1.0 - (-1.0) = 2.0`\n    /// - For 8-bit images in range [0, 255], it should be set to `255.0 - 0.0 = 255.0`\n    pub pixel_range: f32,\n    /// The MS-SSIM metric involves applying convolution to the input tensors using a Gaussian kernel.\n    /// This is the kernel size of the Gaussian kernel. Default is 11.\n    pub kernel_size: usize,\n    /// The MS-SSIM metric involves applying convolution to the input tensors using a Gaussian kernel.\n    /// This is the standard deviation of the Gaussian kernel. Default is 1.5.\n    pub sigma: f32,\n    /// The number of channels in the input images (e.g., 1 for grayscale, 3 for RGB).\n    /// This is used to create the appropriate convolution kernels. Default is 3.\n    pub channels: usize,\n    /// The weights/betas for each scale in the MS-SSIM computation.\n    /// The length of this vector determines the number of scales.\n    /// Default is \\[0.0448, 0.2856, 0.3001, 0.2363, 0.1333\\] (5 scales).\n    pub betas: Vec<f32>,\n}\n\nimpl MsSsimMetricConfig {\n    /// Creates a configuration with the specified data range and default parameters.\n    ///\n    /// # Default parameters\n    /// - k1: 0.01\n    /// - k2: 0.03\n    /// - kernel_size: 11\n    /// - sigma: 1.5\n    /// - channels: 3\n    ///\n    /// # Panics\n    /// - If `pixel_range` is not positive.\n    ///\n    /// # Example\n    /// ```rust,ignore\n    /// // For normalized RGB images [0, 1]\n    /// let config1 = MsSsimMetricConfig::new(1.0);\n    ///\n    /// // For 8-bit images [0, 255]  \n    /// let config2 = MsSsimMetricConfig::new(255.0);\n    ///\n    /// // For grayscale with custom kernel\n    /// let config3 = MsSsimMetricConfig::new(1.0)\n    ///     .with_channels(1)\n    ///     .with_kernel_size(7);\n    /// ```\n    pub fn new(pixel_range: f32) -> Self {\n        assert!(pixel_range > 0.0, \"pixel_range must be positive\");\n        Self {\n            k1: 0.01,\n            k2: 0.03,\n            pixel_range,\n            kernel_size: 11,\n            sigma: 1.5,\n            channels: 3,\n            betas: vec![0.0448, 0.2856, 0.3001, 0.2363, 0.1333],\n        }\n    }\n\n    /// Sets custom values for the k1 and k2 parameters of MS-SSIM which are\n    /// used for numerical stability.\n    ///\n    /// # Default values\n    /// - k1: 0.01\n    /// - k2: 0.03\n    ///\n    /// # Panics\n    /// - If `k1` or `k2` is not positive.\n    pub fn with_k1_k2(mut self, k1: f32, k2: f32) -> Self {\n        assert!(k1 > 0.0, \"k1 must be positive\");\n        assert!(k2 > 0.0, \"k2 must be positive\");\n        self.k1 = k1;\n        self.k2 = k2;\n        self\n    }\n\n    /// Sets a custom kernel size for the Gaussian kernel used in MS-SSIM. The\n    /// kernel size must be a positive odd number.\n    ///\n    /// # Default value\n    /// - kernel_size: 11\n    ///\n    /// # Panics\n    /// - If `kernel_size` is not a positive odd number.\n    pub fn with_kernel_size(mut self, kernel_size: usize) -> Self {\n        assert!(\n            kernel_size > 0 && kernel_size % 2 == 1,\n            \"kernel_size must be positive and an odd number\"\n        );\n        self.kernel_size = kernel_size;\n        self\n    }\n\n    /// Sets a custom sigma (standard deviation) for the Gaussian kernel used in MS-SSIM.\n    ///\n    /// # Default value\n    /// - sigma: 1.5\n    ///\n    /// # Panics\n    /// - If `sigma` is not positive.\n    pub fn with_sigma(mut self, sigma: f32) -> Self {\n        assert!(sigma > 0.0, \"sigma must be a positive number\");\n        self.sigma = sigma;\n        self\n    }\n\n    /// Sets the number of channels for the input images.\n    ///\n    /// This affects the shape of the pre-computed convolution kernels.\n    /// Change this if working with grayscale (1) or multispectral images (>3).\n    ///\n    /// # Default value\n    /// - channels: 3\n    ///\n    /// # Panics\n    /// - If `channels` is 0.\n    pub fn with_channels(mut self, channels: usize) -> Self {\n        assert!(channels > 0, \"channels must be a positive number\");\n        self.channels = channels;\n        self\n    }\n\n    /// Sets custom betas for the scales. The length of the betas vector\n    /// determines the number of scales used in the MS-SSIM computation.\n    /// If you want to make different parameter settings comparable, the betas\n    /// vector should sum to 1 as per the original paper. However, note\n    /// that this is not a strict requirement.\n    ///\n    /// # Default value\n    /// - betas: `[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]` (5 scales)\n    ///\n    /// # Panics\n    /// - If `betas` is empty.\n    /// - If not all values in `betas` are positive.\n    pub fn with_betas(mut self, betas: Vec<f32>) -> Self {\n        assert!(!betas.is_empty(), \"betas vector cannot be empty\");\n\n        assert!(\n            betas.iter().all(|&b| b >= 0.0),\n            \"All beta values must be non-negative\"\n        );\n\n        self.betas = betas;\n        self\n    }\n}\n\n/// Multi-Scale Structural Similarity Index (MS-SSIM) metric for image quality assessment.\n///\n/// MS-SSIM extends the single-scale [SSIM](crate::metric::vision::SsimMetric) by computing\n/// the index at multiple resolutions (scales) and combining them using weighted averaging.\n/// This approach better correlates with human visual perception, especially for\n/// high-resolution images where fine details and texture variations are important.\n///\n/// # Algorithm Overview\n///\n/// MS-SSIM computes structural similarity across M scales (M=5 in Burn):\n///\n/// 1. **Contrast** and **Structure** components are computed at every scale\n/// 2. **Luminance** is computed only at the coarsest (last) scale\n/// 3. Between scales, images are downsampled by a factor of 2 using average pooling\n///\n/// The final metric is computed as:\n/// ```text\n/// MS-SSIM = L_M^{α_M} × ∏_{j=1}^M (C_j^{β_j} × S_j^{γ_j})\n/// ```\n///\n/// Where:\n/// - `L_M` is luminance at the last scale (M)\n/// - `C_j` is contrast at scale j: `(2σ_xσ_y + C2) / (σ_x² + σ_y² + C2)`\n/// - `S_j` is structure at scale j: `(σ_xy + C3) / (σ_xσ_y + C3)`\n/// - `α_M, β_j, γ_j` are weights from Wang et al. (\\[0.0448, 0.2856, 0.3001, 0.2363, 0.1333\\])\n///\n/// # Notes\n///\n/// - This implementation uses separable Gaussian convolution for efficiency (reduces complexity from O(K^2) to O(2K) per pixel)\n/// - Gaussian kernels are pre-computed during initialization to avoid redundant computation\n/// - The metric requires images to be large enough to survive the downsampling operations\n///\n/// # Value Range\n///\n/// MS-SSIM values typically range from 0 to 1, where:\n/// - 1.0 indicates perfect structural similarity (identical images)\n/// - 0.0 indicates no structural similarity\n/// - Values are usually positive due to the stability constants (C1, C2, C3)\n///\n/// # References\n///\n/// [Multi-scale Structural Similarity for Image Quality Assessment](https://www.cns.nyu.edu/pub/eero/wang03b.pdf)\n#[derive(Clone)]\npub struct MsSsimMetric<B: Backend> {\n    name: MetricName,\n    /// Internal state for numeric metric aggregation.\n    state: NumericMetricState,\n    /// Marker for backend type.\n    _b: PhantomData<B>,\n    /// Configuration for the metric.\n    config: MsSsimMetricConfig,\n    /// Pre-computed horizontal Gaussian kernel with shape [C, 1, 1, K]\n    horizontal_kernel: Tensor<B, 4>,\n    /// Pre-computed vertical Gaussian kernel with shape [C, 1, K, 1]\n    vertical_kernel: Tensor<B, 4>,\n}\n\nimpl<B: Backend> MsSsimMetric<B> {\n    /// Creates a new MS-SSIM metric with the given configuration.\n    ///\n    /// # Arguments\n    /// - `config`: Configuration for the metric (data range, kernel size, etc.)\n    /// - `device`: Device to place the Gaussian kernels on\n    ///\n    /// # Note\n    /// The default metric name format is \"MS-SSIM (pr={}, k={}, σ={})\"\n    /// where pr is the pixel range, k is the kernel size, and σ is the\n    /// standard deviation.\n    ///\n    /// # Example\n    /// ```ignore\n    /// let config = MsSsimMetricConfig::new(1.0).with_channels(1); // Grayscale\n    /// let metric = MsSsimMetric::<B>::new(config, &device);\n    /// ```\n    pub fn new(config: MsSsimMetricConfig, device: &B::Device) -> Self {\n        let kernel = Self::create_1d_gaussian_kernel(&config, device);\n        let size = config.kernel_size;\n\n        // Create horizontal kernel: shape [C, 1, 1, K] for depthwise conv\n        let horizontal_kernel = kernel\n            .clone()\n            .reshape([1, 1, 1, size])\n            .repeat_dim(0, config.channels);\n\n        // Create vertical kernel: shape [C, 1, K, 1] for depthwise conv\n        let vertical_kernel = kernel\n            .reshape([1, 1, size, 1])\n            .repeat_dim(0, config.channels);\n\n        Self {\n            name: MetricName::new(format!(\n                \"MS-SSIM (pr={}, k={}, σ={})\",\n                config.pixel_range, config.kernel_size, config.sigma\n            )),\n            state: NumericMetricState::default(),\n            _b: PhantomData,\n            config,\n            horizontal_kernel,\n            vertical_kernel,\n        }\n    }\n\n    /// Overrides the default metric name.\n    ///\n    /// # Example\n    /// ```ignore\n    /// let metric = MsSsimMetric::<B>::new(config, &device)\n    ///     .with_name(\"Custom MS-SSIM Name\");\n    /// ```\n    pub fn with_name(mut self, name: &str) -> Self {\n        self.name = MetricName::new(name.to_string());\n        self\n    }\n\n    /// Creates a normalized 1D Gaussian kernel as a tensor where the kernel values sum to 1.0.\n    fn create_1d_gaussian_kernel(config: &MsSsimMetricConfig, device: &B::Device) -> Tensor<B, 1> {\n        let size = config.kernel_size as i64;\n        let sigma = config.sigma;\n        let center = (size / 2) as f32;\n\n        let one_to_size_tensor = Tensor::<B, 1, Int>::arange(0..size, device).float();\n        let x_vals = one_to_size_tensor.sub_scalar(center);\n\n        // Gaussian: exp(-x² / 2σ²)\n        let x_squared = x_vals.clone().mul(x_vals);\n        let x_squared_div_2_sigma_squared = x_squared.div_scalar(2.0 * sigma * sigma);\n        let unnormalized_kernel = x_squared_div_2_sigma_squared.neg().exp();\n        let kernel_vals_sum = unnormalized_kernel.clone().sum();\n        unnormalized_kernel.div(kernel_vals_sum)\n    }\n\n    /// Applies separable Gaussian convolution using pre-computed kernels.\n    ///\n    /// Performs two 1D convolutions (horizontal then vertical) which is\n    /// computationally cheaper than a single 2D convolution.\n    ///\n    /// # Arguments\n    /// - `input`: Tensor of shape [N, C, H, W]\n    fn gaussian_separable_conv(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let padding = self.config.kernel_size / 2;\n        let h_kernel = self.horizontal_kernel.clone();\n        let v_kernel = self.vertical_kernel.clone();\n\n        // Apply reflect padding to all 4 sides of the input tensor before convolution\n        // Format: (left, right, top, bottom)\n        let padded_input = input.pad((padding, padding, padding, padding), PadMode::Reflect);\n\n        let h_conv_options = ConvOptions::new([1, 1], [0, 0], [1, 1], self.config.channels);\n        let v_conv_options = ConvOptions::new([1, 1], [0, 0], [1, 1], self.config.channels);\n\n        let input_after_h_conv = conv2d(padded_input, h_kernel, None, h_conv_options);\n        conv2d(input_after_h_conv, v_kernel, None, v_conv_options)\n    }\n}\n\nimpl<B: Backend> Metric for MsSsimMetric<B> {\n    type Input = MsSsimInput<B>;\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let dims = item.outputs.dims();\n        let scales = self.config.betas.len();\n\n        assert_eq!(\n            dims[1], self.config.channels,\n            \"Input has {} channels but metric was configured for {}\",\n            dims[1], self.config.channels\n        );\n\n        // Verify minimum size for the given number of scales\n        // After (scales - 1) downsamples, size is original / 2^(scales-1)\n        // We need kernel_size at that scale\n        let downsample_ops_num = scales.saturating_sub(1) as u32;\n        let min_size = self.config.kernel_size * (2usize.pow(downsample_ops_num));\n        let h = dims[2];\n        let w = dims[3];\n        assert!(\n            h >= min_size && w >= min_size,\n            \"Image dimensions (H={}, W={}) must be at least {} to support {} scales of MS-SSIM \\\n                with kernel_size={}. Either increase image size, reduce kernel_size, or reduce the number of scales (betas).\",\n            h,\n            w,\n            min_size,\n            scales,\n            self.config.kernel_size\n        );\n\n        let mut x = item.outputs.clone();\n        let mut y = item.targets.clone();\n        let betas = &self.config.betas;\n\n        // Compute c1 = (k1 * L)^2 and c2 = (k2 * L)^2, c3 = c2/2\n        let c1 = (self.config.k1 * self.config.pixel_range).powi(2);\n        let c2 = (self.config.k2 * self.config.pixel_range).powi(2);\n\n        // Initialize accumulator to 1 for update via multiplication\n        // Shape: [N, C]\n        let batch_size = dims[0];\n        let channels = dims[1];\n        let mut ms_ssim_tensor =\n            Tensor::<B, 2>::ones([batch_size, channels], &item.outputs.device());\n\n        for (j, beta_j) in betas.iter().enumerate() {\n            // Compute mu_x and mu_y\n            let mu_x = self.gaussian_separable_conv(x.clone());\n            let mu_y = self.gaussian_separable_conv(y.clone());\n            let square_of_mu_x = mu_x.clone() * mu_x.clone();\n            let square_of_mu_y = mu_y.clone() * mu_y.clone();\n\n            // Var(X) = E(X^2) - E(X)^2\n            let mu_of_x_squared = self.gaussian_separable_conv(x.clone() * x.clone());\n            let mu_of_y_squared = self.gaussian_separable_conv(y.clone() * y.clone());\n            let var_x = (mu_of_x_squared - square_of_mu_x.clone()).clamp_min(0.0);\n            let var_y = (mu_of_y_squared - square_of_mu_y.clone()).clamp_min(0.0);\n\n            // Cov(X, Y) = E(XY) - E(X)E(Y)\n            let mu_of_xy = self.gaussian_separable_conv(x.clone() * y.clone());\n            let cov_xy = mu_of_xy - (mu_x.clone() * mu_y.clone());\n\n            // Compute cs_map = (2σxy + C2) / (σx² + σy² + C2)\n            // This is mathematically equivalent to c(x,y) * s(x,y) when C3 = C2 / 2\n            let contrast_structure = (cov_xy * 2.0 + c2) / (var_x + var_y + c2);\n\n            // Include luminance at the last scale\n            if j == betas.len() - 1 {\n                // Compute l(x, y) = (2μxμy + C1) / (μx² + μy² + C1)\n                let luminance: Tensor<B, 4> =\n                    (2 * mu_x * mu_y + c1) / (square_of_mu_x + square_of_mu_y + c1);\n                let ssim = luminance * contrast_structure;\n                let ssim_spatial_mean = ssim.mean_dims(&[2, 3]).reshape([batch_size, channels]);\n                // Clamp to avoid negative values before raising to power (prevents NaNs)\n                let ssim_mean_clamped = ssim_spatial_mean.clamp_min(0.0);\n                ms_ssim_tensor = ms_ssim_tensor * ssim_mean_clamped.powf_scalar(*beta_j);\n            } else {\n                let contrast_structure_spatial_mean = contrast_structure\n                    .mean_dims(&[2, 3])\n                    .reshape([batch_size, channels]);\n                // Clamp to avoid negative values before raising to power (prevents NaNs)\n                let c_s_mean_clamped = contrast_structure_spatial_mean.clamp_min(0.0);\n                ms_ssim_tensor = ms_ssim_tensor * c_s_mean_clamped.powf_scalar(*beta_j);\n\n                x = avg_pool2d(x, [2, 2], [2, 2], [0, 0], false, false);\n                y = avg_pool2d(y, [2, 2], [2, 2], [0, 0], false, false);\n            }\n        }\n\n        let ms_ssim_per_image = ms_ssim_tensor.mean_dim(1);\n        let avg_ms_ssim = ms_ssim_per_image.mean().into_scalar().elem::<f64>();\n\n        self.state.update(\n            avg_ms_ssim,\n            batch_size,\n            FormatOptions::new(self.name()).precision(4),\n        )\n    }\n\n    /// Clears the metric state.\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for MsSsimMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, metric::Numeric};\n    use burn_core::tensor::Distribution;\n\n    fn test_config() -> MsSsimMetricConfig {\n        // Use small kernel and single channel for testing\n        // With kernel_size=3, we need images >= 3*16=48\n        MsSsimMetricConfig::new(1.0)\n            .with_kernel_size(3)\n            .with_sigma(1.0)\n            .with_channels(1)\n    }\n\n    #[test]\n    fn test_ms_ssim_perfect_similarity() {\n        // Identical images should give MS-SSIM = 1.0\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            [[[\n                [0.5_f32; 64]; 64  // 64x64 constant image\n            ]]],\n            &device,\n        );\n        let targets = outputs.clone();\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n        let input = MsSsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ms_ssim = metric.value().current();\n        assert!(\n            ms_ssim > 0.99,\n            \"MS-SSIM for identical images should be 1.0, got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_completely_different() {\n        // Black vs white images should give very low MS-SSIM (close to 0.0)\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 256, 256], &device);\n        let targets = Tensor::<TestBackend, 4>::ones([1, 1, 256, 256], &device);\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n        let input = MsSsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ms_ssim = metric.value().current();\n        assert!(\n            (ms_ssim - 0.3).abs() < 0.01,\n            \"MS-SSIM for black vs white should be low (around 0.3), got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_similar_images() {\n        // Small perturbation should give high MS-SSIM (close to 1.0)\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::full([1, 1, 64, 64], 0.5, &device);\n        let targets = Tensor::<TestBackend, 4>::full([1, 1, 64, 64], 0.52, &device);\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n        let input = MsSsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ms_ssim = metric.value().current();\n        assert!(\n            ms_ssim > 0.95,\n            \"MS-SSIM for very similar images should be close to 1.0, got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_batch_averaging() {\n        let device = Default::default();\n        // Batch of 2: one identical, one different\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            [\n                [[[0.5_f32; 64]; 64]], // Image 1: constant 0.5\n                [[[0.0_f32; 64]; 64]], // Image 2: constant 0.0 (black)\n            ],\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            [\n                [[[0.5_f32; 64]; 64]], // Image 1: identical\n                [[[1.0_f32; 64]; 64]], // Image 2: white (opposite)\n            ],\n            &device,\n        );\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n        let input = MsSsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ms_ssim = metric.value().current();\n        // Average of ~1.0 and ~0.292 should be around 0.64\n        assert!(\n            (ms_ssim - 0.64).abs() < 0.02,\n            \"Average MS-SSIM should be around 0.64, got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_multichannel() {\n        let device = Default::default();\n        // Test with 3 channels (RGB)\n        let config = MsSsimMetricConfig::new(1.0)\n            .with_kernel_size(3)\n            .with_sigma(1.0)\n            .with_channels(3);\n\n        let outputs = Tensor::<TestBackend, 4>::random(\n            [2, 3, 64, 64],\n            Distribution::Uniform(0.0, 1.0),\n            &device,\n        );\n        let targets = outputs.clone();\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(config, &device);\n        let input = MsSsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ms_ssim = metric.value().current();\n        assert!(\n            ms_ssim > 0.99,\n            \"MS-SSIM for identical RGB images should be 1.0, got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_running_average() {\n        let device = Default::default();\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n\n        // First update: identical (1.0)\n        let img1 = Tensor::<TestBackend, 4>::full([1, 1, 64, 64], 0.5, &device);\n        let input1 = MsSsimInput::new(img1.clone(), img1);\n        metric.update(&input1, &MetricMetadata::fake());\n\n        assert!(\n            metric.value().current() > 0.99,\n            \"First update should be approximately 1.0\"\n        );\n\n        // Second update: different (~0.29)\n        let black = Tensor::<TestBackend, 4>::zeros([1, 1, 64, 64], &device);\n        let white = Tensor::<TestBackend, 4>::ones([1, 1, 64, 64], &device);\n        let input2 = MsSsimInput::new(black, white);\n        metric.update(&input2, &MetricMetadata::fake());\n\n        let running = metric.running_value().current();\n        assert!(\n            (running - 0.64).abs() < 0.02,\n            \"Running average should be approximately 0.64, got {}\",\n            running\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_single_scale_small_image() {\n        let device = Default::default();\n        // Default 5 scales with kernel_size=11 requires a 176x176 image.\n        // With a single scale, the minimum required size drops to\n        // just 11x11 (kernel_size * 2^0).\n        let config = MsSsimMetricConfig::new(1.0)\n            .with_channels(1)\n            .with_betas(vec![1.0]); // 1 scale\n\n        let mut metric = MsSsimMetric::<TestBackend>::new(config, &device);\n\n        // Create a 16x16 image. This would normally panic with 5 scales,\n        // but should succeed with 1 scale.\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 16, 16], &device);\n        let targets = outputs.clone();\n        let input = MsSsimInput::new(outputs, targets);\n\n        // This should not panic\n        let _ = metric.update(&input, &MetricMetadata::fake());\n\n        // Identical images should still yield ~1.0\n        let ms_ssim = metric.value().current();\n        assert!(\n            ms_ssim > 0.99,\n            \"1-scale MS-SSIM for identical images should be 1.0, got {}\",\n            ms_ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_symmetry() {\n        // MS-SSIM(x, y) should equal MS-SSIM(y, x)\n        // Symmetry is one of the mathematical properties of MS-SSIM\n        let device = Default::default();\n        let config = MsSsimMetricConfig::new(1.0)\n            .with_kernel_size(3)\n            .with_sigma(1.0)\n            .with_channels(3);\n\n        let img1 = Tensor::<TestBackend, 4>::random(\n            [2, 3, 64, 64],\n            Distribution::Uniform(0.0, 1.0),\n            &device,\n        );\n        let img2 = Tensor::<TestBackend, 4>::random(\n            [2, 3, 64, 64],\n            Distribution::Uniform(0.0, 1.0),\n            &device,\n        );\n\n        let mut metric1 = MsSsimMetric::<TestBackend>::new(config.clone(), &device);\n        let input1 = MsSsimInput::new(img1.clone(), img2.clone());\n        let _entry = metric1.update(&input1, &MetricMetadata::fake());\n        let ms_ssim1 = metric1.value().current();\n\n        let mut metric2 = MsSsimMetric::<TestBackend>::new(config, &device);\n        let input2 = MsSsimInput::new(img2, img1);\n        let _entry = metric2.update(&input2, &MetricMetadata::fake());\n        let ms_ssim2 = metric2.value().current();\n\n        assert!(\n            (ms_ssim1 - ms_ssim2).abs() < 0.001,\n            \"MS-SSIM should be symmetric: MS-SSIM(x,y)={} vs MS-SSIM(y,x)={}\",\n            ms_ssim1,\n            ms_ssim2\n        );\n    }\n\n    #[test]\n    fn test_ms_ssim_clear() {\n        let device = Default::default();\n        let mut metric = MsSsimMetric::<TestBackend>::new(test_config(), &device);\n\n        let img = Tensor::<TestBackend, 4>::full([1, 1, 64, 64], 0.5, &device);\n        let input = MsSsimInput::new(img.clone(), img);\n        metric.update(&input, &MetricMetadata::fake());\n\n        assert!(metric.value().current() > 0.99);\n\n        metric.clear();\n        assert!(metric.running_value().current().is_nan());\n    }\n\n    #[test]\n    fn test_ms_ssim_custom_name() {\n        let device = Default::default();\n        let config = MsSsimMetricConfig::new(1.0);\n        let metric = MsSsimMetric::<TestBackend>::new(config, &device).with_name(\"CustomMS-SSIM\");\n        assert_eq!(metric.name().to_string(), \"CustomMS-SSIM\");\n    }\n\n    #[test]\n    fn test_ms_ssim_default_name() {\n        let device = Default::default();\n        let config = MsSsimMetricConfig::new(255.0);\n        let metric = MsSsimMetric::<TestBackend>::new(config, &device);\n        assert_eq!(metric.name().to_string(), \"MS-SSIM (pr=255, k=11, σ=1.5)\");\n    }\n\n    #[test]\n    fn test_ms_ssim_attributes() {\n        let device = Default::default();\n        let config = MsSsimMetricConfig::new(1.0);\n        let metric = MsSsimMetric::<TestBackend>::new(config, &device);\n\n        match metric.attributes() {\n            MetricAttributes::Numeric(attrs) => {\n                assert!(attrs.higher_is_better);\n                assert_eq!(attrs.unit, None);\n            }\n            _ => panic!(\"Expected numeric attributes\"),\n        }\n    }\n\n    #[test]\n    #[should_panic(expected = \"Shape mismatch\")]\n    fn test_ms_ssim_shape_mismatch() {\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 64, 64], &device);\n        let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 32, 32], &device);\n        let _ = MsSsimInput::new(outputs, targets);\n    }\n\n    #[test]\n    #[should_panic(expected = \"k1 must be positive\")]\n    fn test_ms_ssim_negative_k1() {\n        let _ = MsSsimMetricConfig::new(1.0).with_k1_k2(-0.01, 0.03);\n    }\n\n    #[test]\n    #[should_panic(expected = \"k2 must be positive\")]\n    fn test_ms_ssim_negative_k2() {\n        let _ = MsSsimMetricConfig::new(1.0).with_k1_k2(0.01, -0.03);\n    }\n\n    #[test]\n    #[should_panic(expected = \"pixel_range must be positive\")]\n    fn test_ms_ssim_negative_data_range() {\n        let _ = MsSsimMetricConfig::new(-1.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"pixel_range must be positive\")]\n    fn test_ms_ssim_zero_data_range() {\n        let _ = MsSsimMetricConfig::new(0.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"kernel_size must be positive and an odd number\")]\n    fn test_ms_ssim_even_kernel_size() {\n        let _ = MsSsimMetricConfig::new(1.0).with_kernel_size(10);\n    }\n\n    #[test]\n    #[should_panic(expected = \"kernel_size must be positive and an odd number\")]\n    fn test_ms_ssim_zero_kernel_size() {\n        let _ = MsSsimMetricConfig::new(1.0).with_kernel_size(0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"sigma must be a positive number\")]\n    fn test_ms_ssim_negative_sigma() {\n        let _ = MsSsimMetricConfig::new(1.0).with_sigma(-1.5);\n    }\n\n    #[test]\n    #[should_panic(expected = \"sigma must be a positive number\")]\n    fn test_ms_ssim_zero_sigma() {\n        let _ = MsSsimMetricConfig::new(1.0).with_sigma(0.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"channels must be a positive number\")]\n    fn test_ms_ssim_zero_channels() {\n        let _ = MsSsimMetricConfig::new(1.0).with_channels(0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"betas vector cannot be empty\")]\n    fn test_ms_ssim_empty_betas() {\n        let _ = MsSsimMetricConfig::new(1.0).with_betas(vec![]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"All beta values must be non-negative\")]\n    fn test_ms_ssim_negative_betas() {\n        let _ = MsSsimMetricConfig::new(1.0).with_betas(vec![0.3, 0.3, -0.1, 0.5]);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Image dimensions\")]\n    fn test_ms_ssim_image_too_small() {\n        let device = Default::default();\n        // 3 scales with kernel_size=11 requires 44x44 minimum (11 * 2^2)\n        let config = MsSsimMetricConfig::new(1.0).with_betas(vec![0.5, 0.3, 0.2]);\n        let mut metric = MsSsimMetric::<TestBackend>::new(config, &device);\n\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 3, 32, 32], &device); // Too small (32 < 44)\n        let targets = outputs.clone();\n        let input = MsSsimInput::new(outputs, targets);\n        let _ = metric.update(&input, &MetricMetadata::fake());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/psnr.rs",
    "content": "use crate::metric::{\n    Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry,\n    SerializedEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::ElementConversion,\n};\nuse core::marker::PhantomData;\nuse std::f64::consts::LN_10;\n\n/// Input type for the [PsnrMetric].\n///\n/// Both tensors must have shape `[N, C, H, W]`:\n/// - `N`: Batch size\n/// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.)\n/// - `H`: Height\n/// - `W`: Width\npub struct PsnrInput<B: Backend> {\n    /// Model output (predictions/reconstructions) images with shape `[N, C, H, W]`.\n    outputs: Tensor<B, 4>,\n    /// Ground truth images with shape `[N, C, H, W]`.\n    targets: Tensor<B, 4>,\n}\n\nimpl<B: Backend> PsnrInput<B> {\n    /// Creates a new PsnrInput with the given outputs and targets.\n    ///\n    /// Inputs are expected to have the dimensions `[N, C, H, W]`\n    /// where `N` is the batch size, `C` is the number of channels,\n    /// `H` is the height of the image, and `W` is the width of the image.\n    ///\n    /// # Arguments\n    /// - `outputs`: The model output images with shape `[N, C, H, W]`.\n    /// - `targets`: The ground truth images with shape `[N, C, H, W]`.\n    ///\n    /// # Returns\n    /// A new instance of `PsnrInput`.\n    ///\n    /// # Panics\n    /// - If `outputs` and `targets` do not have the same shape.\n    pub fn new(outputs: Tensor<B, 4>, targets: Tensor<B, 4>) -> Self {\n        assert!(\n            outputs.dims() == targets.dims(),\n            \"Shape mismatch: outputs {:?}, targets {:?}\",\n            outputs.dims(),\n            targets.dims()\n        );\n        Self { outputs, targets }\n    }\n}\n\n/// Configuration for the [PsnrMetric].\n#[derive(Debug, Clone, Copy)]\npub struct PsnrMetricConfig {\n    /// Maximum possible pixel value.\n    /// - Use `1.0` for normalized images in range \\[0, 1\\]\n    /// - Use `255.0` for 8-bit images in range \\[0, 255\\]\n    pub max_pixel_val: f64,\n    /// Epsilon value for numerical stability when MSE is very small or zero.\n    ///\n    /// When MSE falls below this threshold, it is clamped to `epsilon`,\n    /// resulting in a maximum PSNR of approximately `10 * log10(max_pixel_val² / epsilon)` dB.\n    ///\n    /// Default is `1e-10`, which yields ~100 dB for perfect reconstruction with `max_pixel_val = 1.0`.\n    pub epsilon: f64,\n}\n\nimpl PsnrMetricConfig {\n    /// Creates a configuration with the specified maximum pixel value.\n    ///\n    /// # Example\n    /// ```ignore\n    /// // Normalized images [0, 1]\n    /// let config = PsnrMetricConfig::new(1.0);\n    ///\n    /// // 8-bit images [0, 255]  \n    /// let config = PsnrMetricConfig::new(255.0);\n    /// // Also set a custom epsilon value\n    /// let config = PsnrMetricConfig::new(255.0).with_epsilon(1e-8);\n    /// ```\n    pub fn new(max_pixel_val: f64) -> Self {\n        assert!(max_pixel_val > 0.0, \"max_pixel_val must be positive\");\n        Self {\n            max_pixel_val,\n            epsilon: 1e-10,\n        }\n    }\n\n    /// Sets a custom epsilon for numerical stability near zero MSE\n    pub fn with_epsilon(mut self, epsilon: f64) -> Self {\n        assert!(epsilon > 0.0, \"epsilon must be positive\");\n        self.epsilon = epsilon;\n        self\n    }\n}\n\n/// The peak signal-to-noise ratio (PSNR) metric for image quality assessment.\n///\n/// PSNR is commonly used to measure the quality of reconstructed images\n/// compared to the original. Higher values (in dB) indicate better quality.\n///\n/// # Formula\n/// ```text\n/// PSNR = 10 * log10(MAX^2 / MSE)\n/// ```\n/// where MAX is the maximum possible pixel value and MSE is the mean squared error.\n///\n/// # Note\n/// - PSNR is computed for each image first, and then it is averaged across all the images in the batch.\n/// - For perfect reconstruction (MSE = 0), the MSE is clamped to `epsilon` to avoid division by zero,\n///   yielding a maximum PSNR of `10 * log10(MAX^2 / epsilon)` dB.\n#[derive(Clone)]\npub struct PsnrMetric<B: Backend> {\n    name: MetricName,\n    /// Internal state for numeric metric aggregation.\n    state: NumericMetricState,\n    /// Marker for backend type.\n    _b: PhantomData<B>,\n    /// Configuration for the metric.\n    config: PsnrMetricConfig,\n}\n\nimpl<B: Backend> PsnrMetric<B> {\n    /// Creates a new PSNR metric with the given configuration.\n    ///\n    /// # Example\n    /// ```ignore\n    /// let config = PsnrMetricConfig::new(1.0);\n    /// let metric = PsnrMetric::<B>::new(config);\n    /// ```\n    pub fn new(config: PsnrMetricConfig) -> Self {\n        Self {\n            name: MetricName::new(format!(\"PSNR@{}\", config.max_pixel_val)),\n            state: NumericMetricState::default(),\n            config,\n            _b: PhantomData,\n        }\n    }\n\n    /// Overrides the default metric name which is `PSNR@{max_pixel_val}`.\n    ///\n    /// Examples names:\n    /// - `PSNR@1.0`\n    /// - `PSNR@255.0`\n    ///\n    /// Use this method to provide a custom name.\n    pub fn with_name(mut self, name: &str) -> Self {\n        self.name = MetricName::new(name.to_string());\n        self\n    }\n}\n\nimpl<B: Backend> Metric for PsnrMetric<B> {\n    type Input = PsnrInput<B>;\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let dims = item.outputs.dims();\n        let batch_size = dims[0];\n        let outputs = item.outputs.clone();\n        let targets = item.targets.clone();\n\n        // Compute per-image MSE by reducing over all dimensions except batch (dims 1, 2, 3)\n        // Resulting shape: [N, 1, 1, 1]\n        let diff = outputs.sub(targets);\n        let mse_per_image = diff.powi_scalar(2).mean_dims(&[1, 2, 3]);\n        // Flatten to shape: [N]\n        let mse_flat = mse_per_image.flatten::<1>(0, 3);\n        // Clamp MSE to avoid division by 0 in the expression (MAX^2 / MSE)\n        let mse_clamped = mse_flat.clamp_min(self.config.epsilon);\n        let max_squared = self.config.max_pixel_val * self.config.max_pixel_val;\n\n        // Compute PSNR for each image and accumulate\n        // PSNR value in dB (using the change of base formula):\n        // 10 * log10(MAX^2 / MSE) = 10 * ln(MAX^2 / MSE) / ln(10)\n        //                         = ln(MAX^2 / MSE) * (10 / ln(10))\n        let psnr_per_image = mse_clamped\n            .recip()\n            .mul_scalar(max_squared)\n            .log()\n            .mul_scalar(10.0 / LN_10);\n        let avg_psnr = psnr_per_image.mean().into_scalar().elem::<f64>();\n\n        self.state.update(\n            avg_psnr,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"dB\").precision(2),\n        )\n    }\n\n    /// Clears the metric state.\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"dB\".to_string()),\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for PsnrMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, metric::Numeric};\n    use burn_core::tensor::TensorData;\n\n    #[test]\n    fn test_psnr_perfect_reconstruction() {\n        // When outputs exactly match targets, PSNR should be very high\n        // (limited by epsilon clamping to ~100 dB with default epsilon=1e-10)\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[1.0_f32, 0.5], [0.25, 0.75]]]]),\n            &device,\n        );\n        let targets = outputs.clone();\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        // With epsilon = 1e-10 and max=1.0:\n        // PSNR = 10 * log10(1.0 / 1e-10) = 100 dB\n        let psnr = metric.value().current();\n        assert!(\n            psnr >= 99.0,\n            \"PSNR for perfect reconstruction should be ~100 dB, got {} dB\",\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_constant_error() {\n        // Constant error of 0.1 across all pixels\n        // MSE = 0.01, PSNR = 10 * log10(1.0 / 0.01) = 20 dB\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),\n            &device,\n        );\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        assert!(\n            (psnr - 20.0).abs() < 0.01,\n            \"Expected PSNR ~20 dB, got {} dB\",\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_varying_error() {\n        // Errors: 0.1, 0.2, 0.3, 0.4 → squared: 0.01, 0.04, 0.09, 0.16\n        // MSE = 0.075, PSNR = 10 * log10(1.0 / 0.075) ≈ 11.249 dB\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.1_f32, 0.2], [0.3, 0.4]]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),\n            &device,\n        );\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 10.0 * (1.0_f64 / 0.075).log10();\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{:.3} dB, got {} dB\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_max_pixel_255() {\n        // Test with 8-bit image range [0, 255]\n        // Error = 10 everywhere, MSE = 100\n        // PSNR = 10 * log10(255^2 / 100) ≈ 28.13 dB\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[10.0_f32, 10.0], [10.0, 10.0]]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),\n            &device,\n        );\n\n        let config = PsnrMetricConfig::new(255.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 10.0 * (255.0_f64 * 255.0 / 100.0).log10();\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{:.3} dB, got {} dB\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_batch_averaging() {\n        // Batch of 2 images with different MSEs\n        // Image 1: error 0.1 → MSE = 0.01 → PSNR = 20 dB\n        // Image 2: error 0.01 → MSE = 0.0001 → PSNR = 40 dB\n        // Average PSNR = 30 dB\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[[0.1_f32, 0.1], [0.1, 0.1]]],\n                [[[0.01_f32, 0.01], [0.01, 0.01]]],\n            ]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[[0.0_f32, 0.0], [0.0, 0.0]]],\n                [[[0.0_f32, 0.0], [0.0, 0.0]]],\n            ]),\n            &device,\n        );\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 30.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected average PSNR ~{} dB, got {} dB\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_multichannel() {\n        // Test with 3 channels (RGB-like)\n        // All channels have constant error 0.1 → MSE = 0.01 → PSNR = 20 dB\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[\n                [[0.1_f32, 0.1], [0.1, 0.1]],\n                [[0.1_f32, 0.1], [0.1, 0.1]],\n                [[0.1_f32, 0.1], [0.1, 0.1]],\n            ]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::zeros([1, 3, 2, 2], &device);\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 20.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{} dB, got {} dB\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_running_average() {\n        // Test running average across multiple updates\n        let device = Default::default();\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n\n        // First update: error 0.1 → MSE = 0.01 → PSNR = 20 dB\n        let outputs1 = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),\n            &device,\n        );\n        let targets1 = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);\n        let input1 = PsnrInput::new(outputs1, targets1);\n        let _entry = metric.update(&input1, &MetricMetadata::fake());\n\n        let psnr1 = metric.value().current();\n        let expected_psnr1 = 20.0;\n        assert!(\n            (psnr1 - expected_psnr1).abs() < 0.01,\n            \"First update PSNR should be ~{} dB, got {} dB\",\n            expected_psnr1,\n            psnr1\n        );\n\n        // Second update: error 0.01 → MSE = 0.0001 → PSNR = 40 dB\n        let outputs2 = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.01_f32, 0.01], [0.01, 0.01]]]]),\n            &device,\n        );\n        let targets2 = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);\n        let input2 = PsnrInput::new(outputs2, targets2);\n        let _entry = metric.update(&input2, &MetricMetadata::fake());\n\n        // Running average: (20 + 40) / 2 = 30 dB\n        let running_avg_psnr = metric.running_value().current();\n        let expected_running_avg_psnr = 30.0;\n        assert!(\n            (running_avg_psnr - expected_running_avg_psnr).abs() < 0.01,\n            \"Running average should be ~{} dB, got {} dB\",\n            expected_running_avg_psnr,\n            running_avg_psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_clear() {\n        // Error 0.1 → MSE = 0.01 → PSNR = 20 dB\n        let device = Default::default();\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 20.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{} dB, got {} dB\",\n            expected_psnr,\n            psnr\n        );\n\n        // Clear and verify reset\n        metric.clear();\n        let psnr = metric.running_value().current();\n        assert!(psnr.is_nan(), \"Expected NaN after clear, got {} dB\", psnr)\n    }\n\n    #[test]\n    fn test_psnr_custom_name() {\n        let config = PsnrMetricConfig::new(1.0);\n        let metric = PsnrMetric::<TestBackend>::new(config).with_name(\"CustomPSNR\");\n\n        assert_eq!(metric.name().to_string(), \"CustomPSNR\");\n    }\n\n    #[test]\n    fn test_psnr_custom_epsilon() {\n        let device = Default::default();\n        // With a larger epsilon, perfect reconstruction gives lower PSNR\n        let config = PsnrMetricConfig::new(1.0).with_epsilon(0.01);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.5_f32, 0.5], [0.5, 0.5]]]]),\n            &device,\n        );\n        let targets = outputs.clone();\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        // With epsilon = 0.01, PSNR = 10 * log10(1.0 / 0.01) = 20 dB\n        let psnr = metric.value().current();\n        let expected_psnr = 20.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{} dB with epsilon=0.01, got {}\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_negative_errors() {\n        // Test that negative differences (target > output) work correctly\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.0_f32, 0.0], [0.0, 0.0]]]]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[[0.1_f32, 0.1], [0.1, 0.1]]]]),\n            &device,\n        );\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        // Same MSE as positive errors (0.01), so PSNR = 20 dB\n        let psnr = metric.value().current();\n        let expected_psnr = 20.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{} dB, got {}\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_large_batch() {\n        // Test with a larger batch to verify batch dimension handling\n        let device = Default::default();\n        let batch_size = 8;\n\n        // All images have constant error 0.1 → MSE = 0.01 → PSNR = 20 dB\n        let outputs = Tensor::<TestBackend, 4>::full([batch_size, 3, 4, 4], 0.1, &device);\n        let targets = Tensor::<TestBackend, 4>::zeros([batch_size, 3, 4, 4], &device);\n\n        let config = PsnrMetricConfig::new(1.0);\n        let mut metric = PsnrMetric::<TestBackend>::new(config);\n        let input = PsnrInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let psnr = metric.value().current();\n        let expected_psnr = 20.0;\n        assert!(\n            (psnr - expected_psnr).abs() < 0.01,\n            \"Expected PSNR ~{} dB, got {}\",\n            expected_psnr,\n            psnr\n        );\n    }\n\n    #[test]\n    fn test_psnr_attributes() {\n        let config = PsnrMetricConfig::new(1.0);\n        let metric = PsnrMetric::<TestBackend>::new(config);\n        let attrs = metric.attributes();\n\n        match attrs {\n            MetricAttributes::Numeric(numeric_attrs) => {\n                assert_eq!(numeric_attrs.unit, Some(\"dB\".to_string()));\n                assert!(numeric_attrs.higher_is_better);\n            }\n            _ => panic!(\"Expected numeric attributes\"),\n        }\n    }\n\n    #[test]\n    #[should_panic(expected = \"Shape mismatch\")]\n    fn test_psnr_shape_mismatch() {\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 2, 2], &device);\n        let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 3, 3], &device);\n\n        let _ = PsnrInput::new(outputs, targets);\n    }\n\n    #[test]\n    #[should_panic(expected = \"max_pixel_val must be positive\")]\n    fn test_psnr_negative_max_pixel_val() {\n        let _ = PsnrMetricConfig::new(-1.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"max_pixel_val must be positive\")]\n    fn test_psnr_zero_max_pixel_val() {\n        let _ = PsnrMetricConfig::new(0.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"epsilon must be positive\")]\n    fn test_psnr_negative_epsilon() {\n        let _ = PsnrMetricConfig::new(1.0).with_epsilon(-1e-10);\n    }\n\n    #[test]\n    #[should_panic(expected = \"epsilon must be positive\")]\n    fn test_psnr_zero_epsilon() {\n        let _ = PsnrMetricConfig::new(1.0).with_epsilon(0.0);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/vision/ssim.rs",
    "content": "use crate::metric::{\n    Metric, MetricAttributes, MetricMetadata, MetricName, Numeric, NumericAttributes, NumericEntry,\n    SerializedEntry,\n    state::{FormatOptions, NumericMetricState},\n};\nuse burn_core::{\n    prelude::{Backend, Tensor},\n    tensor::{ElementConversion, module::conv2d, ops::ConvOptions},\n};\nuse core::marker::PhantomData;\n\n/// Input type for the [SsimMetric].\n///\n/// Both tensors must have shape `[N, C, H, W]`:\n/// - `N`: Batch size\n/// - `C`: Number of channels (1 for grayscale, 3 for RGB, etc.)\n/// - `H`: Height\n/// - `W`: Width\npub struct SsimInput<B: Backend> {\n    /// Model output (predictions/reconstructions) images with shape [N, C, H, W].\n    outputs: Tensor<B, 4>,\n    /// Ground truth images with shape [N, C, H, W].\n    targets: Tensor<B, 4>,\n}\n\nimpl<B: Backend> SsimInput<B> {\n    /// Creates a new SsimInput with the given outputs and targets.\n    ///\n    /// Inputs are expected to have the dimensions `[N, C, H, W]`\n    /// where `N` is the batch size, `C` is the number of channels,\n    /// `H` is the height of the image, and `W` is the width of the image.\n    ///\n    /// # Arguments\n    /// - `outputs`: The model output images with shape [N, C, H, W].\n    /// - `targets`: The ground truth images with shape [N, C, H, W].\n    ///\n    /// # Returns\n    /// A new instance of `SsimInput`.\n    ///\n    /// # Panics\n    /// - If `outputs` and `targets` do not have the same shape.\n    pub fn new(outputs: Tensor<B, 4>, targets: Tensor<B, 4>) -> Self {\n        assert!(\n            outputs.dims() == targets.dims(),\n            \"Shape mismatch: outputs {:?}, targets {:?}\",\n            outputs.dims(),\n            targets.dims()\n        );\n        Self { outputs, targets }\n    }\n}\n\n/// Configuration for the [SsimMetric].\n#[derive(Debug, Clone, Copy)]\npub struct SsimMetricConfig {\n    /// The range of the pixel values in images which can be computed as following:\n    /// `let pixel_range = max_pixel_val - min_pixel_val;`\n    /// where `max_pixel_val` is the maximum possible pixel value and `min_pixel_val`\n    /// is the minimum possible pixel value.\n    ///\n    /// - For normalized images in range [0, 1], it should be set to `1.0 - 0.0 = 1.0`\n    /// - For normalized images in range [-1, 1], it should be set to `1.0 - (-1.0) = 2.0`\n    /// - For 8-bit images in range [0, 255], it should be set to `255.0 - 0.0 = 255.0`\n    pub pixel_range: f32,\n    /// A parameter of SSIM used to stabilize the luminance comparison.\n    /// Default is 0.01.\n    pub k1: f32,\n    /// A parameter of SSIM used to stabilize the contrast comparison.\n    /// Default is 0.03.\n    pub k2: f32,\n    /// The SSIM metric involves applying convolution to the input tensors using a Gaussian kernel.\n    /// This is the kernel size of the Gaussian kernel. Default is 11.\n    pub kernel_size: usize,\n    /// The SSIM metric involves applying convolution to the input tensors using a Gaussian kernel.\n    /// This is the standard deviation of the Gaussian kernel. Default is 1.5.\n    pub sigma: f32,\n}\n\nimpl SsimMetricConfig {\n    /// Creates a configuration with the specified data range and default parameters.\n    ///\n    /// # Default parameters\n    /// - k1: 0.01\n    /// - k2: 0.03\n    /// - kernel_size: 11\n    /// - sigma: 1.5\n    ///\n    /// # Panics\n    /// - If `pixel_range` is not positive.\n    ///\n    /// # Example\n    /// ```ignore\n    /// // Normalized images [0, 1]\n    /// let config1 = SsimMetricConfig::new(1.0);\n    ///\n    /// // 8-bit images [0, 255]  \n    /// let config2 = SsimMetricConfig::new(255.0);\n    ///\n    /// // Also set custom values for k1 and k2\n    /// let config3 = SsimMetricConfig::new(1.0).with_k1_k2(0.015, 0.025);\n    ///\n    /// // Also set a custom value for window size\n    /// config3.with_kernel_size(13);\n    /// ```\n    pub fn new(pixel_range: f32) -> Self {\n        assert!(pixel_range > 0.0, \"pixel_range must be positive\");\n        Self {\n            pixel_range: pixel_range,\n            k1: 0.01,\n            k2: 0.03,\n            kernel_size: 11,\n            sigma: 1.5,\n        }\n    }\n\n    /// Sets a custom value for the k1 and k2 parameters of SSIM which are\n    /// used for numerical stability.\n    ///\n    /// # Default values\n    /// - k1: 0.01\n    /// - k2: 0.03\n    ///\n    /// # Panics\n    /// - If `k1` or `k2` is not positive.\n    pub fn with_k1_k2(mut self, k1: f32, k2: f32) -> Self {\n        assert!(k1 > 0.0, \"k1 must be positive\");\n        assert!(k2 > 0.0, \"k2 must be positive\");\n        self.k1 = k1;\n        self.k2 = k2;\n        self\n    }\n\n    /// Sets a custom window size for the Gaussian kernel used in SSIM. The\n    /// window size must be a positive odd number.\n    ///\n    /// # Default value\n    /// - kernel_size: 11\n    ///\n    /// # Panics\n    /// - If `kernel_size` is not a positive odd number.\n    pub fn with_kernel_size(mut self, kernel_size: usize) -> Self {\n        assert!(\n            kernel_size > 0 && kernel_size % 2 == 1,\n            \"kernel_size must be positive and an odd number\"\n        );\n        self.kernel_size = kernel_size;\n        self\n    }\n\n    /// Sets a custom sigma (standard deviation) for the Gaussian kernel used in SSIM.\n    ///\n    /// # Default value\n    /// - sigma: 1.5\n    ///\n    /// # Panics\n    /// - If `sigma` is not positive.\n    pub fn with_sigma(mut self, sigma: f32) -> Self {\n        assert!(sigma > 0.0, \"sigma must be positive\");\n        self.sigma = sigma;\n        self\n    }\n}\n\n/// The SSIM (structural similarity index measure) metric for image quality assessment.\n///\n/// SSIM measures the perceived quality of images by comparing luminance,\n/// contrast, and structure. Values range from -1 to 1, where 1 indicates\n/// perfect structural similarity.\n///\n/// # Formula\n/// ```text\n/// SSIM(x, y) = (2μxμy + C1)(2σxy + C2) / (μx² + μy² + C1)(σx² + σy² + C2)\n/// ```\n///\n/// # Note\n/// - This implementation uses separable Gaussian convolution for efficiency. Instead of a\n///   single 2D convolution with a K by K kernel, it applies two 1D convolutions (horizontal\n///   then vertical). This reduces the computational complexity from O(K^2) to O(2K) per pixel.\n/// - SSIM is computed for each image first, and then it is averaged across all the images in the batch.\n#[derive(Clone)]\npub struct SsimMetric<B: Backend> {\n    name: MetricName,\n    /// Internal state for numeric metric aggregation.\n    state: NumericMetricState,\n    /// Marker for backend type.\n    _b: PhantomData<B>,\n    /// Configuration for the metric.\n    config: SsimMetricConfig,\n}\n\nimpl<B: Backend> SsimMetric<B> {\n    /// Creates a new SSIM metric with the given configuration.\n    ///\n    /// # Note\n    /// The metric name format is \"SSIM (dr={}, w={}, σ={})\"\n    /// where dr is the data range, w is the window size, sigma is the\n    /// standard deviation. For example, the metric name might be\n    /// \"SSIM (dr=1.0, w=11, σ=1.5)\".\n    ///\n    /// # Example\n    /// ```ignore\n    /// let ssim_config = SsimMetricConfig::new(1.0);\n    /// let ssim_metric = SsimMetric::<B>::new(ssim_config);\n    /// ```\n    pub fn new(config: SsimMetricConfig) -> Self {\n        Self {\n            name: MetricName::new(format!(\n                \"SSIM (dr={}, w={}, σ={})\",\n                config.pixel_range, config.kernel_size, config.sigma,\n            )),\n            state: NumericMetricState::default(),\n            config,\n            _b: PhantomData,\n        }\n    }\n\n    /// Overrides the default metric name which is \"SSIM\".\n    pub fn with_name(mut self, name: &str) -> Self {\n        self.name = MetricName::new(name.to_string());\n        self\n    }\n\n    /// Creates a 1D Gaussian kernel as a tensor.\n    ///\n    /// Returns a normalized kernel where all values sum to 1.\n    /// The returned kernel will be reshaped by the `gaussian_conv_separable`\n    /// associated function later.\n    fn create_1d_gaussian_kernel(&self) -> Vec<f32> {\n        let size = self.config.kernel_size;\n        let sigma = self.config.sigma;\n        let center = (size / 2) as f32;\n\n        let mut kernel = vec![0.0f32; size];\n        let mut sum = 0.0f32;\n\n        for (i, v) in kernel.iter_mut().enumerate() {\n            let x = i as f32 - center;\n            let value = (-(x * x) / (2.0 * sigma * sigma)).exp();\n            *v = value;\n            sum += value;\n        }\n\n        // Normalize so values sum to 1\n        for v in kernel.iter_mut() {\n            *v /= sum;\n        }\n\n        kernel\n    }\n\n    /// Applies separable convolution using two 1D Gaussian kernels.\n    ///\n    /// # Arguments\n    /// - `inputs`: Tensor of shape [N, C, H, W]\n    /// - `kernel_1d`: The 1D Gaussian kernel values\n    /// - `channels`: Number of channels for depthwise convolution.\n    fn gaussian_conv_separable(\n        &self,\n        input: Tensor<B, 4>,\n        kernel_1d: &[f32],\n        channels: usize,\n        device: &B::Device,\n    ) -> Tensor<B, 4> {\n        let size = self.config.kernel_size;\n        let padding = size / 2;\n\n        // Create horizontal kernel: shape [C, 1, 1, K]\n        let horizontal_kernel = Tensor::<B, 1>::from_floats(kernel_1d, device)\n            .reshape([1, 1, 1, size]) // [1, 1, 1, K]\n            .repeat_dim(0, channels); // [C, 1, 1, K]\n\n        let vertical_kernel = Tensor::<B, 1>::from_floats(kernel_1d, device)\n            .reshape([1, 1, size, 1]) // [1, 1, K, 1]\n            .repeat_dim(0, channels); // [C, 1, K, 1]\n\n        // Apply horizontal convolution\n        let horizontal_conv_options = ConvOptions::new([1, 1], [0, padding], [1, 1], channels);\n        let input_after_horizontal_conv =\n            conv2d(input, horizontal_kernel, None, horizontal_conv_options);\n\n        // Apply vertical convolution\n        let vertical_conv_options = ConvOptions::new([1, 1], [padding, 0], [1, 1], channels);\n        conv2d(\n            input_after_horizontal_conv,\n            vertical_kernel,\n            None,\n            vertical_conv_options,\n        )\n    }\n}\n\nimpl<B: Backend> Metric for SsimMetric<B> {\n    type Input = SsimInput<B>;\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn update(&mut self, item: &Self::Input, _metadata: &MetricMetadata) -> SerializedEntry {\n        let dims = item.outputs.dims();\n        let batch_size = dims[0];\n        let channels = dims[1];\n        let device = item.outputs.device();\n\n        let img_height = dims[2];\n        let img_width = dims[3];\n        assert!(\n            img_height >= self.config.kernel_size && img_width >= self.config.kernel_size,\n            \"Image dimensions (H={}, W={}) must be >= kernel_size ({})\",\n            img_height,\n            img_width,\n            self.config.kernel_size\n        );\n\n        // Constants in SSIM formula used for numerical stability\n        let c1 = (self.config.k1 * self.config.pixel_range).powi(2);\n        let c2 = (self.config.k2 * self.config.pixel_range).powi(2);\n\n        // Create 1D Gaussian kernel to apply separable convolutions twice (horizontally and vertically)\n        let kernel_1d = self.create_1d_gaussian_kernel();\n\n        // Compute mu_x and mu_y, their product and squares\n        let x = item.outputs.clone();\n        let y = item.targets.clone();\n        let mu_x = self.gaussian_conv_separable(x.clone(), &kernel_1d, channels, &device);\n        let mu_y = self.gaussian_conv_separable(y.clone(), &kernel_1d, channels, &device);\n        let mu_x_mu_y = mu_x.clone() * mu_y.clone();\n        let square_of_mu_x = mu_x.clone() * mu_x.clone();\n        let square_of_mu_y = mu_y.clone() * mu_y.clone();\n\n        // Compute var_x, var_y (which are the same as (sigma_x)^2 and (sigma_y)^2):\n        // Var(X) = E[X^2] - E[X]^2\n        // var_x = mu_of_x_squared - (mu_x * mu_x)\n        let mu_of_x_squared =\n            self.gaussian_conv_separable(x.clone() * x.clone(), &kernel_1d, channels, &device);\n        let mu_of_y_squared =\n            self.gaussian_conv_separable(y.clone() * y.clone(), &kernel_1d, channels, &device);\n        let var_x = (mu_of_x_squared - square_of_mu_x.clone()).clamp_min(0.0);\n        let var_y = (mu_of_y_squared - square_of_mu_y.clone()).clamp_min(0.0);\n\n        // Compute the sample covariance of x and y: sigma_xy\n        // Cov(X, Y) = E[XY] - E[X]E[Y]\n        // sigma_xy = mu_xy - (mu_x * mu_y)\n        let mu_xy = self.gaussian_conv_separable(x * y, &kernel_1d, channels, &device);\n        let sigma_xy = mu_xy - mu_x_mu_y.clone();\n\n        // Compute SSIM:\n        // SSIM(x, y) = (2μxμy + C1)(2σxy + C2) / (μx² + μy² + C1)(σx² + σy² + C2)\n        let numerator = (mu_x_mu_y.mul_scalar(2.0_f32) + c1) * (sigma_xy.mul_scalar(2.0_f32) + c2);\n        let denominator = (square_of_mu_x + square_of_mu_y + c1) * (var_x + var_y + c2);\n        let ssim_tensor = numerator / denominator;\n\n        // Average SSIM across all dimensions to get a single scalar value\n        let ssim_per_image = ssim_tensor.mean_dims(&[1, 2, 3]);\n        let avg_ssim = ssim_per_image.mean().into_scalar().elem::<f64>();\n\n        self.state.update(\n            avg_ssim,\n            batch_size,\n            FormatOptions::new(self.name()).precision(4),\n        )\n    }\n\n    /// Clears the metric state.\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: None,\n            higher_is_better: true,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for SsimMetric<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\n#[allow(clippy::manual_range_contains)]\nmod tests {\n    use super::*;\n    use crate::{TestBackend, metric::Numeric};\n    use burn_core::tensor::{Distribution, Shape, TensorData};\n\n    fn test_config() -> SsimMetricConfig {\n        SsimMetricConfig::new(1.0)\n            .with_kernel_size(3)\n            .with_sigma(1.0)\n    }\n\n    #[test]\n    fn test_ssim_perfect_similarity() {\n        // When outputs exactly match targets, SSIM should be 1.0\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[\n                [0.1_f32, 0.2, 0.3, 0.4],\n                [0.5, 0.6, 0.7, 0.8],\n                [0.2, 0.3, 0.4, 0.5],\n                [0.6, 0.7, 0.8, 0.9],\n            ]]]),\n            &device,\n        );\n        let targets = outputs.clone();\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"SSIM for identical images should be 1.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_completely_different() {\n        // Constant black vs constant white\n        // With constant images: SSIM = (2*mu_x*mu_y + C1) / (mu_x^2 + mu_y^2 + C1)\n        // For x=0, y=1 with C1=(0.01)^2=0.0001: SSIM ≈ 0.0001 / (1 + 0.00001) = 0.00009999\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);\n        let targets = Tensor::<TestBackend, 4>::ones([1, 1, 4, 4], &device);\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            ssim < 0.0001,\n            \"SSIM for black vs white images should be very low, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_similar_images() {\n        // Small perturbation should give high SSIM\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::full([1, 1, 4, 4], 0.5, &device);\n        let targets = Tensor::<TestBackend, 4>::full([1, 1, 4, 4], 0.51, &device);\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            ssim > 0.99,\n            \"SSIM for very similar images should be close to 1.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_batch_averaging() {\n        // Batch of 2 images:\n        // Image 1: identical (SSIM = 1.0)\n        // Image 2: black vs white (SSIM ≈ 0)\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[\n                    [0.5_f32, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                ]],\n                [[\n                    [0.0_f32, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0],\n                ]],\n            ]),\n            &device,\n        );\n        let targets = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([\n                [[\n                    [0.5_f32, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                    [0.5, 0.5, 0.5, 0.5],\n                ]],\n                [[\n                    [1.0_f32, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0],\n                ]],\n            ]),\n            &device,\n        );\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        // Average of ~1.0 and ~0.0 should be around 0.5\n        assert!(\n            ssim > 0.49 && ssim < 0.51,\n            \"Average SSIM should be around 0.5, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_multichannel() {\n        // Test with 3 channels (e.g., RGB)\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[\n                [\n                    [0.5_f32, 0.6, 0.7, 0.8],\n                    [0.4, 0.5, 0.6, 0.7],\n                    [0.3, 0.4, 0.5, 0.6],\n                    [0.2, 0.3, 0.4, 0.5],\n                ],\n                [\n                    [0.3_f32, 0.4, 0.5, 0.6],\n                    [0.2, 0.3, 0.4, 0.5],\n                    [0.1, 0.2, 0.3, 0.4],\n                    [0.0, 0.1, 0.2, 0.3],\n                ],\n                [\n                    [0.7_f32, 0.8, 0.9, 1.0],\n                    [0.6, 0.7, 0.8, 0.9],\n                    [0.5, 0.6, 0.7, 0.8],\n                    [0.4, 0.5, 0.6, 0.7],\n                ],\n            ]]),\n            &device,\n        );\n        let targets = outputs.clone();\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"SSIM for identical RGB images should be 1.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_symmetry() {\n        // SSIM(x, y) should equal SSIM(y, x)\n        // Symmetry is one of the mathematical properties of SSIM\n        let device = Default::default();\n        let img1 = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[\n                [0.1_f32, 0.2, 0.3, 0.4],\n                [0.5, 0.6, 0.7, 0.8],\n                [0.2, 0.3, 0.4, 0.5],\n                [0.6, 0.7, 0.8, 0.9],\n            ]]]),\n            &device,\n        );\n        let img2 = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[\n                [0.2_f32, 0.3, 0.4, 0.5],\n                [0.6, 0.7, 0.8, 0.9],\n                [0.3, 0.4, 0.5, 0.6],\n                [0.7, 0.8, 0.9, 1.0],\n            ]]]),\n            &device,\n        );\n\n        let config = test_config();\n\n        let mut metric1 = SsimMetric::<TestBackend>::new(config);\n        let input1 = SsimInput::new(img1.clone(), img2.clone());\n        let _entry = metric1.update(&input1, &MetricMetadata::fake());\n        let ssim1 = metric1.value().current();\n\n        let mut metric2 = SsimMetric::<TestBackend>::new(config);\n        let input2 = SsimInput::new(img2, img1);\n        let _entry = metric2.update(&input2, &MetricMetadata::fake());\n        let ssim2 = metric2.value().current();\n\n        assert!(\n            (ssim1 - ssim2).abs() < 0.001,\n            \"SSIM should be symmetric: SSIM(x,y)={} vs SSIM(y,x)={}\",\n            ssim1,\n            ssim2\n        );\n    }\n\n    #[test]\n    fn test_ssim_range() {\n        // SSIM values should be in [-1, 1] range\n        let device = Default::default();\n        let shape = Shape::new([1, 1, 11, 11]);\n        let distribution = Distribution::Uniform(0.0, 1.0);\n        let outputs = Tensor::<TestBackend, 4>::random(shape.clone(), distribution, &device);\n        let targets = Tensor::<TestBackend, 4>::random(shape, distribution, &device);\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            ssim >= -1.0 && ssim <= 1.0,\n            \"SSIM should be in range [-1, 1], got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_running_average() {\n        let device = Default::default();\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n\n        // First update: identical images (SSIM = 1.0)\n        let outputs1 = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[\n                [0.5_f32, 0.6, 0.7, 0.8],\n                [0.4, 0.5, 0.6, 0.7],\n                [0.3, 0.4, 0.5, 0.6],\n                [0.2, 0.3, 0.4, 0.5],\n            ]]]),\n            &device,\n        );\n        let targets1 = outputs1.clone();\n        let input1 = SsimInput::new(outputs1, targets1);\n        let _entry = metric.update(&input1, &MetricMetadata::fake());\n\n        let ssim1 = metric.value().current();\n        assert!(\n            (ssim1 - 1.0).abs() < 0.001,\n            \"First update SSIM should be ~1.0, got {}\",\n            ssim1\n        );\n\n        // Second update: very different images (SSIM close to 0)\n        let outputs2 = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);\n        let targets2 = Tensor::<TestBackend, 4>::ones([1, 1, 4, 4], &device);\n        let input2 = SsimInput::new(outputs2, targets2);\n        let _entry = metric.update(&input2, &MetricMetadata::fake());\n\n        // Running average should be around 0.5\n        let running_avg = metric.running_value().current();\n        assert!(\n            running_avg > 0.49 && running_avg < 0.51,\n            \"Running average should be around 0.5, got {}\",\n            running_avg\n        );\n    }\n\n    #[test]\n    fn test_ssim_clear() {\n        let device = Default::default();\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n\n        let outputs = Tensor::<TestBackend, 4>::from_data(\n            TensorData::from([[[\n                [0.5_f32, 0.6, 0.7, 0.8],\n                [0.4, 0.5, 0.6, 0.7],\n                [0.3, 0.4, 0.5, 0.6],\n                [0.2, 0.3, 0.4, 0.5],\n            ]]]),\n            &device,\n        );\n        let targets = outputs.clone();\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"Expected SSIM ~1.0, got {}\",\n            ssim\n        );\n\n        // Clear and verify reset\n        metric.clear();\n        let ssim = metric.running_value().current();\n        assert!(ssim.is_nan(), \"Expected NaN after clear, got {}\", ssim);\n    }\n\n    #[test]\n    fn test_ssim_custom_name() {\n        let config = SsimMetricConfig::new(1.0);\n        let metric = SsimMetric::<TestBackend>::new(config).with_name(\"CustomSSIM\");\n        assert_eq!(metric.name().to_string(), \"CustomSSIM\");\n\n        let metric = SsimMetric::<TestBackend>::new(test_config());\n        assert_eq!(metric.name().to_string(), \"SSIM (dr=1, w=3, σ=1)\");\n\n        let config = SsimMetricConfig::new(255.0);\n        let metric = SsimMetric::<TestBackend>::new(config);\n        assert_eq!(metric.name().to_string(), \"SSIM (dr=255, w=11, σ=1.5)\");\n    }\n\n    #[test]\n    fn test_ssim_pixel_range_255() {\n        // Test with 8-bit image range [0, 255]\n        let device = Default::default();\n        let shape = Shape::new([1, 1, 10, 10]);\n        let distribution = Distribution::Uniform(0.0, 255.0);\n        let outputs = Tensor::<TestBackend, 4>::random(shape.clone(), distribution, &device);\n        let targets = outputs.clone();\n\n        let config = SsimMetricConfig::new(255.0).with_kernel_size(3);\n        let mut metric = SsimMetric::<TestBackend>::new(config);\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"SSIM for identical 8-bit images should be 1.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_large_batch() {\n        let device = Default::default();\n        let shape = Shape::new([20, 3, 30, 30]);\n        let distribution = Distribution::Uniform(0.0, 1.0);\n        let outputs = Tensor::<TestBackend, 4>::random(shape, distribution, &device);\n        let targets = outputs.clone();\n\n        let mut metric = SsimMetric::<TestBackend>::new(test_config());\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"SSIM for identical batch should be 1.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_default_kernel_size() {\n        // Test with default kernel_size=11, need images >= 11x11\n        let device = Default::default();\n        let shape = Shape::new([1, 1, 1080, 1920]);\n        let distribution = Distribution::Uniform(0.0, 1.0);\n        let outputs = Tensor::<TestBackend, 4>::random(shape, distribution, &device);\n        let targets = outputs.clone();\n\n        let config = SsimMetricConfig::new(1.0); // default kernel_size=11\n        let mut metric = SsimMetric::<TestBackend>::new(config);\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n\n        let ssim = metric.value().current();\n        assert!(\n            (ssim - 1.0).abs() < 0.001,\n            \"SSIM with default window size should work and SSIM should be ~0.0, got {}\",\n            ssim\n        );\n    }\n\n    #[test]\n    fn test_ssim_attributes() {\n        let config = SsimMetricConfig::new(1.0);\n        let metric = SsimMetric::<TestBackend>::new(config);\n        let attrs = metric.attributes();\n\n        match attrs {\n            MetricAttributes::Numeric(numeric_attrs) => {\n                assert_eq!(numeric_attrs.unit, None);\n                assert!(numeric_attrs.higher_is_better);\n            }\n            _ => panic!(\"Expected numeric attributes\"),\n        }\n    }\n\n    #[test]\n    #[should_panic(expected = \"Shape mismatch\")]\n    fn test_ssim_shape_mismatch() {\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);\n        let targets = Tensor::<TestBackend, 4>::zeros([1, 1, 5, 5], &device);\n\n        let _ = SsimInput::new(outputs, targets);\n    }\n\n    #[test]\n    #[should_panic(expected = \"Image dimensions (H=4, W=4) must be >= kernel_size (11)\")]\n    fn test_ssim_image_too_small() {\n        let device = Default::default();\n        let outputs = Tensor::<TestBackend, 4>::zeros([1, 1, 4, 4], &device);\n        let targets = outputs.clone();\n\n        // Default kernel_size=11, but image is only 4x4\n        let config = SsimMetricConfig::new(1.0);\n        let mut metric = SsimMetric::<TestBackend>::new(config);\n        let input = SsimInput::new(outputs, targets);\n        let _entry = metric.update(&input, &MetricMetadata::fake());\n    }\n\n    #[test]\n    fn test_ssim_valid_k1_k2() {\n        let config = SsimMetricConfig::new(1.0).with_k1_k2(0.015, 0.035);\n        assert!(\n            config.k1 == 0.015 && config.k2 == 0.035,\n            \"Expected k1=0.015 and k2=0.035, got k1={} and k2={}\",\n            config.k1,\n            config.k2\n        );\n    }\n\n    #[test]\n    #[should_panic(expected = \"pixel_range must be positive\")]\n    fn test_ssim_negative_pixel_range() {\n        let _ = SsimMetricConfig::new(-1.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"pixel_range must be positive\")]\n    fn test_ssim_zero_pixel_range() {\n        let _ = SsimMetricConfig::new(0.0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"k1 must be positive\")]\n    fn test_ssim_negative_k1() {\n        let _ = SsimMetricConfig::new(1.0).with_k1_k2(-0.01, 0.03);\n    }\n\n    #[test]\n    #[should_panic(expected = \"k2 must be positive\")]\n    fn test_ssim_negative_k2() {\n        let _ = SsimMetricConfig::new(1.0).with_k1_k2(0.01, -0.03);\n    }\n\n    #[test]\n    #[should_panic(expected = \"kernel_size must be positive and an odd number\")]\n    fn test_ssim_even_kernel_size() {\n        let _ = SsimMetricConfig::new(1.0).with_kernel_size(10);\n    }\n\n    #[test]\n    #[should_panic(expected = \"kernel_size must be positive and an odd number\")]\n    fn test_ssim_zero_kernel_size() {\n        let _ = SsimMetricConfig::new(1.0).with_kernel_size(0);\n    }\n\n    #[test]\n    #[should_panic(expected = \"sigma must be positive\")]\n    fn test_ssim_negative_sigma() {\n        let _ = SsimMetricConfig::new(1.0).with_sigma(-1.5);\n    }\n\n    #[test]\n    #[should_panic(expected = \"sigma must be positive\")]\n    fn test_ssim_zero_sigma() {\n        let _ = SsimMetricConfig::new(1.0).with_sigma(0.0);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/metric/wer.rs",
    "content": "use super::cer::edit_distance;\nuse super::state::{FormatOptions, NumericMetricState};\nuse super::{MetricMetadata, SerializedEntry};\nuse crate::metric::{\n    Metric, MetricAttributes, MetricName, Numeric, NumericAttributes, NumericEntry,\n};\nuse burn_core::tensor::backend::Backend;\nuse burn_core::tensor::{Int, Tensor};\nuse core::marker::PhantomData;\nuse std::sync::Arc;\n\n// The edit_distance function remains the same as it calculates the Levenshtein distance\n// between two sequences. The \"units\" within the sequences will now be treated as words.\n/// The word error rate (WER) metric, similar to the CER, is defined as the edit distance (e.g. Levenshtein distance) between the predicted\n/// and reference word sequences, divided by the total number of words in the reference. Here, the \"units\" within the sequences are words.\n///\n#[derive(Clone)]\npub struct WordErrorRate<B: Backend> {\n    name: MetricName,\n    state: NumericMetricState,\n    pad_token: Option<usize>,\n    _b: PhantomData<B>,\n}\n\n/// The [word error rate metric](WordErrorRate) input type.\n#[derive(new)]\npub struct WerInput<B: Backend> {\n    /// The predicted token sequences (as a 2-D tensor of token indices).\n    pub outputs: Tensor<B, 2, Int>,\n    /// The target token sequences (as a 2-D tensor of token indices).\n    pub targets: Tensor<B, 2, Int>,\n}\nimpl<B: Backend> Default for WordErrorRate<B> {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl<B: Backend> WordErrorRate<B> {\n    /// Creates the metric.\n    pub fn new() -> Self {\n        Self {\n            name: Arc::new(\"WER\".to_string()),\n            state: NumericMetricState::default(),\n            pad_token: None,\n            _b: PhantomData,\n        }\n    }\n\n    /// Sets the pad token.\n    pub fn with_pad_token(mut self, index: usize) -> Self {\n        self.pad_token = Some(index);\n        self\n    }\n}\n\nimpl<B: Backend> Metric for WordErrorRate<B> {\n    type Input = WerInput<B>;\n\n    fn update(&mut self, input: &WerInput<B>, _metadata: &MetricMetadata) -> SerializedEntry {\n        let outputs = input.outputs.clone();\n        let targets = input.targets.clone();\n        let [batch_size, seq_len] = targets.dims();\n\n        let outputs_data = outputs\n            .to_data()\n            .to_vec::<i64>()\n            .expect(\"Failed to convert outputs to Vec\");\n        let targets_data = targets\n            .to_data()\n            .to_vec::<i64>()\n            .expect(\"Failed to convert targets to Vec\");\n\n        let pad_token = self.pad_token;\n\n        let mut total_edit_distance = 0.0;\n        let mut total_target_length = 0.0;\n\n        // Process each sequence in the batch\n        for i in 0..batch_size {\n            let start = i * seq_len;\n            let end = (i + 1) * seq_len;\n            let output_seq = &outputs_data[start..end];\n            let target_seq = &targets_data[start..end];\n\n            // Handle padding and map elements to i32.\n            // These sequences now represent \"words\" (token IDs).\n            let output_seq_no_pad = match pad_token {\n                Some(pad) => output_seq\n                    .iter()\n                    .take_while(|&&x| x != pad as i64)\n                    .map(|&x| x as i32)\n                    .collect::<Vec<_>>(),\n                None => output_seq.iter().map(|&x| x as i32).collect(),\n            };\n\n            let target_seq_no_pad = match pad_token {\n                Some(pad) => target_seq\n                    .iter()\n                    .take_while(|&&x| x != pad as i64)\n                    .map(|&x| x as i32)\n                    .collect::<Vec<_>>(),\n                None => target_seq.iter().map(|&x| x as i32).collect(),\n            };\n\n            let ed = edit_distance(&target_seq_no_pad, &output_seq_no_pad);\n            total_edit_distance += ed as f64;\n            total_target_length += target_seq_no_pad.len() as f64;\n        }\n\n        // Compute current WER value as a percentage\n        let value = if total_target_length > 0.0 {\n            100.0 * total_edit_distance / total_target_length\n        } else {\n            0.0\n        };\n\n        self.state.update(\n            value,\n            batch_size,\n            FormatOptions::new(self.name()).unit(\"%\").precision(2),\n        )\n    }\n\n    fn name(&self) -> MetricName {\n        self.name.clone()\n    }\n\n    fn clear(&mut self) {\n        self.state.reset();\n    }\n\n    fn attributes(&self) -> MetricAttributes {\n        NumericAttributes {\n            unit: Some(\"%\".to_string()),\n            higher_is_better: false,\n        }\n        .into()\n    }\n}\n\nimpl<B: Backend> Numeric for WordErrorRate<B> {\n    fn value(&self) -> NumericEntry {\n        self.state.current_value()\n    }\n\n    fn running_value(&self) -> NumericEntry {\n        self.state.running_value()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::TestBackend;\n\n    /// Perfect match => WER = 0 %.\n    #[test]\n    fn test_wer_without_padding() {\n        let device = Default::default();\n        let mut metric = WordErrorRate::<TestBackend>::new();\n\n        // Batch size = 2, sequence length = 2\n        let preds = Tensor::from_data([[1, 2], [3, 4]], &device);\n        let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);\n\n        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());\n\n        assert_eq!(0.0, metric.value().current());\n    }\n\n    /// Two word edits in four target words => 50 %.\n    #[test]\n    fn test_wer_without_padding_two_errors() {\n        let device = Default::default();\n        let mut metric = WordErrorRate::<TestBackend>::new();\n\n        // One substitution in each sequence.\n        // Sequence 1: target [1, 3], pred [1, 2] -> 1 error (3 vs 2)\n        // Sequence 2: target [3, 4], pred [3, 5] -> 1 error (4 vs 5)\n        let preds = Tensor::from_data([[1, 2], [3, 5]], &device);\n        let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);\n\n        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());\n\n        // Total errors = 2, Total target words = 4. WER = (2/4) * 100 = 50 %\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    /// Same scenario as above, but with right-padding (token 9) ignored.\n    #[test]\n    fn test_wer_with_padding() {\n        let device = Default::default();\n        let pad = 9_i64;\n        let mut metric = WordErrorRate::<TestBackend>::new().with_pad_token(pad as usize);\n\n        // Each row has three columns, last one is the pad token.\n        // Target sequences after removing pad: [1, 3] and [3, 4] (total length 4)\n        // Predicted sequences after removing pad: [1, 2] and [3, 5]\n        let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);\n        let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);\n\n        metric.update(&WerInput::new(preds, tgts), &MetricMetadata::fake());\n        assert_eq!(50.0, metric.value().current());\n    }\n\n    /// `clear()` must reset the running statistics to NaN.\n    #[test]\n    fn test_clear_resets_state() {\n        let device = Default::default();\n        let mut metric = WordErrorRate::<TestBackend>::new();\n\n        let preds = Tensor::from_data([[1, 2]], &device);\n        let tgts = Tensor::from_data([[1, 3]], &device); // one error\n\n        metric.update(\n            &WerInput::new(preds.clone(), tgts.clone()),\n            &MetricMetadata::fake(),\n        );\n        assert!(metric.value().current() > 0.0);\n\n        metric.clear();\n        assert!(metric.value().current().is_nan());\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/base.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    LearnerSummary,\n    metric::{MetricDefinition, MetricEntry, NumericEntry},\n};\nuse burn_core::data::dataloader::Progress;\n\n/// Trait for rendering metrics.\npub trait MetricsRendererTraining: Send + Sync {\n    /// Updates the training metric state.\n    ///\n    /// # Arguments\n    ///\n    /// * `state` - The metric state.\n    fn update_train(&mut self, state: MetricState);\n\n    /// Updates the validation metric state.\n    ///\n    /// # Arguments\n    ///\n    /// * `state` - The metric state.\n    fn update_valid(&mut self, state: MetricState);\n\n    /// Renders the training progress.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The training progress.\n    fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);\n\n    /// Renders the validation progress.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The validation progress.\n    fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);\n\n    /// Callback method invoked when training ends, whether it\n    /// completed successfully or was interrupted.\n    ///\n    /// # Returns\n    ///\n    /// A result indicating whether the end-of-training actions were successful.\n    fn on_train_end(\n        &mut self,\n        summary: Option<LearnerSummary>,\n    ) -> Result<(), Box<dyn core::error::Error>> {\n        default_summary_action(summary);\n        Ok(())\n    }\n}\n\n/// A renderer that can be used for both training and evaluation.\npub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {\n    /// Keep the renderer from automatically closing, requiring manual action to close it.\n    fn manual_close(&mut self);\n    /// Register a new metric.\n    fn register_metric(&mut self, definition: MetricDefinition);\n}\n\n#[derive(Clone)]\n/// The name of an evaluation.\n///\n/// This is going to group metrics together for easier analysis.\npub struct EvaluationName {\n    pub(crate) name: Arc<String>,\n}\n\nimpl EvaluationName {\n    /// Creates a new evaluation name.\n    pub fn new<S: core::fmt::Display>(s: S) -> Self {\n        Self {\n            name: Arc::new(format!(\"{s}\")),\n        }\n    }\n\n    /// Returns the evaluation name.\n    pub fn as_str(&self) -> &str {\n        &self.name\n    }\n}\n\nimpl core::fmt::Display for EvaluationName {\n    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n        f.write_str(&self.name)\n    }\n}\n\n/// Trait for rendering metrics.\npub trait MetricsRendererEvaluation: Send + Sync {\n    /// Updates the testing metric state.\n    ///\n    /// # Arguments\n    ///\n    /// * `state` - The metric state.\n    fn update_test(&mut self, name: EvaluationName, state: MetricState);\n    /// Renders the testing progress.\n    ///\n    /// # Arguments\n    ///\n    /// * `item` - The training progress.\n    fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);\n\n    /// Callback method invoked when testing ends, whether it\n    /// completed successfully or was interrupted.\n    ///\n    /// # Returns\n    ///\n    /// A result indicating whether the end-of-testing actions were successful.\n    fn on_test_end(\n        &mut self,\n        summary: Option<LearnerSummary>,\n    ) -> Result<(), Box<dyn core::error::Error>> {\n        default_summary_action(summary);\n        Ok(())\n    }\n}\n\n/// The state of a metric.\n#[derive(Debug)]\npub enum MetricState {\n    /// A generic metric.\n    Generic(MetricEntry),\n    /// A numeric metric.\n    Numeric(MetricEntry, NumericEntry),\n}\n\n/// Training progress.\n#[derive(Debug)]\npub struct TrainingProgress {\n    /// The progress.\n    pub progress: Option<Progress>,\n\n    /// The progress of the whole training.\n    pub global_progress: Progress,\n\n    /// The iteration, if it differs from the items processed.\n    pub iteration: Option<usize>,\n}\n\n/// Evaluation progress.\n#[derive(Debug)]\npub struct EvaluationProgress {\n    /// The progress.\n    pub progress: Progress,\n\n    /// The iteration, if it is different from the processed items.\n    pub iteration: Option<usize>,\n}\n\nimpl From<&EvaluationProgress> for TrainingProgress {\n    fn from(value: &EvaluationProgress) -> Self {\n        TrainingProgress {\n            progress: None,\n            global_progress: value.progress.clone(),\n            iteration: value.iteration,\n        }\n    }\n}\n\nimpl TrainingProgress {\n    /// Creates a new empty training progress.\n    pub fn none() -> Self {\n        Self {\n            progress: None,\n            global_progress: Progress {\n                items_processed: 0,\n                items_total: 0,\n            },\n            iteration: None,\n        }\n    }\n}\n\n/// Type of progress indicators.\npub enum ProgressType {\n    /// Detailed progress.\n    Detailed {\n        /// The tag.\n        tag: String,\n        /// The progress.\n        progress: Progress,\n    },\n    /// Simple value.\n    Value {\n        /// The tag.\n        tag: String,\n        /// The value.\n        value: usize,\n    },\n}\n\nfn default_summary_action(summary: Option<LearnerSummary>) {\n    if let Some(summary) = summary {\n        println!(\"{summary}\");\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/cli.rs",
    "content": "use crate::renderer::{\n    EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,\n    MetricsRendererTraining, ProgressType, TrainingProgress,\n};\n\n/// A simple renderer for when the cli feature is not enabled.\npub struct CliMetricsRenderer;\n\n#[allow(clippy::new_without_default)]\nimpl CliMetricsRenderer {\n    /// Create a new instance.\n    pub fn new() -> Self {\n        Self {}\n    }\n}\n\nimpl MetricsRendererTraining for CliMetricsRenderer {\n    fn update_train(&mut self, _state: MetricState) {}\n\n    fn update_valid(&mut self, _state: MetricState) {}\n\n    fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {\n        println!(\"{item:?}\");\n    }\n\n    fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {\n        println!(\"{item:?}\");\n    }\n}\n\nimpl MetricsRendererEvaluation for CliMetricsRenderer {\n    fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {\n        println!(\"{item:?}\");\n    }\n\n    fn update_test(&mut self, _name: super::EvaluationName, _state: MetricState) {}\n}\n\nimpl MetricsRenderer for CliMetricsRenderer {\n    fn manual_close(&mut self) {\n        // Nothing to do.\n    }\n\n    fn register_metric(&mut self, _definition: crate::metric::MetricDefinition) {}\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/mod.rs",
    "content": "#[cfg(feature = \"tui\")]\nuse std::io::IsTerminal;\n\nmod base;\npub use base::*;\n\npub(crate) mod cli;\n\npub use cli::*;\n\n/// The tui renderer\n#[cfg(feature = \"tui\")]\npub mod tui;\nuse crate::Interrupter;\n\n/// Return the default metrics renderer.\n///\n/// This can be either:\n///   - `TuiMetricsRenderer`, when the `tui` feature is enabled and `stdout` is\n///     a terminal, or\n///   - `CliMetricsRenderer`, when the `tui` feature is not enabled, or `stdout`\n///     is not a terminal.\n#[allow(unused_variables)]\npub(crate) fn default_renderer(\n    interuptor: Interrupter,\n    checkpoint: Option<usize>,\n) -> Box<dyn MetricsRenderer> {\n    #[cfg(feature = \"tui\")]\n    if std::io::stdout().is_terminal() {\n        return Box::new(tui::TuiMetricsRendererWrapper::new(interuptor, checkpoint));\n    }\n\n    Box::new(CliMetricsRenderer::new())\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/base.rs",
    "content": "use std::sync::Arc;\n\nuse super::{\n    ControlsView, NumericMetricView, ProgressBarView, StatusView, TerminalFrame, TextMetricView,\n};\nuse ratatui::{\n    prelude::{Constraint, Direction, Layout, Rect},\n    style::Color,\n};\n\n#[derive(new)]\npub(crate) struct MetricsView<'a> {\n    metric_numeric: NumericMetricView<'a>,\n    metric_text: TextMetricView,\n    progress: ProgressBarView,\n    controls: ControlsView,\n    status: StatusView,\n}\n\nimpl MetricsView<'_> {\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        let chunks = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints([Constraint::Min(16), Constraint::Max(4)].as_ref())\n            .split(size);\n        let size_other = chunks[0];\n        let size_progress = chunks[1];\n\n        let chunks = Layout::default()\n            .direction(Direction::Horizontal)\n            .constraints([Constraint::Percentage(38), Constraint::Percentage(62)].as_ref())\n            .split(size_other);\n        let size_other = chunks[0];\n        let size_metric_numeric = chunks[1];\n\n        let chunks = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints([Constraint::Max(5), Constraint::Min(6), Constraint::Max(6)].as_ref())\n            .split(size_other);\n        let size_controls = chunks[0];\n        let size_metric_text = chunks[1];\n        let size_status = chunks[2];\n\n        self.metric_numeric.render(frame, size_metric_numeric);\n        self.metric_text.render(frame, size_metric_text);\n        self.controls.render(frame, size_controls);\n        self.progress.render(frame, size_progress);\n        self.status.render(frame, size_status);\n    }\n}\n\n#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]\npub(crate) enum TuiSplit {\n    Train,\n    Valid,\n    Test,\n}\n\n#[derive(Hash, Clone, PartialEq, Eq, PartialOrd, Ord)]\npub(crate) enum TuiGroup {\n    Default,\n    Named(Arc<String>),\n}\n\n#[derive(new, Hash, Clone, PartialEq, Eq, PartialOrd, Ord)]\npub(crate) struct TuiTag {\n    pub(crate) split: TuiSplit,\n    pub(crate) group: TuiGroup,\n}\n\nimpl core::fmt::Display for TuiTag {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match &self.group {\n            TuiGroup::Default => f.write_fmt(format_args!(\"{}\", self.split)),\n            TuiGroup::Named(group) => f.write_fmt(format_args!(\"{} - {}\", self.split, group)),\n        }\n    }\n}\nimpl core::fmt::Display for TuiGroup {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            TuiGroup::Default => f.write_str(\"\"),\n            TuiGroup::Named(group) => f.write_fmt(format_args!(\"{group} \")),\n        }\n    }\n}\n\nimpl core::fmt::Display for TuiSplit {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            TuiSplit::Train => f.write_str(\"Train\"),\n            TuiSplit::Valid => f.write_str(\"Valid\"),\n            TuiSplit::Test => f.write_str(\"Test\"),\n        }\n    }\n}\n\nimpl TuiSplit {\n    pub(crate) fn color(&self) -> Color {\n        match self {\n            TuiSplit::Train => Color::LightRed,\n            TuiSplit::Valid => Color::LightBlue,\n            TuiSplit::Test => Color::LightGreen,\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/controls.rs",
    "content": "use super::TerminalFrame;\nuse ratatui::{\n    prelude::{Alignment, Rect},\n    style::{Color, Style, Stylize},\n    text::{Line, Span},\n    widgets::{Block, Borders, Paragraph, Wrap},\n};\n\n/// Controls view.\npub(crate) struct ControlsView;\n\nimpl ControlsView {\n    /// Render the view.\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        let lines = vec![\n            vec![\n                Span::from(\" Quit          : \").yellow().bold(),\n                Span::from(\"q  \").bold(),\n                Span::from(\"  Stop the training.\").italic(),\n            ],\n            vec![\n                Span::from(\" Plots Metrics : \").yellow().bold(),\n                Span::from(\"⬅ ➡\").bold(),\n                Span::from(\"  Switch between metrics.\").italic(),\n            ],\n            vec![\n                Span::from(\" Plots Type    : \").yellow().bold(),\n                Span::from(\"⬆ ⬇\").bold(),\n                Span::from(\"  Switch between types.\").italic(),\n            ],\n        ];\n        let paragraph = Paragraph::new(lines.into_iter().map(Line::from).collect::<Vec<_>>())\n            .alignment(Alignment::Left)\n            .wrap(Wrap { trim: false })\n            .style(Style::default().fg(Color::Gray))\n            .block(\n                Block::default()\n                    .borders(Borders::ALL)\n                    .style(Style::default().fg(Color::Gray))\n                    .title_alignment(Alignment::Left)\n                    .title(\"Controls\"),\n            );\n\n        frame.render_widget(paragraph, size);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/full_history.rs",
    "content": "use super::PlotAxes;\nuse crate::{\n    metric::NumericEntry,\n    renderer::tui::{TuiSplit, TuiTag},\n};\nuse ratatui::{\n    style::{Color, Style},\n    symbols,\n    widgets::{Bar, Dataset, GraphType},\n};\nuse std::collections::BTreeMap;\n\n/// A plot that shows the full history at a reduced resolution.\npub(crate) struct FullHistoryPlot {\n    pub(crate) axes: PlotAxes,\n    points: BTreeMap<TuiTag, FullHistoryPoints>,\n    max_samples: usize,\n    max_samples_ratio: BTreeMap<TuiSplit, f64>,\n    next_x_state: usize,\n}\n\nstruct FullHistoryPoints {\n    min_x: f64,\n    max_x: f64,\n    min_y: f64,\n    max_y: f64,\n    avg_sum: f64,\n    avg_counter: f64,\n    points: Vec<(f64, f64)>,\n    max_samples: usize,\n    step_size: usize,\n}\n\nimpl FullHistoryPlot {\n    /// Create a new history plot.\n    pub(crate) fn new(max_samples: usize) -> Self {\n        Self {\n            points: BTreeMap::default(),\n            axes: PlotAxes::default(),\n            max_samples,\n            max_samples_ratio: BTreeMap::default(),\n            next_x_state: 0,\n        }\n    }\n\n    /// Update the maximum amount of sample to display for the validation points.\n    ///\n    /// This is necessary if we want the validation line to have the same point density as the\n    /// training line.\n    pub(crate) fn update_max_sample(&mut self, split: TuiSplit, ratio: f64) {\n        self.max_samples_ratio.insert(split, ratio);\n\n        self.points\n            .iter_mut()\n            .filter(|(tag, _)| tag.split == split)\n            .for_each(|(_, points)| {\n                points.max_samples = (self.max_samples as f64 * ratio) as usize;\n            });\n    }\n\n    /// Register a training data point.\n    pub(crate) fn push(&mut self, tag: TuiTag, data: NumericEntry) {\n        let x_current = self.next_x();\n        let points = match self.points.get_mut(&tag) {\n            Some(val) => val,\n            None => {\n                let max_samples = self\n                    .max_samples_ratio\n                    .get(&tag.split)\n                    .map(|ratio| (*ratio * self.max_samples as f64) as usize)\n                    .unwrap_or(self.max_samples);\n                self.points\n                    .insert(tag.clone(), FullHistoryPoints::new(max_samples));\n                self.points.get_mut(&tag).unwrap()\n            }\n        };\n\n        points.push((x_current, data));\n\n        self.update_bounds();\n    }\n\n    pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {\n        let mut datasets = Vec::with_capacity(2);\n\n        for (tag, points) in self.points.iter() {\n            datasets.push(points.dataset(format!(\"{tag}\"), tag.split.color()));\n        }\n\n        datasets\n    }\n\n    pub(crate) fn bars(&self, max: u64, bar_width: &mut usize) -> Vec<Bar<'_>> {\n        let mut bars = Vec::new();\n\n        for (tag, points) in self.points.iter() {\n            if let Some((bar, width)) = points.bar(tag, max) {\n                *bar_width = usize::max(*bar_width, width);\n                bars.push(bar);\n            }\n        }\n\n        bars\n    }\n\n    fn next_x(&mut self) -> f64 {\n        let value = self.next_x_state;\n        self.next_x_state += 1;\n        value as f64\n    }\n\n    fn update_bounds(&mut self) {\n        let (mut x_min, mut x_max) = (f64::MAX, f64::MIN);\n        let (mut y_min, mut y_max) = (f64::MAX, f64::MIN);\n\n        for points in self.points.values() {\n            x_min = f64::min(x_min, points.min_x);\n            x_max = f64::max(x_max, points.max_x);\n            y_min = f64::min(y_min, points.min_y);\n            y_max = f64::max(y_max, points.max_y);\n        }\n\n        self.axes.update_bounds((x_min, x_max), (y_min, y_max));\n    }\n}\n\nimpl FullHistoryPoints {\n    fn new(max_samples: usize) -> Self {\n        Self {\n            min_x: 0.,\n            max_x: 0.,\n            min_y: f64::MAX,\n            max_y: f64::MIN,\n            avg_sum: 0.0,\n            avg_counter: 0.0,\n            points: Vec::with_capacity(max_samples),\n            max_samples,\n            step_size: 1,\n        }\n    }\n\n    fn push(&mut self, (x, y): (f64, NumericEntry)) {\n        if !(x as usize).is_multiple_of(self.step_size) {\n            return;\n        }\n\n        let y = match y {\n            NumericEntry::Value(val) => {\n                self.avg_sum += val;\n                self.avg_counter += 1.0;\n                val\n            }\n            NumericEntry::Aggregated {\n                aggregated_value,\n                count,\n            } => {\n                self.avg_sum += aggregated_value * count as f64;\n                self.avg_counter += count as f64;\n                aggregated_value\n            }\n        };\n\n        if x > self.max_x {\n            self.max_x = x;\n        }\n        if x < self.min_x {\n            self.min_x = x;\n        }\n        if y > self.max_y {\n            self.max_y = y;\n        }\n        if y < self.min_y {\n            self.min_y = y\n        }\n\n        self.points.push((x, y));\n\n        if self.points.len() > self.max_samples {\n            self.resize();\n        }\n    }\n\n    /// We keep only half the points and we double the step size.\n    ///\n    /// This ensure that we have the same amount of points across the X axis.\n    fn resize(&mut self) {\n        let mut points = Vec::with_capacity(self.max_samples / 2);\n        let mut max_x = f64::MIN;\n        let mut max_y = f64::MIN;\n        let mut min_x = f64::MAX;\n        let mut min_y = f64::MAX;\n\n        for (i, (x, y)) in self.points.drain(0..self.points.len()).enumerate() {\n            if i % 2 == 0 {\n                if x > max_x {\n                    max_x = x;\n                }\n                if x < min_x {\n                    min_x = x;\n                }\n                if y > max_y {\n                    max_y = y;\n                }\n                if y < min_y {\n                    min_y = y;\n                }\n\n                points.push((x, y));\n            }\n        }\n\n        self.points = points;\n        self.step_size *= 2;\n\n        self.min_x = min_x;\n        self.max_x = max_x;\n        self.min_y = min_y;\n        self.max_y = max_y;\n    }\n\n    fn dataset<'a>(&'a self, name: String, color: Color) -> Dataset<'a> {\n        Dataset::default()\n            .name(name)\n            .marker(symbols::Marker::Braille)\n            .style(Style::default().fg(color).bold())\n            .graph_type(GraphType::Line)\n            .data(&self.points)\n    }\n\n    fn bar<'a>(&'a self, tag: &TuiTag, max: u64) -> Option<(Bar<'a>, usize)> {\n        if self.avg_sum == 0.0 {\n            return None;\n        }\n\n        let label = format!(\"{tag}\");\n        let width = usize::max(label.len(), 7); // 7 min width\n\n        let factor = max as f64;\n\n        let avg = self.avg_sum / self.avg_counter;\n\n        Some((\n            Bar::default()\n                .value((avg * factor) as u64)\n                .style(tag.split.color())\n                .text_value(format!(\"{:.2}\", avg))\n                .label(label),\n            width,\n        ))\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::renderer::tui::{TuiGroup, TuiSplit};\n\n    #[test]\n    fn test_points() {\n        let mut chart = FullHistoryPlot::new(10);\n        let tag_train = TuiTag::new(TuiSplit::Train, TuiGroup::Default);\n        let tag_valid = TuiTag::new(TuiSplit::Valid, TuiGroup::Default);\n        chart.update_max_sample(tag_valid.split, 0.6);\n\n        for i in 0..100 {\n            chart.push(tag_train.clone(), NumericEntry::Value(i as f64));\n        }\n        for i in 0..60 {\n            chart.push(tag_valid.clone(), NumericEntry::Value(i as f64));\n        }\n\n        let expected_train = vec![\n            (0.0, 0.0),\n            (16.0, 16.0),\n            (32.0, 32.0),\n            (48.0, 48.0),\n            (64.0, 64.0),\n            (80.0, 80.0),\n            (96.0, 96.0),\n        ];\n\n        let expected_valid = vec![(100.0, 0.0), (116.0, 16.0), (128.0, 28.0), (144.0, 44.0)];\n\n        assert_eq!(\n            chart.points.get(&tag_train).unwrap().points,\n            expected_train,\n            \"Expected train data points\"\n        );\n        assert_eq!(\n            chart.points.get(&tag_valid).unwrap().points,\n            expected_valid,\n            \"Expected valid data points\"\n        );\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/metric_numeric.rs",
    "content": "use crate::{\n    metric::{MetricName, NumericEntry},\n    renderer::{EvaluationProgress, TrainingProgress, tui::TuiTag},\n};\n\nuse super::{FullHistoryPlot, RecentHistoryPlot, TerminalFrame, TuiSplit};\nuse ratatui::{\n    crossterm::event::{Event, KeyCode, KeyEventKind},\n    prelude::{Alignment, Constraint, Direction, Layout, Rect},\n    style::{Color, Modifier, Style, Stylize},\n    text::Line,\n    widgets::{\n        Axis, BarChart, BarGroup, Block, Borders, Chart, LegendPosition, Padding, Paragraph, Tabs,\n    },\n};\nuse std::collections::BTreeMap;\n\n/// 1000 seems to be required to see some improvement.\nconst MAX_NUM_SAMPLES_RECENT: usize = 1000;\n/// 250 seems to be the right resolution when plotting all history.\n/// Otherwise, there is too much points and the lines arent't smooth enough.\nconst MAX_NUM_SAMPLES_FULL: usize = 250;\n\n/// Numeric metrics state that handles creating plots.\n#[derive(Default)]\npub(crate) struct NumericMetricsState {\n    data: BTreeMap<MetricName, (RecentHistoryPlot, FullHistoryPlot)>,\n    names: Vec<MetricName>,\n    selected: usize,\n    kind: PlotKind,\n    num_samples_train: Option<usize>,\n    num_samples_valid: Option<usize>,\n    num_samples_test: Option<usize>,\n    epoch: usize,\n}\n\n/// The kind of plot to display.\n#[derive(Default, Clone, Copy)]\npub(crate) enum PlotKind {\n    /// Display the full history of the metric with reduced resolution.\n    #[default]\n    Full,\n    /// Display only the recent history of the metric, but with more resolution.\n    Recent,\n    Summary,\n}\n\nimpl NumericMetricsState {\n    /// Register a new training value for the metric with the given name.\n    pub(crate) fn push(&mut self, tag: TuiTag, name: MetricName, data: NumericEntry) {\n        if let Some((recent, full)) = self.data.get_mut(name.as_ref()) {\n            recent.push(tag.clone(), data.current());\n            full.push(tag, data);\n        } else {\n            let mut recent = RecentHistoryPlot::new(MAX_NUM_SAMPLES_RECENT);\n            let mut full = FullHistoryPlot::new(MAX_NUM_SAMPLES_FULL);\n\n            recent.push(tag.clone(), data.current());\n            full.push(tag, data);\n\n            self.names.push(name.clone());\n            self.data.insert(name, (recent, full));\n        }\n    }\n\n    /// Update the state with the training progress.\n    pub(crate) fn update_progress_train(&mut self, progress: &TrainingProgress) {\n        self.epoch = progress.global_progress.items_processed;\n\n        if self.num_samples_train.is_some() {\n            return;\n        }\n\n        // If the training only has the notion of global progress, num_samples_train remains None.\n        self.num_samples_train = progress.progress.as_ref().map(|p| p.items_total);\n    }\n\n    /// Update the state with the validation progress.\n    pub(crate) fn update_progress_valid(&mut self, progress: &TrainingProgress) {\n        if self.num_samples_valid.is_some() {\n            return;\n        }\n\n        // If num_samples_train is None, keep the default max_samples for validation.\n        if let Some(num_sample_train) = self.num_samples_train {\n            for (_, (_recent, full)) in self.data.iter_mut() {\n                let ratio = match &progress.progress {\n                    Some(p) => p.items_total as f64 / num_sample_train as f64,\n                    None => progress.global_progress.items_total as f64 / num_sample_train as f64,\n                };\n\n                full.update_max_sample(TuiSplit::Valid, ratio);\n            }\n        }\n\n        self.epoch = progress.global_progress.items_processed;\n        self.num_samples_valid = progress.progress.as_ref().map(|p| p.items_total);\n    }\n\n    /// Update the state with the testing progress.\n    pub(crate) fn update_progress_test(&mut self, progress: &EvaluationProgress) {\n        if self.num_samples_test.is_some() {\n            return;\n        }\n\n        if let Some(num_sample_train) = self.num_samples_train {\n            for (_, (_recent, full)) in self.data.iter_mut() {\n                let ratio = progress.progress.items_total as f64 / num_sample_train as f64;\n                full.update_max_sample(TuiSplit::Test, ratio);\n            }\n        }\n\n        self.num_samples_test = Some(progress.progress.items_total);\n    }\n\n    /// Create a view to display the numeric metrics.\n    pub(crate) fn view(&self) -> NumericMetricView<'_> {\n        match self.names.is_empty() {\n            true => NumericMetricView::None,\n            false => match self.kind {\n                PlotKind::Summary => {\n                    NumericMetricView::BarPlots(&self.names, self.selected, self.bar_chart())\n                }\n                _ => NumericMetricView::LinePlots(\n                    &self.names,\n                    self.selected,\n                    self.line_chart(),\n                    self.kind,\n                ),\n            },\n        }\n    }\n\n    /// Handle the current event.\n    pub(crate) fn on_event(&mut self, event: &Event) {\n        if let Event::Key(key) = event {\n            match key.kind {\n                KeyEventKind::Release | KeyEventKind::Repeat => (),\n                #[cfg(target_os = \"windows\")] // Fix the double toggle on Windows.\n                KeyEventKind::Press => return,\n                #[cfg(not(target_os = \"windows\"))]\n                KeyEventKind::Press => (),\n            }\n            match key.code {\n                KeyCode::Right => self.next_metric(),\n                KeyCode::Left => self.previous_metric(),\n                KeyCode::Up => self.switch_kind(),\n                KeyCode::Down => self.switch_kind(),\n                _ => {}\n            }\n        }\n    }\n\n    fn switch_kind(&mut self) {\n        self.kind = match self.kind {\n            PlotKind::Full => PlotKind::Recent,\n            PlotKind::Recent => PlotKind::Summary,\n            PlotKind::Summary => PlotKind::Full,\n        };\n    }\n\n    fn next_metric(&mut self) {\n        self.selected = (self.selected + 1) % {\n            let this = &self;\n            this.data.len()\n        };\n    }\n\n    fn previous_metric(&mut self) {\n        if self.selected > 0 {\n            self.selected -= 1;\n        } else {\n            self.selected = ({\n                let this = &self;\n                this.data.len()\n            }) - 1;\n        }\n    }\n\n    fn line_chart<'a>(&'a self) -> Chart<'a> {\n        let name = self.names.get(self.selected).unwrap();\n        let (recent, full) = self.data.get(name).unwrap();\n\n        let (datasets, axes) = match self.kind {\n            PlotKind::Full => (full.datasets(), &full.axes),\n            PlotKind::Recent => (recent.datasets(), &recent.axes),\n            _ => unreachable!(),\n        };\n\n        Chart::<'a>::new(datasets)\n            .block(Block::default())\n            .x_axis(\n                Axis::default()\n                    .style(Style::default().fg(Color::DarkGray))\n                    .title(\"Iteration\")\n                    .labels(axes.labels_x.clone().into_iter().map(|s| s.bold()))\n                    .bounds(axes.bounds_x),\n            )\n            .y_axis(\n                Axis::default()\n                    .style(Style::default().fg(Color::DarkGray))\n                    .labels(axes.labels_y.clone().into_iter().map(|s| s.bold()))\n                    .bounds(axes.bounds_y),\n            )\n            .legend_position(Some(LegendPosition::Right))\n    }\n\n    fn bar_chart<'a>(&'a self) -> BarChart<'a> {\n        let name = self.names.get(self.selected).unwrap();\n        let (_recent, full) = self.data.get(name).unwrap();\n        let mut bar_width = 0;\n        let bars = full.bars(100, &mut bar_width);\n\n        let data = BarGroup::default().bars(&bars);\n        BarChart::default()\n            .block(Block::default().padding(Padding::new(2, 2, 2, 0)))\n            .bar_width(bar_width as u16)\n            .bar_gap(2)\n            .data(data)\n    }\n}\n\n#[allow(clippy::large_enum_variant)]\n#[derive(new)]\npub(crate) enum NumericMetricView<'a> {\n    LinePlots(&'a [MetricName], usize, Chart<'a>, PlotKind),\n    BarPlots(&'a [MetricName], usize, BarChart<'a>),\n    None,\n}\n\nimpl NumericMetricView<'_> {\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        match self {\n            Self::LinePlots(titles, selected, chart, kind) => {\n                let block = Block::default()\n                    .borders(Borders::ALL)\n                    .title(\"Plots\")\n                    .title_alignment(Alignment::Left);\n                let size_new = block.inner(size);\n                frame.render_widget(block, size);\n\n                let size = size_new;\n\n                let chunks = Layout::default()\n                    .direction(Direction::Vertical)\n                    .constraints(\n                        [\n                            Constraint::Length(2),\n                            Constraint::Length(1),\n                            Constraint::Min(0),\n                        ]\n                        .as_ref(),\n                    )\n                    .split(size);\n\n                let tabs = Tabs::new(\n                    titles\n                        .iter()\n                        .map(|i| Line::from(vec![i.to_string().yellow()])),\n                )\n                .select(selected)\n                .style(Style::default())\n                .highlight_style(\n                    Style::default()\n                        .add_modifier(Modifier::BOLD)\n                        .add_modifier(Modifier::UNDERLINED)\n                        .fg(Color::LightYellow),\n                );\n                let title = match kind {\n                    PlotKind::Full => \"Full History\",\n                    PlotKind::Recent => \"Recent History\",\n                    _ => unreachable!(),\n                };\n\n                let plot_type =\n                    Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);\n\n                frame.render_widget(tabs, chunks[0]);\n                frame.render_widget(plot_type, chunks[1]);\n                frame.render_widget(chart, chunks[2]);\n            }\n            Self::BarPlots(titles, selected, chart) => {\n                let block = Block::default()\n                    .borders(Borders::ALL)\n                    .title(\"Summary\")\n                    .title_alignment(Alignment::Left);\n                let size_new = block.inner(size);\n                frame.render_widget(block, size);\n\n                let size = size_new;\n\n                let chunks = Layout::default()\n                    .direction(Direction::Vertical)\n                    .constraints([\n                        Constraint::Length(2),\n                        Constraint::Length(1),\n                        Constraint::Min(0),\n                    ])\n                    .split(size);\n\n                let tabs = Tabs::new(\n                    titles\n                        .iter()\n                        .map(|i| Line::from(vec![i.to_string().yellow()])),\n                )\n                .select(selected)\n                .style(Style::default())\n                .highlight_style(\n                    Style::default()\n                        .add_modifier(Modifier::BOLD)\n                        .add_modifier(Modifier::UNDERLINED)\n                        .fg(Color::LightYellow),\n                );\n                let title = \"Summary\";\n\n                let plot_type =\n                    Paragraph::new(Line::from(title.bold())).alignment(Alignment::Center);\n\n                frame.render_widget(tabs, chunks[0]);\n                frame.render_widget(plot_type, chunks[1]);\n                frame.render_widget(chart, chunks[2]);\n            }\n            Self::None => {}\n        };\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/metric_text.rs",
    "content": "use super::TerminalFrame;\nuse crate::{\n    metric::{MetricEntry, MetricName},\n    renderer::tui::{TuiGroup, TuiSplit},\n};\nuse ratatui::{\n    prelude::{Alignment, Rect},\n    style::{Color, Style, Stylize},\n    text::{Line, Span},\n    widgets::{Block, Borders, Paragraph, Wrap},\n};\nuse std::{collections::BTreeMap, sync::Arc};\n\n#[derive(Default)]\npub(crate) struct TextMetricsState {\n    data: BTreeMap<String, MetricGroup>,\n    names: Vec<MetricName>,\n}\n\nstruct MetricGroup {\n    groups: BTreeMap<TuiGroup, MetricSplits>,\n}\n\nimpl MetricGroup {\n    fn new(group: TuiGroup, metric: MetricSplits) -> Self {\n        Self {\n            groups: BTreeMap::from_iter(Some((group, metric))),\n        }\n    }\n    fn update(&mut self, split: TuiSplit, group: TuiGroup, metric: MetricEntry) {\n        match self.groups.get_mut(&group) {\n            Some(value) => value.update(split, metric),\n            None => {\n                let value = MetricSplits::new(split, metric);\n\n                self.groups.insert(group, value);\n            }\n        }\n    }\n}\n\nstruct MetricSplits {\n    splits: BTreeMap<TuiSplit, MetricEntry>,\n}\n\nimpl MetricSplits {\n    fn new(split: TuiSplit, metric: MetricEntry) -> Self {\n        Self {\n            splits: BTreeMap::from_iter(Some((split, metric))),\n        }\n    }\n\n    fn update(&mut self, split: TuiSplit, metric: MetricEntry) {\n        self.splits.insert(split, metric);\n    }\n}\n\nimpl TextMetricsState {\n    pub(crate) fn update(\n        &mut self,\n        split: TuiSplit,\n        group: TuiGroup,\n        metric: MetricEntry,\n        name: Arc<String>,\n    ) {\n        if let Some(existing) = self.data.get_mut(name.as_ref()) {\n            existing.update(split, group, metric);\n        } else {\n            let key = name.clone();\n            let value = MetricSplits::new(split, metric);\n\n            self.names.push(key.clone());\n            self.data\n                .insert(key.to_string(), MetricGroup::new(group, value));\n        }\n    }\n    pub(crate) fn view(&self) -> TextMetricView {\n        TextMetricView::new(&self.names, &self.data)\n    }\n}\n\npub(crate) struct TextMetricView {\n    lines: Vec<Vec<Span<'static>>>,\n}\n\nimpl TextMetricView {\n    fn new(names: &[MetricName], data: &BTreeMap<String, MetricGroup>) -> Self {\n        let mut lines = Vec::with_capacity(names.len() * 4);\n\n        let start_line = |title: &str| vec![Span::from(format!(\" {title} \")).bold().yellow()];\n        let format_line = |group: &TuiGroup, split: &TuiSplit, formatted: &str| {\n            vec![\n                Span::from(format!(\" {group}{split} \")).bold(),\n                Span::from(formatted.to_string()).italic(),\n            ]\n        };\n\n        for name in names {\n            lines.push(start_line(name));\n\n            let entry = data.get(name.as_ref()).unwrap();\n\n            for (name, group) in entry.groups.iter() {\n                for (split, entry) in group.splits.iter() {\n                    lines.push(format_line(name, split, &entry.serialized_entry.formatted));\n                }\n            }\n\n            lines.push(vec![Span::from(\"\")]);\n        }\n\n        Self { lines }\n    }\n\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())\n            .alignment(Alignment::Left)\n            .wrap(Wrap { trim: false })\n            .block(Block::default().borders(Borders::ALL).title(\"Metrics\"))\n            .style(Style::default().fg(Color::Gray));\n\n        frame.render_widget(paragraph, size);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/mod.rs",
    "content": "mod base;\nmod controls;\nmod full_history;\nmod metric_numeric;\nmod metric_text;\nmod plot_utils;\nmod popup;\nmod progress;\nmod recent_history;\nmod renderer;\nmod status;\n\npub(crate) use base::*;\npub(crate) use controls::*;\npub(crate) use full_history::*;\npub(crate) use metric_numeric::*;\npub(crate) use metric_text::*;\npub(crate) use plot_utils::*;\npub(crate) use popup::*;\npub(crate) use progress::*;\npub(crate) use recent_history::*;\npub use renderer::*;\npub(crate) use status::*;\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/plot_utils.rs",
    "content": "use crate::metric::format_float;\n\nconst AXIS_TITLE_PRECISION: usize = 2;\n\n/// The data describing both X and Y axes.\npub(crate) struct PlotAxes {\n    pub(crate) labels_x: Vec<String>,\n    pub(crate) labels_y: Vec<String>,\n    pub(crate) bounds_x: [f64; 2],\n    pub(crate) bounds_y: [f64; 2],\n}\n\nimpl Default for PlotAxes {\n    fn default() -> Self {\n        Self {\n            bounds_x: [f64::MAX, f64::MIN],\n            bounds_y: [f64::MAX, f64::MIN],\n            labels_x: Vec::new(),\n            labels_y: Vec::new(),\n        }\n    }\n}\n\nimpl PlotAxes {\n    /// Update the bounds based on the min max of each X and Y axes with both train and valid data.\n    pub(crate) fn update_bounds(&mut self, (x_min, x_max): (f64, f64), (y_min, y_max): (f64, f64)) {\n        self.bounds_x = [x_min, x_max];\n        self.bounds_y = [y_min, y_max];\n\n        // We know x are integers.\n        self.labels_x = vec![format!(\"{x_min}\"), format!(\"{x_max}\")];\n        self.labels_y = vec![\n            format_float(y_min, AXIS_TITLE_PRECISION),\n            format_float(y_max, AXIS_TITLE_PRECISION),\n        ];\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/popup.rs",
    "content": "use ratatui::{\n    crossterm::event::{Event, KeyCode},\n    prelude::{Alignment, Constraint, Direction, Layout, Rect},\n    style::{Color, Modifier, Style, Stylize},\n    text::{Line, Span},\n    widgets::{Block, Borders, Paragraph, Wrap},\n};\n\nuse super::TerminalFrame;\n\n/// Popup callback function.\npub(crate) trait CallbackFn: Send + Sync {\n    /// Call the function and return if the popup state should be reset.\n    fn call(&self) -> bool;\n}\n\n/// Popup callback.\npub(crate) struct Callback {\n    title: String,\n    description: String,\n    trigger: char,\n    callback: Box<dyn CallbackFn>,\n}\n\nimpl Callback {\n    /// Create a new popup.\n    pub(crate) fn new<T, D, C>(title: T, description: D, trigger: char, callback: C) -> Self\n    where\n        T: Into<String>,\n        D: Into<String>,\n        C: CallbackFn + 'static,\n    {\n        Self {\n            title: title.into(),\n            description: description.into(),\n            trigger,\n            callback: Box::new(callback),\n        }\n    }\n}\n\n/// Popup state.\npub(crate) enum PopupState {\n    Empty,\n    Full(String, Vec<Callback>),\n}\n\nimpl PopupState {\n    /// If the popup is empty.\n    pub(crate) fn is_empty(&self) -> bool {\n        matches!(&self, PopupState::Empty)\n    }\n    /// Handle popup events.\n    pub(crate) fn on_event(&mut self, event: &Event) {\n        let mut reset = false;\n\n        match self {\n            PopupState::Empty => {}\n            PopupState::Full(_, callbacks) => {\n                for callback in callbacks.iter() {\n                    if let Event::Key(key) = event\n                        && let KeyCode::Char(key) = &key.code\n                        && &callback.trigger == key\n                        && callback.callback.call()\n                    {\n                        reset = true;\n                    }\n                }\n            }\n        };\n\n        if reset {\n            *self = Self::Empty;\n        }\n    }\n    /// Create the popup view.\n    pub(crate) fn view(&self) -> Option<PopupView<'_>> {\n        match self {\n            PopupState::Empty => None,\n            PopupState::Full(title, callbacks) => Some(PopupView::new(title, callbacks)),\n        }\n    }\n}\n\n#[derive(new)]\npub(crate) struct PopupView<'a> {\n    title: &'a String,\n    callbacks: &'a [Callback],\n}\n\nimpl<'a> PopupView<'a> {\n    /// Render the view.\n    pub(crate) fn render<'b>(&'a self, frame: &mut TerminalFrame<'b>, size: Rect) {\n        let lines = self\n            .callbacks\n            .iter()\n            .flat_map(|callback| {\n                vec![\n                    Line::from(vec![\n                        Span::from(format!(\"[{}] \", callback.trigger)).bold(),\n                        Span::from(format!(\"{} \", callback.title)).yellow().bold(),\n                    ]),\n                    Line::from(Span::from(\"\")),\n                    Line::from(Span::from(callback.description.to_string()).italic()),\n                    Line::from(Span::from(\"\")),\n                ]\n            })\n            .collect::<Vec<_>>();\n\n        let paragraph = Paragraph::new(lines)\n            .alignment(Alignment::Left)\n            .wrap(Wrap { trim: false })\n            .style(Style::default().fg(Color::Gray))\n            .block(\n                Block::default()\n                    .borders(Borders::ALL)\n                    .title_alignment(Alignment::Center)\n                    .style(Style::default().fg(Color::Gray))\n                    .title(Span::styled(\n                        self.title,\n                        Style::default().add_modifier(Modifier::BOLD),\n                    )),\n            );\n\n        let area = centered_percent(20, size, Direction::Horizontal);\n        let area = centered_percent(20, area, Direction::Vertical);\n\n        frame.render_widget(paragraph, area);\n    }\n}\n\n/// The percent represents the amount of space that will be taken by each side.\nfn centered_percent(percent: u16, size: Rect, direction: Direction) -> Rect {\n    let center = 100 - (percent * 2);\n\n    Layout::default()\n        .direction(direction)\n        .constraints([\n            Constraint::Percentage(percent),\n            Constraint::Percentage(center),\n            Constraint::Percentage(percent),\n        ])\n        .split(size)[1]\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/progress.rs",
    "content": "use super::TerminalFrame;\nuse crate::renderer::{EvaluationProgress, TrainingProgress, tui::TuiSplit};\nuse ratatui::{\n    prelude::{Alignment, Constraint, Direction, Layout, Rect},\n    style::{Color, Style, Stylize},\n    text::{Line, Span},\n    widgets::{Block, Borders, Gauge, Paragraph},\n};\nuse std::time::{Duration, Instant};\n\n/// Simple progress bar for the training.\n///\n/// We currently ignore the time taken for the validation part.\npub(crate) struct ProgressBarState {\n    progress_total: f64, // Progress for total execution.\n    progress_task: f64,  // Progress for current task.\n    split: TuiSplit,\n    starting_epoch: usize,\n    estimate: ProgressEstimate,\n}\n\nconst MINUTE: u64 = 60;\nconst HOUR: u64 = 60 * 60;\nconst DAY: u64 = 24 * 60 * 60;\n\nimpl ProgressBarState {\n    pub fn new(checkpoint: Option<usize>) -> Self {\n        Self {\n            progress_total: 0.0,\n            progress_task: 0.0,\n            split: TuiSplit::Train,\n            estimate: ProgressEstimate::new(),\n            starting_epoch: checkpoint.unwrap_or(0),\n        }\n    }\n    /// Update the training progress.\n    pub(crate) fn update_train(&mut self, progress: &TrainingProgress) {\n        self.progress_total = calculate_progress(progress, 0, 0);\n        let local_progress = progress\n            .progress\n            .as_ref()\n            .unwrap_or(&progress.global_progress);\n        self.progress_task =\n            local_progress.items_processed as f64 / local_progress.items_total as f64;\n        self.estimate.update(progress, self.starting_epoch);\n        self.split = TuiSplit::Train;\n    }\n\n    /// Update the validation progress.\n    pub(crate) fn update_valid(&mut self, progress: &TrainingProgress) {\n        // We don't use the validation for the total progress yet.\n        let local_progress = progress\n            .progress\n            .as_ref()\n            .unwrap_or(&progress.global_progress);\n        self.progress_task =\n            local_progress.items_processed as f64 / local_progress.items_total as f64;\n        self.split = TuiSplit::Valid;\n    }\n\n    /// Update the testing progress.\n    pub(crate) fn update_test(&mut self, progress: &EvaluationProgress) {\n        // We don't use the testing for the total progress yet.\n        self.progress_task =\n            progress.progress.items_processed as f64 / progress.progress.items_total as f64;\n        self.split = TuiSplit::Test;\n    }\n\n    /// Create a view for the current progress.\n    pub(crate) fn view(&self) -> ProgressBarView {\n        const NO_ETA: &str = \"---\";\n\n        let eta = match self.estimate.secs() {\n            Some(eta) => format_eta(eta),\n            None => NO_ETA.to_string(),\n        };\n        ProgressBarView::new(\n            self.progress_total,\n            self.progress_task,\n            self.split.color(),\n            eta,\n        )\n    }\n}\n\n#[derive(new)]\npub(crate) struct ProgressBarView {\n    progress: f64,\n    progress_task: f64,\n    color_task: Color,\n    eta: String,\n}\n\nimpl ProgressBarView {\n    /// Render the view.\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        let block = Block::default()\n            .borders(Borders::ALL)\n            .title(\"Progress\")\n            .title_alignment(Alignment::Left);\n        let size_new = block.inner(size);\n        frame.render_widget(block, size);\n        let size = size_new;\n\n        let chunks = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints([Constraint::Ratio(1, 2), Constraint::Ratio(1, 2)].as_ref())\n            .split(size);\n\n        let size_task = chunks[0];\n        let size_total = chunks[1];\n\n        let calculate_size = |size: Rect| {\n            Layout::default()\n                .direction(Direction::Horizontal)\n                .constraints(\n                    [\n                        Constraint::Length(1), // Empty space\n                        Constraint::Min(0),\n                        Constraint::Length(self.eta.len() as u16 + 4),\n                    ]\n                    .as_ref(),\n                )\n                .split(size)\n        };\n\n        let chunks = calculate_size(size_total);\n        let size_gauge_total = chunks[1];\n        let size_eta = chunks[2];\n        let chunks = calculate_size(size_task);\n        let size_gauge_task = chunks[1];\n\n        let progress_total = Gauge::default()\n            .gauge_style(Style::default().fg(Color::Yellow))\n            .ratio(self.progress.min(1.0));\n        let progress_task = Gauge::default()\n            .gauge_style(Style::default().fg(self.color_task))\n            .ratio(self.progress_task.min(1.0));\n\n        let eta = Paragraph::new(Line::from(vec![\n            Span::from(\" (\"),\n            Span::from(self.eta).italic(),\n            Span::from(\") \"),\n        ]));\n\n        frame.render_widget(progress_task, size_gauge_task);\n        frame.render_widget(progress_total, size_gauge_total);\n        frame.render_widget(eta, size_eta);\n    }\n}\n\nstruct ProgressEstimate {\n    started: Instant,\n    started_after_warmup: Option<Instant>,\n    warmup_num_items: usize,\n    progress: f64,\n}\n\nimpl ProgressEstimate {\n    fn new() -> Self {\n        Self {\n            started: Instant::now(),\n            started_after_warmup: None,\n            warmup_num_items: 0,\n            progress: 0.0,\n        }\n    }\n\n    fn secs(&self) -> Option<u64> {\n        let eta = self.started_after_warmup?.elapsed();\n\n        let total_estimated = (eta.as_secs() as f64) / self.progress;\n\n        if total_estimated.is_normal() {\n            let remaining = 1.0 - self.progress;\n            let eta = (total_estimated * remaining) as u64;\n            Some(eta)\n        } else {\n            None\n        }\n    }\n\n    fn update(&mut self, progress: &TrainingProgress, starting_epoch: usize) {\n        if self.started_after_warmup.is_some() {\n            self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);\n            return;\n        }\n\n        const WARMUP_NUM_ITERATION: usize = 10;\n\n        // When the training has started since 30 seconds.\n        if self.started.elapsed() > Duration::from_secs(30) {\n            self.init(progress, starting_epoch);\n            return;\n        }\n\n        // When the training has started since at least 10 seconds and completed 10 iterations.\n        if progress.iteration >= Some(WARMUP_NUM_ITERATION)\n            && self.started.elapsed() > Duration::from_secs(10)\n        {\n            self.init(progress, starting_epoch);\n        }\n    }\n\n    fn init(&mut self, progress: &TrainingProgress, starting_epoch: usize) {\n        let epoch = progress.global_progress.items_processed - starting_epoch;\n\n        self.warmup_num_items = match &progress.progress {\n            Some(local_progress) => {\n                let epoch_items = (epoch - 1) * local_progress.items_total;\n                let iteration_items = local_progress.items_processed;\n                epoch_items + iteration_items\n            }\n            None => epoch,\n        };\n\n        self.started_after_warmup = Some(Instant::now());\n        self.progress = calculate_progress(progress, starting_epoch, self.warmup_num_items);\n    }\n}\n\nfn calculate_progress(\n    progress: &TrainingProgress,\n    starting_epoch: usize,\n    ignore_num_items: usize,\n) -> f64 {\n    let epoch_total = progress.global_progress.items_total - starting_epoch;\n    let epoch = progress.global_progress.items_processed - starting_epoch;\n    match &progress.progress {\n        Some(local_progress) => {\n            let total_items = local_progress.items_total * epoch_total;\n            let epoch_items = (epoch - 1) * local_progress.items_total;\n            let iteration_items = local_progress.items_processed;\n            let num_items = epoch_items + iteration_items - ignore_num_items;\n\n            num_items as f64 / total_items as f64\n        }\n        None => epoch as f64 / epoch_total as f64,\n    }\n}\n\nfn format_eta(eta_secs: u64) -> String {\n    let seconds = eta_secs % 60;\n    let minutes = eta_secs / MINUTE % 60;\n    let hours = eta_secs / HOUR % 24;\n    let days = eta_secs / DAY;\n\n    if days > 1 {\n        format!(\"{days} days\")\n    } else if days == 1 {\n        \"1 day\".to_string()\n    } else if hours > 1 {\n        format!(\"{hours} hours\")\n    } else if hours == 1 {\n        \"1 hour\".to_string()\n    } else if minutes > 1 {\n        format!(\"{minutes} mins\")\n    } else if minutes == 1 {\n        \"1 min\".to_string()\n    } else if seconds > 1 {\n        format!(\"{seconds} secs\")\n    } else {\n        \"1 sec\".to_string()\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_core::data::dataloader::Progress;\n\n    #[test]\n    fn test_format_eta() {\n        assert_eq!(\"55 secs\", format_eta(55), \"Less than 1 minutes\");\n        assert_eq!(\"1 min\", format_eta(61), \"More than 1 minutes\");\n        assert_eq!(\"2 mins\", format_eta(2 * 61), \"More than 2 minutes\");\n        assert_eq!(\"1 hour\", format_eta(3601), \"More than 1 hour\");\n        assert_eq!(\"2 hours\", format_eta(2 * 3601), \"More than 2 hour\");\n        assert_eq!(\"1 day\", format_eta(24 * 3601), \"More than 1 day\");\n        assert_eq!(\"2 days\", format_eta(48 * 3601), \"More than 2 day\");\n    }\n\n    #[test]\n    fn calculate_progress_for_eta() {\n        let half = Progress {\n            items_processed: 5,\n            items_total: 10,\n        };\n        let global_progress = Progress {\n            items_processed: 9,\n            items_total: 10,\n        };\n        let progress = TrainingProgress {\n            progress: Some(half),\n            global_progress,\n            iteration: Some(500),\n        };\n\n        let starting_epoch = 8;\n        let progress = calculate_progress(&progress, starting_epoch, 0);\n\n        // Two epochs remaining while the first is half done.\n        assert_eq!(0.25, progress);\n    }\n\n    #[test]\n    fn calculate_progress_for_eta_with_warmup() {\n        let half = Progress {\n            items_processed: 110,\n            items_total: 1000,\n        };\n        let global_progress = Progress {\n            items_processed: 9,\n            items_total: 10,\n        };\n        let progress = TrainingProgress {\n            progress: Some(half),\n            global_progress,\n            iteration: Some(500),\n        };\n\n        let starting_epoch = 8;\n        let progress = calculate_progress(&progress, starting_epoch, 10);\n\n        // Two epochs remaining while the first is half done.\n        assert_eq!(0.05, progress);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/recent_history.rs",
    "content": "use super::PlotAxes;\nuse crate::renderer::tui::TuiTag;\nuse ratatui::{\n    style::{Color, Style},\n    symbols,\n    widgets::{Dataset, GraphType},\n};\nuse std::collections::BTreeMap;\n\nconst FACTOR_BEFORE_RESIZE: usize = 2;\n\n/// A plot that shows the recent history at full resolution.\npub(crate) struct RecentHistoryPlot {\n    pub(crate) axes: PlotAxes,\n    points: BTreeMap<TuiTag, RecentHistoryPoints>,\n    max_samples: usize,\n}\n\nstruct RecentHistoryPoints {\n    min_x: f64,\n    max_x: f64,\n    min_y: f64,\n    max_y: f64,\n    cursor: usize,\n    points: Vec<(f64, f64)>,\n    max_samples: usize,\n    factor_before_resize: usize,\n}\n\nimpl RecentHistoryPlot {\n    pub(crate) fn new(max_samples: usize) -> Self {\n        Self {\n            axes: PlotAxes::default(),\n            points: BTreeMap::default(),\n            max_samples,\n        }\n    }\n\n    pub(crate) fn push(&mut self, tag: TuiTag, data: f64) {\n        if !self.points.contains_key(&tag) {\n            self.points\n                .insert(tag.clone(), RecentHistoryPoints::new(self.max_samples));\n        }\n\n        let (x_min, x_current) = self.point_x();\n\n        for (s, entry) in self.points.iter_mut() {\n            if s == &tag {\n                entry.push((x_current, data));\n            }\n            entry.update_cursor(x_min);\n        }\n\n        self.update_bounds();\n    }\n\n    pub(crate) fn datasets(&self) -> Vec<Dataset<'_>> {\n        let mut datasets = Vec::new();\n\n        for (tag, points) in self.points.iter() {\n            datasets.push(points.dataset(format!(\"{tag}\"), tag.split.color()));\n        }\n\n        datasets\n    }\n\n    fn point_x(&mut self) -> (f64, f64) {\n        let mut x_current = f64::MIN;\n        let mut x_min = f64::MAX;\n\n        for point in self.points.values() {\n            x_current = f64::max(x_current, point.max_x);\n            x_min = f64::min(x_min, point.min_x);\n        }\n\n        if x_current - x_min >= self.max_samples as f64 {\n            x_min += 1.0;\n        }\n\n        (x_min, x_current + 1.0)\n    }\n\n    fn update_bounds(&mut self) {\n        let (mut x_min, mut x_max) = (f64::MAX, f64::MIN);\n        let (mut y_min, mut y_max) = (f64::MAX, f64::MIN);\n\n        for points in self.points.values() {\n            x_min = f64::min(x_min, points.min_x);\n            x_max = f64::max(x_max, points.max_x);\n            y_min = f64::min(y_min, points.min_y);\n            y_max = f64::max(y_max, points.max_y);\n        }\n\n        self.axes.update_bounds((x_min, x_max), (y_min, y_max));\n    }\n}\n\nimpl RecentHistoryPoints {\n    fn new(max_samples: usize) -> Self {\n        let factor_before_resize = FACTOR_BEFORE_RESIZE;\n\n        Self {\n            min_x: 0.,\n            max_x: 0.,\n            min_y: f64::MAX,\n            max_y: f64::MIN,\n            points: Vec::with_capacity(factor_before_resize * max_samples),\n            cursor: 0,\n            max_samples,\n            factor_before_resize,\n        }\n    }\n\n    fn push(&mut self, (x, y): (f64, f64)) {\n        if x > self.max_x {\n            self.max_x = x;\n        }\n        if x < self.min_x {\n            self.min_x = x;\n        }\n        if y > self.max_y {\n            self.max_y = y;\n        }\n        if y < self.min_y {\n            self.min_y = y\n        }\n        self.points.push((x, y));\n    }\n\n    fn update_cursor(&mut self, min_x: f64) {\n        if self.min_x >= min_x {\n            return;\n        }\n        self.min_x = min_x;\n\n        let mut update_y_max = false;\n        let mut update_y_min = false;\n\n        while let Some((x, y)) = self.points.get(self.cursor) {\n            if *x >= self.min_x {\n                break;\n            }\n\n            if *y == self.max_y {\n                update_y_max = true\n            }\n            if *y == self.min_y {\n                update_y_min = true;\n            }\n\n            self.cursor += 1;\n        }\n\n        if update_y_max {\n            self.max_y = self.calculate_max_y();\n        }\n\n        if update_y_min {\n            self.min_y = self.calculate_min_y();\n        }\n\n        if self.points.len() >= self.max_samples * self.factor_before_resize {\n            self.resize();\n        }\n    }\n\n    fn slice(&self) -> &[(f64, f64)] {\n        &self.points[self.cursor..self.points.len()]\n    }\n\n    fn calculate_max_y(&self) -> f64 {\n        let mut max_y = f64::MIN;\n\n        for (_x, y) in self.slice() {\n            max_y = f64::max(max_y, *y);\n        }\n\n        max_y\n    }\n\n    fn calculate_min_y(&self) -> f64 {\n        let mut min_y = f64::MAX;\n\n        for (_x, y) in self.slice() {\n            if *y < min_y {\n                min_y = *y;\n            }\n        }\n\n        min_y\n    }\n\n    fn resize(&mut self) {\n        let mut points = Vec::with_capacity(self.max_samples * self.factor_before_resize);\n\n        for i in self.cursor..self.points.len() {\n            points.push(self.points[i]);\n        }\n\n        self.points = points;\n        self.cursor = 0;\n    }\n\n    fn dataset<'a>(&'a self, name: String, color: Color) -> Dataset<'a> {\n        let data = &self.points[self.cursor..self.points.len()];\n\n        Dataset::default()\n            .name(name)\n            .marker(symbols::Marker::Braille)\n            .style(Style::default().fg(color).bold())\n            .graph_type(GraphType::Scatter)\n            .data(data)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::renderer::tui::{TuiGroup, TuiSplit};\n\n    use super::*;\n\n    #[test]\n    fn test_push_update_bounds_max_y() {\n        let mut chart = RecentHistoryPlot::new(2);\n        let tag = TuiTag::new(TuiSplit::Train, TuiGroup::Default);\n\n        chart.push(tag.clone(), 15.0);\n        chart.push(tag.clone(), 10.0);\n        chart.push(tag.clone(), 14.0);\n\n        assert_eq!(chart.axes.bounds_y[1], 15.);\n        chart.push(tag, 10.0);\n        assert_eq!(chart.axes.bounds_y[1], 14.);\n    }\n\n    #[test]\n    fn test_push_update_bounds_min_y() {\n        let mut chart = RecentHistoryPlot::new(2);\n        let tag = TuiTag::new(TuiSplit::Train, TuiGroup::Default);\n\n        chart.push(tag.clone(), 5.0);\n        chart.push(tag.clone(), 10.0);\n        chart.push(tag.clone(), 14.0);\n\n        assert_eq!(chart.axes.bounds_y[0], 5.);\n        chart.push(tag, 10.0);\n        assert_eq!(chart.axes.bounds_y[0], 10.);\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/renderer.rs",
    "content": "use crate::metric::{MetricDefinition, MetricId};\nuse crate::renderer::tui::TuiSplit;\nuse crate::renderer::{\n    EvaluationName, EvaluationProgress, MetricState, MetricsRenderer, MetricsRendererEvaluation,\n    ProgressType, TrainingProgress,\n};\nuse crate::renderer::{MetricsRendererTraining, tui::NumericMetricsState};\nuse crate::{Interrupter, LearnerSummary};\nuse ratatui::{\n    Terminal,\n    crossterm::{\n        event::{self, Event, KeyCode},\n        execute,\n        terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},\n    },\n    prelude::*,\n};\nuse std::collections::HashMap;\nuse std::panic::{set_hook, take_hook};\nuse std::sync::mpsc::{Receiver, Sender};\nuse std::sync::{Arc, Mutex, mpsc};\nuse std::thread::JoinHandle;\nuse std::{\n    error::Error,\n    io::{self, Stdout},\n    time::{Duration, Instant},\n};\n\nuse super::{\n    Callback, CallbackFn, ControlsView, MetricsView, PopupState, ProgressBarState, StatusState,\n    TextMetricsState, TuiGroup, TuiTag,\n};\n\n/// The current terminal backend.\npub(crate) type TerminalBackend = CrosstermBackend<Stdout>;\n/// The current terminal frame.\npub(crate) type TerminalFrame<'a> = ratatui::Frame<'a>;\n\ntype PanicHook = Box<dyn Fn(&std::panic::PanicHookInfo<'_>) + 'static + Sync + Send>;\n\nconst MAX_REFRESH_RATE_MILLIS: u64 = 100;\n\nenum TuiRendererEvent {\n    MetricRegistration(MetricDefinition),\n    MetricsUpdate((TuiSplit, TuiGroup, MetricState)),\n    StatusUpdateTrain((TuiSplit, TrainingProgress, Vec<ProgressType>)),\n    StatusUpdateTest((EvaluationProgress, Vec<ProgressType>)),\n    ProcessEnd {\n        summary: Option<LearnerSummary>,\n        /// Interrupter reset.\n        reset: bool,\n    },\n    ManualClose,\n    Close,\n    Persistent,\n}\n\n/// The terminal UI metrics renderer.\npub struct TuiMetricsRendererWrapper {\n    sender: mpsc::Sender<TuiRendererEvent>,\n    interrupter: Interrupter,\n    handle_join: Option<JoinHandle<()>>,\n    kill_signal: Arc<Mutex<Receiver<()>>>,\n}\n\nimpl TuiMetricsRendererWrapper {\n    /// Create a new terminal UI renderer.\n    pub fn new(interrupter: Interrupter, checkpoint: Option<usize>) -> Self {\n        let (sender, receiver) = mpsc::channel();\n        let (kill_signal_sender, kill_signal_receiver) = mpsc::channel();\n\n        let interrupter_clone = interrupter.clone();\n        let handle_join = std::thread::Builder::new()\n            .name(\"train-renderer\".into())\n            .spawn(move || {\n                let mut renderer =\n                    TuiMetricsRenderer::new(interrupter_clone, checkpoint, kill_signal_sender);\n\n                let tick_rate = Duration::from_millis(MAX_REFRESH_RATE_MILLIS);\n                loop {\n                    match receiver.try_recv() {\n                        Ok(event) => renderer.handle_event(event),\n                        Err(mpsc::TryRecvError::Empty) => (),\n                        Err(mpsc::TryRecvError::Disconnected) => {\n                            log::error!(\"Renderer thread disconnected.\");\n                            break;\n                        }\n                    }\n\n                    // Render\n                    if renderer.last_update.elapsed() >= tick_rate\n                        && let Err(err) = renderer.render()\n                    {\n                        log::error!(\"Render error: {err}\");\n                        break;\n                    }\n\n                    if (renderer.manual_close && renderer.interrupter.should_stop())\n                        || renderer.close\n                    {\n                        break;\n                    }\n                }\n            })\n            .unwrap();\n\n        Self {\n            sender,\n            interrupter,\n            handle_join: Some(handle_join),\n            kill_signal: Arc::new(Mutex::new(kill_signal_receiver)),\n        }\n    }\n\n    fn send_event(&self, event: TuiRendererEvent) {\n        if self.kill_signal.lock().unwrap().try_recv().is_ok() {\n            panic!(\"Killing training from user input.\")\n        }\n        if let Err(e) = self.sender.send(event) {\n            log::warn!(\"Failed to send TUI event: {e}\");\n        }\n    }\n\n    /// Set the renderer to persistent mode.\n    pub fn persistent(self) -> Self {\n        self.send_event(TuiRendererEvent::Persistent);\n        self\n    }\n}\n\nstruct TuiMetricsRenderer {\n    terminal: Terminal<TerminalBackend>,\n    last_update: std::time::Instant,\n    progress: ProgressBarState,\n    metric_definitions: HashMap<MetricId, MetricDefinition>,\n    metrics_numeric: NumericMetricsState,\n    metrics_text: TextMetricsState,\n    status: StatusState,\n    interrupter: Interrupter,\n    popup: PopupState,\n    previous_panic_hook: Option<Arc<PanicHook>>,\n    persistent: bool,\n    manual_close: bool,\n    close: bool,\n    summary: Option<LearnerSummary>,\n    kill_signal: Sender<()>,\n}\n\nimpl MetricsRendererEvaluation for TuiMetricsRendererWrapper {\n    fn update_test(&mut self, name: EvaluationName, state: MetricState) {\n        self.send_event(TuiRendererEvent::MetricsUpdate((\n            TuiSplit::Test,\n            TuiGroup::Named(name.name),\n            state,\n        )));\n    }\n\n    fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>) {\n        self.send_event(TuiRendererEvent::StatusUpdateTest((\n            item,\n            progress_indicators,\n        )));\n    }\n\n    fn on_test_end(&mut self, summary: Option<LearnerSummary>) -> Result<(), Box<dyn Error>> {\n        // Update the summary\n        self.send_event(TuiRendererEvent::ProcessEnd {\n            summary,\n            reset: false,\n        });\n        Ok(())\n    }\n}\n\nimpl MetricsRenderer for TuiMetricsRendererWrapper {\n    fn manual_close(&mut self) {\n        self.send_event(TuiRendererEvent::ManualClose);\n        let _ = self.handle_join.take().unwrap().join();\n    }\n\n    fn register_metric(&mut self, definition: MetricDefinition) {\n        self.send_event(TuiRendererEvent::MetricRegistration(definition));\n    }\n}\n\nimpl MetricsRendererTraining for TuiMetricsRendererWrapper {\n    fn update_train(&mut self, state: MetricState) {\n        self.send_event(TuiRendererEvent::MetricsUpdate((\n            TuiSplit::Train,\n            TuiGroup::Default,\n            state,\n        )));\n    }\n\n    fn update_valid(&mut self, state: MetricState) {\n        self.send_event(TuiRendererEvent::MetricsUpdate((\n            TuiSplit::Valid,\n            TuiGroup::Default,\n            state,\n        )));\n    }\n\n    fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {\n        self.send_event(TuiRendererEvent::StatusUpdateTrain((\n            TuiSplit::Train,\n            item,\n            progress_indicators,\n        )));\n    }\n\n    fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>) {\n        self.send_event(TuiRendererEvent::StatusUpdateTrain((\n            TuiSplit::Valid,\n            item,\n            progress_indicators,\n        )));\n    }\n\n    fn on_train_end(&mut self, summary: Option<LearnerSummary>) -> Result<(), Box<dyn Error>> {\n        // Reset for following steps.\n        self.interrupter.reset();\n        // Update the summary\n        self.send_event(TuiRendererEvent::ProcessEnd {\n            summary,\n            reset: true,\n        });\n        Ok(())\n    }\n}\n\nimpl Drop for TuiMetricsRendererWrapper {\n    fn drop(&mut self) {\n        if !std::thread::panicking() {\n            self.send_event(TuiRendererEvent::Close);\n            let _ = self.handle_join.take().unwrap().join();\n        }\n    }\n}\n\nimpl TuiMetricsRenderer {\n    fn update_metric(&mut self, split: TuiSplit, group: TuiGroup, state: MetricState) {\n        match state {\n            MetricState::Generic(entry) => {\n                let name = self\n                    .metric_definitions\n                    .get(&entry.metric_id)\n                    .unwrap()\n                    .name\n                    .clone()\n                    .into();\n                self.metrics_text.update(split, group, entry, name);\n            }\n            MetricState::Numeric(entry, value) => {\n                let name: Arc<String> = self\n                    .metric_definitions\n                    .get(&entry.metric_id)\n                    .unwrap()\n                    .name\n                    .clone()\n                    .into();\n                self.metrics_numeric\n                    .push(TuiTag::new(split, group.clone()), name.clone(), value);\n                self.metrics_text.update(split, group, entry, name);\n            }\n        };\n    }\n\n    pub fn new(\n        interrupter: Interrupter,\n        checkpoint: Option<usize>,\n        kill_signal: Sender<()>,\n    ) -> Self {\n        let mut stdout = io::stdout();\n        execute!(stdout, EnterAlternateScreen).unwrap();\n        enable_raw_mode().unwrap();\n        let terminal = Terminal::new(CrosstermBackend::new(stdout)).unwrap();\n\n        // Reset the terminal to raw mode on panic before running the panic handler\n        // This prevents that the panic message is not visible for the user.\n        let previous_panic_hook = Arc::new(take_hook());\n        set_hook(Box::new({\n            let previous_panic_hook = previous_panic_hook.clone();\n            move |panic_info| {\n                let _ = disable_raw_mode();\n                let _ = execute!(io::stdout(), LeaveAlternateScreen);\n                previous_panic_hook(panic_info);\n            }\n        }));\n\n        Self {\n            terminal,\n            last_update: Instant::now(),\n            progress: ProgressBarState::new(checkpoint),\n            metric_definitions: HashMap::default(),\n            metrics_numeric: NumericMetricsState::default(),\n            metrics_text: TextMetricsState::default(),\n            status: StatusState::default(),\n            interrupter,\n            popup: PopupState::Empty,\n            previous_panic_hook: Some(previous_panic_hook),\n            persistent: false,\n            manual_close: false,\n            close: false,\n            summary: None,\n            kill_signal,\n        }\n    }\n\n    fn handle_event(&mut self, event: TuiRendererEvent) {\n        match event {\n            TuiRendererEvent::MetricRegistration(definition) => {\n                self.metric_definitions\n                    .insert(definition.metric_id.clone(), definition);\n            }\n            TuiRendererEvent::MetricsUpdate((split, group, state)) => {\n                self.update_metric(split, group, state);\n            }\n            TuiRendererEvent::StatusUpdateTrain((split, item, status)) => match split {\n                TuiSplit::Train => {\n                    self.progress.update_train(&item);\n                    self.metrics_numeric.update_progress_train(&item);\n                    self.status.update_train(status);\n                }\n                TuiSplit::Valid => {\n                    self.progress.update_valid(&item);\n                    self.metrics_numeric.update_progress_valid(&item);\n                    self.status.update_valid(status);\n                }\n                _ => (),\n            },\n            TuiRendererEvent::StatusUpdateTest((item, status)) => {\n                self.progress.update_test(&item);\n                self.metrics_numeric.update_progress_test(&item);\n                self.status.update_test(status);\n            }\n            TuiRendererEvent::ProcessEnd { summary, reset } => {\n                match (self.summary.take(), summary) {\n                    (None, Some(summary)) => {\n                        self.summary = Some(summary);\n                    }\n                    (Some(current), Some(other)) => self.summary = Some(current.merge(other)),\n                    (_, _) => { /* nothing to update */ }\n                }\n\n                if reset {\n                    self.interrupter.reset();\n                }\n            }\n            TuiRendererEvent::ManualClose => self.manual_close = true,\n            TuiRendererEvent::Persistent => self.persistent = true,\n            TuiRendererEvent::Close => self.close = true,\n        }\n    }\n\n    fn render(&mut self) -> Result<(), Box<dyn Error>> {\n        self.draw()?;\n        self.handle_user_input()?;\n\n        self.last_update = Instant::now();\n\n        Ok(())\n    }\n\n    fn draw(&mut self) -> Result<(), Box<dyn Error>> {\n        self.terminal.draw(|frame| {\n            let size = frame.area();\n\n            match self.popup.view() {\n                Some(view) => view.render(frame, size),\n                None => {\n                    let view = MetricsView::new(\n                        self.metrics_numeric.view(),\n                        self.metrics_text.view(),\n                        self.progress.view(),\n                        ControlsView,\n                        self.status.view(),\n                    );\n\n                    view.render(frame, size);\n                }\n            };\n        })?;\n\n        Ok(())\n    }\n\n    fn handle_user_input(&mut self) -> Result<(), Box<dyn Error>> {\n        while event::poll(Duration::from_secs(0))? {\n            let event = event::read()?;\n            self.popup.on_event(&event);\n\n            if self.popup.is_empty() {\n                self.metrics_numeric.on_event(&event);\n\n                if let Event::Key(key) = event\n                    && let KeyCode::Char('q') = key.code\n                {\n                    self.popup = PopupState::Full(\n                        \"Quit\".to_string(),\n                        vec![\n                            Callback::new(\n                                \"Stop the training.\",\n                                \"Stop the training immediately. This will break from the \\\n                                     training loop, but any remaining code after the loop will be \\\n                                     executed.\",\n                                's',\n                                QuitPopupAccept(self.interrupter.clone()),\n                            ),\n                            Callback::new(\n                                \"Stop the training immediately.\",\n                                \"Kill the program. This will create a panic! which will make \\\n                                     the current training fails. Any code following the training \\\n                                     won't be executed.\",\n                                'k',\n                                KillPopupAccept(self.kill_signal.clone()),\n                            ),\n                            Callback::new(\n                                \"Cancel\",\n                                \"Cancel the action, continue the training.\",\n                                'c',\n                                PopupCancel,\n                            ),\n                        ],\n                    );\n                }\n            }\n        }\n\n        Ok(())\n    }\n\n    fn handle_post_training(&mut self) -> Result<(), Box<dyn Error>> {\n        self.popup = PopupState::Full(\n            \"Training is done\".to_string(),\n            vec![Callback::new(\n                \"Training Done\",\n                \"Press 'x' to close this popup.  Press 'q' to exit the application after the \\\n                popup is closed.\",\n                'x',\n                PopupCancel,\n            )],\n        );\n\n        self.draw().ok();\n\n        loop {\n            if let Ok(true) = event::poll(Duration::from_millis(MAX_REFRESH_RATE_MILLIS)) {\n                match event::read() {\n                    Ok(event @ Event::Key(key)) => {\n                        if self.popup.is_empty() {\n                            self.metrics_numeric.on_event(&event);\n                            if let KeyCode::Char('q') = key.code {\n                                break;\n                            }\n                        } else {\n                            self.popup.on_event(&event);\n                        }\n                        self.draw().ok();\n                    }\n\n                    Ok(Event::Resize(..)) => {\n                        self.draw().ok();\n                    }\n                    Err(err) => {\n                        eprintln!(\"Error reading event: {err}\");\n                        break;\n                    }\n                    _ => continue,\n                }\n            }\n        }\n        Ok(())\n    }\n\n    // Reset the terminal back to raw mode.\n    fn reset(&mut self) -> Result<(), Box<dyn Error>> {\n        // If previous panic hook has already been re-instated, then the terminal was already reset.\n        if self.previous_panic_hook.is_some() {\n            if self.persistent\n                && let Err(err) = self.handle_post_training()\n            {\n                eprintln!(\"Error in post-training handling: {err}\");\n            }\n\n            disable_raw_mode()?;\n            execute!(self.terminal.backend_mut(), LeaveAlternateScreen)?;\n            self.terminal.show_cursor()?;\n\n            // Reinstall the previous panic hook\n            let _ = take_hook();\n            if let Some(previous_panic_hook) =\n                Arc::into_inner(self.previous_panic_hook.take().unwrap())\n            {\n                set_hook(previous_panic_hook);\n            }\n        }\n        Ok(())\n    }\n}\n\nstruct QuitPopupAccept(Interrupter);\nstruct KillPopupAccept(Sender<()>);\nstruct PopupCancel;\n\nimpl CallbackFn for KillPopupAccept {\n    fn call(&self) -> bool {\n        self.0.send(()).unwrap();\n        panic!(\"Killing training from user input.\");\n    }\n}\n\nimpl CallbackFn for QuitPopupAccept {\n    fn call(&self) -> bool {\n        self.0.stop(Some(\"Stopping training from user input.\"));\n        true\n    }\n}\n\nimpl CallbackFn for PopupCancel {\n    fn call(&self) -> bool {\n        true\n    }\n}\n\nimpl Drop for TuiMetricsRenderer {\n    fn drop(&mut self) {\n        // Reset the terminal back to raw mode. This can be skipped during\n        // panicking because the panic hook has already reset the terminal\n        if !std::thread::panicking() {\n            self.reset().unwrap();\n\n            if let Some(summary) = &self.summary {\n                println!(\"{summary}\");\n                log::info!(\"{summary}\");\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-train/src/renderer/tui/status.rs",
    "content": "use crate::renderer::ProgressType;\n\nuse super::TerminalFrame;\nuse ratatui::{\n    prelude::{Alignment, Rect},\n    style::{Color, Style, Stylize},\n    text::{Line, Span},\n    widgets::{Block, Borders, Paragraph, Wrap},\n};\n\n/// Show the training status with various information.\npub(crate) struct StatusState {\n    progress_indicators: Vec<ProgressType>,\n    mode: Mode,\n}\n\nenum Mode {\n    Valid,\n    Train,\n    Evaluation,\n}\n\nimpl Default for StatusState {\n    fn default() -> Self {\n        Self {\n            progress_indicators: vec![],\n            mode: Mode::Train,\n        }\n    }\n}\n\nimpl StatusState {\n    /// Update the training information.\n    pub(crate) fn update_train(&mut self, progress_indicators: Vec<ProgressType>) {\n        self.progress_indicators = progress_indicators;\n        self.mode = Mode::Train;\n    }\n    /// Update the validation information.\n    pub(crate) fn update_valid(&mut self, progress_indicators: Vec<ProgressType>) {\n        self.progress_indicators = progress_indicators;\n        self.mode = Mode::Valid;\n    }\n    /// Update the testing information.\n    pub(crate) fn update_test(&mut self, progress_indicators: Vec<ProgressType>) {\n        self.progress_indicators = progress_indicators;\n        self.mode = Mode::Evaluation;\n    }\n    /// Create a view.\n    pub(crate) fn view(&self) -> StatusView {\n        StatusView::new(&self.progress_indicators, &self.mode)\n    }\n}\n\npub(crate) struct StatusView {\n    lines: Vec<Vec<Span<'static>>>,\n}\n\nimpl StatusView {\n    fn new(progress_indicators: &[ProgressType], mode: &Mode) -> Self {\n        let title = |title: &str| Span::from(format!(\" {title} \")).bold().yellow();\n        let value = |value: String| Span::from(value).italic();\n        let mode = match mode {\n            Mode::Valid => \"Validating\",\n            Mode::Train => \"Training\",\n            Mode::Evaluation => \"Evaluation\",\n        };\n\n        let width = progress_indicators\n            .iter()\n            .map(|p| match p {\n                ProgressType::Detailed { tag, .. } => tag.len(),\n                ProgressType::Value { tag, .. } => tag.len(),\n            })\n            .max()\n            .unwrap_or(4);\n\n        let mut lines = vec![vec![\n            title(&format!(\"{: <width$} :\", \"Mode\")),\n            value(mode.to_string()),\n        ]];\n\n        progress_indicators.iter().for_each(|p| match p {\n            ProgressType::Detailed { tag, progress } => lines.push(vec![\n                title(&format!(\"{: <width$} :\", tag)),\n                value(format!(\n                    \"{}/{}\",\n                    progress.items_processed, progress.items_total\n                )),\n            ]),\n            ProgressType::Value {\n                tag,\n                value: num_items,\n            } => lines.push(vec![\n                title(&format!(\"{: <width$} :\", tag)),\n                value(format!(\"{}\", num_items)),\n            ]),\n        });\n\n        Self { lines }\n    }\n\n    pub(crate) fn render(self, frame: &mut TerminalFrame<'_>, size: Rect) {\n        let paragraph = Paragraph::new(self.lines.into_iter().map(Line::from).collect::<Vec<_>>())\n            .alignment(Alignment::Left)\n            .block(Block::default().borders(Borders::ALL).title(\"Status\"))\n            .wrap(Wrap { trim: false })\n            .style(Style::default().fg(Color::Gray));\n\n        frame.render_widget(paragraph, size);\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/Cargo.toml",
    "content": "[package]\nauthors = [\n    \"nathanielsimard <nathaniel.simard.42@gmail.com>\",\n    \"wingertge <wingertge@gmail.com>\",\n]\ncategories = [\"science\"]\ndescription = \"Vision processing operations for burn tensors\"\ndocumentation = \"https://docs.rs/burn-vision\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\"]\nlicense.workspace = true\nname = \"burn-vision\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-vision\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n\n[features]\ndefault = [\"ndarray\", \"cubecl-backend\", \"fusion\", \"std\"]\nstd = [\"aligned-vec/std\"]\ntracing = [\n    \"burn-cubecl?/tracing\",\n    \"burn-fusion?/tracing\",\n    \"burn-ir/tracing\",\n    \"burn-ndarray?/tracing\",\n    \"burn-tch?/tracing\",\n    \"burn-tensor/tracing\",\n    \"cubecl/tracing\",\n]\n\ncubecl-backend = [\"cubecl\", \"burn-cubecl\"]\nfusion = [\"burn-fusion\", \"burn-cuda/fusion\", \"burn-wgpu/fusion\"]\nndarray = [\"burn-ndarray\"]\ntch = [\"burn-tch\"]\n\n# Test features\ntest-cpu = []\ntest-cuda = [\"cubecl-backend\", ]\ntest-wgpu = [\"cubecl-backend\", ]\ntest-vulkan = [\"burn-wgpu/vulkan\", \"test-wgpu\"]\ntest-metal = [\"burn-wgpu/metal\", \"test-wgpu\"]\n\n[dependencies]\naligned-vec = { version = \"0.6\", default-features = false }\nbon = { workspace = true }\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", optional = true }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\nburn-ir = { path = \"../burn-ir\", version = \"=0.21.0-pre.2\" }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\", optional = true }\nburn-tch = { path = \"../burn-tch\", version = \"=0.21.0-pre.2\", optional = true }\nburn-tensor = { path = \"../burn-tensor\", version = \"=0.21.0-pre.2\" }\nburn-tensor-testgen = { path = \"../burn-tensor-testgen\", version = \"=0.21.0-pre.2\", optional = true }\nbytemuck = { workspace = true }\ncubecl = { workspace = true, optional = true }\nderive-new = { workspace = true }\nhalf = { workspace = true }\nimage = { version = \"0.25\" }\nmacerator = { workspace = true }\nndarray = { workspace = true }\nnum-traits = { workspace = true }\npaste = { workspace = true }\nserde = { workspace = true }\n\n[dev-dependencies]\nburn-cuda = { path = \"../burn-cuda\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-ndarray = { path = \"../burn-ndarray\", version = \"=0.21.0-pre.2\" }\nburn-wgpu = { path = \"../burn-wgpu\", version = \"=0.21.0-pre.2\", default-features = false }\ncubecl = { workspace = true }\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/base.rs",
    "content": "pub trait MinMax {\n    fn min(self, other: Self) -> Self;\n    fn max(self, other: Self) -> Self;\n}\n\nmacro_rules! impl_minmax {\n    ($ty: ty) => {\n        impl MinMax for $ty {\n            fn min(self, other: Self) -> Self {\n                Ord::min(self, other)\n            }\n            fn max(self, other: Self) -> Self {\n                Ord::max(self, other)\n            }\n        }\n    };\n    ($($ty: ty),*) => {\n        $(impl_minmax!($ty);)*\n    }\n}\n\nimpl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64);\n\nimpl MinMax for f32 {\n    fn min(self, other: Self) -> Self {\n        self.min(other)\n    }\n\n    fn max(self, other: Self) -> Self {\n        self.max(other)\n    }\n}\n\nimpl MinMax for f64 {\n    fn min(self, other: Self) -> Self {\n        self.min(other)\n    }\n\n    fn max(self, other: Self) -> Self {\n        self.max(other)\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_center_line_forest_code.rs",
    "content": "no_analyze!{{\nuse centerLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<centerLabels> { match label {\n\t\tNODE_1=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_2);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_3);\n\t\t}\n\t\t\t\t}\n\t\tNODE_3=> {\n\t\tif (*img_row01.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(cl_tree_2);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\treturn Some(cl_tree_1);\n\t\t}\n\t\t\t\t}\n\t\tNODE_4=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\treturn Some(NODE_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(cl_tree_4);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\treturn Some(cl_tree_3);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_6=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_2);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\treturn Some(cl_tree_7);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_1);\n\t\t}\n\t\t\t\t}\n\t\tNODE_2=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_4);\n\t\t}\n\t\t\t\t}\n\t\tNODE_7=> {\n\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\n\t\tNODE_5=> {\n\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\n\t\tNODE_8=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_9);\n\t\t}\n\t\t\t\t}\n\t\tNODE_10=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\treturn Some(cl_tree_11);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_8);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_11);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_12);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\treturn Some(NODE_12);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_11=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_4);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_3);\n\t\t}\n\t\t\t\t}\n\t\tNODE_13=> {\n\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\t\t\t}\n\t\tNODE_9=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_11);\n\t\t}\n\t\t\t\t}\n\t\tNODE_12=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_10);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_9);\n\t\t}\n\t\t\t\t}\n\t\tNODE_14=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_4);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\treturn Some(cl_tree_10);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_15=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(cl_tree_11);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_13);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_5);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\treturn Some(cl_tree_4);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_3);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\treturn Some(cl_tree_3);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\treturn Some(cl_tree_10);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_16=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_8);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\treturn Some(NODE_2);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\treturn Some(NODE_17);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_1);\n\t\t}\n\t\t\t\t}\n\t\tNODE_18=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\n\t\tNODE_19=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\treturn Some(NODE_20);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_8);\n\t\t}\n\t\t\t\t}\n\t\tNODE_21=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(cl_tree_6);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\treturn Some(cl_tree_3);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_3);\n\t\t}\n\t\t\t\t}\n\t\tNODE_22=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(cl_tree_6);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\treturn Some(cl_tree_6);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_23=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(cl_tree_11);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\treturn Some(cl_tree_11);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_24=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\t\t\t}\n\t\tNODE_17=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_7);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(cl_tree_7);\n\t\t}\n\t\t\t\t}\n\t\tNODE_25=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_18);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_8);\n\t\t}\n\t\t\t\t}\n\t\tNODE_20=> {\n\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_26);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\n\t\tNODE_27=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\t\t\t}\n\t\tNODE_28=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_22);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_19);\n\t\t}\n\t\t\t\t}\n\t\tNODE_26=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_29=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_8);\n\t\t}\n\t\t\t\t}\n\t\tNODE_30=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_8);\n\t\t}\n\t\t\t\t}\n\t\tNODE_31=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_23);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_19);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\treturn Some(cl_tree_12);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_32=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_33);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_34);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_35);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_4);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\treturn Some(cl_tree_3);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_36=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_33);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_34);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\treturn Some(NODE_35);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\treturn Some(cl_tree_3);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_37=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_18);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\treturn Some(cl_tree_5);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\treturn Some(cl_tree_5);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(cl_tree_8);\n\t\t}\n\t\t\t\t}\n\t\tNODE_33=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\treturn Some(NODE_26);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\n\t\tNODE_38=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\treturn Some(NODE_22);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_6);\n\t\t}\n\t\t\t\t}\n\t\tNODE_39=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_10);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_10);\n\t\t}\n\t\t\t\t}\n\t\tNODE_35=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(cl_tree_4);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_4);\n\t\t}\n\t\t\t\t}\n\t\tNODE_40=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\treturn Some(NODE_23);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_11);\n\t\t}\n\t\t\t\t}\n\t\tNODE_34=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(cl_tree_5);\n\t\t}\n\t\t\t\t}\ncl_tree_0 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_0); } else { return Some(cl_break_1_0); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_14);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_6);\n\t\t\t\t}\n}\ncl_tree_1 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_1); } else { return Some(cl_break_1_1); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_15);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_6);\n\t\t\t\t}\n}\ncl_tree_2 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_2); } else { return Some(cl_break_1_2); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_10);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_8);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_3 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_3); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_29);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_12);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_29);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_21);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_4 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_4); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_27);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_25);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_12);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_24);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_25);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_21);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_5 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_5); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_30);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_12);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_30);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_5);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_4);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_3);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_6 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_6); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_31);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_28);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_7 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_4); } else { return Some(cl_break_1_7); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_10);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_15);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_16);\n\t\t\t\t}\n}\ncl_tree_8 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_3); } else { return Some(cl_break_1_8); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_27);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_37);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_12);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_24);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_37);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_21);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_9 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_5); } else { return Some(cl_break_1_9); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_9);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_12);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_14);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_16);\n\t\t\t\t}\n}\ncl_tree_10 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_6); } else { return Some(cl_break_1_10); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_40);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_36);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_39);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_14);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_38);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_36);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_2);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_17);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_11 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_7); } else { return Some(cl_break_1_11); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_31);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_31);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_13);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t\treturn Some(NODE_5);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t\treturn Some(NODE_7);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_4);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_3);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_10);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_28);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_22);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t\treturn Some(NODE_20);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_4);\n\t\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_3);\n\t\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_2);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_7);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\ncl_tree_12 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(cl_break_0_8); } else { return Some(cl_break_1_12); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_40);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(cl_tree_11);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_32);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_39);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_10);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(cl_tree_9);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_14);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_38);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\t\t\treturn Some(cl_tree_6);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_32);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_2);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_17);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\n\t\tNODE_41=> {\n\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_42);\n\t\t}\n\t\t\t\t}\n\t\tNODE_43=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\n\t\tNODE_42=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_44=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_45=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\n\t\tNODE_46=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_47=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\t\t\t}\n\t\tNODE_48=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\ncl_break_0_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_47);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_48);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_44);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_48);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_41);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_43);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_3 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_43);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_4 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_41);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_44);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_45);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_5 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_42);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_47);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_45);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_6 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_46);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_47);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_45);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_7 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_0_8 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_46);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_47);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_45);\n\t\t\t\t}\n\t\treturn None;}\n\t\tNODE_49=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_50=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\treturn Some(NODE_49);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_51=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_52);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_52=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_53=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_54);\n\t\t\t}\n\t\t\telse {\n\t\t\treturn Some(NODE_55);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_56);\n\t\t}\n\t\t\t\t}\n\t\tNODE_55=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\t\t\t}\n\t\tNODE_54=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\treturn Some(NODE_57);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_58);\n\t\t}\n\t\t\t\t}\n\t\tNODE_59=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_60);\n\t\t}\n\t\t\t\t}\n\t\tNODE_61=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_62=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_63);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_49);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_58);\n\t\t}\n\t\t\t\t}\n\t\tNODE_63=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\treturn Some(NODE_52);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_64=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\treturn Some(NODE_65);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_65=> {\n\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\treturn Some(NODE_49);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_66=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_63);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_65);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_58);\n\t\t}\n\t\t\t\t}\n\t\tNODE_67=> {\n\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\treturn Some(NODE_58);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_56);\n\t\t}\n\t\t\t\t}\n\t\tNODE_56=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_58);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_60);\n\t\t}\n\t\t\t\t}\n\t\tNODE_58=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_68=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t}\n\t\t\telse {\n\t\t\treturn Some(NODE_69);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_70=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\treturn Some(NODE_71);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_57=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\telse {\n\t\treturn Some(NODE_69);\n\t\t}\n\t\t\t\t}\n\t\tNODE_60=> {\n\t\tif (*img_row01.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\n\t\tNODE_71=> {\n\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_69=> {\n\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\ncl_break_1_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_58);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_67);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_70);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_67);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_68);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_57);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_56);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_3 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_61);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_61);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_59);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_4 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_50);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_50);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_59);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_5 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_60);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_6 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_56);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_7 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_68);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_70);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_53);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_8 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_64);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_64);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_59);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_9 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_54);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_53);\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_10 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_62);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_62);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_55);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_56);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_11 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_71);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_51);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_58);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_56);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\ncl_break_1_12 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_66);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_66);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_55);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_56);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_first_line_forest_code.rs",
    "content": "no_analyze!{{\nuse firstLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<firstLabels> { match label {\n\t\tNODE_72=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(fl_tree_1);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(fl_tree_2);\n\t\t}\n\t\t\t\t}\n\t\tNODE_73=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(fl_tree_1);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\treturn Some(fl_tree_2);\n\t\t}\n\t\t\t\t}\n\t\tNODE_74=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(fl_tree_1);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row01.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\treturn Some(fl_tree_1);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\treturn Some(fl_tree_0);\n\t\t\t}\n\t\t}\n\t\t\t\t}\nfl_tree_0 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_0); } else { return Some(fl_break_1_0); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_72);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_72);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_74);\n\t\t\t\t\t}\n\t\t\t\t}\n}\nfl_tree_1 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_1); } else { return Some(fl_break_1_1); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_73);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_73);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_74);\n\t\t\t\t\t}\n\t\t\t\t}\n}\nfl_tree_2 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(fl_break_0_2); } else { return Some(fl_break_1_2); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_73);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_72);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(fl_tree_1);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\treturn Some(fl_tree_1);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(fl_tree_2);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\treturn Some(fl_tree_2);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_74);\n\t\t\t\t\t}\n\t\t\t\t}\n}\n\t\tNODE_75=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\t\t\t}\nfl_break_0_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nfl_break_0_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nfl_break_0_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_75);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_75);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\n\t\tNODE_76=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\telse {\n\t\t\tif (*img_row01.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_77=> {\n\t\tif (*img_row01.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\t\t\t}\nfl_break_1_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_76);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nfl_break_1_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_76);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nfl_break_1_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_77);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row01.add((c) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_77);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\treturn Some(NODE_77);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\treturn Some(NODE_76);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nfl_ => {},\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_forest_labels.rs",
    "content": "/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.\nmacro_rules! no_analyze {\n    ($tokens:tt) => {\n        $tokens\n    };\n}\n\npub(crate) use no_analyze;\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum centerLabels {\n    NODE_1,\n    NODE_2,\n    NODE_3,\n    NODE_4,\n    NODE_5,\n    NODE_6,\n    NODE_7,\n    NODE_8,\n    NODE_9,\n    NODE_10,\n    NODE_11,\n    NODE_12,\n    NODE_13,\n    NODE_14,\n    NODE_15,\n    NODE_16,\n    NODE_17,\n    NODE_18,\n    NODE_19,\n    NODE_20,\n    NODE_21,\n    NODE_22,\n    NODE_23,\n    NODE_24,\n    NODE_25,\n    NODE_26,\n    NODE_27,\n    NODE_28,\n    NODE_29,\n    NODE_30,\n    NODE_31,\n    NODE_32,\n    NODE_33,\n    NODE_34,\n    NODE_35,\n    NODE_36,\n    NODE_37,\n    NODE_38,\n    NODE_39,\n    NODE_40,\n    NODE_41,\n    NODE_42,\n    NODE_43,\n    NODE_44,\n    NODE_45,\n    NODE_46,\n    NODE_47,\n    NODE_48,\n    NODE_49,\n    NODE_50,\n    NODE_51,\n    NODE_52,\n    NODE_53,\n    NODE_54,\n    NODE_55,\n    NODE_56,\n    NODE_57,\n    NODE_58,\n    NODE_59,\n    NODE_60,\n    NODE_61,\n    NODE_62,\n    NODE_63,\n    NODE_64,\n    NODE_65,\n    NODE_66,\n    NODE_67,\n    NODE_68,\n    NODE_69,\n    NODE_70,\n    NODE_71,\n    cl_tree_0,\n    cl_tree_1,\n    cl_tree_2,\n    cl_tree_3,\n    cl_tree_4,\n    cl_tree_5,\n    cl_tree_6,\n    cl_tree_7,\n    cl_tree_8,\n    cl_tree_9,\n    cl_tree_10,\n    cl_tree_11,\n    cl_tree_12,\n    cl_break_0_0,\n    cl_break_0_1,\n    cl_break_0_2,\n    cl_break_0_3,\n    cl_break_0_4,\n    cl_break_0_5,\n    cl_break_0_6,\n    cl_break_0_7,\n    cl_break_0_8,\n    cl_break_1_0,\n    cl_break_1_1,\n    cl_break_1_2,\n    cl_break_1_3,\n    cl_break_1_4,\n    cl_break_1_5,\n    cl_break_1_6,\n    cl_break_1_7,\n    cl_break_1_8,\n    cl_break_1_9,\n    cl_break_1_10,\n    cl_break_1_11,\n    cl_break_1_12,\n}\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum firstLabels {\n    NODE_72,\n    NODE_73,\n    NODE_74,\n    NODE_75,\n    NODE_76,\n    NODE_77,\n    fl_tree_0,\n    fl_tree_1,\n    fl_tree_2,\n    fl_break_0_0,\n    fl_break_0_1,\n    fl_break_0_2,\n    fl_break_1_0,\n    fl_break_1_1,\n    fl_break_1_2,\n    fl_,\n}\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum lastLabels {\n    NODE_78,\n    NODE_79,\n    NODE_80,\n    NODE_81,\n    NODE_82,\n    NODE_83,\n    NODE_84,\n    NODE_85,\n    NODE_86,\n    NODE_87,\n    NODE_88,\n    NODE_89,\n    NODE_90,\n    NODE_91,\n    NODE_92,\n    ll_tree_0,\n    ll_tree_1,\n    ll_tree_2,\n    ll_tree_3,\n    ll_tree_4,\n    ll_tree_5,\n    ll_tree_6,\n    ll_tree_7,\n    ll_break_0_0,\n    ll_break_0_1,\n    ll_break_0_2,\n    ll_break_0_3,\n    ll_break_1_0,\n    ll_break_1_1,\n    ll_break_1_2,\n    ll_break_1_3,\n    ll_break_1_4,\n    ll_break_1_5,\n    ll_break_1_6,\n    ll_break_1_7,\n    ll_,\n}\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum singleLabels {\n    NODE_93,\n    NODE_94,\n    sl_tree_0,\n    sl_tree_1,\n    sl_break_0_0,\n    sl_break_0_1,\n    sl_break_1_0,\n    sl_break_1_1,\n    sl_,\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_last_line_forest_code.rs",
    "content": "no_analyze!{{\nuse lastLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<lastLabels> { match label {\n\t\tNODE_78=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\treturn Some(ll_tree_4);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(ll_tree_4);\n\t\t}\n\t\t\t\t}\n\t\tNODE_79=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(ll_tree_6);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(ll_tree_6);\n\t\t}\n\t\t\t\t}\n\t\tNODE_80=> {\n\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\treturn Some(ll_tree_6);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(ll_tree_6);\n\t\t}\n\t\t\t\t}\n\t\tNODE_81=> {\n\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\treturn Some(NODE_82);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(ll_tree_4);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(ll_tree_3);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\treturn Some(ll_tree_2);\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_83=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(ll_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\treturn Some(ll_tree_2);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\treturn Some(ll_tree_1);\n\t\t}\n\t\t\t\t}\n\t\tNODE_84=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(ll_tree_5);\n\t\t\t}\n\t\t\telse {\n\t\t\treturn Some(NODE_81);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\treturn Some(ll_tree_1);\n\t\t}\n\t\t\t\t}\n\t\tNODE_82=> {\n\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\treturn Some(ll_tree_4);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\treturn Some(ll_tree_4);\n\t\t}\n\t\t\t\t}\n\t\tNODE_85=> {\n\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\treturn Some(ll_tree_4);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\t\treturn Some(ll_tree_4);\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c + 2) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t\treturn Some(ll_tree_4);\n\t\t}\n\t\t\t\t}\n\t\tNODE_86=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\treturn Some(ll_tree_6);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\treturn Some(ll_tree_7);\n\t\t\t\t}\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\treturn Some(ll_tree_0);\n\t\t\t}\n\t\t}\n\t\t\t\t}\nll_tree_0 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_0); } else { return Some(ll_break_1_0); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_81);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_84);\n\t\t\t\t}\n}\nll_tree_1 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_1); } else { return Some(ll_break_1_1); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_80);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\treturn Some(NODE_82);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_85);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_3);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_2);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_2);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_84);\n\t\t\t\t}\n}\nll_tree_2 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_2); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_7);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_83);\n\t\t\t\t}\n}\nll_tree_3 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_3); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_79);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_78);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_7);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_83);\n\t\t\t\t}\n}\nll_tree_4 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_4); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c + 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_7);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_5);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_82);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_3);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t\treturn Some(ll_tree_1);\n\t\t\t\t\t}\n\t\t\t\t}\n}\nll_tree_5 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_5); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_86);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_84);\n\t\t\t\t}\n}\nll_tree_6 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_3); } else { return Some(ll_break_1_6); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_86);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_80);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_82);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\treturn Some(NODE_85);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_3);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_2);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_84);\n\t\t\t\t}\n}\nll_tree_7 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(ll_break_0_2); } else { return Some(ll_break_1_7); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_79);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\treturn Some(ll_tree_6);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c + 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\tif (*img_row12.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t\t\t\treturn Some(NODE_78);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c + 2) as usize), solver);\n\t\t\t\t\t\t\t\t\treturn Some(ll_tree_4);\n\t\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\t\treturn Some(ll_tree_7);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\t\treturn Some(ll_tree_0);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_83);\n\t\t\t\t}\n}\nll_break_0_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\nll_break_0_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\nll_break_0_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\nll_break_0_3 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\n\t\tNODE_87=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\treturn Some(NODE_88);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\n\t\tNODE_88=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\t\t\t}\n\t\tNODE_89=> {\n\t\tif (*img_row12.add((c - 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t}\n\t\t\t\t}\n\t\tNODE_90=> {\n\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t}\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t}\n\t\t\t\t}\n\t\tNODE_91=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t}\n\t\t\telse {\n\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t}\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\n\t\tNODE_92=> {\n\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row12.add((c) as usize), *img_labels_row12.add((c - 2) as usize), solver);\n\t\t}\n\t\t\t\t}\nll_break_1_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_88);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_87);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_92);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_87);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_2 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_91);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_3 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\treturn Some(NODE_89);\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_91);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_4 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\t}\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_5 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\treturn Some(NODE_90);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_87);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_6 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c - 1) as usize)).to_bool() {\n\t\t\t\t\t\treturn Some(NODE_90);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\treturn Some(NODE_92);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row12.add((c - 2) as usize);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_87);\n\t\t\t\t}\n\t\treturn None;}\nll_break_1_7 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\tif (*img_row12.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t\tif (*img_row11.add((c - 2) as usize)).to_bool() {\n\t\t\t\t\t\t\t\treturn Some(NODE_89);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\telse {\n\t\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 2) as usize), *img_labels_row12.add((c) as usize), solver);\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_91);\n\t\t\t\t}\n\t\treturn None;}\nll_ => {},\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/Spaghetti_single_line_forest_code.rs",
    "content": "no_analyze!{{\nuse singleLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<singleLabels> { match label {\n\t\tNODE_93=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\treturn Some(sl_tree_1);\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\treturn Some(sl_tree_0);\n\t\t}\n\t\t\t\t}\nsl_tree_0 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_0); } else { return Some(sl_break_1_0); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\treturn Some(sl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\treturn Some(sl_tree_0);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_93);\n\t\t\t\t}\n}\nsl_tree_1 => {\nif ({c+=2; c}) >= w - 2 { if c > w - 2 { return Some(sl_break_0_1); } else { return Some(sl_break_1_1); } }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\treturn Some(sl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t\t\treturn Some(sl_tree_0);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_93);\n\t\t\t\t}\n}\nsl_break_0_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\nsl_break_0_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t}\n\t\treturn None;}\n\t\tNODE_94=> {\n\t\tif (*img_row00.add((c + 1) as usize)).to_bool() {\n\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t}\n\t\telse {\n\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t}\n\t\t\t\t}\nsl_break_1_0 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_94);\n\t\t\t\t}\n\t\treturn None;}\nsl_break_1_1 => {\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 2) as usize);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\treturn Some(NODE_94);\n\t\t\t\t}\n\t\treturn None;}\nsl_ => {},\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti/mod.rs",
    "content": "//! Spaghetti algorithm for connected component labeling\n//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,\n//! \"Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,\"\n//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.\n//!\n//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)\n//! as described in\n//!\n//! F. Bolelli, S. Allegretti, C. Grana.\n//! \"One DAG to Rule Them All.\"\n//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021\n\n#![allow(\n    unreachable_code,\n    clippy::collapsible_else_if,\n    clippy::if_same_then_else\n)]\n\nuse std::cmp::Ordering;\n\nuse burn_tensor::{Element, ElementComparison, ElementConversion, cast::ToElement};\nuse ndarray::{Array2, Axis, s};\n\n#[allow(non_snake_case)]\nmod Spaghetti_forest_labels;\npub(crate) use Spaghetti_forest_labels::*;\n\nuse crate::Connectivity;\n\nuse super::{Solver, StatsOp, max_labels};\n\npub fn process<B: Element, LabelsSolver: Solver>(\n    img_arr: Array2<B>,\n    stats: &mut impl StatsOp<Label = LabelsSolver::Label>,\n) -> Array2<LabelsSolver::Label> {\n    let (h, w) = img_arr.dim();\n    let mut img_labels_arr = Array2::<LabelsSolver::Label>::default(img_arr.raw_dim());\n\n    let img = img_arr.as_ptr();\n\n    let e_rows = h as u32 & 0xfffffffe;\n    let o_rows = h % 2 == 1;\n    let e_cols = w as u32 & 0xfffffffe;\n    let o_cols = w % 2 == 1;\n\n    let img_labels = img_labels_arr.as_mut_ptr();\n\n    let mut solver = LabelsSolver::init(max_labels(h, w, Connectivity::Eight));\n\n    let solver = &mut solver;\n\n    let w = w as i32;\n\n    // SAFETY:\n    // Generated code includes mathematically proven bounds checks, so raw pointers are a safe speed\n    // boost.\n    unsafe {\n        if h == 1 {\n            // Single line\n            let r = 0;\n            //Pointers:\n            // Row pointers for the input image\n            let img_row00 = img.add(r * w as usize);\n\n            // Row pointers for the output image\n            let img_labels_row00 = img_labels.add(r * w as usize);\n\n            let mut c = -2i32;\n            let entry = singleLabels::sl_tree_0;\n\n            include!(\"Spaghetti_single_line_forest_code.rs\");\n        } else {\n            // More than one line\n\n            // First couple of lines\n            {\n                let r = 0;\n                //Pointers:\n                // Row pointers for the input image\n                let img_row00 = img.add(r * w as usize);\n                let img_row01 = img.add((r + 1) * w as usize);\n\n                // Row pointers for the output image\n                let img_labels_row00 = img_labels.add(r * w as usize);\n\n                let mut c = -2i32;\n                let entry = firstLabels::fl_tree_0;\n\n                include!(\"Spaghetti_first_line_forest_code.rs\");\n            }\n\n            // Every other line but the last one if image has an odd number of rows\n            for r in (2..e_rows as usize).step_by(2) {\n                //Pointers:\n                // Row pointers for the input image\n                let img_row00 = img.add(r * w as usize);\n                let img_row12 = img.add((r - 2) * w as usize);\n                let img_row11 = img.add((r - 1) * w as usize);\n                let img_row01 = img.add((r + 1) * w as usize);\n\n                // Row pointers for the output image\n                let img_labels_row00 = img_labels.add(r * w as usize);\n                let img_labels_row12 = img_labels.add((r - 2) * w as usize);\n\n                let mut c = -2;\n                let entry = centerLabels::cl_tree_0;\n\n                include!(\"Spaghetti_center_line_forest_code.rs\");\n            }\n\n            if o_rows {\n                let r = h - 1;\n                //Pointers:\n                // Row pointers for the input image\n                let img_row00 = img.add(r * w as usize);\n                let img_row12 = img.add((r - 2) * w as usize);\n                let img_row11 = img.add((r - 1) * w as usize);\n\n                // Row pointers for the output image\n                let img_labels_row00 = img_labels.add(r * w as usize);\n                let img_labels_row12 = img_labels.add((r - 2) * w as usize);\n\n                let mut c = -2;\n                let entry = lastLabels::ll_tree_0;\n\n                include!(\"Spaghetti_last_line_forest_code.rs\");\n            }\n        }\n    }\n\n    let n_labels = solver.flatten();\n    stats.init(n_labels.to_usize());\n\n    let img = img_arr;\n    let mut img_labels = img_labels_arr;\n\n    for r in (0..e_rows as usize).step_by(2) {\n        //Pointers:\n        // Row pointers for the input image\n        let img_row00 = img.index_axis(Axis(0), r);\n        let img_row01 = img.index_axis(Axis(0), r + 1);\n\n        // Row pointers for the output image\n        let (mut img_labels_row00, mut img_labels_row01) =\n            img_labels.multi_slice_mut((s![r, ..], s![r + 1, ..]));\n\n        for c in (0..e_cols as usize).step_by(2) {\n            let mut i_label = img_labels_row00[c];\n            if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {\n                i_label = solver.get_label(i_label);\n                if img_row00[c].to_u8() > 0 {\n                    img_labels_row00[c] = i_label;\n                    stats.update(r, c, i_label);\n                } else {\n                    img_labels_row00[c] = 0.elem();\n                    stats.update(r, c, 0.elem());\n                }\n                if img_row00[c + 1].to_u8() > 0 {\n                    img_labels_row00[c + 1] = i_label;\n                    stats.update(r, c + 1, i_label);\n                } else {\n                    img_labels_row00[c + 1] = 0.elem();\n                    stats.update(r, c + 1, 0.elem());\n                }\n                if img_row01[c].to_u8() > 0 {\n                    img_labels_row01[c] = i_label;\n                    stats.update(r + 1, c, i_label);\n                } else {\n                    img_labels_row01[c] = 0.elem();\n                    stats.update(r + 1, c, 0.elem());\n                }\n                if img_row01[c + 1].to_u8() > 0 {\n                    img_labels_row01[c + 1] = i_label;\n                    stats.update(r + 1, c + 1, i_label);\n                } else {\n                    img_labels_row01[c + 1] = 0.elem();\n                    stats.update(r + 1, c + 1, 0.elem());\n                }\n            } else {\n                img_labels_row00[c] = 0.elem();\n                stats.update(r, c, 0.elem());\n                img_labels_row00[c + 1] = 0.elem();\n                stats.update(r, c + 1, 0.elem());\n                img_labels_row01[c] = 0.elem();\n                stats.update(r + 1, c, 0.elem());\n                img_labels_row01[c + 1] = 0.elem();\n                stats.update(r + 1, c + 1, 0.elem());\n            }\n        }\n        if o_cols {\n            let c = e_cols as usize;\n            let mut i_label = img_labels_row00[c];\n            if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {\n                i_label = solver.get_label(i_label);\n                if img_row00[c].to_u8() > 0 {\n                    img_labels_row00[c] = i_label;\n                    stats.update(r, c, i_label);\n                } else {\n                    img_labels_row00[c] = 0.elem();\n                    stats.update(r, c, 0.elem());\n                }\n                if img_row01[c].to_u8() > 0 {\n                    img_labels_row01[c] = i_label;\n                    stats.update(r + 1, c, i_label);\n                } else {\n                    img_labels_row01[c] = 0.elem();\n                    stats.update(r + 1, c, 0.elem());\n                }\n            } else {\n                img_labels_row00[c] = 0.elem();\n                stats.update(r, c, 0.elem());\n                img_labels_row01[c] = 0.elem();\n                stats.update(r + 1, c, 0.elem());\n            }\n        }\n    }\n\n    if o_rows {\n        let r = e_rows as usize;\n\n        // Row pointers for the input image\n        let img_row00 = img.index_axis(Axis(0), r);\n\n        // Row pointers for the output image\n        let mut img_labels_row00 = img_labels.slice_mut(s![r, ..]);\n\n        for c in (0..e_cols as usize).step_by(2) {\n            let mut i_label = img_labels_row00[c];\n            if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {\n                i_label = solver.get_label(i_label);\n                if img_row00[c].to_u8() > 0 {\n                    img_labels_row00[c] = i_label;\n                    stats.update(r, c, i_label);\n                } else {\n                    img_labels_row00[c] = 0.elem();\n                    stats.update(r, c, 0.elem());\n                }\n                if img_row00[c + 1].to_u8() > 0 {\n                    img_labels_row00[c + 1] = i_label;\n                    stats.update(r, c + 1, i_label);\n                } else {\n                    img_labels_row00[c + 1] = 0.elem();\n                    stats.update(r, c + 1, 0.elem());\n                }\n            } else {\n                img_labels_row00[c] = 0.elem();\n                stats.update(r, c, 0.elem());\n                img_labels_row00[c + 1] = 0.elem();\n                stats.update(r, c + 1, 0.elem());\n            }\n        }\n        if o_cols {\n            let c = e_cols as usize;\n            let mut i_label = img_labels_row00[c];\n            if matches!(i_label.cmp(&0.elem()), Ordering::Greater) {\n                i_label = solver.get_label(i_label);\n                if img_row00[c].to_u8() > 0 {\n                    img_labels_row00[c] = i_label;\n                    stats.update(r, c, i_label);\n                } else {\n                    img_labels_row00[c] = 0.elem();\n                    stats.update(r, c, 0.elem());\n                }\n            } else {\n                img_labels_row00[c] = 0.elem();\n                stats.update(r, c, i_label);\n            }\n        }\n    }\n\n    stats.finish();\n    img_labels\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_center_line_forest_code.rs",
    "content": "no_analyze!{{\nuse centerLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<centerLabels> { match label {\ncl_tree_0 => {\nif ({c+=1; c} >= w) { return None; }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row11.add((c) as usize);\n\t\t\t\t\t\treturn Some(cl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\t\treturn Some(cl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\treturn Some(cl_tree_0);\n\t\t\t\t}\n}\ncl_tree_1 => {\nif ({c+=1; c} >= w) { return None; }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\tif (*img_row11.add((c) as usize)).to_bool() {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = LabelsSolver::merge(*img_labels_row00.add((c - 1) as usize), *img_labels_row11.add((c) as usize), solver);\n\t\t\t\t\t\treturn Some(cl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t\telse {\n\t\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);\n\t\t\t\t\t\treturn Some(cl_tree_1);\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\treturn Some(cl_tree_0);\n\t\t\t\t}\n}\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_first_line_forest_code.rs",
    "content": "no_analyze!{{\nuse firstLabels::*;let mut label = entry;\nwhile let Some(next) = (|label| -> Option<firstLabels> { match label {\nfl_tree_0 => {\nif ({c+=1; c} >= w) { return None; }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = solver.new_label();\n\t\t\t\t\treturn Some(fl_tree_1);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\treturn Some(fl_tree_0);\n\t\t\t\t}\n}\nfl_tree_1 => {\nif ({c+=1; c} >= w) { return None; }\n\t\t\t\tif (*img_row00.add((c) as usize)).to_bool() {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = *img_labels_row00.add((c - 1) as usize);\n\t\t\t\t\treturn Some(fl_tree_1);\n\t\t\t\t}\n\t\t\t\telse {\n\t\t\t\t\t*img_labels_row00.add(c as usize) = 0.elem();\n\t\t\t\t\treturn Some(fl_tree_0);\n\t\t\t\t}\n}\nfl_ => {},\n    }; None})(label)\n{\nlabel = next;\n}\n}}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/Spaghetti4C_forest_labels.rs",
    "content": "/// Workaround for rust-analyzer bug that causes invalid errors on the `include!`.\nmacro_rules! no_analyze {\n    ($tokens:tt) => {\n        $tokens\n    };\n}\n\npub(crate) use no_analyze;\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum centerLabels {\n    cl_tree_0,\n    cl_tree_1,\n}\n\n#[allow(non_snake_case, non_camel_case_types, unused)]\npub enum firstLabels {\n    fl_tree_0,\n    fl_tree_1,\n    fl_,\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components/spaghetti_4c/mod.rs",
    "content": "//! Spaghetti algorithm for connected component labeling, modified for 4-connectivity using the\n//! 4-connected Rosenfeld mask.\n//! F. Bolelli, S. Allegretti, L. Baraldi, and C. Grana,\n//! \"Spaghetti Labeling: Directed Acyclic Graphs for Block-Based Bonnected Components Labeling,\"\n//! IEEE Transactions on Image Processing, vol. 29, no. 1, pp. 1999-2012, 2019.\n//!\n//! Decision forests are generated using a modified [GRAPHGEN](https://github.com/wingertge/GRAPHGEN)\n//! as described in\n//!\n//! F. Bolelli, S. Allegretti, C. Grana.\n//! \"One DAG to Rule Them All.\"\n//! IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021\n\n#![allow(unreachable_code)]\n\nuse burn_tensor::{Element, ElementConversion, cast::ToElement};\nuse ndarray::Array2;\n\nuse crate::Connectivity;\n\nuse super::{Solver, StatsOp, max_labels};\n\n#[allow(non_snake_case)]\nmod Spaghetti4C_forest_labels;\npub(crate) use Spaghetti4C_forest_labels::*;\n\npub fn process<B: Element, LabelsSolver: Solver>(\n    img: Vec<B>,\n    h: usize,\n    w: usize,\n    stats: &mut impl StatsOp<Label = LabelsSolver::Label>,\n) -> Array2<LabelsSolver::Label> {\n    let img = img.as_ptr();\n\n    let mut img_labels: Vec<LabelsSolver::Label> = vec![0.elem(); h * w];\n\n    // A quick and dirty upper bound for the maximum number of labels.\n    // Following formula comes from the fact that a 2x2 block in 4-connectivity case\n    // can never have more than 2 new labels and 1 label for background.\n    // Worst case image example pattern:\n    // 1 0 1 0 1...\n    // 0 1 0 1 0...\n    // 1 0 1 0 1...\n    // ............\n    let max_labels = max_labels(h, w, Connectivity::Four);\n\n    let mut solver = LabelsSolver::init(max_labels);\n    let solver = &mut solver;\n\n    let w = w as i32;\n    // SAFETY:\n    // This code is generated from constraints and includes manual bounds checks, so unchecked pointer\n    // indexes are always safe.\n    unsafe {\n        // First row\n        {\n            let r = 0;\n            //Pointers:\n            // Row pointers for the input image\n            let img_row00 = img.add(r * w as usize);\n\n            // Row pointers for the output image\n            let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);\n            let mut c = -1i32;\n\n            let entry = firstLabels::fl_tree_0;\n\n            include!(\"Spaghetti4C_first_line_forest_code.rs\");\n        }\n\n        for r in 1..h {\n            //Pointers:\n            // Row pointers for the input image\n            let img_row00 = img.add(r * w as usize);\n            let img_row11 = img.add((r - 1) * w as usize);\n\n            // Row pointers for the output image\n            let img_labels_row00 = img_labels.as_mut_ptr().add(r * w as usize);\n            let img_labels_row11 = img_labels.as_mut_ptr().add((r - 1) * w as usize);\n            let mut c = -1i32;\n\n            let entry = centerLabels::cl_tree_0;\n\n            include!(\"Spaghetti4C_center_line_forest_code.rs\");\n        }\n    }\n\n    let n_labels = solver.flatten();\n    stats.init(n_labels.to_usize());\n\n    // SAFETY: This is always valid\n    let mut img_labels = unsafe { Array2::from_shape_vec_unchecked((h, w as usize), img_labels) };\n\n    img_labels.indexed_iter_mut().for_each(|((r, c), label)| {\n        *label = solver.get_label(*label);\n        stats.update(r, c, *label);\n    });\n\n    stats.finish();\n\n    img_labels\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/connected_components.rs",
    "content": "use std::{cmp::Ordering, marker::PhantomData};\n\nuse alloc::vec::Vec;\nuse burn_tensor::{\n    Bool, DType, Element, ElementConversion, ElementOrdered, Int, Shape, Tensor, TensorData,\n    backend::Backend,\n    ops::{BoolTensor, IntTensor},\n};\nuse ndarray::Array2;\n\nuse crate::{ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity};\n\nmod spaghetti;\nmod spaghetti_4c;\n\n/// Dispatches connected components based on `B::IntElem::dtype()`, binding a concrete\n/// integer type to enable generic instantiations without extra trait bounds (after removing\n/// `ElementComparison` from `Element`).\nmacro_rules! dispatch_int_dtype {\n    (|$ty:ident| $body:expr) => {\n        match B::IntElem::dtype() {\n            DType::I64 => {\n                type $ty = i64;\n                $body\n            }\n            DType::I32 => {\n                type $ty = i32;\n                $body\n            }\n            DType::I16 => {\n                type $ty = i16;\n                $body\n            }\n            DType::I8 => {\n                type $ty = i8;\n                $body\n            }\n            DType::U64 => {\n                type $ty = u64;\n                $body\n            }\n            DType::U32 => {\n                type $ty = u32;\n                $body\n            }\n            DType::U16 => {\n                type $ty = u16;\n                $body\n            }\n            DType::U8 => {\n                type $ty = u8;\n                $body\n            }\n            _ => unreachable!(\"Unsupported dtype\"),\n        }\n    };\n}\n\npub fn connected_components<B: Backend>(\n    img: BoolTensor<B>,\n    connectivity: Connectivity,\n) -> IntTensor<B> {\n    dispatch_int_dtype!(|I| run::<B, I, NoOp<_>>(img, connectivity, NoOp::default).0)\n}\n\npub fn connected_components_with_stats<B: Backend>(\n    img: BoolTensor<B>,\n    connectivity: Connectivity,\n    _options: ConnectedStatsOptions,\n) -> (IntTensor<B>, ConnectedStatsPrimitive<B>) {\n    let device = B::bool_device(&img);\n\n    dispatch_int_dtype!(|I| {\n        let (labels, stats) =\n            run::<B, I, ConnectedStatsOp<I>>(img, connectivity, ConnectedStatsOp::default);\n        let stats = finalize_stats(&device, stats);\n        (labels, stats)\n    })\n}\n\nfn run<B: Backend, I: ElementOrdered, Stats: StatsOp<Label = I>>(\n    img: BoolTensor<B>,\n    connectivity: Connectivity,\n    stats: impl Fn() -> Stats,\n) -> (IntTensor<B>, Stats) {\n    let device = B::bool_device(&img);\n    let img = Tensor::<B, 2, Bool>::from_primitive(img);\n    let [height, width] = img.shape().dims();\n    let img = img.into_data();\n    let img = img.into_vec::<B::BoolElem>().unwrap();\n\n    let mut stats = stats();\n\n    let out = match connectivity {\n        Connectivity::Four => {\n            spaghetti_4c::process::<B::BoolElem, UnionFind<_>>(img, height, width, &mut stats)\n        }\n        Connectivity::Eight => {\n            // SAFETY: This is validated by `TensorData`\n            let img = unsafe { Array2::from_shape_vec_unchecked((height, width), img) };\n            spaghetti::process::<B::BoolElem, UnionFind<_>>(img, &mut stats)\n        }\n    };\n\n    let (data, _) = out.into_raw_vec_and_offset();\n    let data = TensorData::new(data, Shape::new([height, width]));\n    let labels = Tensor::<B, 2, Int>::from_data(data, &device).into_primitive();\n    (labels, stats)\n}\n\npub trait Solver {\n    type Label: ElementOrdered;\n\n    fn init(max_labels: usize) -> Self;\n    /// Hack to get around mutable borrow limitations on methods\n    fn merge(label_1: Self::Label, label_2: Self::Label, solver: &mut Self) -> Self::Label;\n    fn new_label(&mut self) -> Self::Label;\n    fn flatten(&mut self) -> Self::Label;\n    fn get_label(&self, i_label: Self::Label) -> Self::Label;\n}\n\npub(crate) struct UnionFind<I: Element> {\n    labels: Vec<I>,\n}\n\nimpl<I: ElementOrdered> Solver for UnionFind<I> {\n    type Label = I;\n\n    fn init(max_labels: usize) -> Self {\n        let mut labels = Vec::with_capacity(max_labels);\n        labels.push(0.elem());\n        Self { labels }\n    }\n\n    fn merge(mut label_1: I, mut label_2: I, solver: &mut Self) -> I {\n        use Ordering::Less;\n\n        while matches!(solver.labels[label_1.to_usize()].cmp(&label_1), Less) {\n            label_1 = solver.labels[label_1.to_usize()];\n        }\n\n        while matches!(solver.labels[label_2.to_usize()].cmp(&label_2), Less) {\n            label_2 = solver.labels[label_2.to_usize()];\n        }\n\n        if matches!(label_1.cmp(&label_2), Less) {\n            solver.labels[label_2.to_usize()] = label_1;\n            label_1\n        } else {\n            solver.labels[label_1.to_usize()] = label_2;\n            label_2\n        }\n    }\n\n    fn new_label(&mut self) -> I {\n        let len = I::from_elem(self.labels.len());\n        self.labels.push(len);\n        len\n    }\n\n    fn flatten(&mut self) -> I {\n        let mut k = 1;\n        for i in 1..self.labels.len() {\n            if matches!(self.labels[i].cmp(&I::from_elem(i)), Ordering::Less) {\n                self.labels[i] = self.labels[self.labels[i].to_usize()];\n            } else {\n                self.labels[i] = k.elem();\n                k += 1;\n            }\n        }\n        k.elem()\n    }\n\n    fn get_label(&self, i_label: I) -> I {\n        self.labels[i_label.to_usize()]\n    }\n}\n\npub trait StatsOp {\n    type Label;\n\n    fn init(&mut self, num_labels: usize);\n    fn update(&mut self, row: usize, column: usize, label: Self::Label);\n    fn finish(&mut self);\n}\n\n#[derive(Default)]\nstruct NoOp<I: Element> {\n    _i: PhantomData<I>,\n}\n\nimpl<I: Element> StatsOp for NoOp<I> {\n    type Label = I; // placeholder still required\n\n    fn init(&mut self, _num_labels: usize) {}\n\n    fn update(&mut self, _row: usize, _column: usize, _label: Self::Label) {}\n\n    fn finish(&mut self) {}\n}\n\n#[derive(Default, Debug)]\nstruct ConnectedStatsOp<I: Element> {\n    pub area: Vec<I>,\n    pub left: Vec<I>,\n    pub top: Vec<I>,\n    pub right: Vec<I>,\n    pub bottom: Vec<I>,\n}\n\nimpl<I: Element> StatsOp for ConnectedStatsOp<I> {\n    type Label = I;\n\n    fn init(&mut self, num_labels: usize) {\n        self.area = vec![0.elem(); num_labels];\n        self.left = vec![I::MAX; num_labels];\n        self.top = vec![I::MAX; num_labels];\n        self.right = vec![0.elem(); num_labels];\n        self.bottom = vec![0.elem(); num_labels];\n    }\n\n    fn update(&mut self, row: usize, column: usize, label: I) {\n        let l = label.to_usize();\n        unsafe {\n            *self.area.get_unchecked_mut(l) =\n                I::from_elem((*self.area.get_unchecked(l)).to_usize() + 1);\n            *self.left.get_unchecked_mut(l) =\n                I::from_elem((*self.left.get_unchecked(l)).to_usize().min(column));\n            *self.top.get_unchecked_mut(l) =\n                I::from_elem((*self.top.get_unchecked(l)).to_usize().min(row));\n            *self.right.get_unchecked_mut(l) =\n                I::from_elem((*self.right.get_unchecked(l)).to_usize().max(column));\n            *self.bottom.get_unchecked_mut(l) =\n                I::from_elem((*self.bottom.get_unchecked(l)).to_usize().max(row));\n        }\n    }\n\n    fn finish(&mut self) {\n        // Background shouldn't have stats\n        self.area[0] = 0.elem();\n        self.left[0] = 0.elem();\n        self.right[0] = 0.elem();\n        self.top[0] = 0.elem();\n        self.bottom[0] = 0.elem();\n    }\n}\n\nfn finalize_stats<B: Backend, I: Element>(\n    device: &B::Device,\n    stats: ConnectedStatsOp<I>,\n) -> ConnectedStatsPrimitive<B> {\n    let labels = stats.area.len();\n\n    let into_prim = |data: Vec<I>| {\n        let data = TensorData::new(data, Shape::new([labels]));\n        Tensor::<B, 1, Int>::from_data(data, device).into_primitive()\n    };\n\n    let max_label = {\n        let data = TensorData::new(vec![I::from_elem(labels - 1)], Shape::new([1]));\n        Tensor::<B, 1, Int>::from_data(data, device).into_primitive()\n    };\n\n    ConnectedStatsPrimitive {\n        area: into_prim(stats.area),\n        left: into_prim(stats.left),\n        top: into_prim(stats.top),\n        right: into_prim(stats.right),\n        bottom: into_prim(stats.bottom),\n        max_label,\n    }\n}\n\npub fn max_labels(h: usize, w: usize, conn: Connectivity) -> usize {\n    match conn {\n        Connectivity::Four => (h * w).div_ceil(2) + 1,\n        Connectivity::Eight => h.div_ceil(2) * w.div_ceil(2) + 1,\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/mod.rs",
    "content": "mod base;\nmod connected_components;\nmod morphology;\nmod nms;\nmod ops;\n\npub use base::*;\npub use connected_components::*;\npub use morphology::*;\npub use nms::*;\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/morphology/filter.rs",
    "content": "use core::slice;\nuse std::{marker::PhantomData, ptr::null};\n\nuse burn_tensor::Element;\nuse macerator::{\n    Scalar, Simd, VOrd, Vector, vload, vload_low, vload_unaligned, vstore, vstore_low,\n    vstore_unaligned,\n};\n\nuse crate::{Point, Size, backends::cpu::MinMax};\n\npub trait MorphOperator<T> {\n    fn apply(a: T, b: T) -> T;\n}\n\npub trait VecMorphOperator<T: Scalar> {\n    fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T>;\n}\n\npub struct MinOp;\npub struct MaxOp;\n\nimpl<T: MinMax> MorphOperator<T> for MinOp {\n    fn apply(a: T, b: T) -> T {\n        MinMax::min(a, b)\n    }\n}\n\nimpl<T: VOrd> VecMorphOperator<T> for MinOp {\n    fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {\n        T::vmin(a, b)\n    }\n}\n\nimpl<T: MinMax> MorphOperator<T> for MaxOp {\n    fn apply(a: T, b: T) -> T {\n        MinMax::max(a, b)\n    }\n}\n\nimpl<T: VOrd> VecMorphOperator<T> for MaxOp {\n    fn apply<S: Simd>(a: Vector<S, T>, b: Vector<S, T>) -> Vector<S, T> {\n        T::vmax(a, b)\n    }\n}\n\npub struct MorphRowFilter<T: Scalar, S: MorphOperator<T>, Vec: VecRow<T>> {\n    pub ksize: usize,\n    pub anchor: usize,\n    vec: Vec,\n    _t: PhantomData<T>,\n    _scalar: PhantomData<S>,\n}\n\nimpl<T: Scalar, SOp: MorphOperator<T>, Vec: VecRow<T>> MorphRowFilter<T, SOp, Vec> {\n    pub fn new(ksize: usize, anchor: usize) -> Self {\n        let vec = Vec::new(ksize, anchor);\n        Self {\n            ksize,\n            anchor,\n            vec,\n            _t: PhantomData,\n            _scalar: PhantomData,\n        }\n    }\n\n    pub fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) {\n        let k_size = self.ksize * ch;\n\n        if k_size == ch {\n            let width = width * ch;\n            dst[..width].copy_from_slice(&src[..width]);\n            return;\n        }\n\n        let i0 = self.vec.apply::<S>(src, dst, width, ch);\n        let width = width * ch;\n\n        for k in 0..ch {\n            let mut last_i = i0;\n            for i in (i0..width.saturating_sub(ch * 2)).step_by(ch * 2) {\n                let mut m = src[k + i + ch];\n                let mut last_j = ch * 2;\n                for j in (ch * 2..k_size).step_by(ch) {\n                    m = SOp::apply(m, src[k + i + j]);\n                    last_j = j + ch;\n                }\n                dst[k + i] = SOp::apply(m, src[k + i]);\n                dst[k + i + ch] = SOp::apply(m, src[k + i + last_j]);\n                last_i = i + ch * 2;\n            }\n\n            for i in (last_i..width).step_by(ch) {\n                let mut m = src[k + i];\n                for j in (ch..k_size).step_by(ch) {\n                    m = SOp::apply(m, src[k + i + j]);\n                }\n                dst[k + i] = m;\n            }\n        }\n    }\n}\n\npub struct MorphRowVec<T: Scalar, Op: VecMorphOperator<T>> {\n    k_size: usize,\n    _t: PhantomData<T>,\n    _op: PhantomData<Op>,\n}\n\npub trait VecRow<T: Scalar> {\n    fn new(ksize: usize, anchor: usize) -> Self;\n    fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, channels: usize) -> usize;\n}\n\nimpl<T: Scalar, Op: VecMorphOperator<T>> VecRow<T> for MorphRowVec<T, Op> {\n    fn apply<S: Simd>(&self, src: &[T], dst: &mut [T], width: usize, ch: usize) -> usize {\n        let src = src.as_ptr();\n        let dst = dst.as_mut_ptr();\n        let k_size = self.k_size * ch;\n        let width = (width * ch) as isize;\n        let lanes = T::lanes::<S>();\n\n        // Safety: everything here is unsafe. Test thoroughly.\n        unsafe {\n            let mut x = 0;\n            while x as isize <= width - 4 * lanes as isize {\n                let mut s0 = vload(src.add(x));\n                let mut s1 = vload(src.add(x + lanes));\n                let mut s2 = vload(src.add(x + 2 * lanes));\n                let mut s3 = vload(src.add(x + 3 * lanes));\n                for k in (ch..k_size).step_by(ch) {\n                    let x = x + k;\n                    s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x)));\n                    s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + lanes)));\n                    s2 = Op::apply::<S>(s2, vload_unaligned(src.add(x + 2 * lanes)));\n                    s3 = Op::apply::<S>(s3, vload_unaligned(src.add(x + 3 * lanes)));\n                }\n                vstore(dst.add(x), s0);\n                vstore(dst.add(x + lanes), s1);\n                vstore(dst.add(x + 2 * lanes), s2);\n                vstore(dst.add(x + 3 * lanes), s3);\n                x += 4 * lanes;\n            }\n            if x as isize <= width - 2 * lanes as isize {\n                let mut s0 = vload(src.add(x));\n                let mut s1 = vload(src.add(x + lanes));\n                for k in (ch..k_size).step_by(ch) {\n                    s0 = Op::apply::<S>(s0, vload_unaligned(src.add(x + k)));\n                    s1 = Op::apply::<S>(s1, vload_unaligned(src.add(x + k + lanes)));\n                }\n                vstore(dst.add(x), s0);\n                vstore(dst.add(x + lanes), s1);\n                x += 2 * lanes;\n            }\n            if x as isize <= width - lanes as isize {\n                let mut s = vload(src.add(x));\n                for k in (ch..k_size).step_by(ch) {\n                    s = Op::apply::<S>(s, vload_unaligned(src.add(x + k)));\n                }\n                vstore(dst.add(x), s);\n                x += lanes;\n            }\n            if x as isize <= width - lanes as isize / 2 {\n                let mut s = vload_low(src.add(x));\n                for k in (ch..k_size).step_by(ch) {\n                    s = Op::apply::<S>(s, vload_low(src.add(x + k)));\n                }\n                vstore_low(dst.add(x), s);\n                x += lanes / 2;\n            }\n            x - x % ch\n        }\n    }\n\n    fn new(k_size: usize, _anchor: usize) -> Self {\n        Self {\n            k_size,\n            _t: PhantomData,\n            _op: PhantomData,\n        }\n    }\n}\n\npub trait VecColumn<T: Scalar> {\n    fn new(ksize: usize, anchor: usize) -> Self;\n    fn apply<S: Simd>(\n        &self,\n\n        src: &[*const T],\n        dst: &mut [T],\n        dst_step: usize,\n        height: usize,\n        width: usize,\n    ) -> usize;\n}\n\npub struct MorphColumnVec<T: Scalar, Op: VecMorphOperator<T>> {\n    k_size: usize,\n    _t: PhantomData<T>,\n    _op: PhantomData<Op>,\n}\n\nimpl<T: VOrd, Op: VecMorphOperator<T>> VecColumn<T> for MorphColumnVec<T, Op> {\n    fn new(k_size: usize, _anchor: usize) -> Self {\n        Self {\n            k_size,\n            _t: PhantomData,\n            _op: PhantomData,\n        }\n    }\n\n    fn apply<S: Simd>(\n        &self,\n\n        src: &[*const T],\n        dst: &mut [T],\n        dst_step: usize,\n        mut count: usize,\n        width: usize,\n    ) -> usize {\n        let ksize = self.k_size;\n        let width = width as isize;\n        let mut dst = dst.as_mut_ptr();\n        let lanes = T::lanes::<S>();\n        let mut y = 0;\n        let mut x = 0;\n\n        // Safety: everything here is unsafe. Test thoroughly.\n        unsafe {\n            while count > 1 && ksize > 1 {\n                x = 0;\n                while x as isize <= width - 4 * lanes as isize {\n                    let sptr = src[y + 1].add(x);\n                    let mut s0 = vload(sptr);\n                    let mut s1 = vload(sptr.add(lanes));\n                    let mut s2 = vload(sptr.add(2 * lanes));\n                    let mut s3 = vload(sptr.add(3 * lanes));\n\n                    for k in 2..ksize {\n                        let sptr = src[y + k].add(x);\n                        s0 = Op::apply::<S>(s0, vload(sptr));\n                        s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));\n                        s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));\n                        s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));\n                    }\n\n                    // Row 1\n                    {\n                        let sptr = src[y].add(x);\n                        let s0 = Op::apply(s0, vload(sptr));\n                        let s1 = Op::apply(s1, vload(sptr.add(lanes)));\n                        let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));\n                        let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));\n                        vstore_unaligned(dst.add(x), s0);\n                        vstore_unaligned(dst.add(x + lanes), s1);\n                        vstore_unaligned(dst.add(x + 2 * lanes), s2);\n                        vstore_unaligned(dst.add(x + 3 * lanes), s3);\n                    }\n\n                    // Row 2\n                    {\n                        let sptr = src[y + ksize].add(x);\n                        let s0 = Op::apply(s0, vload(sptr));\n                        let s1 = Op::apply(s1, vload(sptr.add(lanes)));\n                        let s2 = Op::apply(s2, vload(sptr.add(2 * lanes)));\n                        let s3 = Op::apply(s3, vload(sptr.add(3 * lanes)));\n                        vstore_unaligned(dst.add(dst_step + x), s0);\n                        vstore_unaligned(dst.add(dst_step + x + lanes), s1);\n                        vstore_unaligned(dst.add(dst_step + x + 2 * lanes), s2);\n                        vstore_unaligned(dst.add(dst_step + x + 3 * lanes), s3);\n                    }\n                    x += 4 * lanes;\n                }\n                if x as isize <= width - 2 * lanes as isize {\n                    let sptr = src[y + 1].add(x);\n                    let mut s0 = vload(sptr);\n                    let mut s1 = vload(sptr.add(lanes));\n\n                    for k in 2..ksize {\n                        let sptr = src[y + k].add(x);\n                        s0 = Op::apply::<S>(s0, vload(sptr));\n                        s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));\n                    }\n\n                    // Row 1\n                    {\n                        let sptr = src[y].add(x);\n                        let s0 = Op::apply(s0, vload(sptr));\n                        let s1 = Op::apply(s1, vload(sptr.add(lanes)));\n                        vstore_unaligned(dst.add(x), s0);\n                        vstore_unaligned(dst.add(x + lanes), s1);\n                    }\n\n                    // Row 2\n                    {\n                        let sptr = src[y + ksize].add(x);\n                        let s0 = Op::apply(s0, vload(sptr));\n                        let s1 = Op::apply(s1, vload(sptr.add(lanes)));\n                        vstore_unaligned(dst.add(dst_step + x), s0);\n                        vstore_unaligned(dst.add(dst_step + x + lanes), s1);\n                    }\n                    x += 2 * lanes;\n                }\n                if x as isize <= width - lanes as isize {\n                    let mut s0 = vload(src[y + 1].add(x));\n                    for k in 2..ksize {\n                        s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));\n                    }\n                    // Row 1\n                    {\n                        let sptr = src[y].add(x);\n                        vstore_unaligned(dst.add(x), Op::apply(s0, vload(sptr)));\n                    }\n\n                    // Row 2\n                    {\n                        let sptr = src[y + ksize].add(x);\n                        let s0 = Op::apply(s0, vload(sptr));\n                        vstore_unaligned(dst.add(dst_step + x), s0);\n                    }\n                    x += lanes;\n                }\n                if x as isize <= width - lanes as isize / 2 {\n                    let mut s0 = vload_low(src[y + 1].add(x));\n                    for k in 2..ksize {\n                        s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));\n                    }\n                    // Row 1\n                    {\n                        let sptr = src[y].add(x);\n                        let s0 = Op::apply(s0, vload_low(sptr));\n                        vstore_low(dst.add(x), s0);\n                    }\n\n                    // Row 2\n                    {\n                        let sptr = src[y + ksize].add(x);\n                        let s0 = Op::apply(s0, vload_low(sptr));\n                        vstore_low(dst.add(dst_step + x), s0);\n                    }\n                    x += lanes / 2;\n                }\n\n                count -= 2;\n                dst = dst.add(dst_step * 2);\n                y += 2;\n            }\n\n            while count > 0 {\n                x = 0;\n                while x as isize <= width - 4 * lanes as isize {\n                    let sptr = src[y].add(x);\n                    let mut s0 = vload(sptr);\n                    let mut s1 = vload(sptr.add(lanes));\n                    let mut s2 = vload(sptr.add(2 * lanes));\n                    let mut s3 = vload(sptr.add(3 * lanes));\n\n                    for k in 1..ksize {\n                        let sptr = src[y + k].add(x);\n                        s0 = Op::apply::<S>(s0, vload(sptr));\n                        s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));\n                        s2 = Op::apply::<S>(s2, vload(sptr.add(2 * lanes)));\n                        s3 = Op::apply::<S>(s3, vload(sptr.add(3 * lanes)));\n                    }\n\n                    vstore_unaligned(dst.add(x), s0);\n                    vstore_unaligned(dst.add(x + lanes), s1);\n                    vstore_unaligned(dst.add(x + 2 * lanes), s2);\n                    vstore_unaligned(dst.add(x + 3 * lanes), s3);\n\n                    x += 4 * lanes;\n                }\n                if x as isize <= width - 2 * lanes as isize {\n                    let sptr = src[y].add(x);\n                    let mut s0 = vload(sptr);\n                    let mut s1 = vload(sptr.add(lanes));\n\n                    for k in 1..ksize {\n                        let sptr = src[y + k].add(x);\n                        s0 = Op::apply::<S>(s0, vload(sptr));\n                        s1 = Op::apply::<S>(s1, vload(sptr.add(lanes)));\n                    }\n\n                    vstore_unaligned(dst.add(x), s0);\n                    vstore_unaligned(dst.add(x + lanes), s1);\n                    x += 2 * lanes;\n                }\n                if x as isize <= width - lanes as isize {\n                    let mut s0 = vload(src[y].add(x));\n\n                    for k in 1..ksize {\n                        s0 = Op::apply::<S>(s0, vload(src[y + k].add(x)));\n                    }\n\n                    vstore_unaligned(dst.add(x), s0);\n                    x += lanes;\n                }\n                if x as isize <= width - lanes as isize / 2 {\n                    let mut s0 = vload_low(src[y].add(x));\n\n                    for k in 1..ksize {\n                        s0 = Op::apply::<S>(s0, vload_low(src[y + k].add(x)));\n                    }\n\n                    vstore_low(dst.add(x), s0);\n                    x += lanes / 2;\n                }\n\n                count -= 1;\n                dst = dst.add(dst_step);\n                y += 1;\n            }\n        }\n        x\n    }\n}\n\npub struct MorphColumnFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> {\n    pub ksize: usize,\n    pub anchor: usize,\n    vec: VecOp,\n    _t: PhantomData<T>,\n    _op: PhantomData<Op>,\n}\n\nimpl<T: Scalar, Op: MorphOperator<T>, VecOp: VecColumn<T>> MorphColumnFilter<T, Op, VecOp> {\n    pub fn new(ksize: usize, anchor: usize) -> Self {\n        let vec = VecOp::new(ksize, anchor);\n        Self {\n            ksize,\n            anchor,\n            vec,\n            _t: PhantomData,\n            _op: PhantomData,\n        }\n    }\n\n    pub fn apply<S: Simd>(\n        &self,\n\n        src: &[*const T],\n        dst: &mut [T],\n        dst_step: usize,\n        mut count: usize,\n        width: usize,\n    ) {\n        let ksize = self.ksize;\n        let x0 = self.vec.apply::<S>(src, dst, dst_step, count, width);\n        let width = width as isize;\n\n        let mut d = 0;\n        let mut x;\n        let mut y = 0;\n\n        let slice = |row: *const T| unsafe { slice::from_raw_parts(row, width as usize) };\n\n        while ksize > 1 && count > 1 {\n            x = x0;\n\n            while x as isize <= width - 4 {\n                let row = slice(src[y + 1]);\n                let mut s0 = row[x];\n                let mut s1 = row[x + 1];\n                let mut s2 = row[x + 2];\n                let mut s3 = row[x + 3];\n\n                for k in 2..ksize {\n                    let row = slice(src[y + k]);\n                    s0 = Op::apply(s0, row[x]);\n                    s1 = Op::apply(s1, row[x + 1]);\n                    s2 = Op::apply(s2, row[x + 2]);\n                    s3 = Op::apply(s3, row[x + 3]);\n                }\n\n                let row = slice(src[y]);\n                dst[d + x] = Op::apply(s0, row[x]);\n                dst[d + x + 1] = Op::apply(s1, row[x + 1]);\n                dst[d + x + 2] = Op::apply(s2, row[x + 2]);\n                dst[d + x + 3] = Op::apply(s3, row[x + 3]);\n\n                let row = slice(src[y + ksize]);\n                dst[d + dst_step + x] = Op::apply(s0, row[x]);\n                dst[d + dst_step + x + 1] = Op::apply(s1, row[x + 1]);\n                dst[d + dst_step + x + 2] = Op::apply(s2, row[x + 2]);\n                dst[d + dst_step + x + 3] = Op::apply(s3, row[x + 3]);\n\n                x += 4;\n            }\n            while (x as isize) < width {\n                let mut s0 = slice(src[y + 1])[x];\n                for k in 2..ksize {\n                    s0 = Op::apply(s0, slice(src[y + k])[x]);\n                }\n                dst[d + x] = Op::apply(s0, slice(src[y])[x]);\n                dst[d + dst_step + x] = Op::apply(s0, slice(src[y + ksize])[x]);\n\n                x += 1;\n            }\n\n            count -= 2;\n            d += 2 * dst_step;\n            y += 2;\n        }\n\n        while count > 0 {\n            x = x0;\n\n            while x as isize <= width - 4 {\n                let row = slice(src[y]);\n                let mut s0 = row[x];\n                let mut s1 = row[x + 1];\n                let mut s2 = row[x + 2];\n                let mut s3 = row[x + 3];\n\n                for k in 1..ksize {\n                    let row = slice(src[y + k]);\n                    s0 = Op::apply(s0, row[x]);\n                    s1 = Op::apply(s1, row[x + 1]);\n                    s2 = Op::apply(s2, row[x + 2]);\n                    s3 = Op::apply(s3, row[x + 3]);\n                }\n\n                dst[d + x] = s0;\n                dst[d + x + 1] = s1;\n                dst[d + x + 2] = s2;\n                dst[d + x + 3] = s3;\n\n                x += 4;\n            }\n            while (x as isize) < width {\n                let mut s0 = slice(src[y])[x];\n                for k in 1..ksize {\n                    s0 = Op::apply(s0, slice(src[y + k])[x]);\n                }\n\n                dst[d + x] = s0;\n\n                x += 1;\n            }\n\n            count -= 1;\n            d += dst_step;\n            y += 1;\n        }\n    }\n}\n\npub trait VecFilter<T: Scalar> {\n    fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize;\n}\n\npub struct MorphVec<T: Scalar, Op: VecMorphOperator<T>>(PhantomData<(T, Op)>);\n\nimpl<T: Scalar, Op: VecMorphOperator<T>> VecFilter<T> for MorphVec<T, Op> {\n    fn apply<S: Simd>(src: &[*const T], nz: usize, dst: &mut [T], width: usize) -> usize {\n        let dst = dst.as_mut_ptr();\n        let mut i = 0;\n        let lanes = T::lanes::<S>();\n        let width = width as isize;\n\n        // Safety: everything here is unsafe. Test thoroughly.\n        unsafe {\n            while i as isize <= width - 4 * lanes as isize {\n                let sptr = src[0].add(i);\n                let mut s0 = vload_unaligned(sptr);\n                let mut s1 = vload_unaligned(sptr.add(lanes));\n                let mut s2 = vload_unaligned(sptr.add(2 * lanes));\n                let mut s3 = vload_unaligned(sptr.add(3 * lanes));\n                for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {\n                    s0 = Op::apply::<S>(s0, vload_unaligned(sptr));\n                    s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));\n                    s2 = Op::apply::<S>(s2, vload_unaligned(sptr.add(2 * lanes)));\n                    s3 = Op::apply::<S>(s3, vload_unaligned(sptr.add(3 * lanes)));\n                }\n                vstore_unaligned(dst.add(i), s0);\n                vstore_unaligned(dst.add(i + lanes), s1);\n                vstore_unaligned(dst.add(i + 2 * lanes), s2);\n                vstore_unaligned(dst.add(i + 3 * lanes), s3);\n                i += 4 * lanes;\n            }\n            if i as isize <= width - 2 * lanes as isize {\n                let sptr = src[0].add(i);\n                let mut s0 = vload_unaligned(sptr);\n                let mut s1 = vload_unaligned(sptr.add(lanes));\n                for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {\n                    s0 = Op::apply::<S>(s0, vload_unaligned(sptr));\n                    s1 = Op::apply::<S>(s1, vload_unaligned(sptr.add(lanes)));\n                }\n                vstore_unaligned(dst.add(i), s0);\n                vstore_unaligned(dst.add(i + lanes), s1);\n                i += 2 * lanes;\n            }\n            if i as isize <= width - lanes as isize {\n                let mut s0 = vload_unaligned(src[0].add(i));\n                for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {\n                    s0 = Op::apply::<S>(s0, vload_unaligned(sptr));\n                }\n                vstore_unaligned(dst.add(i), s0);\n                i += lanes;\n            }\n            if i as isize <= width - lanes as isize / 2 {\n                let mut s = vload_low(src[0].add(i));\n                for sptr in src[1..nz].iter().map(|sptr| sptr.add(i)) {\n                    s = Op::apply::<S>(s, vload_low(sptr));\n                }\n                vstore_low(dst.add(i), s);\n                i += lanes / 2;\n            }\n        }\n        i\n    }\n}\n\npub struct MorphFilter<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> {\n    pub ksize: Size,\n    pub anchor: Point,\n    coords: Vec<Point>,\n    ptrs: Vec<*const T>,\n    _op: PhantomData<(Op, VecOp)>,\n}\n\nimpl<T: Scalar, Op: MorphOperator<T>, VecOp: VecFilter<T>> MorphFilter<T, Op, VecOp> {\n    pub fn new<B: Element>(kernel: &[B], ksize: Size, anchor: Point) -> Self {\n        let coords = process_2d_kernel(kernel, ksize);\n        let ptrs = vec![null(); coords.len()];\n\n        Self {\n            ksize,\n            anchor,\n            coords,\n            ptrs,\n            _op: PhantomData,\n        }\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    pub fn apply<S: Simd>(\n        &mut self,\n\n        src: &[*const T],\n        dst: &mut [T],\n        dst_step: usize,\n        mut count: usize,\n        width: usize,\n        ch: usize,\n    ) {\n        let nz = self.coords.len();\n        let width = (width * ch) as isize;\n        let pt = &self.coords;\n        let kp = &mut self.ptrs;\n\n        let mut dst_off = 0;\n        let mut src_off = 0;\n\n        let slice = |ptr: *const T| unsafe { slice::from_raw_parts(ptr, width as usize) };\n\n        unsafe {\n            while count > 0 {\n                for k in 0..nz {\n                    kp[k] = src[src_off + pt[k].y].add(pt[k].x * ch);\n                }\n\n                let mut i = VecOp::apply::<S>(kp, nz, &mut dst[dst_off..], width as usize);\n                while i as isize <= width - 4 {\n                    let sptr = slice(kp[0].add(i));\n                    let mut s0 = sptr[0];\n                    let mut s1 = sptr[1];\n                    let mut s2 = sptr[2];\n                    let mut s3 = sptr[3];\n\n                    for sptr in kp[1..nz].iter().map(|sptr| slice(sptr.add(i))) {\n                        s0 = Op::apply(s0, sptr[0]);\n                        s1 = Op::apply(s1, sptr[1]);\n                        s2 = Op::apply(s2, sptr[2]);\n                        s3 = Op::apply(s3, sptr[3]);\n                    }\n\n                    dst[dst_off + i] = s0;\n                    dst[dst_off + i + 1] = s1;\n                    dst[dst_off + i + 2] = s2;\n                    dst[dst_off + i + 3] = s3;\n                    i += 4;\n                }\n                for i in i..width as usize {\n                    let mut s0 = *kp[0].add(i);\n                    for v in kp[1..nz].iter().map(|sptr| *sptr.add(i)) {\n                        s0 = Op::apply(s0, v);\n                    }\n                    dst[dst_off + i] = s0;\n                }\n\n                count -= 1;\n                dst_off += dst_step;\n                src_off += 1;\n            }\n        }\n    }\n}\n\nfn process_2d_kernel<B: Element>(kernel: &[B], ksize: Size) -> Vec<Point> {\n    let Size { width, height } = ksize;\n\n    let mut nz = kernel.iter().filter(|it| it.to_bool()).count();\n    if nz == 0 {\n        nz = 1;\n    }\n\n    let mut coords = vec![Point::new(0, 0); nz];\n    let mut k = 0;\n\n    for y in 0..height {\n        let krow = &kernel[y * width..];\n        for (x, _) in krow[..width].iter().enumerate().filter(|it| it.1.to_bool()) {\n            coords[k] = Point::new(x, y);\n            k += 1;\n        }\n    }\n\n    coords\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/morphology/filter_engine.rs",
    "content": "use std::{fmt::Debug, ptr::null_mut};\n\nuse burn_tensor::Shape;\nuse bytemuck::{Zeroable, cast_slice, cast_slice_mut};\nuse macerator::{Simd, VOrd, Vector};\n\nuse crate::{BorderType, Point, Size};\n\nuse super::filter::{\n    MorphColumnFilter, MorphColumnVec, MorphFilter, MorphOperator, MorphRowFilter, MorphRowVec,\n    MorphVec, VecMorphOperator,\n};\n\npub type RowFilter<T, Op> = MorphRowFilter<T, Op, MorphRowVec<T, Op>>;\npub type ColFilter<T, Op> = MorphColumnFilter<T, Op, MorphColumnVec<T, Op>>;\npub type Filter2D<T, Op> = MorphFilter<T, Op, MorphVec<T, Op>>;\n\npub enum Filter<T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {\n    Separable {\n        row_filter: RowFilter<T, Op>,\n        col_filter: ColFilter<T, Op>,\n    },\n    Fallback(Filter2D<T, Op>),\n}\n\npub struct FilterEngine<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> {\n    /// Vector aligned ring buffer to serve as intermediate, since image isn't always aligned\n    ring_buf: Vec<Vector<S, T>>,\n    /// Vector aligned row buffer to serve as intermediate, since image isn't always aligned\n    src_row: Vec<Vector<S, T>>,\n    const_border_value: Vec<T>,\n    const_border_row: Vec<Vector<S, T>>,\n    border_table: Vec<usize>,\n    /// Pointers to each row offset in the ring buffer\n    rows: Vec<*const T>,\n\n    filter: Filter<T, Op>,\n\n    ksize: Size,\n    anchor: Point,\n    dx1: usize,\n    dx2: usize,\n    row_count: usize,\n    dst_y: usize,\n    start_y: usize,\n    start_y_0: usize,\n    end_y: usize,\n\n    max_width: usize,\n    buf_step: usize,\n    width: usize,\n    height: usize,\n    border_type: BorderType,\n}\n\nimpl<S: Simd, T: VOrd, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {\n    fn resize_ring_buf(&mut self, size: usize) {\n        let actual = size.div_ceil(T::lanes::<S>());\n        self.ring_buf.resize(actual, Zeroable::zeroed());\n    }\n    fn resize_src_row(&mut self, size: usize) {\n        let actual = size.div_ceil(T::lanes::<S>());\n        self.src_row.resize(actual, Zeroable::zeroed());\n    }\n    fn is_separable(&self) -> bool {\n        matches!(self.filter, Filter::Separable { .. })\n    }\n}\n\nimpl<S: Simd, T: VOrd + Debug, Op: MorphOperator<T> + VecMorphOperator<T>> FilterEngine<S, T, Op> {\n    pub fn new(\n        filter: Filter<T, Op>,\n        border_type: BorderType,\n        border_value: &[T],\n        ch: usize,\n    ) -> Self {\n        let (ksize, anchor) = match &filter {\n            Filter::Separable {\n                row_filter,\n                col_filter,\n            } => {\n                let ksize = Size::new(row_filter.ksize, col_filter.ksize);\n                let anchor = Point::new(row_filter.anchor, col_filter.anchor);\n                (ksize, anchor)\n            }\n            Filter::Fallback(f) => (f.ksize, f.anchor),\n        };\n\n        let mut border_table = Vec::new();\n        let border_length = (ksize.width - 1).max(1);\n        let mut const_border_value = Vec::new();\n        if matches!(border_type, BorderType::Constant) {\n            const_border_value = vec![Zeroable::zeroed(); border_length * ch];\n            for elem in cast_slice_mut::<_, T>(&mut const_border_value).chunks_exact_mut(ch) {\n                elem.copy_from_slice(border_value);\n            }\n        } else {\n            border_table = vec![0; border_length * ch];\n        }\n\n        Self {\n            ring_buf: Default::default(),\n            src_row: Default::default(),\n            rows: Default::default(),\n            border_type,\n            const_border_row: Default::default(),\n            const_border_value,\n            border_table,\n            ksize,\n            anchor,\n            filter,\n            max_width: 0,\n            buf_step: 0,\n            dx1: 0,\n            dx2: 0,\n            row_count: 0,\n            dst_y: 0,\n            start_y: 0,\n            start_y_0: 0,\n            end_y: 0,\n            width: 0,\n            height: 0,\n        }\n    }\n\n    pub fn apply(&mut self, tensor: &mut [T], src_shape: Shape) {\n        let [_, w, ch] = src_shape.dims();\n        let src_step = w * ch;\n        self.start(src_shape);\n        let y = self.start_y;\n        self.proceed(\n            &mut tensor[y * src_step..],\n            src_step,\n            self.end_y - self.start_y,\n            ch,\n        );\n    }\n\n    pub fn start(&mut self, shape: Shape) -> usize {\n        let [height, width, ch] = shape.dims();\n\n        let max_buf_rows = (self.ksize.height + 3)\n            .max(self.anchor.y)\n            .max((self.ksize.height - self.anchor.y - 1) * 2 + 1);\n        let k_offs = if !self.is_separable() {\n            self.ksize.width - 1\n        } else {\n            0\n        };\n        let is_sep = self.is_separable();\n\n        if self.max_width < width || max_buf_rows != self.rows.len() {\n            self.rows.resize(max_buf_rows, null_mut());\n            self.max_width = self.max_width.max(width);\n            self.resize_src_row((self.max_width + self.ksize.width - 1) * ch);\n\n            if matches!(self.border_type, BorderType::Constant) {\n                self.const_border_row.resize(\n                    ((self.max_width + self.ksize.width - 1) * ch).div_ceil(T::lanes::<S>()),\n                    Zeroable::zeroed(),\n                );\n                let mut n = self.const_border_value.len();\n                let n1 = (self.max_width + self.ksize.width - 1) * ch;\n                let const_val = &self.const_border_value;\n                let dst = cast_slice_mut(&mut self.const_border_row);\n                let t_dst = if is_sep {\n                    cast_slice_mut::<_, T>(&mut self.src_row)\n                } else {\n                    alias_slice_mut(dst)\n                };\n\n                for i in (0..n1).step_by(n) {\n                    n = n.min(n1 - i);\n                    t_dst[i..i + n].copy_from_slice(&const_val[..n]);\n                }\n\n                if let Filter::Separable { row_filter, .. } = &self.filter {\n                    row_filter.apply::<S>(cast_slice(&self.src_row), dst, self.max_width, ch);\n                }\n            }\n\n            let max_buf_step =\n                (self.max_width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;\n\n            self.resize_ring_buf(max_buf_step * self.rows.len());\n        }\n\n        let const_val = &self.const_border_value;\n\n        self.buf_step = (width + k_offs).next_multiple_of(align_of::<Vector<S, T>>()) * ch;\n\n        self.dx1 = self.anchor.x;\n        self.dx2 = self.ksize.width - self.anchor.x - 1;\n\n        if self.dx1 > 0 || self.dx2 > 0 {\n            if matches!(self.border_type, BorderType::Constant) {\n                let nr = if self.is_separable() {\n                    1\n                } else {\n                    self.rows.len()\n                };\n                for i in 0..nr {\n                    let dst = if self.is_separable() {\n                        cast_slice_mut::<_, T>(&mut self.src_row)\n                    } else {\n                        &mut cast_slice_mut::<_, T>(&mut self.ring_buf)[self.buf_step * i..]\n                    };\n                    memcpy(dst, const_val, self.dx1 * ch);\n                    let right = (width + self.ksize.width - 1 - self.dx2) * ch;\n                    memcpy(&mut dst[right..], const_val, self.dx2 * ch);\n                }\n            } else {\n                for i in 0..self.dx1 as isize {\n                    let p0 = border_interpolate(i - self.dx1 as isize, width, self.border_type);\n                    let p0 = p0 as usize * ch;\n                    for j in 0..ch {\n                        self.border_table[i as usize * ch + j] = p0 + j;\n                    }\n                }\n                for i in 0..self.dx2 {\n                    let p0 = border_interpolate((width + i) as isize, width, self.border_type)\n                        as usize\n                        * ch;\n                    for j in 0..ch {\n                        self.border_table[(i + self.dx1) * ch + j] = p0 + j;\n                    }\n                }\n            }\n        }\n\n        self.row_count = 0;\n        self.dst_y = 0;\n        self.start_y = 0;\n        self.start_y_0 = 0;\n        self.end_y = height;\n        self.width = width;\n        self.height = height;\n\n        self.start_y\n    }\n\n    #[allow(clippy::too_many_arguments)]\n    pub fn proceed(\n        &mut self,\n\n        src: &mut [T],\n        src_step: usize,\n        mut count: usize,\n        ch: usize,\n    ) -> usize {\n        let buf_rows = self.rows.len();\n        let kheight = self.ksize.height;\n        let kwidth = self.ksize.width;\n        let ay = self.anchor.y as isize;\n        let dx1 = self.dx1;\n        let dx2 = self.dx2;\n        let width1 = self.width + kwidth - 1;\n        let btab = &self.border_table;\n        let make_border = (dx1 > 0 || dx2 > 0) && !matches!(self.border_type, BorderType::Constant);\n        let is_sep = self.is_separable();\n\n        count = count.min(self.remaining_input_rows());\n        let mut dst_off = 0;\n        let mut src_off = 0;\n        let mut dy = 0;\n        let mut i;\n        let brows = &mut self.rows;\n\n        let src_row = cast_slice_mut::<_, T>(&mut self.src_row);\n        let ring_buf = cast_slice_mut::<_, T>(&mut self.ring_buf);\n\n        loop {\n            let dcount = buf_rows as isize - ay - self.start_y as isize - self.row_count as isize;\n            let mut dcount = if dcount > 0 {\n                dcount as usize\n            } else {\n                buf_rows + 1 - kheight\n            };\n            dcount = dcount.min(count);\n            count -= dcount;\n\n            while dcount > 0 {\n                let bi = (self.start_y - self.start_y_0 + self.row_count) % buf_rows;\n                let brow = &mut ring_buf[bi * self.buf_step..];\n                let row = if is_sep {\n                    &mut src_row[..]\n                } else {\n                    alias_slice_mut(brow)\n                };\n\n                if self.row_count + 1 > buf_rows {\n                    self.row_count -= 1;\n                    self.start_y += 1;\n                }\n                self.row_count += 1;\n\n                memcpy(\n                    &mut row[dx1 * ch..],\n                    &src[src_off..],\n                    (width1 - dx2 - dx1) * ch,\n                );\n\n                if make_border {\n                    for i in 0..dx1 * ch {\n                        row[i] = src[src_off + btab[i]];\n                    }\n                    for i in 0..dx2 * ch {\n                        row[i + (width1 - dx2) * ch] = src[src_off + btab[i + dx1 * ch]];\n                    }\n                }\n\n                if let Filter::Separable { row_filter, .. } = &self.filter {\n                    row_filter.apply::<S>(row, brow, self.width, ch);\n                }\n\n                dcount -= 1;\n                src_off += src_step;\n            }\n\n            let max_i = buf_rows.min(self.height - (self.dst_y + dy) + (kheight - 1));\n            i = 0;\n            while i < max_i {\n                let src_y = border_interpolate(\n                    (self.dst_y + dy + i) as isize - ay,\n                    self.height,\n                    self.border_type,\n                );\n                if src_y < 0 {\n                    brows[i] = self.const_border_row.as_ptr() as _;\n                } else {\n                    if src_y as usize >= self.start_y + self.row_count {\n                        break;\n                    }\n                    let bi = (src_y as usize - self.start_y_0) % buf_rows;\n                    brows[i] = unsafe { ring_buf.as_ptr().add(bi * self.buf_step) };\n                }\n\n                i += 1;\n            }\n            if i < kheight {\n                break;\n            }\n            i -= kheight - 1;\n            match &mut self.filter {\n                Filter::Separable { col_filter, .. } => {\n                    col_filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width * ch)\n                }\n                Filter::Fallback(filter) => {\n                    filter.apply::<S>(brows, &mut src[dst_off..], src_step, i, self.width, ch)\n                }\n            }\n\n            dst_off += src_step * i;\n            dy += i;\n        }\n\n        self.dst_y += dy;\n        dy\n    }\n\n    fn remaining_input_rows(&self) -> usize {\n        self.end_y - self.start_y - self.row_count\n    }\n}\n\n#[track_caller]\nfn memcpy<T: Copy>(to: &mut [T], from: &[T], len: usize) {\n    to[..len].copy_from_slice(&from[..len]);\n}\n\n/// Unsafely alias slice. Needed for the conditional slice targets that depend on the filter. The\n/// same slice shouldn't be used multiple times at once\nfn alias_slice_mut<'b, T>(slice: &mut [T]) -> &'b mut [T] {\n    let ptr = slice.as_mut_ptr();\n    let len = slice.len();\n    unsafe { core::slice::from_raw_parts_mut(ptr, len) }\n}\n\nfn border_interpolate(mut p: isize, len: usize, btype: BorderType) -> isize {\n    let len = len as isize;\n    if p < len && p >= 0 {\n        return p;\n    }\n    match btype {\n        BorderType::Constant => -1,\n        BorderType::Replicate if p < 0 => 0,\n        BorderType::Replicate => len - 1,\n        BorderType::Reflect | BorderType::Reflect101 => {\n            let delta = matches!(btype, BorderType::Reflect101) as isize;\n            if len == 1 {\n                return 0;\n            }\n            loop {\n                if p < 0 {\n                    p = -p - 1 + delta;\n                } else {\n                    p = len - 1 - (p - len) - delta;\n                }\n                if p < len && p >= 0 {\n                    break;\n                }\n            }\n            p\n        }\n        BorderType::Wrap => {\n            if p < 0 {\n                p -= ((p - len + 1) / len) * len;\n            }\n            if p >= len {\n                p %= len;\n            }\n            p\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/morphology/mod.rs",
    "content": "use std::fmt::Debug;\n\nuse burn_tensor::{\n    BasicOps, Bool, BoolStore, DType, Element, Shape, Tensor, TensorData, backend::Backend,\n    cast::ToElement, ops::BoolTensor,\n};\nuse filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator};\nuse filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter};\nuse macerator::{Simd, VOrd};\n\nuse crate::{BorderType, MorphOptions, Point, Size};\n\nuse super::MinMax;\n\nmod filter;\nmod filter_engine;\n\n/// A morphology operation.\n/// TODO: Implement composite ops\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub enum MorphOp {\n    Erode,\n    Dilate,\n}\n\npub enum MorphKernel<B: Element> {\n    Rect {\n        size: Size,\n        anchor: Point,\n    },\n    Other {\n        kernel: Vec<B>,\n        size: Size,\n        anchor: Point,\n    },\n}\n\npub fn morph<B: Backend, K: BasicOps<B>>(\n    input: Tensor<B, 3, K>,\n    kernel: BoolTensor<B>,\n    op: MorphOp,\n    opts: MorphOptions<B, K>,\n) -> Tensor<B, 3, K> {\n    let device = input.device();\n\n    let kernel = Tensor::<B, 2, Bool>::new(kernel);\n    let kshape = kernel.shape().dims();\n    let [kh, kw] = kshape;\n\n    let kernel = kernel.into_data().into_vec::<B::BoolElem>().unwrap();\n    let is_rect = kernel.iter().all(|it| it.to_bool());\n    let anchor = opts.anchor.unwrap_or(Point::new(kw / 2, kh / 2));\n    let iter = opts.iterations;\n    let btype = opts.border_type;\n    let bvalue = opts.border_value.map(|it| it.into_data());\n\n    let size = Size::new(kw, kh);\n    let kernel = if is_rect {\n        MorphKernel::Rect { size, anchor }\n    } else {\n        MorphKernel::Other {\n            kernel,\n            size,\n            anchor,\n        }\n    };\n\n    let shape = input.shape();\n    let data = input.into_data();\n    match data.dtype {\n        DType::F64 => {\n            morph_typed::<B, K, f64>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::F32 | DType::Flex32 => {\n            morph_typed::<B, K, f32>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::F16 | DType::BF16 => morph_typed::<B, K, f32>(\n            data.convert::<f32>(),\n            shape,\n            kernel,\n            op,\n            iter,\n            btype,\n            bvalue,\n            &device,\n        ),\n        DType::I64 => {\n            morph_typed::<B, K, i64>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::I32 => {\n            morph_typed::<B, K, i32>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::I16 => {\n            morph_typed::<B, K, i16>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::I8 => morph_typed::<B, K, i8>(data, shape, kernel, op, iter, btype, bvalue, &device),\n        DType::U64 => {\n            morph_typed::<B, K, u64>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::U32 | DType::Bool(BoolStore::U32) => {\n            morph_typed::<B, K, u32>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::U16 => {\n            morph_typed::<B, K, u16>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::U8 | DType::Bool(BoolStore::U8) => {\n            morph_typed::<B, K, u8>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::Bool(BoolStore::Native) => {\n            morph_bool::<B, K>(data, shape, kernel, op, iter, btype, bvalue, &device)\n        }\n        DType::QFloat(_) => unimplemented!(),\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\nfn morph_typed<B: Backend, K: BasicOps<B>, T: VOrd + MinMax + Element>(\n    mut input: TensorData,\n    shape: Shape,\n    kernel: MorphKernel<B::BoolElem>,\n    op: MorphOp,\n    iter: usize,\n    btype: BorderType,\n    bvalue: Option<TensorData>,\n    device: &B::Device,\n) -> Tensor<B, 3, K> {\n    let data = input.as_mut_slice::<T>().unwrap();\n    let bvalue = border_value(btype, bvalue, op, &shape);\n    run_morph(data, shape, kernel, op, iter, btype, &bvalue);\n    Tensor::from_data(input, device)\n}\n\n#[allow(clippy::too_many_arguments)]\nfn morph_bool<B: Backend, K: BasicOps<B>>(\n    mut input: TensorData,\n    shape: Shape,\n    kernel: MorphKernel<B::BoolElem>,\n    op: MorphOp,\n    iter: usize,\n    btype: BorderType,\n    bvalue: Option<TensorData>,\n    device: &B::Device,\n) -> Tensor<B, 3, K> {\n    let data = input.as_mut_slice::<bool>().unwrap();\n    // SAFETY: Morph can't produce invalid boolean values\n    let data = unsafe { core::mem::transmute::<&mut [bool], &mut [u8]>(data) };\n    let bvalue = border_value(btype, bvalue, op, &shape);\n    run_morph(data, shape.clone(), kernel, op, iter, btype, &bvalue);\n    Tensor::from_data(input, device)\n}\n\nfn border_value<T: Element>(\n    btype: BorderType,\n    bvalue: Option<TensorData>,\n    op: MorphOp,\n    shape: &Shape,\n) -> Vec<T> {\n    let [_, _, ch] = shape.dims();\n    match (btype, bvalue) {\n        (BorderType::Constant, Some(value)) => value.convert::<T>().into_vec().unwrap(),\n        (BorderType::Constant, None) => match op {\n            MorphOp::Erode => vec![T::MAX; ch],\n            MorphOp::Dilate => vec![T::MIN; ch],\n        },\n        _ => vec![],\n    }\n}\n\nfn run_morph<T: VOrd + MinMax + Element, B: Element>(\n    input: &mut [T],\n    shape: Shape,\n    kernel: MorphKernel<B>,\n    op: MorphOp,\n    iter: usize,\n    btype: BorderType,\n    bvalue: &[T],\n) {\n    match op {\n        MorphOp::Erode => {\n            let filter = filter::<T, MinOp, B>(kernel);\n            dispatch_morph(input, shape, filter, btype, bvalue, iter);\n        }\n        MorphOp::Dilate => {\n            let filter = filter::<T, MaxOp, B>(kernel);\n            dispatch_morph(input, shape, filter, btype, bvalue, iter);\n        }\n    };\n}\n\nfn filter<T: VOrd + MinMax, Op: MorphOperator<T> + VecMorphOperator<T>, B: Element>(\n    kernel: MorphKernel<B>,\n) -> Filter<T, Op> {\n    match kernel {\n        MorphKernel::Rect { size, anchor } => {\n            let row_filter = RowFilter::new(size.width, anchor.x);\n            let col_filter = ColFilter::new(size.height, anchor.y);\n            Filter::Separable {\n                row_filter,\n                col_filter,\n            }\n        }\n        MorphKernel::Other {\n            kernel,\n            size,\n            anchor,\n        } => {\n            let filter = Filter2D::new(&kernel, size, anchor);\n            Filter::Fallback(filter)\n        }\n    }\n}\n\n#[inline(always)]\n#[allow(clippy::too_many_arguments)]\n#[macerator::with_simd]\nfn dispatch_morph<\n    'a,\n    S: Simd,\n    T: VOrd + MinMax + Debug,\n    Op: MorphOperator<T> + VecMorphOperator<T>,\n>(\n    buffer: &'a mut [T],\n    buffer_shape: Shape,\n    filter: filter_engine::Filter<T, Op>,\n    border_type: BorderType,\n    border_value: &'a [T],\n    iterations: usize,\n) where\n    'a: 'a,\n{\n    let [_, _, ch] = buffer_shape.dims();\n    let mut engine = FilterEngine::<S, _, _>::new(filter, border_type, border_value, ch);\n    engine.apply(buffer, buffer_shape.clone());\n    for _ in 1..iterations {\n        engine.apply(buffer, buffer_shape.clone());\n    }\n}\n\n/// Shape of the structuring element\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub enum KernelShape {\n    /// Rectangular kernel\n    Rect,\n    /// Cross shaped kernel\n    Cross,\n    /// Ellipse shaped kernel\n    Ellipse,\n}\n\n/// Create a structuring element tensor for use with morphology ops\npub fn create_structuring_element<B: Backend>(\n    shape: KernelShape,\n    ksize: Size,\n    anchor: Option<Point>,\n    device: &B::Device,\n) -> Tensor<B, 2, Bool> {\n    fn create_kernel(shape: KernelShape, ksize: Size, anchor: Option<Point>) -> Vec<bool> {\n        let anchor = anchor.unwrap_or(Point::new(ksize.width / 2, ksize.height / 2));\n        let mut r = 0;\n        let mut c = 0;\n        let mut inv_r2 = 0.0;\n\n        if (ksize.width == 1 && ksize.height == 1) || shape == KernelShape::Rect {\n            return vec![true; ksize.height * ksize.width];\n        }\n\n        if shape == KernelShape::Ellipse {\n            r = ksize.height / 2;\n            c = ksize.width / 2;\n            inv_r2 = if r > 0 { 1.0 / (r * r) as f64 } else { 0.0 }\n        }\n\n        let mut elem = vec![false; ksize.height * ksize.width];\n\n        for i in 0..ksize.height {\n            let mut j1 = 0;\n            let mut j2 = 0;\n            if shape == KernelShape::Cross && i == anchor.y {\n                j2 = ksize.width;\n            } else if shape == KernelShape::Cross {\n                j1 = anchor.x;\n                j2 = j1 + 1;\n            } else {\n                let dy = i as isize - r as isize;\n                if dy.abs() <= r as isize {\n                    let dx = (c as f64 * ((r * r - (dy * dy) as usize) as f64 * inv_r2).sqrt())\n                        .round() as isize;\n                    j1 = (c as isize - dx).max(0) as usize;\n                    j2 = (c + dx as usize + 1).min(ksize.width);\n                }\n            }\n\n            for j in j1..j2 {\n                elem[i * ksize.width + j] = true;\n            }\n        }\n        elem\n    }\n\n    let elem = create_kernel(shape, ksize, anchor);\n\n    let data = TensorData::new(elem, [ksize.height, ksize.width]);\n    Tensor::from_data(data, device)\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/nms.rs",
    "content": "use crate::NmsOptions;\nuse aligned_vec::{AVec, ConstAlign};\nuse alloc::vec::Vec;\nuse burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend};\nuse macerator::{Scalar, Simd, Vector, vload};\n\n/// Perform NMS on CPU using SIMD acceleration.\n///\n/// This implementation:\n/// 1. Sorts boxes by score (descending)\n/// 2. Iteratively selects the highest-scoring non-suppressed box\n/// 3. Suppresses all boxes with IoU > threshold using SIMD\npub fn nms<B: Backend>(\n    boxes: Tensor<B, 2>,\n    scores: Tensor<B, 1>,\n    options: NmsOptions,\n) -> Tensor<B, 1, Int> {\n    let device = boxes.device();\n    let [n_boxes, _] = boxes.shape().dims();\n    if n_boxes == 0 {\n        return Tensor::<B, 1, Int>::empty([0], &device);\n    }\n\n    // Get raw data\n    let boxes_data = boxes.to_data();\n    let boxes_vec: Vec<f32> = boxes_data.to_vec().unwrap();\n\n    let scores_data = scores.to_data();\n    let scores_vec: Vec<f32> = scores_data.to_vec().unwrap();\n\n    let keep = nms_vec(boxes_vec, scores_vec, options);\n    let n_kept = keep.len();\n    let indices_data = TensorData::new(keep, Shape::new([n_kept]));\n    Tensor::<B, 1, Int>::from_data(indices_data, &device)\n}\n\n/// Perform NMS on CPU using SIMD acceleration.\nfn nms_vec(boxes_vec: Vec<f32>, scores_vec: Vec<f32>, options: NmsOptions) -> Vec<i32> {\n    let n_boxes = scores_vec.len();\n\n    if n_boxes == 0 {\n        return vec![];\n    }\n\n    // Filter by score threshold first\n    let mut filtered_indices = Vec::with_capacity(n_boxes);\n    for (i, &score) in scores_vec.iter().enumerate() {\n        if score >= options.score_threshold {\n            filtered_indices.push(i); // original index\n        }\n    }\n\n    let n_filtered = filtered_indices.len();\n    if n_filtered == 0 {\n        return vec![];\n    }\n\n    // Sort by score descending\n    filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a]));\n\n    const ALIGN: usize = 64;\n    const FLOATS_PER_ALIGN: usize = ALIGN / size_of::<f32>(); // 16\n    let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN;\n    let mut buf: AVec<f32, ConstAlign<64>> = AVec::with_capacity(ALIGN, stride * 5);\n    buf.resize(stride * 5, 0.0);\n\n    let (x1s, rest) = buf.split_at_mut(stride);\n    let (y1s, rest) = rest.split_at_mut(stride);\n    let (x2s, rest) = rest.split_at_mut(stride);\n    let (y2s, areas) = rest.split_at_mut(stride);\n\n    // Convert filtered boxes to SoA format\n    for (j, &orig_idx) in filtered_indices.iter().enumerate() {\n        let x1 = boxes_vec[orig_idx * 4];\n        let y1 = boxes_vec[orig_idx * 4 + 1];\n        let x2 = boxes_vec[orig_idx * 4 + 2];\n        let y2 = boxes_vec[orig_idx * 4 + 3];\n        x1s[j] = x1;\n        y1s[j] = y1;\n        x2s[j] = x2;\n        y2s[j] = y2;\n        areas[j] = (x2 - x1) * (y2 - y1);\n    }\n\n    // Apply NMS with SIMD dispatch\n    let mut suppressed = vec![false; stride];\n    let mut keep = Vec::new();\n\n    for i in 0..n_filtered {\n        if suppressed[i] {\n            continue;\n        }\n\n        // Optimization to reduce inner loop comparisons\n        suppressed[i] = true;\n        keep.push(filtered_indices[i] as i32); // original index\n\n        if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes {\n            break;\n        }\n\n        // Suppress overlapping boxes using SIMD\n        suppress_overlapping(\n            x1s[i],\n            y1s[i],\n            x2s[i],\n            y2s[i],\n            areas[i],\n            x1s,\n            y1s,\n            x2s,\n            y2s,\n            areas,\n            &mut suppressed,\n            stride,\n            options.iou_threshold,\n        );\n    }\n\n    keep\n}\n\n/// SIMD-accelerated suppression of overlapping boxes.\n#[allow(clippy::too_many_arguments)]\n#[inline(always)]\n#[macerator::with_simd]\nfn suppress_overlapping<'a, S: Simd>(\n    ref_x1: f32,\n    ref_y1: f32,\n    ref_x2: f32,\n    ref_y2: f32,\n    ref_area: f32,\n    x1s: &'a [f32],\n    y1s: &'a [f32],\n    x2s: &'a [f32],\n    y2s: &'a [f32],\n    areas: &'a [f32],\n    suppressed: &'a mut [bool],\n    n_boxes: usize, // stride, always multiple of lanes\n    threshold: f32,\n) where\n    'a: 'a,\n{\n    let lanes = f32::lanes::<S>();\n\n    // Splat reference values\n    let ref_x1_v: Vector<S, f32> = ref_x1.splat();\n    let ref_y1_v: Vector<S, f32> = ref_y1.splat();\n    let ref_x2_v: Vector<S, f32> = ref_x2.splat();\n    let ref_y2_v: Vector<S, f32> = ref_y2.splat();\n    let ref_area_v: Vector<S, f32> = ref_area.splat();\n    let thresh_v: Vector<S, f32> = threshold.splat();\n    let zero_v: Vector<S, f32> = 0.0f32.splat();\n\n    let mut i = 0;\n\n    let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit();\n    // Process lanes boxes at a time with SIMD\n    while i + lanes <= n_boxes {\n        // Skip if all boxes in this chunk are already suppressed\n        let all_suppressed = unsafe {\n            match lanes {\n                4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101,\n                8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101,\n                16 => {\n                    *(suppressed.as_ptr().add(i) as *const u128)\n                        == 0x01010101010101010101010101010101\n                }\n                _ => unreachable!(),\n            }\n        };\n\n        if !all_suppressed {\n            let x1_v: Vector<S, f32> = unsafe { vload(x1s.as_ptr().add(i)) };\n            let y1_v: Vector<S, f32> = unsafe { vload(y1s.as_ptr().add(i)) };\n            let x2_v: Vector<S, f32> = unsafe { vload(x2s.as_ptr().add(i)) };\n            let y2_v: Vector<S, f32> = unsafe { vload(y2s.as_ptr().add(i)) };\n            let area_v: Vector<S, f32> = unsafe { vload(areas.as_ptr().add(i)) };\n\n            // Compute intersection coordinates\n            let xx1 = ref_x1_v.max(x1_v);\n            let yy1 = ref_y1_v.max(y1_v);\n            let xx2 = ref_x2_v.min(x2_v);\n            let yy2 = ref_y2_v.min(y2_v);\n\n            // Compute intersection area (clamp to 0 for non-overlapping)\n            let w = (xx2 - xx1).max(zero_v);\n            let h = (yy2 - yy1).max(zero_v);\n            let inter = w * h;\n\n            // Compute IoU\n            let union = ref_area_v + area_v - inter;\n            let iou = inter / union;\n\n            // Get suppression mask (IoU > threshold)\n            let suppress_mask = iou.gt(thresh_v);\n\n            // Extract mask to bool array and apply to suppressed\n            // SAFETY: mask_store_as_bool writes exactly `lanes` bools, we only read 0..lanes\n            unsafe { f32::mask_store_as_bool::<S>(mask_buf.as_mut_ptr().cast(), suppress_mask) };\n            let mask_buf = unsafe { mask_buf.assume_init() };\n\n            for k in 0..lanes {\n                if mask_buf[k] {\n                    suppressed[i + k] = true;\n                }\n            }\n        }\n\n        i += lanes;\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cpu/ops.rs",
    "content": "#[cfg(feature = \"ndarray\")]\nmod ndarray {\n    use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};\n    use burn_ndarray::{\n        FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensor, QuantElement, SharedArray,\n    };\n\n    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> BoolVisionOps\n        for NdArray<E, I, Q>\n    where\n        NdArrayTensor: From<SharedArray<E>>,\n        NdArrayTensor: From<SharedArray<I>>,\n    {\n    }\n    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntVisionOps\n        for NdArray<E, I, Q>\n    where\n        NdArrayTensor: From<SharedArray<E>>,\n        NdArrayTensor: From<SharedArray<I>>,\n    {\n    }\n    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatVisionOps\n        for NdArray<E, I, Q>\n    where\n        NdArrayTensor: From<SharedArray<E>>,\n        NdArrayTensor: From<SharedArray<I>>,\n    {\n    }\n    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> QVisionOps for NdArray<E, I, Q>\n    where\n        NdArrayTensor: From<SharedArray<E>>,\n        NdArrayTensor: From<SharedArray<I>>,\n    {\n    }\n    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> VisionBackend\n        for NdArray<E, I, Q>\n    where\n        NdArrayTensor: From<SharedArray<E>>,\n        NdArrayTensor: From<SharedArray<I>>,\n    {\n    }\n}\n\n#[cfg(feature = \"tch\")]\nmod tch {\n    use crate::{BoolVisionOps, FloatVisionOps, IntVisionOps, QVisionOps, VisionBackend};\n    use burn_tch::{LibTorch, TchElement};\n\n    impl<E: TchElement, Q: burn_tch::QuantElement> BoolVisionOps for LibTorch<E, Q> {}\n    impl<E: TchElement, Q: burn_tch::QuantElement> IntVisionOps for LibTorch<E, Q> {}\n    impl<E: TchElement, Q: burn_tch::QuantElement> FloatVisionOps for LibTorch<E, Q> {}\n    impl<E: TchElement, Q: burn_tch::QuantElement> QVisionOps for LibTorch<E, Q> {}\n    impl<E: TchElement, Q: burn_tch::QuantElement> VisionBackend for LibTorch<E, Q> {}\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs",
    "content": "//! Hardware Accelerated 4-connected, adapted from\n//! A. Hennequin, L. Lacassagne, L. Cabaret, Q. Meunier,\n//! \"A new Direct Connected Component Labeling and Analysis Algorithms for GPUs\",\n//! DASIP, 2018\n\nuse crate::{\n    ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity,\n    backends::cube::connected_components::stats_from_opts,\n};\nuse burn_cubecl::{\n    BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement, kernel,\n    ops::{into_data_sync, numeric::zeros_client},\n    tensor::CubeTensor,\n};\nuse burn_tensor::{Shape, TensorMetadata, cast::ToElement, ops::IntTensorOps};\nuse cubecl::{features::Plane, prelude::*};\n\nuse super::prefix_sum::prefix_sum;\n\nconst BLOCK_H: usize = 4;\n\n#[cube]\nfn merge<I: Int>(labels: &Tensor<Atomic<I>>, label_1: u32, label_2: u32) {\n    let mut label_1 = label_1 as usize;\n    let mut label_2 = label_2 as usize;\n\n    while label_1 != label_2 && (label_1 != usize::cast_from(labels[label_1].load()) - 1) {\n        label_1 = usize::cast_from(labels[label_1].load()) - 1;\n    }\n    while label_1 != label_2 && (label_2 != usize::cast_from(labels[label_2].load()) - 1) {\n        label_2 = usize::cast_from(labels[label_2].load()) - 1;\n    }\n    while label_1 != label_2 {\n        #[allow(clippy::manual_swap)]\n        if label_1 < label_2 {\n            let tmp = label_1;\n            label_1 = label_2;\n            label_2 = tmp;\n        }\n        let label_3 = usize::cast_from(labels[label_1].fetch_min(I::cast_from(label_2 + 1))) - 1;\n        if label_1 == label_3 {\n            label_1 = label_2;\n        } else {\n            label_1 = label_3;\n        }\n    }\n}\n\n#[cube]\nfn start_distance(pixels: u32, tx: u32) -> u32 {\n    (!(pixels << (32 - tx))).leading_zeros()\n}\n\n#[cube]\nfn end_distance(pixels: u32, tx: u32) -> u32 {\n    (!(pixels >> (tx + 1))).find_first_set()\n}\n\n#[cube]\n#[allow(unconditional_panic, reason = \"clippy thinks PLANE_DIM is always 2\")]\nfn ballot_dyn(y: u32, pred: bool) -> u32 {\n    let index = y % (PLANE_DIM / 32);\n    plane_ballot(pred)[index as usize]\n}\n\n#[cube(launch_unchecked)]\nfn strip_labeling<I: Int, BT: CubePrimitive>(\n    img: &Tensor<BT>,\n    labels: &Tensor<Atomic<I>>,\n    #[comptime] connectivity: Connectivity,\n) {\n    let mut shared_pixels = SharedMemory::<u32>::new(BLOCK_H);\n\n    let y = ABSOLUTE_POS_Y;\n    let rows = labels.shape(0) as u32;\n    let cols = labels.shape(1) as u32;\n\n    if y >= rows {\n        terminate!();\n    }\n\n    let img_stride = img.stride(0) as u32;\n    let labels_stride = labels.stride(0) as u32;\n\n    let img_line_base = y * img_stride + UNIT_POS_X;\n    let labels_line_base = y * labels_stride + UNIT_POS_X;\n\n    let mut distance_y = 0u32;\n    let mut distance_y_1 = 0;\n\n    for i in range_stepped(0, img.shape(1) as u32, PLANE_DIM) {\n        let x = UNIT_POS_X + i;\n\n        if x < cols {\n            let mut mask = 0xffffffffu32;\n            let involved_cols = cols - i;\n            if involved_cols < 32 {\n                mask >>= 32 - involved_cols;\n            }\n\n            let img_index = img_line_base + i;\n            let labels_index = labels_line_base + i;\n\n            let p_y = bool::cast_from(img[img_index as usize]);\n\n            let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask;\n            let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X);\n\n            if p_y && s_dist_y == 0 {\n                labels[labels_index as usize].store(I::cast_from(\n                    labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1,\n                ));\n            }\n\n            // Only needed pre-Volta, but we can't check that at present\n            sync_cube();\n\n            if UNIT_POS_X == 0 {\n                shared_pixels[UNIT_POS_Y as usize] = pixels_y;\n            }\n\n            sync_cube();\n\n            // Requires if and not select, because `select` may execute the then branch even if the\n            // condition is false (on non-CUDA backends), which can lead to OOB reads.\n            let pixels_y_1 = if UNIT_POS_Y > 0 {\n                shared_pixels[(UNIT_POS_Y - 1) as usize]\n            } else {\n                0u32.runtime()\n            };\n\n            let p_y_1 = (pixels_y_1 >> UNIT_POS_X) & 1 != 0;\n            let mut s_dist_y_1 = start_distance(pixels_y_1, UNIT_POS_X);\n\n            if UNIT_POS_X == 0 {\n                s_dist_y = distance_y;\n                s_dist_y_1 = distance_y_1;\n            }\n\n            match connectivity {\n                Connectivity::Four => {\n                    if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {\n                        let label_1 = labels_index - s_dist_y;\n                        let label_2 = labels_index - s_dist_y_1 - labels_stride;\n                        merge(labels, label_1, label_2);\n                    }\n                }\n                Connectivity::Eight => {\n                    let pixels_y_shifted = (pixels_y << 1) | (distance_y > 0) as u32;\n                    let pixels_y_1_shifted = (pixels_y_1 << 1) | (distance_y_1 > 0) as u32;\n\n                    if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {\n                        let label_1 = labels_index - s_dist_y;\n                        let label_2 = labels_index - s_dist_y_1 - labels_stride;\n                        merge(labels, label_1, label_2);\n                    } else if p_y && s_dist_y == 0 && (pixels_y_1_shifted >> UNIT_POS_X) & 1 != 0 {\n                        let s_dist_y_1_prev = select(\n                            UNIT_POS_X == 0,\n                            distance_y_1 - 1,\n                            start_distance(pixels_y_1, UNIT_POS_X - 1),\n                        );\n                        let label_1 = labels_index;\n                        let label_2 = labels_index - labels_stride - 1 - s_dist_y_1_prev;\n                        merge(labels, label_1, label_2);\n                    } else if p_y_1 && s_dist_y_1 == 0 && (pixels_y_shifted >> UNIT_POS_X) & 1 != 0\n                    {\n                        let s_dist_y_prev = select(\n                            UNIT_POS_X == 0,\n                            distance_y - 1,\n                            start_distance(pixels_y, UNIT_POS_X - 1),\n                        );\n                        let label_1 = labels_index - 1 - s_dist_y_prev;\n                        let label_2 = labels_index - labels_stride;\n                        merge(labels, label_1, label_2);\n                    }\n                }\n            }\n\n            if p_y && p_y_1 && (s_dist_y == 0 || s_dist_y_1 == 0) {\n                let label_1 = labels_index - s_dist_y;\n                let label_2 = labels_index - s_dist_y_1 - labels_stride;\n                merge(labels, label_1, label_2);\n            }\n\n            let mut d = start_distance(pixels_y_1, 32);\n            distance_y_1 = d + select(d == 32, distance_y_1, 0);\n            d = start_distance(pixels_y, 32);\n            distance_y = d + select(d == 32, distance_y, 0);\n        }\n    }\n}\n\n#[cube(launch_unchecked)]\nfn strip_merge<I: Int, BT: CubePrimitive>(\n    img: &Tensor<BT>,\n    labels: &Tensor<Atomic<I>>,\n    #[comptime] connectivity: Connectivity,\n) {\n    let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM;\n    let y = (CUBE_POS_Y + 1) * BLOCK_H as u32;\n    let x = plane_start_x + UNIT_POS_X;\n\n    let img_step = img.stride(0) as u32;\n    let labels_step = labels.stride(0) as u32;\n    let cols = img.shape(1) as u32;\n\n    if y < labels.shape(0) as u32 && x < labels.shape(1) as u32 {\n        let mut mask = 0xffffffffu32;\n        if cols - plane_start_x < 32 {\n            mask >>= 32 - (cols - plane_start_x);\n        }\n\n        let img_index = y * img_step + x;\n        let labels_index = y * labels_step + x;\n\n        let img_index_up = img_index - img_step;\n        let labels_index_up = labels_index - labels_step;\n\n        let p = bool::cast_from(img[img_index as usize]);\n        let p_up = bool::cast_from(img[img_index_up as usize]);\n\n        let pixels = ballot_dyn(UNIT_POS_Z, p) & mask;\n        let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask;\n\n        match connectivity {\n            Connectivity::Four => {\n                if p && p_up {\n                    let s_dist = start_distance(pixels, UNIT_POS_X);\n                    let s_dist_up = start_distance(pixels_up, UNIT_POS_X);\n                    if s_dist == 0 || s_dist_up == 0 {\n                        merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);\n                    }\n                }\n            }\n            Connectivity::Eight => {\n                let mut last_dist_vec = SharedMemory::<u32>::new(32usize);\n                let mut last_dist_up_vec = SharedMemory::<u32>::new(32usize);\n\n                let s_dist = start_distance(pixels, UNIT_POS_X);\n                let s_dist_up = start_distance(pixels_up, UNIT_POS_X);\n\n                if UNIT_POS_PLANE == PLANE_DIM - 1 {\n                    last_dist_vec[UNIT_POS_Z as usize] = start_distance(pixels, 32);\n                    last_dist_up_vec[UNIT_POS_Z as usize] = start_distance(pixels_up, 32);\n                }\n\n                sync_cube();\n\n                if CUBE_POS_X == 0 || UNIT_POS_Z > 0 {\n                    let last_dist = if UNIT_POS_Z > 0 {\n                        last_dist_vec[(UNIT_POS_Z - 1) as usize]\n                    } else {\n                        0u32.runtime()\n                    };\n                    let last_dist_up = if UNIT_POS_Z > 0 {\n                        last_dist_up_vec[(UNIT_POS_Z - 1) as usize]\n                    } else {\n                        0u32.runtime()\n                    };\n\n                    let p_prev =\n                        select(UNIT_POS_X > 0, (pixels >> (UNIT_POS_X - 1)) & 1, last_dist) != 0;\n                    let p_up_prev = select(\n                        UNIT_POS_X > 0,\n                        (pixels_up >> (UNIT_POS_X - 1)) & 1,\n                        last_dist_up,\n                    ) != 0;\n\n                    if p && p_up {\n                        let s_dist = start_distance(pixels, UNIT_POS_X);\n                        let s_dist_up = start_distance(pixels_up, UNIT_POS_X);\n                        if s_dist == 0 || s_dist_up == 0 {\n                            merge(labels, labels_index - s_dist, labels_index_up - s_dist_up);\n                        }\n                    } else if p && p_up_prev && s_dist == 0 {\n                        let s_dist_up_prev = select(\n                            UNIT_POS_X == 0,\n                            last_dist_up - 1,\n                            start_distance(pixels_up, UNIT_POS_X - 1),\n                        );\n                        merge(labels, labels_index, labels_index_up - 1 - s_dist_up_prev);\n                    } else if p_prev && p_up && s_dist_up == 0 {\n                        let s_dist_prev = select(\n                            UNIT_POS_X == 0,\n                            last_dist - 1,\n                            start_distance(pixels, UNIT_POS_X - 1),\n                        );\n                        merge(labels, labels_index - 1 - s_dist_prev, labels_index_up);\n                    }\n                }\n            }\n        }\n    }\n}\n\n#[cube(launch_unchecked)]\nfn relabeling<I: Int, BT: CubePrimitive>(img: &Tensor<BT>, labels: &mut Tensor<I>) {\n    let plane_start_x = CUBE_POS_X * CUBE_DIM_X;\n    let y = ABSOLUTE_POS_Y;\n    let x = plane_start_x + UNIT_POS_X;\n\n    let cols = labels.shape(1) as u32;\n    let rows = labels.shape(0) as u32;\n    let img_step = img.stride(0) as u32;\n    let labels_step = labels.stride(0) as u32;\n\n    if x < cols && y < rows {\n        let mut mask = 0xffffffffu32;\n        if cols - plane_start_x < 32 {\n            mask >>= 32 - (cols - plane_start_x);\n        }\n\n        let img_index = y * img_step + x;\n        let labels_index = y * labels_step + x;\n\n        let p = bool::cast_from(img[img_index as usize]);\n        let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;\n        let s_dist = start_distance(pixels, UNIT_POS_X);\n        let mut label = 0u32;\n\n        if p && s_dist == 0 {\n            label = u32::cast_from(labels[labels_index as usize]) - 1;\n            while label != u32::cast_from(labels[label as usize]) - 1 {\n                label = u32::cast_from(labels[label as usize]) - 1;\n            }\n        }\n\n        label = plane_shuffle(label, UNIT_POS_X - s_dist);\n\n        if p {\n            labels[labels_index as usize] = I::cast_from(label + 1);\n        }\n    }\n}\n\n#[cube(launch_unchecked)]\nfn analysis<I: Int, BT: CubePrimitive>(\n    img: &Tensor<BT>,\n    labels: &mut Tensor<I>,\n    area: &mut Tensor<Atomic<I>>,\n    top: &mut Tensor<Atomic<I>>,\n    left: &mut Tensor<Atomic<I>>,\n    right: &mut Tensor<Atomic<I>>,\n    bottom: &mut Tensor<Atomic<I>>,\n    max_label: &mut Tensor<Atomic<I>>,\n    #[comptime] opts: ConnectedStatsOptions,\n) {\n    let y = ABSOLUTE_POS_Y;\n    let x = ABSOLUTE_POS_X;\n\n    let cols = labels.shape(1) as u32;\n    let rows = labels.shape(0) as u32;\n    let img_step = img.stride(0) as u32;\n    let labels_step = labels.stride(0) as u32;\n\n    if x < cols && y < rows {\n        let mut mask = 0xffffffffu32;\n        if cols - CUBE_POS_X * CUBE_DIM_X < 32 {\n            mask >>= 32 - (cols - CUBE_POS_X * CUBE_DIM_X);\n        }\n\n        let img_index = y * img_step + x;\n        let labels_index = y * labels_step + x;\n\n        let p = bool::cast_from(img[img_index as usize]);\n        let pixels = ballot_dyn(UNIT_POS_Y, p) & mask;\n        let s_dist = start_distance(pixels, UNIT_POS_X);\n        let count = end_distance(pixels, UNIT_POS_X);\n        let max_x = x + count - 1;\n\n        let mut label = 0u32;\n\n        if p && s_dist == 0 {\n            label = u32::cast_from(labels[labels_index as usize]) - 1;\n            while label != u32::cast_from(labels[label as usize]) - 1 {\n                label = u32::cast_from(labels[label as usize]) - 1;\n            }\n            label += 1;\n\n            area[label as usize].fetch_add(I::cast_from(count));\n\n            if opts.bounds_enabled {\n                left[label as usize].fetch_min(I::cast_from(x));\n                top[label as usize].fetch_min(I::cast_from(y));\n                right[label as usize].fetch_max(I::cast_from(max_x));\n                bottom[label as usize].fetch_max(I::cast_from(y));\n            }\n            if comptime!(opts.max_label_enabled || opts.compact_labels) {\n                max_label[0].fetch_max(I::cast_from(label));\n            }\n        }\n\n        label = plane_shuffle(label, UNIT_POS_X - s_dist);\n\n        if p {\n            labels[labels_index as usize] = I::cast_from(label);\n        }\n    }\n}\n\n#[cube(launch_unchecked)]\nfn compact_labels<I: Int>(\n    labels: &mut Tensor<I>,\n    remap: &Tensor<I>,\n    max_label: &Tensor<Atomic<I>>,\n) {\n    let x = ABSOLUTE_POS_X;\n    let y = ABSOLUTE_POS_Y;\n\n    let labels_pos = y * labels.stride(0) as u32 + x;\n\n    if labels_pos as usize >= labels.len() {\n        terminate!();\n    }\n\n    let label = u32::cast_from(labels[labels_pos as usize]);\n    if label != 0 {\n        let new_label = remap[label as usize];\n        labels[labels_pos as usize] = new_label;\n        max_label[0].fetch_max(new_label);\n    }\n}\n\n#[cube(launch_unchecked)]\nfn compact_stats<I: Int>(\n    area: &Tensor<I>,\n    area_new: &mut Tensor<I>,\n    top: &Tensor<I>,\n    top_new: &mut Tensor<I>,\n    left: &Tensor<I>,\n    left_new: &mut Tensor<I>,\n    right: &Tensor<I>,\n    right_new: &mut Tensor<I>,\n    bottom: &Tensor<I>,\n    bottom_new: &mut Tensor<I>,\n    remap: &Tensor<I>,\n) {\n    let label = ABSOLUTE_POS_X;\n    if label as usize >= remap.len() {\n        terminate!();\n    }\n\n    let area = area[label as usize];\n    if area == I::new(0) {\n        terminate!();\n    }\n    let new_label = u32::cast_from(remap[label as usize]);\n\n    area_new[new_label as usize] = area;\n    // This should be gated but there's a problem with the Eq bound only being implemented for tuples\n    // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work.\n    top_new[new_label as usize] = top[label as usize];\n    left_new[new_label as usize] = left[label as usize];\n    right_new[new_label as usize] = right[label as usize];\n    bottom_new[new_label as usize] = bottom[label as usize];\n}\n\n#[allow(clippy::type_complexity)]\npub fn hardware_accelerated<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement>(\n    img: CubeTensor<R>,\n    stats_opt: ConnectedStatsOptions,\n    connectivity: Connectivity,\n) -> Result<\n    (\n        CubeTensor<R>,\n        ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>,\n    ),\n    String,\n> {\n    let client = img.client.clone();\n    let device = img.device.clone();\n\n    if !client.properties().features.plane.contains(Plane::Ops) {\n        return Err(\"Requires plane instructions\".into());\n    }\n\n    let props = &client.properties().hardware;\n\n    if props.plane_size_min == 32 && props.plane_size_max == 32 {\n        return Err(\"Requires plane size of at least 32\".into());\n    }\n\n    // Somehow the kernel doesn't work on AMD and Apple Silicon.\n    //\n    // The check invalidates those, but probably not for the right reason.\n    if props.plane_size_max != 32 {\n        return Err(\"Requires plane size of at least 32\".into());\n    }\n\n    let [rows, cols] = img.meta.shape().dims();\n\n    let labels = zeros_client::<R>(client.clone(), device.clone(), img.shape(), I::dtype());\n\n    // Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32.\n    // This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp\n    // size at compile time. `REQUIRE_FULL_SUBGROUPS` or subgroup size controls are not supported\n    // in wgpu.\n    let warp_size = 32;\n    let cube_dim = CubeDim::new_2d(warp_size, BLOCK_H as u32);\n    let cube_count = CubeCount::new_2d(1, (rows as u32).div_ceil(cube_dim.y));\n\n    unsafe {\n        strip_labeling::launch_unchecked::<I, BT, R>(\n            &client,\n            cube_count,\n            cube_dim,\n            img.clone().into_tensor_arg(),\n            labels.clone().into_tensor_arg(),\n            connectivity,\n        )\n    };\n\n    let horizontal_warps = Ord::min((cols as u32).div_ceil(warp_size), 32);\n    let cube_dim_merge = CubeDim::new_3d(warp_size, 1, horizontal_warps);\n    let cube_count = CubeCount::new_2d(\n        Ord::max((cols as u32 + warp_size * 30 - 1) / (warp_size * 31), 1),\n        (rows as u32 - 1) / BLOCK_H as u32,\n    );\n\n    unsafe {\n        strip_merge::launch_unchecked::<I, BT, R>(\n            &client,\n            cube_count,\n            cube_dim_merge,\n            img.clone().into_tensor_arg(),\n            labels.clone().into_tensor_arg(),\n            connectivity,\n        )\n    };\n\n    let cube_count = CubeCount::new_2d(\n        (cols as u32).div_ceil(cube_dim.x),\n        (rows as u32).div_ceil(cube_dim.y),\n    );\n\n    let mut stats = stats_from_opts(labels.clone(), stats_opt);\n\n    if stats_opt == ConnectedStatsOptions::none() {\n        unsafe {\n            relabeling::launch_unchecked::<I, BT, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                img.into_tensor_arg(),\n                labels.clone().into_tensor_arg(),\n            )\n        };\n    } else {\n        unsafe {\n            analysis::launch_unchecked::<I, BT, R>(\n                &client,\n                cube_count,\n                cube_dim,\n                img.clone().into_tensor_arg(),\n                labels.clone().into_tensor_arg(),\n                stats.area.clone().into_tensor_arg(),\n                stats.top.clone().into_tensor_arg(),\n                stats.left.clone().into_tensor_arg(),\n                stats.right.clone().into_tensor_arg(),\n                stats.bottom.clone().into_tensor_arg(),\n                stats.max_label.clone().into_tensor_arg(),\n                stats_opt,\n            )\n        };\n        if stats_opt.compact_labels {\n            let max_label = CubeBackend::<R, F, I, BT>::int_max(stats.max_label);\n            let max_label = into_data_sync::<R>(max_label);\n            let max_label = ToElement::to_usize(&max_label.as_slice::<I>().unwrap()[0]);\n            let sliced = kernel::slice::<R>(\n                stats.area.clone(),\n                #[allow(clippy::single_range_in_vec_init)]\n                &[0..(max_label + 1).next_multiple_of(4)],\n            );\n            let relabel = prefix_sum::<R, I>(sliced);\n\n            let cube_dim = CubeDim::new_2d(32, 8);\n            let cube_count = CubeCount::new_2d(\n                (cols as u32).div_ceil(cube_dim.x),\n                (rows as u32).div_ceil(cube_dim.y),\n            );\n            stats.max_label =\n                zeros_client::<R>(client.clone(), device.clone(), Shape::new([1]), I::dtype());\n            unsafe {\n                compact_labels::launch_unchecked::<I, R>(\n                    &client,\n                    cube_count,\n                    cube_dim,\n                    labels.clone().into_tensor_arg(),\n                    relabel.clone().into_tensor_arg(),\n                    stats.max_label.clone().into_tensor_arg(),\n                )\n            };\n\n            let cube_dim = CubeDim::new_1d(256);\n            let cube_count = CubeCount::new_1d((rows * cols).div_ceil(256) as u32);\n            unsafe {\n                compact_stats::launch_unchecked::<I, R>(\n                    &client,\n                    cube_count,\n                    cube_dim,\n                    stats.area.copy().into_tensor_arg(),\n                    stats.area.clone().into_tensor_arg(),\n                    stats.top.copy().into_tensor_arg(),\n                    stats.top.clone().into_tensor_arg(),\n                    stats.left.copy().into_tensor_arg(),\n                    stats.left.clone().into_tensor_arg(),\n                    stats.right.copy().into_tensor_arg(),\n                    stats.right.clone().into_tensor_arg(),\n                    stats.bottom.copy().into_tensor_arg(),\n                    stats.bottom.clone().into_tensor_arg(),\n                    relabel.into_tensor_arg(),\n                )\n            };\n        }\n    }\n\n    Ok((labels, stats))\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cube/connected_components/mod.rs",
    "content": "mod hardware_accelerated;\n\n/// Should eventually make this a full op, but the kernel is too specialized on ints and plane ops\n/// to really use it in a general case. Needs more work to use as a normal tensor method.\nmod prefix_sum;\n\nuse burn_cubecl::{\n    BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement,\n    ops::numeric::{full_client, zeros_client},\n    tensor::CubeTensor,\n};\nuse burn_tensor::Shape;\npub use hardware_accelerated::*;\n\nuse crate::{ConnectedStatsOptions, ConnectedStatsPrimitive};\n\npub(crate) fn stats_from_opts<R, F, I, BT>(\n    l: CubeTensor<R>,\n    opts: ConnectedStatsOptions,\n) -> ConnectedStatsPrimitive<CubeBackend<R, F, I, BT>>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    let [height, width] = l.meta.shape().dims();\n    let shape = Shape::new([height * width]);\n    let zeros = || {\n        zeros_client::<R>(\n            l.client.clone(),\n            l.device.clone(),\n            shape.clone(),\n            I::dtype(),\n        )\n    };\n    let max = I::max_value();\n    let max = || full_client::<R, I>(l.client.clone(), shape.clone(), l.device.clone(), max);\n    let dummy = || {\n        CubeTensor::new_contiguous(\n            l.client.clone(),\n            l.device.clone(),\n            shape.clone(),\n            l.handle.clone(),\n            l.dtype,\n        )\n    };\n    ConnectedStatsPrimitive {\n        area: (opts != ConnectedStatsOptions::none())\n            .then(zeros)\n            .unwrap_or_else(dummy),\n        left: opts.bounds_enabled.then(max).unwrap_or_else(dummy),\n        top: opts.bounds_enabled.then(max).unwrap_or_else(dummy),\n        right: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),\n        bottom: opts.bounds_enabled.then(zeros).unwrap_or_else(dummy),\n        max_label: zeros_client::<R>(\n            l.client.clone(),\n            l.device.clone(),\n            Shape::new([1]),\n            I::dtype(),\n        ),\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs",
    "content": "use burn_tensor::{Shape, TensorMetadata};\nuse cubecl::prelude::*;\n\nuse burn_cubecl::{\n    CubeRuntime, IntElement,\n    ops::{\n        numeric::{empty_device, zeros_client},\n        reshape,\n    },\n    tensor::CubeTensor,\n};\n\nconst CUBE_SIZE: usize = 256;\nconst MIN_SUBGROUP_SIZE: usize = 4;\nconst MAX_REDUCE_SIZE: usize = CUBE_SIZE / MIN_SUBGROUP_SIZE;\n\nconst PART_SIZE: usize = 4096;\n\n#[cube(launch_unchecked)]\nfn prefix_sum_kernel<I: Int, N: Size>(\n    scan_in: &Tensor<Vector<I, N>>,\n    scan_out: &mut Tensor<Vector<I, N>>,\n    scan_bump: &Tensor<Atomic<I>>,\n    reduction: &Tensor<Atomic<I>>,\n    cube_count_x: usize,\n) {\n    let mut broadcast = SharedMemory::<I>::new(1usize);\n    let mut reduce = SharedMemory::<I>::new(MAX_REDUCE_SIZE);\n    let batch = CUBE_POS_Z as usize;\n    let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.vector_size());\n    let nums_per_cube = CUBE_SIZE * line_spt;\n    let v_last = comptime!(scan_in.vector_size() - 1);\n\n    //acquire partition index\n    if UNIT_POS_X == 0 {\n        broadcast[0] = scan_bump[batch].fetch_add(I::new(1));\n    }\n    sync_cube();\n    let part_id = usize::cast_from(broadcast[0]);\n\n    let plane_id = UNIT_POS_X / PLANE_DIM;\n    let dev_offs = part_id * nums_per_cube;\n    let plane_offs = UNIT_POS_X as usize * line_spt;\n\n    // Exit if full plane is out of bounds\n    if dev_offs + plane_offs >= scan_in.shape(1) {\n        terminate!();\n    }\n\n    let zero = I::new(0);\n\n    let flag_reduction = I::new(1);\n    let flag_inclusive = I::new(2);\n    let flag_mask = I::new(3);\n\n    let red_offs = batch * reduction.stride(0);\n    let scan_offs = batch * scan_in.stride(0);\n\n    let mut t_scan = Array::<Vector<I, N>>::new(line_spt);\n    {\n        let mut i = dev_offs + plane_offs + UNIT_POS_PLANE as usize;\n\n        if part_id < cube_count_x - 1 {\n            for k in 0..line_spt {\n                // Manually fuse not_equal and cast\n                let mut scan =\n                    Vector::cast_from(scan_in[i + scan_offs].not_equal(Vector::new(zero)));\n                #[unroll]\n                for v in 1..scan_in.vector_size() {\n                    let prev = scan[v - 1];\n                    scan[v] += prev;\n                }\n                t_scan[k] = scan;\n                i += PLANE_DIM as usize;\n            }\n        }\n\n        if part_id == cube_count_x - 1 {\n            for k in 0..line_spt {\n                if i < scan_in.shape(1) {\n                    // Manually fuse not_equal and cast\n                    let mut scan =\n                        Vector::cast_from(scan_in[i + scan_offs].not_equal(Vector::new(zero)));\n                    #[unroll]\n                    for v in 1..scan_in.vector_size() {\n                        let prev = scan[v - 1];\n                        scan[v] += prev;\n                    }\n                    t_scan[k] = scan;\n                }\n                i += PLANE_DIM as usize;\n            }\n        }\n\n        let mut prev = zero;\n        let plane_mask = PLANE_DIM - 1;\n        let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask;\n        for k in 0..line_spt {\n            let t = plane_shuffle(plane_inclusive_sum(t_scan[k][v_last]), circular_shift);\n            t_scan[k] += Vector::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev);\n            prev += plane_broadcast(t, 0u32);\n        }\n\n        if UNIT_POS_PLANE == 0 {\n            reduce[plane_id as usize] = prev;\n        }\n    }\n    sync_cube();\n\n    //Non-divergent subgroup agnostic inclusive scan across subgroup reductions\n    let lane_log = count_trailing_zeros(PLANE_DIM);\n    let spine_size = CUBE_DIM >> lane_log;\n    {\n        let mut offset_0 = 0;\n        let mut offset_1 = 0;\n        let aligned_size =\n            1 << ((count_trailing_zeros(spine_size) + lane_log + 1) / lane_log * lane_log);\n        let mut j = PLANE_DIM;\n        while j <= aligned_size {\n            let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0;\n            let pred_0 = i_0 < spine_size;\n            let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0 as usize], zero));\n            if pred_0 {\n                reduce[i_0 as usize] = t_0;\n            }\n            sync_cube();\n\n            if j != PLANE_DIM {\n                let rshift = j >> lane_log;\n                let i_1 = UNIT_POS_X + rshift;\n                if (i_1 & (j - 1)) >= rshift {\n                    let pred_1 = i_1 < spine_size;\n                    let t_1 = select(\n                        pred_1,\n                        reduce[(((i_1 >> offset_1) << offset_1) - 1) as usize],\n                        zero,\n                    );\n                    if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 {\n                        reduce[i_1 as usize] += t_1;\n                    }\n                }\n            } else {\n                offset_0 += 1;\n            }\n            offset_1 += lane_log;\n\n            j <<= lane_log;\n        }\n    }\n    sync_cube();\n\n    //Device broadcast\n    if UNIT_POS_X == 0 {\n        reduction[part_id + red_offs].store(\n            (reduce[(spine_size - 1) as usize] << I::new(2))\n                | select(part_id != 0, flag_reduction, flag_inclusive),\n        )\n    }\n\n    //Lookback, single thread\n    if part_id != 0 {\n        if UNIT_POS_X == 0 {\n            let mut lookback_id = part_id - 1;\n            let mut prev_reduction = zero;\n            loop {\n                let flag_payload = reduction[lookback_id + red_offs].load();\n                if (flag_payload & flag_mask) == flag_inclusive {\n                    prev_reduction += flag_payload >> I::new(2);\n                    reduction[part_id + red_offs].store(\n                        ((prev_reduction + reduce[(spine_size - 1) as usize]) << I::new(2))\n                            | flag_inclusive,\n                    );\n                    broadcast[0] = prev_reduction;\n                    break;\n                }\n\n                if (flag_payload & flag_mask) == flag_reduction {\n                    prev_reduction += flag_payload >> I::new(2);\n                    lookback_id -= 1;\n                }\n            }\n        }\n        sync_cube();\n    }\n\n    {\n        let prev = if plane_id != 0 {\n            reduce[(plane_id - 1) as usize]\n        } else {\n            zero\n        };\n        let prev = Vector::cast_from(broadcast[0] + prev);\n        let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt as u32;\n        let dev_offset = part_id * nums_per_cube;\n        let mut i = s_offset as usize + dev_offset;\n\n        if part_id < cube_count_x - 1 {\n            for k in 0..line_spt {\n                scan_out[i + scan_offs] = t_scan[k] + prev;\n                i += PLANE_DIM as usize;\n            }\n        }\n\n        if part_id == cube_count_x - 1 {\n            for k in 0..line_spt {\n                if i < scan_out.shape(1) {\n                    scan_out[i + scan_offs] = t_scan[k] + prev;\n                }\n                i += PLANE_DIM as usize;\n            }\n        }\n    }\n}\n\n#[cube]\nfn count_trailing_zeros(num: u32) -> u32 {\n    u32::find_first_set(num) - 1\n}\n\n/// Compute the prefix sum of a tensor\npub fn prefix_sum<R: CubeRuntime, I: IntElement>(input: CubeTensor<R>) -> CubeTensor<R> {\n    let client = input.client.clone();\n    let device = input.device.clone();\n    let num_elems = input.meta.num_elements();\n    let numbers = *input.meta.shape().last().unwrap();\n    let batches = num_elems / numbers;\n\n    let input = reshape(input, Shape::new([batches, numbers]));\n    let out = empty_device::<R, I>(client.clone(), device.clone(), input.shape());\n\n    let cubes = numbers.div_ceil(PART_SIZE);\n    let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32);\n    let cube_count = CubeCount::new_3d(cubes as u32, 1, batches as u32);\n\n    let bump = zeros_client::<R>(\n        client.clone(),\n        device.clone(),\n        Shape::new([batches]),\n        I::dtype(),\n    );\n    let reduction = zeros_client::<R>(\n        client.clone(),\n        device.clone(),\n        Shape::new([batches, cubes]),\n        I::dtype(),\n    );\n\n    unsafe {\n        prefix_sum_kernel::launch_unchecked::<I, R>(\n            &out.client,\n            cube_count,\n            cube_dim,\n            4,\n            input.into_tensor_arg(),\n            out.clone().into_tensor_arg(),\n            bump.into_tensor_arg(),\n            reduction.into_tensor_arg(),\n            cubes,\n        )\n    };\n\n    out\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cube/mod.rs",
    "content": "mod connected_components;\nmod ops;\n"
  },
  {
    "path": "crates/burn-vision/src/backends/cube/ops.rs",
    "content": "use crate::{\n    BoolVisionOps, ConnectedStatsOptions, ConnectedStatsPrimitive, Connectivity, FloatVisionOps,\n    IntVisionOps, QVisionOps, VisionBackend, backends::cpu,\n};\nuse burn_cubecl::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};\n\nuse burn_tensor::{\n    Element,\n    ops::{BoolTensor, IntTensor},\n};\n\nuse super::connected_components::hardware_accelerated;\n\nimpl<R, F, I, BT> BoolVisionOps for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {\n        hardware_accelerated::<R, F, I, BT>(\n            img.clone(),\n            ConnectedStatsOptions::none(),\n            connectivity,\n        )\n        .map(|it| it.0)\n        .unwrap_or_else(|_| cpu::connected_components::<Self>(img, connectivity))\n    }\n\n    fn connected_components_with_stats(\n        img: BoolTensor<Self>,\n        connectivity: Connectivity,\n        opts: ConnectedStatsOptions,\n    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {\n        hardware_accelerated::<R, F, I, BT>(img.clone(), opts, connectivity).unwrap_or_else(|_| {\n            cpu::connected_components_with_stats::<Self>(img, connectivity, opts)\n        })\n    }\n}\n\nimpl<R, F, I, BT> IntVisionOps for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n}\nimpl<R, F, I, BT> FloatVisionOps for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n}\nimpl<R, F, I, BT> QVisionOps for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n}\nimpl<R, F, I, BT> VisionBackend for CubeBackend<R, F, I, BT>\nwhere\n    R: CubeRuntime,\n    F: FloatElement,\n    I: IntElement,\n    BT: BoolElement,\n{\n}\n\n#[cfg(feature = \"fusion\")]\nmod fusion {\n    use super::*;\n    use burn_fusion::{\n        Fusion, FusionBackend, FusionRuntime,\n        stream::{Operation, OperationStreams},\n    };\n    use burn_ir::{CustomOpIr, HandleContainer, OperationIr, OperationOutput, TensorIr};\n    use burn_tensor::Shape;\n\n    impl<B: FusionBackend + BoolVisionOps> BoolVisionOps for Fusion<B> {\n        fn connected_components(img: BoolTensor<Self>, conn: Connectivity) -> IntTensor<Self> {\n            let height = img.shape[0];\n            let width = img.shape[1];\n            let client = img.client.clone();\n\n            #[derive(derive_new::new, Clone, Debug)]\n            struct ConnComp<B> {\n                desc: CustomOpIr,\n                conn: Connectivity,\n                _b: core::marker::PhantomData<B>,\n            }\n\n            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnComp<B1> {\n                fn execute(\n                    &self,\n                    handles: &mut HandleContainer<\n                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,\n                    >,\n                ) {\n                    let ([img], [labels]) = self.desc.as_fixed();\n                    let input = handles.get_bool_tensor::<B1>(img);\n                    let output = B1::connected_components(input, self.conn);\n\n                    handles.register_int_tensor::<B1>(&labels.id, output);\n                }\n            }\n\n            let streams = OperationStreams::with_inputs([&img]);\n            let out = TensorIr::uninit(\n                client.create_empty_handle(),\n                Shape::new([height, width]),\n                B::IntElem::dtype(),\n            );\n\n            let desc = CustomOpIr::new(\"connected_components\", &[img.into_ir()], &[out]);\n            client\n                .register(\n                    streams,\n                    OperationIr::Custom(desc.clone()),\n                    ConnComp::<B>::new(desc, conn),\n                )\n                .output()\n        }\n\n        fn connected_components_with_stats(\n            img: BoolTensor<Self>,\n            conn: Connectivity,\n            opts: ConnectedStatsOptions,\n        ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {\n            let height = img.shape[0];\n            let width = img.shape[1];\n            let client = img.client.clone();\n\n            #[derive(derive_new::new, Clone, Debug)]\n            struct ConnCompStats<B> {\n                desc: CustomOpIr,\n                conn: Connectivity,\n                opts: ConnectedStatsOptions,\n                _b: core::marker::PhantomData<B>,\n            }\n\n            impl<B1: FusionBackend + BoolVisionOps> Operation<B1::FusionRuntime> for ConnCompStats<B1> {\n                fn execute(\n                    &self,\n                    handles: &mut HandleContainer<\n                        <B1::FusionRuntime as FusionRuntime>::FusionHandle,\n                    >,\n                ) {\n                    let ([img], [labels, area, left, top, right, bottom, max_label]) =\n                        self.desc.as_fixed();\n                    let input = handles.get_bool_tensor::<B1>(img);\n                    let (output, stats) =\n                        B1::connected_components_with_stats(input, self.conn, self.opts);\n\n                    handles.register_int_tensor::<B1>(&labels.id, output);\n                    handles.register_int_tensor::<B1>(&area.id, stats.area);\n                    handles.register_int_tensor::<B1>(&left.id, stats.left);\n                    handles.register_int_tensor::<B1>(&top.id, stats.top);\n                    handles.register_int_tensor::<B1>(&right.id, stats.right);\n                    handles.register_int_tensor::<B1>(&bottom.id, stats.bottom);\n                    handles.register_int_tensor::<B1>(&max_label.id, stats.max_label);\n                }\n            }\n\n            let dtype = B::IntElem::dtype();\n            let shape = Shape::new([height, width]);\n            let shape_flat = shape.clone().flatten();\n            let streams = OperationStreams::with_inputs([&img]);\n            let out = TensorIr::uninit(client.create_empty_handle(), shape.clone(), dtype);\n            let area = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);\n            let left = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);\n            let top = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);\n            let right = TensorIr::uninit(client.create_empty_handle(), shape_flat.clone(), dtype);\n            let bottom = TensorIr::uninit(client.create_empty_handle(), shape_flat, dtype);\n            let max_label = TensorIr::uninit(client.create_empty_handle(), [1].into(), dtype);\n\n            let desc = CustomOpIr::new(\n                \"connected_components\",\n                &[img.into_ir()],\n                &[out, area, left, top, right, bottom, max_label],\n            );\n            let [out, area, left, top, right, bottom, max_label] = client\n                .register(\n                    streams,\n                    OperationIr::Custom(desc.clone()),\n                    ConnCompStats::<B>::new(desc, conn, opts),\n                )\n                .try_into()\n                .unwrap();\n\n            let stats = ConnectedStatsPrimitive {\n                area,\n                left,\n                top,\n                right,\n                bottom,\n                max_label,\n            };\n            (out, stats)\n        }\n    }\n    impl<B: FusionBackend + IntVisionOps> IntVisionOps for Fusion<B> {}\n    impl<B: FusionBackend + FloatVisionOps> FloatVisionOps for Fusion<B> {}\n    impl<B: FusionBackend + QVisionOps> QVisionOps for Fusion<B> {}\n    impl<B: FusionBackend + VisionBackend> VisionBackend for Fusion<B> {}\n}\n"
  },
  {
    "path": "crates/burn-vision/src/backends/mod.rs",
    "content": "pub(crate) mod cpu;\n#[cfg(feature = \"cubecl-backend\")]\nmod cube;\n\npub use cpu::{KernelShape, create_structuring_element};\n"
  },
  {
    "path": "crates/burn-vision/src/base.rs",
    "content": "use derive_new::new;\n\n/// 2D size used for vision ops.\n#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub struct Size {\n    /// Width of the element\n    pub width: usize,\n    /// Height of the element\n    pub height: usize,\n}\n\n/// 2D Point used for vision ops. Coordinates start at the top left.\n#[derive(new, Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub struct Point {\n    /// X (horizontal) coordinate\n    pub x: usize,\n    /// Y (vertical) coordinate\n    pub y: usize,\n}\n"
  },
  {
    "path": "crates/burn-vision/src/lib.rs",
    "content": "//! Vision ops for burn, with GPU acceleration where possible.\n//!\n//! # Operations\n//! Operation names are based on `opencv` wherever applicable.\n//!\n//! Currently implemented are:\n//! - `connected_components`\n//! - `connected_components_with_stats`\n//! - `nms` (Non-Maximum Suppression)\n//!\n\n#![warn(missing_docs)]\n\nextern crate alloc;\n\n/// Backend implementations for JIT and CPU\npub mod backends;\nmod base;\nmod ops;\nmod tensor;\nmod transform;\n\npub use base::*;\npub use ops::*;\npub use tensor::*;\npub use transform::*;\n\n/// Module for vision/image utilities\npub mod utils;\n\npub use backends::{KernelShape, create_structuring_element};\n"
  },
  {
    "path": "crates/burn-vision/src/ops/base.rs",
    "content": "use crate::{\n    Point,\n    backends::cpu::{self, MorphOp, morph},\n};\nuse bon::Builder;\nuse burn_tensor::{\n    Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,\n    backend::Backend,\n    ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},\n};\n\n/// Connected components connectivity\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub enum Connectivity {\n    /// Four-connected (only connected in cardinal directions)\n    Four,\n    /// Eight-connected (connected if any of the surrounding 8 pixels are in the foreground)\n    Eight,\n}\n\n/// Which stats should be enabled for `connected_components_with_stats`.\n/// Currently only used by the GPU implementation to save on atomic operations for unneeded stats.\n///\n/// Disabled stats are aliased to the labels tensor\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]\npub struct ConnectedStatsOptions {\n    /// Whether to enable bounding boxes\n    pub bounds_enabled: bool,\n    /// Whether to enable the max label\n    pub max_label_enabled: bool,\n    /// Whether labels must be contiguous starting at 1\n    pub compact_labels: bool,\n}\n\n/// Options for morphology ops\n#[derive(Clone, Debug, Builder)]\npub struct MorphOptions<B: Backend, K: TensorKind<B>> {\n    /// Anchor position within the kernel. Defaults to the center.\n    pub anchor: Option<Point>,\n    /// Number of iterations to apply\n    #[builder(default = 1)]\n    pub iterations: usize,\n    /// Border type. Default: constant based on operation\n    #[builder(default)]\n    pub border_type: BorderType,\n    /// Value of each channel for constant border type\n    pub border_value: Option<Tensor<B, 1, K>>,\n}\n\nimpl<B: Backend, K: TensorKind<B>> Default for MorphOptions<B, K> {\n    fn default() -> Self {\n        Self {\n            anchor: Default::default(),\n            iterations: 1,\n            border_type: Default::default(),\n            border_value: Default::default(),\n        }\n    }\n}\n\n/// Morphology border type\n#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]\npub enum BorderType {\n    /// Constant border with per-channel value. If no value is provided, the value is picked based\n    /// on the morph op.\n    #[default]\n    Constant,\n    /// Replicate first/last element\n    Replicate,\n    /// Reflect start/end elements\n    Reflect,\n    /// Reflect start/end elements, ignoring the first/last element\n    Reflect101,\n    /// Not supported for erode/dilate\n    Wrap,\n}\n\n/// Stats collected by the connected components analysis\n///\n/// Disabled analyses may be aliased to labels\n#[derive(Clone, Debug)]\npub struct ConnectedStats<B: Backend> {\n    /// Total area of each component\n    pub area: Tensor<B, 1, Int>,\n    /// Topmost y coordinate in the component\n    pub top: Tensor<B, 1, Int>,\n    /// Leftmost x coordinate in the component\n    pub left: Tensor<B, 1, Int>,\n    /// Rightmost x coordinate in the component\n    pub right: Tensor<B, 1, Int>,\n    /// Bottommost y coordinate in the component\n    pub bottom: Tensor<B, 1, Int>,\n    /// Scalar tensor of the max label\n    pub max_label: Tensor<B, 1, Int>,\n}\n\n/// Primitive version of [`ConnectedStats`], to be returned by the backend\npub struct ConnectedStatsPrimitive<B: Backend> {\n    /// Total area of each component\n    pub area: IntTensor<B>,\n    /// Leftmost x coordinate in the component\n    pub left: IntTensor<B>,\n    /// Topmost y coordinate in the component\n    pub top: IntTensor<B>,\n    /// Rightmost x coordinate in the component\n    pub right: IntTensor<B>,\n    /// Bottommost y coordinate in the component\n    pub bottom: IntTensor<B>,\n    /// Scalar tensor of the max label\n    pub max_label: IntTensor<B>,\n}\n\nimpl<B: Backend> From<ConnectedStatsPrimitive<B>> for ConnectedStats<B> {\n    fn from(value: ConnectedStatsPrimitive<B>) -> Self {\n        ConnectedStats {\n            area: Tensor::from_primitive(value.area),\n            top: Tensor::from_primitive(value.top),\n            left: Tensor::from_primitive(value.left),\n            right: Tensor::from_primitive(value.right),\n            bottom: Tensor::from_primitive(value.bottom),\n            max_label: Tensor::from_primitive(value.max_label),\n        }\n    }\n}\n\nimpl<B: Backend> ConnectedStats<B> {\n    /// Convert a connected stats into the corresponding primitive\n    pub fn into_primitive(self) -> ConnectedStatsPrimitive<B> {\n        ConnectedStatsPrimitive {\n            area: self.area.into_primitive(),\n            top: self.top.into_primitive(),\n            left: self.left.into_primitive(),\n            right: self.right.into_primitive(),\n            bottom: self.bottom.into_primitive(),\n            max_label: self.max_label.into_primitive(),\n        }\n    }\n}\n\nimpl Default for ConnectedStatsOptions {\n    fn default() -> Self {\n        Self::all()\n    }\n}\n\nimpl ConnectedStatsOptions {\n    /// Don't collect any stats\n    pub fn none() -> Self {\n        Self {\n            bounds_enabled: false,\n            max_label_enabled: false,\n            compact_labels: false,\n        }\n    }\n\n    /// Collect all stats\n    pub fn all() -> Self {\n        Self {\n            bounds_enabled: true,\n            max_label_enabled: true,\n            compact_labels: true,\n        }\n    }\n}\n\n/// Non-Maximum Suppression options.\n#[derive(Clone, Copy, Debug)]\npub struct NmsOptions {\n    /// IoU threshold for suppression (default: 0.5).\n    /// Boxes with IoU > threshold with a higher-scoring box are suppressed.\n    pub iou_threshold: f32,\n    /// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).\n    /// Boxes with score < score_threshold are discarded.\n    pub score_threshold: f32,\n    /// Maximum number of boxes to keep (0 = unlimited).\n    pub max_output_boxes: usize,\n}\n\nimpl Default for NmsOptions {\n    fn default() -> Self {\n        Self {\n            iou_threshold: 0.5,\n            score_threshold: 0.0,\n            max_output_boxes: 0,\n        }\n    }\n}\n\n/// Vision capable backend, implemented by each backend\npub trait VisionBackend:\n    BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend\n{\n}\n\n/// Vision ops on bool tensors\npub trait BoolVisionOps: Backend {\n    /// Computes the connected components labeled image of boolean image with 4 or 8 way\n    /// connectivity - returns a tensor of the component label of each pixel.\n    ///\n    /// `img`- The boolean image tensor in the format [batches, height, width]\n    fn connected_components(img: BoolTensor<Self>, connectivity: Connectivity) -> IntTensor<Self> {\n        cpu::connected_components::<Self>(img, connectivity)\n    }\n\n    /// Computes the connected components labeled image of boolean image with 4 or 8 way\n    /// connectivity and collects statistics on each component - returns a tensor of the component\n    /// label of each pixel, along with stats collected for each component.\n    ///\n    /// `img`- The boolean image tensor in the format [batches, height, width]\n    fn connected_components_with_stats(\n        img: BoolTensor<Self>,\n        connectivity: Connectivity,\n        opts: ConnectedStatsOptions,\n    ) -> (IntTensor<Self>, ConnectedStatsPrimitive<Self>) {\n        cpu::connected_components_with_stats(img, connectivity, opts)\n    }\n\n    /// Erodes an input tensor with the specified kernel.\n    fn bool_erode(\n        input: BoolTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Bool>,\n    ) -> BoolTensor<Self> {\n        let input = Tensor::<Self, 3, Bool>::from_primitive(input);\n        morph(input, kernel, MorphOp::Erode, opts).into_primitive()\n    }\n\n    /// Dilates an input tensor with the specified kernel.\n    fn bool_dilate(\n        input: BoolTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Bool>,\n    ) -> BoolTensor<Self> {\n        let input = Tensor::<Self, 3, Bool>::from_primitive(input);\n        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()\n    }\n}\n\n/// Vision ops on int tensors\npub trait IntVisionOps: Backend {\n    /// Erodes an input tensor with the specified kernel.\n    fn int_erode(\n        input: IntTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Int>,\n    ) -> IntTensor<Self> {\n        let input = Tensor::<Self, 3, Int>::from_primitive(input);\n        morph(input, kernel, MorphOp::Erode, opts).into_primitive()\n    }\n\n    /// Dilates an input tensor with the specified kernel.\n    fn int_dilate(\n        input: IntTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Int>,\n    ) -> IntTensor<Self> {\n        let input = Tensor::<Self, 3, Int>::from_primitive(input);\n        morph(input, kernel, MorphOp::Dilate, opts).into_primitive()\n    }\n}\n\n/// Vision ops on float tensors\npub trait FloatVisionOps: Backend {\n    /// Erodes an input tensor with the specified kernel.\n    fn float_erode(\n        input: FloatTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Float>,\n    ) -> FloatTensor<Self> {\n        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));\n\n        morph(input, kernel, MorphOp::Erode, opts)\n            .into_primitive()\n            .tensor()\n    }\n\n    /// Dilates an input tensor with the specified kernel.\n    fn float_dilate(\n        input: FloatTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Float>,\n    ) -> FloatTensor<Self> {\n        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::Float(input));\n        morph(input, kernel, MorphOp::Dilate, opts)\n            .into_primitive()\n            .tensor()\n    }\n\n    /// Perform Non-Maximum Suppression on bounding boxes.\n    ///\n    /// Returns indices of kept boxes after suppressing overlapping detections.\n    /// Boxes are processed in descending score order; a box suppresses all\n    /// lower-scoring boxes with IoU > threshold.\n    ///\n    /// # Arguments\n    /// * `boxes` - Bounding boxes as \\[N, 4\\] tensor in (x1, y1, x2, y2) format\n    /// * `scores` - Confidence scores as \\[N\\] tensor\n    /// * `options` - NMS options (IoU threshold, score threshold, max boxes)\n    ///\n    /// # Returns\n    /// Indices of kept boxes as \\[M\\] tensor where M <= N\n    fn nms(\n        boxes: FloatTensor<Self>,\n        scores: FloatTensor<Self>,\n        options: NmsOptions,\n    ) -> IntTensor<Self> {\n        let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));\n        let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));\n        cpu::nms::<Self>(boxes, scores, options).into_primitive()\n    }\n}\n\n/// Vision ops on quantized float tensors\npub trait QVisionOps: Backend {\n    /// Erodes an input tensor with the specified kernel.\n    fn q_erode(\n        input: QuantizedTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Float>,\n    ) -> QuantizedTensor<Self> {\n        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));\n        match morph(input, kernel, MorphOp::Erode, opts).into_primitive() {\n            TensorPrimitive::QFloat(tensor) => tensor,\n            _ => unreachable!(),\n        }\n    }\n\n    /// Dilates an input tensor with the specified kernel.\n    fn q_dilate(\n        input: QuantizedTensor<Self>,\n        kernel: BoolTensor<Self>,\n        opts: MorphOptions<Self, Float>,\n    ) -> QuantizedTensor<Self> {\n        let input = Tensor::<Self, 3>::from_primitive(TensorPrimitive::QFloat(input));\n        match morph(input, kernel, MorphOp::Dilate, opts).into_primitive() {\n            TensorPrimitive::QFloat(tensor) => tensor,\n            _ => unreachable!(),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/ops/mod.rs",
    "content": "mod base;\n\npub use base::*;\n"
  },
  {
    "path": "crates/burn-vision/src/tensor.rs",
    "content": "use burn_tensor::{\n    BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::Backend,\n    ops::BoolTensor,\n};\n\nuse crate::{\n    BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,\n    VisionBackend,\n};\n\n/// Connected components tensor extensions\npub trait ConnectedComponents<B: Backend> {\n    /// Computes the connected components labeled image of boolean image with 4 or 8 way\n    /// connectivity - returns a tensor of the component label of each pixel.\n    ///\n    /// `img`- The boolean image tensor in the format [batches, height, width]\n    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int>;\n\n    /// Computes the connected components labeled image of boolean image with 4 or 8 way\n    /// connectivity and collects statistics on each component - returns a tensor of the component\n    /// label of each pixel, along with stats collected for each component.\n    ///\n    /// `img`- The boolean image tensor in the format [batches, height, width]\n    fn connected_components_with_stats(\n        self,\n        connectivity: Connectivity,\n        options: ConnectedStatsOptions,\n    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>);\n}\n\n/// Morphology tensor operations\npub trait Morphology<B: Backend, K: TensorKind<B>> {\n    /// Erodes this tensor using the specified kernel.\n    /// Assumes NHWC layout.\n    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;\n    /// Dilates this tensor using the specified kernel.\n    /// Assumes NHWC layout.\n    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self;\n}\n\n/// Morphology tensor operations\npub trait MorphologyKind<B: Backend>: BasicOps<B> {\n    /// Erodes this tensor using the specified kernel\n    fn erode(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive;\n    /// Dilates this tensor using the specified kernel\n    fn dilate(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive;\n}\n\n/// Non-maximum suppression tensor operations\npub trait Nms<B: Backend> {\n    /// Perform Non-Maximum Suppression on this tensor of bounding boxes.\n    ///\n    /// Returns indices of kept boxes after suppressing overlapping detections.\n    /// Boxes are processed in descending score order; a box suppresses all\n    /// lower-scoring boxes with IoU > threshold.\n    ///\n    /// # Arguments\n    /// * `self` - Bounding boxes as \\[N, 4\\] tensor in (x1, y1, x2, y2) format\n    /// * `scores` - Confidence scores as \\[N\\] tensor\n    /// * `options` - NMS options (IoU threshold, score threshold, max boxes)\n    ///\n    /// # Returns\n    /// Indices of kept boxes as \\[M\\] tensor where M <= N\n    fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;\n}\n\nimpl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {\n    fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {\n        Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))\n    }\n\n    fn connected_components_with_stats(\n        self,\n        connectivity: Connectivity,\n        options: ConnectedStatsOptions,\n    ) -> (Tensor<B, 2, Int>, ConnectedStats<B>) {\n        let (labels, stats) =\n            B::connected_components_with_stats(self.into_primitive(), connectivity, options);\n        (Tensor::from_primitive(labels), stats.into())\n    }\n}\n\nimpl<B: VisionBackend, K: MorphologyKind<B>> Morphology<B, K> for Tensor<B, 3, K> {\n    fn erode(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {\n        Tensor::new(K::erode(\n            self.into_primitive(),\n            kernel.into_primitive(),\n            opts,\n        ))\n    }\n\n    fn dilate(self, kernel: Tensor<B, 2, Bool>, opts: MorphOptions<B, K>) -> Self {\n        Tensor::new(K::dilate(\n            self.into_primitive(),\n            kernel.into_primitive(),\n            opts,\n        ))\n    }\n}\n\nimpl<B: VisionBackend> MorphologyKind<B> for Float {\n    fn erode(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_erode(tensor, kernel, opts))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_erode(tensor, kernel, opts))\n            }\n        }\n    }\n\n    fn dilate(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        match tensor {\n            TensorPrimitive::Float(tensor) => {\n                TensorPrimitive::Float(B::float_dilate(tensor, kernel, opts))\n            }\n            TensorPrimitive::QFloat(tensor) => {\n                TensorPrimitive::QFloat(B::q_dilate(tensor, kernel, opts))\n            }\n        }\n    }\n}\n\nimpl<B: VisionBackend> MorphologyKind<B> for Int {\n    fn erode(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        B::int_erode(tensor, kernel, opts)\n    }\n\n    fn dilate(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        B::int_dilate(tensor, kernel, opts)\n    }\n}\n\nimpl<B: VisionBackend> MorphologyKind<B> for Bool {\n    fn erode(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        B::bool_erode(tensor, kernel, opts)\n    }\n\n    fn dilate(\n        tensor: Self::Primitive,\n        kernel: BoolTensor<B>,\n        opts: MorphOptions<B, Self>,\n    ) -> Self::Primitive {\n        B::bool_dilate(tensor, kernel, opts)\n    }\n}\n\nimpl<B: VisionBackend> Nms<B> for Tensor<B, 2> {\n    fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {\n        match (self.into_primitive(), scores.into_primitive()) {\n            (TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {\n                Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))\n            }\n            _ => todo!(\"Quantized inputs are not yet supported\"),\n        }\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/tests/mod.rs",
    "content": "use std::path::PathBuf;\n\nuse burn_tensor::{Shape, Tensor, TensorData, backend::Backend};\nuse image::{DynamicImage, ImageBuffer, Luma, Rgb};\n\nmod connected_components;\nmod morphology;\n\n#[macro_export]\nmacro_rules! testgen_all {\n    () => {\n        use burn_tensor::{Bool, Float, Int};\n\n        pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;\n        pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;\n        pub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;\n\n        pub mod vision {\n            pub use super::*;\n\n            pub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;\n\n            burn_vision::testgen_connected_components!();\n            burn_vision::testgen_morphology!();\n        }\n    };\n}\n"
  },
  {
    "path": "crates/burn-vision/src/transform/mod.rs",
    "content": "mod transform2d;\n\npub use transform2d::*;\n"
  },
  {
    "path": "crates/burn-vision/src/transform/transform2d.rs",
    "content": "use burn_tensor::{\n    Tensor,\n    backend::Backend,\n    grid::affine_grid_2d,\n    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},\n};\n\n/// 2D point transformation\n///\n/// Useful for resampling: rotating, scaling, translating, etc image tensors\npub struct Transform2D {\n    // 2x3 transformation matrix, to be used with column vectors:\n    // T(x) = Ax\n    transform: [[f32; 3]; 2],\n}\n\nimpl Transform2D {\n    /// Transforms an image\n    ///\n    /// * `img` - Images tensor with shape (batch_size, channels, height, width)\n    ///\n    /// # Returns\n    ///\n    /// A tensor with the same as the input\n    pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {\n        let [batch_size, channels, height, width] = img.shape().dims();\n        let transform = Tensor::<B, 2>::from(self.transform);\n        let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);\n        let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);\n\n        let options = GridSampleOptions::new(InterpolateMode::Bilinear)\n            .with_padding_mode(GridSamplePaddingMode::Border)\n            .with_align_corners(true);\n        img.grid_sample_2d(grid, options)\n    }\n\n    /// Makes a 2d transformation composed of other transformations\n    pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {\n        let mut result = Self::identity();\n        for t in transforms.into_iter() {\n            result = result.mul(t);\n        }\n        result\n    }\n\n    /// Multiply two affine transforms represented as 2x3 matrices\n    fn mul(self, other: Transform2D) -> Transform2D {\n        let mut result = [[0.0f32; 3]; 2];\n\n        // Row 0\n        result[0][0] = self.transform[0][0] * other.transform[0][0]\n            + self.transform[0][1] * other.transform[1][0];\n        result[0][1] = self.transform[0][0] * other.transform[0][1]\n            + self.transform[0][1] * other.transform[1][1];\n        result[0][2] = self.transform[0][0] * other.transform[0][2]\n            + self.transform[0][1] * other.transform[1][2]\n            + self.transform[0][2];\n\n        // Row 1\n        result[1][0] = self.transform[1][0] * other.transform[0][0]\n            + self.transform[1][1] * other.transform[1][0];\n        result[1][1] = self.transform[1][0] * other.transform[0][1]\n            + self.transform[1][1] * other.transform[1][1];\n        result[1][2] = self.transform[1][0] * other.transform[0][2]\n            + self.transform[1][1] * other.transform[1][2]\n            + self.transform[1][2];\n\n        Transform2D { transform: result }\n    }\n\n    /// Makes an identity transform (x = Ax)\n    pub fn identity() -> Self {\n        Self {\n            transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],\n        }\n    }\n\n    /// Makes a [`Transform2D`] for rotating a tensor\n    ///\n    /// * `theta` - In radians, the rotation\n    /// * `cx` - Center of rotation, x\n    /// * `cy` - Center of rotation, y\n    pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {\n        let cos_theta = theta.cos();\n        let sin_theta = theta.sin();\n\n        let transform = [\n            [cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],\n            [sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],\n        ];\n\n        Self { transform }\n    }\n\n    /// Makes a [`Transform2D`] for scaling an image tensor\n    ///\n    /// * `sx` - Scale factor in the x direction\n    /// * `sy` - Scale factor in the y direction\n    /// * `cx` - Center of scaling, x\n    /// * `cy` - Center of scaling, y\n    pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {\n        let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];\n\n        Self { transform }\n    }\n\n    /// Makes a [`Transform2D`] for translating an image tensor\n    ///\n    /// * `tx` - Translation in the x direction\n    /// * `ty` - Translation in the y direction\n    pub fn translation(tx: f32, ty: f32) -> Self {\n        let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];\n\n        Self { transform }\n    }\n\n    /// Applies a general shear transformation around the image center,\n    /// combining both X and Y shear.\n    ///\n    /// # Arguments\n    /// * `shx` - Shear factor along the X-axis.\n    /// * `shy` - Shear factor along the Y-axis.\n    /// * `cx`, `cy` - Coordinates of the image center.\n    ///\n    /// # Returns\n    /// * `Self` with a combined shear transform matrix.\n    pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {\n        let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];\n\n        Self { transform }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_ndarray::NdArray;\n    use burn_tensor::Tolerance;\n    type B = NdArray;\n\n    #[test]\n    fn transform_identity_translation() {\n        let t = Transform2D::translation(0.0, 0.0);\n        let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);\n        let image_transformed = t.transform(image_original.clone());\n        image_original\n            .to_data()\n            .assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_translation() {\n        let t = Transform2D::translation(1., 1.);\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        // This result would change if the padding method is different\n        let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_rotation_90_degrees() {\n        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_rotation_around_corner() {\n        let cx = 1.;\n        let cy = -1.;\n        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        // This result would change if the padding method is different\n        let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_scale() {\n        let cx = 0.0;\n        let cy = 0.0;\n        let t = Transform2D::scale(0.5, 0.5, cx, cy);\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_scale_around_corner() {\n        let cx = 1.;\n        let cy = -1.;\n        let t = Transform2D::scale(0.5, 0.5, cx, cy);\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n\n    #[test]\n    fn transform_combined() {\n        let t1 = Transform2D::translation(0.2, -0.5);\n        let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);\n        let t = Transform2D::composed([t1, t2]);\n\n        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);\n        // This result would change if the padding method is different\n        let image_expected =\n            Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);\n        let image = t.transform(image);\n        image_expected\n            .to_data()\n            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/src/utils/mod.rs",
    "content": "mod save;\n\npub use save::*;\n"
  },
  {
    "path": "crates/burn-vision/src/utils/save.rs",
    "content": "//! Utilities for saving tensors as images\n\nuse burn_tensor::{ElementConversion, Tensor, backend::Backend};\nuse image::{Rgb, RgbImage};\nuse std::fs;\nuse std::path::Path;\n\n/// How to save a tensor as an image\npub struct TensorDisplayOptions {\n    /// How should the dimensions be interpreted\n    pub dim_order: ImageDimOrder,\n    /// What colors should be used\n    pub color_opts: ColorDisplayOpts,\n    /// How to handle batches\n    pub batch_opts: Option<BatchDisplayOpts>,\n    /// Output image width\n    pub width_out: usize,\n    /// Output image height\n    pub height_out: usize,\n}\n\n/// How to interpret dimensions for image tensors\npub enum ImageDimOrder {\n    /// dims: (height, width)\n    Hw,\n    /// dims: (channels, height, width)\n    Chw,\n    /// dims: (height, width, channels)\n    Hwc,\n    /// dims: (batch_size, height, width)\n    Nhw,\n    /// dims: (batch_size, channels, height, width)\n    Nchw,\n    /// dims: (batch_size, height, width, channels)\n    Nhwc,\n}\n\n/// How to translate tensor values to colors\npub enum ColorDisplayOpts {\n    /// The values in each channel are respectively assigned to an RGB channel\n    Rgb,\n    /// The channel value is mapped between two colors\n    Monochrome {\n        /// Color assigned to the minimum value\n        min: [f32; 3],\n        /// Color assigned to the maximum value\n        max: [f32; 3],\n    },\n}\n\n/// How to handle multi-batch tensors\n#[derive(Clone, Copy, PartialEq, Eq)]\npub enum BatchDisplayOpts {\n    /// Each item is placed consecutively in the image\n    Tiled,\n    /// Each item is aggregated\n    Aggregated,\n}\n\n/// Save a tensor of a batch of images as an image\n///\n/// * `tensor` - Image batch with shape (N, height, width)\n/// * `opts` - Options for how to draw the tensor\n/// * `path` - The file path to use\npub fn save_tensor_as_image<B: Backend, const D: usize, P: AsRef<std::ffi::OsStr>>(\n    tensor: Tensor<B, D>,\n    opts: TensorDisplayOptions,\n    path: P,\n) -> Result<(), Box<dyn std::error::Error>> {\n    // Output file\n    let path = Path::new(&path);\n    if let Some(parent) = path.parent() {\n        fs::create_dir_all(parent)?;\n    }\n\n    let tensor = normalize(tensor);\n\n    // convert to (N,C,H,W) format\n    let tensor: Tensor<B, 4> = match opts.dim_order {\n        ImageDimOrder::Hw => {\n            let [h, w] = tensor.shape().dims();\n            tensor.reshape([1, 1, h, w])\n        }\n        ImageDimOrder::Chw => {\n            let [c, h, w] = tensor.shape().dims();\n            tensor.reshape([1, c, h, w])\n        }\n        ImageDimOrder::Hwc => {\n            let [h, w, c] = tensor.shape().dims();\n            tensor.swap_dims(0, 2).swap_dims(1, 2).reshape([1, c, h, w])\n        }\n        ImageDimOrder::Nhw => {\n            let [n, h, w] = tensor.shape().dims();\n            tensor.reshape([n, 1, h, w])\n        }\n        ImageDimOrder::Nchw => tensor.reshape([0, 0, 0, 0]),\n        ImageDimOrder::Nhwc => tensor.swap_dims(1, 3).swap_dims(2, 3).reshape([0, 0, 0, 0]),\n    };\n\n    let data = tensor.to_data();\n    let shape = data.shape.clone();\n    let (batch, channels, src_height, src_width) = (shape[0], shape[1], shape[2], shape[3]);\n\n    let mut img = if let Some(batch_opts) = &opts.batch_opts\n        && BatchDisplayOpts::Tiled == *batch_opts\n    {\n        RgbImage::new(opts.width_out as u32, (opts.height_out * batch) as u32)\n    } else {\n        RgbImage::new(opts.width_out as u32, opts.height_out as u32)\n    };\n\n    let data_vec = data.to_vec::<f32>().unwrap();\n\n    let mut channel_vals = vec![0 as f32; channels]; // value for each channel in a given pixel\n    for n in 0..batch {\n        for x in 0..opts.width_out {\n            for y in 0..opts.height_out {\n                let i = ((x as f32) / (opts.width_out as f32) * (src_width as f32))\n                    .floor()\n                    .clamp(0.0, src_width as f32) as usize;\n                let j = ((y as f32) / (opts.height_out as f32) * (src_height as f32))\n                    .floor()\n                    .clamp(0.0, src_height as f32) as usize;\n\n                for c in 0..channels {\n                    channel_vals[c] =\n                        data_vec[i + (j + (n * channels + c) * src_height) * src_width];\n                }\n\n                let (x, y) = if let Some(batch_opts) = opts.batch_opts\n                    && BatchDisplayOpts::Tiled == batch_opts\n                {\n                    let batch_x = 0;\n                    let batch_y = n as u32 * opts.height_out as u32;\n                    (x as u32 + batch_x, y as u32 + batch_y)\n                } else {\n                    (x as u32, y as u32)\n                };\n\n                let mut pixel = [0 as f32; 3];\n                match opts.color_opts {\n                    ColorDisplayOpts::Rgb => match channels {\n                        1 => {\n                            pixel[0] = channel_vals[0];\n                            pixel[1] = 0.0;\n                            pixel[2] = 0.0;\n                        }\n                        2 => {\n                            pixel[0] = channel_vals[0];\n                            pixel[1] = channel_vals[1];\n                            pixel[2] = 0.0;\n                        }\n                        3 => {\n                            pixel[0] = channel_vals[0];\n                            pixel[1] = channel_vals[1];\n                            pixel[2] = channel_vals[2];\n                        }\n                        _ => unimplemented!(\"More than 3 channels not supported ({channels})\"),\n                    },\n                    ColorDisplayOpts::Monochrome { min, max } => {\n                        let val: f32 = channel_vals.iter().sum();\n                        pixel[0] = min[0] * (1.0 - val) + max[0] * val;\n                        pixel[1] = min[1] * (1.0 - val) + max[1] * val;\n                        pixel[2] = min[2] * (1.0 - val) + max[2] * val;\n                    }\n                }\n\n                let pixel = [\n                    (pixel[0] * 255.0) as u8,\n                    (pixel[1] * 255.0) as u8,\n                    (pixel[2] * 255.0) as u8,\n                ];\n                img.put_pixel(x, y, Rgb(pixel));\n            }\n        }\n    }\n\n    img.save(path)?;\n    Ok(())\n}\n\n/// Normalize values in 2D tensor from 0 to 1\nfn normalize<B: Backend, const D: usize>(tensor: Tensor<B, D>) -> Tensor<B, D> {\n    let min = tensor.clone().min().into_scalar().elem::<f32>();\n    let max = tensor.clone().max().into_scalar().elem::<f32>();\n    let range = if max - min == 0.0 { 1.0 } else { max - min };\n\n    tensor\n        .sub_scalar(min.elem::<f32>())\n        .div_scalar(range.elem::<f32>())\n}\n"
  },
  {
    "path": "crates/burn-vision/tests/common/mod.rs",
    "content": "use std::path::PathBuf;\n\nuse burn_tensor::{Shape, Tensor, TensorData, backend::Backend};\nuse image::{DynamicImage, ImageBuffer, Luma, Rgb};\n\nuse burn_tensor::{Bool, Int};\n\n#[cfg(all(\n    any(feature = \"test-cpu\", feature = \"ndarray\"),\n    not(any(feature = \"test-wgpu\", feature = \"test-cuda\"))\n))]\npub type TestBackend = burn_ndarray::NdArray<f32, i32>;\n\n#[cfg(all(test, feature = \"test-wgpu\"))]\npub type TestBackend = burn_wgpu::Wgpu;\n\n#[cfg(all(test, feature = \"test-cuda\"))]\npub type TestBackend = burn_cuda::Cuda;\n\n#[allow(unused)]\npub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;\npub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;\n#[allow(unused)]\npub type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;\n\n#[allow(unused)]\npub type IntType = <TestBackend as burn_tensor::backend::Backend>::IntElem;\n\n#[allow(missing_docs)]\n#[macro_export]\nmacro_rules! as_type {\n    ($ty:ident: [$($elem:tt),*]) => {\n        [$($crate::as_type![$ty: $elem]),*]\n    };\n    ($ty:ident: [$($elem:tt,)*]) => {\n        [$($crate::as_type![$ty: $elem]),*]\n    };\n    ($ty:ident: $elem:expr) => {\n        {\n            use cubecl::prelude::*;\n\n            $ty::new($elem)\n        }\n    };\n}\n\n#[allow(unused)]\npub fn test_image<B: Backend>(name: &str, device: &B::Device, luma: bool) -> Tensor<B, 3> {\n    let file = PathBuf::from(\"tests/images\").join(name);\n    let image = image::open(file).unwrap();\n    if luma {\n        let image = image.to_luma32f();\n        let h = image.height() as usize;\n        let w = image.width() as usize;\n        let data = TensorData::new(image.into_vec(), Shape::new([h, w, 1]));\n        Tensor::from_data(data, device)\n    } else {\n        let image = image.to_rgb32f();\n        let h = image.height() as usize;\n        let w = image.width() as usize;\n        let data = TensorData::new(image.into_vec(), Shape::new([h, w, 3]));\n        Tensor::from_data(data, device)\n    }\n}\n\n#[allow(unused)]\npub fn save_test_image<B: Backend>(name: &str, tensor: Tensor<B, 3>, luma: bool) {\n    let file = PathBuf::from(\"tests/images\").join(name);\n    let [h, w, _] = tensor.shape().dims();\n    let data = tensor\n        .into_data()\n        .convert::<f32>()\n        .into_vec::<f32>()\n        .unwrap();\n    if luma {\n        let image = ImageBuffer::<Luma<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();\n        DynamicImage::from(image).to_luma8().save(file).unwrap();\n    } else {\n        let image = ImageBuffer::<Rgb<f32>, _>::from_raw(w as u32, h as u32, data).unwrap();\n        DynamicImage::from(image).to_rgb8().save(file).unwrap();\n    }\n}\n"
  },
  {
    "path": "crates/burn-vision/tests/connected_components.rs",
    "content": "#![allow(clippy::single_range_in_vec_init)]\n\nuse std::collections::HashMap;\n\nuse burn_tensor::TensorData;\nuse burn_vision::{ConnectedComponents, ConnectedStatsOptions, Connectivity};\n\nmod common;\nuse common::*;\n\nfn space_invader() -> [[IntType; 14]; 9] {\n    as_type!(IntType: [\n        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n        [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],\n        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n        [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],\n        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n        [1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1],\n        [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1],\n        [1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1],\n        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n    ])\n}\n\n#[test]\nfn should_support_8_connectivity() {\n    let tensor = TestTensorBool::<2>::from(space_invader());\n\n    let output = tensor.connected_components(Connectivity::Eight);\n    let expected = space_invader(); // All pixels are in the same group for 8-connected\n    let expected = TestTensorInt::<2>::from(expected);\n\n    normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);\n}\n\n#[test]\nfn should_support_8_connectivity_with_stats() {\n    let tensor = TestTensorBool::<2>::from(space_invader());\n\n    let (output, stats) =\n        tensor.connected_components_with_stats(Connectivity::Eight, ConnectedStatsOptions::all());\n    let expected = space_invader(); // All pixels are in the same group for 8-connected\n    let expected = TestTensorInt::<2>::from(expected);\n\n    let (area, left, top, right, bottom) = (\n        stats.area.slice([1..2]).into_data(),\n        stats.left.slice([1..2]).into_data(),\n        stats.top.slice([1..2]).into_data(),\n        stats.right.slice([1..2]).into_data(),\n        stats.bottom.slice([1..2]).into_data(),\n    );\n\n    output.into_data().assert_eq(&expected.into_data(), false);\n\n    area.assert_eq(&TensorData::from([58]), false);\n    left.assert_eq(&TensorData::from([0]), false);\n    top.assert_eq(&TensorData::from([0]), false);\n    right.assert_eq(&TensorData::from([13]), false);\n    bottom.assert_eq(&TensorData::from([8]), false);\n    stats\n        .max_label\n        .into_data()\n        .assert_eq(&TensorData::from([1]), false);\n}\n\n#[test]\nfn should_support_4_connectivity() {\n    let tensor = TestTensorBool::<2>::from(space_invader());\n\n    let output = tensor.connected_components(Connectivity::Four);\n    let expected = as_type!(IntType: [\n        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],\n        [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],\n        [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],\n        [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],\n        [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],\n        [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],\n        [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],\n        [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],\n        [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],\n    ]);\n    let expected = TestTensorInt::<2>::from(expected);\n\n    normalize_labels(output.into_data()).assert_eq(&expected.into_data(), false);\n}\n\n#[test]\nfn should_support_4_connectivity_with_stats() {\n    let tensor = TestTensorBool::<2>::from(space_invader());\n\n    let (output, stats) =\n        tensor.connected_components_with_stats(Connectivity::Four, ConnectedStatsOptions::all());\n    let expected = as_type!(IntType: [\n        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0],\n        [0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0],\n        [0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0],\n        [0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0],\n        [0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0],\n        [4, 0, 0, 3, 3, 0, 0, 0, 0, 3, 3, 0, 0, 5],\n        [4, 4, 0, 0, 3, 3, 3, 3, 3, 3, 0, 0, 5, 5],\n        [4, 4, 0, 3, 3, 3, 0, 0, 3, 3, 3, 0, 5, 5],\n        [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0],\n    ]);\n    let expected = TestTensorInt::<2>::from(expected);\n\n    // Slice off background and limit to compacted labels\n    let (area, left, top, right, bottom) = (\n        stats.area.slice([1..6]).into_data(),\n        stats.left.slice([1..6]).into_data(),\n        stats.top.slice([1..6]).into_data(),\n        stats.right.slice([1..6]).into_data(),\n        stats.bottom.slice([1..6]).into_data(),\n    );\n\n    output.into_data().assert_eq(&expected.into_data(), false);\n\n    area.assert_eq(&TensorData::from([1, 1, 46, 5, 5]), false);\n    left.assert_eq(&TensorData::from([3, 10, 1, 0, 12]), false);\n    top.assert_eq(&TensorData::from([0, 0, 1, 5, 5]), false);\n    right.assert_eq(&TensorData::from([3, 10, 12, 1, 13]), false);\n    bottom.assert_eq(&TensorData::from([0, 0, 8, 7, 7]), false);\n    stats\n        .max_label\n        .into_data()\n        .assert_eq(&TensorData::from([5]), false);\n}\n\n/// Normalize labels to sequential since actual labels aren't required to be contiguous and\n/// different algorithms can return different numbers even if correct\nfn normalize_labels(mut labels: TensorData) -> TensorData {\n    let mut next_label = 0;\n    let mut mappings = HashMap::<i32, i32>::default();\n    let data = labels.as_mut_slice::<i32>().unwrap();\n    for label in data {\n        if *label != 0 {\n            let relabel = mappings.entry(*label).or_insert_with(|| {\n                next_label += 1;\n                next_label\n            });\n            *label = *relabel;\n        }\n    }\n    labels\n}\n"
  },
  {
    "path": "crates/burn-vision/tests/morphology.rs",
    "content": "use burn_tensor::{Tolerance, ops::FloatElem};\nuse burn_vision::{\n    BorderType, KernelShape, MorphOptions, Morphology, Point, Size, create_structuring_element,\n};\ntype FT = FloatElem<TestBackend>;\n\nmod common;\nuse common::*;\n\n#[test]\nfn should_support_dilate_luma() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_5x5_Rect.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_luma_cross() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_5x5_Cross.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_luma_ellipse() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Ellipse,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_5x5_Ellipse.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_luma_non_square_rect() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(3, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_3x5_Rect.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_luma_non_square_cross() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(3, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_3x5_Cross.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_rect() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(3, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_2_3x5_Rect.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_cross() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(3, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_2_3x5_Cross.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_reflect_rect() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Reflect)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_reflect_cross() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Reflect)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_reflect101_rect() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Reflect101)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Rect_BORDER_REFLECT101.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_reflect101_cross() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Reflect101)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Cross_BORDER_REFLECT101.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_replicate_rect() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Replicate)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Rect_BORDER_REPLICATE.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_border_replicate_cross() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(7, 7),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .border_type(BorderType::Replicate)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_7x7_Cross_BORDER_REPLICATE.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_anchor_rect() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(5, 7),\n        Some(Point::new(1, 2)),\n        &Default::default(),\n    );\n\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder().anchor(Point::new(2, 1)).build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_5x7_Rect_ANCHOR.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_rgb_anchor_cross() {\n    let tensor = test_image(\"morphology/Base_2.png\", &Default::default(), false);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 7),\n        Some(Point::new(1, 2)),\n        &Default::default(),\n    );\n\n    // With default border, bottom left pixel is undefined with this particular kernel and anchor\n    // Use replicate instead for comparability\n    let output = tensor.dilate(\n        kernel,\n        MorphOptions::builder()\n            .anchor(Point::new(2, 1))\n            .border_type(BorderType::Replicate)\n            .build(),\n    );\n    let expected = test_image(\n        \"morphology/Dilate_2_5x7_Cross_ANCHOR_BORDER_REPLICATE.png\",\n        &Default::default(),\n        false,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_boolean_rect() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true).greater_elem(0);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    // With default border, bottom left pixel is undefined with this particular kernel and anchor\n    // Use replicate instead for comparability\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_5x5_Rect.png\",\n        &Default::default(),\n        true,\n    )\n    .greater_elem(0);\n    let expected = TestTensorBool::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_boolean_cross() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true).greater_elem(0);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    // With default border, bottom left pixel is undefined with this particular kernel and anchor\n    // Use replicate instead for comparability\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Dilate_1_5x5_Cross.png\",\n        &Default::default(),\n        true,\n    )\n    .greater_elem(0);\n    let expected = TestTensorBool::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_int_rect() {\n    let tensor = (test_image(\"morphology/Base_1.png\", &Default::default(), true) * 255.0).int();\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    // With default border, bottom left pixel is undefined with this particular kernel and anchor\n    // Use replicate instead for comparability\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = (test_image(\n        \"morphology/Dilate_1_5x5_Rect.png\",\n        &Default::default(),\n        true,\n    ) * 255.0)\n        .int();\n    let expected = TestTensorInt::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_dilate_int_cross() {\n    let tensor = (test_image(\"morphology/Base_1.png\", &Default::default(), true) * 255.0).int();\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    // With default border, bottom left pixel is undefined with this particular kernel and anchor\n    // Use replicate instead for comparability\n    let output = tensor.dilate(kernel, MorphOptions::default());\n    let expected = (test_image(\n        \"morphology/Dilate_1_5x5_Cross.png\",\n        &Default::default(),\n        true,\n    ) * 255.0)\n        .int();\n    let expected = TestTensorInt::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_erode_luma() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = TestTensorBool::<2>::from([\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n    ]);\n\n    let output = tensor.erode(kernel, MorphOptions::default());\n    let expected = test_image(\"morphology/Erode_1_5x5_Rect.png\", &Default::default(), true);\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_erode_luma_cross() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.erode(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Erode_1_5x5_Cross.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn should_support_erode_luma_ellipse() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Ellipse,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n\n    let output = tensor.erode(kernel, MorphOptions::default());\n    let expected = test_image(\n        \"morphology/Erode_1_5x5_Ellipse.png\",\n        &Default::default(),\n        true,\n    );\n    let expected = TestTensor::<3>::from(expected);\n\n    output\n        .into_data()\n        .assert_approx_eq::<FT>(&expected.into_data(), Tolerance::absolute(1e-6));\n}\n\n#[test]\nfn create_structuring_element_should_match_manual_rect() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Rect,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n    let kernel_manual = TestTensorBool::<2>::from([\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n    ]);\n\n    let output = tensor.clone().dilate(kernel, MorphOptions::default());\n    let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());\n\n    output\n        .into_data()\n        .assert_eq(&output_manual.into_data(), false);\n}\n\n#[test]\nfn create_structuring_element_should_match_manual_cross() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Cross,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n    let kernel_manual = TestTensorBool::<2>::from([\n        [false, false, true, false, false],\n        [false, false, true, false, false],\n        [true, true, true, true, true],\n        [false, false, true, false, false],\n        [false, false, true, false, false],\n    ]);\n\n    let output = tensor.clone().dilate(kernel, MorphOptions::default());\n    let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());\n\n    output\n        .into_data()\n        .assert_eq(&output_manual.into_data(), false);\n}\n#[test]\nfn create_structuring_element_should_match_manual_ellipse() {\n    let tensor = test_image(\"morphology/Base_1.png\", &Default::default(), true);\n    let kernel = create_structuring_element::<TestBackend>(\n        KernelShape::Ellipse,\n        Size::new(5, 5),\n        None,\n        &Default::default(),\n    );\n    let kernel_manual = TestTensorBool::<2>::from([\n        [false, false, true, false, false],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [true, true, true, true, true],\n        [false, false, true, false, false],\n    ]);\n\n    let output = tensor.clone().dilate(kernel, MorphOptions::default());\n    let output_manual = tensor.dilate(kernel_manual, MorphOptions::default());\n\n    output\n        .into_data()\n        .assert_eq(&output_manual.into_data(), false);\n}\n"
  },
  {
    "path": "crates/burn-vision/tests/nms.rs",
    "content": "use burn_vision::{Nms, NmsOptions};\n\nmod common;\nuse common::*;\n\n#[test]\nfn should_suppress_non_maximum() {\n    let boxes = TestTensor::<2>::from([\n        [0, 0, 100, 100],\n        [0, 1, 100, 100],\n        [0, 101, 200, 200],\n        [0, 100, 200, 200],\n        [0, 170, 300, 300],\n    ]);\n    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);\n    let options = NmsOptions {\n        iou_threshold: 0.5,\n        score_threshold: 0.0,\n        max_output_boxes: 0,\n    };\n\n    let output = boxes.nms(scores, options);\n\n    let expected = TestTensorInt::<1>::from([4, 2, 1]);\n    output.into_data().assert_eq(&expected.into_data(), true);\n}\n\n#[test]\nfn should_apply_score_threshold() {\n    let boxes = TestTensor::<2>::from([\n        [0, 0, 100, 100],\n        [0, 1, 100, 100],\n        [0, 101, 200, 200],\n        [0, 100, 200, 200],\n        [0, 170, 300, 300],\n    ]);\n    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);\n    let options = NmsOptions {\n        iou_threshold: 0.5,\n        score_threshold: 0.3,\n        max_output_boxes: 0,\n    };\n\n    let output = boxes.nms(scores, options);\n\n    let expected = TestTensorInt::<1>::from([4, 2]);\n    output.into_data().assert_eq(&expected.into_data(), true);\n}\n\n#[test]\nfn should_apply_iou_threshold() {\n    let boxes = TestTensor::<2>::from([\n        [0, 0, 100, 100],\n        [0, 1, 100, 100],\n        [0, 101, 200, 200],\n        [0, 100, 200, 200],\n        [0, 170, 300, 300],\n    ]);\n    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);\n    let options = NmsOptions {\n        iou_threshold: 0.1,\n        score_threshold: 0.0,\n        max_output_boxes: 0,\n    };\n\n    let output = boxes.nms(scores, options);\n\n    let expected = TestTensorInt::<1>::from([4, 1]);\n    output.into_data().assert_eq(&expected.into_data(), true);\n}\n\n#[test]\nfn should_apply_max_output_boxes() {\n    let boxes = TestTensor::<2>::from([\n        [0, 0, 100, 100],\n        [0, 1, 100, 100],\n        [0, 101, 200, 200],\n        [0, 100, 200, 200],\n        [0, 170, 300, 300],\n    ]);\n    let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]);\n    let options = NmsOptions {\n        iou_threshold: 0.5,\n        score_threshold: 0.0,\n        max_output_boxes: 1,\n    };\n\n    let output = boxes.nms(scores, options);\n\n    let expected = TestTensorInt::<1>::from([4]);\n    output.into_data().assert_eq(&expected.into_data(), true);\n}\n"
  },
  {
    "path": "crates/burn-wgpu/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\ncategories = [\"science\"]\ndescription = \"WGPU backend for the Burn framework\"\ndocumentation = \"https://docs.rs/burn-wgpu\"\nedition.workspace = true\nkeywords = [\"deep-learning\", \"machine-learning\", \"gpu\", \"wgpu\", \"webgpu\"]\nlicense.workspace = true\nname = \"burn-wgpu\"\nreadme.workspace = true\nrepository = \"https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu\"\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"std\", \"autotune\", \"fusion\", \"burn-cubecl/default\", \"cubecl/default\"]\ndoc = [\"burn-cubecl/doc\"]\nstd = [\"burn-cubecl/std\", \"cubecl/std\"]\ntracing = [\n    \"cubecl/tracing\",\n    \"burn-backend/tracing\",\n    \"burn-fusion?/tracing\",\n    \"burn-cubecl/tracing\",\n]\n\nautotune = [\"burn-cubecl/autotune\"]\nautotune-checks = [\"burn-cubecl/autotune-checks\"]\nexclusive-memory-only = [\"cubecl/exclusive-memory-only\"]\nfusion = [\"burn-fusion\", \"burn-cubecl/fusion\"]\ntemplate = [\"burn-cubecl/template\", \"cubecl/template\"]\n\n# Backends\nmetal = [\"cubecl-msl\"]\nvulkan = [\"cubecl-spirv\"]\nwebgpu = [\"cubecl-wgsl\"]\n\n# Compilers\ncubecl-msl = [\"cubecl/wgpu-msl\"]\ncubecl-spirv = [\"cubecl/wgpu-spirv\"]\ncubecl-wgsl = []\n\n[dependencies]\ncubecl = { workspace = true, features = [\"wgpu\"] }\n\nburn-cubecl = { path = \"../burn-cubecl\", version = \"=0.21.0-pre.2\", default-features = false }\nburn-fusion = { path = \"../burn-fusion\", version = \"=0.21.0-pre.2\", optional = true }\nburn-backend = { path = \"../burn-backend\", version = \"=0.21.0-pre.2\", default-features = false, features = [\n    \"cubecl-wgpu\",\n] }\n\n\n[package.metadata.docs.rs]\nfeatures = [\"default\"]\nrustdoc-args = [\"--cfg\", \"docsrs\"]\n"
  },
  {
    "path": "crates/burn-wgpu/README.md",
    "content": "# Burn WGPU Backend\n\n[Burn](https://github.com/tracel-ai/burn) WGPU backend\n\n[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)\n[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-wgpu/blob/master/README.md)\n\nThis crate provides a WGPU backend for [Burn](https://github.com/tracel-ai/burn) using the\n[wgpu](https://github.com/gfx-rs/wgpu).\n\nThe backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.\n\n## Usage Example\n\n```rust\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use burn_autodiff::Autodiff;\n    use burn_wgpu::{Wgpu, WgpuDevice};\n    use mnist::training;\n\n    pub fn run() {\n        let device = WgpuDevice::default();\n        training::run::<Autodiff<Wgpu<f32, i32>>>(device);\n    }\n}\n```\n\n> ⚠️ **Warning**  \n> When using one of the `wgpu` backends, you may encounter compilation errors related to recursive type evaluation. This is due to complex type nesting within the `wgpu` dependency chain.  \n> To resolve this issue, add the following line at the top of your `main.rs` or `lib.rs` file:\n> ```rust\n> #![recursion_limit = \"256\"]\n> ```\n> The default recursion limit (128) is often just below the required depth (typically 130-150) due to deeply nested associated types and trait bounds.\n\n\n## Configuration\n\nYou can set `BURN_WGPU_MAX_TASKS` to a positive integer that determines how many computing tasks are\nsubmitted in batches to the graphics API.\n\n## Alternative SPIR-V backend\n\nWhen targeting Vulkan, the `spirv` feature flag can be enabled to enable the SPIR-V compiler\nbackend, which performs significantly better than WGSL. This is especially true for matrix\nmultiplication, where SPIR-V can make use of TensorCores and run at `f16` precision. This isn't\ncurrently supported by WGSL. The compiler can also be selected at runtime by setting the\ncorresponding generic parameter to either `SpirV` or `Wgsl`.\n\n## Platform Support\n\n| Option    | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |\n| :-------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |\n| Metal     | No  | Yes |  No   |  Yes  |   No    |   No    | Yes |  No  |\n| Vulkan    | Yes | Yes |  Yes  |  Yes  |   Yes   |   Yes   | Yes |  No  |\n| OpenGL    | No  | Yes |  Yes  |  Yes  |   Yes   |   Yes   | Yes |  No  |\n| WebGpu    | No  | Yes |  No   |  No   |   No    |   No    | No  | Yes  |\n| Dx11/Dx12 | No  | Yes |  No   |  No   |   Yes   |   No    | No  |  No  |\n"
  },
  {
    "path": "crates/burn-wgpu/src/lib.rs",
    "content": "#![cfg_attr(docsrs, feature(doc_cfg))]\n\nextern crate alloc;\n\n#[cfg(feature = \"template\")]\npub use burn_cubecl::{\n    kernel::{KernelMetadata, into_contiguous},\n    kernel_source,\n    template::{KernelSource, SourceKernel, SourceTemplate, build_info},\n};\n\npub use burn_cubecl::{BoolElement, FloatElement, IntElement};\npub use burn_cubecl::{CubeBackend, tensor::CubeTensor};\npub use cubecl::CubeDim;\npub use cubecl::flex32;\n\npub use cubecl::wgpu::{\n    AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,\n    WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,\n};\n// Vulkan and WebGpu would have conflicting type names\npub mod graphics {\n    pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};\n}\n\n#[cfg(feature = \"cubecl-wgsl\")]\npub use cubecl::wgpu::WgslCompiler;\n#[cfg(feature = \"cubecl-spirv\")]\npub use cubecl::wgpu::vulkan::VkSpirvCompiler;\n\n#[cfg(feature = \"fusion\")]\n/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.\n///\n/// This backend can target multiple graphics APIs, including:\n///   - [Vulkan][crate::graphics::Vulkan] on Linux, Windows, and Android.\n///   - [OpenGL](crate::graphics::OpenGl) on Linux, Windows, and Android.\n///   - [DirectX 12](crate::graphics::Dx12) on Windows.\n///   - [Metal][crate::graphics::Metal] on Apple hardware.\n///   - [WebGPU](crate::graphics::WebGpu) on supported browsers and `wasm` runtimes.\n///\n/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,\n/// you have to manually initialize the runtime. For example:\n///\n/// ```rust, ignore\n/// fn custom_init() {\n///     let device = Default::default();\n///     burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(\n///         &device,\n///         Default::default(),\n///     );\n/// }\n/// ```\n/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.\n/// It's also possible to use an existing wgpu device, by using `init_device`.\n///\n/// # Notes\n///\n/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor\n/// operations for improved performance.\n///\n/// You can disable the `fusion` feature flag to remove that functionality, which might be\n/// necessary on `wasm` for now.\npub type Wgpu<F = f32, I = i32, B = u32> =\n    burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>>;\n\n#[cfg(not(feature = \"fusion\"))]\n/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.\n///\n/// This backend can target multiple graphics APIs, including:\n///   - [Vulkan] on Linux, Windows, and Android.\n///   - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.\n///   - [DirectX 12](crate::Dx12) on Windows.\n///   - [Metal] on Apple hardware.\n///   - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.\n///\n/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,\n/// you have to manually initialize the runtime. For example:\n///\n/// ```rust, ignore\n/// fn custom_init() {\n///     let device = Default::default();\n///     burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(\n///         &device,\n///         Default::default(),\n///     );\n/// }\n/// ```\n/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.\n/// It's also possible to use an existing wgpu device, by using `init_device`.\n///\n/// # Notes\n///\n/// This version of the wgpu backend doesn't use [burn_fusion] to compile and optimize streams of tensor\n/// operations.\n///\n/// You can enable the `fusion` feature flag to add that functionality, which might improve\n/// performance.\npub type Wgpu<F = f32, I = i32, B = u32> = CubeBackend<cubecl::wgpu::WgpuRuntime, F, I, B>;\n\n#[cfg(feature = \"vulkan\")]\n/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.\npub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;\n\n#[cfg(feature = \"webgpu\")]\n/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.\npub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B>;\n\n#[cfg(feature = \"metal\")]\n/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.\npub type Metal<F = f32, I = i32, B = u8> = Wgpu<F, I, B>;\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use burn_backend::{Backend, BoolStore, DType, QTensorPrimitive};\n\n    #[test]\n    fn should_support_dtypes() {\n        type B = Wgpu;\n        let device = Default::default();\n\n        assert!(B::supports_dtype(&device, DType::F32));\n        assert!(B::supports_dtype(&device, DType::I64));\n        assert!(B::supports_dtype(&device, DType::I32));\n        assert!(B::supports_dtype(&device, DType::U64));\n        assert!(B::supports_dtype(&device, DType::U32));\n        assert!(B::supports_dtype(\n            &device,\n            DType::QFloat(CubeTensor::<WgpuRuntime>::default_scheme())\n        ));\n        // Registered as supported type but we don't actually use it?\n        assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native)));\n\n        #[cfg(feature = \"vulkan\")]\n        {\n            assert!(B::supports_dtype(&device, DType::F16));\n            assert!(B::supports_dtype(&device, DType::I16));\n            assert!(B::supports_dtype(&device, DType::I8));\n            assert!(B::supports_dtype(&device, DType::U16));\n            assert!(B::supports_dtype(&device, DType::U8));\n\n            assert!(!B::supports_dtype(&device, DType::F64));\n            assert!(!B::supports_dtype(&device, DType::Flex32));\n            // Not supported for any arithmetics, but buffer, conversion and possibly matmul (hw dependent)\n            assert!(!B::supports_dtype(&device, DType::BF16));\n        }\n\n        #[cfg(feature = \"metal\")]\n        {\n            assert!(B::supports_dtype(&device, DType::F16));\n            assert!(B::supports_dtype(&device, DType::I16));\n            assert!(B::supports_dtype(&device, DType::I8));\n            assert!(B::supports_dtype(&device, DType::U16));\n            assert!(B::supports_dtype(&device, DType::U8));\n\n            assert!(!B::supports_dtype(&device, DType::F64));\n            assert!(!B::supports_dtype(&device, DType::BF16));\n            assert!(!B::supports_dtype(&device, DType::Flex32));\n        }\n\n        // On macOS without the `metal` feature, wgpu still uses Metal at runtime,\n        // which doesn't support F64 or BF16.\n        #[cfg(all(not(any(feature = \"vulkan\", feature = \"metal\")), target_os = \"macos\"))]\n        {\n            assert!(B::supports_dtype(&device, DType::Flex32));\n            assert!(B::supports_dtype(&device, DType::F16));\n\n            assert!(!B::supports_dtype(&device, DType::F64));\n            assert!(!B::supports_dtype(&device, DType::BF16));\n            assert!(!B::supports_dtype(&device, DType::I16));\n            assert!(!B::supports_dtype(&device, DType::I8));\n            assert!(!B::supports_dtype(&device, DType::U16));\n            assert!(!B::supports_dtype(&device, DType::U8));\n        }\n\n        #[cfg(not(any(feature = \"vulkan\", feature = \"metal\", target_os = \"macos\")))]\n        {\n            assert!(B::supports_dtype(&device, DType::F64));\n            assert!(B::supports_dtype(&device, DType::Flex32));\n            assert!(B::supports_dtype(&device, DType::F16));\n\n            assert!(!B::supports_dtype(&device, DType::BF16));\n            assert!(!B::supports_dtype(&device, DType::I16));\n            assert!(!B::supports_dtype(&device, DType::I8));\n            assert!(!B::supports_dtype(&device, DType::U16));\n            assert!(!B::supports_dtype(&device, DType::U8));\n        }\n    }\n}\n"
  },
  {
    "path": "deny.toml",
    "content": "# If 1 or more target triples (and optionally, target_features) are specified,\n# only the specified targets will be checked when running `cargo deny check`.\n# This means, if a particular package is only ever used as a target specific\n# dependency, such as, for example, the `nix` crate only being used via the\n# `target_family = \"unix\"` configuration, that only having windows targets in\n# this list would mean the nix crate, as well as any of its exclusive\n# dependencies not shared by any other crates, would be ignored, as the target\n# list here is effectively saying which targets you are building for.\n[graph]\ntargets = [\n    { triple = \"x86_64-unknown-linux-gnu\" },\n    { triple = \"aarch64-unknown-linux-gnu\" },\n    { triple = \"x86_64-unknown-linux-musl\" },\n    { triple = \"aarch64-apple-darwin\" },\n    { triple = \"x86_64-apple-darwin\" },\n    { triple = \"x86_64-pc-windows-msvc\" },\n]\n\n[advisories]\n# A list of advisory IDs to ignore. Note that ignored advisories will still\n# output a note when they are encountered.\nignore = [\n    #\"RUSTSEC-0000-0000\",\n]\n\n[bans]\n# Lint level for when multiple versions of the same crate are detected\nmultiple-versions = \"warn\"\n# Lint level for when a crate version requirement is `*`\nwildcards = \"allow\"\n# The graph highlighting used when creating dotgraphs for crates\n# with multiple versions\n# * lowest-version - The path to the lowest versioned duplicate is highlighted\n# * simplest-path - The path to the version with the fewest edges is highlighted\n# * all - Both lowest-version and simplest-path are used\nhighlight = \"all\"\n# The default lint level for `default` features for crates that are members of\n# the workspace that is being checked. This can be overridden by allowing/denying\n# `default` on a crate-by-crate basis if desired.\nworkspace-default-features = \"allow\"\n# The default lint level for `default` features for external crates that are not\n# members of the workspace. This can be overridden by allowing/denying `default`\n# on a crate-by-crate basis if desired.\nexternal-default-features = \"allow\"\n# Certain crates/versions that will be skipped when doing duplicate detection.\nskip = [\n    #{ name = \"crate\", version = \"=0.1.0\" },\n]\n# Similarly to `skip` allows you to skip certain crates during duplicate\n# detection. Unlike skip, it also includes the entire tree of transitive\n# dependencies starting at the specified crate, up to a certain depth, which is\n# by default infinite.\nskip-tree = [\n    #{ name = \"crate\", version = \"=0.1.0\", depth = 20 },\n]\n\n[sources]\n# Lint level for what to happen when a crate from a crate registry that is not\n# in the allow list is encountered\nunknown-registry = \"deny\"\n# Lint level for what to happen when a crate from a git repository that is not\n# in the allow list is encountered\nunknown-git = \"deny\"\n\n[licenses]\n# The confidence threshold for detecting a license from license text.\n# The higher the value, the more closely the license text must be to the\n# canonical license text of a valid SPDX license file.\n# [possible values: any between 0.0 and 1.0].\nconfidence-threshold = 0.60\n# List of explicitly allowed licenses\n# See https://spdx.org/licenses/ for list of possible licenses\n# [possible values: any SPDX 3.11 short identifier (+ optional exception)].\nallow = [\n    \"Apache-2.0 WITH LLVM-exception\",\n    \"Apache-2.0\",\n    \"BSD-3-Clause\",\n    \"BSD-2-Clause\",\n    \"BSL-1.0\", # in NOTICES.md\n    \"CC0-1.0\",\n    \"ISC\",\n    \"MIT\",\n    \"MPL-2.0\",\n    \"OpenSSL\",\n    \"Unicode-DFS-2016\",\n    \"Unicode-3.0\",\n    \"Unlicense\",\n    \"Zlib\",\n]\n# Allow 1 or more licenses on a per-crate basis, so that particular licenses\n# aren't accepted for every possible crate as with the normal allow list\nexceptions = [\n    # Each entry is the crate and version constraint, and its specific allow\n    # list\n    #{ allow = [\"license_name\"], name = \"crate\", version = \"*\" },\n]\n"
  },
  {
    "path": "docs/katex-header.html",
    "content": "<!--Follows the instructions at-->\n<!--https://docs.rs/rustdoc-katex-demo/0.1.5/rustdoc_katex_demo/-->\n\n<link rel=\"stylesheet\" href=\"https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/katex.min.css\" integrity=\"sha384-5TcZemv2l/9On385z///+d7MSYlvIEw9FuZTIdZ14vJLqWphw7e7ZPuOiCHJcFCP\" crossorigin=\"anonymous\">\n<script defer src=\"https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/katex.min.js\" integrity=\"sha384-cMkvdD8LoxVzGF/RPUKAcvmm49FQ0oxwDF3BGKtDXcEc+T1b2N+teh/OJfpU0jr6\" crossorigin=\"anonymous\"></script>\n<script defer src=\"https://cdn.jsdelivr.net/npm/katex@0.16.22/dist/contrib/auto-render.min.js\" integrity=\"sha384-hCXGrW6PitJEwbkoStFjeJxv+fSOOQKOPbJxSfM6G5sWZjAyWhXiTIIAmQqnlLlh\" crossorigin=\"anonymous\"></script>\n<script>\n    document.addEventListener(\"DOMContentLoaded\", function() {\n        renderMathInElement(document.body, {\n          // customised options\n          // • auto-render specific keys, e.g.:\n          delimiters: [\n              {left: '$$', right: '$$', display: true},\n              {left: '$', right: '$', display: false},\n              {left: '\\\\(', right: '\\\\)', display: false},\n              {left: '\\\\[', right: '\\\\]', display: true}\n          ],\n          // • rendering keys, e.g.:\n          throwOnError : false\n        });\n    });\n</script>\n"
  },
  {
    "path": "examples/custom-csv-dataset/.gitignore",
    "content": "# Ignore downloaded csv file\n*.csv"
  },
  {
    "path": "examples/custom-csv-dataset/Cargo.toml",
    "content": "[package]\nauthors = [\"laggui <lagrange.guillaume.1@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-csv-dataset\"\ndescription = \"Example implementation for loading a custom CSV dataset from disk\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/dataset\"]\ndataframe = [\"dep:burn-dataset\", \"polars\"]\n\n[dependencies]\nburn = {path = \"../../crates/burn\"}\nburn-dataset = { path = \"../../crates/burn-dataset\", features = [\"dataframe\"], optional = true }\n\n# File download\nreqwest = {workspace = true, features = [\"blocking\"]}\n\n# CSV parsing\ncsv = {workspace = true}\nserde = {workspace = true, features = [\"std\", \"derive\"]}\n\n# Dataframe support (optional)\npolars = {workspace = true, optional = true, features = [\"csv\", \"temporal\"]}\n\n[[example]]\nname = \"dataframe-dataset\"\nrequired-features = [\"dataframe\"]"
  },
  {
    "path": "examples/custom-csv-dataset/README.md",
    "content": "# Custom CSV Dataset\n\nThis example demonstrates two ways to load a CSV dataset and implement the `Dataset` trait. For this example, we use the [diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) (original [source](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html)), which contains 442 patient records.\n\n## InMemDataset\n\nThe [custom-csv-dataset](src/dataset.rs) example uses [`InMemDataset::from_csv(path)`](src/dataset.rs#L80) to read the csv dataset file into a vector (in-memory) of [`DiabetesPatient`](src/dataset.rs#L13) records (struct) with the help of `serde`.\n\n### Example Usage\n\n```sh\ncargo run --example custom-csv-dataset\n```\n\n## DataframeDataset (Polars)\n\nThe [dataframe-dataset](src/dataframe_dataset.rs) example demonstrates using [`DataframeDataset`](src/dataframe_dataset.rs#L61) with [Polars](https://www.pola.rs/) as the backend. This approach is well-suited for efficient data manipulation and analysis of larger datasets.\n\nThe same diabetes dataset is loaded into a Polars DataFrame, which is then wrapped by `DataframeDataset` to implement the `Dataset` trait.\n\n### Example Usage\n\n```sh\ncargo run --example dataframe-dataset --features dataframe\n```"
  },
  {
    "path": "examples/custom-csv-dataset/examples/custom-csv-dataset.rs",
    "content": "use burn::data::dataset::Dataset;\nuse custom_csv_dataset::dataset::DiabetesDataset;\n\nfn main() {\n    let dataset = DiabetesDataset::new().expect(\"Could not load diabetes dataset\");\n\n    println!(\"Dataset loaded with {} rows\", dataset.len());\n\n    // Display first and last elements\n    let item = dataset.get(0).unwrap();\n    println!(\"First item:\\n{item:?}\");\n\n    let item = dataset.get(441).unwrap();\n    println!(\"Last item:\\n{item:?}\");\n}\n"
  },
  {
    "path": "examples/custom-csv-dataset/examples/dataframe-dataset.rs",
    "content": "use burn::data::dataset::Dataset;\nuse custom_csv_dataset::dataframe_dataset::DiabetesDataframeDataset;\n\nfn main() {\n    let dataset = DiabetesDataframeDataset::new()\n        .expect(\"Could not load diabetes dataset with DataframeDataset\");\n\n    println!(\n        \"Dataset loaded with {} rows using DataframeDataset\",\n        dataset.len()\n    );\n\n    // Display first and last elements\n    let item = dataset.get(0).unwrap();\n    println!(\"First item:\\n{item:?}\");\n\n    let item = dataset.get(441).unwrap();\n    println!(\"Last item:\\n{item:?}\");\n}\n"
  },
  {
    "path": "examples/custom-csv-dataset/src/dataframe_dataset.rs",
    "content": "use crate::{diabetes_patient::DiabetesPatient, utils::download_csv_if_missing};\nuse burn_dataset::{DataframeDataset, Dataset};\nuse polars::prelude::*;\n/// Diabetes dataset using Polars DataframeDataset as the backend.\npub struct DiabetesDataframeDataset {\n    dataset: DataframeDataset<DiabetesPatient>,\n}\n\nimpl DiabetesDataframeDataset {\n    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {\n        // Download dataset csv file\n        let path = download_csv_if_missing();\n\n        // Column definitions: (name, schema_type for parsing, cast_type for final output)\n        const COLS: &[(&str, DataType, DataType)] = &[\n            (\"AGE\", DataType::Int64, DataType::Int8),\n            (\"SEX\", DataType::Int64, DataType::Int8),\n            (\"BMI\", DataType::Float64, DataType::Float32),\n            (\"BP\", DataType::Float64, DataType::Float32),\n            (\"S1\", DataType::Int64, DataType::Int16),\n            (\"S2\", DataType::Float64, DataType::Float32),\n            (\"S3\", DataType::Float64, DataType::Float32),\n            (\"S4\", DataType::Float64, DataType::Float32),\n            (\"S5\", DataType::Float64, DataType::Float32),\n            (\"S6\", DataType::Int64, DataType::Int8),\n            (\"Y\", DataType::Int64, DataType::Int16),\n        ];\n\n        // Build Schema\n        let schema = Schema::from_iter(\n            COLS.iter()\n                .map(|(name, schema_type, _)| Field::new((*name).into(), schema_type.clone())),\n        );\n\n        let mut df = LazyCsvReader::new(PlPath::new(path.to_str().unwrap()))\n            .with_has_header(true)\n            .with_separator(b'\\t')\n            .with_schema(Some(Arc::new(schema)))\n            .finish()?\n            .collect()?;\n\n        // cast columns\n        for (col, _, cast_type) in COLS {\n            df.with_column(df.column(col)?.cast(cast_type)?.clone())?;\n        }\n\n        let dataset = DataframeDataset::new(df)?;\n\n        Ok(Self { dataset })\n    }\n}\n\nimpl Default for DiabetesDataframeDataset {\n    fn default() -> Self {\n        Self::new().expect(\"Could not load diabetes dataset\")\n    }\n}\n\n// Implement the `Dataset` trait which requires `get` and `len`\nimpl Dataset<DiabetesPatient> for DiabetesDataframeDataset {\n    fn get(&self, index: usize) -> Option<DiabetesPatient> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n"
  },
  {
    "path": "examples/custom-csv-dataset/src/dataset.rs",
    "content": "use crate::{diabetes_patient::DiabetesPatient, utils::download_csv_if_missing};\nuse burn::data::dataset::{Dataset, InMemDataset};\n\n/// Diabetes patients dataset, also used in [scikit-learn](https://scikit-learn.org/stable/).\n/// See [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset).\n///\n/// The data is parsed from a single csv file (tab as the delimiter).\n/// The dataset contains 10 baseline variables (age, sex, body mass index, average blood pressure and\n/// 6 blood serum measurements for a total of 442 diabetes patients.\n/// For each patient, the response of interest, a quantitative measure of disease progression one year\n/// after baseline, was collected. This represents the target variable.\npub struct DiabetesDataset {\n    dataset: InMemDataset<DiabetesPatient>,\n}\n\nimpl DiabetesDataset {\n    pub fn new() -> Result<Self, std::io::Error> {\n        // Download dataset csv file\n        let path = download_csv_if_missing();\n\n        // Build dataset from csv with tab ('\\t') delimiter\n        let mut rdr = csv::ReaderBuilder::new();\n        let rdr = rdr.delimiter(b'\\t');\n\n        let dataset = InMemDataset::from_csv(path, rdr).unwrap();\n\n        let dataset = Self { dataset };\n\n        Ok(dataset)\n    }\n}\n\n// Implement the `Dataset` trait which requires `get` and `len`\nimpl Dataset<DiabetesPatient> for DiabetesDataset {\n    fn get(&self, index: usize) -> Option<DiabetesPatient> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n"
  },
  {
    "path": "examples/custom-csv-dataset/src/diabetes_patient.rs",
    "content": "use serde::{Deserialize, Serialize};\n\n/// Diabetes patient record.\n/// For each field, we manually specify the expected header name for serde as all names\n/// are capitalized and some field names are not very informative.\n#[derive(Deserialize, Serialize, Debug, Clone)]\npub struct DiabetesPatient {\n    /// Age in years\n    #[serde(rename = \"AGE\")]\n    pub age: i8,\n\n    /// Sex categorical label\n    #[serde(rename = \"SEX\")]\n    pub sex: i8,\n\n    /// Body mass index\n    #[serde(rename = \"BMI\")]\n    pub bmi: f32,\n\n    /// Average blood pressure\n    #[serde(rename = \"BP\")]\n    pub bp: f32,\n\n    /// S1: total serum cholesterol\n    #[serde(rename = \"S1\")]\n    pub tc: i16,\n\n    /// S2: low-density lipoproteins\n    #[serde(rename = \"S2\")]\n    pub ldl: f32,\n\n    /// S3: high-density lipoproteins\n    #[serde(rename = \"S3\")]\n    pub hdl: f32,\n\n    /// S4: total cholesterol\n    #[serde(rename = \"S4\")]\n    pub tch: f32,\n\n    /// S5: possibly log of serum triglycerides level\n    #[serde(rename = \"S5\")]\n    pub ltg: f32,\n\n    /// S6: blood sugar level\n    #[serde(rename = \"S6\")]\n    pub glu: i8,\n\n    /// Y: quantitative measure of disease progression one year after baseline\n    #[serde(rename = \"Y\")]\n    pub response: i16,\n}\n"
  },
  {
    "path": "examples/custom-csv-dataset/src/lib.rs",
    "content": "#[cfg(feature = \"dataframe\")]\npub mod dataframe_dataset;\npub mod dataset;\npub mod diabetes_patient;\npub mod utils;\n"
  },
  {
    "path": "examples/custom-csv-dataset/src/utils.rs",
    "content": "use std::{\n    fs::File,\n    io::copy,\n    path::{Path, PathBuf},\n};\n\n/// Download the CSV file from its original source on the web.\n/// Panics if the download cannot be completed or the content of the file cannot be written to disk.\npub fn download_csv_if_missing() -> PathBuf {\n    // Point file to current example directory\n    let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap();\n    let file_name = example_dir.join(\"diabetes.csv\");\n\n    if file_name.exists() {\n        println!(\"File already downloaded at {file_name:?}\");\n    } else {\n        // Get file from web\n        println!(\"Downloading file to {file_name:?}\");\n        let url = \"https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt\";\n        let mut response = reqwest::blocking::get(url).unwrap();\n\n        // Create file to write the downloaded content to\n        let mut file = File::create(&file_name).unwrap();\n\n        // Copy the downloaded contents\n        copy(&mut response, &mut file).unwrap();\n    };\n\n    file_name\n}\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-cubecl-kernel\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[dependencies]\nburn = { path = \"../../crates/burn\", default-features = false, features = [\n    \"autodiff\",\n    \"wgpu\",\n    \"autotune\",\n    \"template\",\n] }\nburn-cubecl = { path = \"../../crates/burn-cubecl\" }\ncubecl = { workspace = true, features = [\"wgpu\"] }\n\n# Serialization\nlog = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n\n# Wgpu internal dependencies\nbytemuck = { workspace = true }\nderive-new = { workspace = true }\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/examples/custom-cubecl-kernel.rs",
    "content": "use burn::{\n    backend::wgpu::WgpuRuntime,\n    tensor::{Distribution, Tensor, Tolerance},\n};\nuse custom_cubecl_kernel::{\n    AutodiffBackend, Backend, matmul_add_relu_custom, matmul_add_relu_reference,\n};\n\nfn inference<B: Backend>(device: &B::Device) {\n    let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device);\n    let rhs = Tensor::random([32, 32, 32], Distribution::Default, device);\n    let bias = Tensor::random([32, 32, 32], Distribution::Default, device);\n\n    let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone())\n        .into_data()\n        .convert::<f32>();\n    let custom = matmul_add_relu_custom(lhs, rhs, bias)\n        .into_data()\n        .convert::<f32>();\n\n    reference.assert_approx_eq::<f32>(&custom, Tolerance::default());\n\n    println!(\"Both reference and the custom fused kernel have the same output\");\n}\n\nfn autodiff<B: AutodiffBackend>(device: &B::Device) {\n    let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device).require_grad();\n    let rhs = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();\n    let bias = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();\n\n    let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone());\n\n    let mut gradients = reference.backward();\n\n    let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap();\n    let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap();\n    let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap();\n\n    let lhs = lhs.detach();\n    let rhs = rhs.detach();\n    let bias = bias.detach();\n\n    let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone());\n\n    let mut gradients = custom.backward();\n\n    let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap();\n    let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap();\n    let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap();\n\n    lhs_grad_ref\n        .into_data()\n        .convert::<B::FloatElem>()\n        .assert_approx_eq::<f32>(\n            &lhs_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same lhs gradient\");\n\n    rhs_grad_ref\n        .into_data()\n        .convert::<f32>()\n        .assert_approx_eq::<f32>(\n            &rhs_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same rhs gradient\");\n\n    bias_grad_ref\n        .into_data()\n        .convert::<f32>()\n        .assert_approx_eq::<f32>(\n            &bias_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same bias gradient\");\n}\n\nfn main() {\n    type MyBackend = burn::backend::wgpu::CubeBackend<WgpuRuntime, f32, i32, u32>;\n    type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;\n    let device = Default::default();\n    inference::<MyBackend>(&device);\n    autodiff::<MyAutodiffBackend>(&device);\n}\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/src/backward.rs",
    "content": "use crate::FloatTensor;\n\nuse super::{AutodiffBackend, Backend};\nuse burn::{\n    backend::autodiff::{\n        Autodiff, NodeId,\n        checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},\n        grads::Gradients,\n        ops::{Backward, Ops, OpsKind, broadcast_shape},\n    },\n    tensor::{Shape, TensorMetadata},\n};\nuse burn_cubecl::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};\n\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend\n    for Autodiff<CubeBackend<R, F, I, BT>>\n{\n}\n\n// Implement our custom backend trait for any backend that also implements our custom backend trait.\nimpl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Create our zero-sized type that will implement the Backward trait.\n        #[derive(Debug)]\n        struct FusedMatmulAddReluBackward;\n\n        // Implement the backward trait for the given backend B, the node gradient\n        // with three other gradients to calculate (lhs, rhs, and bias).\n        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {\n            // Our state that we must build during the forward pass to compute the backward pass.\n            //\n            // Note that we could improve the performance further by only keeping the state of\n            // tensors that are tracked, improving memory management, but for simplicity, we avoid\n            // that part.\n            type State = (NodeId, NodeId, FloatTensor<B>, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                // Get the nodes of each variable.\n                let [node_lhs, node_rhs, node_bias] = ops.parents;\n                // Fetch the gradient for the current node.\n                let grad = grads.consume::<B>(&ops.node);\n\n                // Set our state.\n                let (lhs_state, rhs_state, output, shape_bias) = ops.state;\n                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);\n                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);\n\n                // Fetch shapes of our tensor to support broadcasting.\n                let shape_lhs = lhs.shape();\n                let shape_rhs = rhs.shape();\n\n                // Compute the gradient of the output using the already existing `relu_backward`\n                // function in the basic Burn backend trait.\n                let grad_output = B::relu_backward(output, grad);\n\n                // Compute the lhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_lhs = broadcast_shape::<B>(\n                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),\n                    &shape_lhs,\n                );\n                // Compute the rhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_rhs = broadcast_shape::<B>(\n                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),\n                    &shape_rhs,\n                );\n                // The add derivative is only 1, so we just need to support broadcasting to\n                // compute the bias gradient.\n                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);\n\n                // Register the gradient for each variable based on whether they are marked as\n                // `tracked`.\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, grad_bias);\n                }\n                if let Some(node) = node_lhs {\n                    grads.register::<B>(node.id, grad_lhs);\n                }\n                if let Some(node) = node_rhs {\n                    grads.register::<B>(node.id, grad_rhs);\n                }\n            }\n        }\n\n        // Prepare a stateful operation with each variable node and corresponding graph.\n        //\n        // Each node can be fetched with `ops.parents` in the same order as defined here.\n        match FusedMatmulAddReluBackward\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])\n            // Marks the operation as compute bound, meaning it will save its\n            // state instead of recomputing itself during checkpointing\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                // When at least one node is tracked, we should register our backward step.\n\n                // The state consists of what will be needed for this operation's backward pass.\n                // Since we need the parents' outputs, we must checkpoint their ids to retrieve\n                // their node output at the beginning of the backward pass. We can also save\n                // utility data such as the bias shape. If we also need this operation's output,\n                // we can either save it in the state or recompute it.\n                // during the backward pass. Here we choose to save it in the state because it's a\n                // compute bound operation.\n                let lhs_state = prep.checkpoint(&lhs);\n                let rhs_state = prep.checkpoint(&rhs);\n                let bias_shape = bias.primitive.shape();\n\n                let output = B::fused_matmul_add_relu(\n                    lhs.primitive.clone(),\n                    rhs.primitive.clone(),\n                    bias.primitive,\n                );\n\n                let state = (lhs_state, rhs_state, output.clone(), bias_shape);\n\n                prep.finish(state, output)\n            }\n            OpsKind::UnTracked(prep) => {\n                // When no node is tracked, we can just compute the original operation without\n                // keeping any state.\n                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);\n                prep.finish(output)\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/src/forward.rs",
    "content": "use crate::{FloatTensor, kernel::fused_matmul_add_relu_kernel};\n\nuse super::Backend;\nuse burn::tensor::Shape;\nuse burn_cubecl::{\n    CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement,\n    kernel::into_contiguous, tensor::CubeTensor,\n};\nuse cubecl::{CubeCount, CubeDim};\n\n/// Implement our custom backend trait for the generic `CubeBackend`.\nimpl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> Backend\n    for CubeBackend<R, F, I, BT>\n{\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Define cube dim, hardcoded for simplicity.\n        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };\n\n        lhs.assert_is_on_same_device(&rhs);\n        lhs.assert_is_on_same_device(&bias);\n\n        // For simplicity, make sure each tensor is continuous.\n        let lhs = into_contiguous(lhs);\n        let rhs = into_contiguous(rhs);\n        let bias = into_contiguous(bias);\n\n        // Get the matmul relevant shapes.\n        let ndims = lhs.meta.num_dims();\n        let num_rows = lhs.meta.shape()[ndims - 2];\n        let num_cols = rhs.meta.shape()[ndims - 1];\n\n        // Compute shape of output, while tracking number of batches.\n        let mut num_batches = 1;\n        let mut shape_out = vec![0; ndims];\n        for i in shape_out.clone().into_iter().take(ndims - 2) {\n            shape_out[i] = usize::max(lhs.meta.shape()[i], rhs.meta.shape()[i]);\n            num_batches *= shape_out[i];\n        }\n        shape_out[ndims - 2] = num_rows;\n        shape_out[ndims - 1] = num_cols;\n        let shape_out = Shape::from(shape_out);\n\n        // Create a buffer for the output tensor.\n        let buffer = lhs\n            .client\n            .empty(shape_out.num_elements() * core::mem::size_of::<F>());\n\n        // Create the output tensor primitive.\n        let output = CubeTensor::new_contiguous(\n            lhs.client.clone(),\n            lhs.device.clone(),\n            shape_out,\n            buffer,\n            F::dtype(),\n        );\n\n        // Declare the wgsl workgroup with the number of cubes in x, y and z.\n        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;\n        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;\n        let cube_count =\n            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);\n\n        // Execute lazily the kernel with the launch information and the given buffers. For\n        // simplicity, no vectorization is performed\n        fused_matmul_add_relu_kernel::launch::<F, R>(\n            &output.client,\n            cube_count,\n            cube_dim,\n            lhs.into_tensor_arg(),\n            rhs.into_tensor_arg(),\n            bias.into_tensor_arg(),\n            output.clone().into_tensor_arg(),\n        );\n\n        // Return the output tensor.\n        output\n    }\n}\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/src/kernel.rs",
    "content": "use cubecl::{cube, prelude::*};\n\n/// Declare a custom kernel that gets compiled to `wgpu`/`CUDA`\n#[cube(launch)]\npub fn fused_matmul_add_relu_kernel<F: Float>(\n    lhs: &Tensor<F>,\n    rhs: &Tensor<F>,\n    bias: &Tensor<F>,\n    output: &mut Tensor<F>,\n) {\n    let row = ABSOLUTE_POS_X as usize;\n    let col = ABSOLUTE_POS_Y as usize;\n    let batch = ABSOLUTE_POS_Z as usize;\n\n    let n_rows = output.shape(output.rank() - 2);\n    let n_cols = output.shape(output.rank() - 1);\n    let dim_k = rhs.shape(rhs.rank() - 1);\n\n    if row >= n_rows || col >= n_cols {\n        terminate!();\n    }\n\n    let offset_output = batch * n_rows * n_cols;\n    let mut offset_lhs = 0;\n    let mut offset_rhs = 0;\n\n    let batch_dims = output.rank() - 2;\n    for dim in 0..batch_dims {\n        offset_lhs += offset_output / output.stride(dim) % lhs.shape(dim) * lhs.stride(dim);\n        offset_rhs += offset_output / output.stride(dim) % rhs.shape(dim) * rhs.stride(dim);\n    }\n\n    let mut sum = F::new(0.0);\n    for k in 0..dim_k {\n        let lhs_index = row * dim_k + k;\n        let rhs_index = k * n_cols + col;\n\n        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];\n    }\n\n    let out_index = row * n_cols + col;\n    let index = offset_output + out_index;\n\n    output[index] = F::max(sum + bias[index], F::new(0.0));\n}\n"
  },
  {
    "path": "examples/custom-cubecl-kernel/src/lib.rs",
    "content": "mod backward;\nmod forward;\nmod kernel;\n\nuse burn::tensor::{Tensor, TensorPrimitive, activation, ops::FloatTensor};\n\n/// We create our own Backend trait that extends the Burn backend trait.\npub trait Backend: burn::tensor::backend::Backend {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self>;\n}\n\n/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.\npub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}\n\n/// We define our custom implementation using the added function on our custom backend.\npub fn matmul_add_relu_custom<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let output = B::fused_matmul_add_relu(\n        lhs.into_primitive().tensor(),\n        rhs.into_primitive().tensor(),\n        bias.into_primitive().tensor(),\n    );\n\n    Tensor::from_primitive(TensorPrimitive::Float(output))\n}\n\n/// We define a reference implementation using basic tensor operations.\npub fn matmul_add_relu_reference<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let x = lhs.matmul(rhs) + bias;\n\n    activation::relu(x)\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/.gitignore",
    "content": "# Ignore downloaded dataset\ncifar10/\n*.txt\n*.png"
  },
  {
    "path": "examples/custom-image-dataset/Cargo.toml",
    "content": "[package]\nauthors = [\"laggui <lagrange.guillaume.1@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-image-dataset\"\ndescription = \"Example implementation for loading a custom image dataset from disk\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/std\", \"burn/tui\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\", \"burn/fusion\"]\nmetal = [\"burn/metal\", \"burn/fusion\"]\n\n[dependencies]\n# Disable autotune default for now (convolutions not optimized)\nburn = { path = \"../../crates/burn\", features = [\n    \"train\",\n    \"vision\",\n    \"network\",\n], default-features = false }\n\n\n# File download\nflate2 = { workspace = true }\ntar = \"0.4.44\"\n"
  },
  {
    "path": "examples/custom-image-dataset/README.md",
    "content": "# Training on a Custom Image Dataset\n\nIn this example, a [simple CNN](src/model.rs) model is trained from scratch on the\n[CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) by leveraging the\n`ImageFolderDataset` struct to retrieve images from a folder structure on disk.\n\nSince the original source is in binary format, the data is downloaded from a\n[fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44) in a\nfolder structure with `.png` images.\n\n```\ncifar10\n├── labels.txt\n├── test\n│   ├── airplane\n│   ├── automobile\n│   ├── bird\n│   ├── cat\n│   ├── deer\n│   ├── dog\n│   ├── frog\n│   ├── horse\n│   ├── ship\n│   └── truck\n└── train\n    ├── airplane\n    ├── automobile\n    ├── bird\n    ├── cat\n    ├── deer\n    ├── dog\n    ├── frog\n    ├── horse\n    ├── ship\n    └── truck\n```\n\nTo load the training and test dataset splits, it is as simple as providing the root path to both\nfolders\n\n```rust\nlet train_ds = ImageFolderDataset::new_classification(\"/path/to/cifar10/train\").unwrap();\nlet test_ds = ImageFolderDataset::new_classification(\"/path/to/cifar10/test\").unwrap();\n```\n\nas is done in [`CIFAR10Loader`](src/dataset.rs) for this example.\n\n## Example Usage\n\nThe CNN model and training recipe used in this example are fairly simple since the objective is to\ndemonstrate how to load a custom image classification dataset from disk. Nonetheless, it still\nachieves 70-80% accuracy on the test set after just 30 epochs.\n\nRun it with the Torch GPU backend:\n\n```sh\nexport TORCH_CUDA_VERSION=cu128\ncargo run --example custom-image-dataset --release --features tch-gpu\n```\n\nRun it with our WGPU backend:\n\n```sh\ncargo run --example custom-image-dataset --release --features wgpu\n```\n\nRun it with our Metal backend:\n\n```sh\ncargo run --example custom-image-dataset --release --features metal\n```\n"
  },
  {
    "path": "examples/custom-image-dataset/examples/custom-image-dataset.rs",
    "content": "#![recursion_limit = \"256\"]\n\nuse burn::optim::{SgdConfig, momentum::MomentumConfig};\nuse custom_image_dataset::training::TrainingConfig;\n\n// Import only when backend features are enabled\n#[cfg(any(feature = \"tch-gpu\", feature = \"wgpu\", feature = \"metal\"))]\nuse {burn::backend::Autodiff, custom_image_dataset::training::train};\n\n/// Creates a training configuration with SGD optimizer and momentum.\nfn create_config() -> TrainingConfig {\n    TrainingConfig::new(SgdConfig::new().with_momentum(Some(MomentumConfig {\n        momentum: 0.9,\n        dampening: 0.,\n        nesterov: false,\n    })))\n}\n\nfn main() {\n    #[allow(unused_variables)]\n    let config = create_config();\n\n    #[cfg(feature = \"tch-gpu\")]\n    {\n        use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        train::<Autodiff<LibTorch>>(config, device);\n    }\n\n    #[cfg(feature = \"wgpu\")]\n    {\n        use burn::backend::wgpu::{Wgpu, WgpuDevice};\n        train::<Autodiff<Wgpu>>(config, WgpuDevice::default());\n    }\n\n    #[cfg(feature = \"metal\")]\n    {\n        // Note: Metal backend may have shader compilation issues on Intel Macs with AMD GPUs\n        // If you encounter errors, use WGPU backend as an alternative\n        use burn::backend::wgpu::{Metal, WgpuDevice};\n        train::<Autodiff<Metal>>(config, WgpuDevice::default());\n    }\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/src/data.rs",
    "content": "use burn::{\n    data::{\n        dataloader::batcher::Batcher,\n        dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},\n    },\n    prelude::*,\n};\n\n// CIFAR-10 mean and std values\nconst MEAN: [f32; 3] = [0.4914, 0.48216, 0.44653];\nconst STD: [f32; 3] = [0.24703, 0.24349, 0.26159];\n\n/// Normalizer for the CIFAR-10 dataset.\n#[derive(Clone)]\npub struct Normalizer<B: Backend> {\n    pub mean: Tensor<B, 4>,\n    pub std: Tensor<B, 4>,\n}\n\nimpl<B: Backend> Normalizer<B> {\n    /// Creates a new normalizer.\n    pub fn new(device: &Device<B>) -> Self {\n        let mean = Tensor::<B, 1>::from_floats(MEAN, device).reshape([1, 3, 1, 1]);\n        let std = Tensor::<B, 1>::from_floats(STD, device).reshape([1, 3, 1, 1]);\n        Self { mean, std }\n    }\n\n    /// Normalizes the input image according to the CIFAR-10 dataset.\n    ///\n    /// The input image should be in the range [0, 1].\n    /// The output image will be in the range [-1, 1].\n    ///\n    /// The normalization is done according to the following formula:\n    /// `input = (input - mean) / std`\n    pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        (input - self.mean.clone()) / self.std.clone()\n    }\n\n    /// Returns a new normalizer on the given device.\n    pub fn to_device(&self, device: &B::Device) -> Self {\n        Self {\n            mean: self.mean.clone().to_device(device),\n            std: self.std.clone().to_device(device),\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct ClassificationBatcher<B: Backend> {\n    normalizer: Normalizer<B>,\n}\n\n#[derive(Clone, Debug)]\npub struct ClassificationBatch<B: Backend> {\n    pub images: Tensor<B, 4>,\n    pub targets: Tensor<B, 1, Int>,\n    pub images_path: Vec<String>,\n}\n\nimpl<B: Backend> ClassificationBatcher<B> {\n    pub fn new(device: B::Device) -> Self {\n        Self {\n            normalizer: Normalizer::<B>::new(&device),\n        }\n    }\n}\n\nimpl<B: Backend> Batcher<B, ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> {\n    fn batch(&self, items: Vec<ImageDatasetItem>, device: &B::Device) -> ClassificationBatch<B> {\n        fn image_as_vec_u8(item: ImageDatasetItem) -> Vec<u8> {\n            // Convert Vec<PixelDepth> to Vec<u8> (we know that CIFAR images are u8)\n            item.image\n                .into_iter()\n                .map(|p: PixelDepth| -> u8 { p.try_into().unwrap() })\n                .collect::<Vec<u8>>()\n        }\n\n        let targets = items\n            .iter()\n            .map(|item| {\n                // Expect class label (int) as target\n                if let Annotation::Label(y) = item.annotation {\n                    Tensor::<B, 1, Int>::from_data(\n                        TensorData::from([(y as i64).elem::<B::IntElem>()]),\n                        device,\n                    )\n                } else {\n                    panic!(\"Invalid target type\")\n                }\n            })\n            .collect();\n\n        // Original sample path\n        let images_path: Vec<String> = items.iter().map(|item| item.image_path.clone()).collect();\n\n        let images = items\n            .into_iter()\n            .map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([32, 32, 3])))\n            .map(|data| {\n                Tensor::<B, 3>::from_data(data.convert::<B::FloatElem>(), device)\n                    // permute(2, 0, 1)\n                    .swap_dims(2, 1) // [H, C, W]\n                    .swap_dims(1, 0) // [C, H, W]\n            })\n            .map(|tensor| tensor / 255) // normalize between [0, 1]\n            .collect();\n\n        let images = Tensor::stack(images, 0);\n        let targets = Tensor::cat(targets, 0);\n\n        let images = self.normalizer.to_device(device).normalize(images);\n\n        ClassificationBatch {\n            images,\n            targets,\n            images_path,\n        }\n    }\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/src/dataset.rs",
    "content": "use flate2::read::GzDecoder;\nuse std::path::{Path, PathBuf};\nuse tar::Archive;\n\nuse burn::data::{dataset::vision::ImageFolderDataset, network::downloader};\n\n/// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).\n/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE).\nconst URL: &str = \"https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz\";\n\n/// The [CIFAR-10](https://www.cs.toronto.edu/%7Ekriz/cifar.html) dataset consists of 60,000 32x32\n/// colour images, with 6,000 images per class. There are 50,000 training images and 10,000 test\n/// images.\n///\n/// The data is downloaded from the web from the [fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44).\npub trait CIFAR10Loader {\n    fn cifar10_train() -> Self;\n    fn cifar10_test() -> Self;\n}\n\nimpl CIFAR10Loader for ImageFolderDataset {\n    /// Creates a new CIFAR10 train dataset.\n    fn cifar10_train() -> Self {\n        let root = download();\n\n        Self::new_classification(root.join(\"train\")).unwrap()\n    }\n\n    /// Creates a new CIFAR10 test dataset.\n    fn cifar10_test() -> Self {\n        let root = download();\n\n        Self::new_classification(root.join(\"test\")).unwrap()\n    }\n}\n\n/// Download the CIFAR10 dataset from the web to the current example directory.\nfn download() -> PathBuf {\n    // Point to current example directory\n    let example_dir = Path::new(env!(\"CARGO_MANIFEST_DIR\"));\n    let cifar_dir = example_dir.join(\"cifar10\");\n\n    // Check for already downloaded content\n    let labels_file = cifar_dir.join(\"labels.txt\");\n    if !labels_file.exists() {\n        // Download gzip file\n        let bytes = downloader::download_file_as_bytes(URL, \"cifar10.tgz\");\n\n        // Decode gzip file content and unpack archive\n        let gz_buffer = GzDecoder::new(&bytes[..]);\n        let mut archive = Archive::new(gz_buffer);\n        archive.unpack(example_dir).unwrap();\n    }\n\n    cifar_dir\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/src/inference.rs",
    "content": "use burn::{\n    data::{\n        dataloader::batcher::Batcher,\n        dataset::vision::{Annotation, ImageDatasetItem},\n    },\n    module::Module,\n    record::{CompactRecorder, Recorder},\n    tensor::backend::Backend,\n};\n\nuse crate::{data::ClassificationBatcher, model::Cnn};\n\nconst NUM_CLASSES: u8 = 10;\n\npub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: ImageDatasetItem) {\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model should exist\");\n\n    let model: Cnn<B> = Cnn::new(NUM_CLASSES.into(), &device).load_record(record);\n\n    let mut label = 0;\n    if let Annotation::Label(category) = item.annotation {\n        label = category;\n    };\n    let batcher = ClassificationBatcher::new(device);\n    let batch = batcher.batch(vec![item]);\n    let output = model.forward(batch.images);\n    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();\n    println!(\"Predicted {predicted} Expected {label:?}\");\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/src/lib.rs",
    "content": "pub mod data;\npub mod dataset;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/custom-image-dataset/src/model.rs",
    "content": "use burn::{\n    nn::{\n        Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,\n        conv::{Conv2d, Conv2dConfig},\n        pool::{MaxPool2d, MaxPool2dConfig},\n    },\n    prelude::*,\n};\n\n/// Basic convolutional neural network with VGG-style blocks.\n//\n//       VGG block\n// ┌────────────────────┐\n// │      3x3 conv      │\n// │          ↓         │\n// │     activation     │\n// │          ↓         │\n// │      3x3 conv      │\n// │          ↓         │\n// │     activation     │\n// │          ↓         │\n// │       maxpool      │\n// └────────────────────┘\n#[derive(Module, Debug)]\npub struct Cnn<B: Backend> {\n    activation: Relu,\n    dropout: Dropout,\n    pool: MaxPool2d,\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n    conv3: Conv2d<B>,\n    conv4: Conv2d<B>,\n    conv5: Conv2d<B>,\n    conv6: Conv2d<B>,\n    fc1: Linear<B>,\n    fc2: Linear<B>,\n}\n\nimpl<B: Backend> Cnn<B> {\n    pub fn new(num_classes: usize, device: &Device<B>) -> Self {\n        let conv1 = Conv2dConfig::new([3, 32], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n        let conv2 = Conv2dConfig::new([32, 32], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n\n        let conv3 = Conv2dConfig::new([32, 64], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n        let conv4 = Conv2dConfig::new([64, 64], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n\n        let conv5 = Conv2dConfig::new([64, 128], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n        let conv6 = Conv2dConfig::new([128, 128], [3, 3])\n            .with_padding(PaddingConfig2d::Same)\n            .init(device);\n\n        let pool = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();\n\n        let fc1 = LinearConfig::new(2048, 128).init(device);\n        let fc2 = LinearConfig::new(128, num_classes).init(device);\n\n        let dropout = DropoutConfig::new(0.3).init();\n\n        Self {\n            activation: Relu::new(),\n            dropout,\n            pool,\n            conv1,\n            conv2,\n            conv3,\n            conv4,\n            conv5,\n            conv6,\n            fc1,\n            fc2,\n        }\n    }\n\n    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {\n        let x = self.conv1.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.conv2.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.pool.forward(x);\n        let x = self.dropout.forward(x);\n\n        let x = self.conv3.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.conv4.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.pool.forward(x);\n        let x = self.dropout.forward(x);\n\n        let x = self.conv5.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.conv6.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.pool.forward(x);\n        let x = self.dropout.forward(x);\n\n        let x = x.flatten(1, 3);\n\n        let x = self.fc1.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.fc2.forward(x)\n    }\n}\n"
  },
  {
    "path": "examples/custom-image-dataset/src/training.rs",
    "content": "use std::time::Instant;\n\nuse crate::{\n    data::{ClassificationBatch, ClassificationBatcher},\n    dataset::CIFAR10Loader,\n    model::Cnn,\n};\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset},\n    nn::loss::CrossEntropyLossConfig,\n    optim::SgdConfig,\n    prelude::*,\n    record::CompactRecorder,\n    tensor::backend::AutodiffBackend,\n    train::{\n        ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep,\n        metric::{AccuracyMetric, LossMetric},\n    },\n};\n\nconst NUM_CLASSES: u8 = 10;\nconst ARTIFACT_DIR: &str = \"/tmp/custom-image-dataset\";\n\nimpl<B: Backend> Cnn<B> {\n    pub fn forward_classification(\n        &self,\n        images: Tensor<B, 4>,\n        targets: Tensor<B, 1, Int>,\n    ) -> ClassificationOutput<B> {\n        let output = self.forward(images);\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output.device())\n            .forward(output.clone(), targets.clone());\n\n        ClassificationOutput::new(loss, output, targets)\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for Cnn<B> {\n    type Input = ClassificationBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: ClassificationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_classification(batch.images, batch.targets);\n\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for Cnn<B> {\n    type Input = ClassificationBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: ClassificationBatch<B>) -> ClassificationOutput<B> {\n        self.forward_classification(batch.images, batch.targets)\n    }\n}\n\n#[derive(Config, Debug)]\npub struct TrainingConfig {\n    pub optimizer: SgdConfig,\n    #[config(default = 30)]\n    pub num_epochs: usize,\n    #[config(default = 128)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 0.02)]\n    pub learning_rate: f64,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {\n    create_artifact_dir(ARTIFACT_DIR);\n\n    config\n        .save(format!(\"{ARTIFACT_DIR}/config.json\"))\n        .expect(\"Config should be saved successfully\");\n\n    B::seed(&device, config.seed);\n\n    // Dataloaders\n    let batcher_train = ClassificationBatcher::<B>::new(device.clone());\n    let batcher_valid = ClassificationBatcher::<B::InnerBackend>::new(device.clone());\n\n    let dataloader_train = DataLoaderBuilder::new(batcher_train)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(ImageFolderDataset::cifar10_train());\n\n    // NOTE: we use the CIFAR-10 test set as validation for demonstration purposes\n    let dataloader_test = DataLoaderBuilder::new(batcher_valid)\n        .batch_size(config.batch_size)\n        .num_workers(config.num_workers)\n        .build(ImageFolderDataset::cifar10_test());\n\n    // Learner config\n    let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_test)\n        .metric_train_numeric(AccuracyMetric::new())\n        .metric_valid_numeric(AccuracyMetric::new())\n        .metric_train_numeric(LossMetric::new())\n        .metric_valid_numeric(LossMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let model = Cnn::new(NUM_CLASSES.into(), &device);\n\n    // Training\n    let now = Instant::now();\n    let result = training.launch(Learner::new(\n        model,\n        config.optimizer.init(),\n        config.learning_rate,\n    ));\n    let elapsed = now.elapsed().as_secs();\n    println!(\"Training completed in {}m{}s\", (elapsed / 60), elapsed % 60);\n\n    result\n        .model\n        .save_file(format!(\"{ARTIFACT_DIR}/model\"), &CompactRecorder::new())\n        .expect(\"Trained model should be saved successfully\");\n}\n"
  },
  {
    "path": "examples/custom-learning-strategy/Cargo.toml",
    "content": "[package]\nname = \"custom-learning-strategy\"\nedition.workspace = true\nlicense.workspace = true\nversion.workspace = true\npublish = false\n\n[lints]\nworkspace = true\n\n[dependencies]\nburn = {path = \"../../crates/burn\", default-features = false, features=[\"autodiff\", \"webgpu\", \"vision\"]}\nguide = {path = \"../guide\"}\nderive-new = { workspace = true }\nlog = { workspace = true }"
  },
  {
    "path": "examples/custom-learning-strategy/examples/custom-learning-strategy.rs",
    "content": "use burn::backend::{Autodiff, WebGpu};\n\nfn main() {\n    custom_learning_strategy::training::run::<Autodiff<WebGpu>>(Default::default());\n}\n"
  },
  {
    "path": "examples/custom-learning-strategy/src/lib.rs",
    "content": "pub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/custom-learning-strategy/src/model.rs",
    "content": "use burn::{\n    nn::{\n        Dropout, DropoutConfig, Linear, LinearConfig, Relu,\n        conv::{Conv2d, Conv2dConfig},\n        loss::CrossEntropyLossConfig,\n        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},\n    },\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n    train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep},\n};\nuse guide::data::MnistBatch;\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n    pool: AdaptiveAvgPool2d,\n    dropout: Dropout,\n    linear1: Linear<B>,\n    linear2: Linear<B>,\n    activation: Relu,\n}\n\n#[derive(Config, Debug)]\npub struct ModelConfig {\n    num_classes: usize,\n    hidden_size: usize,\n    #[config(default = \"0.5\")]\n    dropout: f64,\n}\n\nimpl ModelConfig {\n    /// Returns the initialized model.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {\n        Model {\n            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),\n            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),\n            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),\n            activation: Relu::new(),\n            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),\n            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n        }\n    }\n}\n\nimpl<B: Backend> Model<B> {\n    /// # Shapes\n    ///   - Images [batch_size, height, width]\n    ///   - Output [batch_size, class_prob]\n    pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = images.dims();\n\n        // Create a channel.\n        let x = images.reshape([batch_size, 1, height, width]);\n\n        let x = self.conv1.forward(x); // [batch_size, 8, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.conv2.forward(x); // [batch_size, 16, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        let x = self.pool.forward(x); // [batch_size, 16, 8, 8]\n        let x = x.reshape([batch_size, 16 * 8 * 8]);\n        let x = self.linear1.forward(x);\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        self.linear2.forward(x) // [batch_size, num_classes]\n    }\n\n    pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {\n        let targets = item.targets;\n        let output = self.forward(item.images);\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output.device())\n            .forward(output.clone(), targets.clone());\n\n        ClassificationOutput {\n            loss,\n            output,\n            targets,\n        }\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_classification(item);\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {\n        self.forward_classification(batch)\n    }\n}\n"
  },
  {
    "path": "examples/custom-learning-strategy/src/training.rs",
    "content": "use crate::model::ModelConfig;\nuse burn::data::dataloader::Progress;\nuse burn::record::NoStdTrainingRecorder;\nuse burn::tensor::backend::DeviceOps;\nuse burn::train::{\n    EventProcessorTraining, Learner, LearningComponentsTypes, SupervisedLearningStrategy,\n    SupervisedTraining, SupervisedTrainingEventProcessor, TrainLoader, TrainingBackend,\n    TrainingComponents, TrainingModel, ValidLoader,\n};\nuse burn::{\n    data::{\n        dataloader::DataLoaderBuilder,\n        dataset::{transform::PartialDataset, vision::MnistDataset},\n    },\n    lr_scheduler::{\n        composed::ComposedLrSchedulerConfig, cosine::CosineAnnealingLrSchedulerConfig,\n        linear::LinearLrSchedulerConfig,\n    },\n    module::AutodiffModule,\n    optim::AdamConfig,\n    prelude::*,\n    record::CompactRecorder,\n    tensor::{Device, backend::AutodiffBackend},\n    train::{\n        InferenceStep, LearnerEvent, MetricEarlyStoppingStrategy, StoppingCondition, TrainingItem,\n        metric::{\n            AccuracyMetric, LossMetric,\n            store::{Aggregate, Direction, Split},\n        },\n    },\n};\nuse guide::data::MnistBatcher;\nuse std::{marker::PhantomData, sync::Arc};\n\nstatic ARTIFACT_DIR: &str = \"/tmp/burn-example-mnist\";\n\n#[derive(Config, Debug)]\npub struct MnistTrainingConfig {\n    #[config(default = 5)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1e-4)]\n    pub lr: f64,\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    create_artifact_dir(ARTIFACT_DIR);\n    // Config\n    let config_model = ModelConfig::new(10, 1024);\n    let config_optimizer = AdamConfig::new();\n    let config = MnistTrainingConfig::new(config_model, config_optimizer);\n\n    B::seed(&device, config.seed);\n\n    let model = config.model.init::<B>(&device);\n\n    let dataset_train_original = Arc::new(MnistDataset::train());\n    let dataset_train = PartialDataset::new(dataset_train_original.clone(), 0, 15_000);\n    let dataset_valid = PartialDataset::new(dataset_train_original.clone(), 15_000, 17_000);\n\n    let lr_scheduler = ComposedLrSchedulerConfig::new()\n        .cosine(CosineAnnealingLrSchedulerConfig::new(1.0, 2000))\n        // Warmup\n        .linear(LinearLrSchedulerConfig::new(1e-8, 1.0, 2000))\n        .linear(LinearLrSchedulerConfig::new(1e-2, 1e-6, 10000));\n    let early_stopping = MetricEarlyStoppingStrategy::new(\n        &LossMetric::<B>::new(),\n        Aggregate::Mean,\n        Direction::Lowest,\n        Split::Valid,\n        StoppingCondition::NoImprovementSince { n_epochs: 5 },\n    );\n\n    let dataloader_train = DataLoaderBuilder::new(MnistBatcher::default())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(dataset_train);\n\n    let dataloader_valid = DataLoaderBuilder::new(MnistBatcher::default())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(dataset_valid);\n\n    let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_valid)\n        .metrics((AccuracyMetric::new(), LossMetric::new()))\n        .with_file_checkpointer(CompactRecorder::new())\n        .early_stopping(early_stopping)\n        .num_epochs(config.num_epochs)\n        .summary()\n        .with_training_strategy(burn::train::TrainingStrategy::Custom(Arc::new(\n            MyCustomLearningStrategy::new(device),\n        )));\n\n    let result = training.launch(Learner::new(\n        model,\n        config.optimizer.init(),\n        lr_scheduler.init().unwrap(),\n    ));\n\n    result\n        .model\n        .save_file(\n            format!(\"{ARTIFACT_DIR}/model\"),\n            &NoStdTrainingRecorder::new(),\n        )\n        .expect(\"Failed to save trained model\");\n}\n\nstruct MyCustomLearningStrategy<LC: LearningComponentsTypes> {\n    device: Device<TrainingBackend<LC>>,\n    _p: PhantomData<LC>,\n}\n\nimpl<LC: LearningComponentsTypes> MyCustomLearningStrategy<LC> {\n    pub fn new(device: Device<TrainingBackend<LC>>) -> Self {\n        Self {\n            device,\n            _p: PhantomData,\n        }\n    }\n}\n\nimpl<LC: LearningComponentsTypes> SupervisedLearningStrategy<LC> for MyCustomLearningStrategy<LC> {\n    fn fit(\n        &self,\n        training_components: TrainingComponents<LC>,\n        mut learner: Learner<LC>,\n        dataloader_train: TrainLoader<LC>,\n        dataloader_valid: ValidLoader<LC>,\n        starting_epoch: usize,\n    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>) {\n        let dataloader_train = dataloader_train.to_device(&self.device);\n        let dataloader_valid = dataloader_valid.to_device(self.device.inner());\n        learner.fork(&self.device);\n        let mut event_processor = training_components.event_processor;\n        let mut checkpointer = training_components.checkpointer;\n        let interrupter = training_components.interrupter;\n        let num_epochs = training_components.num_epochs;\n\n        for epoch in starting_epoch..num_epochs + 1 {\n            // Iterate over our training and validation loop for X epochs.\n            log::info!(\"Executing training step for epoch {}\", epoch,);\n\n            // Single device / dataloader\n            let mut iterator = dataloader_train.iter();\n            let mut iteration = 0;\n\n            while let Some(item) = iterator.next() {\n                iteration += 1;\n                learner.lr_step();\n                log::info!(\"Iteration {iteration} of my custom learning strategy\");\n\n                let progress = iterator.progress();\n                let item = learner.train_step(item);\n                learner.optimizer_step(item.grads);\n\n                let item = TrainingItem::new(\n                    item.item,\n                    progress,\n                    Progress::new(epoch, num_epochs),\n                    Some(iteration),\n                    Some(learner.lr_current()),\n                );\n\n                event_processor.process_train(LearnerEvent::ProcessedItem(item));\n\n                if interrupter.should_stop() {\n                    let reason = interrupter\n                        .get_message()\n                        .unwrap_or(String::from(\"Reason unknown\"));\n                    log::info!(\"Training interrupted: {reason}\");\n                    break;\n                }\n            }\n            event_processor.process_train(LearnerEvent::EndEpoch(epoch));\n\n            let model_valid = learner.model().valid();\n\n            let mut iterator = dataloader_valid.iter();\n            let mut iteration = 0;\n\n            while let Some(item) = iterator.next() {\n                let progress = iterator.progress();\n                iteration += 1;\n\n                let item = model_valid.step(item);\n                let item = TrainingItem::new(\n                    item,\n                    progress,\n                    Progress::new(epoch, num_epochs),\n                    Some(iteration),\n                    None,\n                );\n\n                event_processor.process_valid(LearnerEvent::ProcessedItem(item));\n            }\n            event_processor.process_valid(LearnerEvent::EndEpoch(epoch));\n\n            if let Some(checkpointer) = &mut checkpointer {\n                checkpointer.checkpoint(&learner, epoch, &training_components.event_store);\n            }\n        }\n\n        (learner.model(), event_processor)\n    }\n}\n"
  },
  {
    "path": "examples/custom-renderer/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\", \"Ankitects Pty Ltd\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-renderer\"\ndescription = \"Example of how to render training progress outside of the tui\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[dependencies]\nburn = {path = \"../../crates/burn\", features=[\"autodiff\", \"wgpu\", \"train\", \"dataset\", \"vision\"], default-features=false}\nguide = {path = \"../guide\", default-features=false}\n\n# Serialization\nlog = {workspace = true}\nserde = {workspace = true, features = [\"std\", \"derive\"]}\n\n# Wgpu internal dependencies\nderive-new = { workspace = true }\nbytemuck = { workspace = true }\n"
  },
  {
    "path": "examples/custom-renderer/examples/custom-renderer.rs",
    "content": "use burn::backend::{Autodiff, WebGpu, wgpu::WgpuDevice};\n\nfn main() {\n    custom_renderer::run::<Autodiff<WebGpu>>(WgpuDevice::default());\n}\n"
  },
  {
    "path": "examples/custom-renderer/src/lib.rs",
    "content": "use burn::{\n    config::Config,\n    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n    optim::AdamConfig,\n    tensor::backend::AutodiffBackend,\n    train::{\n        Learner, SupervisedTraining,\n        renderer::{\n            EvaluationName, EvaluationProgress, MetricState, MetricsRenderer,\n            MetricsRendererEvaluation, MetricsRendererTraining, ProgressType, TrainingProgress,\n        },\n    },\n};\nuse guide::{data::MnistBatcher, model::ModelConfig};\n\n#[derive(Config, Debug)]\npub struct MnistTrainingConfig {\n    #[config(default = 10)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1e-4)]\n    pub lr: f64,\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n}\n\nstruct CustomRenderer {}\n\nimpl MetricsRendererTraining for CustomRenderer {\n    fn update_train(&mut self, _state: MetricState) {}\n\n    fn update_valid(&mut self, _state: MetricState) {}\n\n    fn render_train(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {\n        dbg!(item);\n    }\n\n    fn render_valid(&mut self, item: TrainingProgress, _progress_indicators: Vec<ProgressType>) {\n        dbg!(item);\n    }\n}\n\nimpl MetricsRenderer for CustomRenderer {\n    fn manual_close(&mut self) {\n        // Nothing to do.\n    }\n\n    fn register_metric(&mut self, _definition: burn::train::metric::MetricDefinition) {}\n}\n\nimpl MetricsRendererEvaluation for CustomRenderer {\n    fn update_test(&mut self, _name: EvaluationName, _state: MetricState) {}\n\n    fn render_test(&mut self, item: EvaluationProgress, _progress_indicators: Vec<ProgressType>) {\n        dbg!(item);\n    }\n}\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    // Create the configuration.\n    let config_model = ModelConfig::new(10, 1024);\n    let config_optimizer = AdamConfig::new();\n    let config = MnistTrainingConfig::new(config_model, config_optimizer);\n\n    B::seed(&device, config.seed);\n\n    // Create the model and optimizer.\n    let model = config.model.init::<B>(&device);\n    let optim = config.optimizer.init();\n\n    // Create the batcher.\n    let batcher = MnistBatcher::default();\n\n    // Create the dataloaders.\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::test());\n\n    // artifact dir does not need to be provided when log_to_file is false\n    let training = SupervisedTraining::new(\"\", dataloader_train, dataloader_test)\n        .num_epochs(config.num_epochs)\n        .renderer(CustomRenderer {})\n        .with_application_logger(None);\n    // can be used to interrupt training\n    let _interrupter = training.interrupter();\n\n    let _model_trained = training.launch(Learner::new(model, optim, config.lr));\n}\n"
  },
  {
    "path": "examples/custom-training-loop/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-training-loop\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[dependencies]\nburn = {path = \"../../crates/burn\", features=[\"autodiff\", \"webgpu\", \"vision\"]}\nguide = {path = \"../guide\"}\n\n# Serialization\nlog = {workspace = true}\nserde = {workspace = true, features = [\"std\", \"derive\"]}\n\n# Wgpu internal dependencies\nderive-new = { workspace = true }\nbytemuck = { workspace = true }\n"
  },
  {
    "path": "examples/custom-training-loop/examples/custom-training-loop.rs",
    "content": "use burn::backend::{Autodiff, WebGpu};\n\nfn main() {\n    custom_training_loop::run::<Autodiff<WebGpu>>(Default::default());\n}\n"
  },
  {
    "path": "examples/custom-training-loop/src/lib.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n    module::AutodiffModule,\n    nn::loss::CrossEntropyLoss,\n    optim::{AdamConfig, GradientsParams, Optimizer},\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n};\nuse guide::{\n    data::{MnistBatch, MnistBatcher},\n    model::{Model, ModelConfig},\n};\n\n#[derive(Config, Debug)]\npub struct MnistTrainingConfig {\n    #[config(default = 10)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1e-4)]\n    pub lr: f64,\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n}\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    // Create the configuration.\n    let config_model = ModelConfig::new(10, 1024);\n    let config_optimizer = AdamConfig::new();\n    let config = MnistTrainingConfig::new(config_model, config_optimizer);\n\n    B::seed(&device, config.seed);\n\n    // Create the model and optimizer.\n    let mut model = config.model.init::<B>(&device);\n    let mut optim = config.optimizer.init();\n\n    // Create the batcher.\n    let batcher = MnistBatcher::default();\n\n    // Create the dataloaders.\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::test());\n\n    // Iterate over our training and validation loop for X epochs.\n    for epoch in 1..config.num_epochs + 1 {\n        // Implement our training loop.\n        for (iteration, batch) in dataloader_train.iter().enumerate() {\n            let output = model.forward(batch.images);\n            let loss = CrossEntropyLoss::new(None, &output.device())\n                .forward(output.clone(), batch.targets.clone());\n            let accuracy = accuracy(output, batch.targets);\n\n            println!(\n                \"[Train - Epoch {} - Iteration {}] Loss {:.3} | Accuracy {:.3} %\",\n                epoch,\n                iteration,\n                loss.clone().into_scalar(),\n                accuracy,\n            );\n\n            // Gradients for the current backward pass\n            let grads = loss.backward();\n            // Gradients linked to each parameter of the model.\n            let grads = GradientsParams::from_grads(grads, &model);\n            // Update the model using the optimizer.\n            model = optim.step(config.lr, model, grads);\n        }\n\n        // Get the model without autodiff.\n        let model_valid = model.valid();\n\n        // Implement our validation loop.\n        for (iteration, batch) in dataloader_test.iter().enumerate() {\n            let output = model_valid.forward(batch.images);\n            let loss = CrossEntropyLoss::new(None, &output.device())\n                .forward(output.clone(), batch.targets.clone());\n            let accuracy = accuracy(output, batch.targets);\n\n            println!(\n                \"[Valid - Epoch {} - Iteration {}] Loss {} | Accuracy {}\",\n                epoch,\n                iteration,\n                loss.clone().into_scalar(),\n                accuracy,\n            );\n        }\n    }\n}\n\n/// Create out own accuracy metric calculation.\nfn accuracy<B: Backend>(output: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> f32 {\n    let predictions = output.argmax(1).squeeze_dim(1);\n    let num_predictions: usize = targets.dims().iter().product();\n    let num_corrects = predictions.equal(targets).int().sum().into_scalar();\n\n    num_corrects.elem::<f32>() / num_predictions as f32 * 100.0\n}\n\n#[allow(dead_code)]\nstruct Learner1<B, O>\nwhere\n    B: AutodiffBackend,\n{\n    model: Model<B>,\n    optim: O,\n}\n\n#[allow(dead_code)]\nstruct Learner2<M, O> {\n    model: M,\n    optim: O,\n}\n\n#[allow(dead_code)]\nstruct Learner3<B, M, O> {\n    model: M,\n    optim: O,\n    _b: PhantomData<B>,\n}\n\n#[allow(dead_code)]\nimpl<B, O> Learner1<B, O>\nwhere\n    B: AutodiffBackend,\n    O: Optimizer<Model<B>, B>,\n{\n    pub fn step1(&mut self, _batch: MnistBatch<B>) {\n        //\n    }\n}\n\n#[allow(dead_code)]\nimpl<B, O> Learner2<Model<B>, O>\nwhere\n    B: AutodiffBackend,\n    O: Optimizer<Model<B>, B>,\n{\n    pub fn step2(&mut self, _batch: MnistBatch<B>) {\n        //\n    }\n}\n\n#[allow(dead_code)]\nimpl<M, O> Learner2<M, O> {\n    pub fn step3<B>(&mut self, _batch: MnistBatch<B>)\n    where\n        B: AutodiffBackend,\n        M: AutodiffModule<B>,\n        O: Optimizer<M, B>,\n    {\n        //\n    }\n}\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"custom-wgpu-kernel\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[dependencies]\nburn = { path = \"../../crates/burn\", default-features = false, features = [\n    \"autodiff\",\n    \"wgpu\",\n    \"autotune\",\n    \"template\",\n] }\ncubecl = { workspace = true, features = [\"wgpu\"] }\n\n# Serialization\nlog = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n\n# Wgpu internal dependencies\nderive-new = { workspace = true }\nbytemuck = { workspace = true }\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs",
    "content": "use burn::{\n    backend::wgpu::WgpuRuntime,\n    tensor::{Distribution, Tensor, Tolerance},\n};\nuse custom_wgpu_kernel::{\n    AutodiffBackend, Backend, matmul_add_relu_custom, matmul_add_relu_reference,\n};\n\nfn inference<B: Backend>(device: &B::Device) {\n    let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device);\n    let rhs = Tensor::random([32, 32, 32], Distribution::Default, device);\n    let bias = Tensor::random([32, 32, 32], Distribution::Default, device);\n\n    let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone())\n        .into_data()\n        .convert::<f32>();\n    let custom = matmul_add_relu_custom(lhs, rhs, bias)\n        .into_data()\n        .convert::<f32>();\n\n    reference.assert_approx_eq::<f32>(&custom, Tolerance::default());\n\n    println!(\"Both reference and the custom fused kernel have the same output\");\n}\n\nfn autodiff<B: AutodiffBackend>(device: &B::Device) {\n    let lhs = Tensor::<B, 3>::random([1, 32, 32], Distribution::Default, device).require_grad();\n    let rhs = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();\n    let bias = Tensor::random([32, 32, 32], Distribution::Default, device).require_grad();\n\n    let reference = matmul_add_relu_reference(lhs.clone(), rhs.clone(), bias.clone());\n\n    let mut gradients = reference.backward();\n\n    let lhs_grad_ref = lhs.grad_remove(&mut gradients).unwrap();\n    let rhs_grad_ref = rhs.grad_remove(&mut gradients).unwrap();\n    let bias_grad_ref = bias.grad_remove(&mut gradients).unwrap();\n\n    let lhs = lhs.detach();\n    let rhs = rhs.detach();\n    let bias = bias.detach();\n\n    let custom = matmul_add_relu_custom(lhs.clone(), rhs.clone(), bias.clone());\n\n    let mut gradients = custom.backward();\n\n    let lhs_grad_custom = lhs.grad_remove(&mut gradients).unwrap();\n    let rhs_grad_custom = rhs.grad_remove(&mut gradients).unwrap();\n    let bias_grad_custom = bias.grad_remove(&mut gradients).unwrap();\n\n    lhs_grad_ref\n        .into_data()\n        .convert::<B::FloatElem>()\n        .assert_approx_eq::<f32>(\n            &lhs_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same lhs gradient\");\n\n    rhs_grad_ref\n        .into_data()\n        .convert::<f32>()\n        .assert_approx_eq::<f32>(\n            &rhs_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same rhs gradient\");\n\n    bias_grad_ref\n        .into_data()\n        .convert::<f32>()\n        .assert_approx_eq::<f32>(\n            &bias_grad_custom.into_data().convert::<B::FloatElem>(),\n            Tolerance::default(),\n        );\n\n    println!(\"Both reference and the custom fused kernel have the same bias gradient\");\n}\n\nfn main() {\n    type MyBackend = burn::backend::wgpu::CubeBackend<WgpuRuntime, f32, i32, u32>;\n    type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;\n    let device = Default::default();\n    inference::<MyBackend>(&device);\n    autodiff::<MyAutodiffBackend>(&device);\n}\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/src/backward.rs",
    "content": "use crate::FloatTensor;\n\nuse super::{AutodiffBackend, Backend};\nuse burn::{\n    backend::{\n        autodiff::{\n            Autodiff, NodeId,\n            checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},\n            grads::Gradients,\n            ops::{Backward, Ops, OpsKind, broadcast_shape},\n        },\n        wgpu::{BoolElement, CubeBackend, FloatElement, IntElement, WgpuRuntime},\n    },\n    tensor::{Shape, TensorMetadata},\n};\n\nimpl<F: FloatElement, I: IntElement, BT: BoolElement> AutodiffBackend\n    for Autodiff<CubeBackend<WgpuRuntime, F, I, BT>>\n{\n}\n\n// Implement our custom backend trait for any backend that also implements our custom backend trait.\n//\n// Note that we could implement the backend trait only for the Wgpu backend instead of any backend that\n// also implements our own API. This would allow us to call any function only implemented for Wgpu\n// and potentially call a custom kernel crafted only for this task.\nimpl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Create our zero-sized type that will implement the Backward trait.\n        #[derive(Debug)]\n        struct FusedMatmulAddReluBackward;\n\n        // Implement the backward trait for the given backend B, the node gradient\n        // with three other gradients to calculate (lhs, rhs, and bias).\n        impl<B: Backend> Backward<B, 3> for FusedMatmulAddReluBackward {\n            // Our state that we must build during the forward pass to compute the backward pass.\n            //\n            // Note that we could improve the performance further by only keeping the state of\n            // tensors that are tracked, improving memory management, but for simplicity, we avoid\n            // that part.\n            type State = (NodeId, NodeId, FloatTensor<B>, Shape);\n\n            fn backward(\n                self,\n                ops: Ops<Self::State, 3>,\n                grads: &mut Gradients,\n                checkpointer: &mut Checkpointer,\n            ) {\n                // Get the nodes of each variable.\n                let [node_lhs, node_rhs, node_bias] = ops.parents;\n                // Fetch the gradient for the current node.\n                let grad = grads.consume::<B>(&ops.node);\n\n                // Set our state.\n                let (lhs_state, rhs_state, output, shape_bias) = ops.state;\n                let lhs: FloatTensor<B> = checkpointer.retrieve_node_output(lhs_state);\n                let rhs: FloatTensor<B> = checkpointer.retrieve_node_output(rhs_state);\n\n                // Fetch shapes of our tensor to support broadcasting.\n                let shape_lhs = lhs.shape();\n                let shape_rhs = rhs.shape();\n\n                // Compute the gradient of the output using the already existing `relu_backward`\n                // function in the basic Burn backend trait.\n                let grad_output = B::relu_backward(output, grad);\n\n                // Compute the lhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_lhs = broadcast_shape::<B>(\n                    B::float_matmul(grad_output.clone(), B::float_transpose(rhs)),\n                    &shape_lhs,\n                );\n                // Compute the rhs gradient, which is the derivative of matmul with support for\n                // broadcasting.\n                let grad_rhs = broadcast_shape::<B>(\n                    B::float_matmul(B::float_transpose(lhs), grad_output.clone()),\n                    &shape_rhs,\n                );\n                // The add derivative is only 1, so we just need to support broadcasting to\n                // compute the bias gradient.\n                let grad_bias = broadcast_shape::<B>(grad_output, &shape_bias);\n\n                // Register the gradient for each variable based on whether they are marked as\n                // `tracked`.\n                if let Some(node) = node_bias {\n                    grads.register::<B>(node.id, grad_bias);\n                }\n                if let Some(node) = node_lhs {\n                    grads.register::<B>(node.id, grad_lhs);\n                }\n                if let Some(node) = node_rhs {\n                    grads.register::<B>(node.id, grad_rhs);\n                }\n            }\n        }\n\n        // Prepare a stateful operation with each variable node and corresponding graph.\n        //\n        // Each node can be fetched with `ops.parents` in the same order as defined here.\n        match FusedMatmulAddReluBackward\n            .prepare::<C>([lhs.node.clone(), rhs.node.clone(), bias.node.clone()])\n            // Marks the operation as compute bound, meaning it will save its\n            // state instead of recomputing itself during checkpointing\n            .compute_bound()\n            .stateful()\n        {\n            OpsKind::Tracked(mut prep) => {\n                // When at least one node is tracked, we should register our backward step.\n\n                // The state consists of what will be needed for this operation's backward pass.\n                // Since we need the parents' outputs, we must checkpoint their ids to retrieve their node\n                // output at the beginning of the backward. We can also save utility data such as the bias shape\n                // If we also need this operation's output, we can either save it in the state or recompute it\n                // during the backward pass. Here we choose to save it in the state because it's a compute bound operation.\n                let lhs_state = prep.checkpoint(&lhs);\n                let rhs_state = prep.checkpoint(&rhs);\n                let bias_shape = bias.primitive.shape();\n\n                let output = B::fused_matmul_add_relu(\n                    lhs.primitive.clone(),\n                    rhs.primitive.clone(),\n                    bias.primitive,\n                );\n\n                let state = (lhs_state, rhs_state, output.clone(), bias_shape);\n\n                prep.finish(state, output)\n            }\n            OpsKind::UnTracked(prep) => {\n                // When no node is tracked, we can just compute the original operation without\n                // keeping any state.\n                let output = B::fused_matmul_add_relu(lhs.primitive, rhs.primitive, bias.primitive);\n                prep.finish(output)\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/src/forward.rs",
    "content": "use crate::FloatTensor;\n\nuse super::Backend;\nuse burn::{\n    backend::wgpu::{\n        BoolElement, CubeBackend, CubeTensor, FloatElement, IntElement, KernelSource, SourceKernel,\n        SourceTemplate, WgpuRuntime, build_info, into_contiguous, kernel_source,\n    },\n    tensor::Shape,\n};\nuse cubecl::{CubeCount, CubeDim, prelude::KernelId, server::KernelArguments};\nuse derive_new::new;\nuse std::marker::PhantomData;\n\n// Source the kernel written in WGSL.\nkernel_source!(FusedMatmulAddReluRaw, \"./kernel.wgsl\");\n\n// Define our kernel type with cube information.\n#[derive(new, Debug)]\nstruct FusedMatmulAddRelu<E: FloatElement> {\n    cube_dim: CubeDim,\n    _elem: PhantomData<E>,\n}\n\n// Implement the dynamic kernel trait for our kernel type.\nimpl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {\n    fn source(&self) -> SourceTemplate {\n        // Extend our raw kernel with cube size information using the\n        // `SourceTemplate` trait.\n        FusedMatmulAddReluRaw::new()\n            .source()\n            .register(\"workgroup_size_x\", self.cube_dim.x.to_string())\n            .register(\"workgroup_size_y\", self.cube_dim.y.to_string())\n            .register(\"elem\", E::type_name())\n            .register(\"int\", \"i32\")\n    }\n\n    fn id(&self) -> KernelId {\n        KernelId::new::<Self>().info(self.cube_dim)\n    }\n}\n\n/// Implement our custom backend trait for the existing backend `WgpuBackend`.\nimpl<F: FloatElement, I: IntElement, BT: BoolElement> Backend\n    for CubeBackend<WgpuRuntime, F, I, BT>\n{\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self> {\n        // Define cube dim, hardcoded for simplicity.\n        let cube_dim = CubeDim { x: 16, y: 16, z: 1 };\n\n        lhs.assert_is_on_same_device(&rhs);\n        lhs.assert_is_on_same_device(&bias);\n\n        // For simplicity, make sure each tensor is continuous.\n        let lhs = into_contiguous(lhs);\n        let rhs = into_contiguous(rhs);\n        let bias = into_contiguous(bias);\n\n        // Get the matmul relevant shapes.\n        let ndims = lhs.meta.shape().num_dims();\n        let num_rows = lhs.meta.shape()[ndims - 2];\n        let num_cols = rhs.meta.shape()[ndims - 1];\n\n        // Compute shape of output, while tracking number of batches.\n        let mut num_batches = 1;\n        let mut shape_out = vec![0; ndims];\n        for i in shape_out.clone().into_iter().take(ndims - 2) {\n            shape_out[i] = usize::max(lhs.meta.shape()[i], rhs.meta.shape()[i]);\n            num_batches *= shape_out[i];\n        }\n        shape_out[ndims - 2] = num_rows;\n        shape_out[ndims - 1] = num_cols;\n        let shape_out = Shape::from(shape_out);\n\n        // Create a buffer for the output tensor.\n        let buffer = lhs\n            .client\n            .empty(shape_out.num_elements() * core::mem::size_of::<F>());\n\n        // Create the output tensor primitive.\n        let output = CubeTensor::new_contiguous(\n            lhs.client.clone(),\n            lhs.device.clone(),\n            shape_out,\n            buffer,\n            F::dtype(),\n        );\n\n        // Create the kernel.\n        let kernel = FusedMatmulAddRelu::<F>::new(cube_dim);\n\n        // Build info buffer with tensor information needed by the kernel, such as shapes and strides.\n        let info = build_info::<_, F>(&[&lhs, &rhs, &output]);\n        let info_handle = lhs.client.create_from_slice(bytemuck::cast_slice(&info));\n\n        // Declare the wgsl workgroup with the number of cubes in x, y and z.\n        let cubes_needed_in_x = f32::ceil(num_rows as f32 / cube_dim.x as f32) as u32;\n        let cubes_needed_in_y = f32::ceil(num_cols as f32 / cube_dim.y as f32) as u32;\n        let cube_count =\n            CubeCount::Static(cubes_needed_in_x, cubes_needed_in_y, num_batches as u32);\n\n        // Execute lazily the kernel with the launch information and the given buffers.\n        lhs.client.launch(\n            Box::new(SourceKernel::new(kernel, cube_dim)),\n            cube_count,\n            KernelArguments::new().with_buffers(vec![\n                lhs.handle.binding(),\n                rhs.handle.binding(),\n                bias.handle.binding(),\n                output.handle.clone().binding(),\n                info_handle.binding(),\n            ]),\n        );\n\n        // Return the output tensor.\n        output\n    }\n}\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/src/kernel.wgsl",
    "content": "@group(0)\n@binding(0)\nvar<storage, read_write> lhs: array<{{ elem }}>;\n\n@group(0)\n@binding(1)\nvar<storage, read_write> rhs: array<{{ elem }}>;\n\n@group(0)\n@binding(2)\nvar<storage, read_write> bias: array<{{ elem }}>;\n\n@group(0)\n@binding(3)\nvar<storage, read_write> output: array<{{ elem }}>;\n\n@group(0)\n@binding(4)\nvar<storage, read_write> info: array<u32>;\n\nconst BLOCK_SIZE = {{ workgroup_size_x }}u;\n\n@compute\n@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)\nfn main(\n    @builtin(global_invocation_id) global_id: vec3<u32>,\n    @builtin(local_invocation_index) local_idx: u32,\n    @builtin(workgroup_id) workgroup_id: vec3<u32>,\n) {\n    // Indices\n    let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);\n    let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);\n    let batch = global_id.z;\n\n    // Basic information\n    let dim = info[0];\n    let n_rows = info[6u * dim - 1u];\n    let n_cols = info[6u * dim];\n    let K = info[5u * dim - 1u];\n\n    // Returns if outside the output dimension\n    if row >= n_rows || col >= n_cols {\n        return;\n    }\n\n    // Calculate the corresponding offsets with support for broadcasting.\n    let offset_output = batch * n_rows * n_cols;\n    var offset_lhs: u32 = 0u;\n    var offset_rhs: u32 = 0u;\n\n    let batch_dims = dim - 2u;\n    for (var b: u32 = 1u; b <= batch_dims; b++) {\n        let stride_lhs = info[b];\n        let stride_rhs = info[b + dim];\n        let stride_output = info[b + 2u * dim];\n        let shape_lhs = info[b + 3u * dim];\n        let shape_rhs = info[b + 4u * dim];\n\n        offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;\n        offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;\n    }\n\n    // Basic matmul implementation\n    var sum = 0.0;\n    for (var k: u32 = 0u; k < K; k++) {\n        let lhs_index = row * K + k;\n        let rhs_index = k * n_cols + col;\n\n        sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];\n    }\n\n    let output_index = row * n_cols + col;\n    let index = offset_output + output_index;\n\n    output[index] = max(sum + bias[index], 0.0);\n}\n"
  },
  {
    "path": "examples/custom-wgpu-kernel/src/lib.rs",
    "content": "mod backward;\nmod forward;\n\nuse burn::tensor::{Tensor, TensorPrimitive, activation, ops::FloatTensor};\n\n/// We create our own Backend trait that extends the Burn backend trait.\npub trait Backend: burn::tensor::backend::Backend {\n    fn fused_matmul_add_relu(\n        lhs: FloatTensor<Self>,\n        rhs: FloatTensor<Self>,\n        bias: FloatTensor<Self>,\n    ) -> FloatTensor<Self>;\n}\n\n/// We create our own AutodiffBackend trait that extends the Burn autodiff backend trait.\npub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}\n\n/// We define our custom implementation using the added function on our custom backend.\npub fn matmul_add_relu_custom<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let output = B::fused_matmul_add_relu(\n        lhs.into_primitive().tensor(),\n        rhs.into_primitive().tensor(),\n        bias.into_primitive().tensor(),\n    );\n\n    Tensor::from_primitive(TensorPrimitive::Float(output))\n}\n\n/// We define a reference implementation using basic tensor operations.\npub fn matmul_add_relu_reference<B: Backend>(\n    lhs: Tensor<B, 3>,\n    rhs: Tensor<B, 3>,\n    bias: Tensor<B, 3>,\n) -> Tensor<B, 3> {\n    let x = lhs.matmul(rhs) + bias;\n\n    activation::relu(x)\n}\n"
  },
  {
    "path": "examples/dop_timer/Cargo.toml",
    "content": "[package]\nname = \"dop_timer\"\ndescription = \"Distributed operation timer utility.\"\nedition.workspace = true\nlicense.workspace = true\nreadme.workspace = true\nversion.workspace = true\npublish = false\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"cuda\"]\nndarray = [\"burn/ndarray\"]\ncuda = [\"burn/cuda\"]\nwgpu = [\"burn/wgpu\"]\nmetal = [\"burn/metal\"]\n\n[dependencies]\nclap = { workspace = true, features = [\"derive\"] }\nburn = { path = \"../../crates/burn\", version = \"=0.21.0-pre.2\", features = [\n    \"collective\",\n    \"tracing\",\n] }\n\nlog = { workspace = true }\nrand = { workspace = true }\n\ntokio = { workspace = true, features = [\"full\", \"tracing\"] }\ntracing = { workspace = true }\n# todo: tracing-log\ntracing-subscriber = { workspace = true, features = [\"env-filter\", \"json\"] }\ntracing-opentelemetry = { workspace = true }\nopentelemetry = { workspace = true, features = [\"trace\"] }\nopentelemetry_sdk = { workspace = true }\nopentelemetry-aws = { workspace = true }\nopentelemetry-otlp = { workspace = true, features = [\"tracing\", \"grpc-tonic\"] }\nopentelemetry-stdout = \"0.31.0\"\n"
  },
  {
    "path": "examples/dop_timer/README.md",
    "content": "# dop_timer\n\nThis binary exists to time the behavior of distributed (local, global) collective operations.\n\nThis binary uses the `gRPC` OTEL exporter to send traces to an OTEL Collector on port `4317`.\n\n## Example\n\n1. Setup an OTEL Collector\n\nThere are many ways to do this; one of the simplest is to use the `jaegertracing/all-in-one:latest` docker image:\n\n```bash\n$ docker run -e OTEL_TRACES_SAMPLER=always_off -e COLLECTOR_OTLP_ENABLED=true -p 16686:16686 -p 4317-4318:4317-4318 -p 14250:14250 -p 14268:14268 -p 14269:14269 jaegertracing/all-in-one:latest\n```\n\nThen navigate to `localhost:16686` to view traces.\n\n2. Run the binary, with the OTEL Collector endpoint as an argument:\n\n```bash\n$ cargo run -p dop_timer --features cuda -- --tracing otel\n```"
  },
  {
    "path": "examples/dop_timer/src/event_utils.rs",
    "content": "/// Simply instrumented event; to verify event reporting.\n#[tracing::instrument(level = \"trace\")]\npub(crate) fn example_instrumented_event() {\n    tracing::info!(\"test event\");\n\n    let span = tracing::info_span!(\"test_span\");\n    let _guard = span.enter();\n    tracing::info!(\"inside span\");\n}\n"
  },
  {
    "path": "examples/dop_timer/src/main.rs",
    "content": "use crate::event_utils::example_instrumented_event;\nuse crate::workers::WorkerHandle;\nuse burn::Tensor;\nuse burn::collective::{AllReduceStrategy, CollectiveConfig, ReduceOperation};\nuse burn::prelude::{Backend, DeviceOps};\nuse burn::tensor::Shape;\nuse burn::tensor::backend::DeviceId;\nuse clap::{Parser, ValueEnum};\nuse opentelemetry::trace::TracerProvider;\nuse opentelemetry_sdk::Resource;\nuse opentelemetry_sdk::propagation::TraceContextPropagator;\nuse std::error::Error;\nuse std::iter::repeat_with;\nuse std::sync::mpsc::Receiver;\nuse tracing_subscriber::EnvFilter;\nuse tracing_subscriber::fmt::format::FmtSpan;\nuse tracing_subscriber::layer::SubscriberExt;\nuse tracing_subscriber::util::SubscriberInitExt;\n\nmod parsers;\nuse parsers::*;\nmod event_utils;\nmod workers;\n\nstatic APP_NAME: &str = \"dop_timer\";\n\n#[derive(Debug, Clone, ValueEnum)]\npub enum ConsoleFormat {\n    Text,\n    Json,\n}\n\n#[derive(Debug, Clone, ValueEnum)]\npub enum TracingMode {\n    /// Print to stderr.\n    Console,\n\n    /// Export to OTEL via gRPC.\n    Otel,\n}\n\n/// Timing tool for measuring the performance of collective operations.\n///\n/// Currently only supports `all_reduce`.\n#[derive(Parser, Debug)]\npub struct Args {\n    /// Suppress verbose output.\n    #[arg(long, action = clap::ArgAction::Set)]\n    pub quiet: Option<bool>,\n\n    /// Enable tracing.\n    #[arg(long, value_enum)]\n    pub tracing: Option<TracingMode>,\n\n    /// Number of timing runs to perform.\n    #[arg(long, default_value = \"3\")]\n    pub timing_runs: usize,\n\n    /// Output format for console tracing.\n    #[arg(long, value_enum, default_value = \"text\")]\n    pub console_tracing_format: ConsoleFormat,\n\n    // TODO: sub-commands.\n    /// Shape of the tensor to reduce.\n    #[arg(long, value_parser = parse_array4, default_value = \"2, 3, 256, 256\")]\n    pub shape: [usize; 4],\n\n    /// All-reduce strategy.\n    #[arg(long, value_parser = parse_all_reduce_strategy, default_value = \"tree:2\")]\n    pub strategy: AllReduceStrategy,\n\n    /// Number of workers per device.\n    #[arg(long, default_value = \"1\")]\n    pub workers_per_device: usize,\n\n    /// Reduce operation.\n    #[arg(long, value_parser = parse_reduce_operation, default_value = \"sum\")]\n    pub op: ReduceOperation,\n}\n\nimpl Args {\n    pub fn quiet(&self) -> bool {\n        self.quiet.unwrap_or(false)\n    }\n\n    pub fn verbose(&self) -> bool {\n        !self.quiet()\n    }\n}\n\n#[tokio::main(flavor = \"multi_thread\")]\nasync fn main() -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    let args = Args::parse();\n    if args.verbose() {\n        println!(\"{:?}\", args);\n    }\n\n    let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(\"info\"));\n\n    let tracing_provider = match &args.tracing {\n        None => None,\n        Some(TracingMode::Console) => {\n            let subscriber = tracing_subscriber::fmt()\n                .with_env_filter(env_filter)\n                .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT);\n\n            let subscriber = subscriber.with_writer(std::io::stderr);\n\n            match &args.console_tracing_format {\n                ConsoleFormat::Text => subscriber.try_init()?,\n                ConsoleFormat::Json => subscriber.json().try_init()?,\n            };\n\n            None\n        }\n        Some(TracingMode::Otel) => {\n            let exporter = opentelemetry_otlp::SpanExporter::builder()\n                .with_tonic()\n                .build()?;\n\n            let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()\n                .with_batch_exporter(exporter)\n                .with_sampler(opentelemetry_sdk::trace::Sampler::AlwaysOn)\n                .with_resource(Resource::builder().with_service_name(APP_NAME).build())\n                .build();\n\n            opentelemetry::global::set_tracer_provider(provider.clone());\n\n            let tracer = provider.tracer(APP_NAME);\n\n            let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);\n\n            tracing_subscriber::registry()\n                .with(env_filter)\n                .with(telemetry)\n                .try_init()?;\n\n            opentelemetry::global::set_text_map_propagator(TraceContextPropagator::new());\n\n            Some(provider)\n        }\n    };\n\n    example_instrumented_event();\n\n    run(&args)?;\n\n    if let Some(provider) = tracing_provider {\n        if args.verbose() {\n            println!(\"> main: shutting down tracing\");\n        }\n        provider.shutdown()?;\n    }\n\n    Ok(())\n}\n\n#[cfg(feature = \"cuda\")]\nfn run(args: &Args) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    run_backend::<burn::backend::Cuda>(args)\n}\n\n#[cfg(feature = \"metal\")]\nfn run(args: &Args) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    run_backend::<burn::backend::Metal>(args)\n}\n\n#[cfg(feature = \"wgpu\")]\nfn run(args: &Args) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    run_backend::<burn::backend::Wgpu>(args)\n}\n\n#[cfg(feature = \"ndarray\")]\nfn run(args: &Args) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    run_backend::<burn::backend::ndarray>(args)\n}\n\n#[tracing::instrument(level = \"trace\", skip(args))]\nfn run_backend<B: Backend>(args: &Args) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {\n    let type_id = 0;\n    let device_count = B::Device::device_count(type_id);\n\n    let devices = (0..device_count)\n        .map(|idx| B::Device::from_id(DeviceId::new(type_id, idx as u32)))\n        .collect::<Vec<_>>();\n\n    // Duplicate the devices to force a heterogeneous setup.\n    let devices = devices\n        .iter()\n        .flat_map(|x| repeat_with(|| x.clone()).take(args.workers_per_device))\n        .collect::<Vec<_>>();\n\n    let config = CollectiveConfig::default()\n        .with_num_devices(devices.len())\n        .with_local_all_reduce_strategy(args.strategy);\n    if args.verbose() {\n        println!(\"> run: config:\\n{:#?}\", config);\n    }\n\n    let handles: Vec<WorkerHandle<B>> = devices\n        .iter()\n        .enumerate()\n        .map(|(idx, device)| WorkerHandle::new(idx, device, config.clone()))\n        .collect();\n\n    if args.verbose() {\n        println!(\"> run: registering workers\");\n    }\n    handles\n        .iter()\n        .map(|h| h.register())\n        // Introduce a sequence point (wait for all workers to start)\n        .collect::<Vec<_>>()\n        // Wait on all results.\n        .into_iter()\n        .try_for_each(|rx: Receiver<()>| rx.recv())?;\n\n    let shape: Shape = args.shape.into();\n\n    let expected_cell: f32 = {\n        let count = handles.len() as f32;\n        let sum = (count * (count + 1.0)) / 2.0;\n        if args.op == ReduceOperation::Mean {\n            sum / count\n        } else {\n            sum\n        }\n    };\n    let expected =\n        Tensor::<B, 4>::full(shape.clone(), expected_cell, &B::Device::default()).to_data();\n\n    for r in 0..args.timing_runs {\n        println!(\"Run {}/{}\", r + 1, args.timing_runs);\n        if args.verbose() {\n            println!(\"> run: setting up device tensors: {:?}\", shape);\n        }\n\n        let tensors = handles\n            .iter()\n            .enumerate()\n            .map(|(idx, h)| {\n                let full_value = idx as f32 + 1.0;\n                Tensor::<B, 4>::full(shape.clone(), full_value, h.device())\n            })\n            .collect::<Vec<_>>();\n\n        if args.verbose() {\n            println!(\"> run: running all_reduce\");\n        }\n        let reduced: Vec<Tensor<B, 4>> = handles\n            .iter()\n            .zip(tensors.into_iter())\n            .map(|(h, t)| h.all_reduce(args.op, t))\n            // Introduce a sequence point.\n            .collect::<Vec<_>>()\n            // Wait on all results.\n            .into_iter()\n            .map(|rx| rx.recv())\n            .collect::<Result<Vec<_>, _>>()?;\n\n        if args.verbose() {\n            println!(\"> run: verifying result\");\n        }\n        for t in reduced {\n            t.into_data().assert_eq(&expected, true);\n        }\n    }\n\n    Ok(())\n}\n"
  },
  {
    "path": "examples/dop_timer/src/parsers.rs",
    "content": "use burn::collective::{AllReduceStrategy, ReduceOperation};\nuse std::num::ParseIntError;\n\npub(crate) fn parse_array4(s: &str) -> Result<[usize; 4], String> {\n    let parts: Result<Vec<_>, _> = s.split(',').map(|p| p.trim().parse()).collect();\n    let parts = parts.map_err(|e: ParseIntError| e.to_string())?;\n    parts\n        .try_into()\n        .map_err(|v: Vec<_>| format!(\"expected 4 values, got {}\", v.len()))\n}\n\npub(crate) fn parse_all_reduce_strategy(s: &str) -> Result<AllReduceStrategy, String> {\n    let s = s.trim();\n    if let Some(depth) = s.strip_prefix(\"tree:\") {\n        let depth = depth.parse::<usize>().map_err(|e| e.to_string())?;\n        Ok(AllReduceStrategy::Tree(depth as u32))\n    } else if s.eq(\"centralized\") {\n        Ok(AllReduceStrategy::Centralized)\n    } else if s.eq(\"ring\") {\n        Ok(AllReduceStrategy::Ring)\n    } else {\n        Err(format!(\"unknown strategy: {}\", s))\n    }\n}\n\npub(crate) fn parse_reduce_operation(s: &str) -> Result<ReduceOperation, String> {\n    let s = s.trim();\n    if s.eq(\"sum\") {\n        Ok(ReduceOperation::Sum)\n    } else if s.eq(\"mean\") {\n        Ok(ReduceOperation::Mean)\n    } else {\n        Err(format!(\"unknown reduce operation: {}\", s))\n    }\n}\n"
  },
  {
    "path": "examples/dop_timer/src/workers.rs",
    "content": "use crate::workers::WorkRequest::{AllReduceRequest, RegisterRequest};\nuse burn::Tensor;\nuse burn::collective::{CollectiveConfig, PeerId, ReduceOperation, all_reduce};\nuse burn::prelude::Backend;\nuse burn::tensor::TensorPrimitive;\nuse std::sync::mpsc::Receiver;\n\npub enum WorkRequest<B: Backend> {\n    RegisterRequest {\n        tx: std::sync::mpsc::SyncSender<()>,\n    },\n    AllReduceRequest {\n        tensor: Tensor<B, 4>,\n        op: ReduceOperation,\n        tx: std::sync::mpsc::SyncSender<Tensor<B, 4>>,\n    },\n}\n\nstruct Worker<B: Backend> {\n    index: usize,\n    id: PeerId,\n    device: B::Device,\n    config: CollectiveConfig,\n}\n\nimpl<B: Backend> Worker<B> {\n    pub fn new(index: usize, device: B::Device, config: CollectiveConfig) -> Self {\n        let device = device.clone();\n        let id = index.into();\n        Self {\n            index,\n            id,\n            device,\n            config,\n        }\n    }\n\n    #[tracing::instrument(level = \"trace\", skip(self, tensor))]\n    pub fn dispatch_all_reduce<const R: usize>(\n        &mut self,\n        tensor: Tensor<B, R>,\n        op: ReduceOperation,\n    ) -> Tensor<B, R> {\n        log::debug!(\"w={}: dispatch_all_reduce start\", self.index);\n        let tensor = Tensor::from_primitive(TensorPrimitive::Float(\n            all_reduce::<B>(self.id, tensor.into_primitive().tensor(), op).unwrap(),\n        ));\n        log::debug!(\"w={}: dispatch_all_reduce end\", self.index);\n        tensor\n    }\n\n    pub fn run(&mut self, rx: Receiver<WorkRequest<B>>) {\n        println!(\"worker {} started\", self.index);\n        while let Ok(command) = rx.recv() {\n            use burn::collective::register;\n            match command {\n                RegisterRequest { tx } => {\n                    register::<B>(self.id, self.device.clone(), self.config.clone()).unwrap();\n                    tx.send(()).unwrap();\n                }\n                AllReduceRequest { tensor, op, tx } => {\n                    assert_eq!(&tensor.device(), &self.device);\n                    let tensor = self.dispatch_all_reduce(tensor, op);\n                    tx.send(tensor).unwrap();\n                }\n            }\n        }\n    }\n}\n\npub struct WorkerHandle<B: Backend> {\n    device: B::Device,\n    tx: std::sync::mpsc::SyncSender<WorkRequest<B>>,\n    phantom: std::marker::PhantomData<B>,\n}\n\nimpl<B: Backend> WorkerHandle<B> {\n    #[tracing::instrument(level = \"trace\", skip(config))]\n    pub fn new(index: usize, device: &B::Device, config: CollectiveConfig) -> Self {\n        let mut worker: Worker<B> = Worker::new(index, device.clone(), config.clone());\n\n        let (tx, rx) = std::sync::mpsc::sync_channel(1);\n        std::thread::spawn(move || worker.run(rx));\n        Self {\n            device: device.clone(),\n            tx,\n            phantom: Default::default(),\n        }\n    }\n\n    pub fn register(&self) -> Receiver<()> {\n        let (tx, rx) = std::sync::mpsc::sync_channel(1);\n        self.tx.send(WorkRequest::RegisterRequest { tx }).unwrap();\n        rx\n    }\n\n    pub fn device(&self) -> &B::Device {\n        &self.device\n    }\n\n    pub fn all_reduce(&self, op: ReduceOperation, tensor: Tensor<B, 4>) -> Receiver<Tensor<B, 4>> {\n        let (tx, rx) = std::sync::mpsc::sync_channel(1);\n        self.tx\n            .send(WorkRequest::AllReduceRequest { tensor, op, tx })\n            .unwrap();\n        rx\n    }\n}\n"
  },
  {
    "path": "examples/dqn-agent/Cargo.toml",
    "content": "[package]\nname = \"dqn-agent\"\nedition.workspace = true\nlicense.workspace = true\nreadme.workspace = true\nversion.workspace = true\n\n[features]\ndefault = [\"burn/tui\"]\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"burn/ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"burn/ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"burn/ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nremote = [\"burn/remote\"]\nwgpu = [\"burn/wgpu\", \"burn/default\"]\nmetal = [\"burn/metal\", \"burn/default\"]\ncuda = [\"burn/cuda\"]\nvulkan = [\"burn/vulkan\", \"burn/default\"]\nrocm = [\"burn/rocm\", \"burn/default\"]\n\n[dependencies]\n# Disable autotune default for now.\nburn = { path = \"../../crates/burn\", features = [\n    \"train\",\n    \"metrics\",\n    \"std\",\n    \"rl\",\n    # \"fusion\",\n    \"ndarray\",\n    # \"autotune\",\n], default-features = false }\n# Just for this example.\ngym-rs = { version = \"0.3.1\", branch = \"main\", git = \"https://github.com/MathisWellmann/gym-rs\" }\nrand.workspace = true\nderive-new = { workspace = true }\n\n[lints]\nworkspace = true\n"
  },
  {
    "path": "examples/dqn-agent/examples/dqn-agent.rs",
    "content": "#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::{\n        Autodiff,\n        ndarray::{NdArray, NdArrayDevice},\n    };\n    use dqn_agent::training;\n\n    pub fn run() {\n        let device = NdArrayDevice::Cpu;\n        training::run::<Autodiff<NdArray>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n    use dqn_agent::training;\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        training::run::<Autodiff<LibTorch>>(device);\n    }\n}\n\n#[cfg(any(feature = \"wgpu\", feature = \"metal\", feature = \"vulkan\"))]\nmod wgpu {\n    use burn::backend::{\n        Autodiff,\n        wgpu::{Wgpu, WgpuDevice},\n    };\n    use dqn_agent::training;\n\n    pub fn run() {\n        let device = WgpuDevice::default();\n        training::run::<Autodiff<Wgpu>>(device);\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use burn::backend::{Autodiff, Cuda};\n    use dqn_agent::training;\n\n    pub fn run() {\n        let device = Default::default();\n        training::run::<Autodiff<Cuda>>(device);\n    }\n}\n\n#[cfg(feature = \"rocm\")]\nmod rocm {\n    use burn::backend::{Autodiff, Rocm};\n    use dqn_agent::training;\n\n    pub fn run() {\n        let device = Default::default();\n        training::run::<Autodiff<Rocm>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n    use dqn_agent::training;\n\n    pub fn run() {\n        let device = LibTorchDevice::Cpu;\n        training::run::<Autodiff<LibTorch>>(device);\n    }\n}\n\n#[cfg(feature = \"remote\")]\nmod remote {\n    use burn::backend::{Autodiff, RemoteBackend};\n    use dqn_agent::training;\n\n    pub fn run() {\n        training::run::<Autodiff<RemoteBackend>>(Default::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(any(feature = \"wgpu\", feature = \"metal\", feature = \"vulkan\"))]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n    #[cfg(feature = \"rocm\")]\n    rocm::run();\n    #[cfg(feature = \"remote\")]\n    remote::run();\n}\n"
  },
  {
    "path": "examples/dqn-agent/src/agent.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn::backend::NdArray;\nuse burn::module::Module;\nuse burn::record::Record;\nuse burn::rl::{\n    Batchable, LearnerTransitionBatch, Policy, PolicyLearner, PolicyState, RLTrainOutput,\n    SliceAccess,\n};\nuse burn::tensor::{Int, Transaction};\nuse burn::tensor::activation::softmax;\nuse burn::train::ItemLazy;\nuse burn::train::metric::{Adaptor, LossInput};\nuse burn::{\n    Tensor,\n    config::Config,\n    module::AutodiffModule,\n    nn::{self, loss::MseLoss},\n    optim::{GradientsParams, Optimizer},\n    prelude::Backend,\n    tensor::backend::AutodiffBackend,\n};\nuse rand::distr::Distribution;\nuse rand::distr::weighted::WeightedIndex;\nuse rand::rng;\n\nuse crate::utils::{\n    EpsilonGreedyPolicy, EpsilonGreedyPolicyState, create_lin_layers, soft_update_linear,\n};\n\npub trait DiscreteActionModel<B: Backend>: Module<B> {\n    type Input: Clone + Send + Batchable;\n\n    fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2>;\n}\n\n#[derive(Config, Debug)]\npub struct MlpNetConfig {\n    /// The number of layers.\n    #[config(default = 3)]\n    pub num_layers: usize,\n    /// The dropout rate.\n    #[config(default = 0.)]\n    pub dropout: f64,\n    /// The input dimension.\n    #[config(default = 4)]\n    pub d_input: usize,\n    /// The output dimension.\n    #[config(default = 2)]\n    pub d_output: usize,\n    /// The size of hidden layers.\n    #[config(default = 256)]\n    pub d_hidden: usize,\n}\n\n/// Multilayer Perceptron Network.\n#[derive(Module, Debug)]\npub struct MlpNet<B: Backend> {\n    pub linears: Vec<nn::Linear<B>>,\n    pub dropout: nn::Dropout,\n    pub activation: nn::Relu,\n}\n\nimpl<B: Backend> MlpNet<B> {\n    /// Create the module from the given configuration.\n    pub fn new(config: &MlpNetConfig, device: &B::Device) -> Self {\n        Self {\n            linears: create_lin_layers(\n                config.num_layers,\n                config.d_input,\n                config.d_hidden,\n                config.d_output,\n                device,\n            ),\n            dropout: nn::DropoutConfig::new(config.dropout).init(),\n            activation: nn::Relu::new(),\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct ObservationTensor<B: Backend, const D: usize> {\n    pub state: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> Batchable for ObservationTensor<B, D> {\n    fn batch(value: Vec<Self>) -> Self {\n        let tensors = value.iter().map(|v| v.state.clone()).collect();\n        Self {\n            state: Tensor::cat(tensors, 0),\n        }\n    }\n\n    fn unbatch(self) -> Vec<Self> {\n        self.state\n            .split(1, 0)\n            .iter()\n            .map(|s| ObservationTensor { state: s.clone() })\n            .collect()\n    }\n}\n\nimpl<B: Backend> SliceAccess<B> for ObservationTensor<B, 2> {\n    fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {\n        let feature_dim = sample.state.dims()[1];\n        Self {\n            state: Tensor::zeros([capacity, feature_dim], device),\n        }\n    }\n\n    fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {\n        Self {\n            state: Tensor::select(self.state, dim, indices),\n        }\n    }\n\n    fn slice_assign_inplace(&mut self, index: usize, value: Self) {\n        self.state\n            .inplace(|t| t.slice_assign(index..index + 1, value.state));\n    }\n}\n\nimpl<B: Backend> DiscreteActionModel<B> for MlpNet<B> {\n    type Input = ObservationTensor<B, 2>;\n\n    /// Applies the forward pass on the input tensor.\n    ///\n    /// # Shapes\n    ///\n    /// - input: `[batch_size, d_input]`\n    /// - output: `[batch_size, d_output]`\n    fn forward(&self, input: Self::Input) -> DiscreteLogitsTensor<B, 2> {\n        let mut x = input.state;\n\n        for (i, linear) in self.linears.iter().enumerate() {\n            x = linear.forward(x);\n            x = self.dropout.forward(x);\n            if i < self.linears.len() - 1 {\n                x = self.activation.forward(x);\n            }\n        }\n\n        DiscreteLogitsTensor { logits: x }\n    }\n}\n\n#[derive(Config, Debug)]\npub struct DqnAgentConfig {\n    /// Discount factor (How to value long-term vs short-term rewards)\n    #[config(default = 0.99)]\n    pub gamma: f64,\n    /// The learning rate\n    #[config(default = 3e-4)]\n    pub learning_rate: f64,\n    /// The soft update rate of the target network\n    #[config(default = 0.005)]\n    pub tau: f64,\n    /// Initial value of epsilon (Probability to choose a random action)\n    #[config(default = 0.9)]\n    pub epsilon_start: f64,\n    /// Final value of epsilon (Probability to choose a random action)\n    #[config(default = 0.01)]\n    pub epsilon_end: f64,\n    /// The exponential rate at which the epsilon value decays. Higher = slower decay\n    #[config(default = 2500.0)]\n    pub epsilon_decay: f64,\n}\n\npub trait TargetModel<B: Backend> {\n    fn soft_update(&self, that: &Self, tau: f64) -> Self;\n}\n\nimpl<B: Backend> TargetModel<B> for MlpNet<B> {\n    fn soft_update(&self, that: &Self, tau: f64) -> Self {\n        let mut linears = Vec::with_capacity(self.linears.len());\n        for i in 0..self.linears.len() {\n            let layer = soft_update_linear(self.linears[i].clone(), &that.linears[i].clone(), tau);\n            linears.insert(i, layer);\n        }\n        Self {\n            linears,\n            dropout: self.dropout.clone(),\n            activation: self.activation.clone(),\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct DqnState<B: Backend, M: DiscreteActionModel<B>> {\n    model: M,\n    _backend: PhantomData<B>,\n}\n\nimpl<B: Backend, M: DiscreteActionModel<B>> PolicyState<B> for DqnState<B, M> {\n    type Record = M::Record;\n\n    fn into_record(self) -> Self::Record {\n        self.model.clone().into_record()\n    }\n\n    fn load_record(&self, record: Self::Record) -> Self {\n        Self {\n            model: self.model.clone().load_record(record),\n            _backend: PhantomData,\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct DQN<B: Backend, M: DiscreteActionModel<B>> {\n    model: M,\n    _backend: PhantomData<B>,\n}\n\nimpl<B: Backend, M: DiscreteActionModel<B>> DQN<B, M> {\n    pub fn new(policy: M) -> Self {\n        Self {\n            model: policy,\n            _backend: PhantomData,\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct DiscreteLogitsTensor<B: Backend, const D: usize> {\n    pub logits: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> Batchable for DiscreteLogitsTensor<B, D> {\n    fn batch(value: Vec<Self>) -> Self {\n        let tensors = value.iter().map(|v| v.logits.clone()).collect();\n        Self {\n            logits: Tensor::cat(tensors, 0),\n        }\n    }\n\n    fn unbatch(self) -> Vec<Self> {\n        self.logits\n            .split(1, 0)\n            .iter()\n            .map(|l| DiscreteLogitsTensor { logits: l.clone() })\n            .collect()\n    }\n}\n\n#[derive(Clone)]\npub struct DiscreteActionTensor<B: Backend, const D: usize> {\n    pub actions: Tensor<B, D>,\n}\n\nimpl<B: Backend, const D: usize> Batchable for DiscreteActionTensor<B, D> {\n    fn batch(value: Vec<Self>) -> Self {\n        let tensors = value.iter().map(|v| v.actions.clone()).collect();\n        Self {\n            actions: Tensor::cat(tensors, 0),\n        }\n    }\n\n    fn unbatch(self) -> Vec<Self> {\n        self.actions\n            .split(1, 0)\n            .iter()\n            .map(|a| DiscreteActionTensor { actions: a.clone() })\n            .collect()\n    }\n}\n\nimpl<B: Backend> SliceAccess<B> for DiscreteActionTensor<B, 2> {\n    fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {\n        let feature_dim = sample.actions.dims()[1];\n        Self {\n            actions: Tensor::zeros([capacity, feature_dim], device),\n        }\n    }\n\n    fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {\n        Self {\n            actions: Tensor::select(self.actions, dim, indices),\n        }\n    }\n\n    fn slice_assign_inplace(&mut self, index: usize, value: Self) {\n        self.actions\n            .inplace(|t| t.slice_assign(index..index + 1, value.actions));\n    }\n}\n\nimpl<B: Backend, M: DiscreteActionModel<B>> Policy<B> for DQN<B, M> {\n    type Observation = M::Input;\n    type ActionDistribution = DiscreteLogitsTensor<B, 2>;\n    type Action = DiscreteActionTensor<B, 2>;\n\n    type ActionContext = ();\n    type PolicyState = DqnState<B, M>;\n\n    fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {\n        self.model.forward(states)\n    }\n\n    fn action(\n        &mut self,\n        states: Self::Observation,\n        deterministic: bool,\n    ) -> (Self::Action, Vec<Self::ActionContext>) {\n        let logits = self.forward(states).logits;\n        if deterministic {\n            let output = DiscreteActionTensor {\n                actions: logits.argmax(1).float(),\n            };\n            return (output, vec![]);\n        }\n\n        let mut actions = vec![];\n        let probs = softmax(logits, 1);\n        let probs = probs.split(1, 0);\n        let mut rng = rng();\n        for p in probs {\n            let dist = WeightedIndex::new(p.to_data().to_vec::<f32>().unwrap()).unwrap();\n            let action = dist.sample(&mut rng);\n            actions.push(Tensor::<B, 1>::from_floats([action], &p.device()));\n        }\n\n        let output = DiscreteActionTensor {\n            actions: Tensor::stack(actions, 1),\n        };\n        (output, vec![])\n    }\n\n    fn update(&mut self, update: Self::PolicyState) {\n        self.model = update.model;\n    }\n\n    fn state(&self) -> Self::PolicyState {\n        DqnState {\n            model: self.model.clone(),\n            _backend: PhantomData,\n        }\n    }\n\n    fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {\n        let state = self.state().load_record(record);\n        Self {\n            model: state.model,\n            _backend: PhantomData,\n        }\n    }\n}\n\n#[derive(Record)]\npub struct DqnLearningRecord<B: AutodiffBackend, M: AutodiffModule<B>, O: Optimizer<M, B>> {\n    policy_model: M::Record,\n    target_model: M::Record,\n    optimizer: O::Record,\n}\n\n#[derive(Clone)]\npub struct DqnLearningAgent<B, M, O>\nwhere\n    B: AutodiffBackend,\n    M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,\n    M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,\n    O: Optimizer<M, B> + 'static,\n{\n    policy_model: M,\n    target_model: M,\n    agent: EpsilonGreedyPolicy<B, DQN<B, M>>,\n    optimizer: O,\n    config: DqnAgentConfig,\n}\n\nimpl<B, M, O> DqnLearningAgent<B, M, O>\nwhere\n    B: AutodiffBackend,\n    M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,\n    M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,\n    O: Optimizer<M, B> + 'static,\n{\n    pub fn new(model: M, optimizer: O, config: DqnAgentConfig) -> Self {\n        let agent = EpsilonGreedyPolicy::new(\n            DQN::new(model.clone()),\n            config.epsilon_start,\n            config.epsilon_end,\n            config.epsilon_decay,\n        );\n        Self {\n            policy_model: model.clone(),\n            target_model: model,\n            agent,\n            optimizer,\n            config,\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct SimpleTrainOutput<B: Backend> {\n    pub policy_model_loss: Tensor<B, 1>,\n}\n\nimpl<B: Backend> ItemLazy for SimpleTrainOutput<B> {\n    type ItemSync = SimpleTrainOutput<NdArray>;\n\n    fn sync(self) -> Self::ItemSync {\n        let [loss] = Transaction::default()\n            .register(self.policy_model_loss)\n            .execute()\n            .try_into()\n            .expect(\"Correct amount of tensor data\");\n\n        let device = &Default::default();\n\n        SimpleTrainOutput {\n            policy_model_loss: Tensor::from_data(loss, device),\n        }\n    }\n}\n\nimpl<B: Backend> Adaptor<LossInput<B>> for SimpleTrainOutput<B> {\n    fn adapt(&self) -> LossInput<B> {\n        LossInput::new(self.policy_model_loss.clone())\n    }\n}\n\nimpl<B, M, O> PolicyLearner<B> for DqnLearningAgent<B, M, O>\nwhere\n    B: AutodiffBackend,\n    M: DiscreteActionModel<B> + AutodiffModule<B> + TargetModel<B> + 'static,\n    M::Input: Clone,\n    M::InnerModule: DiscreteActionModel<B::InnerBackend> + TargetModel<B::InnerBackend>,\n    O: Optimizer<M, B> + 'static,\n{\n    type TrainContext = SimpleTrainOutput<B>;\n    type InnerPolicy = EpsilonGreedyPolicy<B, DQN<B, M>>;\n    type Record = DqnLearningRecord<B, M, O>;\n\n    fn train(\n        &mut self,\n        input: LearnerTransitionBatch<B, Self::InnerPolicy>,\n    ) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState> {\n        let states_batch = input.states;\n        let next_states_batch = input.next_states;\n        let actions_batch = input.actions.actions;\n        let rewards_batch = input.rewards;\n        let dones_batch = input.dones;\n\n        // Optimize\n        let logits = self.policy_model.forward(states_batch).logits;\n        let state_action_values = logits.gather(1, actions_batch.int());\n\n        let next_state_values = self.target_model.forward(next_states_batch.clone());\n        let next_state_values = next_state_values.logits.max_dim(1).squeeze::<1>();\n\n        let not_done_batch = Tensor::ones_like(&dones_batch) - dones_batch;\n        let expected_state_action_values = (next_state_values * not_done_batch.squeeze())\n            .mul_scalar(self.config.gamma)\n            + rewards_batch.squeeze();\n        let expected_state_action_values = expected_state_action_values.unsqueeze_dim::<2>(1);\n\n        let loss = MseLoss::new().forward(\n            state_action_values,\n            expected_state_action_values,\n            nn::loss::Reduction::Mean,\n        );\n        let gradients = loss.backward();\n        let gradient_params = GradientsParams::from_grads(gradients, &self.policy_model);\n        self.policy_model = self.optimizer.step(\n            self.config.learning_rate,\n            self.policy_model.clone(),\n            gradient_params,\n        );\n        self.target_model = self\n            .target_model\n            .soft_update(&self.policy_model, self.config.tau);\n        let policy_update = EpsilonGreedyPolicyState::new(\n            DqnState {\n                model: self.policy_model.clone(),\n                _backend: PhantomData,\n            },\n            self.agent.state().step,\n        );\n        self.agent.update(policy_update.clone());\n        RLTrainOutput {\n            policy: policy_update,\n            item: SimpleTrainOutput {\n                policy_model_loss: loss,\n            },\n        }\n    }\n\n    fn policy(&self) -> Self::InnerPolicy {\n        self.agent.clone()\n    }\n\n    fn update_policy(&mut self, update: Self::InnerPolicy) {\n        self.agent = update;\n    }\n\n    fn record(&self) -> Self::Record {\n        DqnLearningRecord {\n            policy_model: self.policy_model.clone().into_record(),\n            target_model: self.target_model.clone().into_record(),\n            optimizer: self.optimizer.to_record(),\n        }\n    }\n\n    fn load_record(self, record: Self::Record) -> Self {\n        let policy_model = self.policy_model.load_record(record.policy_model);\n        let target_model = self.target_model.load_record(record.target_model);\n        let optimizer = self.optimizer.load_record(record.optimizer);\n        Self {\n            policy_model,\n            target_model,\n            agent: self.agent,\n            optimizer,\n            config: self.config,\n        }\n    }\n}\n"
  },
  {
    "path": "examples/dqn-agent/src/env.rs",
    "content": "use burn::rl::{Environment, StepResult};\nuse burn::{\n    Tensor,\n    prelude::{Backend, ToElement},\n};\nuse gym_rs::{\n    core::Env,\n    envs::classical_control::cartpole::{CartPoleEnv, CartPoleObservation},\n};\n\nuse crate::agent::{DiscreteActionTensor, ObservationTensor};\n\n#[derive(Clone)]\npub struct CartPoleAction {\n    action: usize,\n}\n\nimpl<B: Backend> From<DiscreteActionTensor<B, 2>> for CartPoleAction {\n    fn from(value: DiscreteActionTensor<B, 2>) -> Self {\n        Self {\n            action: value.actions.int().into_scalar().to_usize(),\n        }\n    }\n}\n\nimpl<B: Backend> From<CartPoleAction> for DiscreteActionTensor<B, 2> {\n    fn from(value: CartPoleAction) -> Self {\n        DiscreteActionTensor {\n            actions: Tensor::<B, 1>::from_data([value.action], &Default::default()).unsqueeze(),\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct CartPoleState {\n    pub state: [f64; 4],\n}\n\nimpl From<CartPoleObservation> for CartPoleState {\n    fn from(observation: CartPoleObservation) -> Self {\n        let vec = Vec::<f64>::from(observation);\n        Self {\n            state: [vec[0], vec[1], vec[2], vec[3]],\n        }\n    }\n}\nimpl<B: Backend> From<CartPoleState> for ObservationTensor<B, 2> {\n    fn from(val: CartPoleState) -> Self {\n        ObservationTensor {\n            state: Tensor::<B, 1>::from_floats(val.state, &Default::default()).unsqueeze(),\n        }\n    }\n}\n\n#[derive(Clone)]\npub struct CartPoleWrapper {\n    gym_env: CartPoleEnv,\n    step_index: usize,\n}\n\nimpl Default for CartPoleWrapper {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl CartPoleWrapper {\n    pub fn new() -> Self {\n        Self {\n            gym_env: CartPoleEnv::new(gym_rs::utils::renderer::RenderMode::None),\n            step_index: 0,\n        }\n    }\n}\n\nimpl Environment for CartPoleWrapper {\n    type State = CartPoleState;\n    type Action = CartPoleAction;\n\n    const MAX_STEPS: usize = 500;\n\n    fn state(&self) -> Self::State {\n        CartPoleState::from(self.gym_env.state)\n    }\n\n    fn step(&mut self, action: Self::Action) -> StepResult<Self::State> {\n        let action_reward = self.gym_env.step(action.action);\n        self.step_index += 1;\n        StepResult {\n            next_state: CartPoleState::from(action_reward.observation),\n            reward: action_reward.reward.into_inner(),\n            done: action_reward.done,\n            truncated: self.step_index >= Self::MAX_STEPS,\n        }\n    }\n\n    fn reset(&mut self) {\n        self.gym_env.reset(None, false, None);\n        self.step_index = 0;\n    }\n}\n"
  },
  {
    "path": "examples/dqn-agent/src/lib.rs",
    "content": "pub mod agent;\npub mod env;\npub mod training;\npub mod utils;\n"
  },
  {
    "path": "examples/dqn-agent/src/training.rs",
    "content": "use burn::{\n    grad_clipping::GradientClippingConfig,\n    optim::AdamWConfig,\n    record::CompactRecorder,\n    tensor::backend::AutodiffBackend,\n    train::{\n        OffPolicyConfig, RLTraining,\n        metric::{CumulativeRewardMetric, EpisodeLengthMetric, ExplorationRateMetric, LossMetric},\n    },\n};\n\nuse crate::{\n    agent::{DqnAgentConfig, DqnLearningAgent, MlpNet, MlpNetConfig},\n    env::CartPoleWrapper,\n};\n\nstatic ARTIFACT_DIR: &str = \"/tmp/burn-example-dqn-agent\";\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    let dqn_config = DqnAgentConfig {\n        gamma: 0.99,\n        learning_rate: 3e-4,\n        tau: 0.005,\n        epsilon_start: 0.99,\n        epsilon_end: 0.05,\n        epsilon_decay: 6000.0,\n    };\n    let model_config = MlpNetConfig {\n        num_layers: 3,\n        dropout: 0.0,\n        d_input: 4,\n        d_output: 2,\n        d_hidden: 64,\n    };\n    let learning_config = OffPolicyConfig {\n        num_envs: 8,\n        autobatch_size: 8,\n        replay_buffer_size: 50_000,\n        train_interval: 8,\n        eval_interval: 4_000,\n        eval_episodes: 5,\n        train_batch_size: 128,\n        train_steps: 4,\n        warmup_steps: 0,\n    };\n\n    let policy_model = MlpNet::<B>::new(&model_config, &device);\n    let optimizer = AdamWConfig::new()\n        .with_grad_clipping(Some(GradientClippingConfig::Value(100.0)))\n        .init();\n    let agent = DqnLearningAgent::new(policy_model, optimizer, dqn_config);\n    let learner = RLTraining::new(ARTIFACT_DIR, CartPoleWrapper::new)\n        .metrics_train((LossMetric::new(),))\n        .metrics_agent((ExplorationRateMetric::new(),))\n        .metrics_episode((EpisodeLengthMetric::new(), CumulativeRewardMetric::new()))\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_steps(40_000)\n        .with_learning_strategy(burn::train::RLStrategies::OffPolicyStrategy(\n            learning_config,\n        ))\n        .summary();\n\n    let _result = learner.launch(agent);\n}\n"
  },
  {
    "path": "examples/dqn-agent/src/utils.rs",
    "content": "use std::marker::PhantomData;\n\nuse burn::{\n    Tensor,\n    module::{Param, ParamId},\n    nn::{self, Linear},\n    prelude::Backend,\n    record::Record,\n    rl::{Policy, PolicyState},\n    tensor::Device,\n    train::{\n        ItemLazy,\n        metric::{Adaptor, ExplorationRateInput},\n    },\n};\nuse derive_new::new;\nuse rand::{random, random_range};\n\nuse crate::agent::{DiscreteActionTensor, DiscreteLogitsTensor};\n\npub fn create_lin_layers<B: Backend>(\n    num_layers: usize,\n    d_input: usize,\n    d_hidden: usize,\n    d_output: usize,\n    device: &Device<B>,\n) -> Vec<Linear<B>> {\n    let mut linears = Vec::with_capacity(num_layers);\n\n    if num_layers == 1 {\n        linears.push(nn::LinearConfig::new(d_input, d_output).init(device));\n        return linears;\n    }\n    for i in 0..num_layers {\n        if i == 0 {\n            linears.push(nn::LinearConfig::new(d_input, d_hidden).init(device));\n        } else if i == num_layers - 1 {\n            linears.push(nn::LinearConfig::new(d_hidden, d_output).init(device));\n        } else {\n            linears.push(nn::LinearConfig::new(d_hidden, d_hidden).init(device));\n        }\n    }\n    linears\n}\n\npub fn soft_update_linear<B: Backend>(this: Linear<B>, that: &Linear<B>, tau: f64) -> Linear<B> {\n    let weight = soft_update_tensor(&this.weight, &that.weight, tau);\n    let bias = match (&this.bias, &that.bias) {\n        (Some(this_bias), Some(that_bias)) => Some(soft_update_tensor(this_bias, that_bias, tau)),\n        _ => None,\n    };\n\n    Linear::<B> { weight, bias }\n}\n\nfn soft_update_tensor<const N: usize, B: Backend>(\n    this: &Param<Tensor<B, N>>,\n    that: &Param<Tensor<B, N>>,\n    tau: f64,\n) -> Param<Tensor<B, N>> {\n    let that_weight = that.val();\n    let this_weight = this.val();\n    let new_weight = this_weight * (1.0 - tau) + that_weight * tau;\n\n    Param::initialized(ParamId::new(), new_weight)\n}\n\n#[derive(Clone)]\npub struct EpsilonGreedyPolicyOutput {\n    pub epsilon: f64,\n}\n\nimpl ItemLazy for EpsilonGreedyPolicyOutput {\n    type ItemSync = EpsilonGreedyPolicyOutput;\n\n    fn sync(self) -> Self::ItemSync {\n        self\n    }\n}\n\nimpl Adaptor<ExplorationRateInput> for EpsilonGreedyPolicyOutput {\n    fn adapt(&self) -> ExplorationRateInput {\n        ExplorationRateInput::new(self.epsilon)\n    }\n}\n\n#[derive(Record)]\npub struct EpsilonGreedyPolicyRecord<B: Backend, P: Policy<B>> {\n    pub inner_state: <P::PolicyState as PolicyState<B>>::Record,\n    pub step: usize,\n}\n\n#[derive(Clone, new)]\npub struct EpsilonGreedyPolicyState<B: Backend, P: Policy<B>> {\n    pub inner_state: P::PolicyState,\n    pub step: usize,\n}\n\nimpl<B: Backend, P: Policy<B>> PolicyState<B> for EpsilonGreedyPolicyState<B, P> {\n    type Record = EpsilonGreedyPolicyRecord<B, P>;\n\n    fn into_record(self) -> Self::Record {\n        EpsilonGreedyPolicyRecord {\n            inner_state: self.inner_state.into_record(),\n            step: self.step,\n        }\n    }\n\n    fn load_record(&self, record: Self::Record) -> Self {\n        let inner_state = self.inner_state.load_record(record.inner_state);\n        Self {\n            inner_state,\n            step: record.step,\n        }\n    }\n}\n\n#[derive(Clone, Debug)]\npub struct EpsilonGreedyPolicy<B: Backend, P: Policy<B>> {\n    inner_policy: P,\n    eps_start: f64,\n    eps_end: f64,\n    eps_decay: f64,\n    step: usize,\n    _backend: PhantomData<B>,\n}\n\nimpl<B: Backend, P: Policy<B>> EpsilonGreedyPolicy<B, P> {\n    pub fn new(inner_policy: P, eps_start: f64, eps_end: f64, eps_decay: f64) -> Self {\n        Self {\n            inner_policy,\n            eps_start,\n            eps_end,\n            eps_decay,\n            step: 0,\n            _backend: PhantomData,\n        }\n    }\n\n    fn get_threshold(&self) -> f64 {\n        self.eps_end\n            + (self.eps_start - self.eps_end) * f64::exp(-(self.step as f64) / self.eps_decay)\n    }\n\n    fn step(&mut self) -> f64 {\n        let thresh = self.get_threshold();\n        self.step += 1;\n        thresh\n    }\n}\n\nimpl<B, P> Policy<B> for EpsilonGreedyPolicy<B, P>\nwhere\n    B: Backend,\n    P: Policy<\n            B,\n            ActionDistribution = DiscreteLogitsTensor<B, 2>,\n            Action = DiscreteActionTensor<B, 2>,\n        >,\n{\n    type ActionContext = EpsilonGreedyPolicyOutput;\n    type PolicyState = EpsilonGreedyPolicyState<B, P>;\n\n    type Observation = P::Observation;\n    type ActionDistribution = DiscreteLogitsTensor<B, 2>;\n    type Action = DiscreteActionTensor<B, 2>;\n\n    fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {\n        self.inner_policy.forward(states)\n    }\n\n    fn action(\n        &mut self,\n        states: Self::Observation,\n        deterministic: bool,\n    ) -> (Self::Action, Vec<Self::ActionContext>) {\n        let logits = self.inner_policy.forward(states).logits;\n        let greedy_actions = logits.argmax(1);\n        let greedy_actions = greedy_actions.split(1, 0);\n\n        let mut contexts = vec![];\n        let mut actions = vec![];\n        for a in greedy_actions {\n            let threshold = self.step();\n            let threshold = if deterministic { 0.0 } else { threshold };\n            contexts.push(EpsilonGreedyPolicyOutput { epsilon: threshold });\n            if random::<f64>() > threshold {\n                actions.push(a.clone().float());\n            } else {\n                actions.push(\n                    Tensor::<B, 1>::from_floats([random_range(0..2)], &a.device()).unsqueeze(),\n                );\n            }\n        }\n\n        let output = Tensor::cat(actions, 0);\n        (DiscreteActionTensor { actions: output }, contexts)\n    }\n\n    fn update(&mut self, update: Self::PolicyState) {\n        // Note : updating an epsilon greedy policy doesn't change the step.\n        self.inner_policy.update(update.inner_state);\n    }\n\n    fn state(&self) -> Self::PolicyState {\n        EpsilonGreedyPolicyState {\n            inner_state: self.inner_policy.state(),\n            step: self.step,\n        }\n    }\n\n    fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {\n        let state = self.state().load_record(record);\n        let inner_policy = self\n            .inner_policy\n            .load_record(state.inner_state.into_record());\n        EpsilonGreedyPolicy {\n            inner_policy,\n            eps_start: self.eps_start,\n            eps_end: self.eps_end,\n            eps_decay: self.eps_decay,\n            step: state.step,\n            _backend: PhantomData,\n        }\n    }\n}\n"
  },
  {
    "path": "examples/guide/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"guide\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/default\", \"burn/tui\"]\n# Opt-in for macOS users with the Vulkan SDK installed (enabled by default on other platforms)\nvulkan = [\"burn/vulkan\"]\n\n[dependencies]\nburn = { path = \"../../crates/burn\", features = [\n    \"webgpu\",\n    \"train\",\n    \"vision\",\n], default-features = false }\n\n[target.'cfg(not(target_os = \"macos\"))'.dependencies]\nburn = { path = \"../../crates/burn\", features = [\"vulkan\"] }\n\n# Serialization\nlog = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n"
  },
  {
    "path": "examples/guide/README.md",
    "content": "# Basic Workflow: From Training to Inference\n\nThis example corresponds to the [book's guide](https://burn.dev/books/burn/basic-workflow/).\n\n## Example Usage\n\n\n### Training\n\n```sh\ncargo run --bin train --release\n```\n\n### Inference\n\n```sh\ncargo run --bin infer --release\n```\n\n### Print the model\n\n```sh\ncargo run --bin print --release\n```\n"
  },
  {
    "path": "examples/guide/examples/guide.rs",
    "content": "//\n// Note: If you are following the Burn Book guide this file can be ignored.\n//\n// This example file is added only for convenience and consistency so that\n// the guide example can be executed like any other examples using:\n//\n//     cargo run --release --example guide\n//\nuse std::process::Command;\n\nfn main() {\n    Command::new(\"cargo\")\n        .args([\"run\", \"--bin\", \"train\", \"--release\"])\n        .status()\n        .expect(\"guide example should run\");\n}\n"
  },
  {
    "path": "examples/guide/src/bin/infer.rs",
    "content": "#![recursion_limit = \"131\"]\nuse burn::{backend::WebGpu, data::dataset::Dataset};\nuse guide::inference;\n\nfn main() {\n    type MyBackend = WebGpu<f32, i32>;\n\n    let device = burn::backend::wgpu::WgpuDevice::default();\n\n    // All the training artifacts are saved in this directory\n    let artifact_dir = \"/tmp/guide\";\n\n    // Infer the model\n    inference::infer::<MyBackend>(\n        artifact_dir,\n        device,\n        burn::data::dataset::vision::MnistDataset::test()\n            .get(42)\n            .unwrap(),\n    );\n}\n"
  },
  {
    "path": "examples/guide/src/bin/print.rs",
    "content": "use burn::backend::WebGpu;\nuse guide::model::ModelConfig;\n\nfn main() {\n    type MyBackend = WebGpu<f32, i32>;\n\n    let device = Default::default();\n    let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);\n\n    println!(\"{model}\");\n}\n"
  },
  {
    "path": "examples/guide/src/bin/train.rs",
    "content": "#![recursion_limit = \"131\"]\nuse burn::{\n    backend::{Autodiff, WebGpu},\n    data::dataset::Dataset,\n    optim::AdamConfig,\n};\nuse guide::{\n    inference,\n    model::ModelConfig,\n    training::{self, TrainingConfig},\n};\n\nfn main() {\n    type MyBackend = WebGpu<f32, i32>;\n    type MyAutodiffBackend = Autodiff<MyBackend>;\n\n    // Create a default Wgpu device\n    let device = burn::backend::wgpu::WgpuDevice::default();\n\n    // All the training artifacts will be saved in this directory\n    let artifact_dir = \"target/guide\";\n\n    // Train the model\n    training::train::<MyAutodiffBackend>(\n        artifact_dir,\n        TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),\n        device.clone(),\n    );\n\n    // Infer the model\n    inference::infer::<MyBackend>(\n        artifact_dir,\n        device,\n        burn::data::dataset::vision::MnistDataset::test()\n            .get(42)\n            .unwrap(),\n    );\n}\n"
  },
  {
    "path": "examples/guide/src/data.rs",
    "content": "use burn::{\n    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n    prelude::*,\n};\n\n#[derive(Clone, Default)]\npub struct MnistBatcher {}\n\n#[derive(Clone, Debug)]\npub struct MnistBatch<B: Backend> {\n    pub images: Tensor<B, 3>,\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {\n    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {\n        let images = items\n            .iter()\n            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())\n            .map(|data| Tensor::<B, 2>::from_data(data, device))\n            .map(|tensor| tensor.reshape([1, 28, 28]))\n            // Normalize: scale between [0,1] and make the mean=0 and std=1\n            // values mean=0.1307,std=0.3081 are from the PyTorch MNIST example\n            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122\n            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)\n            .collect();\n\n        let targets = items\n            .iter()\n            .map(|item| {\n                Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)\n            })\n            .collect();\n\n        let images = Tensor::cat(images, 0);\n        let targets = Tensor::cat(targets, 0);\n\n        MnistBatch { images, targets }\n    }\n}\n"
  },
  {
    "path": "examples/guide/src/inference.rs",
    "content": "use crate::{data::MnistBatcher, training::TrainingConfig};\nuse burn::{\n    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n    prelude::*,\n    record::{CompactRecorder, Recorder},\n};\n\npub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {\n    let config = TrainingConfig::load(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should exist for the model; run train first\");\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model should exist; run train first\");\n\n    let model = config.model.init::<B>(&device).load_record(record);\n\n    let label = item.label;\n    let batcher = MnistBatcher::default();\n    let batch = batcher.batch(vec![item], &device);\n    let output = model.forward(batch.images);\n    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();\n\n    println!(\"Predicted {predicted} Expected {label}\");\n}\n"
  },
  {
    "path": "examples/guide/src/lib.rs",
    "content": "//\n// Note: If you are following the Burn Book guide this file can be ignored.\n//\n// This lib.rs file is added only for convenience so that the code in this\n// guide can be reused.\n//\npub mod data;\npub mod inference;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/guide/src/model.rs",
    "content": "use burn::{\n    nn::{\n        Dropout, DropoutConfig, Linear, LinearConfig, Relu,\n        conv::{Conv2d, Conv2dConfig},\n        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},\n    },\n    prelude::*,\n};\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n    pool: AdaptiveAvgPool2d,\n    dropout: Dropout,\n    linear1: Linear<B>,\n    linear2: Linear<B>,\n    activation: Relu,\n}\n\n#[derive(Config, Debug)]\npub struct ModelConfig {\n    num_classes: usize,\n    hidden_size: usize,\n    #[config(default = \"0.5\")]\n    dropout: f64,\n}\n\nimpl ModelConfig {\n    /// Returns the initialized model.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {\n        Model {\n            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),\n            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),\n            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),\n            activation: Relu::new(),\n            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),\n            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n        }\n    }\n}\n\nimpl<B: Backend> Model<B> {\n    /// # Shapes\n    ///   - Images [batch_size, height, width]\n    ///   - Output [batch_size, class_prob]\n    pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = images.dims();\n\n        // Create a channel.\n        let x = images.reshape([batch_size, 1, height, width]);\n\n        let x = self.conv1.forward(x); // [batch_size, 8, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.conv2.forward(x); // [batch_size, 16, _, _]\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        let x = self.pool.forward(x); // [batch_size, 16, 8, 8]\n        let x = x.reshape([batch_size, 16 * 8 * 8]);\n        let x = self.linear1.forward(x);\n        let x = self.dropout.forward(x);\n        let x = self.activation.forward(x);\n\n        self.linear2.forward(x) // [batch_size, num_classes]\n    }\n}\n"
  },
  {
    "path": "examples/guide/src/training.rs",
    "content": "use crate::{\n    data::{MnistBatch, MnistBatcher},\n    model::{Model, ModelConfig},\n};\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n    nn::loss::CrossEntropyLossConfig,\n    optim::AdamConfig,\n    prelude::*,\n    record::CompactRecorder,\n    tensor::backend::AutodiffBackend,\n    train::{\n        ClassificationOutput, InferenceStep, Learner, SupervisedTraining, TrainOutput, TrainStep,\n        metric::{AccuracyMetric, LossMetric},\n    },\n};\n\nimpl<B: Backend> Model<B> {\n    pub fn forward_classification(\n        &self,\n        images: Tensor<B, 3>,\n        targets: Tensor<B, 1, Int>,\n    ) -> ClassificationOutput<B> {\n        let output = self.forward(images);\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output.device())\n            .forward(output.clone(), targets.clone());\n\n        ClassificationOutput::new(loss, output, targets)\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_classification(batch.images, batch.targets);\n\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {\n        self.forward_classification(batch.images, batch.targets)\n    }\n}\n\n#[derive(Config, Debug)]\npub struct TrainingConfig {\n    pub model: ModelConfig,\n    pub optimizer: AdamConfig,\n    #[config(default = 10)]\n    pub num_epochs: usize,\n    #[config(default = 64)]\n    pub batch_size: usize,\n    #[config(default = 4)]\n    pub num_workers: usize,\n    #[config(default = 42)]\n    pub seed: u64,\n    #[config(default = 1.0e-4)]\n    pub learning_rate: f64,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {\n    create_artifact_dir(artifact_dir);\n    config\n        .save(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should be saved successfully\");\n\n    B::seed(&device, config.seed);\n\n    let batcher = MnistBatcher::default();\n\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::test());\n\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metrics((AccuracyMetric::new(), LossMetric::new()))\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let model = config.model.init::<B>(&device);\n    let result = training.launch(Learner::new(\n        model,\n        config.optimizer.init(),\n        config.learning_rate,\n    ));\n\n    result\n        .model\n        .save_file(format!(\"{artifact_dir}/model\"), &CompactRecorder::new())\n        .expect(\"Trained model should be saved successfully\");\n}\n"
  },
  {
    "path": "examples/import-model-weights/Cargo.toml",
    "content": "[package]\nauthors = [\"Dilshod Tadjibaev (@antimora)\"]\nedition.workspace = true\nlicense = \"MIT OR Apache-2.0\"\nname = \"import-model-weights\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[dependencies]\n\nburn = { path = \"../../crates/burn\", features = [\n    \"ndarray\",\n    \"dataset\",\n    \"vision\",\n] }\n\nburn-store = { path = \"../../crates/burn-store\", features = [\n    \"std\",\n    \"pytorch\",\n    \"safetensors\",\n    \"burnpack\",\n], default-features = false }\n"
  },
  {
    "path": "examples/import-model-weights/README.md",
    "content": "# Import Model Weights\n\nThis crate provides examples for importing model weights from different formats to Burn.\n\n## Examples\n\n### PyTorch\n\nImports weights from a PyTorch `.pt` file using `burn-store`.\n\n```bash\ncargo run --bin pytorch -- <image_index>\n```\n\nExample:\n```bash\ncargo run --bin pytorch -- 15\n\nLoading PyTorch model weights from file: weights/mnist.pt\nImage index: 15\nSuccess!\nPredicted: 5\nActual: 5\nSee the image online, click the link below:\nhttps://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=15\n```\n\n### Safetensors\n\nImports weights from a Safetensors file using `burn-store`.\n\n```bash\ncargo run --bin safetensors -- <image_index>\n```\n\nExample:\n```bash\ncargo run --bin safetensors -- 42\n\nLoading Safetensors model weights from file: weights/mnist.safetensors\nImage index: 42\nSuccess!\nPredicted: 4\nActual: 4\nSee the image online, click the link below:\nhttps://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=42\n```\n\n### Convert\n\nConverts between different weight formats (PyTorch or Safetensors) to Burn's native Burnpack format.\n\n```bash\ncargo run --bin convert -- <format> <output_directory>\n```\n\nWhere:\n- `<format>`: Either `pytorch` or `safetensors`\n- `<output_directory>`: Path to save the converted model file\n\nExample with PyTorch:\n```bash\ncargo run --bin convert -- pytorch /tmp/burn-convert\n\nLoading PyTorch weights from 'weights/mnist.pt'...\nSaving model to '/tmp/burn-convert/mnist.bpk'...\nModel successfully saved to '/tmp/burn-convert/mnist.bpk'.\n```\n\nExample with Safetensors:\n```bash\ncargo run --bin convert -- safetensors /tmp/burn-convert\n\nLoading Safetensors weights from 'weights/mnist.safetensors'...\nSaving model to '/tmp/burn-convert/mnist.bpk'...\nModel successfully saved to '/tmp/burn-convert/mnist.bpk'.\n```\n\n### Burnpack\n\nDemonstrates loading and using a model from Burn's native Burnpack format.\n\n```bash\ncargo run --bin burnpack -- <image_index> <model_path>\n```\n\nWhere:\n- `<image_index>`: Index of the MNIST test image to classify\n- `<model_path>`: Path to the model file (without extension)\n\nExample:\n```bash\ncargo run --bin burnpack -- 35 /tmp/burn-convert/mnist\n\nLoading model weights from file: /tmp/burn-convert/mnist.bpk\nImage index: 35\nSuccess!\nPredicted: 2\nActual: 2\nSee the image online, click the link below:\nhttps://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=35\n```\n\n## Workflow\n\nA typical workflow using these examples:\n\n1. Start with pre-trained weights in either PyTorch or Safetensors format\n2. Use the `convert` example to convert to Burn's native Burnpack format\n3. Load and use the converted model with the `burnpack` example\n"
  },
  {
    "path": "examples/import-model-weights/src/bin/burnpack.rs",
    "content": "use std::env;\nuse std::path::Path;\n\nuse burn::backend::NdArray;\n\nuse burn_store::{BurnpackStore, ModuleSnapshot};\nuse import_model_weights::{Model, infer};\n\ntype B = NdArray<f32>;\n\npub fn main() {\n    let args: Vec<String> = env::args().collect();\n\n    if args.len() < 3 {\n        eprintln!(\"Usage: {} <image_index> <model_path>\", args[0]);\n        std::process::exit(1);\n    }\n\n    let model_path_str = &args[2];\n    let model_path = Path::new(model_path_str);\n    println!(\n        \"Loading model weights from file: {}.bpk\",\n        model_path.display()\n    );\n\n    // Initialize a model with default weights\n    let device = Default::default();\n    let mut model: Model<B> = Model::init(&device);\n\n    // Load the model from the Burnpack file\n    let mut store = BurnpackStore::from_file(model_path);\n    model\n        .load_from(&mut store)\n        .expect(\"Failed to load model from Burnpack file\");\n\n    // Infer using the loaded model\n    infer(model);\n}\n"
  },
  {
    "path": "examples/import-model-weights/src/bin/convert.rs",
    "content": "use std::{env, path::Path, process};\n\nuse burn::backend::NdArray;\nuse burn_store::{\n    BurnpackStore, ModuleSnapshot, PyTorchToBurnAdapter, PytorchStore, SafetensorsStore,\n};\nuse import_model_weights::Model;\n\n// Path constants\nconst PYTORCH_WEIGHTS_PATH: &str = \"weights/mnist.pt\";\nconst SAFETENSORS_WEIGHTS_PATH: &str = \"weights/mnist.safetensors\";\nconst MODEL_OUTPUT_NAME: &str = \"mnist\";\n\n// Basic backend type (not used for computation).\ntype B = NdArray<f32>;\n\npub fn main() {\n    let args: Vec<String> = env::args().collect();\n\n    // Check argument count\n    if args.len() < 3 {\n        eprintln!(\n            \"Usage: {} <pytorch|safetensors> <output_directory>\",\n            args[0]\n        );\n        process::exit(1);\n    }\n\n    // Get weight format and output directory from arguments\n    let weight_format = args[1].as_str();\n    let output_directory = Path::new(&args[2]);\n\n    // Use the default device (CPU)\n    let device = Default::default();\n\n    // Initialize a model with default weights\n    let mut model: Model<B> = Model::init(&device);\n\n    // Load the model weights based on the specified format\n    match weight_format {\n        \"pytorch\" => {\n            println!(\"Loading PyTorch weights from '{PYTORCH_WEIGHTS_PATH}'...\");\n            let mut store = PytorchStore::from_file(PYTORCH_WEIGHTS_PATH);\n            model.load_from(&mut store).unwrap_or_else(|e| {\n                panic!(\"Failed to load PyTorch model weights from '{PYTORCH_WEIGHTS_PATH}': {e}\")\n            });\n        }\n        \"safetensors\" => {\n            println!(\"Loading Safetensors weights from '{SAFETENSORS_WEIGHTS_PATH}'...\");\n            let mut store = SafetensorsStore::from_file(SAFETENSORS_WEIGHTS_PATH)\n                .with_from_adapter(PyTorchToBurnAdapter);\n            model.load_from(&mut store).unwrap_or_else(|e| {\n                panic!(\n                    \"Failed to load Safetensors model weights from '{SAFETENSORS_WEIGHTS_PATH}': {e}\"\n                )\n            });\n        }\n        _ => {\n            eprintln!(\n                \"Error: Unsupported weight format '{weight_format}'. Please use 'pytorch' or 'safetensors'.\"\n            );\n            process::exit(1);\n        }\n    };\n\n    // Define the output path for the Burn model file\n    let output_file_path = output_directory.join(MODEL_OUTPUT_NAME);\n\n    println!(\"Saving model to '{}.bpk'...\", output_file_path.display());\n\n    // Save the model using BurnpackStore\n    let mut store = BurnpackStore::from_file(&output_file_path).overwrite(true);\n    model.save_into(&mut store).unwrap_or_else(|e| {\n        panic!(\n            \"Failed to save model to '{}.bpk': {e}\",\n            output_file_path.display()\n        )\n    });\n\n    println!(\n        \"Model successfully saved to '{}.bpk'.\",\n        output_file_path.display()\n    );\n}\n"
  },
  {
    "path": "examples/import-model-weights/src/bin/pytorch.rs",
    "content": "use burn::backend::NdArray;\n\nuse burn_store::{ModuleSnapshot, PytorchStore};\n\nuse import_model_weights::{Model, infer};\n\ntype B = NdArray<f32>;\n\nconst WEIGHTS_FILE: &str = \"weights/mnist.pt\";\n\npub fn main() {\n    println!(\"Loading PyTorch model weights from file: {WEIGHTS_FILE}\");\n\n    // Initialize a model with default weights\n    let device = Default::default();\n    let mut model: Model<B> = Model::init(&device);\n\n    // Load PyTorch weights into the model\n    let mut store = PytorchStore::from_file(WEIGHTS_FILE);\n    model\n        .load_from(&mut store)\n        .expect(\"Failed to load PyTorch model weights\");\n\n    // Infer using the loaded model\n    infer(model);\n}\n"
  },
  {
    "path": "examples/import-model-weights/src/bin/safetensors.rs",
    "content": "use burn::backend::NdArray;\n\nuse burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};\n\nuse import_model_weights::{Model, infer};\n\ntype B = NdArray<f32>;\n\nconst WEIGHTS_FILE: &str = \"weights/mnist.safetensors\";\n\npub fn main() {\n    println!(\"Loading Safetensors model weights from file: {WEIGHTS_FILE}\");\n\n    // Initialize a model with default weights\n    let device = Default::default();\n    let mut model: Model<B> = Model::init(&device);\n\n    // Load Safetensors weights into the model (using PyTorch adapter since weights were exported from PyTorch)\n    let mut store =\n        SafetensorsStore::from_file(WEIGHTS_FILE).with_from_adapter(PyTorchToBurnAdapter);\n    model\n        .load_from(&mut store)\n        .expect(\"Failed to load Safetensors model weights\");\n\n    // Infer using the loaded model\n    infer(model);\n}\n"
  },
  {
    "path": "examples/import-model-weights/src/inference.rs",
    "content": "use burn::prelude::*;\n\nuse std::env::args;\n\nuse burn::data::dataloader::Dataset;\nuse burn::data::dataset::vision::MnistDataset;\n\nuse crate::model::Model;\n\nconst IMAGE_INX: usize = 42; // <- Change this to test a different image\n\npub fn infer<B: Backend>(model: Model<B>) {\n    // Get image index argument (first) from command line\n\n    let image_index = if let Some(image_index) = args().nth(1) {\n        println!(\"Image index: {image_index}\");\n        image_index\n            .parse::<usize>()\n            .expect(\"Failed to parse image index\")\n    } else {\n        println!(\"No image index provided; Using default image index: {IMAGE_INX}\");\n        IMAGE_INX\n    };\n\n    assert!(image_index < 10000, \"Image index must be less than 10000\");\n\n    // Get device from the model\n    let device = model.devices().into_iter().next().unwrap_or_default();\n\n    // Load the MNIST dataset and get an item\n    let dataset = MnistDataset::test();\n    let item = dataset.get(image_index).unwrap();\n\n    // Create a tensor from the image data\n    let image_data = item.image.iter().copied().flatten().collect::<Vec<f32>>();\n    let mut input =\n        Tensor::<B, 1>::from_floats(image_data.as_slice(), &device).reshape([1, 1, 28, 28]);\n\n    // Normalize the input\n    input = ((input / 255) - 0.1307) / 0.3081;\n\n    // Run the model on the input\n    let output = model.forward(input);\n\n    // Get the index of the maximum value\n    let arg_max: u8 = output.argmax(1).into_scalar().elem();\n\n    // Check if the index matches the label\n    assert!(arg_max == item.label);\n\n    println!(\"Success!\");\n    println!(\"Predicted: {arg_max}\");\n    println!(\"Actual: {}\", item.label);\n    println!(\"See the image online, click the link below:\");\n    println!(\"https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row={image_index}\");\n}\n"
  },
  {
    "path": "examples/import-model-weights/src/lib.rs",
    "content": "pub mod inference;\npub mod model;\n\npub use inference::infer;\npub use model::Model;\n"
  },
  {
    "path": "examples/import-model-weights/src/model.rs",
    "content": "use burn::{\n    nn::{\n        BatchNorm, BatchNormConfig, Linear, LinearConfig,\n        conv::{Conv2d, Conv2dConfig},\n    },\n    prelude::*,\n    tensor::activation::{log_softmax, relu},\n};\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: Conv2d<B>,\n    conv2: Conv2d<B>,\n    conv3: Conv2d<B>,\n    norm1: BatchNorm<B>,\n    fc1: Linear<B>,\n    fc2: Linear<B>,\n    norm2: BatchNorm<B>,\n}\n\nimpl<B: Backend> Model<B> {\n    pub fn init(device: &B::Device) -> Self {\n        let conv1 = Conv2dConfig::new([1, 8], [3, 3]).init(device);\n        let conv2 = Conv2dConfig::new([8, 16], [3, 3]).init(device);\n        let conv3 = Conv2dConfig::new([16, 24], [3, 3]).init(device);\n        let norm1 = BatchNormConfig::new(24).init(device);\n        let fc1 = LinearConfig::new(11616, 32).init(device);\n        let fc2 = LinearConfig::new(32, 10).init(device);\n        let norm2 = BatchNormConfig::new(10).init(device);\n\n        Self {\n            conv1,\n            conv2,\n            conv3,\n            norm1,\n            fc1,\n            fc2,\n            norm2,\n        }\n    }\n\n    pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 2> {\n        let conv1_out1 = self.conv1.forward(input1);\n        let relu1_out1 = relu(conv1_out1);\n        let conv2_out1 = self.conv2.forward(relu1_out1);\n        let relu2_out1 = relu(conv2_out1);\n        let conv3_out1 = self.conv3.forward(relu2_out1);\n        let relu3_out1 = relu(conv3_out1);\n        let norm1_out1 = self.norm1.forward(relu3_out1);\n        let flatten1_out1 = norm1_out1.flatten(1, 3);\n        let fc1_out1 = self.fc1.forward(flatten1_out1);\n        let relu4_out1 = relu(fc1_out1);\n        let fc2_out1 = self.fc2.forward(relu4_out1);\n        let norm2_out1 = self.norm2.forward(fc2_out1);\n        log_softmax(norm2_out1, 1)\n    }\n}\n"
  },
  {
    "path": "examples/import-model-weights/weights/mnist_train_export.py",
    "content": "#!/usr/bin/env python3\n\n# Originally copied and modified from: https://github.com/pytorch/examples/blob/main/mnist/main.py\n# under the following license:  BSD-3-Clause license\n\nfrom __future__ import print_function\nimport argparse\nfrom safetensors.torch import save_file\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torchvision import datasets, transforms\nfrom torch.optim.lr_scheduler import StepLR\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.conv1 = nn.Conv2d(1, 8, 3)\n        self.conv2 = nn.Conv2d(8, 16, 3)\n        self.conv3 = nn.Conv2d(16, 24, 3)\n        self.norm1 = nn.BatchNorm2d(24)\n        self.dropout1 = nn.Dropout(0.3)\n        self.fc1 = nn.Linear(24 * 22 * 22, 32)\n        self.fc2 = nn.Linear(32, 10)\n        self.norm2 = nn.BatchNorm1d(10)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = F.relu(x)\n        x = self.conv2(x)\n        x = F.relu(x)\n        x = self.conv3(x)\n        x = F.relu(x)\n        x = self.norm1(x)\n        x = torch.flatten(x, 1)\n        x = self.fc1(x)\n        x = F.relu(x)\n        x = self.dropout1(x)\n        x = self.fc2(x)\n        x = self.norm2(x)\n        output = F.log_softmax(x, dim=1)\n        return output\n\n\ndef train(args, model, device, train_loader, optimizer, epoch):\n    model.train()\n    for batch_idx, (data, target) in enumerate(train_loader):\n        data, target = data.to(device), target.to(device)\n        optimizer.zero_grad()\n        output = model(data)\n        loss = F.nll_loss(output, target)\n        loss.backward()\n        optimizer.step()\n        if batch_idx % args.log_interval == 0:\n            print(\n                \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n                    epoch,\n                    batch_idx * len(data),\n                    len(train_loader.dataset),\n                    100.0 * batch_idx / len(train_loader),\n                    loss.item(),\n                )\n            )\n            if args.dry_run:\n                break\n\n\ndef test(model, device, test_loader):\n    model.eval()\n    test_loss = 0\n    correct = 0\n    with torch.no_grad():\n        for data, target in test_loader:\n            data, target = data.to(device), target.to(device)\n            output = model(data)\n            # sum up batch loss\n            test_loss += F.nll_loss(output, target, reduction=\"sum\").item()\n            # get the index of the max log-probability\n            pred = output.argmax(dim=1, keepdim=True)\n            correct += pred.eq(target.view_as(pred)).sum().item()\n\n    test_loss /= len(test_loader.dataset)\n\n    print(\n        \"\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n\".format(\n            test_loss,\n            correct,\n            len(test_loader.dataset),\n            100.0 * correct / len(test_loader.dataset),\n        )\n    )\n\n\ndef main():\n    # Training settings\n    parser = argparse.ArgumentParser(description=\"PyTorch MNIST Example\")\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=64,\n        metavar=\"N\",\n        help=\"input batch size for training (default: 64)\",\n    )\n    parser.add_argument(\n        \"--test-batch-size\",\n        type=int,\n        default=1000,\n        metavar=\"N\",\n        help=\"input batch size for testing (default: 1000)\",\n    )\n    parser.add_argument(\n        \"--epochs\",\n        type=int,\n        default=8,\n        metavar=\"N\",\n        help=\"number of epochs to train (default: 14)\",\n    )\n    parser.add_argument(\n        \"--lr\",\n        type=float,\n        default=1.0,\n        metavar=\"LR\",\n        help=\"learning rate (default: 1.0)\",\n    )\n    parser.add_argument(\n        \"--gamma\",\n        type=float,\n        default=0.7,\n        metavar=\"M\",\n        help=\"Learning rate step gamma (default: 0.7)\",\n    )\n    parser.add_argument(\n        \"--no-cuda\", action=\"store_true\", default=False, help=\"disables CUDA training\"\n    )\n    parser.add_argument(\n        \"--no-mps\",\n        action=\"store_true\",\n        default=False,\n        help=\"disables macOS GPU training\",\n    )\n    parser.add_argument(\n        \"--dry-run\",\n        action=\"store_true\",\n        default=False,\n        help=\"quickly check a single pass\",\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\"\n    )\n    parser.add_argument(\n        \"--log-interval\",\n        type=int,\n        default=10,\n        metavar=\"N\",\n        help=\"how many batches to wait before logging training status\",\n    )\n    parser.add_argument(\n        \"--save-model\",\n        action=\"store_true\",\n        default=True,\n        help=\"For Saving the current Model\",\n    )\n    parser.add_argument(\n        \"--export-onnx\",\n        action=\"store_true\",\n        default=False,\n        help=\"For Saving the current Model in ONNX format\",\n    )\n    args = parser.parse_args()\n    use_cuda = not args.no_cuda and torch.cuda.is_available()\n    use_mps = not args.no_mps and torch.backends.mps.is_available()\n\n    torch.manual_seed(args.seed)\n\n    if use_cuda:\n        device = torch.device(\"cuda\")\n    elif use_mps:\n        device = torch.device(\"mps\")\n        print(\"using MPS\")\n    else:\n        device = torch.device(\"cpu\")\n\n    train_kwargs = {\"batch_size\": args.batch_size}\n    test_kwargs = {\"batch_size\": args.test_batch_size}\n    if use_cuda:\n        cuda_kwargs = {\"num_workers\": 1, \"pin_memory\": True, \"shuffle\": True}\n        train_kwargs.update(cuda_kwargs)\n        test_kwargs.update(cuda_kwargs)\n\n    transform = transforms.Compose(\n        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n    )\n    dataset1 = datasets.MNIST(\n        \"/tmp/mnist-data\", train=True, download=True, transform=transform\n    )\n    dataset2 = datasets.MNIST(\"/tmp/mnist-data\", train=False, transform=transform)\n    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)\n    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)\n\n    model = Net().to(device)\n    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)\n\n    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)\n    for epoch in range(1, args.epochs + 1):\n        train(args, model, device, train_loader, optimizer, epoch)\n        test(model, device, test_loader)\n        scheduler.step()\n\n    if args.save_model:\n        torch.save(model.state_dict(), \"mnist.pt\")\n        save_file(model.state_dict(), \"mnist.safetensors\")\n\n    if args.export_onnx:\n        dummy_input = torch.randn(1, 1, 28, 28, device=device)\n        torch.onnx.export(\n            model, dummy_input, \"mnist.onnx\", verbose=True, opset_version=16\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/mnist/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"mnist\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/default\", \"burn/tui\"]\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\"]\nmetal = [\"burn/metal\"]\ncuda = [\"burn/cuda\"]\nvulkan = [\"burn/vulkan\"]\nrocm = [\"burn/rocm\"]\n\n[dependencies]\nburn = { path = \"../../crates/burn\", features = [\n    \"train\",\n    \"vision\",\n    \"metrics\",\n    \"fusion\",\n    \"dispatch\",\n    \"ndarray\",\n], default-features = false }\n\n# Serialization\nlog = { workspace = true }\nrand.workspace = true\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n"
  },
  {
    "path": "examples/mnist/README.md",
    "content": "# MNIST\n\nThe example is showing you how to:\n\n- Define your own custom module (MLP).\n- Create the data pipeline from a raw dataset to a batched multi-threaded fast DataLoader.\n- Configure a learner to display and log metrics as well as to keep training checkpoints.\n\nThe example can be run like so:\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n# Use the --release flag to really speed up training.\necho \"Using ndarray backend\"\ncargo run --example mnist --release --features ndarray                # CPU NdArray Backend - f32 - single thread\ncargo run --example mnist --release --features ndarray-blas-openblas  # CPU NdArray Backend - f32 - blas with openblas\ncargo run --example mnist --release --features ndarray-blas-netlib    # CPU NdArray Backend - f32 - blas with netlib\necho \"Using tch backend\"\nexport TORCH_CUDA_VERSION=cu128                                       # Set the cuda version\ncargo run --example mnist --release --features tch-gpu                # GPU Tch Backend - f32\ncargo run --example mnist --release --features tch-cpu                # CPU Tch Backend - f32\necho \"Using vulkan backend\"\ncargo run --example mnist --release --features vulkan\n```\n"
  },
  {
    "path": "examples/mnist/cubecl.toml",
    "content": "[profiling]\nlogger = { log = \"info\", level = \"disabled\" }\n\n[autotune]\nlevel = \"balanced\"\ncache = \"target\"\nlogger = { file = \"/tmp/autotune.log\", level = \"disabled\" }\n\n[compilation]\nlogger = { level = \"disabled\" }\ncache = \"target\"\n\n[memory]\nlogger = { level = \"disabled\", file = \"/tmp/memory.log\" }\npersistent_memory = \"enabled\"\n"
  },
  {
    "path": "examples/mnist/examples/mnist.rs",
    "content": "#![recursion_limit = \"256\"]\n\nuse burn::{Dispatch, DispatchDevice};\nuse mnist::training;\n\n#[cfg(feature = \"cuda\")]\nuse burn::backend::cuda::CudaDevice;\n#[cfg(feature = \"tch-gpu\")]\nuse burn::backend::libtorch::LibTorchDevice;\n#[cfg(feature = \"ndarray\")]\nuse burn::backend::ndarray::NdArrayDevice;\n#[cfg(feature = \"rocm\")]\nuse burn::backend::rocm::RocmDevice;\n#[cfg(any(feature = \"wgpu\", feature = \"metal\", feature = \"vulkan\"))]\nuse burn::backend::wgpu::WgpuDevice;\n\n#[allow(unreachable_code)]\nfn select_device() -> DispatchDevice {\n    #[cfg(feature = \"ndarray\")]\n    return NdArrayDevice::Cpu.into();\n\n    #[cfg(all(feature = \"tch-gpu\", not(target_os = \"macos\")))]\n    return LibTorchDevice::Cuda(0).into();\n\n    #[cfg(all(feature = \"tch-gpu\", target_os = \"macos\"))]\n    return LibTorchDevice::Mps.into();\n\n    #[cfg(feature = \"tch-cpu\")]\n    return LibTorchDevice::Cpu;\n\n    #[cfg(any(feature = \"wgpu\", feature = \"metal\", feature = \"vulkan\"))]\n    return WgpuDevice::default().into();\n\n    #[cfg(feature = \"cuda\")]\n    return CudaDevice::default().into();\n\n    #[cfg(feature = \"rocm\")]\n    return RocmDevice::default().into();\n\n    unreachable!(\"At least one backend will be selected.\")\n}\n\nfn main() {\n    let device = select_device();\n    training::run::<Dispatch>(DispatchDevice::autodiff(device));\n}\n"
  },
  {
    "path": "examples/mnist/src/data.rs",
    "content": "use std::{f32::consts::FRAC_PI_4, fmt::Display};\n\nuse burn::{\n    backend::NdArray,\n    data::{\n        dataloader::batcher::Batcher,\n        dataset::{transform::Mapper, vision::MnistItem},\n    },\n    prelude::*,\n    vision::Transform2D,\n};\nuse rand::RngExt;\n\n#[derive(Clone, Debug, Default)]\npub struct MnistBatcher {}\n\n#[derive(Clone, Debug)]\npub struct MnistBatch<B: Backend> {\n    pub images: Tensor<B, 3>,\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Batcher<B, MnistItemPrepared, MnistBatch<B>> for MnistBatcher {\n    fn batch(&self, items: Vec<MnistItemPrepared>, device: &B::Device) -> MnistBatch<B> {\n        let images = items.iter().map(|item| item.image.clone()).collect();\n\n        let targets = items\n            .iter()\n            .map(|item| {\n                Tensor::<NdArray, 1, Int>::from_data(\n                    TensorData::from([(item.label as i64).elem::<<NdArray as Backend>::IntElem>()]),\n                    &Default::default(),\n                )\n            })\n            .collect();\n\n        let images = Tensor::cat(images, 0);\n        let images = Tensor::from_data(images.into_data(), device);\n\n        let targets = Tensor::cat(targets, 0);\n        let targets = Tensor::from_data(targets.into_data(), device);\n\n        MnistBatch { images, targets }\n    }\n}\n\n#[derive(Clone, Debug, Copy)]\npub enum Transform {\n    Translate,\n    Shear,\n    Scale,\n    Rotation,\n}\n\nimpl Display for Transform {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Transform::Translate => f.write_str(\"Tr\"),\n            Transform::Shear => f.write_str(\"Sr\"),\n            Transform::Scale => f.write_str(\"Sc\"),\n            Transform::Rotation => f.write_str(\"Rot\"),\n        }\n    }\n}\n\n#[derive(Default)]\npub struct MnistMapper {\n    transforms: Vec<Transform>,\n}\n\nimpl MnistMapper {\n    pub fn transform(mut self, transforms: &[Transform]) -> Self {\n        for t in transforms {\n            self.transforms.push(*t);\n        }\n        self\n    }\n    pub fn translate(mut self) -> Self {\n        self.transforms.push(Transform::Translate);\n        self\n    }\n    pub fn shear(mut self) -> Self {\n        self.transforms.push(Transform::Shear);\n        self\n    }\n    pub fn scale(mut self) -> Self {\n        self.transforms.push(Transform::Scale);\n        self\n    }\n    pub fn rotation(mut self) -> Self {\n        self.transforms.push(Transform::Rotation);\n        self\n    }\n}\n\nimpl Mapper<MnistItem, MnistItemPrepared> for MnistMapper {\n    fn map(&self, item: &MnistItem) -> MnistItemPrepared {\n        prepare_image(&self.transforms, item.clone())\n    }\n}\n\n#[derive(Clone, Debug)]\npub struct MnistItemPrepared {\n    image: Tensor<NdArray, 3>,\n    label: u8,\n}\n\nfn prepare_image(transforms: &[Transform], item: MnistItem) -> MnistItemPrepared {\n    let data = TensorData::from(item.image);\n    let tensor = Tensor::<NdArray, 2>::from_data(data.convert::<f32>(), &Default::default());\n    let tensor = tensor.reshape([1, 28, 28]);\n\n    // normalize: make between [0,1] and make the mean =  0 and std = 1\n    // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example\n    // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122\n    let tensor = ((tensor / 255) - 0.1307) / 0.3081;\n    let tensor = if !transforms.is_empty() {\n        mangle_image_batch(transforms, tensor)\n    } else {\n        tensor\n    };\n\n    MnistItemPrepared {\n        image: tensor,\n        label: item.label,\n    }\n}\n\n/// Mange the image by applying small random transformations to augment the dataset.\n///\n/// * `images` - The images with shape [batch size, height, width]\n///\n/// ## Return\n///\n/// The transformed images tensor with shape [batch size, height, width]\nfn mangle_image_batch<B: Backend>(transforms: &[Transform], images: Tensor<B, 3>) -> Tensor<B, 3> {\n    let mut rng = rand::rng();\n\n    let transforms = transforms.iter().map(|transform| match transform {\n        Transform::Translate => {\n            Transform2D::translation(rng.random_range(-0.2..0.2), rng.random_range(-0.2..0.2))\n        }\n        Transform::Shear => Transform2D::shear(\n            rng.random_range(-0.6..0.6),\n            rng.random_range(-0.6..0.6),\n            0.0,\n            0.0,\n        ),\n        Transform::Scale => Transform2D::scale(\n            rng.random_range(0.6..1.5),\n            rng.random_range(0.6..1.5),\n            0.0,\n            0.0,\n        ),\n        Transform::Rotation => {\n            Transform2D::rotation(rng.random_range(-FRAC_PI_4..FRAC_PI_4), 0.0, 0.0)\n        }\n    });\n\n    Transform2D::composed(transforms)\n        .transform(images.unsqueeze_dim::<4>(1))\n        .squeeze_dims::<3>(&[1])\n}\n"
  },
  {
    "path": "examples/mnist/src/lib.rs",
    "content": "pub mod data;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/mnist/src/model.rs",
    "content": "use crate::data::MnistBatch;\nuse burn::{\n    nn::{\n        BatchNorm, PaddingConfig2d,\n        loss::CrossEntropyLossConfig,\n        pool::{MaxPool2d, MaxPool2dConfig},\n    },\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n    train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep},\n};\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: ConvBlock<B>,\n    conv2: ConvBlock<B>,\n    dropout: nn::Dropout,\n    fc1: nn::Linear<B>,\n    fc2: nn::Linear<B>,\n    fc3: nn::Linear<B>,\n    activation: nn::Gelu,\n}\n\nimpl<B: Backend> Default for Model<B> {\n    fn default() -> Self {\n        let device = B::Device::default();\n        Self::new(&device)\n    }\n}\n\nconst NUM_CLASSES: usize = 10;\n\nimpl<B: Backend> Model<B> {\n    pub fn new(device: &B::Device) -> Self {\n        let conv1 = ConvBlock::new([1, 64], [3, 3], device, true); // out: max_pool -> [Batch,32,13,13]\n        let conv2 = ConvBlock::new([64, 64], [3, 3], device, true); // out: max_pool -> [Batch,64,5,5]\n        let hidden_size = 64 * 5 * 5;\n        let fc1 = nn::LinearConfig::new(hidden_size, 128).init(device);\n        let fc2 = nn::LinearConfig::new(128, 128).init(device);\n        let fc3 = nn::LinearConfig::new(128, NUM_CLASSES).init(device);\n\n        let dropout = nn::DropoutConfig::new(0.25).init();\n\n        Self {\n            conv1,\n            conv2,\n            dropout,\n            fc1,\n            fc2,\n            fc3,\n            activation: nn::Gelu::new(),\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = input.dims();\n\n        let x = input.reshape([batch_size, 1, height, width]).detach();\n        let x = self.conv1.forward(x);\n        let x = self.conv2.forward(x);\n\n        let [batch_size, channels, height, width] = x.dims();\n        let x = x.reshape([batch_size, channels * height * width]);\n\n        let x = self.fc1.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n\n        let x = self.fc2.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.fc3.forward(x)\n    }\n\n    pub fn forward_classification(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {\n        let targets = item.targets;\n        let output = self.forward(item.images);\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output.device())\n            .forward(output.clone(), targets.clone());\n\n        ClassificationOutput {\n            loss,\n            output,\n            targets,\n        }\n    }\n}\n\n#[derive(Module, Debug)]\npub struct ConvBlock<B: Backend> {\n    conv: nn::conv::Conv2d<B>,\n    norm: BatchNorm<B>,\n    pool: Option<MaxPool2d>,\n    activation: nn::Relu,\n}\n\nimpl<B: Backend> ConvBlock<B> {\n    pub fn new(\n        channels: [usize; 2],\n        kernel_size: [usize; 2],\n        device: &B::Device,\n        pool: bool,\n    ) -> Self {\n        let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)\n            .with_padding(PaddingConfig2d::Valid)\n            .init(device);\n        let norm = nn::BatchNormConfig::new(channels[1]).init(device);\n        let pool = if pool {\n            Some(MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init())\n        } else {\n            None\n        };\n\n        Self {\n            conv,\n            norm,\n            pool,\n            activation: nn::Relu::new(),\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv.forward(input);\n        let x = self.norm.forward(x);\n        let x = self.activation.forward(x);\n\n        if let Some(pool) = &self.pool {\n            pool.forward(x)\n        } else {\n            x\n        }\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_classification(item);\n\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for Model<B> {\n    type Input = MnistBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: MnistBatch<B>) -> ClassificationOutput<B> {\n        self.forward_classification(item)\n    }\n}\n"
  },
  {
    "path": "examples/mnist/src/training.rs",
    "content": "use std::sync::Arc;\n\nuse crate::{\n    data::{MnistBatcher, MnistItemPrepared, MnistMapper, Transform},\n    model::Model,\n};\n\nuse burn::{\n    data::{\n        dataloader::DataLoaderBuilder,\n        dataset::{\n            Dataset,\n            transform::{ComposedDataset, MapperDataset, PartialDataset, SamplerDataset},\n            vision::{MnistDataset, MnistItem},\n        },\n    },\n    lr_scheduler::{\n        composed::ComposedLrSchedulerConfig, cosine::CosineAnnealingLrSchedulerConfig,\n        linear::LinearLrSchedulerConfig,\n    },\n    prelude::*,\n    record::{CompactRecorder, NoStdTrainingRecorder},\n    tensor::backend::AutodiffBackend,\n    train::{\n        EvaluatorBuilder, Learner, MetricEarlyStoppingStrategy, StoppingCondition,\n        metric::{\n            AccuracyMetric, LearningRateMetric, LossMetric,\n            store::{Aggregate, Direction, Split},\n        },\n        renderer::MetricsRenderer,\n    },\n};\nuse burn::{optim::AdamWConfig, train::SupervisedTraining};\n\nstatic ARTIFACT_DIR: &str = \"/tmp/burn-example-mnist\";\n\n#[derive(Config, Debug)]\npub struct MnistTrainingConfig {\n    #[config(default = 5)]\n    pub num_epochs: usize,\n\n    #[config(default = 256)]\n    pub batch_size: usize,\n\n    #[config(default = 8)]\n    pub num_workers: usize,\n\n    #[config(default = 42)]\n    pub seed: u64,\n\n    pub optimizer: AdamWConfig,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn run<B: AutodiffBackend>(device: B::Device) {\n    create_artifact_dir(ARTIFACT_DIR);\n    // Config\n    let config_optimizer = AdamWConfig::new()\n        .with_cautious_weight_decay(true)\n        .with_weight_decay(5e-5);\n\n    let config = MnistTrainingConfig::new(config_optimizer);\n    B::seed(&device, config.seed);\n\n    let model = Model::<B>::new(&device);\n\n    let dataset_train_original = Arc::new(MnistDataset::train());\n    let dataset_train_plain = PartialDataset::new(dataset_train_original.clone(), 0, 55_000);\n    let dataset_valid_plain = PartialDataset::new(dataset_train_original.clone(), 55_000, 60_000);\n\n    let ident_trains = generate_idents(Some(10000));\n    let ident_valid = generate_idents(None);\n    let dataset_train = DatasetIdent::compose(ident_trains, dataset_train_plain);\n    let dataset_valid = DatasetIdent::compose(ident_valid, dataset_valid_plain);\n\n    let dataloader_train = DataLoaderBuilder::new(MnistBatcher::default())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(dataset_train);\n    let dataloader_valid = DataLoaderBuilder::new(MnistBatcher::default())\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(dataset_valid);\n    let lr_scheduler = ComposedLrSchedulerConfig::new()\n        .cosine(CosineAnnealingLrSchedulerConfig::new(1.0, 2000))\n        // Warmup\n        .linear(LinearLrSchedulerConfig::new(1e-8, 1.0, 2000))\n        .linear(LinearLrSchedulerConfig::new(1e-2, 1e-6, 10000));\n\n    let training = SupervisedTraining::new(ARTIFACT_DIR, dataloader_train, dataloader_valid)\n        .metrics((AccuracyMetric::new(), LossMetric::new()))\n        .metric_train_numeric(LearningRateMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .early_stopping(MetricEarlyStoppingStrategy::new(\n            &LossMetric::<B>::new(),\n            Aggregate::Mean,\n            Direction::Lowest,\n            Split::Valid,\n            StoppingCondition::NoImprovementSince { n_epochs: 5 },\n        ))\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let result = training.launch(Learner::new(\n        model,\n        config.optimizer.init(),\n        lr_scheduler.init().unwrap(),\n    ));\n\n    let dataset_test_plain = Arc::new(MnistDataset::test());\n    let mut renderer = result.renderer;\n\n    let idents_tests = generate_idents(None);\n\n    for (ident, _) in idents_tests {\n        let name = ident.to_string();\n        renderer = evaluate::<B::InnerBackend>(\n            name.as_str(),\n            ident,\n            result.model.clone(),\n            renderer,\n            dataset_test_plain.clone(),\n            config.batch_size,\n        );\n    }\n\n    result\n        .model\n        .save_file(\n            format!(\"{ARTIFACT_DIR}/model\"),\n            &NoStdTrainingRecorder::new(),\n        )\n        .expect(\"Failed to save trained model\");\n\n    config\n        .save(format!(\"{ARTIFACT_DIR}/config.json\").as_str())\n        .unwrap();\n\n    renderer.manual_close();\n}\n\nfn evaluate<B: Backend>(\n    name: &str,\n    ident: DatasetIdent,\n    model: Model<B>,\n    renderer: Box<dyn MetricsRenderer>,\n    dataset: impl Dataset<MnistItem> + 'static,\n    batch_size: usize,\n) -> Box<dyn MetricsRenderer> {\n    let batcher = MnistBatcher::default();\n    let dataset_test = DatasetIdent::prepare(ident, dataset);\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(batch_size)\n        .num_workers(2)\n        .build(dataset_test);\n\n    let evaluator = EvaluatorBuilder::new(ARTIFACT_DIR)\n        .renderer(renderer)\n        .metrics((AccuracyMetric::new(), LossMetric::new()))\n        .summary()\n        .build(model);\n\n    evaluator.eval(name, dataloader_test)\n}\n\nenum DatasetIdent {\n    Plain,\n    Transformed(Vec<Transform>),\n    All,\n}\n\nimpl core::fmt::Display for DatasetIdent {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            DatasetIdent::Plain => f.write_str(\"Plain\")?,\n            DatasetIdent::Transformed(items) => {\n                for i in 0..items.len() {\n                    f.write_fmt(format_args!(\"{}\", items[i]))?;\n                    if i < items.len() - 1 {\n                        f.write_str(\" \")?;\n                    }\n                }\n            }\n            DatasetIdent::All => f.write_str(\"All\")?,\n        };\n\n        Ok(())\n    }\n}\n\nimpl DatasetIdent {\n    pub fn many(transforms: Vec<Transform>) -> Self {\n        Self::Transformed(transforms)\n    }\n\n    pub fn prepare(self, dataset: impl Dataset<MnistItem>) -> impl Dataset<MnistItemPrepared> {\n        let items = match self {\n            DatasetIdent::Plain => Vec::new(),\n            DatasetIdent::All => {\n                vec![\n                    Transform::Translate,\n                    Transform::Shear,\n                    Transform::Scale,\n                    Transform::Rotation,\n                ]\n            }\n            DatasetIdent::Transformed(items) => items.clone(),\n        };\n        MapperDataset::new(dataset, MnistMapper::default().transform(&items))\n    }\n\n    pub fn compose(\n        idents: Vec<(Self, Option<usize>)>,\n        dataset: PartialDataset<Arc<MnistDataset>, MnistItem>,\n    ) -> impl Dataset<MnistItemPrepared> {\n        let datasets = idents\n            .into_iter()\n            .map(|(ident, size)| match size {\n                Some(size) => {\n                    SamplerDataset::with_replacement(ident.prepare(dataset.clone()), size)\n                }\n                None => {\n                    let dataset = ident.prepare(dataset.clone());\n                    let size = dataset.len();\n                    SamplerDataset::without_replacement(dataset, size)\n                }\n            })\n            .collect();\n        ComposedDataset::new(datasets)\n    }\n}\n\nfn generate_idents(num_samples_base: Option<usize>) -> Vec<(DatasetIdent, Option<usize>)> {\n    let mut current = Vec::new();\n    let mut idents = Vec::new();\n\n    for shear in [None, Some(Transform::Shear)] {\n        for scale in [None, Some(Transform::Scale)] {\n            for rotation in [None, Some(Transform::Rotation)] {\n                for translate in [None, Some(Transform::Translate)] {\n                    if let Some(tr) = shear {\n                        current.push(tr);\n                    }\n                    if let Some(tr) = scale {\n                        current.push(tr);\n                    }\n                    if let Some(tr) = rotation {\n                        current.push(tr);\n                    }\n                    if let Some(tr) = translate {\n                        current.push(tr);\n                    }\n\n                    let num_samples = num_samples_base.map(|val| val * current.len());\n\n                    if current.len() == 4 {\n                        idents.push((DatasetIdent::All, num_samples));\n                    } else if current.is_empty() {\n                        idents.push((DatasetIdent::Plain, num_samples));\n                    } else {\n                        idents.push((DatasetIdent::many(current.clone()), num_samples));\n                    }\n\n                    current.clear();\n                }\n            }\n        }\n    }\n\n    idents\n}\n"
  },
  {
    "path": "examples/mnist-inference-web/Cargo.toml",
    "content": "[package]\nauthors = [\"Dilshod Tadjibaev (@antimora)\"]\nedition.workspace = true\nlicense = \"MIT OR Apache-2.0\"\nname = \"mnist-inference-web\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[lib]\ncrate-type = [\"cdylib\"]\n\n[features]\ndefault = [\"ndarray\"]\n\nndarray = [\"burn/ndarray\"]\nwgpu = [\"burn/wgpu\"]\n\n[dependencies]\nburn = { path = \"../../crates/burn\", default-features = false }\nserde = { workspace = true }\nconsole_error_panic_hook = { workspace = true }\n\n# Wasm dependencies\nwasm-bindgen = \"0.2\"\nwasm-bindgen-futures = \"0.4\"\njs-sys = \"0.3\"\n"
  },
  {
    "path": "examples/mnist-inference-web/README.md",
    "content": "# MNIST Inference on Web\n\n[![Live Demo](https://img.shields.io/badge/live-demo-brightgreen)](https://burn.dev/demo)\n\nThis crate demonstrates how to run an MNIST-trained model in the browser for inference.\n\n## Running\n\n1. Build\n\n   ```shell\n   ./build-for-web.sh {backend}\n   ```\n\n   The backend can either be `ndarray` or `wgpu`. Note that `wgpu` only works for browsers with support for WebGPU.\n\n2. Run the server\n\n   ```shell\n   ./run-server.sh\n   ```\n\n3. Open the [`http://localhost:8000/`](http://localhost:8000/) in the browser.\n\n## Design\n\nThe inference components of `burn` with the `ndarray` backend can be built with `#![no_std]`. This\nmakes it possible to build and run the model with the `wasm32-unknown-unknown` target without a\nspecial system library, such as [WASI](https://wasi.dev/). (See [Cargo.toml](./Cargo.toml) on how to\ninclude burn dependencies without `std`).\n\nFor this demo, we use trained parameters (`model.bin`) and model (`model.rs`) from the\n[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).\n\nThe inference API for JavaScript is exposed with the help of\n[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.\n\nJavaScript (`index.js`) is used to transform hand-drawn digits to a format that the inference API\naccepts. The transformation includes image cropping, scaling down, and converting it to grayscale\nvalues.\n\n## Model\n\nLayers:\n\n1. Input Image (28,28, 1ch)\n2. `Conv2d`(3x3, 64ch), `BatchNorm2d`, `Gelu`, `MaxPool`(2x2)\n3. `Conv2d`(3x3, 16ch), `BatchNorm2d`, `Gelu`, `MaxPool`(2x2)\n4. `Linear`(1600, 128), `Relu`\n4. `Linear`(128, 128), `Relu`\n5. `Linear`(128, 10)\n6. Softmax Output\n\nThe total number of parameters is 260,810.\n\nThe model is trained with 18 epochs and the final test accuracy is 95.83%.\n\nRandom transformations are used for data augmentation.\n\nThe training and hyper parameter information in can be found in\n[`burn` MNIST example](https://github.com/tracel-ai/burn/tree/main/examples/mnist).\n\n## Comparison\n\nThe main differentiating factor of this example's approach (compiling rust model into wasm) and\nother popular tools, such as [TensorFlow.js](https://www.tensorflow.org/js),\n[ONNX Runtime JS](https://onnxruntime.ai/docs/tutorials/web/) and\n[TVM Web](https://github.com/apache/tvm/tree/main/web) is the absence of runtime code. The rust\ncompiler optimizes and includes only used `burn` routines. 1,509,747 bytes out of Wasm's 1,866,491\nbyte file is the model's parameters. The rest of 356,744 bytes contain all the code (including\n`burn`'s `nn` components, the data deserialization library, and math operations).\n\n## Future Improvements\n\nThere are several planned enhancements in place:\n\n- [#1271](https://github.com/rust-ndarray/ndarray/issues/1271) -\n  [WASM SIMD](https://github.com/WebAssembly/simd/blob/master/proposals/simd/SIMD.md) support in\n  NDArray that can speed up computation on CPU.\n\n## Acknowledgements\n\nTwo online MNIST demos inspired and helped build this demo:\n[MNIST Draw](https://mco-mnist-draw-rwpxka3zaa-ue.a.run.app/) by Marc (@mco-gh) and\n[MNIST Web Demo](https://ufal.mff.cuni.cz/~straka/courses/npfl129/2223/demos/mnist_web.html) (no\ncode was copied but helped tremendously with an implementation approach).\n\n## Resources\n\n1. [Rust 🦀 and WebAssembly](https://rustwasm.github.io/docs/book/)\n2. [wasm-bindgen](https://rustwasm.github.io/wasm-bindgen/)\n"
  },
  {
    "path": "examples/mnist-inference-web/build-for-web.sh",
    "content": "#!/usr/bin/env bash\n\n# Add wasm32 target for compiler.\nrustup target add wasm32-unknown-unknown\n\nif ! command -v wasm-pack &> /dev/null\nthen\n    echo \"wasm-pack could not be found. Installing ...\"\n    cargo install wasm-pack\nfi\n\n# Set optimization flags\nRUSTFLAGS=\"-C embed-bitcode=yes -C codegen-units=1 -C opt-level=3 --cfg web_sys_unstable_apis\"\n\n# Run wasm pack tool to build JS wrapper files and copy wasm to pkg directory.\nmkdir -p pkg\nwasm-pack build --out-dir pkg --release --target web --no-typescript --no-default-features --features $1\n\n"
  },
  {
    "path": "examples/mnist-inference-web/index.html",
    "content": "<!-- This demo is part of Burn project: https://github.com/tracel-ai/burn\n\n    Released under a dual license: \n    https://github.com/tracel-ai/burn/blob/main/LICENSE-MIT\n\n    https://github.com/tracel-ai/burn/blob/main/LICENSE-APACHE\n-->\n<!DOCTYPE html>\n<html>\n  <head>\n    <meta charset=\"utf-8\" />\n    <title>Burn MNIST Inference Web Demo</title>\n\n    <script\n      src=\"https://cdn.jsdelivr.net/npm/fabric@5.3.0/dist/fabric.min.js\"\n      integrity=\"sha256-SPjwkVvrUS/H/htIwO6wdd0IA8eQ79/XXNAH+cPuoso=\"\n      crossorigin=\"anonymous\"\n    ></script>\n\n    <script\n      src=\"https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js\"\n      integrity=\"sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc=\"\n      crossorigin=\"anonymous\"\n    ></script>\n\n    <script\n      src=\"https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js\"\n      integrity=\"sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg=\"\n      crossorigin=\"anonymous\"\n    ></script>\n\n    <link\n      rel=\"stylesheet\"\n      href=\"https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css\"\n      integrity=\"sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls=\"\n      crossorigin=\"anonymous\"\n    />\n\n    <style>\n      h1 {\n        padding: 15px;\n      }\n      th,\n      td {\n        padding: 5px;\n        text-align: center;\n        vertical-align: middle;\n      }\n    </style>\n  </head>\n  <body>\n    <h1>Burn MNIST Inference Demo</h1>\n\n    <table>\n      <tr>\n        <th>Draw a digit here</th>\n        <th>Cropped and scaled</th>\n        <th>Probability result</th>\n      </tr>\n      <tr>\n        <td>\n          <canvas id=\"main-canvas\" width=\"300\" height=\"300\" style=\"border: 1px solid #aaa\"></canvas>\n        </td>\n        <td>\n          <canvas\n            id=\"scaled-canvas\"\n            width=\"28\"\n            height=\"28\"\n            style=\"border: 1px solid #aaa; width: 100px; height: 100px\"\n          ></canvas>\n          <canvas id=\"crop-canvas\" width=\"28\" height=\"28\" style=\"display: none\"></canvas>\n        </td>\n        <td>\n          <canvas id=\"chart\" style=\"border: 1px solid #aaa; width: 600px; height: 300px\"></canvas>\n        </td>\n      </tr>\n      <tr>\n        <td><button id=\"clear\">Clear</button></td>\n        <td></td>\n        <td></td>\n      </tr>\n    </table>\n\n    <div></div>\n\n    <script type=\"module\">\n      import { $, cropScaleGetImageData, toFixed, chartConfigBuilder } from \"./index.js\";\n\n      import { default as wasm, Mnist } from \"./pkg/mnist_inference_web.js\";\n\n      const chart = chartConfigBuilder($(\"chart\"));\n\n      const mainCanvasEl = $(\"main-canvas\");\n      const scaledCanvasEl = $(\"scaled-canvas\");\n      const cropEl = $(\"crop-canvas\");\n      const mainContext = mainCanvasEl.getContext(\"2d\", { willReadFrequently: true });\n      const cropContext = cropEl.getContext(\"2d\", { willReadFrequently: true });\n      const scaledContext = scaledCanvasEl.getContext(\"2d\", { willReadFrequently: true });\n\n      const fabricCanvas = new fabric.Canvas(mainCanvasEl, {\n        isDrawingMode: true,\n      });\n\n      const backgroundColor = \"rgba(255, 255, 255, 255)\"; // White with solid alpha\n      fabricCanvas.freeDrawingBrush.width = 25;\n      fabricCanvas.backgroundColor = backgroundColor;\n\n      $(\"clear\").onclick = function () {\n        fabricCanvas.clear();\n        fabricCanvas.backgroundColor = backgroundColor;\n        fabricCanvas.renderAll();\n        mainContext.clearRect(0, 0, mainCanvasEl.width, mainCanvasEl.height);\n        scaledContext.clearRect(0, 0, scaledCanvasEl.width, scaledCanvasEl.height);\n\n        chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];\n        chart.update();\n      };\n\n      let timeoutId;\n      let isDrawing = false;\n      let isTimeOutSet = false;\n\n      wasm().then((module) => {\n        const mnist = new Mnist();\n\n        async function fireOffInference() {\n          clearTimeout(timeoutId);\n          timeoutId = setTimeout(async () => {\n            isTimeOutSet = true;\n            fabricCanvas.freeDrawingBrush._finalizeAndAddPath();\n            const data = cropScaleGetImageData(mainContext, cropContext, scaledContext);\n            const output = await mnist.inference(data);\n            chart.data.datasets[0].data = output;\n            chart.update();\n            isTimeOutSet = false;\n          }, 50);\n          isTimeOutSet = true;\n        }\n\n        fabricCanvas.on(\"mouse:down\", function (event) {\n          isDrawing = true;\n        });\n        fabricCanvas.on(\"mouse:up\", async function (event) {\n          isDrawing = false;\n          await fireOffInference();\n        });\n\n        fabricCanvas.on(\"mouse:move\", async function (event) {\n          if (isDrawing && isTimeOutSet == false) {\n            await fireOffInference();\n          }\n        });\n      });\n    </script>\n  </body>\n</html>\n"
  },
  {
    "path": "examples/mnist-inference-web/index.js",
    "content": "/**\n * \n * This demo is part of Burn project: https://github.com/tracel-ai/burn\n * \n * Released under a dual license: \n * https://github.com/tracel-ai/burn/blob/main/LICENSE-MIT\n * https://github.com/tracel-ai/burn/blob/main/LICENSE-APACHE\n * \n */\n\n/**\n * Auto crops the image, scales to 28x28 pixel image, and returns as grayscale image.\n * @param {object} mainContext - The 2d context of the source canvas.\n * @param {object} cropContext - The 2d context of an intermediate hidden canvas.\n * @param {object} scaledContext - The 2d context of the destination 28x28 canvas.\n */\nexport function cropScaleGetImageData(mainContext, cropContext, scaledContext) {\n\n    const cropEl = cropContext.canvas;\n\n    // Get the auto-cropped image data and put into the intermediate/hidden canvas\n    cropContext.fillStyle = \"rgba(255, 255, 255, 255)\"; // white non-transparent color\n    cropContext.fillRect(0, 0, cropEl.width, cropEl.height);\n    cropContext.save();\n    const [w, h, croppedImage] = cropImageFromCanvas(mainContext);\n    cropEl.width = Math.max(w, h) * 1.2;\n    cropEl.height = Math.max(w, h) * 1.2;\n    const leftPadding = (cropEl.width - w) / 2;\n    const topPadding = (cropEl.height - h) / 2;\n    cropContext.putImageData(croppedImage, leftPadding, topPadding);\n\n    // Copy image data to scale 28x28 canvas\n    scaledContext.save();\n    scaledContext.clearRect(0, 0, scaledContext.canvas.height, scaledContext.canvas.width);\n    scaledContext.fillStyle = \"rgba(255, 255, 255, 255)\"; // white non-transparent color\n    scaledContext.fillRect(0, 0, cropEl.width, cropEl.height);\n    scaledContext.scale(28.0 / cropContext.canvas.width, 28.0 / cropContext.canvas.height);\n    scaledContext.drawImage(cropEl, 0, 0);\n\n    // Extract image data and convert into single value (greyscale) array\n    const data = rgba2gray(scaledContext.getImageData(0, 0, 28, 28).data);\n    scaledContext.restore();\n\n    return data;\n}\n\n/**\n * Converts RGBA image data from canvas to grayscale (0 is white & 255 is black).\n * @param {int[]} - Image data.\n */\nexport function rgba2gray(data) {\n    let converted = new Float32Array(data.length / 4);\n\n    // Data is stored as [r0,g0,b0,a0, ... r[n],g[n],b[n],a[n]] where n is number of pixels.\n    for (let i = 0; i < data.length; i += 4) {\n        let r = 255 - data[i];     // red\n        let g = 255 - data[i + 1]; // green\n        let b = 255 - data[i + 2]; // blue\n        let a = 255 - data[i + 3]; // alpha\n\n        // Use RGB grayscale coefficients (https://imagej.nih.gov/ij/docs/menus/image.html)\n        let y = 0.299 * r + 0.587 * g + 0.114 * b;\n        converted[i / 4] = y; // 4 times fewer data points but the same number of pixels.\n    }\n    return converted;\n}\n\n/**\n * Auto crops a canvas images and returns its image data.\n * @param {object} ctx - canvas 2d context.\n * src: https://stackoverflow.com/a/22267731\n */\nexport function cropImageFromCanvas(ctx) {\n    let canvas = ctx.canvas,\n        w = canvas.width,\n        h = canvas.height,\n        pix = { x: [], y: [] },\n        imageData = ctx.getImageData(0, 0, canvas.width, canvas.height),\n        x,\n        y,\n        index;\n    for (y = 0; y < h; y++) {\n        for (x = 0; x < w; x++) {\n            index = (y * w + x) * 4;\n\n            let r = imageData.data[index];\n            let g = imageData.data[index + 1];\n            let b = imageData.data[index + 2];\n            // On some browsers the canvas has a grey border which prevents cropping if we do min != 255\n            if (Math.min(r, g, b) < 240) {\n                pix.x.push(x);\n                pix.y.push(y);\n            }\n        }\n    }\n    pix.x.sort(function (a, b) {\n        return a - b;\n    });\n    pix.y.sort(function (a, b) {\n        return a - b;\n    });\n    let n = pix.x.length - 1;\n    w = 1 + pix.x[n] - pix.x[0];\n    h = 1 + pix.y[n] - pix.y[0];\n    return [w, h, ctx.getImageData(pix.x[0], pix.y[0], w, h, { willReadFrequently: true })];\n}\n\n/**\n * Truncates number to a given decimal position\n * @param {number} num - Number to truncate.\n * @param {number} fixed - Decimal positions.\n * src: https://stackoverflow.com/a/11818658\n */\nexport function toFixed(num, fixed) {\n    const re = new RegExp('^-?\\\\d+(?:\\.\\\\d{0,' + (fixed || -1) + '})?');\n    return num.toString().match(re)[0];\n}\n\n/**\n * Looks up element by an id.\n * @param {string} - Element id.\n */\nexport function $(id) {\n    return document.getElementById(id);\n}\n\n/**\n * Helper function that builds a chart using Chart.js library.\n * @param {object} chartEl - Chart canvas element.\n * \n * NOTE: Assumes chart.js is loaded into the global.\n */\nexport function chartConfigBuilder(chartEl) {\n    Chart.register(ChartDataLabels);\n    return new Chart(chartEl, {\n        plugins: [ChartDataLabels],\n        type: \"bar\",\n        data: {\n            labels: [\"0\", \"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\"],\n            datasets: [\n                {\n                    data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                    borderWidth: 0,\n                    fill: true,\n                    backgroundColor: \"#247ABF\",\n                },\n            ],\n        },\n        options: {\n            responsive: false,\n            maintainAspectRatio: false,\n            animation: true,\n            plugins: {\n                legend: {\n                    display: false,\n                },\n                tooltip: {\n                    enabled: true,\n                },\n                datalabels: {\n                    color: \"white\",\n                    formatter: function (value, context) {\n                        return toFixed(value, 2);\n                    },\n                },\n            },\n            scales: {\n                y: {\n                    beginAtZero: true,\n                    max: 1.0,\n                },\n            },\n        },\n    });\n}"
  },
  {
    "path": "examples/mnist-inference-web/run-server.sh",
    "content": "#!/usr/bin/env bash\n\n# Opening index.html file directly by a browser does not work because of\n# the security restrictions by the browser. Viewing the HTML file will fail with \n# this error message:\n\n# ```\n# Access to script at\n#  'file:///Users/user/Projects/burn-mac/examples/mnist-inference-web/pkg/mnist_inference_web.js' \n# from origin 'null' has been blocked by CORS policy: \n# Cross origin requests are only supported for protocol schemes: \n# http, data, isolated-app, chrome-extension, chrome, https, chrome-untrusted.\n# ```\n#  So that's why running a local HTTP server is needed. \n\nif ! command -v python3 &> /dev/null\nthen\n    echo \"python3 could not be found. Running server requires python3.\"\n    exit\nfi\n\necho \"Running local python HTTP server on port 8000 ...\"\npython3 -m http.server 8000\n"
  },
  {
    "path": "examples/mnist-inference-web/src/lib.rs",
    "content": "#![cfg_attr(not(test), no_std)]\n\npub mod model;\npub mod state;\npub mod web;\n\nextern crate alloc;\n"
  },
  {
    "path": "examples/mnist-inference-web/src/model.rs",
    "content": "use burn::{\n    nn::{\n        BatchNorm, PaddingConfig2d,\n        pool::{MaxPool2d, MaxPool2dConfig},\n    },\n    prelude::*,\n};\n\n#[derive(Module, Debug)]\npub struct Model<B: Backend> {\n    conv1: ConvBlock<B>,\n    conv2: ConvBlock<B>,\n    dropout: nn::Dropout,\n    fc1: nn::Linear<B>,\n    fc2: nn::Linear<B>,\n    fc3: nn::Linear<B>,\n    activation: nn::Gelu,\n}\n\nconst NUM_CLASSES: usize = 10;\n\nimpl<B: Backend> Model<B> {\n    pub fn new(device: &B::Device) -> Self {\n        let conv1 = ConvBlock::new([1, 64], [3, 3], device, true); // out: max_pool -> [Batch,32,13,13]\n        let conv2 = ConvBlock::new([64, 64], [3, 3], device, true); // out: max_pool -> [Batch,64,5,5]\n        let hidden_size = 64 * 5 * 5;\n        let fc1 = nn::LinearConfig::new(hidden_size, 128).init(device);\n        let fc2 = nn::LinearConfig::new(128, 128).init(device);\n        let fc3 = nn::LinearConfig::new(128, NUM_CLASSES).init(device);\n\n        let dropout = nn::DropoutConfig::new(0.25).init();\n\n        Self {\n            conv1,\n            conv2,\n            dropout,\n            fc1,\n            fc2,\n            fc3,\n            activation: nn::Gelu::new(),\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {\n        let [batch_size, height, width] = input.dims();\n\n        let x = input.reshape([batch_size, 1, height, width]).detach();\n        let x = self.conv1.forward(x);\n        let x = self.conv2.forward(x);\n\n        let [batch_size, channels, height, width] = x.dims();\n        let x = x.reshape([batch_size, channels * height * width]);\n\n        let x = self.fc1.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n        let x = self.fc2.forward(x);\n        let x = self.activation.forward(x);\n        let x = self.dropout.forward(x);\n\n        self.fc3.forward(x)\n    }\n}\n\n#[derive(Module, Debug)]\npub struct ConvBlock<B: Backend> {\n    conv: nn::conv::Conv2d<B>,\n    norm: BatchNorm<B>,\n    pool: Option<MaxPool2d>,\n    activation: nn::Relu,\n}\n\nimpl<B: Backend> ConvBlock<B> {\n    pub fn new(\n        channels: [usize; 2],\n        kernel_size: [usize; 2],\n        device: &B::Device,\n        pool: bool,\n    ) -> Self {\n        let conv = nn::conv::Conv2dConfig::new(channels, kernel_size)\n            .with_padding(PaddingConfig2d::Valid)\n            .init(device);\n        let norm = nn::BatchNormConfig::new(channels[1]).init(device);\n        let pool = if pool {\n            Some(MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init())\n        } else {\n            None\n        };\n\n        Self {\n            conv,\n            norm,\n            pool,\n            activation: nn::Relu::new(),\n        }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {\n        let x = self.conv.forward(input);\n        let x = self.norm.forward(x);\n        let x = self.activation.forward(x);\n\n        if let Some(pool) = &self.pool {\n            pool.forward(x)\n        } else {\n            x\n        }\n    }\n}\n"
  },
  {
    "path": "examples/mnist-inference-web/src/state.rs",
    "content": "use crate::model::Model;\nuse burn::{\n    module::Module,\n    record::{BinBytesRecorder, FullPrecisionSettings, Recorder},\n};\n\n#[cfg(feature = \"wgpu\")]\nuse burn::backend::wgpu::{Wgpu, WgpuDevice, graphics::AutoGraphicsApi, init_setup_async};\n\n#[cfg(feature = \"wgpu\")]\npub type Backend = Wgpu<f32, i32>;\n\n#[cfg(all(feature = \"ndarray\", not(feature = \"wgpu\")))]\npub type Backend = burn::backend::ndarray::NdArray<f32>;\n\nstatic STATE_ENCODED: &[u8] = include_bytes!(\"../model.bin\");\n\n/// Builds and loads trained parameters into the model.\npub async fn build_and_load_model() -> Model<Backend> {\n    #[cfg(feature = \"wgpu\")]\n    init_setup_async::<AutoGraphicsApi>(&WgpuDevice::default(), Default::default()).await;\n\n    let model: Model<Backend> = Model::new(&Default::default());\n    let record = BinBytesRecorder::<FullPrecisionSettings, &'static [u8]>::default()\n        .load(STATE_ENCODED, &Default::default())\n        .expect(\"Failed to decode state\");\n\n    model.load_record(record)\n}\n"
  },
  {
    "path": "examples/mnist-inference-web/src/web.rs",
    "content": "#![allow(clippy::new_without_default)]\n\nuse alloc::string::String;\nuse js_sys::Array;\n\n#[cfg(target_family = \"wasm\")]\nuse wasm_bindgen::prelude::*;\n\nuse crate::model::Model;\nuse crate::state::{Backend, build_and_load_model};\n\nuse burn::tensor::Tensor;\n\n#[cfg_attr(target_family = \"wasm\", wasm_bindgen(start))]\npub fn start() {\n    console_error_panic_hook::set_once();\n}\n\n/// Mnist structure that corresponds to JavaScript class.\n/// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html)\n#[cfg_attr(target_family = \"wasm\", wasm_bindgen)]\npub struct Mnist {\n    model: Option<Model<Backend>>,\n}\n\n#[cfg_attr(target_family = \"wasm\", wasm_bindgen)]\nimpl Mnist {\n    /// Constructor called by JavaScripts with the new keyword.\n    #[cfg_attr(target_family = \"wasm\", wasm_bindgen(constructor))]\n    pub fn new() -> Self {\n        console_error_panic_hook::set_once();\n        Self { model: None }\n    }\n\n    /// Returns the inference results.\n    ///\n    /// This method is called from JavaScript via generated wrapper code by wasm-bindgen.\n    ///\n    /// # Arguments\n    ///\n    /// * `input` - A f32 slice of input 28x28 image\n    ///\n    /// See bindgen support types for passing and returning arrays:\n    /// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html)\n    /// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html)\n    ///\n    pub async fn inference(&mut self, input: &[f32]) -> Result<Array, String> {\n        if self.model.is_none() {\n            self.model = Some(build_and_load_model().await);\n        }\n\n        let model = self.model.as_ref().unwrap();\n\n        let device = Default::default();\n        // Reshape from the 1D array to 3d tensor [batch, height, width]\n        let input = Tensor::<Backend, 1>::from_floats(input, &device).reshape([1, 28, 28]);\n\n        // Normalize input: make between [0,1] and make the mean=0 and std=1\n        // values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example\n        // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122\n\n        let input = ((input / 255) - 0.1307) / 0.3081;\n\n        // Run the tensor input through the model\n        let output: Tensor<Backend, 2> = model.forward(input);\n\n        // Convert the model output into probability distribution using softmax formula\n        let output = burn::tensor::activation::softmax(output, 1);\n\n        // Flatten output tensor with [1, 10] shape into boxed slice of [f32]\n        let output = output.into_data_async().await.unwrap();\n\n        let array = Array::new();\n        for value in output.iter::<f32>() {\n            array.push(&value.into());\n        }\n\n        Ok(array)\n    }\n}\n"
  },
  {
    "path": "examples/modern-lstm/Cargo.toml",
    "content": "[package]\nedition.workspace = true\nname = \"modern-lstm\"\nversion = \"0.5.0\"\n\n[lints]\nworkspace = true\n\n[features]\ncuda = [\"burn/cuda\"]\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"burn/ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"burn/ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"burn/ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\"]\n\n[dependencies]\nburn = { path = \"../../crates/burn\", features = [\"train\"] }\n\n# Random number generator\nrand = { workspace = true, features = [\"thread_rng\"] }\nrand_distr = { workspace = true }\n\n# Serialization\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n\n# Organise the results in dataframe\nplanus = { workspace = true }\npolars = { workspace = true }\n"
  },
  {
    "path": "examples/modern-lstm/README.md",
    "content": "# Advanced LSTM Implementation with Burn\n\nA more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined\nweight matrices for the input and hidden states, based on the\n[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch).\n\n`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM\nvariants differ by `bidirectional` and `num_layers` settings：\n\n- LSTM: `num_layers = 1` and `bidirectional = false`\n- Stacked LSTM: `num_layers > 1` and `bidirectional = false`\n- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true`\n- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true`\n\nThis implementation is complementary to Burn's official LSTM, users can choose either one depends on\nthe project's specific needs.\n\n## Usage\n\n## Training\n\n```sh\n# Cuda backend\ncargo run --example lstm-train --release --features cuda\n\n# Wgpu backend\ncargo run --example lstm-train --release --features wgpu\n\n# Tch GPU backend\nexport TORCH_CUDA_VERSION=cu128 # Set the cuda version\ncargo run --example lstm-train --release --features tch-gpu\n\n# Tch CPU backend\ncargo run --example lstm-train --release --features tch-cpu\n\n# NdArray backend (CPU)\ncargo run --example lstm-train --release --features ndarray\ncargo run --example lstm-train --release --features ndarray-blas-openblas\ncargo run --example lstm-train --release --features ndarray-blas-netlib\n```\n\n### Inference\n\n```sh\ncargo run --example lstm-infer --release --features cuda\n```\n"
  },
  {
    "path": "examples/modern-lstm/examples/lstm-infer.rs",
    "content": "use burn::tensor::backend::Backend;\n\npub fn launch<B: Backend>(device: B::Device) {\n    modern_lstm::inference::infer::<B>(\"/tmp/modern-lstm\", device);\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::ndarray::{NdArray, NdArrayDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<NdArray>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<LibTorch>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<LibTorch>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::launch;\n    use burn::backend::wgpu::Wgpu;\n\n    pub fn run() {\n        launch::<Wgpu>(Default::default());\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::launch;\n    use burn::backend::Cuda;\n\n    pub fn run() {\n        launch::<Cuda>(Default::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n}\n"
  },
  {
    "path": "examples/modern-lstm/examples/lstm-train.rs",
    "content": "use burn::{\n    grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend,\n};\nuse modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig};\n\npub fn launch<B: AutodiffBackend>(device: B::Device) {\n    let config = TrainingConfig::new(\n        LstmNetworkConfig::new(),\n        // Gradient clipping via optimizer config\n        AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))),\n    );\n\n    modern_lstm::training::train::<B>(\"/tmp/modern-lstm\", config, device);\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::{\n        Autodiff,\n        ndarray::{NdArray, NdArrayDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<Autodiff<LibTorch>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::launch;\n    use burn::backend::{Autodiff, wgpu::Wgpu};\n\n    pub fn run() {\n        launch::<Autodiff<Wgpu>>(Default::default());\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::launch;\n    use burn::backend::{Autodiff, Cuda, cuda::CudaDevice};\n\n    pub fn run() {\n        launch::<Autodiff<Cuda>>(CudaDevice::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n}\n"
  },
  {
    "path": "examples/modern-lstm/src/dataset.rs",
    "content": "use burn::{\n    data::{\n        dataloader::batcher::Batcher,\n        dataset::{Dataset, InMemDataset},\n    },\n    prelude::*,\n};\nuse rand::RngExt;\nuse rand_distr::{Distribution, Normal};\nuse serde::{Deserialize, Serialize};\n\n// Dataset parameters\npub const NUM_SEQUENCES: usize = 1000;\npub const SEQ_LENGTH: usize = 10;\npub const NOISE_LEVEL: f32 = 0.1;\npub const RANDOM_SEED: u64 = 5;\n\n// Generate a sequence where each number is the sum of previous two numbers plus noise\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct SequenceDatasetItem {\n    pub sequence: Vec<f32>,\n    pub target: f32,\n}\n\nimpl SequenceDatasetItem {\n    pub fn new(seq_length: usize, noise_level: f32) -> Self {\n        // Start with two random numbers between 0 and 1\n        let mut seq = vec![rand::rng().random(), rand::rng().random()];\n\n        // Generate sequence\n        for _i in 0..seq_length {\n            // Next number is sum of previous two plus noise\n            let normal = Normal::new(0.0, noise_level).unwrap();\n            let next_val =\n                seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::rng());\n            seq.push(next_val);\n        }\n\n        Self {\n            // Convert to sequence and target\n            sequence: seq[0..seq.len() - 1].to_vec(), // All but last\n            target: seq[seq.len() - 1],               // Last value\n        }\n    }\n}\n\n// Custom Dataset for Sequence Data\npub struct SequenceDataset {\n    dataset: InMemDataset<SequenceDatasetItem>,\n}\n\nimpl SequenceDataset {\n    pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self {\n        let mut items = vec![];\n        for _i in 0..num_sequences {\n            items.push(SequenceDatasetItem::new(seq_length, noise_level));\n        }\n        let dataset = InMemDataset::new(items);\n\n        Self { dataset }\n    }\n}\n\nimpl Dataset<SequenceDatasetItem> for SequenceDataset {\n    fn get(&self, index: usize) -> Option<SequenceDatasetItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\n#[derive(Clone, Debug, Default)]\npub struct SequenceBatcher {}\n\n#[derive(Clone, Debug)]\npub struct SequenceBatch<B: Backend> {\n    pub sequences: Tensor<B, 3>, // [batch_size, seq_length, input_size]\n    pub targets: Tensor<B, 2>,   // [batch_size, 1]\n}\n\nimpl<B: Backend> Batcher<B, SequenceDatasetItem, SequenceBatch<B>> for SequenceBatcher {\n    fn batch(&self, items: Vec<SequenceDatasetItem>, device: &B::Device) -> SequenceBatch<B> {\n        let mut sequences: Vec<Tensor<B, 2>> = Vec::new();\n\n        for item in items.iter() {\n            let seq_tensor = Tensor::<B, 1>::from_floats(item.sequence.as_slice(), device);\n            // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations\n            sequences.push(seq_tensor.unsqueeze_dims(&[-1]));\n        }\n        let sequences = Tensor::stack(sequences, 0);\n\n        let targets = items\n            .iter()\n            .map(|item| Tensor::<B, 1>::from_floats([item.target], device))\n            .collect();\n        let targets = Tensor::stack(targets, 0);\n\n        SequenceBatch { sequences, targets }\n    }\n}\n"
  },
  {
    "path": "examples/modern-lstm/src/inference.rs",
    "content": "use crate::{\n    dataset::{\n        NOISE_LEVEL, NUM_SEQUENCES, SEQ_LENGTH, SequenceBatcher, SequenceDataset,\n        SequenceDatasetItem,\n    },\n    model::LstmNetwork,\n    training::TrainingConfig,\n};\nuse burn::{\n    data::{dataloader::batcher::Batcher, dataset::Dataset},\n    prelude::*,\n    record::{CompactRecorder, Recorder},\n};\nuse polars::prelude::*;\n\npub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {\n    // Loading model\n    let config = TrainingConfig::load(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should exist for the model; run train first\");\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model should exist; run train first\");\n\n    let model: LstmNetwork<B> = config.model.init(&device).load_record(record);\n\n    let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL);\n    let items: Vec<SequenceDatasetItem> = dataset.iter().collect();\n\n    let batcher = SequenceBatcher::default();\n    // Put all items in one batch\n    let batch = batcher.batch(items, &device);\n    let predicted = model.forward(batch.sequences, None);\n    let targets = batch.targets;\n\n    let predicted = predicted.squeeze_dim::<1>(1).into_data();\n    let expected = targets.squeeze_dim::<1>(1).into_data();\n\n    // Display the predicted vs expected values\n    let results = df![\n        \"predicted\" => &predicted.to_vec::<f32>().unwrap(),\n        \"expected\" => &expected.to_vec::<f32>().unwrap(),\n    ]\n    .unwrap();\n    println!(\"{}\", &results.head(Some(10)));\n}\n"
  },
  {
    "path": "examples/modern-lstm/src/lib.rs",
    "content": "pub mod dataset;\npub mod inference;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/modern-lstm/src/model.rs",
    "content": "use burn::{\n    nn::{\n        Dropout, DropoutConfig, Initializer, LayerNorm, LayerNormConfig, Linear, LinearConfig,\n        LstmState, Sigmoid, Tanh,\n    },\n    prelude::*,\n};\n\n/// LSTM Cell implementation with layer normalization.\n///\n/// Mathematical formulation of LSTM:\n/// f_t = σ(W_f · [h_{t-1}, x_t] + b_f)      # Forget gate\n/// i_t = σ(W_i · [h_{t-1}, x_t] + b_i]      # Input gate\n/// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g]   # Candidate cell state\n/// o_t = σ(W_o · [h_{t-1}, x_t] + b_o)      # Output gate\n///\n/// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t            # New cell state\n/// h_t = o_t ⊙ tanh(c_t)                       # New hidden state\n///\n/// where:\n/// - σ is the sigmoid function\n/// - ⊙ is the element-wise multiplication\n/// - [h_{t-1}, x_t] represents concatenation\n\n#[derive(Module, Debug)]\npub struct LstmCell<B: Backend> {\n    pub hidden_size: usize,\n    // Combined weight matrices for efficiency\n    // weight_ih layer uses combined weights for [i_t, f_t, g_t, o_t] for input x_t\n    // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1}\n    pub weight_ih: Linear<B>,\n    pub weight_hh: Linear<B>,\n    // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM.\n    pub norm_x: LayerNorm<B>, // Normalize gate pre-activations\n    pub norm_h: LayerNorm<B>, // Normalize hidden state\n    pub norm_c: LayerNorm<B>, // Normalize cell state\n    pub dropout: Dropout,\n}\n\n/// Configuration to create a Lstm module using the init function.\n#[derive(Config, Debug)]\npub struct LstmCellConfig {\n    // The size of the input features\n    pub input_size: usize,\n    // The size of the hidden state\n    pub hidden_size: usize,\n    // The number of hidden layers\n    pub dropout: f64,\n}\n\nimpl LstmCellConfig {\n    // Initialize parameters using best practices:\n    // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn)\n    // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training\n    #[allow(clippy::single_range_in_vec_init)]\n    pub fn init<B: Backend>(&self, device: &B::Device) -> LstmCell<B> {\n        let initializer = Initializer::XavierNormal { gain: 1.0 };\n        let init_bias = Tensor::<B, 1>::ones([self.hidden_size], device);\n\n        let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size)\n            .with_initializer(initializer.clone())\n            .init(device);\n        // Set forget gate bias to 1.0 (helps with learning long sequences)\n        let bias = weight_ih\n            .bias\n            .clone()\n            .unwrap()\n            .val()\n            .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone());\n        weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias));\n\n        let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size)\n            .with_initializer(initializer)\n            .init(device);\n        let bias = weight_hh\n            .bias\n            .clone()\n            .unwrap()\n            .val()\n            .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias);\n        weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias));\n\n        LstmCell {\n            hidden_size: self.hidden_size,\n            weight_ih,\n            weight_hh,\n            norm_x: LayerNormConfig::new(4 * self.hidden_size).init(device),\n            norm_h: LayerNormConfig::new(self.hidden_size).init(device),\n            norm_c: LayerNormConfig::new(self.hidden_size).init(device),\n            dropout: DropoutConfig::new(self.dropout).init(),\n        }\n    }\n}\n\nimpl<B: Backend> LstmCell<B> {\n    /// Forward pass of LSTM cell.\n    /// Args:\n    ///     x: Input tensor of shape (batch_size, input_size)\n    ///     state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size)\n    /// Returns:\n    ///  Tuple of (h_t, c_t) representing new hidden and cell states\n    pub fn forward(&self, x: Tensor<B, 2>, state: LstmState<B, 2>) -> LstmState<B, 2> {\n        let (h_prev, c_prev) = (state.hidden, state.cell);\n\n        // Combined matrix multiplication for all gates\n        // Shape: (batch_size, 4 * hidden_size)\n        let gates_x = self.weight_ih.forward(x); // Transform input\n        let gates_h = self.weight_hh.forward(h_prev); // Transform previous hidden state\n\n        // Apply layer normalization\n        let gates_x = self.norm_x.forward(gates_x);\n        // Combined gate pre-activations\n        let gates = gates_x + gates_h;\n\n        // Split into individual gates\n        // Each gate shape: (batch_size, hidden_size)\n        let gates = gates.chunk(4, 1);\n        let i_gate = gates[0].clone();\n        let f_gate = gates[1].clone();\n        let g_gate = gates[2].clone();\n        let o_gate = gates[3].clone();\n\n        // Apply gate non-linearities\n        let i_t = Sigmoid::new().forward(i_gate);\n        let f_t = Sigmoid::new().forward(f_gate);\n        let g_t = Tanh::new().forward(g_gate);\n        let o_t = Sigmoid::new().forward(o_gate);\n\n        // Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t\n        let c_t = f_t * c_prev + i_t * g_t;\n        let c_t = self.norm_c.forward(c_t);\n\n        // Update cell state: h_t = o_t ⊙ tanh(c_t)\n        let h_t = o_t * Tanh::new().forward(c_t.clone());\n        let h_t = self.norm_h.forward(h_t);\n\n        let h_t = self.dropout.forward(h_t);\n\n        LstmState::new(h_t, c_t)\n    }\n\n    // Initialize cell state and hidden state if provided or with zeros\n    pub fn init_state(&self, batch_size: usize, device: &B::Device) -> LstmState<B, 2> {\n        let cell = Tensor::zeros([batch_size, self.hidden_size], device);\n        let hidden = Tensor::zeros([batch_size, self.hidden_size], device);\n\n        LstmState::new(cell, hidden)\n    }\n}\n\n/// Stacked LSTM implementation supporting multiple layers\n/// Each layer processes the output of the previous layer\n#[derive(Module, Debug)]\npub struct StackedLstm<B: Backend> {\n    pub layers: Vec<LstmCell<B>>,\n}\n\n#[derive(Config, Debug)]\npub struct StackedLstmConfig {\n    pub input_size: usize,\n    pub hidden_size: usize,\n    pub num_layers: usize,\n    pub dropout: f64,\n}\n\nimpl StackedLstmConfig {\n    pub fn init<B: Backend>(&self, device: &B::Device) -> StackedLstm<B> {\n        let mut layers: Vec<LstmCell<B>> = vec![];\n        // Create list of LSTM cells, one for each layer\n        for i in 0..self.num_layers {\n            if i == 0 {\n                if i < self.num_layers - 1 {\n                    layers.push(\n                        LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout)\n                            .init(device),\n                    );\n                } else {\n                    // No dropout on last layer\n                    layers.push(\n                        LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device),\n                    );\n                }\n            } else if i < self.num_layers - 1 {\n                layers.push(\n                    LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout)\n                        .init(device),\n                );\n            } else {\n                // No dropout on last layer\n                layers.push(\n                    LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device),\n                );\n            }\n        }\n        StackedLstm { layers }\n    }\n}\n\nimpl<B: Backend> StackedLstm<B> {\n    /// Process input sequence through stacked LSTM layers.\n    ///\n    /// Args:\n    ///     x: Input tensor of shape (batch_size, seq_length, input_size)\n    ///     states: Optional initial states for each layer\n    ///\n    /// Returns:\n    ///     Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size)\n    ///     and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size)\n    pub fn forward(\n        &self,\n        x: Tensor<B, 3>,\n        states: Option<Vec<LstmState<B, 2>>>,\n    ) -> (Tensor<B, 3>, Vec<LstmState<B, 2>>) {\n        let [batch_size, seq_length, _] = x.dims();\n        let device = x.device();\n\n        let mut states = match states {\n            None => {\n                let mut temp: Vec<LstmState<B, 2>> = vec![];\n                for layer in self.layers.iter() {\n                    temp.push(layer.init_state(batch_size, &device));\n                }\n                temp\n            }\n            _ => states.unwrap(),\n        };\n\n        let mut layer_outputs = vec![];\n        for t in 0..seq_length {\n            let mut input_t = x.clone().slice(s![.., t..t + 1, ..]).squeeze_dim::<2>(1);\n            for (i, lstm_cell) in self.layers.iter().enumerate() {\n                let mut state: LstmState<B, 2> =\n                    LstmState::new(states[i].cell.clone(), states[i].hidden.clone());\n                state = lstm_cell.forward(input_t, state);\n                input_t = state.hidden.clone();\n                states[i] = state;\n            }\n            layer_outputs.push(input_t);\n        }\n\n        // Stack output along sequence dimension\n        let output = Tensor::stack(layer_outputs, 1);\n\n        (output, states)\n    }\n}\n\n/// Complete LSTM network with bidirectional support.\n///\n/// In bidirectional mode:\n/// - Forward LSTM processes sequence from left to right\n/// - Backward LSTM processes sequence from right to left\n/// - Outputs are concatenated for final prediction\n#[derive(Module, Debug)]\npub struct LstmNetwork<B: Backend> {\n    // Forward direction LSTM\n    pub stacked_lstm: StackedLstm<B>,\n    // Optional backward direction LSTM for bidirectional processing\n    pub reverse_lstm: Option<StackedLstm<B>>,\n    pub dropout: Dropout,\n    pub fc: Linear<B>,\n}\n\n#[derive(Config, Debug)]\npub struct LstmNetworkConfig {\n    #[config(default = 1)]\n    pub input_size: usize, // Single feature (number sequence)\n    #[config(default = 32)]\n    pub hidden_size: usize, // Size of LSTM hidden state\n    #[config(default = 2)]\n    pub num_layers: usize, // Number of LSTM layers\n    #[config(default = 1)]\n    pub output_size: usize, // Predict one number\n    #[config(default = 0.1)]\n    pub dropout: f64,\n    #[config(default = true)]\n    pub bidirectional: bool, // Use bidirectional LSTM\n}\n\nimpl LstmNetworkConfig {\n    pub fn init<B: Backend>(&self, device: &B::Device) -> LstmNetwork<B> {\n        // Forward direction LSTM\n        let stacked_lstm = StackedLstmConfig::new(\n            self.input_size,\n            self.hidden_size,\n            self.num_layers,\n            self.dropout,\n        )\n        .init(device);\n\n        // Optional backward direction LSTM for bidirectional processing\n        let (reverse_lstm, hidden_size) = if self.bidirectional {\n            let lstm = StackedLstmConfig::new(\n                self.input_size,\n                self.hidden_size,\n                self.num_layers,\n                self.dropout,\n            )\n            .init(device);\n            (Some(lstm), 2 * self.hidden_size)\n        } else {\n            (None, self.hidden_size)\n        };\n\n        let fc = LinearConfig::new(hidden_size, self.output_size).init(device);\n        let dropout = DropoutConfig::new(self.dropout).init();\n\n        LstmNetwork {\n            stacked_lstm,\n            reverse_lstm,\n            dropout,\n            fc,\n        }\n    }\n}\n\nimpl<B: Backend> LstmNetwork<B> {\n    /// Forward pass of the network.\n    ///\n    /// For bidirectional processing:\n    /// 1. Process sequence normally with forward LSTM\n    /// 2. Process reversed sequence with backward LSTM\n    /// 3. Concatenate both outputs\n    /// 4. Apply final linear transformation\n    ///\n    /// Args:\n    ///     x: Input tensor of shape (batch_size, seq_length, input_size)\n    ///     states: Optional initial states\n    ///\n    /// Returns:\n    ///     Output tensor of shape (batch_size, output_size)\n    pub fn forward(&self, x: Tensor<B, 3>, states: Option<Vec<LstmState<B, 2>>>) -> Tensor<B, 2> {\n        let seq_length = x.dims()[1];\n        // Forward direction\n        let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states);\n\n        output = match &self.reverse_lstm {\n            Some(reverse_lstm) => {\n                //Process sequence in reverse direction\n                let (mut reverse_output, _states) = reverse_lstm.forward(x.flip([1]), None);\n                // Flip back to align with forward sequence\n                reverse_output = reverse_output.flip([1]);\n                // Concatenate forward and backward outputs along the feature dimension\n                output = Tensor::cat(vec![output, reverse_output], 2);\n                output\n            }\n            None => output,\n        };\n\n        // Apply dropout before final layer\n        output = self.dropout.forward(output);\n        // Use final timestep output for prediction\n        self.fc.forward(\n            output\n                .slice(s![.., seq_length - 1..seq_length, ..])\n                .squeeze_dim::<2>(1),\n        )\n    }\n}\n"
  },
  {
    "path": "examples/modern-lstm/src/training.rs",
    "content": "use crate::dataset::{\n    NOISE_LEVEL, NUM_SEQUENCES, RANDOM_SEED, SEQ_LENGTH, SequenceBatcher, SequenceDataset,\n};\nuse crate::model::{LstmNetwork, LstmNetworkConfig};\nuse burn::{\n    data::dataloader::DataLoaderBuilder,\n    module::AutodiffModule,\n    nn::loss::{MseLoss, Reduction::Mean},\n    optim::{AdamConfig, GradientsParams, Optimizer},\n    prelude::*,\n    record::CompactRecorder,\n    tensor::backend::AutodiffBackend,\n};\n\n#[derive(Config, Debug)]\npub struct TrainingConfig {\n    pub model: LstmNetworkConfig,\n    pub optimizer: AdamConfig,\n\n    #[config(default = 30)]\n    pub num_epochs: usize,\n    #[config(default = 32)]\n    pub batch_size: usize,\n    #[config(default = 2)]\n    pub num_workers: usize,\n    #[config(default = 1e-3)]\n    pub lr: f64,\n}\n\n// Create the directory to save the model and model config\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {\n    create_artifact_dir(artifact_dir);\n\n    // Save training config\n    config\n        .save(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should be saved successfully\");\n    B::seed(&device, RANDOM_SEED);\n\n    // Create the model and optimizer\n    let mut model = config.model.init::<B>(&device);\n    let mut optim = config.optimizer.init::<B, LstmNetwork<B>>();\n\n    // Create the batcher\n    let batcher = SequenceBatcher::default();\n\n    // Create the dataloaders\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .shuffle(RANDOM_SEED)\n        .num_workers(config.num_workers)\n        .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL));\n\n    let dataloader_valid = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .shuffle(RANDOM_SEED)\n        .num_workers(config.num_workers)\n        // 20% size of training\n        .build(SequenceDataset::new(\n            NUM_SEQUENCES / 5,\n            SEQ_LENGTH,\n            NOISE_LEVEL,\n        ));\n\n    let train_num_items = dataloader_train.num_items();\n    let valid_num_items = dataloader_valid.num_items();\n\n    println!(\"Starting training...\");\n    // Iterate over our training for X epochs\n    for epoch in 1..config.num_epochs + 1 {\n        // Initialize the training and validation metrics at the start of each epoch\n        let mut train_losses = vec![];\n        let mut train_loss = 0.0;\n        let mut valid_losses = vec![];\n        let mut valid_loss = 0.0;\n\n        // Implement our training loop\n        for batch in dataloader_train.iter() {\n            let output = model.forward(batch.sequences, None);\n            let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean);\n            train_loss += loss.clone().into_scalar().elem::<f32>() * batch.targets.dims()[0] as f32;\n\n            // Gradients for the current backward pass\n            let grads = loss.backward();\n            // Gradients linked to each parameter of the model\n            let grads = GradientsParams::from_grads(grads, &model);\n            // Update the model using the optimizer\n            model = optim.step(config.lr, model, grads);\n        }\n\n        // The averaged train loss per epoch\n        let avg_train_loss = train_loss / train_num_items as f32;\n        train_losses.push(avg_train_loss);\n\n        // Get the model without autodiff\n        let valid_model = model.valid();\n\n        // Implement our validation loop\n        for batch in dataloader_valid.iter() {\n            let output = valid_model.forward(batch.sequences, None);\n            let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean);\n            valid_loss += loss.clone().into_scalar().elem::<f32>() * batch.targets.dims()[0] as f32;\n        }\n        // The averaged train loss per epoch\n        let avg_valid_loss = valid_loss / valid_num_items as f32;\n        valid_losses.push(avg_valid_loss);\n\n        // Display the averaged training and validation metrics every 10 epochs\n        if (epoch + 1) % 5 == 0 {\n            println!(\n                \"Epoch {}/{}, Avg Loss {:.4}, Avg Val Loss: {:.4}\",\n                epoch + 1,\n                config.num_epochs,\n                avg_train_loss,\n                avg_valid_loss,\n            );\n        }\n    }\n\n    // Save the trained model\n    model\n        .save_file(format!(\"{artifact_dir}/model\"), &CompactRecorder::new())\n        .expect(\"Trained model should be saved successfully\");\n}\n"
  },
  {
    "path": "examples/multi-gpus/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"multi-gpus\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = []\nf16 = []\nflex32 = []\ntch-gpu = [\"burn/tch\"]\ncuda = [\"burn/cuda\"]\nrocm = [\"burn/rocm\"]\n\n[dependencies]\n# Burn\nburn = { path = \"../../crates/burn\", features = [\n    \"autotune\",\n    \"fusion\",\n    \"collective\",\n    \"train\",\n    \"std\",\n], default-features = false }\ntext-classification = { path = \"../text-classification\" }\n"
  },
  {
    "path": "examples/multi-gpus/examples/multi-gpus.rs",
    "content": "fn main() {\n    #[cfg(feature = \"cuda\")]\n    multi_gpus::run::<burn::backend::Cuda>();\n    #[cfg(feature = \"rocm\")]\n    multi_gpus::run::<burn::backend::Rocm>();\n    #[cfg(feature = \"tch-gpu\")]\n    multi_gpus::run::<burn::backend::LibTorch>();\n}\n"
  },
  {
    "path": "examples/multi-gpus/src/lib.rs",
    "content": "use burn::{\n    backend::Autodiff,\n    collective::{self, CollectiveConfig, PeerId, ReduceOperation},\n    data::{dataloader::DataLoaderBuilder, dataset::transform::PartialDataset},\n    nn::transformer::TransformerEncoderConfig,\n    optim::{GradientsParams, Optimizer, SgdConfig},\n    prelude::*,\n    tensor::{\n        TensorPrimitive,\n        backend::{AutodiffBackend, DeviceId},\n    },\n};\nuse std::{\n    sync::{Arc, mpsc::SyncSender},\n    time::Instant,\n};\nuse text_classification::{\n    AgNewsDataset, TextClassificationDataset,\n    data::{TextClassificationBatcher, Tokenizer},\n    model::TextClassificationModel,\n};\n\npub fn run<B: Backend>() {\n    let type_id = 0;\n    let num_devices = B::Device::device_count(type_id);\n\n    let devices = (0..num_devices)\n        .map(|i| B::Device::from_id(DeviceId::new(type_id, i as u32)))\n        .collect();\n\n    run_with::<B>(devices);\n}\n\nfn run_with<B: Backend>(devices: Vec<B::Device>) {\n    for strategy in [\n        collective::AllReduceStrategy::Tree(2),\n        collective::AllReduceStrategy::Ring,\n        collective::AllReduceStrategy::Centralized,\n    ] {\n        println!(\"[Gradient Update - {strategy:?}] starting ...\");\n        let start = Instant::now();\n        task_grad_all_reduce::<Autodiff<B>>(devices.clone(), strategy);\n        println!(\n            \"[Gradient Update - {strategy:?}] took {:?}\",\n            start.elapsed()\n        );\n    }\n    for strategy in [\n        collective::AllReduceStrategy::Centralized,\n        collective::AllReduceStrategy::Ring,\n        collective::AllReduceStrategy::Tree(2),\n    ] {\n        println!(\"[All Reduce - {strategy:?}] starting ...\");\n        let start = Instant::now();\n        task_all_reduce::<B>(devices.clone(), 420, strategy);\n        println!(\"[All Reduce - {strategy:?}] took {:?}\", start.elapsed());\n    }\n    task_naive_aggregation::<B>(devices.clone(), 100);\n}\n\nfn task_naive_aggregation<B: Backend>(mut devices: Vec<B::Device>, num_iterations: usize) {\n    let aggregation_device = devices.pop().unwrap();\n\n    let shape = [8, 4096, 4096];\n\n    let (sender, receiver) = std::sync::mpsc::sync_channel(devices.len());\n\n    fn compute<B: Backend>(input: Tensor<B, 3>) -> Tensor<B, 3> {\n        let log = input.clone() + 1.0;\n        input.matmul(log)\n    }\n\n    let mut handles = devices\n        .into_iter()\n        .map(|device| {\n            let sender = sender.clone();\n            std::thread::spawn(move || {\n                let input =\n                    Tensor::<B, 3>::random(shape, burn::tensor::Distribution::Default, &device);\n\n                for _ in 0..num_iterations {\n                    let new = compute(input.clone());\n                    sender.send(new.clone()).unwrap();\n                }\n            })\n        })\n        .collect::<Vec<_>>();\n\n    handles.push(std::thread::spawn(move || {\n        let mut input = Tensor::<B, 3>::random(\n            shape,\n            burn::tensor::Distribution::Default,\n            &aggregation_device,\n        );\n\n        while let Ok(tensor) = receiver.recv() {\n            let main = tensor.to_device(&aggregation_device);\n            let value = main.clone().sum().into_scalar().elem::<f32>();\n            input = input + main / 2;\n            println!(\"{value:?}\");\n            assert_ne!(value, 0.0);\n        }\n    }));\n\n    for handle in handles {\n        handle.join().unwrap();\n    }\n}\n\nfn task_all_reduce<B: Backend>(\n    devices: Vec<B::Device>,\n    num_iterations: usize,\n    strategy: collective::AllReduceStrategy,\n) {\n    let num_devices = devices.len();\n    let batch = 32;\n    let shape_signal = [batch, 2048, 2048];\n    let shape_weights = [1, 2048, 2048];\n\n    fn compute<B: Backend>(weights: Tensor<B, 3>, signal: Tensor<B, 3>) -> Tensor<B, 3> {\n        weights.matmul(signal)\n    }\n\n    let handles = devices\n        .into_iter()\n        .enumerate()\n        .map(|(id, device)| {\n            std::thread::spawn(move || {\n                let mut weights = Tensor::<B, 3>::random(\n                    shape_weights,\n                    burn::tensor::Distribution::Default,\n                    &device,\n                ) - 0.5;\n\n                let id = PeerId::from(id);\n                let config = CollectiveConfig::default()\n                    .with_num_devices(num_devices)\n                    .with_local_all_reduce_strategy(strategy);\n\n                collective::register::<B>(id, device.clone(), config).unwrap();\n\n                for i in 0..num_iterations {\n                    let signal = Tensor::<B, 3>::random(\n                        shape_signal,\n                        burn::tensor::Distribution::Default,\n                        &device,\n                    ) - 0.5;\n                    let signal = compute(weights, signal);\n                    let weights_update = signal.mean_dim(0);\n\n                    let result = collective::all_reduce::<B>(\n                        id,\n                        weights_update.into_primitive().tensor(),\n                        ReduceOperation::Mean,\n                    )\n                    .unwrap();\n                    weights = Tensor::from_primitive(TensorPrimitive::Float(result));\n                    let val = weights.clone().sum().into_scalar().elem::<f32>();\n                    if id == PeerId::from(0) {\n                        println!(\"Iter {i} => {val}\");\n                    }\n                }\n                collective::finish_collective::<B>(id).unwrap();\n            })\n        })\n        .collect::<Vec<_>>();\n\n    for handle in handles {\n        handle.join().unwrap();\n    }\n}\n\nfn task_grad_all_reduce<B: AutodiffBackend>(\n    devices: Vec<B::Device>,\n    strategy: collective::AllReduceStrategy,\n) {\n    let num_devices = devices.len();\n    let seq_length = nn::attention::SeqLengthOption::Fixed(512);\n    let batch_size = 32;\n    let config = TransformerEncoderConfig::new(256, 1024, 8, 4);\n\n    let dataset = text_classification::AgNewsDataset::train();\n    let tokenizer = Arc::new(text_classification::data::BertCasedTokenizer::default());\n    let model_config = text_classification::model::TextClassificationModelConfig::new(\n        config,\n        AgNewsDataset::num_classes(),\n        tokenizer.vocab_size(),\n        seq_length,\n    );\n    let datasets = PartialDataset::split(dataset, devices.len());\n    let model_main = model_config.init(&devices[0]);\n\n    let handles = devices\n        .into_iter()\n        .zip(datasets)\n        .enumerate()\n        .map(|(id, (device, dataset))| {\n            let model_main = model_main.clone();\n            let tokenizer = tokenizer.clone();\n\n            std::thread::spawn(move || {\n                println!(\"[{id}] Running on device {device:?}\");\n                let mut model = model_main.fork(&device);\n                let batcher = TextClassificationBatcher::new(tokenizer, seq_length);\n                let dataloader_train = DataLoaderBuilder::new(batcher)\n                    .batch_size(batch_size)\n                    .set_device(device.clone())\n                    .build(dataset);\n\n                let syncher = GradSyncer::start::<B>(\n                    CollectiveConfig::default()\n                        .with_num_devices(num_devices)\n                        .with_local_all_reduce_strategy(strategy),\n                    device.clone(),\n                    PeerId::from(id),\n                );\n\n                let mut optim = SgdConfig::new().init::<B, TextClassificationModel<B>>();\n\n                for (i, batch) in dataloader_train.iter().enumerate() {\n                    let output = model.forward(batch);\n                    let loss: Tensor<B, 1> = output.loss.clone();\n\n                    let grads = loss.backward();\n                    let loss = loss.into_scalar().elem::<f32>();\n\n                    let grads = GradientsParams::from_grads(grads, &model);\n                    let grads = syncher.sync(grads);\n\n                    if let Some(grads) = grads {\n                        model = optim.step(1.0e-5, model, grads);\n                    }\n\n                    println!(\"[{id}] Iter {i} => {loss}\");\n                }\n            })\n        })\n        .collect::<Vec<_>>();\n\n    for handle in handles {\n        handle.join().unwrap();\n    }\n}\n\nstruct GradSyncer {\n    sender: SyncSender<Message>,\n}\n\nstruct Message {\n    callback: SyncSender<Option<GradientsParams>>,\n    grads: GradientsParams,\n}\n\nimpl GradSyncer {\n    fn start<B: AutodiffBackend>(config: CollectiveConfig, device: Device<B>, id: PeerId) -> Self {\n        let (sender, receiver) = std::sync::mpsc::sync_channel::<Message>(8);\n\n        std::thread::spawn(move || {\n            println!(\"[{id}] Register collective operation {config:?}\");\n            collective::register::<B::InnerBackend>(id, device, config).unwrap();\n            let num_stages = 4;\n            let mut buffers: Vec<GradientsParams> = Vec::new();\n\n            while let Ok(msg) = receiver.recv() {\n                let grads = msg\n                    .grads\n                    .all_reduce::<B::InnerBackend>(id, ReduceOperation::Mean)\n                    .unwrap();\n\n                buffers.push(grads);\n\n                let result = if buffers.len() >= num_stages {\n                    Some(buffers.remove(0))\n                } else {\n                    None\n                };\n\n                msg.callback.send(result).unwrap();\n            }\n            collective::finish_collective::<B::InnerBackend>(id).unwrap();\n        });\n\n        Self { sender }\n    }\n\n    fn sync(&self, grads: GradientsParams) -> Option<GradientsParams> {\n        let (sender, receiver) = std::sync::mpsc::sync_channel(1);\n        let msg = Message {\n            callback: sender,\n            grads,\n        };\n        self.sender.send(msg).unwrap();\n\n        receiver.recv().unwrap()\n    }\n}\n"
  },
  {
    "path": "examples/notebook/README.md",
    "content": "# Jupyter Notebook Examples with Burn\n\nThis directory includes Jupyter Notebook examples showcasing the usage of the Burn deep learning\nframework in Rust through\n[Evcxr Jupyter](https://github.com/evcxr/evcxr/blob/main/evcxr_jupyter/README.md). The examples are\nsystematically organized based on the specific Burn features they illustrate.\n\n## Viewing Options\n\nYou can explore the examples in different ways:\n\n- **Notebook Viewer:** If you prefer not to set up the entire crate package, you can view the\n  examples in a notebook viewer or run them to see images and other media outputs.\n\n- **Visual Studio Code (vscode):** If you're using vscode, you already have access to a built-in\n  notebook viewer, enabling you to open and interact with the notebook files directly.\n\nFor other editors, you can utilize the [Jupyter Notebook Viewer](https://nbviewer.jupyter.org/).\n\n## Getting Started with Rust and Evcxr\n\nTo execute the Rust code within the notebooks, you must install the Evcxr kernel. Here's how to get\nstarted:\n\n### Install Evcxr Kernel\n\n1. **Build Evcxr Kernel:** Install the required package with the following command:\n\n   ```shell\n   cargo install evcxr_jupyter\n   ```\n\n2. **Install and Register the Kernel to Jupyter:**\n   ```shell\n   evcxr_jupyter --install\n   ```\n\n### Open and Run Notebooks\n\nOnce the kernel is installed, you can open the notebook files in your preferred editor and run the\ncode. Ensure that the kernel is set to `Rust` within the notebook for proper execution.\n\n## Additional Reading Resources\n\n- [Notebook Special Commands for Evcxr](https://github.com/evcxr/evcxr/blob/main/COMMON.md): Learn\n  about the unique commands and functionalities offered by Evcxr for a more efficient workflow with\n  Jupyter Notebooks.\n"
  },
  {
    "path": "examples/notebook/autodiff.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Autodifferentiation and Gradient Descent in Burn\\n\",\n    \"\\n\",\n    \"This notebook demonstrates how to use automatic differentiation in Burn to compute gradients and implement gradient descent.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"// Dependency declarations\\n\",\n    \":dep burn = {path = \\\"../../crates/burn\\\"}\\n\",\n    \":dep burn-ndarray = {path = \\\"../../crates/burn-ndarray\\\"}\\n\",\n    \":dep burn-autodiff = {path = \\\"../../crates/burn-autodiff\\\"}\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"// Import packages\\n\",\n    \"use burn::prelude::*;\\n\",\n    \"use burn_autodiff::Autodiff;\\n\",\n    \"use burn_ndarray::NdArray;\\n\",\n    \"\\n\",\n    \"// Type alias: Autodiff<NdArray> enables automatic differentiation\\n\",\n    \"type B = Autodiff<NdArray<f32>>;\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 1. Understanding require_grad()\\n\",\n    \"\\n\",\n    \"In Burn, tensors can be marked for gradient tracking using `.require_grad()`. This tells the framework to track operations on this tensor so gradients can be computed later.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Regular tensor x: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.0, 3.0, 4.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Tensor y with require_grad: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.0, 3.0, 4.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"result = sum(y * 2) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[20.0],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"\\n\",\n    \"// Create a regular tensor - no gradient tracking\\n\",\n    \"let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device);\\n\",\n    \"println!(\\\"Regular tensor x: {}\\\", x);\\n\",\n    \"\\n\",\n    \"// Create a tensor that requires gradient computation\\n\",\n    \"let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\\n\",\n    \"println!(\\\"Tensor y with require_grad: {}\\\", y);\\n\",\n    \"\\n\",\n    \"// Now let's do some operations on y\\n\",\n    \"let z = y.clone() * 2.0;\\n\",\n    \"let result = z.sum();\\n\",\n    \"println!(\\\"result = sum(y * 2) = {}\\\", result);\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Computing Gradients with backward()\\n\",\n    \"\\n\",\n    \"The `.backward()` method computes the gradients of all tensors that have `require_grad()` set. It returns a gradients object that holds the computed gradients.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"y = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.0, 3.0, 4.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"dy/dx = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[2.0, 2.0, 2.0, 2.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Example: y = [1, 2, 3, 4]\\n\",\n    \"// z = y * 2 = [2, 4, 6, 8]\\n\",\n    \"// result = sum(z) = 20\\n\",\n    \"//\\n\",\n    \"// d(result)/d(y) = d(result)/dz * dz/dy = 1 * 2 = [2, 2, 2, 2]\\n\",\n    \"\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"let y: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\\n\",\n    \"let z = y.clone() * 2.0;\\n\",\n    \"let result = z.sum();\\n\",\n    \"\\n\",\n    \"// Compute gradients\\n\",\n    \"let grads = result.backward();\\n\",\n    \"\\n\",\n    \"// Get gradient for y\\n\",\n    \"let y_grad = y.grad(&grads).unwrap();\\n\",\n    \"println!(\\\"y = {}\\\", y);\\n\",\n    \"println!(\\\"d(result)/dy = {}\\\", y_grad);\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 3. More Complex Example: Quadratic Function\\n\",\n    \"Let's compute the gradient of a more complex function: f(x) = x²\\n\",\n    \"\\n\",\n    \"The derivative is: f'(x) = 2x\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"x = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.0, 3.0, 4.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"x^2 = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 4.0, 9.0, 16.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"d(x^2)/dx = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[2.0, 4.0, 6.0, 8.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Expected: [2, 4, 6, 8]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// f(x) = x^2\\n\",\n    \"// f'(x) = 2x\\n\",\n    \"\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"let x: Tensor<B, 1> = Tensor::from_floats([1.0, 2.0, 3.0, 4.0], &device).require_grad();\\n\",\n    \"let y = x.clone().powf_scalar(2.0);\\n\",\n    \"let result = y.clone().sum();\\n\",\n    \"\\n\",\n    \"let grads = result.backward();\\n\",\n    \"let x_grad = x.grad(&grads).unwrap();\\n\",\n    \"\\n\",\n    \"println!(\\\"x = {}\\\", x);\\n\",\n    \"println!(\\\"x^2 = {}\\\", y);\\n\",\n    \"println!(\\\"d(x^2)/dx = {}\\\", x_grad);\\n\",\n    \"\\n\",\n    \"// Verify: d(x^2)/dx should be [2, 4, 6, 8]\\n\",\n    \"println!(\\\"Expected: [2, 4, 6, 8]\\\");\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 4. Chain Rule Example\\n\",\n    \"\\n\",\n    \"Let's verify the chain rule: f(g(x))' = f'(g(x)) * g'(x)\\n\",\n    \"\\n\",\n    \"Example: y = sin(x²), we want dy/dx\\n\",\n    \"\\n\",\n    \"Let u = x², y = sin(u)\\n\",\n    \"dy/du = cos(u), du/dx = 2x\\n\",\n    \"dy/dx = cos(x²) * 2x\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"x = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0, 2.0, 3.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"y = sin(x^2) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 0.84147096, -0.7568025, 0.4121185],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"dy/dx = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0806046, -2.6145744, -5.4667816],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Expected (cos(x^2) * 2x): Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0806046, -2.6145744, -5.4667816],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"autodiff<ndarray>\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// y = sin(x^2)\\n\",\n    \"// dy/dx = cos(x^2) * 2x\\n\",\n    \"\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"let x: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device).require_grad();\\n\",\n    \"\\n\",\n    \"// Forward pass\\n\",\n    \"let x_squared = x.clone().powf_scalar(2.0);\\n\",\n    \"let y = x_squared.sin();\\n\",\n    \"let result = y.clone().sum();\\n\",\n    \"\\n\",\n    \"// Backward pass\\n\",\n    \"let grads = result.backward();\\n\",\n    \"let x_grad = x.grad(&grads).unwrap();\\n\",\n    \"\\n\",\n    \"println!(\\\"x = {}\\\", x);\\n\",\n    \"println!(\\\"y = sin(x^2) = {}\\\", y);\\n\",\n    \"println!(\\\"dy/dx = {}\\\", x_grad);\\n\",\n    \"\\n\",\n    \"// Verify manually: cos(x^2) * 2x\\n\",\n    \"let expected_grad = x.clone().powf_scalar(2.0).cos() * (x.clone() * 2.0);\\n\",\n    \"println!(\\\"Expected (cos(x^2) * 2x): {}\\\", expected_grad);\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 5. Gradient Descent from Scratch\\n\",\n    \"\\n\",\n    \"Now let's implement the classic gradient descent algorithm to find the minimum of a function.\\n\",\n    \"\\n\",\n    \"We'll minimize: f(x) = (x - 3)²\\n\",\n    \"\\n\",\n    \"The minimum is at x = 3, where f(x) = 0\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Starting gradient descent to minimize (x - 3)^2\\n\",\n      \"Expected minimum: x = 3\\n\",\n      \"---\\n\",\n      \"Iteration 0: x = 0.0000, loss = 9.0000\\n\",\n      \"Iteration 1: x = 0.6000, loss = 5.7600\\n\",\n      \"Iteration 2: x = 1.0800, loss = 3.6864\\n\",\n      \"Iteration 3: x = 1.4640, loss = 2.3593\\n\",\n      \"Iteration 4: x = 1.7712, loss = 1.5099\\n\",\n      \"Iteration 5: x = 2.0170, loss = 0.9664\\n\",\n      \"Iteration 6: x = 2.2136, loss = 0.6185\\n\",\n      \"Iteration 7: x = 2.3709, loss = 0.3958\\n\",\n      \"Iteration 8: x = 2.4967, loss = 0.2533\\n\",\n      \"Iteration 9: x = 2.5973, loss = 0.1621\\n\",\n      \"Iteration 10: x = 2.6779, loss = 0.1038\\n\",\n      \"Iteration 11: x = 2.7423, loss = 0.0664\\n\",\n      \"Iteration 12: x = 2.7938, loss = 0.0425\\n\",\n      \"Iteration 13: x = 2.8351, loss = 0.0272\\n\",\n      \"Iteration 14: x = 2.8681, loss = 0.0174\\n\",\n      \"Iteration 15: x = 2.8944, loss = 0.0111\\n\",\n      \"Iteration 16: x = 2.9156, loss = 0.0071\\n\",\n      \"Iteration 17: x = 2.9324, loss = 0.0046\\n\",\n      \"Iteration 18: x = 2.9460, loss = 0.0029\\n\",\n      \"Iteration 19: x = 2.9568, loss = 0.0019\\n\",\n      \"---\\n\",\n      \"Final x = 2.9654\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Target: minimize f(x) = (x - 3)^2\\n\",\n    \"// This has minimum at x = 3\\n\",\n    \"\\n\",\n    \"fn loss<B: Backend>(x: &Tensor<B, 1>) -> Tensor<B, 1> {\\n\",\n    \"    // f(x) = (x - 3)^2\\n\",\n    \"    (x.clone() - 3.0).powf_scalar(2.0)\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"// Start from x = 0\\n\",\n    \"let mut x_val: f32 = 0.0;\\n\",\n    \"\\n\",\n    \"let learning_rate: f32 = 0.1;\\n\",\n    \"\\n\",\n    \"println!(\\\"Starting gradient descent to minimize (x - 3)^2\\\");\\n\",\n    \"println!(\\\"Expected minimum: x = 3\\\");\\n\",\n    \"println!(\\\"---\\\");\\n\",\n    \"\\n\",\n    \"for i in 0..20 {\\n\",\n    \"    // Create a new tensor with current x value and require gradients\\n\",\n    \"    let x = Tensor::<B, 1>::from_floats([x_val], &device).require_grad();\\n\",\n    \"    \\n\",\n    \"    // Forward pass\\n\",\n    \"    let loss_value = loss(&x);\\n\",\n    \"    \\n\",\n    \"    // Get loss as f32 for printing\\n\",\n    \"    let loss_scalar: f32 = loss_value.clone().into_scalar().elem::<f32>();\\n\",\n    \"    \\n\",\n    \"    println!(\\\"Iteration {}: x = {:.4}, loss = {:.4}\\\", i, x_val, loss_scalar);\\n\",\n    \"\\n\",\n    \"    // Backward pass\\n\",\n    \"    let grads = loss_value.backward();\\n\",\n    \"    let grad = x.grad(&grads).unwrap();\\n\",\n    \"    \\n\",\n    \"    // Update: x = x - learning_rate * gradient\\n\",\n    \"    let grad_val: f32 = grad.into_scalar().elem::<f32>();\\n\",\n    \"    x_val = x_val - grad_val * learning_rate;\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"println!(\\\"---\\\");\\n\",\n    \"println!(\\\"Final x = {:.4}\\\", x_val);\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 6. Linear Regression with Gradient Descent\\n\",\n    \"\\n\",\n    \"Let's use gradient descent to fit a simple linear regression model: y = wx + b\\n\",\n    \"\\n\",\n    \"We'll generate synthetic data where the true relationship is y = 2x + 1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Generated 100 data points\\n\",\n      \"True relationship: y = 2x + 1\\n\",\n      \"First 5 x values: [0.0, 0.1, 0.2, 0.3, 0.4]\\n\",\n      \"First 5 y values: [0.87993187, 0.98804677, 1.5366085, 1.7324162, 1.653858]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"use burn::tensor::{Distribution, TensorData};\\n\",\n    \"\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"// Generate synthetic data: y = 2x + 1 + noise\\n\",\n    \"let num_samples = 100;\\n\",\n    \"let x_data = TensorData::new((0..num_samples).map(|i| i as f32 / 10.0).collect(), [num_samples, 1]);\\n\",\n    \"// Generate noise using Burn's random tensor\\n\",\n    \"let noise = Tensor::<B, 2>::random([num_samples, 1], Distribution::Uniform(-0.25, 0.25), &device);\\n\",\n    \"\\n\",\n    \"let x = Tensor::<B, 2>::from(x_data);\\n\",\n    \"let y: Tensor<B, 2> = 2 * x.clone() + 1 + noise;\\n\",\n    \"\\n\",\n    \"println!(\\\"Generated {} data points\\\", num_samples);\\n\",\n    \"println!(\\\"True relationship: y = 2x + 1\\\");\\n\",\n    \"println!(\\\"First 5 x values: {}\\\", x.clone().slice([0..5, 0..1]));\\n\",\n    \"println!(\\\"First 5 y values: {}\\\", y.clone().slice([0..5, 0..1]));\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 31,\n   \"metadata\": {\n    \"vscode\": {\n     \"languageId\": \"rust\"\n    }\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Training linear regression with gradient descent...\\n\",\n      \"Initial w = 0.5000, b = 0.5000\\n\",\n      \"Epoch   0: loss = 81.7705, w = 1.5358, b = 0.6586\\n\",\n      \"Epoch  20: loss = 0.0365, w = 2.0384, b = 0.7594\\n\",\n      \"Epoch  40: loss = 0.0341, w = 2.0351, b = 0.7810\\n\",\n      \"Epoch  60: loss = 0.0321, w = 2.0322, b = 0.8006\\n\",\n      \"Epoch  80: loss = 0.0305, w = 2.0295, b = 0.8184\\n\",\n      \"---\\n\",\n      \"Final: w = 2.0272, b = 0.8336\\n\",\n      \"True: w = 2.0, b = 1.0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Initialize weights randomly\\n\",\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"let mut w_val: f32 = 0.5; // Start with reasonable initial values\\n\",\n    \"let mut b_val: f32 = 0.5;\\n\",\n    \"\\n\",\n    \"let learning_rate: f32 = 0.01;\\n\",\n    \"let num_epochs = 100;\\n\",\n    \"\\n\",\n    \"println!(\\\"Training linear regression with gradient descent...\\\");\\n\",\n    \"println!(\\\"Initial w = {:.4}, b = {:.4}\\\", w_val, b_val);\\n\",\n    \"\\n\",\n    \"for epoch in 0..num_epochs {\\n\",\n    \"    // Create tensors with current parameter values\\n\",\n    \"    let w = Tensor::<B, 2>::from_floats([[w_val]], &device).require_grad();\\n\",\n    \"    let b = Tensor::<B, 2>::from_floats([[b_val]], &device).require_grad();\\n\",\n    \"    \\n\",\n    \"    // Forward pass: y_pred = w * x + b\\n\",\n    \"    let y_pred = x.clone().matmul(w.clone()) + b.clone();\\n\",\n    \"    \\n\",\n    \"    // Compute loss: MSE = (1/n) * sum((y_pred - y)^2)\\n\",\n    \"    let loss = (y_pred.clone() - y.clone()).powf_scalar(2.0).mean();\\n\",\n    \"    \\n\",\n    \"    // Backward pass\\n\",\n    \"    let grads = loss.backward();\\n\",\n    \"    let w_grad = w.grad(&grads).unwrap();\\n\",\n    \"    let b_grad = b.grad(&grads).unwrap();\\n\",\n    \"    \\n\",\n    \"    // Update weights\\n\",\n    \"    let w_grad_val: f32 = w_grad.into_scalar().elem::<f32>();\\n\",\n    \"    let b_grad_val: f32 = b_grad.into_scalar().elem::<f32>();\\n\",\n    \"    w_val = w_val - w_grad_val * learning_rate;\\n\",\n    \"    b_val = b_val - b_grad_val * learning_rate;\\n\",\n    \"    \\n\",\n    \"    if epoch % 20 == 0 {\\n\",\n    \"        let loss_val: f32 = loss.clone().into_scalar().elem::<f32>();\\n\",\n    \"        println!(\\\"Epoch {:3}: loss = {:.4}, w = {:.4}, b = {:.4}\\\", epoch, loss_val, w_val, b_val);\\n\",\n    \"    }\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"println!(\\\"---\\\");\\n\",\n    \"println!(\\\"Final: w = {:.4}, b = {:.4}\\\", w_val, b_val);\\n\",\n    \"println!(\\\"True: w = 2.0, b = 1.0\\\");\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Summary\\n\",\n    \"\\n\",\n    \"In this notebook, we covered:\\n\",\n    \"\\n\",\n    \"- **require_grad()**: Mark tensors for gradient tracking\\n\",\n    \"- **backward()**: Compute gradients automatically using reverse-mode autodiff\\n\",\n    \"- **grad()**: Retrieve computed gradients\\n\",\n    \"- **Gradient Descent**: Implemented from scratch to minimize a quadratic function\\n\",\n    \"- **Linear Regression**: Used gradient descent to fit a linear model to data\\n\",\n    \"\\n\",\n    \"These concepts are the foundation of neural network training in Burn!\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Rust\",\n   \"language\": \"rust\",\n   \"name\": \"rust\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"rust\",\n   \"file_extension\": \".rs\",\n   \"mimetype\": \"text/rust\",\n   \"name\": \"Rust\",\n   \"pygment_lexer\": \"rust\",\n   \"version\": \"\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/notebook/basic-tensor-op.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Tensor Operations in Burn\\n\",\n    \"\\n\",\n    \"This notebook demonstrates basic tensor operations in Burn, a deep learning framework written in Rust.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"\\n\",\n    \"// Dependency declarations for the notebook.\\n\",\n    \"// The syntax is similar to Cargo.toml. Just prefix with :dep\\n\",\n    \"\\n\",\n    \":dep burn = {path = \\\"../../crates/burn\\\"}\\n\",\n    \":dep burn-ndarray = {path = \\\"../../crates/burn-ndarray\\\"}\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"// Import packages\\n\",\n    \"use burn::prelude::*;\\n\",\n    \"use burn_ndarray::NdArray;\\n\",\n    \"\\n\",\n    \"// Type alias for the backend (using CPU/NdArray)\\n\",\n    \"type B = NdArray<f32>;\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 1. Tensor Creation\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Empty tensor shape: Shape { dims: [2, 3, 4] }\\n\",\n      \"Zeros tensor: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[0.0, 0.0, 0.0],\\n\",\n      \" [0.0, 0.0, 0.0],\\n\",\n      \" [0.0, 0.0, 0.0]],\\n\",\n      \"  shape:  [3, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Ones tensor: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 1.0, 1.0, 1.0],\\n\",\n      \" [1.0, 1.0, 1.0, 1.0]],\\n\",\n      \"  shape:  [2, 4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Full tensor (7.0): Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[7.0, 7.0, 7.0],\\n\",\n      \" [7.0, 7.0, 7.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let device = <B as Backend>::Device::default();\\n\",\n    \"\\n\",\n    \"// Create an empty tensor (uninitialized values)\\n\",\n    \"let empty: Tensor<B, 3> = Tensor::empty([2, 3, 4], &device);\\n\",\n    \"println!(\\\"Empty tensor shape: {:?}\\\", empty.shape());\\n\",\n    \"\\n\",\n    \"// Create a tensor filled with zeros\\n\",\n    \"let zeros: Tensor<B, 2> = Tensor::zeros([3, 3], &device);\\n\",\n    \"println!(\\\"Zeros tensor: {}\\\", zeros);\\n\",\n    \"\\n\",\n    \"// Create a tensor filled with ones\\n\",\n    \"let ones: Tensor<B, 2> = Tensor::ones([2, 4], &device);\\n\",\n    \"println!(\\\"Ones tensor: {}\\\", ones);\\n\",\n    \"\\n\",\n    \"// Create a tensor filled with a specific value\\n\",\n    \"let full: Tensor<B, 2> = Tensor::full([2, 3], 7.0, &device);\\n\",\n    \"println!(\\\"Full tensor (7.0): {}\\\", full);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"From slice:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0, 3.0],\\n\",\n      \" [4.0, 5.0, 6.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Random tensor: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.32371014, 0.41100568, 0.94457513, 0.8408601, 0.42262083],\\n\",\n      \"  shape:  [5],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Normal distribution: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[-0.22402725, 1.8367178, -1.1049407, -0.6302627, 1.1106112],\\n\",\n      \"  shape:  [5],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Uniform [0, 10): Tensor {\\n\",\n      \"  data:\\n\",\n      \"[8.110331, 7.335061, 9.858947, 6.0834813, 3.6619747],\\n\",\n      \"  shape:  [5],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Create a tensor from a slice of values\\n\",\n    \"let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];\\n\",\n    \"let from_slice = Tensor::<B, 1>::from_floats(data, &device).reshape([2, 3]);\\n\",\n    \"println!(\\\"From slice:\\\\n{}\\\", from_slice);\\n\",\n    \"\\n\",\n    \"// Create a random tensor\\n\",\n    \"use burn::tensor::Distribution;\\n\",\n    \"let random: Tensor<B, 1> = Tensor::random([5], Distribution::Default, &device);\\n\",\n    \"println!(\\\"Random tensor: {}\\\", random);\\n\",\n    \"\\n\",\n    \"// Create a tensor with normal distribution\\n\",\n    \"let normal: Tensor<B, 1> = Tensor::random([5], Distribution::Normal(0.0, 1.0), &device);\\n\",\n    \"println!(\\\"Normal distribution: {}\\\", normal);\\n\",\n    \"\\n\",\n    \"// Create a tensor with uniform distribution in range [0, 10)\\n\",\n    \"let uniform: Tensor<B, 1> = Tensor::random([5], Distribution::Uniform(0.0, 10.0), &device);\\n\",\n    \"println!(\\\"Uniform [0, 10): {}\\\", uniform);\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Shape Operations\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Original (2x3):\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0, 3.0],\\n\",\n      \" [4.0, 5.0, 6.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Reshaped (1x2x3): Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[[1.0, 2.0, 3.0],\\n\",\n      \"  [4.0, 5.0, 6.0]]],\\n\",\n      \"  shape:  [1, 2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Flattened: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],\\n\",\n      \"  shape:  [6],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Reshape tensor - change the dimensions without changing the data\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\\n\",\n    \"println!(\\\"Original (2x3):\\\\n{}\\\", tensor);\\n\",\n    \"\\n\",\n    \"let reshaped: Tensor<B, 3> = tensor.clone().reshape([1, 2, 3]);\\n\",\n    \"println!(\\\"Reshaped (1x2x3): {}\\\", reshaped);\\n\",\n    \"\\n\",\n    \"// Flatten - reshape to 1D\\n\",\n    \"let flat: Tensor<B, 1> = tensor.flatten(0, 1);\\n\",\n    \"println!(\\\"Flattened: {}\\\", flat);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Original:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0],\\n\",\n      \" [3.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Transposed:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 3.0],\\n\",\n      \" [2.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Using .t():\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 3.0],\\n\",\n      \" [2.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Transpose - swap dimensions\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\\n\",\n    \"println!(\\\"Original:\\\\n{}\\\", tensor);\\n\",\n    \"\\n\",\n    \"let transposed = tensor.clone().transpose();\\n\",\n    \"println!(\\\"Transposed:\\\\n{}\\\", transposed);\\n\",\n    \"\\n\",\n    \"// Also .t() works for 2D tensors\\n\",\n    \"let t = tensor.t();\\n\",\n    \"println!(\\\"Using .t():\\\\n{}\\\", t);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 20,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Before squeeze [1,1,2]: shape = Shape { dims: [1, 1, 2] }\\n\",\n      \"After squeeze: shape = Shape { dims: [2] }\\n\",\n      \"Before unsqueeze [2,2]: shape = Shape { dims: [2, 2] }\\n\",\n      \"After unsqueeze: shape = Shape { dims: [1, 2, 2] }\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Squeeze - remove dimensions of size 1\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0], &device).reshape([1, 1, 2]);\\n\",\n    \"println!(\\\"Before squeeze [1,1,2]: shape = {:?}\\\", tensor.shape());\\n\",\n    \"\\n\",\n    \"let squeezed = tensor.squeeze::<1>();\\n\",\n    \"println!(\\\"After squeeze: shape = {:?}\\\", squeezed.shape());\\n\",\n    \"\\n\",\n    \"// Unsqueeze - add a dimension of size 1 at specified position\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\\n\",\n    \"println!(\\\"Before unsqueeze [2,2]: shape = {:?}\\\", tensor.shape());\\n\",\n    \"\\n\",\n    \"let unsqueezed = tensor.unsqueeze::<3>();\\n\",\n    \"println!(\\\"After unsqueeze: shape = {:?}\\\", unsqueezed.shape());\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 3. Indexing and Slicing\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 22,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Original tensor:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0, 3.0, 4.0],\\n\",\n      \" [5.0, 6.0, 7.0, 8.0],\\n\",\n      \" [9.0, 10.0, 11.0, 12.0]],\\n\",\n      \"  shape:  [3, 4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Create a tensor for indexing examples\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats(\\n\",\n    \"    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],\\n\",\n    \"&device\\n\",\n    \").reshape([3, 4]);\\n\",\n    \"println!(\\\"Original tensor:\\\\n{}\\\", tensor);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Sliced [1..3, 1..4]:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[6.0, 7.0, 8.0],\\n\",\n      \" [10.0, 11.0, 12.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Row 1: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[5.0, 6.0, 7.0, 8.0]],\\n\",\n      \"  shape:  [1, 4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Column 2: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[3.0],\\n\",\n      \" [7.0],\\n\",\n      \" [11.0]],\\n\",\n      \"  shape:  [3, 1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Slice tensor - select a portion using ranges\\n\",\n    \"// Get rows 1-2 (index 1 to end), columns 1-3 (index 1 to 3)\\n\",\n    \"let sliced = tensor.clone().slice([1..3, 1..4]);\\n\",\n    \"println!(\\\"Sliced [1..3, 1..4]:\\\\n{}\\\", sliced);\\n\",\n    \"\\n\",\n    \"// Get single row\\n\",\n    \"let row = tensor.clone().slice([1..2, 0..4]);\\n\",\n    \"println!(\\\"Row 1: {}\\\", row);\\n\",\n    \"\\n\",\n    \"// Get single column\\n\",\n    \"let col = tensor.slice([0..3, 2..3]);\\n\",\n    \"println!(\\\"Column 2: {}\\\", col);\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 4. Basic Math Operations\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"a = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0],\\n\",\n      \" [3.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[5.0, 6.0],\\n\",\n      \" [7.0, 8.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a + b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[6.0, 8.0],\\n\",\n      \" [10.0, 12.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a - b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[-4.0, -4.0],\\n\",\n      \" [-4.0, -4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a * b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[5.0, 12.0],\\n\",\n      \" [21.0, 32.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a / b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[0.2, 0.33333334],\\n\",\n      \" [0.42857143, 0.5]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\\n\",\n    \"let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);\\n\",\n    \"\\n\",\n    \"println!(\\\"a = {}\\\", a);\\n\",\n    \"println!(\\\"b = {}\\\", b);\\n\",\n    \"\\n\",\n    \"// Addition\\n\",\n    \"let c = a.clone() + b.clone();\\n\",\n    \"println!(\\\"a + b = {}\\\", c);\\n\",\n    \"\\n\",\n    \"// Subtraction\\n\",\n    \"let c = a.clone() - b.clone();\\n\",\n    \"println!(\\\"a - b = {}\\\", c);\\n\",\n    \"\\n\",\n    \"// Multiplication (element-wise)\\n\",\n    \"let c = a.clone() * b.clone();\\n\",\n    \"println!(\\\"a * b = {}\\\", c);\\n\",\n    \"\\n\",\n    \"// Division (element-wise)\\n\",\n    \"let c = a.clone() / b.clone();\\n\",\n    \"println!(\\\"a / b = {}\\\", c);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 28,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"a = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0],\\n\",\n      \" [3.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a + 10 = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[11.0, 12.0],\\n\",\n      \" [13.0, 14.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a * 2 = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[2.0, 4.0],\\n\",\n      \" [6.0, 8.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Scalar operations\\n\",\n    \"let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\\n\",\n    \"\\n\",\n    \"println!(\\\"a = {}\\\", a);\\n\",\n    \"\\n\",\n    \"// Add scalar\\n\",\n    \"let c = a.clone() + 10.0;\\n\",\n    \"println!(\\\"a + 10 = {}\\\", c);\\n\",\n    \"\\n\",\n    \"// Multiply scalar\\n\",\n    \"let c = a.clone() * 2.0;\\n\",\n    \"println!(\\\"a * 2 = {}\\\", c);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 30,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"a = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0],\\n\",\n      \" [3.0, 4.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[5.0, 6.0],\\n\",\n      \" [7.0, 8.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a @ b (matmul) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[19.0, 22.0],\\n\",\n      \" [43.0, 50.0]],\\n\",\n      \"  shape:  [2, 2],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Matrix multiplication\\n\",\n    \"let a = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device).reshape([2, 2]);\\n\",\n    \"let b = Tensor::<B, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device).reshape([2, 2]);\\n\",\n    \"\\n\",\n    \"println!(\\\"a = {}\\\", a);\\n\",\n    \"println!(\\\"b = {}\\\", b);\\n\",\n    \"\\n\",\n    \"let result = a.matmul(b);\\n\",\n    \"println!(\\\"a @ b (matmul) = {}\\\", result);\\n\",\n    \"\\n\",\n    \"// Verify (rows of a · columns of b): row1 [1,2] · col1 [5,7] = 1*5+2*7 = 19, row1 [1,2] · col2 [6,8] = 1*6+2*8 = 22\\n\",\n    \"//                                      row2 [3,4] · col1 [5,7] = 3*5+4*7 = 43, row2 [3,4] · col2 [6,8] = 3*6+4*8 = 50\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 5. Element-wise Math Functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"a = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0, 2.0],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"exp(a) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 2.7182817, 7.389056],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"log(a + 1) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 0.6931472, 1.0986123],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a.powf(2) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0, 4.0],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a.powf(0.5) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0, 1.4142135],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let a: Tensor<B, 1> = Tensor::from_floats([0.0, 1.0, 2.0], &device);\\n\",\n    \"\\n\",\n    \"println!(\\\"a = {}\\\", a);\\n\",\n    \"\\n\",\n    \"// Exponential\\n\",\n    \"println!(\\\"exp(a) = {}\\\", a.clone().exp());\\n\",\n    \"\\n\",\n    \"// Natural logarithm\\n\",\n    \"println!(\\\"log(a + 1) = {}\\\", (a.clone() + 1.0).log());\\n\",\n    \"\\n\",\n    \"// Power\\n\",\n    \"println!(\\\"a.powf(2) = {}\\\", a.clone().powf_scalar(2.0));\\n\",\n    \"println!(\\\"a.powf(0.5) = {}\\\", a.clone().powf_scalar(0.5));\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 33,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"angles = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 0.7853982, 1.5707964],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"sin(angles) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 0.70710677, 1.0],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"cos(angles) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 0.70710677, -4.371139e-8],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"tan(angles) = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[0.0, 1.0, -22877332.0],\\n\",\n      \"  shape:  [3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Trigonometric functions\\n\",\n    \"let angles: Tensor<B, 1> = Tensor::from_floats([0.0, std::f32::consts::PI / 4.0, std::f32::consts::PI / 2.0], &device);\\n\",\n    \"\\n\",\n    \"println!(\\\"angles = {}\\\", angles);\\n\",\n    \"println!(\\\"sin(angles) = {}\\\", angles.clone().sin());\\n\",\n    \"println!(\\\"cos(angles) = {}\\\", angles.clone().cos());\\n\",\n    \"println!(\\\"tan(angles) = {}\\\", angles.clone().tan());\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 6. Reduction Operations\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 35,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tensor:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0, 3.0],\\n\",\n      \" [4.0, 5.0, 6.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Sum: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[21.0],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Mean: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[3.5],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Product: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[720.0],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Max: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[6.0],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Min: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0],\\n\",\n      \"  shape:  [1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\\n\",\n    \"println!(\\\"Tensor:\\\\n{}\\\", tensor);\\n\",\n    \"\\n\",\n    \"// Sum all elements\\n\",\n    \"println!(\\\"Sum: {}\\\", tensor.clone().sum());\\n\",\n    \"\\n\",\n    \"// Mean of all elements\\n\",\n    \"println!(\\\"Mean: {}\\\", tensor.clone().mean());\\n\",\n    \"\\n\",\n    \"// Product of all elements\\n\",\n    \"println!(\\\"Product: {}\\\", tensor.clone().prod());\\n\",\n    \"\\n\",\n    \"// Maximum and minimum\\n\",\n    \"println!(\\\"Max: {}\\\", tensor.clone().max());\\n\",\n    \"println!(\\\"Min: {}\\\", tensor.clone().min());\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 37,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Tensor:\\n\",\n      \"Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[1.0, 2.0, 3.0],\\n\",\n      \" [4.0, 5.0, 6.0]],\\n\",\n      \"  shape:  [2, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Sum dim 0: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[5.0, 7.0, 9.0]],\\n\",\n      \"  shape:  [1, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Sum dim 1: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[6.0],\\n\",\n      \" [15.0]],\\n\",\n      \"  shape:  [2, 1],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Mean dim 0: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[[2.5, 3.5, 4.5]],\\n\",\n      \"  shape:  [1, 3],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Reduce along specific dimensions\\n\",\n    \"let tensor = Tensor::<B, 1>::from_floats([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &device).reshape([2, 3]);\\n\",\n    \"println!(\\\"Tensor:\\\\n{}\\\", tensor);\\n\",\n    \"\\n\",\n    \"// Sum along dimension 0 (columns)\\n\",\n    \"println!(\\\"Sum dim 0: {}\\\", tensor.clone().sum_dim(0));\\n\",\n    \"\\n\",\n    \"// Sum along dimension 1 (rows)\\n\",\n    \"println!(\\\"Sum dim 1: {}\\\", tensor.clone().sum_dim(1));\\n\",\n    \"\\n\",\n    \"// Mean along dimension 0\\n\",\n    \"println!(\\\"Mean dim 0: {}\\\", tensor.clone().mean_dim(0));\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 7. Comparison and Selection\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 42,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"a = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 5.0, 3.0, 8.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"b = Tensor {\\n\",\n      \"  data:\\n\",\n      \"[4.0, 2.0, 6.0, 7.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"a > b: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[false, true, false, true],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Bool\\\",\\n\",\n      \"  dtype:  \\\"bool\\\",\\n\",\n      \"}\\n\",\n      \"a < b: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[true, false, true, false],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Bool\\\",\\n\",\n      \"  dtype:  \\\"bool\\\",\\n\",\n      \"}\\n\",\n      \"a == b: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[false, false, false, false],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Bool\\\",\\n\",\n      \"  dtype:  \\\"bool\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);\\n\",\n    \"let b: Tensor<B, 1> = Tensor::from_floats([4.0, 2.0, 6.0, 7.0], &device);\\n\",\n    \"\\n\",\n    \"println!(\\\"a = {}\\\", a);\\n\",\n    \"println!(\\\"b = {}\\\", b);\\n\",\n    \"\\n\",\n    \"// Element-wise comparison returns a boolean tensor\\n\",\n    \"let greater = a.clone().greater(b.clone());\\n\",\n    \"println!(\\\"a > b: {}\\\", greater);\\n\",\n    \"\\n\",\n    \"let less = a.clone().lower(b.clone());\\n\",\n    \"println!(\\\"a < b: {}\\\", less);\\n\",\n    \"\\n\",\n    \"let equal = a.clone().equal(b.clone());\\n\",\n    \"println!(\\\"a == b: {}\\\", equal);\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 41,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Original: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 5.0, 3.0, 8.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Where > 4, replace with 0: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, 0.0, 3.0, 0.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\",\n      \"Where > 4, replace with -1: Tensor {\\n\",\n      \"  data:\\n\",\n      \"[1.0, -1.0, 3.0, -1.0],\\n\",\n      \"  shape:  [4],\\n\",\n      \"  device:  Cpu,\\n\",\n      \"  backend:  \\\"ndarray\\\",\\n\",\n      \"  kind:  \\\"Float\\\",\\n\",\n      \"  dtype:  \\\"f32\\\",\\n\",\n      \"}\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"// Conditional selection\\n\",\n    \"let a: Tensor<B, 1> = Tensor::from_floats([1.0, 5.0, 3.0, 8.0], &device);\\n\",\n    \"\\n\",\n    \"// mask_where: where condition is true, use replacement value, else keep original value\\n\",\n    \"let condition = a.clone().greater_elem(4.0);\\n\",\n    \"let result = a.clone().mask_where(condition, Tensor::zeros([4], &device));\\n\",\n    \"println!(\\\"Original: {}\\\", a);\\n\",\n    \"println!(\\\"Where > 4, replace with 0: {}\\\", result);\\n\",\n    \"\\n\",\n    \"// mask_fill: simpler - just replace values matching condition\\n\",\n    \"let result = a.clone().mask_fill(a.clone().greater_elem(4.0), -1.0);\\n\",\n    \"println!(\\\"Where > 4, replace with -1: {}\\\", result);\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Summary\\n\",\n    \"\\n\",\n    \"In this notebook, we covered:\\n\",\n    \"- **Tensor Creation**: empty, zeros, ones, full, from_floats, random\\n\",\n    \"- **Shape Operations**: reshape, transpose, flatten, squeeze, unsqueeze\\n\",\n    \"- **Indexing and Slicing**: slice operation with ranges\\n\",\n    \"- **Math Operations**: add, sub, mul, div, matmul\\n\",\n    \"- **Element-wise Functions**: exp, log, powf_scalar, sin, cos, tan\\n\",\n    \"- **Reduction Operations**: sum, mean, prod, max, min\\n\",\n    \"- **Comparison**: greater, lower, equal, mask_where, mask_fill\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Rust\",\n   \"language\": \"rust\",\n   \"name\": \"rust\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": \"rust\",\n   \"file_extension\": \".rs\",\n   \"mimetype\": \"text/rust\",\n   \"name\": \"rust\",\n   \"pygment_lexer\": \"rust\",\n   \"version\": \"\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/server/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"server\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"webgpu\"]\ncuda = [\"burn/cuda\"]\nwebgpu = [\"burn/webgpu\"]\nvulkan = [\"burn/vulkan\"]\nndarray = [\"burn/ndarray\"]\n\n[dependencies]\ncfg-if = { workspace = true }\nburn = { path = \"../../crates/burn\", version = \"=0.21.0-pre.2\", features = [\"server\"] }\ncubecl = { workspace = true }\n"
  },
  {
    "path": "examples/server/cubecl.toml",
    "content": "[profiling]\nlogger = { log = \"info\", level = \"disabled\" }\n\n[autotune]\nlogger = { log = \"info\", level = \"disabled\" }\n# logger = { log = \"info\", level = \"full\" }\n\n[compilation]\nlogger = { log = \"info\", level = \"disabled\" }\n# logger = { log = \"info\", level = \"full\" }\ncache = \"target\"\n"
  },
  {
    "path": "examples/server/examples/server.rs",
    "content": "fn main() {\n    server::start();\n}\n"
  },
  {
    "path": "examples/server/src/lib.rs",
    "content": "#![recursion_limit = \"141\"]\n\npub fn start() {\n    let port = std::env::var(\"REMOTE_BACKEND_PORT\")\n        .map(|port| match port.parse::<u16>() {\n            Ok(val) => val,\n            Err(err) => panic!(\"Invalid port, got {port} with error {err}\"),\n        })\n        .unwrap_or(3000);\n\n    cfg_if::cfg_if! {\n        if #[cfg(feature = \"ndarray\")]{\n            burn::server::start_websocket::<burn::backend::NdArray>(Default::default(), port);\n        } else if #[cfg(feature = \"cuda\")]{\n            burn::server::start_websocket::<burn::backend::Cuda>(Default::default(), port);\n        } else if #[cfg(feature = \"webgpu\")] {\n            burn::server::start_websocket::<burn::backend::WebGpu>(Default::default(), port);\n        } else if #[cfg(feature = \"vulkan\")] {\n            burn::server::start_websocket::<burn::backend::Vulkan>(Default::default(), port);\n        } else {\n            panic!(\"No backend selected, can't start server on port {port}\");\n        }\n    }\n}\n"
  },
  {
    "path": "examples/simple-regression/Cargo.toml",
    "content": "[package]\nauthors = [\"aasheeshsingh <aasheeshdtu@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"simple-regression\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/dataset\", \"burn/sqlite-bundled\"]\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"burn/ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"burn/ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"burn/ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\"]\nremote = [\"burn/remote\"]\n\n[dependencies]\nburn = {path = \"../../crates/burn\", features=[\"train\"]}\n\n# Serialization\nlog = {workspace = true}\nserde = {workspace = true, features = [\"std\", \"derive\"]}\n\n# Displaying results\ntextplots = \"0.8.7\"\nrgb = \"0.8.52\""
  },
  {
    "path": "examples/simple-regression/README.md",
    "content": "# Regression\n\nThe example shows you how to:\n\n- Define a custom dataset for regression problems. We implement the\n  [California Housing Dataset](https://huggingface.co/datasets/gvlassis/california_housing) from\n  HuggingFace hub. The dataset is also available as part of toy regression datasets in\n  sklearn[datasets](https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset).\n- Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature\n  scaling.\n- Define a Simple NN model for regression using Burn Modules.\n\n> **Note**  \n> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)\n> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)\n> installed on your computer.\n\nThe example can be run like so:\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n# Use the --release flag to really speed up training.\necho \"Using ndarray backend\"\ncargo run --example regression --release --features ndarray                # CPU NdArray Backend - f32 - single thread\ncargo run --example regression --release --features ndarray-blas-openblas  # CPU NdArray Backend - f32 - blas with openblas\ncargo run --example regression --release --features ndarray-blas-netlib    # CPU NdArray Backend - f32 - blas with netlib\necho \"Using tch backend\"\nexport TORCH_CUDA_VERSION=cu128                                            # Set the cuda version\ncargo run --example regression --release --features tch-gpu                # GPU Tch Backend - f32\ncargo run --example regression --release --features tch-cpu                # CPU Tch Backend - f32\necho \"Using wgpu backend\"\ncargo run --example regression --release --features wgpu\n```\n"
  },
  {
    "path": "examples/simple-regression/examples/regression.rs",
    "content": "use burn::{backend::Autodiff, tensor::backend::Backend};\nuse simple_regression::{inference, training};\n\nstatic ARTIFACT_DIR: &str = \"/tmp/burn-example-regression\";\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::ndarray::{NdArray, NdArrayDevice};\n\n    pub fn run() {\n        let device = NdArrayDevice::Cpu;\n        super::run::<NdArray>(device.clone());\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        super::run::<LibTorch>(device);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use burn::backend::wgpu::{Wgpu, WgpuDevice};\n\n    pub fn run() {\n        let device = WgpuDevice::default();\n        super::run::<Wgpu>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n    use simple_regression::training;\n    pub fn run() {\n        let device = LibTorchDevice::Cpu;\n        super::run::<LibTorch>(device);\n    }\n}\n\n#[cfg(feature = \"remote\")]\nmod remote {\n    use burn::backend::{RemoteBackend, remote::RemoteDevice};\n\n    pub fn run() {\n        let device = RemoteDevice::default();\n        super::run::<RemoteBackend>(device);\n    }\n}\n\n/// Train a regression model and predict results on a number of samples.\npub fn run<B: Backend>(device: B::Device) {\n    training::run::<Autodiff<B>>(ARTIFACT_DIR, device.clone());\n    inference::infer::<B>(ARTIFACT_DIR, device)\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"remote\")]\n    remote::run();\n}\n"
  },
  {
    "path": "examples/simple-regression/src/dataset.rs",
    "content": "use burn::{\n    data::{\n        dataloader::batcher::Batcher,\n        dataset::{Dataset, HuggingfaceDatasetLoader, SqliteDataset},\n    },\n    prelude::*,\n};\n\npub const NUM_FEATURES: usize = 8;\n\n// Pre-computed statistics for the housing dataset features\nconst FEATURES_MIN: [f32; NUM_FEATURES] = [0.4999, 1., 0.8461, 0.375, 3., 0.6923, 32.54, -124.35];\nconst FEATURES_MAX: [f32; NUM_FEATURES] = [\n    15., 52., 141.9091, 34.0667, 35682., 1243.3333, 41.95, -114.31,\n];\n\n#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]\npub struct HousingDistrictItem {\n    /// Median income\n    #[serde(rename = \"MedInc\")]\n    pub median_income: f32,\n\n    /// Median house age\n    #[serde(rename = \"HouseAge\")]\n    pub house_age: f32,\n\n    /// Average number of rooms per household\n    #[serde(rename = \"AveRooms\")]\n    pub avg_rooms: f32,\n\n    /// Average number of bedrooms per household\n    #[serde(rename = \"AveBedrms\")]\n    pub avg_bedrooms: f32,\n\n    /// Block group population\n    #[serde(rename = \"Population\")]\n    pub population: f32,\n\n    /// Average number of household members\n    #[serde(rename = \"AveOccup\")]\n    pub avg_occupancy: f32,\n\n    /// Block group latitude\n    #[serde(rename = \"Latitude\")]\n    pub latitude: f32,\n\n    /// Block group longitude\n    #[serde(rename = \"Longitude\")]\n    pub longitude: f32,\n\n    /// Median house value (in 100 000$)\n    #[serde(rename = \"MedHouseVal\")]\n    pub median_house_value: f32,\n}\n\npub struct HousingDataset {\n    dataset: SqliteDataset<HousingDistrictItem>,\n}\n\nimpl Dataset<HousingDistrictItem> for HousingDataset {\n    fn get(&self, index: usize) -> Option<HousingDistrictItem> {\n        self.dataset.get(index)\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\nimpl HousingDataset {\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    pub fn validation() -> Self {\n        Self::new(\"validation\")\n    }\n\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    pub fn new(split: &str) -> Self {\n        let dataset: SqliteDataset<HousingDistrictItem> =\n            HuggingfaceDatasetLoader::new(\"gvlassis/california_housing\")\n                .dataset(split)\n                .unwrap();\n\n        Self { dataset }\n    }\n}\n\n/// Normalizer for the housing dataset.\n#[derive(Clone, Debug)]\npub struct Normalizer<B: Backend> {\n    pub min: Tensor<B, 2>,\n    pub max: Tensor<B, 2>,\n}\n\nimpl<B: Backend> Normalizer<B> {\n    /// Creates a new normalizer.\n    pub fn new(device: &B::Device, min: &[f32], max: &[f32]) -> Self {\n        let min = Tensor::<B, 1>::from_floats(min, device).unsqueeze();\n        let max = Tensor::<B, 1>::from_floats(max, device).unsqueeze();\n        Self { min, max }\n    }\n\n    /// Normalizes the input image according to the housing dataset min/max.\n    pub fn normalize(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        (input - self.min.clone()) / (self.max.clone() - self.min.clone())\n    }\n\n    /// Returns a new normalizer on the given device.\n    pub fn to_device(&self, device: &B::Device) -> Self {\n        Self {\n            min: self.min.clone().to_device(device),\n            max: self.max.clone().to_device(device),\n        }\n    }\n}\n\n#[derive(Clone, Debug)]\npub struct HousingBatcher<B: Backend> {\n    normalizer: Normalizer<B>,\n}\n\n#[derive(Clone, Debug)]\npub struct HousingBatch<B: Backend> {\n    pub inputs: Tensor<B, 2>,\n    pub targets: Tensor<B, 1>,\n}\n\nimpl<B: Backend> HousingBatcher<B> {\n    pub fn new(device: B::Device) -> Self {\n        Self {\n            normalizer: Normalizer::new(&device, &FEATURES_MIN, &FEATURES_MAX),\n        }\n    }\n}\n\nimpl<B: Backend> Batcher<B, HousingDistrictItem, HousingBatch<B>> for HousingBatcher<B> {\n    fn batch(&self, items: Vec<HousingDistrictItem>, device: &B::Device) -> HousingBatch<B> {\n        let mut inputs: Vec<Tensor<B, 2>> = Vec::new();\n\n        for item in items.iter() {\n            let input_tensor = Tensor::<B, 1>::from_floats(\n                [\n                    item.median_income,\n                    item.house_age,\n                    item.avg_rooms,\n                    item.avg_bedrooms,\n                    item.population,\n                    item.avg_occupancy,\n                    item.latitude,\n                    item.longitude,\n                ],\n                device,\n            );\n\n            inputs.push(input_tensor.unsqueeze());\n        }\n\n        let inputs = Tensor::cat(inputs, 0);\n        let inputs = self.normalizer.to_device(device).normalize(inputs);\n\n        let targets = items\n            .iter()\n            .map(|item| Tensor::<B, 1>::from_floats([item.median_house_value], device))\n            .collect();\n\n        let targets = Tensor::cat(targets, 0);\n\n        HousingBatch { inputs, targets }\n    }\n}\n"
  },
  {
    "path": "examples/simple-regression/src/inference.rs",
    "content": "use burn::{\n    data::{dataloader::batcher::Batcher, dataset::Dataset},\n    module::Module,\n    record::{NoStdTrainingRecorder, Recorder},\n    tensor::backend::Backend,\n};\nuse rgb::RGB8;\nuse textplots::{Chart, ColorPlot, Shape};\n\nuse crate::{\n    dataset::{HousingBatcher, HousingDataset, HousingDistrictItem},\n    model::{RegressionModelConfig, RegressionModelRecord},\n};\n\npub fn infer<B: Backend>(artifact_dir: &str, device: B::Device) {\n    let record: RegressionModelRecord<B> = NoStdTrainingRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model should exist; run train first\");\n\n    let model = RegressionModelConfig::new()\n        .init(&device)\n        .load_record(record);\n\n    // Use a sample of 1000 items from the test split\n    let dataset = HousingDataset::test();\n    let items: Vec<HousingDistrictItem> = dataset.iter().take(1000).collect();\n\n    let batcher = HousingBatcher::new(device.clone());\n    let batch = batcher.batch(items.clone(), &device);\n    let predicted = model.forward(batch.inputs);\n    let targets = batch.targets;\n\n    // Display the predicted vs expected values\n    let predicted = predicted.squeeze_dim::<1>(1).into_data();\n    let expected = targets.into_data();\n\n    let points = predicted\n        .iter::<f32>()\n        .zip(expected.iter::<f32>())\n        .collect::<Vec<_>>();\n\n    println!(\"Predicted vs. Expected Median House Value (in 100,000$)\");\n    Chart::new_with_y_range(120, 60, 0., 5., 0., 5.)\n        .linecolorplot(\n            &Shape::Points(&points),\n            RGB8 {\n                r: 255,\n                g: 85,\n                b: 85,\n            },\n        )\n        .display();\n\n    // Print a single numeric value as an example\n    println!(\"Predicted {} Expected {}\", points[0].0, points[0].1);\n}\n"
  },
  {
    "path": "examples/simple-regression/src/lib.rs",
    "content": "pub mod dataset;\npub mod inference;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/simple-regression/src/model.rs",
    "content": "use crate::dataset::{HousingBatch, NUM_FEATURES};\nuse burn::{\n    nn::{\n        Linear, LinearConfig, Relu,\n        loss::{MseLoss, Reduction::Mean},\n    },\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n    train::{InferenceStep, RegressionOutput, TrainOutput, TrainStep},\n};\n\n#[derive(Module, Debug)]\npub struct RegressionModel<B: Backend> {\n    input_layer: Linear<B>,\n    output_layer: Linear<B>,\n    activation: Relu,\n}\n\n#[derive(Config, Debug)]\npub struct RegressionModelConfig {\n    #[config(default = 64)]\n    pub hidden_size: usize,\n}\n\nimpl RegressionModelConfig {\n    pub fn init<B: Backend>(&self, device: &B::Device) -> RegressionModel<B> {\n        let input_layer = LinearConfig::new(NUM_FEATURES, self.hidden_size)\n            .with_bias(true)\n            .init(device);\n        let output_layer = LinearConfig::new(self.hidden_size, 1)\n            .with_bias(true)\n            .init(device);\n\n        RegressionModel {\n            input_layer,\n            output_layer,\n            activation: Relu::new(),\n        }\n    }\n}\n\nimpl<B: Backend> RegressionModel<B> {\n    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        let x = self.input_layer.forward(input);\n        let x = self.activation.forward(x);\n        self.output_layer.forward(x)\n    }\n\n    pub fn forward_step(&self, item: HousingBatch<B>) -> RegressionOutput<B> {\n        let targets: Tensor<B, 2> = item.targets.unsqueeze_dim(1);\n        let output: Tensor<B, 2> = self.forward(item.inputs);\n\n        let loss = MseLoss::new().forward(output.clone(), targets.clone(), Mean);\n\n        RegressionOutput {\n            loss,\n            output,\n            targets,\n        }\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for RegressionModel<B> {\n    type Input = HousingBatch<B>;\n    type Output = RegressionOutput<B>;\n\n    fn step(&self, item: HousingBatch<B>) -> TrainOutput<RegressionOutput<B>> {\n        let item = self.forward_step(item);\n\n        TrainOutput::new(self, item.loss.backward(), item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for RegressionModel<B> {\n    type Input = HousingBatch<B>;\n    type Output = RegressionOutput<B>;\n\n    fn step(&self, item: HousingBatch<B>) -> RegressionOutput<B> {\n        self.forward_step(item)\n    }\n}\n"
  },
  {
    "path": "examples/simple-regression/src/training.rs",
    "content": "use crate::dataset::{HousingBatcher, HousingDataset};\nuse crate::model::RegressionModelConfig;\nuse burn::optim::AdamConfig;\nuse burn::train::{Learner, SupervisedTraining};\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::Dataset},\n    prelude::*,\n    record::{CompactRecorder, NoStdTrainingRecorder},\n    tensor::backend::AutodiffBackend,\n    train::metric::LossMetric,\n};\n\n#[derive(Config, Debug)]\npub struct ExpConfig {\n    #[config(default = 100)]\n    pub num_epochs: usize,\n\n    #[config(default = 2)]\n    pub num_workers: usize,\n\n    #[config(default = 1337)]\n    pub seed: u64,\n\n    pub optimizer: AdamConfig,\n\n    #[config(default = 256)]\n    pub batch_size: usize,\n}\n\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts before to get an accurate learner summary\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\npub fn run<B: AutodiffBackend>(artifact_dir: &str, device: B::Device) {\n    create_artifact_dir(artifact_dir);\n\n    // Config\n    let optimizer = AdamConfig::new();\n    let config = ExpConfig::new(optimizer);\n    let model = RegressionModelConfig::new().init(&device);\n    B::seed(&device, config.seed);\n\n    // Define train/valid datasets and dataloaders\n    let train_dataset = HousingDataset::train();\n    let valid_dataset = HousingDataset::validation();\n\n    println!(\"Train Dataset Size: {}\", train_dataset.len());\n    println!(\"Valid Dataset Size: {}\", valid_dataset.len());\n\n    let batcher_train = HousingBatcher::<B>::new(device.clone());\n\n    let batcher_test = HousingBatcher::<B::InnerBackend>::new(device.clone());\n\n    let dataloader_train = DataLoaderBuilder::new(batcher_train)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(train_dataset);\n\n    let dataloader_test = DataLoaderBuilder::new(batcher_test)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(valid_dataset);\n\n    // Model\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metric_train_numeric(LossMetric::new())\n        .metric_valid_numeric(LossMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let result = training.launch(Learner::new(model, config.optimizer.init(), 1e-3));\n\n    config\n        .save(format!(\"{artifact_dir}/config.json\").as_str())\n        .unwrap();\n\n    result\n        .model\n        .save_file(\n            format!(\"{artifact_dir}/model\"),\n            &NoStdTrainingRecorder::new(),\n        )\n        .expect(\"Failed to save trained model\");\n}\n"
  },
  {
    "path": "examples/text-classification/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"text-classification\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = []\nf16 = []\nflex32 = []\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"burn/ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"burn/ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"burn/ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\"]\nvulkan = [\"burn/vulkan\"]\nremote = [\"burn/remote\"]\ncuda = [\"burn/cuda\"]\nrocm = [\"burn/rocm\"]\nmetal = [\"burn/metal\"]\nddp = [\"burn/collective\"]\n\n[dependencies]\n# Burn\nburn = { path = \"../../crates/burn\", features = [\n    \"train\",\n    \"tui\",\n    \"sqlite-bundled\",\n    \"metrics\",\n    \"ndarray\",\n    \"autotune\",\n    # \"fusion\",\n    \"std\",\n], default-features = false }\nlog = { workspace = true }\n\n# Tokenizer\ntokenizers = { version = \"0.22.2\", default-features = false, features = [\n    \"onig\",\n    \"http\",\n] }\n\n# Utils\nderive-new = { workspace = true }\nserde = { workspace = true, features = [\"std\", \"derive\"] }\n"
  },
  {
    "path": "examples/text-classification/README.md",
    "content": "# Text Classification\n\nThis project provides an example implementation for training and inferencing text classification\nmodels on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.\n\n> **Note**  \n> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)\n> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)\n> installed on your computer.\n\n## Dataset Details\n\n- AG News: The AG News dataset is a collection of news articles from more than 2000 news sources.\n  This library helps you load and process this dataset, categorizing articles into four classes:\n  \"World\", \"Sports\", \"Business\", and \"Technology\".\n\n- DbPedia: The DbPedia dataset is a large multi-class text classification dataset extracted from\n  Wikipedia. This library helps you load and process this dataset, categorizing articles into 14\n  classes including \"Company\", \"Educational Institution\", \"Artist\", among others.\n\n# Usage\n\n## Torch GPU backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n# Use the f16 feature if your CUDA device supports FP16 (half precision) operations. May not work well on every device.\n\nexport TORCH_CUDA_VERSION=cu128  # Set the cuda version (CUDA users)\n\n# AG News\ncargo run --example ag-news-train --release --features tch-gpu   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features tch-gpu   # Run inference on the ag news dataset\n\n# DbPedia\ncargo run --example db-pedia-train --release --features tch-gpu  # Train on the db pedia dataset\ncargo run --example db-pedia-infer --release --features tch-gpu  # Run inference db pedia dataset\n```\n\n## Torch CPU backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n\n# AG News\ncargo run --example ag-news-train --release --features tch-cpu   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features tch-cpu   # Run inference on the ag news dataset\n\n# DbPedia\ncargo run --example db-pedia-train --release --features tch-cpu  # Train on the db pedia dataset\ncargo run --example db-pedia-infer --release --features tch-cpu  # Run inference db pedia dataset\n```\n\n## ndarray backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n\n# Replace ndarray by ndarray-blas-netlib, ndarray-blas-openblas or ndarray-blas-accelerate for different matmul techniques\n\n# AG News\ncargo run --example ag-news-train --release --features ndarray   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features ndarray   # Run inference on the ag news dataset\n\n# DbPedia\ncargo run --example db-pedia-train --release --features ndarray  # Train on the db pedia dataset\ncargo run --example db-pedia-infer --release --features ndarray  # Run inference db pedia dataset\n```\n\n## WGPU backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n\n# AG News\ncargo run --example ag-news-train --release --features wgpu   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features wgpu   # Run inference on the ag news dataset\n\n# DbPedia\ncargo run --example db-pedia-train --release --features wgpu  # Train on the db pedia dataset\ncargo run --example db-pedia-infer --release --features wgpu  # Run inference db pedia dataset\n```\n\n## CUDA backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n# Add the f16 feature to run in f16. \n\n# AG News\ncargo run --example ag-news-train --release --features cuda   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features cuda   # Run inference on the ag news dataset\n```\n\n## Metal backend\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\n# Add the f16 feature to run in f16. \n\n# AG News\ncargo run --example ag-news-train --release --features metal   # Train on the ag news dataset\ncargo run --example ag-news-infer --release --features metal   # Run inference on the ag news dataset\n```\n"
  },
  {
    "path": "examples/text-classification/cubecl.toml",
    "content": "[profiling]\nlogger = { log = \"info\", level = \"disabled\" }\n\n[autotune]\nlevel = \"balanced\"\ncache = \"target\"\nlogger = { info = true, level = \"full\" }\n\n[compilation]\nlogger = { level = \"disabled\" }\ncache = \"target\"\n\n[memory]\nlogger = { level = \"disabled\", file = \"/tmp/memory.log\" }\npersistent_memory = \"enabled\"\n\n[streaming]\nmax_streams = 8\n"
  },
  {
    "path": "examples/text-classification/examples/ag-news-infer.rs",
    "content": "#![recursion_limit = \"256\"]\n\nuse burn::tensor::backend::Backend;\nuse text_classification::AgNewsDataset;\n\n#[cfg(not(feature = \"f16\"))]\n#[allow(dead_code)]\ntype ElemType = f32;\n#[cfg(feature = \"f16\")]\ntype ElemType = burn::tensor::f16;\n\npub fn launch<B: Backend>(device: B::Device) {\n    text_classification::inference::infer::<B, AgNewsDataset>(\n        device,\n        \"/tmp/text-classification-ag-news\",\n        // Samples from the test dataset, but you are free to test with your own text.\n        vec![\n            \"Jays power up to take finale Contrary to popular belief, the power never really \\\n             snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \\\n             took some extra time for the batting orders to provide some extra wattage.\"\n                .to_string(),\n            \"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \\\n             man to death and 14 others to prison terms for a series of attacks and terrorist \\\n             plots in 2002, including the bombing of a French oil tanker.\"\n                .to_string(),\n            \"IBM puts grids to work at U.S. Open IBM will put a collection of its On \\\n             Demand-related products and technologies to this test next week at the U.S. Open \\\n             tennis championships, implementing a grid-based infrastructure capable of running \\\n             multiple workloads including two not associated with the tournament.\"\n                .to_string(),\n        ],\n    );\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::ndarray::{NdArray, NdArrayDevice};\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use crate::{ElemType, launch};\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<LibTorch<ElemType>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use crate::{ElemType, launch};\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    pub fn run() {\n        launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::{ElemType, launch};\n    use burn::backend::wgpu::{Wgpu, WgpuDevice};\n\n    pub fn run() {\n        launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());\n    }\n}\n\n#[cfg(feature = \"metal\")]\nmod metal {\n    use crate::{ElemType, launch};\n    use burn::backend::metal::{Metal, MetalDevice};\n\n    pub fn run() {\n        launch::<Metal<ElemType, i32>>(MetalDevice::default());\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::{ElemType, launch};\n    use burn::backend::{Cuda, cuda::CudaDevice};\n\n    pub fn run() {\n        launch::<Cuda<ElemType, i32>>(CudaDevice::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n}\n"
  },
  {
    "path": "examples/text-classification/examples/ag-news-train.rs",
    "content": "#![recursion_limit = \"256\"]\n\nuse burn::{\n    nn::transformer::TransformerEncoderConfig,\n    optim::{AdamConfig, decay::WeightDecayConfig},\n    prelude::*,\n    tensor::backend::{AutodiffBackend, DeviceId},\n};\n\nuse text_classification::{AgNewsDataset, training::ExperimentConfig};\n\n#[cfg(not(any(feature = \"f16\", feature = \"flex32\")))]\n#[allow(unused)]\ntype ElemType = f32;\n#[cfg(feature = \"f16\")]\ntype ElemType = burn::tensor::f16;\n#[cfg(feature = \"flex32\")]\ntype ElemType = burn::tensor::flex32;\n\npub fn launch_multi<B: AutodiffBackend>() {\n    let type_id = 0;\n    let num_devices = B::Device::device_count(type_id);\n\n    let devices = (0..num_devices)\n        .map(|i| B::Device::from_id(DeviceId::new(type_id, i as u32)))\n        .collect();\n\n    launch::<B>(devices)\n}\n\npub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {\n    let config = ExperimentConfig::new(\n        TransformerEncoderConfig::new(256, 1024, 8, 4)\n            .with_norm_first(true)\n            .with_quiet_softmax(true),\n        AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),\n    );\n\n    text_classification::training::train::<B, AgNewsDataset>(\n        devices,\n        AgNewsDataset::train(),\n        AgNewsDataset::test(),\n        config,\n        \"/tmp/text-classification-ag-news\",\n    );\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::{\n        Autodiff,\n        ndarray::{NdArray, NdArrayDevice},\n    };\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use crate::{ElemType, launch};\n    use burn::backend::autodiff::checkpoint::strategy::BalancedCheckpointing;\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::{ElemType, launch};\n    use burn::backend::{Autodiff, Wgpu};\n\n    pub fn run() {\n        launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![Default::default()]);\n    }\n}\n\n#[cfg(feature = \"vulkan\")]\nmod vulkan {\n    use crate::{ElemType, launch};\n    use burn::backend::{Autodiff, Vulkan, autodiff::checkpoint::strategy::BalancedCheckpointing};\n\n    pub fn run() {\n        type B = Autodiff<Vulkan<ElemType, i32>, BalancedCheckpointing>;\n        launch::<B>(vec![Default::default()]);\n    }\n}\n\n#[cfg(feature = \"metal\")]\nmod metal {\n    use crate::{ElemType, launch};\n    use burn::backend::{Autodiff, Metal};\n\n    pub fn run() {\n        launch::<Autodiff<Metal<ElemType, i32>>>(vec![Default::default()]);\n    }\n}\n\n#[cfg(feature = \"remote\")]\nmod remote {\n    use crate::{ElemType, launch};\n    use burn::backend::{Autodiff, RemoteBackend};\n\n    pub fn run() {\n        launch::<Autodiff<RemoteBackend>>(vec![Default::default()]);\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::{ElemType, launch_multi};\n    use burn::backend::{Autodiff, Cuda, autodiff::checkpoint::strategy::BalancedCheckpointing};\n\n    pub fn run() {\n        launch_multi::<Autodiff<Cuda<ElemType, i32>, BalancedCheckpointing>>();\n    }\n}\n\n#[cfg(feature = \"rocm\")]\nmod rocm {\n    use crate::{ElemType, launch};\n    use burn::backend::{Autodiff, Rocm, autodiff::checkpoint::strategy::BalancedCheckpointing};\n\n    pub fn run() {\n        launch::<Autodiff<Rocm<ElemType, i32>, BalancedCheckpointing>>(vec![Default::default()]);\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n    #[cfg(feature = \"rocm\")]\n    rocm::run();\n    #[cfg(feature = \"remote\")]\n    remote::run();\n    #[cfg(feature = \"vulkan\")]\n    vulkan::run();\n    #[cfg(feature = \"metal\")]\n    metal::run();\n}\n"
  },
  {
    "path": "examples/text-classification/examples/db-pedia-infer.rs",
    "content": "use text_classification::DbPediaDataset;\n\nuse burn::tensor::backend::Backend;\n\n#[cfg(not(feature = \"f16\"))]\n#[allow(dead_code)]\ntype ElemType = f32;\n#[cfg(feature = \"f16\")]\ntype ElemType = burn::tensor::f16;\n\npub fn launch<B: Backend>(device: B::Device) {\n    text_classification::inference::infer::<B, DbPediaDataset>(\n        device,\n        \"/tmp/text-classification-db-pedia\",\n        // Samples from the test dataset, but you are free to test with your own text.\n        vec![\n            \" Magnus Eriksson is a Swedish former footballer who played as a forward.\".to_string(),\n            \"Crossbeam Systems is headquartered in Boxborough Massachusetts and has offices in \\\n             Europe Latin America and Asia Pacific. Crossbeam Systems was acquired by Blue Coat \\\n             Systems in December 2012 and the Crossbeam brand has been fully absorbed into Blue \\\n             Coat.\"\n                .to_string(),\n            \" Zia is the sequel to the award-winning Island of the Blue Dolphins by Scott O'Dell. \\\n             It was published in 1976 sixteen years after the publication of the first novel.\"\n                .to_string(),\n        ],\n    );\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::ndarray::{NdArray, NdArrayDevice};\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<LibTorch<ElemType>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::tch::{LibTorch, LibTorchDevice};\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use burn::backend::wgpu::{Wgpu, WgpuDevice};\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n}\n"
  },
  {
    "path": "examples/text-classification/examples/db-pedia-train.rs",
    "content": "use burn::{\n    nn::transformer::TransformerEncoderConfig,\n    optim::{AdamConfig, decay::WeightDecayConfig},\n    tensor::backend::AutodiffBackend,\n};\n\nuse text_classification::{DbPediaDataset, training::ExperimentConfig};\n\n#[cfg(not(feature = \"f16\"))]\n#[allow(dead_code)]\ntype ElemType = f32;\n#[cfg(feature = \"f16\")]\ntype ElemType = burn::tensor::f16;\n\npub fn launch<B: AutodiffBackend>(devices: Vec<B::Device>) {\n    let config = ExperimentConfig::new(\n        TransformerEncoderConfig::new(256, 1024, 8, 4).with_norm_first(true),\n        AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))),\n    );\n\n    text_classification::training::train::<B, DbPediaDataset>(\n        devices,\n        DbPediaDataset::train(),\n        DbPediaDataset::test(),\n        config,\n        \"/tmp/text-classification-db-pedia\",\n    );\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use crate::{ElemType, launch};\n    use burn::backend::{\n        Autodiff,\n        ndarray::{NdArray, NdArrayDevice},\n    };\n\n    pub fn run() {\n        launch::<Autodiff<NdArray<ElemType>>>(vec![NdArrayDevice::Cpu]);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<Autodiff<LibTorch<ElemType>>>(vec![device]);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<Autodiff<LibTorch<ElemType>>>(vec![LibTorchDevice::Cpu]);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use burn::backend::{\n        Autodiff,\n        wgpu::{Wgpu, WgpuDevice},\n    };\n\n    use crate::{ElemType, launch};\n\n    pub fn run() {\n        launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n}\n"
  },
  {
    "path": "examples/text-classification/src/data/batcher.rs",
    "content": "// The module defines two structs TextClassificationTrainingBatch and TextClassificationInferenceBatch\n// to handle batches of data during training and inference respectively. The TextClassificationBatcher\n// struct is implemented for creating these batches. It is parameterized on the type B: Backend to\n// support different computation backends (e.g., CPU, CUDA).\n\n// Two implementations of the Batcher trait are provided for TextClassificationBatcher, one for creating\n// training batches and one for creating inference batches. In each implementation, the batch function is\n// defined to convert a vector of items into a batch. For training, the items are instances of\n// TextClassificationItem and include both the text and the corresponding label.\n// For inference, the items are simply strings without labels. The function tokenizes the text,\n// generates a padding mask, and returns a batch object.\n\nuse super::{dataset::TextClassificationItem, tokenizer::Tokenizer};\nuse burn::{\n    data::dataloader::batcher::Batcher,\n    nn::attention::{SeqLengthOption, generate_padding_mask},\n    prelude::*,\n};\nuse std::sync::Arc;\n\n/// Struct for batching text classification items\n#[derive(Clone, new)]\npub struct TextClassificationBatcher {\n    tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs\n    seq_length: SeqLengthOption,   // Sequence length option for tokenized text\n}\n\n/// Struct for training batch in text classification task\n#[derive(Debug, Clone, new)]\npub struct TextClassificationTrainingBatch<B: Backend> {\n    pub tokens: Tensor<B, 2, Int>,    // Tokenized text\n    pub labels: Tensor<B, 1, Int>,    // Labels of the text\n    pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text\n}\n\n/// Struct for inference batch in text classification task\n#[derive(Debug, Clone, new)]\npub struct TextClassificationInferenceBatch<B: Backend> {\n    pub tokens: Tensor<B, 2, Int>,    // Tokenized text\n    pub mask_pad: Tensor<B, 2, Bool>, // Padding mask for the tokenized text\n}\n\n/// Implement Batcher trait for TextClassificationBatcher struct for training\nimpl<B: Backend> Batcher<B, TextClassificationItem, TextClassificationTrainingBatch<B>>\n    for TextClassificationBatcher\n{\n    /// Batches a vector of text classification items into a training batch\n    fn batch(\n        &self,\n        items: Vec<TextClassificationItem>,\n        device: &B::Device,\n    ) -> TextClassificationTrainingBatch<B> {\n        let mut tokens_list = Vec::with_capacity(items.len());\n        let mut labels_list = Vec::with_capacity(items.len());\n\n        // Tokenize text and create label tensor for each item\n        for item in items {\n            tokens_list.push(self.tokenizer.encode(&item.text));\n            labels_list.push(Tensor::from_data(\n                TensorData::from([(item.label as i64).elem::<B::IntElem>()]),\n                device,\n            ));\n        }\n\n        // Generate padding mask for tokenized text\n        let mask = generate_padding_mask(\n            self.tokenizer.pad_token(),\n            tokens_list,\n            self.seq_length,\n            device,\n        );\n\n        // Create and return training batch\n        TextClassificationTrainingBatch {\n            tokens: mask.tensor,\n            labels: Tensor::cat(labels_list, 0),\n            mask_pad: mask.mask,\n        }\n    }\n}\n\n/// Implement Batcher trait for TextClassificationBatcher struct for inference\nimpl<B: Backend> Batcher<B, String, TextClassificationInferenceBatch<B>>\n    for TextClassificationBatcher\n{\n    /// Batches a vector of strings into an inference batch\n    fn batch(&self, items: Vec<String>, device: &B::Device) -> TextClassificationInferenceBatch<B> {\n        let mut tokens_list = Vec::with_capacity(items.len());\n\n        // Tokenize each string\n        for item in items {\n            tokens_list.push(self.tokenizer.encode(&item));\n        }\n\n        // Generate padding mask for tokenized text\n        let mask = generate_padding_mask(\n            self.tokenizer.pad_token(),\n            tokens_list,\n            self.seq_length,\n            device,\n        );\n\n        // Create and return inference batch\n        TextClassificationInferenceBatch {\n            tokens: mask.tensor.to_device(device),\n            mask_pad: mask.mask.to_device(device),\n        }\n    }\n}\n"
  },
  {
    "path": "examples/text-classification/src/data/dataset.rs",
    "content": "// The AgNewsDataset and DbPediaDataset structs are examples of specific text\n// classification datasets.  Each dataset struct has a field for the underlying\n// SQLite dataset and implements methods for accessing and processing the data.\n// Each dataset is also provided with specific information about its classes via\n// the TextClassificationDataset trait. These implementations are designed to be used\n// with a machine learning framework for tasks such as training a text classification model.\n\nuse burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader};\n\n// Define a struct for text classification items\n#[derive(new, Clone, Debug)]\npub struct TextClassificationItem {\n    pub text: String, // The text for classification\n    pub label: usize, // The label of the text (classification category)\n}\n\n// Trait for text classification datasets\npub trait TextClassificationDataset: Dataset<TextClassificationItem> {\n    fn num_classes() -> usize; // Returns the number of unique classes in the dataset\n    fn class_name(label: usize) -> String; // Returns the name of the class given its label\n}\n\n// Struct for items in the AG News dataset\n#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]\npub struct AgNewsItem {\n    pub text: String, // The text for classification\n    pub label: usize, // The label of the text (classification category)\n}\n\n// Struct for the AG News dataset\npub struct AgNewsDataset {\n    dataset: SqliteDataset<AgNewsItem>, // Underlying SQLite dataset\n}\n\n// Implement the Dataset trait for the AG News dataset\nimpl Dataset<TextClassificationItem> for AgNewsDataset {\n    /// Returns a specific item from the dataset\n    fn get(&self, index: usize) -> Option<TextClassificationItem> {\n        self.dataset\n            .get(index)\n            .map(|item| TextClassificationItem::new(item.text, item.label)) // Map AgNewsItems to TextClassificationItems\n    }\n\n    /// Returns the length of the dataset\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\n// Implement methods for constructing the AG News dataset\nimpl AgNewsDataset {\n    /// Returns the training portion of the dataset\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    /// Returns the testing portion of the dataset\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    /// Constructs the dataset from a split (either \"train\" or \"test\")\n    pub fn new(split: &str) -> Self {\n        let dataset: SqliteDataset<AgNewsItem> = HuggingfaceDatasetLoader::new(\"ag_news\")\n            .dataset(split)\n            .unwrap();\n        Self { dataset }\n    }\n}\n\n/// Implements the TextClassificationDataset trait for the AG News dataset\nimpl TextClassificationDataset for AgNewsDataset {\n    /// Returns the number of unique classes in the dataset\n    fn num_classes() -> usize {\n        4\n    }\n\n    /// Returns the name of a class given its label\n    fn class_name(label: usize) -> String {\n        match label {\n            0 => \"World\",\n            1 => \"Sports\",\n            2 => \"Business\",\n            3 => \"Technology\",\n            _ => panic!(\"invalid class\"),\n        }\n        .to_string()\n    }\n}\n\n/// Struct for items in the DbPedia dataset\n#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]\npub struct DbPediaItem {\n    pub title: String,   // The title of the item\n    pub content: String, // The content of the item\n    pub label: usize,    // The label of the item (classification category)\n}\n\n/// Struct for the DbPedia dataset\npub struct DbPediaDataset {\n    dataset: SqliteDataset<DbPediaItem>, // Underlying SQLite dataset\n}\n\n/// Implements the Dataset trait for the DbPedia dataset\nimpl Dataset<TextClassificationItem> for DbPediaDataset {\n    /// Returns a specific item from the dataset\n    fn get(&self, index: usize) -> Option<TextClassificationItem> {\n        self.dataset.get(index).map(|item| {\n            TextClassificationItem::new(\n                format!(\"Title: {} - Content: {}\", item.title, item.content),\n                item.label,\n            )\n        })\n    }\n\n    /// Returns the length of the dataset\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\n/// Implement methods for constructing the DbPedia dataset\nimpl DbPediaDataset {\n    /// Returns the training portion of the dataset\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    /// Returns the testing portion of the dataset\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n\n    /// Constructs the dataset from a split (either \"train\" or \"test\")\n    pub fn new(split: &str) -> Self {\n        let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new(\"dbpedia_14\")\n            .dataset(split)\n            .unwrap();\n        Self { dataset }\n    }\n}\n\n/// Implement the TextClassificationDataset trait for the DbPedia dataset\nimpl TextClassificationDataset for DbPediaDataset {\n    /// Returns the number of unique classes in the dataset\n    fn num_classes() -> usize {\n        14\n    }\n\n    /// Returns the name of a class given its label\n    fn class_name(label: usize) -> String {\n        match label {\n            0 => \"Company\",\n            1 => \"EducationalInstitution\",\n            2 => \"Artist\",\n            3 => \"Athlete\",\n            4 => \"OfficeHolder\",\n            5 => \"MeanOfTransportation\",\n            6 => \"Building\",\n            7 => \"NaturalPlace\",\n            8 => \"Village\",\n            9 => \"Animal\",\n            10 => \"Plant\",\n            11 => \"Album\",\n            12 => \"Film\",\n            13 => \"WrittenWork\",\n            _ => panic!(\"invalid class\"),\n        }\n        .to_string()\n    }\n}\n"
  },
  {
    "path": "examples/text-classification/src/data/mod.rs",
    "content": "mod batcher;\nmod dataset;\nmod tokenizer;\n\npub use batcher::*;\npub use dataset::*;\npub use tokenizer::*;\n"
  },
  {
    "path": "examples/text-classification/src/data/tokenizer.rs",
    "content": "// This module defines a trait `Tokenizer` that represents a common interface for all tokenizer\n// types used in the text classification library. A specific implementation of this trait,\n// `BertCasedTokenizer`, uses the BERT cased tokenization strategy provided by the `tokenizers` library.\n\n// This trait represents the common interface for all tokenizer types.\n// The `Send + Sync` bounds are necessary for allowing these operations\n// to work across thread boundaries.\n#[allow(dead_code)]\npub trait Tokenizer: Send + Sync {\n    /// Converts a text string into a sequence of tokens.\n    fn encode(&self, value: &str) -> Vec<usize>;\n\n    /// Converts a sequence of tokens back into a text string.\n    fn decode(&self, tokens: &[usize]) -> String;\n\n    /// Gets the size of the tokenizer's vocabulary.\n    fn vocab_size(&self) -> usize;\n\n    /// Gets the token used for padding sequences to a consistent length.\n    fn pad_token(&self) -> usize;\n\n    /// Gets the string representation of the padding token.\n    /// The default implementation uses `decode` on the padding token.\n    fn pad_token_value(&self) -> String {\n        self.decode(&[self.pad_token()])\n    }\n}\n\n/// Struct represents a specific tokenizer using the BERT cased tokenization strategy.\npub struct BertCasedTokenizer {\n    // The underlying tokenizer from the `tokenizers` library.\n    tokenizer: tokenizers::Tokenizer,\n}\n\n// Default implementation for creating a new BertCasedTokenizer.\n// This uses a pretrained BERT cased tokenizer model.\nimpl Default for BertCasedTokenizer {\n    fn default() -> Self {\n        Self {\n            tokenizer: tokenizers::Tokenizer::from_pretrained(\"bert-base-cased\", None).unwrap(),\n        }\n    }\n}\n\n// Implementation of the Tokenizer trait for BertCasedTokenizer.\nimpl Tokenizer for BertCasedTokenizer {\n    // Convert a text string into a sequence of tokens using the BERT cased tokenization strategy.\n    fn encode(&self, value: &str) -> Vec<usize> {\n        let tokens = self.tokenizer.encode(value, true).unwrap();\n        tokens.get_ids().iter().map(|t| *t as usize).collect()\n    }\n\n    /// Converts a sequence of tokens back into a text string.\n    fn decode(&self, tokens: &[usize]) -> String {\n        let tokens = tokens.iter().map(|t| *t as u32).collect::<Vec<u32>>();\n        self.tokenizer.decode(&tokens, false).unwrap()\n    }\n\n    /// Gets the size of the BERT cased tokenizer's vocabulary.\n    fn vocab_size(&self) -> usize {\n        self.tokenizer.get_vocab_size(true)\n    }\n\n    /// Gets the token used for padding sequences to a consistent length.\n    fn pad_token(&self) -> usize {\n        self.tokenizer.token_to_id(\"[PAD]\").unwrap() as usize\n    }\n}\n"
  },
  {
    "path": "examples/text-classification/src/inference.rs",
    "content": "// This module defines the inference process for a text classification model.\n// It loads a model and its configuration from a directory, and uses a tokenizer\n// and a batcher to prepare the input data. The model is then used to make predictions\n// on the input samples, and the results are printed out for each sample.\n// Import required modules and types\n\nuse crate::{\n    data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer},\n    model::TextClassificationModelConfig,\n    training::ExperimentConfig,\n};\nuse burn::{\n    data::dataloader::batcher::Batcher,\n    prelude::*,\n    record::{CompactRecorder, Recorder},\n};\nuse std::sync::Arc;\n\n// Define inference function\npub fn infer<B: Backend, D: TextClassificationDataset + 'static>(\n    device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)\n    artifact_dir: &str, // Directory containing model and config files\n    samples: Vec<String>, // Text samples for inference\n) {\n    // Load experiment configuration\n    let config = ExperimentConfig::load(format!(\"{artifact_dir}/config.json\").as_str())\n        .expect(\"Config file present\");\n\n    // Initialize tokenizer\n    let tokenizer = Arc::new(BertCasedTokenizer::default());\n\n    // Get number of classes from dataset\n    let n_classes = D::num_classes();\n\n    // Initialize batcher for batching samples\n    let batcher = Arc::new(TextClassificationBatcher::new(\n        tokenizer.clone(),\n        config.seq_length,\n    ));\n\n    // Load pre-trained model weights\n    println!(\"Loading weights ...\");\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/model\").into(), &device)\n        .expect(\"Trained model weights tb\");\n\n    // Create model using loaded weights\n    println!(\"Creating model ...\");\n    let model = TextClassificationModelConfig::new(\n        config.transformer,\n        n_classes,\n        tokenizer.vocab_size(),\n        config.seq_length,\n    )\n    .init::<B>(&device)\n    .load_record(record); // Initialize model with loaded weights\n\n    // Run inference on the given text samples\n    println!(\"Running inference ...\");\n    let item = batcher.batch(samples.clone(), &device); // Batch samples using the batcher\n    let predictions = model.infer(item); // Get model predictions\n\n    // Print out predictions for each sample\n    for (i, text) in samples.into_iter().enumerate() {\n        #[allow(clippy::single_range_in_vec_init)]\n        let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample\n        let logits = prediction.to_data(); // Convert prediction tensor to data\n        let class_index = prediction.argmax(1).squeeze_dim::<1>(1).into_scalar(); // Get class index with the highest value\n        let class = D::class_name(class_index.elem::<i32>() as usize); // Get class name\n\n        // Print sample text, predicted logits and predicted class\n        println!(\n            \"\\n=== Item {i} ===\\n- Text: {text}\\n- Logits: {logits}\\n- Prediction: \\\n             {class}\\n================\"\n        );\n    }\n}\n"
  },
  {
    "path": "examples/text-classification/src/lib.rs",
    "content": "#[macro_use]\nextern crate derive_new;\n\npub mod data;\npub mod inference;\npub mod model;\npub mod training;\n\npub use data::{AgNewsDataset, DbPediaDataset, TextClassificationDataset};\n"
  },
  {
    "path": "examples/text-classification/src/model.rs",
    "content": "// This is a basic text classification model implemented in Rust using the Burn framework.\n// It uses a Transformer as the base model and applies Linear and Embedding layers.\n// The model is then trained using Cross-Entropy loss. It contains methods for model initialization\n// (both with and without pre-trained weights), forward pass, inference, training, and validation.\n\nuse crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch};\nuse burn::{\n    nn::{\n        Embedding, EmbeddingConfig, Linear, LinearConfig,\n        attention::SeqLengthOption,\n        loss::CrossEntropyLossConfig,\n        transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},\n    },\n    prelude::*,\n    tensor::{activation::softmax, backend::AutodiffBackend},\n    train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep},\n};\n\n// Define the model configuration\n#[derive(Config, Debug)]\npub struct TextClassificationModelConfig {\n    transformer: TransformerEncoderConfig,\n    n_classes: usize,\n    vocab_size: usize,\n    seq_length: SeqLengthOption,\n}\n\n// Define the model structure\n#[derive(Module, Debug)]\npub struct TextClassificationModel<B: Backend> {\n    transformer: TransformerEncoder<B>,\n    embedding_token: Embedding<B>,\n    embedding_pos: Embedding<B>,\n    output: Linear<B>,\n    n_classes: usize,\n}\n\n// Define functions for model initialization\nimpl TextClassificationModelConfig {\n    /// Initializes a model with default weights\n    pub fn init<B: Backend>(&self, device: &B::Device) -> TextClassificationModel<B> {\n        let output = LinearConfig::new(self.transformer.d_model, self.n_classes).init(device);\n        let transformer = self.transformer.init(device);\n        let embedding_token =\n            EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);\n        let max_seq_length = match self.seq_length {\n            SeqLengthOption::Fixed(max) | SeqLengthOption::Max(max) => max,\n            SeqLengthOption::NoMax => panic!(\n                \"Text classification requires a max sequence length because of the embedding strategy.\"\n            ),\n        };\n        let embedding_pos =\n            EmbeddingConfig::new(max_seq_length, self.transformer.d_model).init(device);\n\n        TextClassificationModel {\n            transformer,\n            embedding_token,\n            embedding_pos,\n            output,\n            n_classes: self.n_classes,\n        }\n    }\n}\n\n/// Define model behavior\nimpl<B: Backend> TextClassificationModel<B> {\n    // Defines forward pass for training\n    pub fn forward(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {\n        // Get batch and sequence length, and the device\n        let [batch_size, seq_length] = item.tokens.dims();\n        let device = &self.embedding_token.devices()[0];\n\n        // Move tensors to the correct device\n        let tokens = item.tokens.to_device(device);\n        let labels = item.labels.to_device(device);\n        let mask_pad = item.mask_pad.to_device(device);\n\n        // Calculate token and position embeddings, and combine them\n        let index_positions = Tensor::arange(0..seq_length as i64, device)\n            .reshape([1, seq_length])\n            .repeat_dim(0, batch_size);\n        let embedding_positions = self.embedding_pos.forward(index_positions);\n        let embedding_tokens = self.embedding_token.forward(tokens);\n        let embedding = (embedding_positions + embedding_tokens) / 2;\n\n        // Perform transformer encoding, calculate output and loss\n        let encoded = self\n            .transformer\n            .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));\n        let output = self.output.forward(encoded);\n\n        let output_classification = output\n            .slice([0..batch_size, 0..1])\n            .reshape([batch_size, self.n_classes]);\n\n        let loss = CrossEntropyLossConfig::new()\n            .init(&output_classification.device())\n            .forward(output_classification.clone(), labels.clone());\n\n        // Return the output and loss\n        ClassificationOutput {\n            loss,\n            output: output_classification,\n            targets: labels,\n        }\n    }\n\n    /// Defines forward pass for inference\n    pub fn infer(&self, item: TextClassificationInferenceBatch<B>) -> Tensor<B, 2> {\n        // Get batch and sequence length, and the device\n        let [batch_size, seq_length] = item.tokens.dims();\n        let device = &self.embedding_token.devices()[0];\n\n        // Move tensors to the correct device\n        let tokens = item.tokens.to_device(device);\n        let mask_pad = item.mask_pad.to_device(device);\n\n        // Calculate token and position embeddings, and combine them\n        let index_positions = Tensor::arange(0..seq_length as i64, device)\n            .reshape([1, seq_length])\n            .repeat_dim(0, batch_size);\n        let embedding_positions = self.embedding_pos.forward(index_positions);\n        let embedding_tokens = self.embedding_token.forward(tokens);\n        let embedding = (embedding_positions + embedding_tokens) / 2;\n\n        // Perform transformer encoding, calculate output and apply softmax for prediction\n        let encoded = self\n            .transformer\n            .forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));\n        let output = self.output.forward(encoded);\n        let output = output\n            .slice([0..batch_size, 0..1])\n            .reshape([batch_size, self.n_classes]);\n\n        softmax(output, 1)\n    }\n}\n\n/// Define training step\nimpl<B: AutodiffBackend> TrainStep for TextClassificationModel<B> {\n    type Input = TextClassificationTrainingBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(\n        &self,\n        item: TextClassificationTrainingBatch<B>,\n    ) -> TrainOutput<ClassificationOutput<B>> {\n        // Run forward pass, calculate gradients and return them along with the output\n        let item = self.forward(item);\n        let grads = item.loss.backward();\n\n        TrainOutput::new(self, grads, item)\n    }\n}\n\n/// Define validation step\nimpl<B: Backend> InferenceStep for TextClassificationModel<B> {\n    type Input = TextClassificationTrainingBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: TextClassificationTrainingBatch<B>) -> ClassificationOutput<B> {\n        // Run forward pass and return the output\n        self.forward(item)\n    }\n}\n"
  },
  {
    "path": "examples/text-classification/src/training.rs",
    "content": "// This module trains a text classification model using the provided training and testing datasets,\n// as well as the provided configuration. It first initializes a tokenizer and batchers for the datasets,\n// then initializes the model and data loaders for the datasets. The function then initializes\n// an optimizer and a learning rate scheduler, and uses them along with the model and datasets\n// to build a learner, which is used to train the model. The trained model and the configuration are\n// then saved to the specified directory.\n\nuse crate::{\n    data::{BertCasedTokenizer, TextClassificationBatcher, TextClassificationDataset, Tokenizer},\n    model::TextClassificationModelConfig,\n};\n#[cfg(feature = \"ddp\")]\nuse burn::collective::{AllReduceStrategy, CollectiveConfig};\nuse burn::train::{Learner, SupervisedTraining};\n#[cfg(not(feature = \"ddp\"))]\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::transform::SamplerDataset},\n    lr_scheduler::noam::NoamLrSchedulerConfig,\n    nn::{attention::SeqLengthOption, transformer::TransformerEncoderConfig},\n    optim::AdamConfig,\n    prelude::*,\n    record::{CompactRecorder, Recorder},\n    tensor::backend::AutodiffBackend,\n    train::{\n        MultiDeviceOptim,\n        metric::{\n            AccuracyMetric, CudaMetric, IterationSpeedMetric, LearningRateMetric, LossMetric,\n        },\n    },\n};\nuse std::sync::Arc;\n\n// Define configuration struct for the experiment\n#[derive(Config, Debug)]\npub struct ExperimentConfig {\n    pub transformer: TransformerEncoderConfig,\n    pub optimizer: AdamConfig,\n    #[config(default = \"SeqLengthOption::Fixed(256)\")]\n    pub seq_length: SeqLengthOption,\n    #[config(default = 16)]\n    pub batch_size: usize,\n    #[config(default = 5)]\n    pub num_epochs: usize,\n}\n\n// Define train function\npub fn train<B: AutodiffBackend, D: TextClassificationDataset + 'static>(\n    devices: Vec<B::Device>, // Device on which to perform computation (e.g., CPU or CUDA device)\n    dataset_train: D,        // Training dataset\n    dataset_test: D,         // Testing dataset\n    config: ExperimentConfig, // Experiment configuration\n    artifact_dir: &str,      // Directory to save model and config files\n) {\n    // Initialize tokenizer\n    let tokenizer = Arc::new(BertCasedTokenizer::default());\n\n    // Initialize batcher\n    let batcher = TextClassificationBatcher::new(tokenizer.clone(), config.seq_length);\n\n    // Initialize model\n    let model = TextClassificationModelConfig::new(\n        config.transformer.clone(),\n        D::num_classes(),\n        tokenizer.vocab_size(),\n        config.seq_length,\n    )\n    .init::<B>(&devices[0]);\n\n    // Initialize data loaders for training and testing data\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .num_workers(1)\n        .build(SamplerDataset::new(dataset_train, 50_000));\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .num_workers(1)\n        .build(SamplerDataset::new(dataset_test, 5_000));\n\n    // Initialize optimizer\n    let optim = config.optimizer.init();\n\n    // Initialize learning rate scheduler\n    let lr_scheduler = NoamLrSchedulerConfig::new(1e-2)\n        .with_warmup_steps(1000)\n        .with_model_size(config.transformer.d_model)\n        .init()\n        .unwrap();\n\n    // Initialize learner\n    #[cfg(not(feature = \"ddp\"))]\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metric_train(CudaMetric::new())\n        .metric_valid(CudaMetric::new())\n        .metric_train(IterationSpeedMetric::new())\n        .metric_train_numeric(LossMetric::new())\n        .metric_valid_numeric(LossMetric::new())\n        .metric_train_numeric(AccuracyMetric::new())\n        .metric_valid_numeric(AccuracyMetric::new())\n        .metric_train_numeric(LearningRateMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .num_epochs(config.num_epochs)\n        .summary()\n        .with_training_strategy(burn::train::TrainingStrategy::MultiDevice(\n            devices,\n            MultiDeviceOptim::OptimSharded,\n        ));\n\n    #[cfg(feature = \"ddp\")]\n    let collective_config =\n        CollectiveConfig::default().with_local_all_reduce_strategy(AllReduceStrategy::Tree(2));\n    #[cfg(feature = \"ddp\")]\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metric_train(CudaMetric::new())\n        .metric_valid(CudaMetric::new())\n        .metric_train(IterationSpeedMetric::new())\n        .metric_train_numeric(LossMetric::new())\n        .metric_valid_numeric(LossMetric::new())\n        .metric_train_numeric(AccuracyMetric::new())\n        .metric_valid_numeric(AccuracyMetric::new())\n        .metric_train_numeric(LearningRateMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .with_training_strategy(burn::train::ddp(devices, collective_config))\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    // Train the model\n    let result = training.launch(Learner::new(model, optim, lr_scheduler));\n\n    // Save the configuration and the trained model\n    config.save(format!(\"{artifact_dir}/config.json\")).unwrap();\n    CompactRecorder::new()\n        .record(\n            result.model.into_record(),\n            format!(\"{artifact_dir}/model\").into(),\n        )\n        .unwrap();\n}\n"
  },
  {
    "path": "examples/text-generation/Cargo.toml",
    "content": "[package]\nauthors = [\"nathanielsimard <nathaniel.simard.42@gmail.com>\"]\nedition.workspace = true\nlicense.workspace = true\nname = \"text-generation\"\npublish = false\nversion.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\ndefault = [\"burn/dataset\", \"burn/sqlite-bundled\"]\nf16 = []\n\n[dependencies]\n# Burn\nburn = {path = \"../../crates/burn\", features=[\"train\", \"tch\"]}\n\n# Tokenizer\ntokenizers = {version = \"0.22.2\", default-features = false, features = [\n  \"onig\",\n  \"http\",\n]}\n\n# Utils\nderive-new = {workspace = true}\nlog = {workspace = true}\nserde = {workspace = true, features = [\"std\", \"derive\"]}\n"
  },
  {
    "path": "examples/text-generation/README.md",
    "content": "# Text Generation\n\n> **Note**  \n> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)\n> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)\n> installed on your computer.\n\nThe example can be run like so:\n\n## CUDA users\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\nexport TORCH_CUDA_VERSION=cu128\ncargo run --example text-generation --release\n```\n\n## Mac users\n\n```bash\ngit clone https://github.com/tracel-ai/burn.git\ncd burn\n\n# Use the --release flag to really speed up training.\ncargo run --example text-generation --release\n```\n"
  },
  {
    "path": "examples/text-generation/examples/text-generation.rs",
    "content": "use burn::optim::decay::WeightDecayConfig;\nuse text_generation::{DbPediaDataset, training::ExperimentConfig};\n\n#[cfg(feature = \"f16\")]\ntype Elem = burn::tensor::f16;\n#[cfg(not(feature = \"f16\"))]\ntype Elem = f32;\n\ntype Backend = burn::backend::Autodiff<burn::backend::LibTorch<Elem>>;\n\nfn main() {\n    let config = ExperimentConfig::new(\n        burn::nn::transformer::TransformerEncoderConfig::new(384, 1536, 12, 6)\n            .with_norm_first(true),\n        burn::optim::AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1.0e-6))),\n    );\n\n    text_generation::training::train::<Backend, DbPediaDataset>(\n        if cfg!(target_os = \"macos\") {\n            burn::tensor::Device::<Backend>::Mps\n        } else {\n            burn::tensor::Device::<Backend>::Cuda(0)\n        },\n        DbPediaDataset::train(),\n        DbPediaDataset::test(),\n        config,\n        \"/tmp/text-generation\",\n    );\n}\n"
  },
  {
    "path": "examples/text-generation/src/data/batcher.rs",
    "content": "use super::{dataset::TextGenerationItem, tokenizer::Tokenizer};\nuse burn::{data::dataloader::batcher::Batcher, nn::attention::generate_padding_mask, prelude::*};\nuse std::sync::Arc;\n\n#[derive(Clone, new)]\npub struct TextGenerationBatcher {\n    tokenizer: Arc<dyn Tokenizer>,\n    max_seq_length: usize,\n}\n\n#[derive(Debug, Clone, new)]\npub struct TextGenerationBatch<B: Backend> {\n    pub tokens: Tensor<B, 2, Int>,\n    pub mask_pad: Tensor<B, 2, Bool>,\n}\n\n#[derive(Debug, Clone, new)]\npub struct TrainingTextGenerationBatch<B: Backend> {\n    pub tokens_inputs: Tensor<B, 2, Int>,\n    pub targets: Tensor<B, 2, Int>,\n    pub mask_pad: Tensor<B, 2, Bool>,\n}\n\nimpl<B: Backend> Batcher<B, TextGenerationItem, TextGenerationBatch<B>> for TextGenerationBatcher {\n    fn batch(&self, items: Vec<TextGenerationItem>, device: &B::Device) -> TextGenerationBatch<B> {\n        let mut tokens_list = Vec::with_capacity(items.len());\n\n        for item in items {\n            tokens_list.push(self.tokenizer.encode(&item.text, true));\n        }\n\n        let mask = generate_padding_mask(\n            self.tokenizer.pad_token(),\n            tokens_list,\n            Some(self.max_seq_length),\n            device,\n        );\n\n        TextGenerationBatch {\n            tokens: mask.tensor,\n            mask_pad: mask.mask,\n        }\n    }\n}\n\nimpl<B: Backend> Batcher<B, TextGenerationItem, TrainingTextGenerationBatch<B>>\n    for TextGenerationBatcher\n{\n    fn batch(\n        &self,\n        items: Vec<TextGenerationItem>,\n        device: &B::Device,\n    ) -> TrainingTextGenerationBatch<B> {\n        let item: TextGenerationBatch<B> = self.batch(items, device);\n        let [batch_size, seq_length] = item.tokens.dims();\n\n        let inputs = item\n            .tokens\n            .clone()\n            .slice([0..batch_size, 0..seq_length - 1]);\n        let targets = item.tokens.slice([0..batch_size, 1..seq_length]);\n        let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]);\n\n        TrainingTextGenerationBatch::new(inputs, targets, mask_pad)\n    }\n}\n"
  },
  {
    "path": "examples/text-generation/src/data/dataset.rs",
    "content": "use burn::data::dataset::{Dataset, SqliteDataset, source::huggingface::HuggingfaceDatasetLoader};\n\n#[derive(new, Clone, Debug)]\npub struct TextGenerationItem {\n    pub text: String,\n}\n\n#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]\npub struct DbPediaItem {\n    pub content: String,\n}\n\npub struct DbPediaDataset {\n    dataset: SqliteDataset<DbPediaItem>,\n}\n\nimpl Dataset<TextGenerationItem> for DbPediaDataset {\n    fn get(&self, index: usize) -> Option<TextGenerationItem> {\n        self.dataset\n            .get(index)\n            .map(|item| TextGenerationItem::new(item.content))\n    }\n\n    fn len(&self) -> usize {\n        self.dataset.len()\n    }\n}\n\nimpl DbPediaDataset {\n    pub fn train() -> Self {\n        Self::new(\"train\")\n    }\n\n    pub fn test() -> Self {\n        Self::new(\"test\")\n    }\n    pub fn new(split: &str) -> Self {\n        let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new(\"dbpedia_14\")\n            .dataset(split)\n            .unwrap();\n        Self { dataset }\n    }\n}\n"
  },
  {
    "path": "examples/text-generation/src/data/mod.rs",
    "content": "mod batcher;\nmod dataset;\nmod tokenizer;\n\npub use batcher::*;\npub use dataset::*;\npub use tokenizer::*;\n"
  },
  {
    "path": "examples/text-generation/src/data/tokenizer.rs",
    "content": "#[allow(dead_code)]\npub trait Tokenizer: Send + Sync {\n    fn encode(&self, value: &str, special_tokens: bool) -> Vec<usize>;\n    fn decode(&self, tokens: &[usize]) -> String;\n    fn vocab_size(&self) -> usize;\n    fn pad_token(&self) -> usize;\n    fn start_token(&self) -> usize;\n    fn end_token(&self) -> usize;\n    fn pad_token_value(&self) -> String {\n        self.decode(&[self.pad_token()])\n    }\n    fn start_token_value(&self) -> String {\n        self.decode(&[self.start_token()])\n    }\n    fn end_token_value(&self) -> String {\n        self.decode(&[self.end_token()])\n    }\n}\n\npub struct Gpt2Tokenizer {\n    tokenizer: tokenizers::Tokenizer,\n}\n\nimpl Default for Gpt2Tokenizer {\n    fn default() -> Self {\n        let mut tokenizer = tokenizers::Tokenizer::from_pretrained(\"gpt2\", None).unwrap();\n        tokenizer.add_special_tokens(&[\n            tokenizers::AddedToken::from(\"[START]\", true),\n            tokenizers::AddedToken::from(\"[END]\", true),\n            tokenizers::AddedToken::from(\"[PAD]\", true),\n        ]);\n\n        Self { tokenizer }\n    }\n}\n\nimpl Tokenizer for Gpt2Tokenizer {\n    fn encode(&self, value: &str, special_tokens: bool) -> Vec<usize> {\n        let text = match special_tokens {\n            true => \"[START]\".to_owned() + value + \"[END]\",\n            false => value.to_string(),\n        };\n        let tokens = self.tokenizer.encode(text, true).unwrap();\n        tokens.get_ids().iter().map(|t| *t as usize).collect()\n    }\n\n    fn decode(&self, tokens: &[usize]) -> String {\n        let tokens = tokens.iter().map(|t| *t as u32).collect::<Vec<u32>>();\n        self.tokenizer.decode(&tokens, false).unwrap()\n    }\n\n    fn vocab_size(&self) -> usize {\n        self.tokenizer.get_vocab_size(true)\n    }\n\n    fn pad_token(&self) -> usize {\n        self.tokenizer.token_to_id(\"[PAD]\").unwrap() as usize\n    }\n\n    fn start_token(&self) -> usize {\n        self.tokenizer.token_to_id(\"[START]\").unwrap() as usize\n    }\n\n    fn end_token(&self) -> usize {\n        self.tokenizer.token_to_id(\"[END]\").unwrap() as usize\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_encode_decode() {\n        let tokenizer = Gpt2Tokenizer::default();\n        let text = \"A sentence\";\n\n        let tokens = tokenizer.encode(text, false);\n        let decoded = tokenizer.decode(&tokens);\n\n        assert_eq!(decoded, text);\n    }\n\n    #[test]\n    fn test_add_start_end_token() {\n        let tokenizer = Gpt2Tokenizer::default();\n        let text = \"A sentence\";\n\n        let tokens_without = tokenizer.encode(text, false);\n        let tokens_with = tokenizer.encode(text, true);\n\n        assert_eq!(tokens_with.len() - 2, tokens_without.len());\n    }\n}\n"
  },
  {
    "path": "examples/text-generation/src/lib.rs",
    "content": "#[macro_use]\nextern crate derive_new;\n\nmod data;\nmod model;\n\npub mod training;\npub use data::DbPediaDataset;\n"
  },
  {
    "path": "examples/text-generation/src/model.rs",
    "content": "use crate::data::TrainingTextGenerationBatch;\nuse burn::{\n    nn::{\n        Embedding, EmbeddingConfig, Linear, LinearConfig,\n        attention::generate_autoregressive_mask,\n        loss::CrossEntropyLossConfig,\n        transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},\n    },\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n    train::{ClassificationOutput, InferenceStep, TrainOutput, TrainStep},\n};\n\n#[derive(Config, Debug)]\npub struct TextGenerationModelConfig {\n    transformer: TransformerEncoderConfig,\n    vocab_size: usize,\n    pad_token: usize,\n    max_seq_length: usize,\n}\n\n#[derive(Module, Debug)]\npub struct TextGenerationModel<B: Backend> {\n    transformer: TransformerEncoder<B>,\n    embedding_token: Embedding<B>,\n    embedding_pos: Embedding<B>,\n    output: Linear<B>,\n    vocab_size: usize,\n    pad_token: usize,\n    max_seq_length: usize,\n}\n\nimpl TextGenerationModelConfig {\n    pub fn init<B: Backend>(&self, device: &B::Device) -> TextGenerationModel<B> {\n        let output = LinearConfig::new(self.transformer.d_model, self.vocab_size).init(device);\n        let transformer = self.transformer.init(device);\n        let embedding_token =\n            EmbeddingConfig::new(self.vocab_size, self.transformer.d_model).init(device);\n        let embedding_pos =\n            EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init(device);\n\n        TextGenerationModel {\n            transformer,\n            embedding_token,\n            embedding_pos,\n            output,\n            vocab_size: self.vocab_size,\n            pad_token: self.pad_token,\n            max_seq_length: self.max_seq_length,\n        }\n    }\n}\nimpl<B: Backend> TextGenerationModel<B> {\n    pub fn forward_training(\n        &self,\n        item: TrainingTextGenerationBatch<B>,\n    ) -> ClassificationOutput<B> {\n        let [batch_size, seq_length] = item.tokens_inputs.dims();\n        let device = &self.devices()[0];\n\n        let inputs = item.tokens_inputs.to_device(device);\n        let targets = item.targets.to_device(device);\n        let mask_pad = item.mask_pad.to_device(device);\n\n        let index_positions = Tensor::arange(0..seq_length as i64, device)\n            .reshape([1, seq_length])\n            .repeat_dim(0, batch_size);\n\n        let embedding_positions = self.embedding_pos.forward(index_positions);\n        let embedding_tokens = self.embedding_token.forward(inputs);\n        let embedding = (embedding_positions + embedding_tokens) / 2;\n\n        let mask_attn = generate_autoregressive_mask::<B>(batch_size, seq_length, device);\n        let encoded = self.transformer.forward(\n            TransformerEncoderInput::new(embedding)\n                .mask_pad(mask_pad)\n                .mask_attn(mask_attn),\n        );\n\n        let output = self.output.forward(encoded);\n        let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]);\n        let targets_flatten = targets.reshape([batch_size * seq_length]);\n\n        let loss = CrossEntropyLossConfig::new()\n            .with_pad_tokens(Some(vec![self.pad_token]))\n            .init(&output_flatten.device());\n        let loss = loss.forward(output_flatten.clone(), targets_flatten.clone());\n\n        ClassificationOutput {\n            loss,\n            output: output_flatten,\n            targets: targets_flatten,\n        }\n    }\n}\n\nimpl<B: AutodiffBackend> TrainStep for TextGenerationModel<B> {\n    type Input = TrainingTextGenerationBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: TrainingTextGenerationBatch<B>) -> TrainOutput<ClassificationOutput<B>> {\n        let item = self.forward_training(item);\n        let grads = item.loss.backward();\n\n        TrainOutput::new(self, grads, item)\n    }\n}\n\nimpl<B: Backend> InferenceStep for TextGenerationModel<B> {\n    type Input = TrainingTextGenerationBatch<B>;\n    type Output = ClassificationOutput<B>;\n\n    fn step(&self, item: TrainingTextGenerationBatch<B>) -> ClassificationOutput<B> {\n        self.forward_training(item)\n    }\n}\n"
  },
  {
    "path": "examples/text-generation/src/training.rs",
    "content": "use crate::{\n    data::{Gpt2Tokenizer, TextGenerationBatcher, TextGenerationItem, Tokenizer},\n    model::TextGenerationModelConfig,\n};\nuse burn::{\n    data::{\n        dataloader::DataLoaderBuilder,\n        dataset::{Dataset, transform::SamplerDataset},\n    },\n    lr_scheduler::noam::NoamLrSchedulerConfig,\n    nn::transformer::TransformerEncoderConfig,\n    optim::AdamConfig,\n    prelude::*,\n    record::{CompactRecorder, DefaultRecorder, Recorder},\n    tensor::backend::AutodiffBackend,\n    train::{\n        Learner, SupervisedTraining,\n        metric::{AccuracyMetric, CudaMetric, LearningRateMetric, LossMetric, PerplexityMetric},\n    },\n};\nuse std::sync::Arc;\n\n#[derive(Config, Debug)]\npub struct ExperimentConfig {\n    transformer: TransformerEncoderConfig,\n    optimizer: AdamConfig,\n    #[config(default = 512)]\n    max_seq_length: usize,\n    #[config(default = 6)]\n    batch_size: usize,\n    #[config(default = 50)]\n    num_epochs: usize,\n}\n\npub fn train<B: AutodiffBackend, D: Dataset<TextGenerationItem> + 'static>(\n    device: B::Device,\n    dataset_train: D,\n    dataset_test: D,\n    config: ExperimentConfig,\n    artifact_dir: &str,\n) {\n    let tokenizer = Arc::new(Gpt2Tokenizer::default());\n    let batcher = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length);\n\n    let model = TextGenerationModelConfig::new(\n        config.transformer.clone(),\n        tokenizer.vocab_size(),\n        tokenizer.pad_token(),\n        config.max_seq_length,\n    )\n    .init::<B>(&device);\n\n    let dataloader_train = DataLoaderBuilder::new(batcher.clone())\n        .batch_size(config.batch_size)\n        .num_workers(4)\n        .build(SamplerDataset::new(dataset_train, 10_000));\n\n    let dataloader_test = DataLoaderBuilder::new(batcher)\n        .batch_size(config.batch_size)\n        .num_workers(4)\n        .build(SamplerDataset::new(dataset_test, 1000));\n\n    let accum = 6; // Effective batch size = 6 * 6 = 32.\n    let optim = config.optimizer.init();\n    let lr_scheduler = NoamLrSchedulerConfig::new(0.01 / accum as f64)\n        .with_warmup_steps(6000)\n        .with_model_size(config.transformer.d_model)\n        .init()\n        .unwrap();\n\n    let training = SupervisedTraining::new(artifact_dir, dataloader_train, dataloader_test)\n        .metric_train(CudaMetric::new())\n        .metric_valid(CudaMetric::new())\n        .metric_train_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))\n        .metric_valid_numeric(AccuracyMetric::new().with_pad_token(tokenizer.pad_token()))\n        .metric_train_numeric(PerplexityMetric::new().with_pad_token(tokenizer.pad_token()))\n        .metric_valid_numeric(PerplexityMetric::new().with_pad_token(tokenizer.pad_token()))\n        .metric_train(LossMetric::new())\n        .metric_valid(LossMetric::new())\n        .metric_train_numeric(LearningRateMetric::new())\n        .with_file_checkpointer(CompactRecorder::new())\n        .grads_accumulation(accum)\n        .num_epochs(config.num_epochs)\n        .summary();\n\n    let result = training.launch(Learner::new(model, optim, lr_scheduler));\n\n    config.save(format!(\"{artifact_dir}/config.json\")).unwrap();\n\n    DefaultRecorder::new()\n        .record(\n            result.model.into_record(),\n            format!(\"{artifact_dir}/model\").into(),\n        )\n        .unwrap();\n}\n"
  },
  {
    "path": "examples/wgan/Cargo.toml",
    "content": "[package]\nname = \"wgan\"\nversion = \"0.5.0\"\nedition.workspace = true\n\n[lints]\nworkspace = true\n\n[features]\nndarray = [\"burn/ndarray\"]\nndarray-blas-accelerate = [\"burn/ndarray\", \"burn/accelerate\"]\nndarray-blas-netlib = [\"burn/ndarray\", \"burn/blas-netlib\"]\nndarray-blas-openblas = [\"burn/ndarray\", \"burn/openblas\"]\ntch-cpu = [\"burn/tch\"]\ntch-gpu = [\"burn/tch\"]\nwgpu = [\"burn/wgpu\"]\ncuda = [\"burn/cuda\"]\n\n[dependencies]\nburn = { path = \"../../crates/burn\", features=[\"train\", \"vision\"] }\nimage = { workspace = true }\n"
  },
  {
    "path": "examples/wgan/README.md",
    "content": "# Wasserstein Generative Adversarial Network\n\nA burn implementation of an example WGAN model to generate MNIST digits inspired by\n[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html).\nPlease note that better performance maybe gained by adopting a convolution layer in\n[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch).\n\n## Usage\n\n## Training\n\n```sh\n# Cuda backend\ncargo run --example wgan-mnist --release --features cuda\n\n# Wgpu backend\ncargo run --example wgan-mnist --release --features wgpu\n\n# Tch GPU backend\nexport TORCH_CUDA_VERSION=cu128 # Set the cuda version\ncargo run --example wgan-mnist --release --features tch-gpu\n\n# Tch CPU backend\ncargo run --example wgan-mnist --release --features tch-cpu\n\n# NdArray backend (CPU)\ncargo run --example wgan-mnist --release --features ndarray                # f32 - single thread\ncargo run --example wgan-mnist --release --features ndarray-blas-openblas  # f32 - blas with openblas\ncargo run --example wgan-mnist --release --features ndarray-blas-netlib    # f32 - blas with netlib\n```\n\n### Generating\n\nTo generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend.\n\n```sh\ncargo run --example wgan-generate --release --features cuda\n```\n"
  },
  {
    "path": "examples/wgan/examples/wgan-generate.rs",
    "content": "use burn::tensor::backend::Backend;\n\npub fn launch<B: Backend>(device: B::Device) {\n    wgan::infer::generate::<B>(\"/tmp/wgan-mnist\", device);\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::ndarray::{NdArray, NdArrayDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<NdArray>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<LibTorch>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::libtorch::{LibTorch, LibTorchDevice};\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<LibTorch>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::launch;\n    use burn::backend::wgpu::Wgpu;\n\n    pub fn run() {\n        launch::<Wgpu>(Default::default());\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::launch;\n    use burn::backend::Cuda;\n\n    pub fn run() {\n        launch::<Cuda>(Default::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n}\n"
  },
  {
    "path": "examples/wgan/examples/wgan-mnist.rs",
    "content": "use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend};\n\nuse wgan::{model::ModelConfig, training::TrainingConfig};\n\npub fn launch<B: AutodiffBackend>(device: B::Device) {\n    let config = TrainingConfig::new(\n        ModelConfig::new(),\n        RmsPropConfig::new()\n            .with_alpha(0.99)\n            .with_momentum(0.0)\n            .with_epsilon(0.00000008)\n            .with_weight_decay(None)\n            .with_centered(false),\n    );\n\n    wgan::training::train::<B>(\"/tmp/wgan-mnist\", config, device);\n}\n\n#[cfg(any(\n    feature = \"ndarray\",\n    feature = \"ndarray-blas-netlib\",\n    feature = \"ndarray-blas-openblas\",\n    feature = \"ndarray-blas-accelerate\",\n))]\nmod ndarray {\n    use burn::backend::{\n        Autodiff,\n        ndarray::{NdArray, NdArrayDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"tch-gpu\")]\nmod tch_gpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        #[cfg(not(target_os = \"macos\"))]\n        let device = LibTorchDevice::Cuda(0);\n        #[cfg(target_os = \"macos\")]\n        let device = LibTorchDevice::Mps;\n\n        launch::<Autodiff<LibTorch>>(device);\n    }\n}\n\n#[cfg(feature = \"tch-cpu\")]\nmod tch_cpu {\n    use burn::backend::{\n        Autodiff,\n        libtorch::{LibTorch, LibTorchDevice},\n    };\n\n    use crate::launch;\n\n    pub fn run() {\n        launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);\n    }\n}\n\n#[cfg(feature = \"wgpu\")]\nmod wgpu {\n    use crate::launch;\n    use burn::backend::{Autodiff, wgpu::Wgpu};\n\n    pub fn run() {\n        launch::<Autodiff<Wgpu>>(Default::default());\n    }\n}\n\n#[cfg(feature = \"cuda\")]\nmod cuda {\n    use crate::launch;\n    use burn::backend::{Autodiff, Cuda, cuda::CudaDevice};\n\n    pub fn run() {\n        launch::<Autodiff<Cuda>>(CudaDevice::default());\n    }\n}\n\nfn main() {\n    #[cfg(any(\n        feature = \"ndarray\",\n        feature = \"ndarray-blas-netlib\",\n        feature = \"ndarray-blas-openblas\",\n        feature = \"ndarray-blas-accelerate\",\n    ))]\n    ndarray::run();\n    #[cfg(feature = \"tch-gpu\")]\n    tch_gpu::run();\n    #[cfg(feature = \"tch-cpu\")]\n    tch_cpu::run();\n    #[cfg(feature = \"wgpu\")]\n    wgpu::run();\n    #[cfg(feature = \"cuda\")]\n    cuda::run();\n}\n"
  },
  {
    "path": "examples/wgan/src/dataset.rs",
    "content": "use burn::{\n    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},\n    prelude::*,\n};\n\n#[derive(Clone, Debug, Default)]\npub struct MnistBatcher {}\n\n#[derive(Clone, Debug)]\npub struct MnistBatch<B: Backend> {\n    pub images: Tensor<B, 4>,\n    pub targets: Tensor<B, 1, Int>,\n}\n\nimpl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {\n    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {\n        let images = items\n            .iter()\n            .map(|item| TensorData::from(item.image))\n            .map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), device))\n            .map(|tensor| tensor.reshape([1, 28, 28]))\n            // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example\n            .map(|tensor| ((tensor / 255) - 0.5) / 0.5)\n            .collect();\n\n        let targets = items\n            .iter()\n            .map(|item| {\n                Tensor::<B, 1, Int>::from_data(\n                    TensorData::from([(item.label as i64).elem::<B::IntElem>()]),\n                    device,\n                )\n            })\n            .collect();\n\n        let images = Tensor::stack(images, 0);\n        let targets = Tensor::cat(targets, 0);\n\n        MnistBatch { images, targets }\n    }\n}\n"
  },
  {
    "path": "examples/wgan/src/infer.rs",
    "content": "use crate::training::{TrainingConfig, save_image};\nuse burn::{\n    prelude::*,\n    record::{CompactRecorder, Recorder},\n    tensor::Distribution,\n};\n\npub fn generate<B: Backend>(artifact_dir: &str, device: B::Device) {\n    // Loading model\n    let config = TrainingConfig::load(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should exist for the model; run train first\");\n    let record = CompactRecorder::new()\n        .load(format!(\"{artifact_dir}/generator\").into(), &device)\n        .expect(\"Trained model should exist; run train first\");\n    let (mut generator, _) = config.model.init::<B>(&device);\n    generator = generator.load_record(record);\n\n    // Get a batch of noise\n    let noise = Tensor::<B, 2>::random(\n        [config.batch_size, config.model.latent_dim],\n        Distribution::Normal(0.0, 1.0),\n        &device,\n    );\n    let fake_images = generator.forward(noise); // [batch_size, channesl*height*width]\n    let fake_images = fake_images.reshape([\n        config.batch_size,\n        config.model.channels,\n        config.model.image_size,\n        config.model.image_size,\n    ]);\n    // [B, C, H, W] to [B, H, C, W] to [B, H, W, C]\n    let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25);\n    // Normalize the images. The Rgb32 images should be in range 0.0-1.0\n    let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1]))\n        / (fake_images.clone().max().reshape([1, 1, 1, 1])\n            - fake_images.clone().min().reshape([1, 1, 1, 1]));\n    // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source\n    let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0);\n    // Save images in artifact directory\n    save_image::<B, _>(fake_images, 5, format!(\"{artifact_dir}/fake_image.png\")).unwrap();\n}\n"
  },
  {
    "path": "examples/wgan/src/lib.rs",
    "content": "pub mod dataset;\npub mod infer;\npub mod model;\npub mod training;\n"
  },
  {
    "path": "examples/wgan/src/model.rs",
    "content": "use burn::{\n    module::{Module, ModuleMapper, Param},\n    prelude::*,\n    tensor::backend::AutodiffBackend,\n};\n\n/// Layer block of generator model\n#[derive(Module, Debug)]\npub struct LayerBlock<B: Backend> {\n    fc: nn::Linear<B>,\n    bn: nn::BatchNorm<B>,\n    leakyrelu: nn::LeakyRelu,\n}\n\nimpl<B: Backend> LayerBlock<B> {\n    pub fn new(input: usize, output: usize, device: &B::Device) -> Self {\n        let fc = nn::LinearConfig::new(input, output)\n            .with_bias(true)\n            .init(device);\n        let bn: nn::BatchNorm<B> = nn::BatchNormConfig::new(output)\n            .with_epsilon(0.8)\n            .init(device);\n        let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init();\n\n        Self { fc, bn, leakyrelu }\n    }\n\n    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {\n        let output = self.fc.forward(input); // output: [Batch, x]\n        let output = self.bn.forward(output); // output: [Batch, x]\n\n        self.leakyrelu.forward(output) // output: [Batch, x]\n    }\n}\n\n/// Generator model\n#[derive(Module, Debug)]\npub struct Generator<B: Backend> {\n    layer1: LayerBlock<B>,\n    layer2: LayerBlock<B>,\n    layer3: LayerBlock<B>,\n    layer4: LayerBlock<B>,\n    fc: nn::Linear<B>,\n    tanh: nn::Tanh,\n}\n\nimpl<B: Backend> Generator<B> {\n    /// Applies the forward pass on the input tensor by specified order\n    pub fn forward(&self, noise: Tensor<B, 2>) -> Tensor<B, 2> {\n        let output = self.layer1.forward(noise);\n        let output = self.layer2.forward(output);\n        let output = self.layer3.forward(output);\n        let output = self.layer4.forward(output);\n        let output = self.fc.forward(output);\n\n        self.tanh.forward(output) // [batch_size, channels*height*width]\n    }\n}\n\n/// Discriminator model\n#[derive(Module, Debug)]\npub struct Discriminator<B: Backend> {\n    fc1: nn::Linear<B>,\n    leakyrelu1: nn::LeakyRelu,\n    fc2: nn::Linear<B>,\n    leakyrelu2: nn::LeakyRelu,\n    fc3: nn::Linear<B>,\n}\n\nimpl<B: Backend> Discriminator<B> {\n    /// Applies the forward pass on the input tensor by specified order.\n    /// The input image shape is [batch, channels, height, width]\n    pub fn forward(&self, images: Tensor<B, 4>) -> Tensor<B, 2> {\n        // Full connection for each batch\n        let output = images.flatten(1, 3); // output: [batch, channels*height*width]\n        let output = self.fc1.forward(output); // output: [batch, 512]\n        let output = self.leakyrelu1.forward(output); // output: [batch, 512]\n        let output = self.fc2.forward(output); // output: [batch, 256]\n        let output = self.leakyrelu2.forward(output); // output: [batch, 256]\n\n        self.fc3.forward(output) // output: [batch, 1]\n    }\n}\n\n// Use model config to construct a generative and adversarial model\n#[derive(Config, Debug)]\npub struct ModelConfig {\n    /// Dimensionality of the latent space\n    #[config(default = 100)]\n    pub latent_dim: usize,\n    #[config(default = 28)]\n    pub image_size: usize,\n    #[config(default = 1)]\n    pub channels: usize,\n}\n\nimpl ModelConfig {\n    /// Initialize the generator and discriminator models based on the config.\n    pub fn init<B: Backend>(&self, device: &B::Device) -> (Generator<B>, Discriminator<B>) {\n        // Construct the initialized generator\n        let layer1 = LayerBlock::new(self.latent_dim, 128, device);\n        let layer2 = LayerBlock::new(128, 256, device);\n        let layer3 = LayerBlock::new(256, 512, device);\n        let layer4 = LayerBlock::new(512, 1024, device);\n        let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size)\n            .with_bias(true)\n            .init(device);\n\n        let generator = Generator {\n            layer1,\n            layer2,\n            layer3,\n            layer4,\n            fc,\n            tanh: nn::Tanh::new(),\n        };\n\n        // Construct the initialized discriminator\n        let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512)\n            .init(device);\n        let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init();\n        let fc2 = nn::LinearConfig::new(512, 256).init(device);\n        let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init();\n        let fc3 = nn::LinearConfig::new(256, 1).init(device);\n\n        let discriminator = Discriminator {\n            fc1,\n            leakyrelu1,\n            fc2,\n            leakyrelu2,\n            fc3,\n        };\n\n        (generator, discriminator)\n    }\n}\n\n/// Clip module mapper to clip all module parameters between a range of values\n#[derive(Module, Clone, Debug)]\npub struct Clip {\n    pub min: f32,\n    pub max: f32,\n}\n\nimpl<B: AutodiffBackend> ModuleMapper<B> for Clip {\n    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {\n        let (id, tensor, mapper) = param.consume();\n        let is_require_grad = tensor.is_require_grad();\n\n        let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max));\n\n        if is_require_grad {\n            tensor = tensor.require_grad();\n        }\n        Param::from_mapped_value(id, tensor, mapper)\n    }\n}\n"
  },
  {
    "path": "examples/wgan/src/training.rs",
    "content": "use crate::dataset::MnistBatcher;\nuse crate::model::{Clip, ModelConfig};\nuse burn::optim::{GradientsParams, Optimizer, RmsPropConfig};\nuse burn::{\n    data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},\n    prelude::*,\n    record::CompactRecorder,\n    tensor::{Distribution, backend::AutodiffBackend},\n};\nuse image::{Rgb32FImage, RgbImage, buffer::ConvertBuffer, error::ImageResult};\nuse std::path::Path;\n\n#[derive(Config, Debug)]\npub struct TrainingConfig {\n    pub model: ModelConfig,\n    pub optimizer: RmsPropConfig,\n\n    #[config(default = 200)]\n    pub num_epochs: usize,\n    #[config(default = 512)]\n    pub batch_size: usize,\n    #[config(default = 8)]\n    pub num_workers: usize,\n    #[config(default = 5)]\n    pub seed: u64,\n    #[config(default = 3e-4)]\n    pub lr: f64,\n\n    /// Number of training steps for discriminator before generator is trained per iteration\n    #[config(default = 5)]\n    pub num_critic: usize,\n    /// Lower and upper clip value for disc. weights\n    #[config(default = 0.01)]\n    pub clip_value: f32,\n    /// Save a sample of images every `sample_interval` epochs\n    #[config(default = 10)]\n    pub sample_interval: usize,\n}\n\n// Create the directory to save the model and model config\nfn create_artifact_dir(artifact_dir: &str) {\n    // Remove existing artifacts\n    std::fs::remove_dir_all(artifact_dir).ok();\n    std::fs::create_dir_all(artifact_dir).ok();\n}\n\n/// Save the generated images\n// The images format is [B, H, W, C]\npub fn save_image<B: Backend, Q: AsRef<Path>>(\n    images: Tensor<B, 4>,\n    nrow: u32,\n    path: Q,\n) -> ImageResult<()> {\n    let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32;\n\n    let width = images.dims()[2] as u32;\n    let height = images.dims()[1] as u32;\n\n    // Supports both 1 and 3 channels image\n    let channels = match images.dims()[3] {\n        1 => 3,\n        3 => 1,\n        _ => panic!(\"Wrong channels number\"),\n    };\n\n    let mut imgbuf = RgbImage::new(nrow * width, ncol * height);\n    // Write images into a nrow*ncol grid layout\n    for row in 0..nrow {\n        for col in 0..ncol {\n            let image: Tensor<B, 3> = images\n                .clone()\n                .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize)\n                .squeeze_dim(0);\n            // The Rgb32 should be in range 0.0-1.0\n            let image = image.into_data().iter::<f32>().collect::<Vec<f32>>();\n            // Supports both 1 and 3 channels image\n            let image = image\n                .into_iter()\n                .flat_map(|n| std::iter::repeat_n(n, channels))\n                .collect();\n\n            let image = Rgb32FImage::from_vec(width, height, image).unwrap();\n            let image: RgbImage = image.convert();\n            for (x, y, pixel) in image.enumerate_pixels() {\n                imgbuf.put_pixel(row * width + x, col * height + y, *pixel);\n            }\n        }\n    }\n    imgbuf.save(path)\n}\n\npub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {\n    create_artifact_dir(artifact_dir);\n\n    // Create the Clip module mapper\n    let mut clip = Clip {\n        min: -config.clip_value,\n        max: config.clip_value,\n    };\n\n    // Save training config\n    config\n        .save(format!(\"{artifact_dir}/config.json\"))\n        .expect(\"Config should be saved successfully\");\n    B::seed(&device, config.seed);\n\n    // Create the model and optimizer\n    let (mut generator, mut discriminator) = config.model.init::<B>(&device);\n    let mut optimizer_g = config.optimizer.init();\n    let mut optimizer_d = config.optimizer.init();\n\n    // Create the dataset batcher\n    let batcher_train = MnistBatcher::default();\n\n    // Create the dataloaders\n    let dataloader_train = DataLoaderBuilder::new(batcher_train)\n        .batch_size(config.batch_size)\n        .shuffle(config.seed)\n        .num_workers(config.num_workers)\n        .build(MnistDataset::train());\n\n    // Iterate over our training for X epochs\n    for epoch in 0..config.num_epochs {\n        // Implement our training loop\n        for (iteration, batch) in dataloader_train.iter().enumerate() {\n            // Generate a batch of fake images from noise (standarded normal distribution)\n            let noise = Tensor::<B, 2>::random(\n                [config.batch_size, config.model.latent_dim],\n                Distribution::Normal(0.0, 1.0),\n                &device,\n            );\n            // datach: do not update generator, only discriminator is updated\n            let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width]\n            let fake_images = fake_images.reshape([\n                config.batch_size,\n                config.model.channels,\n                config.model.image_size,\n                config.model.image_size,\n            ]);\n            // Adversarial loss\n            let loss_d = -discriminator.forward(batch.images).mean()\n                + discriminator.forward(fake_images.clone()).mean();\n\n            // Gradients for the current backward pass\n            let grads = loss_d.backward();\n            // Gradients linked to each parameter of the discriminator\n            let grads = GradientsParams::from_grads(grads, &discriminator);\n            // Update the discriminator using the optimizer\n            discriminator = optimizer_d.step(config.lr, discriminator, grads);\n            // Clip parameters (weights) of discriminator\n            discriminator = discriminator.map(&mut clip);\n\n            // Train the generator every num_critic iterations\n            if iteration % config.num_critic == 0 {\n                // Generate a batch of images again without detaching\n                let critic_fake_images = generator.forward(noise.clone());\n                let critic_fake_images = critic_fake_images.reshape([\n                    config.batch_size,\n                    config.model.channels,\n                    config.model.image_size,\n                    config.model.image_size,\n                ]);\n                // Adversarial loss. Minimize it to make the fake images as truth\n                let loss_g = -discriminator.forward(critic_fake_images).mean();\n\n                let grads = loss_g.backward();\n                let grads = GradientsParams::from_grads(grads, &generator);\n                generator = optimizer_g.step(config.lr, generator, grads);\n\n                // Print the progression\n                let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32)\n                    .ceil() as usize;\n                println!(\n                    \"[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]\",\n                    epoch + 1,\n                    config.num_epochs,\n                    iteration,\n                    batch_num,\n                    loss_d.into_scalar(),\n                    loss_g.into_scalar()\n                );\n            }\n            //  If at save interval => save the first 25 generated images\n            if epoch % config.sample_interval == 0 && iteration == 0 {\n                // [B, C, H, W] to [B, H, C, W] to [B, H, W, C]\n                let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25);\n                // Normalize the images. The Rgb32 images should be in range 0.0-1.0\n                let fake_images = (fake_images.clone()\n                    - fake_images.clone().min().reshape([1, 1, 1, 1]))\n                    / (fake_images.clone().max().reshape([1, 1, 1, 1])\n                        - fake_images.clone().min().reshape([1, 1, 1, 1]));\n                // Add 0.5/255.0 to the images, refer to pytorch save_image source\n                let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0);\n                // Save images in artifact directory\n                let path = format!(\"{artifact_dir}/image-{epoch}.png\");\n                save_image::<B, _>(fake_images, 5, path).unwrap();\n            }\n        }\n    }\n\n    // Save the trained models\n    generator\n        .save_file(format!(\"{artifact_dir}/generator\"), &CompactRecorder::new())\n        .expect(\"Generator should be saved successfully\");\n    discriminator\n        .save_file(\n            format!(\"{artifact_dir}/discriminator\"),\n            &CompactRecorder::new(),\n        )\n        .expect(\"Discriminator should be saved successfully\");\n}\n"
  },
  {
    "path": "rustfmt.toml",
    "content": "max_width = 100\n\n# uncomment and run `cargo +nightly fmt --all` to find and fix lines that are too long (and therefore break autoformatting)\n# error_on_line_overflow = true\n# format_strings = true\n"
  },
  {
    "path": "xtask/Cargo.toml",
    "content": "[package]\nname = \"xtask\"\nversion = \"4.10.0\"\nedition.workspace = true\nlicense = \"MIT OR Apache-2.0\"\n\n# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html\n\n[lints]\nworkspace = true\n\n[dependencies]\nlog = { workspace = true }\nstrum = { workspace = true }\ntracel-xtask = { workspace = true }\n\n[dev-dependencies]\nrstest = { workspace = true }\n"
  },
  {
    "path": "xtask/src/commands/books.rs",
    "content": "use std::path::Path;\n\nuse tracel_xtask::prelude::*;\n\n#[derive(clap::Args)]\npub struct BooksArgs {\n    #[command(subcommand)]\n    book: BookKind,\n}\n\n#[derive(clap::Subcommand)]\npub(crate) enum BookKind {\n    ///  Burn Book, a.k.a. the guide, made for the Burn users.\n    Burn(BookKindArgs),\n    /// Contributor book, made for people willing to get all the technical understanding and advice to contribute actively to the project.\n    Contributor(BookKindArgs),\n}\n\n#[derive(clap::Args)]\npub(crate) struct BookKindArgs {\n    #[command(subcommand)]\n    command: BookSubCommand,\n}\n\n#[derive(clap::Subcommand, strum::Display)]\npub(crate) enum BookSubCommand {\n    /// Build the book\n    Build,\n    /// Open the book on the specified port or random port and rebuild it automatically upon changes\n    Open(OpenArgs),\n}\n\n#[derive(clap::Args)]\npub(crate) struct OpenArgs {\n    /// Specify the port to open the book on (defaults to a random port if not specified)\n    #[clap(long, default_value_t = random_port())]\n    port: u16,\n}\n\n/// Book information\npub(crate) struct Book {\n    name: &'static str,\n    path: &'static Path,\n}\n\nimpl BooksArgs {\n    pub(crate) fn parse(&self) -> anyhow::Result<()> {\n        Book::run(&self.book)\n    }\n}\n\nimpl Book {\n    const BURN_BOOK_NAME: &'static str = \"Burn Book\";\n    const BURN_BOOK_PATH: &'static str = \"./burn-book\";\n\n    const CONTRIBUTOR_BOOK_NAME: &'static str = \"Contributor Book\";\n    const CONTRIBUTOR_BOOK_PATH: &'static str = \"./contributor-book\";\n\n    pub(crate) fn run(book_arg: &BookKind) -> anyhow::Result<()> {\n        let (book, command) = match book_arg {\n            BookKind::Burn(args) => (\n                Self {\n                    name: Self::BURN_BOOK_NAME,\n                    path: Path::new(Self::BURN_BOOK_PATH),\n                },\n                &args.command,\n            ),\n            BookKind::Contributor(args) => (\n                Self {\n                    name: Self::CONTRIBUTOR_BOOK_NAME,\n                    path: Path::new(Self::CONTRIBUTOR_BOOK_PATH),\n                },\n                &args.command,\n            ),\n        };\n        book.execute(command)\n    }\n\n    fn execute(&self, command: &BookSubCommand) -> anyhow::Result<()> {\n        ensure_cargo_crate_is_installed(\"mdbook\", None, None, false)?;\n        group!(\"{}: {}\", self.name, command);\n        match command {\n            BookSubCommand::Build => self.build(),\n            BookSubCommand::Open(args) => self.open(args),\n        }?;\n        endgroup!();\n        Ok(())\n    }\n\n    fn build(&self) -> anyhow::Result<()> {\n        run_process(\n            \"mdbook\",\n            &[\"build\"],\n            None,\n            Some(self.path),\n            \"mdbook should build the book successfully\",\n        )\n    }\n\n    fn open(&self, args: &OpenArgs) -> anyhow::Result<()> {\n        run_process(\n            \"mdbook\",\n            &[\"serve\", \"--open\", \"--port\", &args.port.to_string()],\n            None,\n            Some(self.path),\n            \"mdbook should open the book successfully\",\n        )\n    }\n}\n"
  },
  {
    "path": "xtask/src/commands/build.rs",
    "content": "use std::collections::HashMap;\n\nuse tracel_xtask::prelude::{clap::ValueEnum, *};\n\nuse crate::{ARM_NO_ATOMIC_PTR_TARGET, ARM_TARGET, NO_STD_CRATES, WASM32_TARGET};\n\n#[macros::extend_command_args(BuildCmdArgs, Target, None)]\npub struct BurnBuildCmdArgs {\n    /// Build in CI mode which excludes unsupported crates.\n    #[arg(long)]\n    pub ci: bool,\n}\n\npub(crate) fn handle_command(\n    mut args: BurnBuildCmdArgs,\n    env: Environment,\n    context: Context,\n) -> anyhow::Result<()> {\n    match context {\n        Context::NoStd => {\n            [\n                \"Default\",\n                WASM32_TARGET,\n                ARM_TARGET,\n                ARM_NO_ATOMIC_PTR_TARGET,\n            ]\n            .iter()\n            .try_for_each(|build_target| {\n                let mut build_args = vec![\"--no-default-features\"];\n                let mut env_vars = HashMap::new();\n                if *build_target != \"Default\" {\n                    build_args.extend(vec![\"--target\", *build_target]);\n                }\n\n                let mut crates = NO_STD_CRATES.to_vec();\n\n                if *build_target == ARM_NO_ATOMIC_PTR_TARGET {\n                    // Temporarily remove `burn-autodiff` from building with the\n                    // target `thumbv6m-none-eabi` as it requires enabling the\n                    // `arbitrary_self_types` feature for the\n                    // `clone_if_require_grad` method of\n                    // `burn-autodiff::graph::Node`\n                    crates.retain(|&v| v != \"burn-autodiff\");\n\n                    env_vars.insert(\n                        \"RUSTFLAGS\",\n                        \"--cfg portable_atomic_unsafe_assume_single_core\",\n                    );\n                }\n                helpers::custom_crates_build(\n                    crates,\n                    build_args,\n                    Some(env_vars),\n                    None,\n                    &format!(\"no-std with target {}\", *build_target),\n                )\n            })?;\n            Ok(())\n        }\n        Context::Std => {\n            if args.ci {\n                // Exclude crates that are not supported on CI\n                args.exclude.extend(vec![\n                    \"burn-cuda\".to_string(),\n                    \"burn-rocm\".to_string(),\n                    \"burn-tch\".to_string(),\n                ]);\n                if std::env::var(\"DISABLE_WGPU\").is_ok() {\n                    args.exclude.extend(vec![\"burn-wgpu\".to_string()]);\n                };\n            }\n            // Build workspace\n            base_commands::build::handle_command(args.try_into().unwrap(), env, context)?;\n            // Specific additional commands to test specific features\n            // burn-dataset\n            helpers::custom_crates_build(\n                vec![\"burn-dataset\"],\n                vec![\"--all-features\"],\n                None,\n                None,\n                \"std with all features\",\n            )?;\n            Ok(())\n        }\n        Context::All => Context::value_variants()\n            .iter()\n            .filter(|ctx| **ctx != Context::All)\n            .try_for_each(|ctx| {\n                handle_command(\n                    BurnBuildCmdArgs {\n                        target: args.target.clone(),\n                        exclude: args.exclude.clone(),\n                        only: args.only.clone(),\n                        ci: args.ci,\n                        release: args.release,\n                        features: args.features.clone(),\n                        no_default_features: args.no_default_features,\n                    },\n                    env.clone(),\n                    ctx.clone(),\n                )\n            }),\n    }\n}\n"
  },
  {
    "path": "xtask/src/commands/doc.rs",
    "content": "use tracel_xtask::prelude::*;\n\npub(crate) fn handle_command(\n    mut args: DocCmdArgs,\n    env: Environment,\n    ctx: Context,\n) -> anyhow::Result<()> {\n    if args.get_command() == DocSubCommand::Build {\n        args.exclude\n            .extend(vec![\"burn-cuda\".to_string(), \"burn-rocm\".to_string()]);\n    }\n\n    // Execute documentation command on workspace\n    base_commands::doc::handle_command(args.clone(), env, ctx)?;\n\n    // Specific additional commands to build other docs\n    if args.get_command() == DocSubCommand::Build {\n        // burn-dataset\n        helpers::custom_crates_doc_build(\n            vec![\"burn-dataset\"],\n            vec![\"--all-features\"],\n            None,\n            None,\n            \"All features\",\n        )?;\n    }\n    Ok(())\n}\n"
  },
  {
    "path": "xtask/src/commands/mod.rs",
    "content": "pub(crate) mod books;\npub(crate) mod build;\npub(crate) mod doc;\npub(crate) mod test;\npub(crate) mod validate;\n"
  },
  {
    "path": "xtask/src/commands/test.rs",
    "content": "use tracel_xtask::{\n    prelude::{clap::ValueEnum, *},\n    utils::{\n        process::{ExitSignal, ProcessExitError},\n        workspace::WorkspaceMember,\n    },\n};\n\nuse crate::NO_STD_CRATES;\n\n#[cfg(unix)]\nuse std::os::unix::process::ExitStatusExt;\n\n#[macros::extend_command_args(TestCmdArgs, Target, TestSubCommand)]\npub struct BurnTestCmdArgs {\n    /// Test in CI mode which excludes unsupported crates.\n    #[arg(long)]\n    pub ci: CiTestType,\n}\n\n#[allow(clippy::enum_variant_names)]\n#[derive(Debug, Clone, ValueEnum, PartialEq)]\npub enum CiTestType {\n    GithubRunner,\n    GithubMacRunner,\n    GcpCudaRunner,\n    GcpVulkanRunner,\n    GcpWgpuRunner,\n}\n\nfn handle_backend_tests(\n    mut args: TestCmdArgs,\n    backend: &str,\n    env: Environment,\n    context: Context,\n) -> anyhow::Result<()> {\n    args.target = Target::AllPackages;\n    args.only.push(\"burn-backend-tests\".to_string());\n    args.no_default_features = true;\n\n    let mut features = vec![String::from(backend)];\n    if !matches!(context, Context::NoStd) {\n        features.push(\"std\".into())\n    }\n    args.features = Some(features);\n\n    base_commands::test::handle_command(args, env, context)\n}\n\nfn handle_wgpu_test(member: &str, args: &TestCmdArgs) -> anyhow::Result<()> {\n    #[cfg(unix)]\n    let filter_err = |e: &&ProcessExitError| {\n        e.status.signal() == Some(11) || matches!(e.signal, Some(ExitSignal { code: 11, .. }))\n    };\n    #[cfg(not(unix))]\n    let filter_err = |e: &&ProcessExitError| matches!(e.signal, Some(ExitSignal { code: 11, .. }));\n\n    let workspace_member = WorkspaceMember {\n        name: member.into(),\n        path: \"\".into(), // unused\n    };\n\n    if let Err(err) = base_commands::test::run_unit_test(&workspace_member, args) {\n        let should_ignore = err\n            .downcast_ref::<ProcessExitError>()\n            .filter(filter_err)\n            // Failed to execute unit test for '{member}'\n            .map(|e| e.message.contains(member))\n            .unwrap_or(false);\n\n        if should_ignore {\n            // Ignore intermittent successful failures\n            // https://github.com/gfx-rs/wgpu/issues/2949\n            // https://github.com/KhronosGroup/Vulkan-ValidationLayers/issues/4391\n            eprintln!(\"⚠️ Ignored SIGSEGV in wgpu test\");\n        } else {\n            return Err(err);\n        }\n    }\n    Ok(())\n}\n\npub(crate) fn handle_command(\n    mut args: BurnTestCmdArgs,\n    env: Environment,\n    context: Context,\n) -> anyhow::Result<()> {\n    match context {\n        Context::NoStd => {\n            [\"Default\"].iter().try_for_each(|test_target| {\n                let mut test_args = vec![\"--no-default-features\"];\n                if *test_target != \"Default\" {\n                    test_args.extend(vec![\"--target\", *test_target]);\n                }\n                helpers::custom_crates_tests(\n                    NO_STD_CRATES.to_vec(),\n                    handle_test_args(&test_args, args.release),\n                    None,\n                    None,\n                    \"no-std\",\n                )\n            })?;\n            handle_backend_tests(args.clone().try_into().unwrap(), \"ndarray\", env, context)?;\n\n            Ok(())\n        }\n        Context::Std => {\n            // 1) Tests with default features\n            // ------------------------------\n            match args.ci {\n                CiTestType::GithubRunner => {\n                    // Exclude crates that are not supported on CI\n                    args.exclude.extend(vec![\n                        \"burn-cpu\".to_string(),\n                        \"burn-cuda\".to_string(),\n                        \"burn-rocm\".to_string(),\n                        // \"burn-router\" uses \"burn-wgpu\" for the tests.\n                        \"burn-router\".to_string(),\n                        \"burn-tch\".to_string(),\n                        \"burn-wgpu\".to_string(),\n                        // dqn-agent example relies on gym-rs dependency which requires SDL2.\n                        // It would be good to remove the gym-rs dependency in the future.\n                        \"dqn-agent\".to_string(),\n                        // Requires wgpu runtime\n                        \"burn-cubecl-fusion\".to_string(),\n                    ]);\n\n                    // Burn remote tests don't work on windows for now\n                    #[cfg(target_os = \"windows\")]\n                    {\n                        args.exclude.extend(vec![\"burn-remote\".to_string()]);\n                    };\n\n                    base_commands::test::handle_command(\n                        args.clone().try_into().unwrap(),\n                        env.clone(),\n                        context.clone(),\n                    )?;\n\n                    handle_backend_tests(\n                        args.clone().try_into().unwrap(),\n                        \"ndarray\",\n                        env,\n                        context,\n                    )?;\n                }\n                CiTestType::GithubMacRunner => {\n                    handle_backend_tests(\n                        args.clone().try_into().unwrap(),\n                        \"metal\",\n                        env.clone(),\n                        context.clone(),\n                    )?;\n\n                    args.target = Target::AllPackages;\n                    args.only.push(\"burn-wgpu\".to_string());\n                    args.features\n                        .get_or_insert_with(Vec::new)\n                        .push(\"metal\".to_string());\n\n                    base_commands::test::handle_command(\n                        args.clone().try_into().unwrap(),\n                        env,\n                        context,\n                    )?;\n                }\n                CiTestType::GcpCudaRunner => {\n                    handle_backend_tests(args.clone().try_into().unwrap(), \"cuda\", env, context)?;\n                }\n                CiTestType::GcpVulkanRunner => {\n                    handle_backend_tests(args.clone().try_into().unwrap(), \"vulkan\", env, context)?;\n\n                    args.target = Target::AllPackages;\n                    let mut args_vulkan: TestCmdArgs = args.clone().try_into().unwrap();\n                    args_vulkan.features = Some(vec![\"test-vulkan\".into()]);\n                    handle_wgpu_test(\"burn-core\", &args_vulkan)?;\n                    handle_wgpu_test(\"burn-optim\", &args_vulkan)?;\n                    handle_wgpu_test(\"burn-nn\", &args_vulkan)?;\n                    handle_wgpu_test(\"burn-vision\", &args_vulkan)?;\n                }\n                CiTestType::GcpWgpuRunner => {\n                    handle_backend_tests(args.clone().try_into().unwrap(), \"wgpu\", env, context)?;\n                    // \"burn-router\" uses \"burn-wgpu\" for the tests.\n                    args.target = Target::AllPackages;\n                    let mut args_wgpu = args.clone().try_into().unwrap();\n                    handle_wgpu_test(\"burn-wgpu\", &args_wgpu)?;\n                    handle_wgpu_test(\"burn-router\", &args_wgpu)?;\n                    handle_wgpu_test(\"burn-cubecl-fusion\", &args_wgpu)?;\n\n                    args_wgpu.features = Some(vec![\"test-wgpu\".into()]);\n                    handle_wgpu_test(\"burn-core\", &args_wgpu)?;\n                    handle_wgpu_test(\"burn-optim\", &args_wgpu)?;\n                    handle_wgpu_test(\"burn-nn\", &args_wgpu)?;\n                    handle_wgpu_test(\"burn-vision\", &args_wgpu)?;\n                }\n            }\n\n            // 2) Specific additional commands to test specific features\n            // ---------------------------------------------------------\n            match args.ci {\n                CiTestType::GithubRunner => {\n                    // burn-dataset\n                    helpers::custom_crates_tests(\n                        vec![\"burn-dataset\"],\n                        handle_test_args(&[\"--all-features\"], args.release),\n                        None,\n                        None,\n                        \"std all features\",\n                    )?;\n\n                    // burn-core\n                    helpers::custom_crates_tests(\n                        vec![\"burn-core\"],\n                        handle_test_args(\n                            &[\"--features\", \"test-tch,record-item-custom-serde\"],\n                            args.release,\n                        ),\n                        None,\n                        None,\n                        \"std with features: test-tch,record-item-custom-serde\",\n                    )?;\n\n                    // burn-vision\n                    helpers::custom_crates_tests(\n                        vec![\"burn-vision\"],\n                        handle_test_args(&[\"--features\", \"test-cpu\"], args.release),\n                        None,\n                        None,\n                        \"std cpu\",\n                    )?;\n\n                    // burn-train vision (LPIPS, DISTS metrics)\n                    helpers::custom_crates_tests(\n                        vec![\"burn-train\"],\n                        handle_test_args(&[\"--features\", \"vision\"], args.release),\n                        None,\n                        None,\n                        \"std vision\",\n                    )?;\n\n                    // burn-nn (pretrained and local tests)\n                    let mut nn_features = \"pretrained\".to_string();\n                    // If the \"CI\" environment variable is missing, we are running locally.\n                    if std::env::var(\"CI\").is_err() {\n                        nn_features.push_str(\",test-local\");\n                    }\n                    helpers::custom_crates_tests(\n                        vec![\"burn-nn\"],\n                        handle_test_args(&[\"--features\", &nn_features], args.release),\n                        None,\n                        None,\n                        &format!(\"std burn-nn with features: {}\", nn_features),\n                    )?;\n                }\n                CiTestType::GcpCudaRunner => (),\n                CiTestType::GcpVulkanRunner | CiTestType::GcpWgpuRunner => (), // handled in tests above\n                CiTestType::GithubMacRunner => {\n                    // burn-ndarray\n                    helpers::custom_crates_tests(\n                        vec![\"burn-ndarray\"],\n                        handle_test_args(&[\"--features\", \"blas-accelerate\"], args.release),\n                        None,\n                        None,\n                        \"std blas-accelerate\",\n                    )?;\n\n                    // burn-train vision (LPIPS, DISTS metrics)\n                    helpers::custom_crates_tests(\n                        vec![\"burn-train\"],\n                        handle_test_args(&[\"--features\", \"vision\"], args.release),\n                        None,\n                        None,\n                        \"std vision\",\n                    )?;\n                    helpers::custom_crates_tests(\n                        vec![\"burn-core\"],\n                        handle_test_args(&[\"--features\", \"test-metal\"], args.release),\n                        None,\n                        None,\n                        \"std metal\",\n                    )?;\n                    helpers::custom_crates_tests(\n                        vec![\"burn-vision\"],\n                        handle_test_args(&[\"--features\", \"test-metal\"], args.release),\n                        None,\n                        None,\n                        \"std metal\",\n                    )?;\n                }\n            }\n            Ok(())\n        }\n        Context::All => Context::value_variants()\n            .iter()\n            .filter(|ctx| **ctx != Context::All)\n            .try_for_each(|ctx| {\n                handle_command(\n                    BurnTestCmdArgs {\n                        command: args.command.clone(),\n                        target: args.target.clone(),\n                        exclude: args.exclude.clone(),\n                        only: args.only.clone(),\n                        threads: args.threads,\n                        jobs: args.jobs,\n                        ci: args.ci.clone(),\n                        features: args.features.clone(),\n                        no_default_features: args.no_default_features,\n                        release: args.release,\n                        test: args.test.clone(),\n                        force: args.force,\n                        no_capture: args.no_capture,\n                    },\n                    env.clone(),\n                    ctx.clone(),\n                )\n            }),\n    }\n}\n\nfn handle_test_args<'a>(args: &'a [&'a str], release: bool) -> Vec<&'a str> {\n    let mut args = args.to_vec();\n    if release {\n        args.push(\"--release\");\n    }\n    args\n}\n"
  },
  {
    "path": "xtask/src/commands/validate.rs",
    "content": "use tracel_xtask::prelude::*;\n\nuse crate::commands::{\n    build::BurnBuildCmdArgs,\n    test::{BurnTestCmdArgs, CiTestType},\n};\n\npub fn handle_command(\n    args: &ValidateCmdArgs,\n    env: Environment,\n    context: Context,\n) -> anyhow::Result<()> {\n    let target = Target::Workspace;\n    let exclude = vec![];\n    let only = vec![];\n\n    if context == Context::NoStd || context == Context::All {\n        // =================\n        // no-std validation\n        // =================\n        info!(\"Run validation for no-std execution environment...\");\n\n        #[cfg(target_os = \"linux\")]\n        {\n            // build\n            super::build::handle_command(\n                BurnBuildCmdArgs {\n                    target: target.clone(),\n                    exclude: exclude.clone(),\n                    only: only.clone(),\n                    ci: true,\n                    release: args.release,\n                    features: args.features.clone(),\n                    no_default_features: args.no_default_features,\n                },\n                env.clone(),\n                Context::NoStd,\n            )?;\n\n            // tests\n            super::test::handle_command(\n                BurnTestCmdArgs {\n                    target: target.clone(),\n                    exclude: exclude.clone(),\n                    only: only.clone(),\n                    threads: None,\n                    jobs: None,\n                    command: Some(TestSubCommand::All),\n                    ci: CiTestType::GithubRunner,\n                    features: None,\n                    no_default_features: false,\n                    force: false,\n                    no_capture: false,\n                    release: args.release,\n                    test: None,\n                },\n                env.clone(),\n                Context::NoStd,\n            )?;\n        }\n    }\n\n    if context == Context::Std || context == Context::All {\n        // ==============\n        // std validation\n        // ==============\n        info!(\"Run validation for std execution environment...\");\n\n        // checks\n        [\n            CheckSubCommand::Audit,\n            CheckSubCommand::Format,\n            CheckSubCommand::Lint,\n            CheckSubCommand::Typos,\n        ]\n        .iter()\n        .try_for_each(|c| {\n            base_commands::check::handle_command(\n                CheckCmdArgs {\n                    target: target.clone(),\n                    exclude: exclude.clone(),\n                    only: only.clone(),\n                    command: Some(c.clone()),\n                    ignore_audit: args.ignore_audit,\n                    features: args.features.clone(),\n                    no_default_features: args.no_default_features,\n                    ignore_typos: args.ignore_typos,\n                },\n                env.clone(),\n                context.clone(),\n            )\n        })?;\n\n        // build\n        super::build::handle_command(\n            BurnBuildCmdArgs {\n                target: target.clone(),\n                exclude: exclude.clone(),\n                only: only.clone(),\n                ci: true,\n                release: args.release,\n                features: args.features.clone(),\n                no_default_features: args.no_default_features,\n            },\n            env.clone(),\n            Context::Std,\n        )?;\n\n        // tests\n        super::test::handle_command(\n            BurnTestCmdArgs {\n                target: target.clone(),\n                exclude: exclude.clone(),\n                only: only.clone(),\n                threads: None,\n                jobs: None,\n                command: Some(TestSubCommand::All),\n                ci: CiTestType::GithubRunner,\n                features: None,\n                no_default_features: false,\n                release: args.release,\n                test: None,\n                force: false,\n                no_capture: false,\n            },\n            env.clone(),\n            Context::Std,\n        )?;\n\n        // documentation\n        [DocSubCommand::Build, DocSubCommand::Tests]\n            .iter()\n            .try_for_each(|c| {\n                super::doc::handle_command(\n                    DocCmdArgs {\n                        target: target.clone(),\n                        exclude: exclude.clone(),\n                        only: only.clone(),\n                        command: Some(c.clone()),\n                        features: args.features.clone(),\n                        no_default_features: args.no_default_features,\n                    },\n                    env.clone(),\n                    context.clone(),\n                )\n            })?;\n    }\n\n    Ok(())\n}\n"
  },
  {
    "path": "xtask/src/main.rs",
    "content": "mod commands;\n\n#[macro_use]\nextern crate log;\n\nuse std::time::Instant;\nuse tracel_xtask::prelude::*;\n\n// no-std\nconst WASM32_TARGET: &str = \"wasm32-unknown-unknown\";\nconst ARM_TARGET: &str = \"thumbv7m-none-eabi\";\nconst ARM_NO_ATOMIC_PTR_TARGET: &str = \"thumbv6m-none-eabi\";\nconst NO_STD_CRATES: &[&str] = &[\n    \"burn\",\n    \"burn-autodiff\",\n    \"burn-core\",\n    \"burn-std\",\n    \"burn-backend\",\n    \"burn-tensor\",\n    \"burn-ndarray\",\n    \"burn-no-std-tests\",\n];\n\n#[macros::base_commands(\n    Bump,\n    Check,\n    Compile,\n    Coverage,\n    Doc,\n    Dependencies,\n    Fix,\n    Publish,\n    Validate,\n    Vulnerabilities\n)]\npub enum Command {\n    /// Run commands to manage Burn Books.\n    Books(commands::books::BooksArgs),\n    /// Build Burn in different modes.\n    Build(commands::build::BurnBuildCmdArgs),\n    /// Test Burn.\n    Test(commands::test::BurnTestCmdArgs),\n}\n\nfn main() -> anyhow::Result<()> {\n    let start = Instant::now();\n    let (args, environment) = init_xtask::<Command>(parse_args::<Command>()?)?;\n\n    if args.context == Context::NoStd {\n        // Install additional targets for no-std execution environments\n        rustup_add_target(WASM32_TARGET)?;\n        rustup_add_target(ARM_TARGET)?;\n        rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET)?;\n    }\n\n    match args.command {\n        Command::Books(cmd_args) => cmd_args.parse(),\n        Command::Build(cmd_args) => {\n            commands::build::handle_command(cmd_args, environment, args.context)\n        }\n        Command::Doc(cmd_args) => {\n            commands::doc::handle_command(cmd_args, environment, args.context)\n        }\n        Command::Test(cmd_args) => {\n            commands::test::handle_command(cmd_args, environment, args.context)\n        }\n        Command::Validate(cmd_args) => {\n            commands::validate::handle_command(&cmd_args, environment, args.context)\n        }\n        _ => dispatch_base_commands(args, environment),\n    }?;\n\n    let duration = start.elapsed();\n    info!(\n        \"\\x1B[32;1mTime elapsed for the current execution: {}\\x1B[0m\",\n        format_duration(&duration)\n    );\n\n    Ok(())\n}\n"
  }
]